In [None]:
import torch
from utils import *
from matplotlib import pyplot as plt
set_seed(123)

device=torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
print("Device:",device)

In [None]:
def normal_distribution(mean, std, amount=100):
    if len(mean)!=len(std):
        raise("Different Dim!")
    dim=len(mean)
    data=torch.empty((amount, dim), device=device)
    for i in range(dim):
        data[:,i]=data[:,i].normal_(mean=mean[i], std=std[i])
    return data

In [None]:
INPUT_DIM=3
CLASS_NUM=20
SAMPLE_NUM=10
MEAN_ARRANGE=50
STD_ARRANGE=6

if INPUT_DIM==2:
    # 2D
    # mean_list=[
    #     [1,1],
    #     [10,10],
    #     [3,5],
    #     [10,4],
    # ]
    # std_list=[
    #     [1,0.2],
    #     [0.5,0.5],
    #     [.5,.2],
    #     [.8,1],
    # ]
    mean_list=[]
    std_list=[]
    for i in range(CLASS_NUM):
        mean_list.append([random.random()*MEAN_ARRANGE, random.random()*MEAN_ARRANGE])
        std_list.append([0.5+random.random()*STD_ARRANGE, 0.5+random.random()*STD_ARRANGE])
    
    data={}
    for i in range(len(mean_list)):
        temp=normal_distribution(mean_list[i], std_list[i], SAMPLE_NUM)
        data[str(i)]=temp

    if CLASS_NUM<=25:
        for tag in data:
            plt.scatter(data[tag][:, 0].cpu(), data[tag][:, 1].cpu(), label=tag)
        plt.legend()
elif INPUT_DIM==3:
    # 3D
    # mean_list=[
    #     [1,1,1],
    #     [10,10,10],
    #     [3,5,8],
    #     [8,5,0],
    # ]
    # std_list=[
    #     [1,0.2,1],
    #     [0.5,0.5,0.5],
    #     [1,2,2],
    #     [.8,2,.5],
    # ]
    mean_list=[]
    std_list=[]
    for i in range(CLASS_NUM):
        mean_list.append([random.random()*MEAN_ARRANGE, random.random()*MEAN_ARRANGE, random.random()*MEAN_ARRANGE])
        std_list.append([0.5+random.random()*STD_ARRANGE, 0.5+random.random()*STD_ARRANGE, 0.5+random.random()*STD_ARRANGE])

    data={}
    for i in range(len(mean_list)):
        temp=normal_distribution(mean_list[i], std_list[i], SAMPLE_NUM)
        data[str(i)]=temp

    if CLASS_NUM<=25:
        fig = plt.figure()
        ax = fig.add_subplot(projection='3d')
        for tag in data:
            ax.scatter(data[tag][:, 0].cpu(), data[tag][:, 1].cpu(), data[tag][:, 2].cpu(), label=tag)
        ax.set_xlabel('X')
        ax.set_ylabel('Y')
        ax.set_zlabel('Z')
        ax.legend()

train_pipe=[]
train_dict={}
test_dict={}
for tag in data:
    train, test = train_test_split([i for i in data[tag]], test_size=0.2)
    train_pipe.extend([(_,tag) for _ in train])
    train_dict[tag]=train
    test_dict[tag]=test

In [None]:
def batch_train(batch_size):
    shuffled = random.sample(train_pipe, len(train_pipe))
    for i in range(0, len(shuffled)+batch_size, batch_size):
        if train_pipe[i:i+batch_size]:
            loss=Model.batch_train(train_pipe[i:i+batch_size])
    Model.epoch+=1
    Model.loss_dict["Classifier Batch Train Loss"].append((Model.epoch, loss))

def continual_forward_baseline(shuffle=False):
    if shuffle:
        pipe = random.sample(train_pipe, len(train_pipe))
    else:
        pipe = train_pipe.copy()
    
    for i in tqdm(pipe):
        Model.continual_forward_baseline(i[0],i[1])
        if Model.epoch%SAMPLE_NUM==int(SAMPLE_NUM/2):
            predict_trainset(silent=True)
            predict_testset(silent=True)

