In [None]:
import torch
from utils import *
from matplotlib import pyplot as plt
from ContinualLearning import ContinualLearningModel
set_seed(123)

In [None]:
device=torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
print("Device:",device)

In [None]:
def normal_distribution(mean, std, amount=100):
    if len(mean)!=len(std):
        raise("Different Dim!")
    dim=len(mean)
    data=torch.empty((amount, dim), device=device)
    for i in range(dim):
        data[:,i]=data[:,i].normal_(mean=mean[i], std=std[i])
    return data

In [None]:
INPUT_DIM=2
CLASS_NUM=15
SAMPLE_NUM=5

if INPUT_DIM==2:
    # 2D
    # mean_list=[
    #     [1,1],
    #     [10,10],
    #     [3,5],
    #     [10,4],
    # ]
    # std_list=[
    #     [1,0.2],
    #     [0.5,0.5],
    #     [.5,.2],
    #     [.8,1],
    # ]
    mean_list=[]
    std_list=[]
    for i in range(CLASS_NUM):
        mean_list.append([random.random()*100, random.random()*100])
        std_list.append([random.random()*2, random.random()*2])
    
    data={}
    for i in range(len(mean_list)):
        temp=normal_distribution(mean_list[i], std_list[i], SAMPLE_NUM)
        data[i]=temp

    if CLASS_NUM<=25:
        for tag in data:
            plt.scatter(data[tag][:, 0].cpu(), data[tag][:, 1].cpu(), label=tag)
        # plt.legend()
elif INPUT_DIM==3:
    # 3D
    # mean_list=[
    #     [1,1,1],
    #     [10,10,10],
    #     [3,5,8],
    #     [8,5,0],
    # ]
    # std_list=[
    #     [1,0.2,1],
    #     [0.5,0.5,0.5],
    #     [1,2,2],
    #     [.8,2,.5],
    # ]
    mean_list=[]
    std_list=[]
    for i in range(CLASS_NUM):
        mean_list.append([random.random()*100, random.random()*100, random.random()*100])
        std_list.append([random.random()*2, random.random()*2, random.random()*2])

    data={}
    for i in range(len(mean_list)):
        temp=normal_distribution(mean_list[i], std_list[i], SAMPLE_NUM)
        data[i]=temp

    if CLASS_NUM<=25:
        fig = plt.figure()
        ax = fig.add_subplot(projection='3d')
        for tag in data:
            ax.scatter(data[tag][:, 0].cpu(), data[tag][:, 1].cpu(), data[tag][:, 2].cpu(), label=tag)
        ax.set_xlabel('X')
        ax.set_ylabel('Y')
        ax.set_zlabel('Z')
        # ax.legend()

train_pipe=[]
train_dict={}
test_dict={}
for tag in data:
    train, test = train_test_split([i for i in data[tag]], test_size=0.2)
    train_pipe.extend([(_,tag) for _ in train])
    train_dict[tag]=train
    test_dict[tag]=test

random.shuffle(train_pipe)

In [None]:
def continual_train():
    random.shuffle(train_pipe)
    for i in tqdm(train_pipe):
        Model.continual_forward(i[0],i[1])

def continual_train_without_generator():
    random.shuffle(train_pipe)
    for i in tqdm(train_pipe):
        Model.continual_forward_without_generator(i[0],i[1])

def predict_trainset():
    total=0
    total_correct=0
    for key, value in tqdm(train_dict.items()):
        total+=len(value)
        pred_tags_list=Model.predict(value)
        for pred_tags in pred_tags_list:
            if key in pred_tags:
                total_correct+=1
    
    print(total_correct/total)

def predict_testset():
    total=0
    total_correct=0
    for key, value in tqdm(test_dict.items()):
        total+=len(value)
        pred_tags_list=Model.predict(value)
        for pred_tags in pred_tags_list:
            if key in pred_tags:
                total_correct+=1
    
    print(total_correct/total)