def continual_train(shuffle=False):
    if shuffle:
        pipe = random.sample(train_pipe, len(train_pipe))
    else:
        pipe = train_pipe.copy()
    
    for i in tqdm(pipe):
        Model.continual_forward(i[0],i[1])
        if Model.epoch%SAMPLE_NUM==int(SAMPLE_NUM/2):
            predict_trainset(silent=True)
            predict_testset(silent=True)

def predict_trainset(silent=False):
    total=0
    total_correct=0
    # for key, value in tqdm(train_dict.items()):
    for key, value in train_dict.items():
        if Model.tag_dict.get(key):
            total+=len(value)
            current_correct=0
            pred_tags_list=Model.predict(value)
            for pred_tags in pred_tags_list:
                if key in pred_tags:
                    total_correct+=1
                    current_correct+=1
            acc=current_correct/len(value)
            acc_dict[key+"_train"].append([Model.epoch, acc])
    total_acc=total_correct/total
    acc_dict["total_train"].append([Model.epoch, total_acc])
    if not silent:
        print(total_acc)

def predict_testset(silent=False):
    total=0
    total_correct=0
    # for key, value in tqdm(train_dict.items()):
    for key, value in test_dict.items():
        if Model.tag_dict.get(key):
            total+=len(value)
            current_correct=0
            pred_tags_list=Model.predict(value)
            for pred_tags in pred_tags_list:
                if key in pred_tags:
                    total_correct+=1
                    current_correct+=1
            acc=current_correct/len(value)
            acc_dict[key+"_test"].append([Model.epoch, acc])
    total_acc=total_correct/total
    acc_dict["total_test"].append([Model.epoch, total_acc])
    if not silent:
        print(total_acc)

def predict_train_and_paint():
    predict_result={}
    for key, value in tqdm(train_dict.items()):
        pred_tags_list=Model.predict(value)
        for pred_tags, sample in zip(pred_tags_list, value):
            pred_tag=pred_tags[0]
            if not predict_result.get(pred_tag):
                predict_result[pred_tag]=[sample]
            else:
                predict_result[pred_tag].append(sample)
    predict_result=dict(sorted(predict_result.items(), key=lambda x:x[0]))
    if INPUT_DIM==2:
        for tag in predict_result:
            t=torch.tensor([i.tolist() for i in predict_result[tag]])
            plt.scatter(t[:, 0], t[:, 1], label=tag)
        # plt.legend()
        plt.title("TRAIN")
    elif INPUT_DIM==3:
        fig = plt.figure()
        ax = fig.add_subplot(projection='3d')
        for tag in predict_result:
            t=torch.tensor([i.tolist() for i in predict_result[tag]])
            ax.scatter(t[:, 0].cpu(), t[:, 1].cpu(), t[:, 2].cpu(), label=tag)
        ax.set_xlabel('X')
        ax.set_ylabel('Y')
        ax.set_zlabel('Z')
        # ax.legend()
        plt.title("TRAIN")

def predict_test_and_paint():
    predict_result={}
    for key, value in tqdm(test_dict.items()):
        pred_tags_list=Model.predict(value)
        for pred_tags, sample in zip(pred_tags_list, value):
            pred_tag=pred_tags[0]
            if not predict_result.get(pred_tag):
                predict_result[pred_tag]=[sample]
            else:
                predict_result[pred_tag].append(sample)
    predict_result=dict(sorted(predict_result.items(), key=lambda x:x[0]))
    if INPUT_DIM==2:
        for tag in predict_result:
            t=torch.tensor([i.tolist() for i in predict_result[tag]])
            plt.scatter(t[:, 0], t[:, 1], label=tag)
        # plt.legend()
        plt.title("TEST")
    elif INPUT_DIM==3:
        fig = plt.figure()
        ax = fig.add_subplot(projection='3d')
        for tag in predict_result:
            t=torch.tensor([i.tolist() for i in predict_result[tag]])
            ax.scatter(t[:, 0].cpu(), t[:, 1].cpu(), t[:, 2].cpu(), label=tag)
        ax.set_xlabel('X')
        ax.set_ylabel('Y')
        ax.set_zlabel('Z')
        # ax.legend()
        plt.title("TEST")

def plot_loss():
    plt.figure()
    keys=[]
    for key in Model.loss_dict.keys():
        if Model.loss_dict[key]:
            t=torch.tensor(Model.loss_dict[key])
            plt.plot(t[:,0], t[:,1])
            keys.append(key)
    plt.legend(keys)
    plt.show()

def plot_acc():
    plt.figure(figsize=(16,8))
    ax=plt.subplot(1,2,1)
    for key in [i for i in acc_dict.keys() if "total" not in i and "train" in i]:
        t=torch.tensor(acc_dict[key])
        ax.plot(t[:,0], t[:,1], label=key[:-6])
    ax.set_ylim(top=1.1, bottom=-0.1)
    ax.set_title("Train ACC")
    ax.legend()

    ax=plt.subplot(1,2,2)
    for key in [i for i in acc_dict.keys() if "total" not in i and "test" in i]:
        t=torch.tensor(acc_dict[key])
        ax.plot(t[:,0], t[:,1], label=key[:-5])
    ax.set_ylim(top=1.1, bottom=-0.1)
    ax.set_title("Test ACC")
    ax.legend()
    
    plt.show()

def initialize():
    global acc_dict
    acc_dict={}
    for tag in train_dict:
        if acc_dict.get(tag)==None:
            acc_dict[tag+"_test"]=[]
            acc_dict[tag+"_train"]=[]
    acc_dict["total_train"]=[]
    acc_dict["total_test"]=[]

In [None]:
INNER_DIM=100

class Classifier(torch.nn.Module):
    def __init__(self):
        super(Classifier, self).__init__()
        self.fc = torch.nn.Linear(INPUT_DIM, INNER_DIM)
        init_network(self)

    def forward(self, feature_vecs):
        out = self.fc(feature_vecs)
        return out

class Generator(torch.nn.Module):
    def __init__(self):
        super(Generator, self).__init__()
        self.fc = torch.nn.Linear(INNER_DIM, INPUT_DIM)
        init_network(self)

    def forward(self, tag_vecs):
        out = self.fc(tag_vecs)
        return out

class EmptyModel(torch.nn.Module):
    def __init__(self, *args, **kwargs) -> None:
        super().__init__(*args, **kwargs)
    
    def forward(self, t):
        t=t.reshape(1,-1)
        return t

In [None]:
acc_reservation={}
SEED=random.randint(1, 999999)
SEED

In [None]:
from ContinualLearning import BaseModel
set_seed(SEED)
Model=BaseModel(
    EmptyModel,
    Classifier,
    0.001
)

initialize()
for i in tqdm(range(100)):
    batch_train(16)
plot_loss()

acc_reservation["total_train_batch"]=acc_dict["total_train"].copy()
acc_reservation["total_test_batch"]=acc_dict["total_test"].copy()

predict_trainset()
predict_testset()

In [None]:
from ContinualLearning import BaseModel
set_seed(SEED)
Model=BaseModel(
    EmptyModel,
    Classifier,
    0.001
)

initialize()
continual_forward_baseline()
plot_loss()
plot_acc()

acc_reservation["total_train_baseline"]=acc_dict["total_train"].copy()
acc_reservation["total_test_baseline"]=acc_dict["total_test"].copy()

predict_trainset()
predict_testset()

In [None]:
from ContinualLearning import ContinualLearningModel_Store
set_seed(SEED)
Model=ContinualLearningModel_Store(
    EmptyModel,
    Classifier,
    0.001, 99
)

initialize()
continual_train()
plot_loss()
plot_acc()

acc_reservation["total_train_"]=acc_dict["total_train"].copy()
acc_reservation["total_test_"]=acc_dict["total_test"].copy()

predict_trainset()
predict_testset()

In [None]:
acc_reservation

In [None]:
predict_train_and_paint()

In [None]:
predict_test_and_paint()