def predict_train_and_paint():
    predict_result={}
    for key, value in tqdm(train_dict.items()):
        pred_tags_list=Model.predict(value)
        for pred_tags, sample in zip(pred_tags_list, value):
            pred_tag=pred_tags[0]
            if not predict_result.get(pred_tag):
                predict_result[pred_tag]=[sample]
            else:
                predict_result[pred_tag].append(sample)
    predict_result=dict(sorted(predict_result.items(), key=lambda x:x[0]))
    if INPUT_DIM==2:
        for tag in predict_result:
            t=torch.tensor([i.tolist() for i in predict_result[tag]])
            plt.scatter(t[:, 0], t[:, 1], label=tag)
        # plt.legend()
        plt.title("TRAIN")
    elif INPUT_DIM==3:
        fig = plt.figure()
        ax = fig.add_subplot(projection='3d')
        for tag in predict_result:
            t=torch.tensor([i.tolist() for i in predict_result[tag]])
            ax.scatter(t[:, 0].cpu(), t[:, 1].cpu(), t[:, 2].cpu(), label=tag)
        ax.set_xlabel('X')
        ax.set_ylabel('Y')
        ax.set_zlabel('Z')
        # ax.legend()
        plt.title("TRAIN")

def predict_test_and_paint():
    predict_result={}
    for key, value in tqdm(test_dict.items()):
        pred_tags_list=Model.predict(value)
        for pred_tags, sample in zip(pred_tags_list, value):
            pred_tag=pred_tags[0]
            if not predict_result.get(pred_tag):
                predict_result[pred_tag]=[sample]
            else:
                predict_result[pred_tag].append(sample)
    predict_result=dict(sorted(predict_result.items(), key=lambda x:x[0]))
    if INPUT_DIM==2:
        for tag in predict_result:
            t=torch.tensor([i.tolist() for i in predict_result[tag]])
            plt.scatter(t[:, 0], t[:, 1], label=tag)
        # plt.legend()
        plt.title("TEST")
    elif INPUT_DIM==3:
        fig = plt.figure()
        ax = fig.add_subplot(projection='3d')
        for tag in predict_result:
            t=torch.tensor([i.tolist() for i in predict_result[tag]])
            ax.scatter(t[:, 0].cpu(), t[:, 1].cpu(), t[:, 2].cpu(), label=tag)
        ax.set_xlabel('X')
        ax.set_ylabel('Y')
        ax.set_zlabel('Z')
        # ax.legend()
        plt.title("TEST")

In [None]:
INNER_DIM=20

class Classifier(torch.nn.Module):
    def __init__(self):
        super(Classifier, self).__init__()
        self.fc = torch.nn.Linear(INPUT_DIM, INNER_DIM)
        init_network(self)

    def forward(self, feature_vecs):
        out = self.fc(feature_vecs)
        return out

class Generator(torch.nn.Module):
    def __init__(self):
        super(Generator, self).__init__()
        self.fc = torch.nn.Linear(INNER_DIM, INPUT_DIM)
        init_network(self)

    def forward(self, tag_vecs):
        out = self.fc(tag_vecs)
        return out

class EmptyLM(torch.nn.Module):
    def __init__(self, *args, **kwargs) -> None:
        super().__init__(*args, **kwargs)
    
    def forward(self, t):
        t=t.reshape(1,-1)
        return t

In [None]:
Model=ContinualLearningModel(EmptyLM, Classifier, Generator, 0.05, 0.05)

continual_train_without_generator()
plt.figure()
plt.plot(Model.loss_dict["Classifier Continual Baseline Train Loss"], label="Baseline Train Loss")
plt.show()
print()
predict_trainset()
predict_testset()

print()
Model.initilize()
continual_train()
plt.figure()
plt.plot(Model.loss_dict["Classifier Continual Attach Train Loss"], label="Attach Train Loss")
plt.plot(Model.loss_dict["Generator Continual Single Train Loss"], label="Generator")
plt.legend()
plt.show()
print()
predict_trainset()
predict_testset()

In [None]:
predict_train_and_paint()

In [None]:
predict_test_and_paint()