In [None]:
import torch
from utils import *
from matplotlib import pyplot as plt

device=torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
print("Device:",device)

In [None]:
def normal_distribution(mean, std, amount=100):
    if len(mean)!=len(std):
        raise("Different Dim!")
    dim=len(mean)
    data=torch.empty((amount, dim), device=device)
    for i in range(dim):
        data[:,i]=data[:,i].normal_(mean=mean[i], std=std[i])
    return data

In [None]:
INPUT_DIM=3
CLASS_NUM=20
SAMPLE_NUM=10
SAMPLE_NUM_RANGE=0.5
MEAN_ARRANGE=100
STD_ARRANGE=10
SAMPLE_SHIFTING=True

if INPUT_DIM==2:
    # 2D
    mean_list=[]
    std_list=[]
    for i in range(CLASS_NUM):
        mean_list.append([random.random()*MEAN_ARRANGE, random.random()*MEAN_ARRANGE])
        std_list.append([0.5+random.random()*STD_ARRANGE, 0.5+random.random()*STD_ARRANGE])
    
    data={}
    for i in range(len(mean_list)):
        temp=normal_distribution(mean_list[i], std_list[i], int( (1 + random.sample([1, -1], 1)[0] * random.random() * SAMPLE_NUM_RANGE) * SAMPLE_NUM))
        data[str(i)]=temp

    if CLASS_NUM<=25:
        for tag in data:
            plt.scatter(data[tag][:, 0].cpu(), data[tag][:, 1].cpu())
elif INPUT_DIM==3:
    # 3D
    mean_list=[]
    std_list=[]
    for i in range(CLASS_NUM):
        mean_list.append([random.random()*MEAN_ARRANGE, random.random()*MEAN_ARRANGE, random.random()*MEAN_ARRANGE])
        std_list.append([0.5+random.random()*STD_ARRANGE, 0.5+random.random()*STD_ARRANGE, 0.5+random.random()*STD_ARRANGE])

    data={}
    for i in range(len(mean_list)):
        temp=normal_distribution(mean_list[i], std_list[i], int( (1 + random.sample([1, -1], 1)[0] * random.random() * SAMPLE_NUM_RANGE) * SAMPLE_NUM))
        data[str(i)]=temp

    if CLASS_NUM<=25:
        fig = plt.figure()
        ax = fig.add_subplot(projection='3d')
        for tag in data:
            ax.scatter(data[tag][:, 0].cpu(), data[tag][:, 1].cpu(), data[tag][:, 2].cpu())
else:
    mean_list=[]
    std_list=[]
    for i in range(CLASS_NUM):
        mean_list.append([random.random()*MEAN_ARRANGE for i in range(INPUT_DIM)])
        std_list.append([0.5+random.random()*STD_ARRANGE for i in range(INPUT_DIM)])
    
    data={}
    for i in range(len(mean_list)):
        temp=normal_distribution(mean_list[i], std_list[i], int( (1 + random.sample([1, -1], 1)[0] * random.random() * SAMPLE_NUM_RANGE) * SAMPLE_NUM))
        data[str(i)]=temp

train_pipe=[]
train_dict={}
test_dict={}
for tag in data:
    train, test = train_test_split([i for i in data[tag]], test_size=0.2)
    # 为了测试持续学习的性能，类别内的样本分布应该在变化（漂移）
    # 这里每一类内的样本按照到原点的距离排序
    if SAMPLE_SHIFTING:
        train = sorted(train, key = lambda x: torch.linalg.norm(x))
    train_pipe.extend([(_,tag) for _ in train])
    train_dict[tag]=train
    test_dict[tag]=test

In [None]:
def batch_train(Model, batch_size, shuffle=False, evaluate=True):
    if shuffle:
        pipe = random.sample(train_pipe, len(train_pipe))
    else:
        pipe = train_pipe.copy()
    
    for i in range(0, len(pipe)+batch_size, batch_size):
        if pipe[i:i+batch_size] and Model.iteration<len(train_pipe):
            Model.batch_train(pipe[i:i+batch_size])
            if evaluate:
                if Model.iteration%int(SAMPLE_NUM*(1-SAMPLE_NUM_RANGE))==0:
                    predict_trainset(Model, silent=True)
                    predict_testset(Model, silent=True)

def continual_forward_baseline(Model, shuffle=False, evaluate=True):
    if shuffle:
        pipe = random.sample(train_pipe, len(train_pipe))
    else:
        pipe = train_pipe.copy()
    for i in tqdm(pipe):
        Model.continual_forward_baseline(i[0],i[1])
        if evaluate:
            if Model.iteration%int(SAMPLE_NUM*(1-SAMPLE_NUM_RANGE))==0:
                predict_trainset(Model, silent=True)
                predict_testset(Model, silent=True)

def continual_train(Model, shuffle=False, evaluate=True):
    if shuffle:
        pipe = random.sample(train_pipe, len(train_pipe))
    else:
        pipe = train_pipe.copy()
    for i in tqdm(pipe):
        Model.continual_forward(i[0],i[1])
        if evaluate:
            if Model.iteration%int(SAMPLE_NUM*(1-SAMPLE_NUM_RANGE))==0:
                predict_trainset(Model, silent=True)
                predict_testset(Model, silent=True)

def predict_trainset(Model, silent=False):
    acc_list=[]
    total=0
    total_correct=0
    forgetting_list=[]
    for key, value in train_dict.items():
        total+=len(value)
        if Model.tag_dict.get(key):
            correct=0
            pred_tags_list=Model.predict(value)
            for pred_tags in pred_tags_list:
                if key in pred_tags:
                    correct+=1
            total_correct+=correct
            acc=correct/len(value)
            metric_dict["Train"][key].append([Model.iteration, acc])
            acc_list.append(acc)
            if len(metric_dict["Train"][key])>1:
                forgetting_list.append(metric_dict["Train"][key][0][1] - metric_dict["Train"][key][-1][1])
        else:
            acc_list.append(0)
    
    total_acc=total_correct/total
    metric_dict["Train"]["Total ACC"].append([Model.iteration, total_acc])
    
    average_acc=sum(acc_list)/len(acc_list)
    metric_dict["Train"]["Average ACC"].append([Model.iteration, average_acc])

    if forgetting_list:
        forgetting_rate=sum(forgetting_list)/len(forgetting_list)
        metric_dict["Train"]["Forgetting Rate"].append([Model.iteration, forgetting_rate])
        
    if not silent:
        print("Total Acc", total_acc)
        print("Average Acc", average_acc)

def predict_testset(Model, silent=False):
    acc_list=[]
    total=0
    total_correct=0
    forgetting_list=[]
    for key, value in test_dict.items():
        total+=len(value)
        if Model.tag_dict.get(key):
            correct=0
            pred_tags_list=Model.predict(value)
            for pred_tags in pred_tags_list:
                if key in pred_tags:
                    correct+=1
            total_correct+=correct
            acc=correct/len(value)
            metric_dict["Test"][key].append([Model.iteration, acc])
            acc_list.append(acc)
            if len(metric_dict["Test"][key])>1:
                forgetting_list.append(metric_dict["Test"][key][0][1] - metric_dict["Test"][key][-1][1])
        else:
            acc_list.append(0)
    
    total_acc=total_correct/total
    metric_dict["Test"]["Total ACC"].append([Model.iteration, total_acc])
    
    average_acc=sum(acc_list)/len(acc_list)
    metric_dict["Test"]["Average ACC"].append([Model.iteration, average_acc])

    if forgetting_list:
        forgetting_rate=sum(forgetting_list)/len(forgetting_list)
        metric_dict["Test"]["Forgetting Rate"].append([Model.iteration, forgetting_rate])

    if not silent:
        print("Total Acc", total_acc)
        print("Average Acc", average_acc)

def predict_train_and_paint(Model):
    predict_result={}
    for key, value in tqdm(train_dict.items()):
        pred_tags_list=Model.predict(value)
        for pred_tags, sample in zip(pred_tags_list, value):
            pred_tag=pred_tags[0]
            if not predict_result.get(pred_tag):
                predict_result[pred_tag]=[sample]
            else:
                predict_result[pred_tag].append(sample)
    predict_result=dict(sorted(predict_result.items(), key=lambda x:x[0]))
    if INPUT_DIM==2:
        for tag in predict_result:
            t=torch.tensor([i.tolist() for i in predict_result[tag]])
            plt.scatter(t[:, 0], t[:, 1], label=tag)
        # plt.legend()
        plt.title("Train")
    elif INPUT_DIM==3:
        fig = plt.figure()
        ax = fig.add_subplot(projection='3d')
        for tag in predict_result:
            t=torch.tensor([i.tolist() for i in predict_result[tag]])
            ax.scatter(t[:, 0].cpu(), t[:, 1].cpu(), t[:, 2].cpu(), label=tag)
        ax.set_xlabel('X')
        ax.set_ylabel('Y')
        ax.set_zlabel('Z')
        # ax.legend()
        plt.title("Train")

def predict_test_and_paint(Model):
    predict_result={}
    for key, value in tqdm(test_dict.items()):
        pred_tags_list=Model.predict(value)
        for pred_tags, sample in zip(pred_tags_list, value):
            pred_tag=pred_tags[0]
            if not predict_result.get(pred_tag):
                predict_result[pred_tag]=[sample]
            else:
                predict_result[pred_tag].append(sample)
    predict_result=dict(sorted(predict_result.items(), key=lambda x:x[0]))
    if INPUT_DIM==2:
        for tag in predict_result:
            t=torch.tensor([i.tolist() for i in predict_result[tag]])
            plt.scatter(t[:, 0], t[:, 1], label=tag)
        # plt.legend()
        plt.title("Test")
    elif INPUT_DIM==3:
        fig = plt.figure()
        ax = fig.add_subplot(projection='3d')
        for tag in predict_result:
            t=torch.tensor([i.tolist() for i in predict_result[tag]])
            ax.scatter(t[:, 0].cpu(), t[:, 1].cpu(), t[:, 2].cpu(), label=tag)
        ax.set_xlabel('X')
        ax.set_ylabel('Y')
        ax.set_zlabel('Z')
        # ax.legend()
        plt.title("Test")

def plot_loss(Model):
    plt.figure()
    keys=[]
    for key in Model.loss_dict.keys():
        if Model.loss_dict[key]:
            t=torch.tensor(Model.loss_dict[key])
            plt.plot(t[:,0], t[:,1])
            keys.append(key)
    plt.legend(keys)
    plt.show()

def plot_acc():
    if CLASS_NUM<=25:
        plt.figure(figsize=(16,8))
        ax=plt.subplot(1,2,1)
        for key in [i for i in metric_dict["Train"].keys() if i not in ["Average ACC", "Total ACC", "Forgetting Rate"]]:
            t=torch.tensor(metric_dict["Train"][key])
            ax.plot(t[:,0], t[:,1], label=key)
        ax.set_ylim(top=1.1, bottom=-0.1)
        ax.set_title("Train ACC")
        ax.legend()

        ax=plt.subplot(1,2,2)
        for key in [i for i in metric_dict["Test"].keys() if i not in ["Average ACC", "Total ACC", "Forgetting Rate"]]:
            t=torch.tensor(metric_dict["Test"][key])
            ax.plot(t[:,0], t[:,1], label=key)
        ax.set_ylim(top=1.1, bottom=-0.1)
        ax.set_title("Test ACC")
        ax.legend()
        
        plt.show()

def initialize():
    global metric_dict
    metric_dict={
        "Train":{},
        "Test":{},
    }
    
    for tag in train_dict:
        if metric_dict["Train"].get(tag)==None:
            metric_dict["Train"][tag]=[]
            metric_dict["Train"][tag]=[]
    metric_dict["Train"]["Total ACC"]=[]
    metric_dict["Train"]["Average ACC"]=[]
    metric_dict["Train"]["Forgetting Rate"]=[]
    
    for tag in train_dict:
        if metric_dict["Test"].get(tag)==None:
            metric_dict["Test"][tag]=[]
            metric_dict["Test"][tag]=[]
    metric_dict["Test"]["Total ACC"]=[]
    metric_dict["Test"]["Average ACC"]=[]
    metric_dict["Test"]["Forgetting Rate"]=[]

In [None]:
INNER_DIM=9

class Classifier(torch.nn.Module):
    def __init__(self):
        super(Classifier, self).__init__()
        self.fc = torch.nn.Linear(INPUT_DIM, INNER_DIM)
        init_network(self)

    def forward(self, feature_vecs):
        out = self.fc(feature_vecs)
        return out

class Generator(torch.nn.Module):
    def __init__(self):
        super(Generator, self).__init__()
        self.fc = torch.nn.Linear(INNER_DIM, INPUT_DIM)
        init_network(self)

    def forward(self, tag_vecs):
        out = self.fc(tag_vecs)
        return out

class EmptyModel(torch.nn.Module):
    def __init__(self, *args, **kwargs) -> None:
        super().__init__(*args, **kwargs)
    
    def forward(self, t):
        t=t.reshape(1,-1)
        return t

In [None]:
SHUFFLE=False
CLS_LR=0.001
GEN_LR=0.01
TRAIN_ALONG=5

In [None]:
def reserve_metric(suffix):
    for key in [i for i in metric_dict["Train"].keys() if i in ["Average ACC", "Total ACC", "Forgetting Rate"]]:
        metric_reservation["Train"][key+" "+suffix]=metric_dict["Train"][key].copy()
    for key in [i for i in metric_dict["Test"].keys() if i in ["Average ACC", "Total ACC", "Forgetting Rate"]]:
        metric_reservation["Test"][key+" "+suffix]=metric_dict["Test"][key].copy()

metric_reservation_list=[]
for i in range(5):
    SEED=random.randint(1, 999999)

    metric_reservation={
        "Train":{},
        "Test":{},
    }
    
    from ContinualLearning import BaseModel
    set_seed(SEED)
    Model=BaseModel(EmptyModel, Classifier, CLS_LR)
    batch_size=SAMPLE_NUM*2
    initialize()
    for i in tqdm(range(batch_size)):
        batch_train(Model, batch_size, shuffle=True)
    predict_trainset(Model, silent=True)
    predict_testset(Model, silent=True)
    reserve_metric("Batch")

    from ContinualLearning import BaseModel
    set_seed(SEED)
    Model=BaseModel(EmptyModel, Classifier, CLS_LR)
    initialize()
    continual_forward_baseline(Model, shuffle=SHUFFLE)
    predict_trainset(Model, silent=True)
    predict_testset(Model, silent=True)
    reserve_metric("Baseline")

    from ContinualLearning import ContinualLearningModel_Store
    set_seed(SEED)
    Model=ContinualLearningModel_Store(EmptyModel, Classifier, CLS_LR, TRAIN_ALONG)
    initialize()
    continual_train(Model, shuffle=SHUFFLE)
    predict_trainset(Model, silent=True)
    predict_testset(Model, silent=True)
    reserve_metric("Store")

    from ContinualLearning import ContinualLearningModel_Generate
    set_seed(SEED)
    Model=ContinualLearningModel_Generate(EmptyModel, Classifier, CLS_LR, TRAIN_ALONG, Generator, GEN_LR)
    initialize()
    continual_train(Model, shuffle=SHUFFLE)
    predict_trainset(Model, silent=True)
    predict_testset(Model, silent=True)
    reserve_metric("Generate")

    metric_reservation_list.append(metric_reservation)

In [None]:
fig=plt.figure(figsize=(16,16))

o=1
# verticals=["Average ACC", "Total ACC", "Forgetting Rate"]
verticals=["Average ACC", "Forgetting Rate"]
horizontals=["Train", "Test"]
for criterion in verticals:
    for mode in horizontals:
        ax=plt.subplot(len(verticals), len(horizontals), o)
        top=0
        bottom=1
        tt={}
        for metric_reservation in metric_reservation_list:
            colors=["#00A2FF20", "#00000020", "#FF000020", "#00FF0020"]
            for key in [j for j in metric_reservation[mode].keys() if criterion in j]:
                t=torch.tensor(metric_reservation[mode][key])
                if not tt.get(key):
                    tt[key]=[t]
                else:
                    tt[key].append(t)
                ax.plot(t[:,0], t[:,1], color=colors[0])
                colors.pop(0)
        
        colors=["#00A2FF", "#000000", "#FF0000", "#00FF00"]
        for key in [j for j in metric_reservation[mode].keys() if criterion in j]:
            t=torch.stack(tt[key]).mean(dim=0)
            bottom=min(torch.min(t[:, 1]).tolist(), bottom)
            top=max(torch.max(t[:, 1]).tolist(), top)
            ax.plot(t[:,0], t[:,1], label=key, color=colors[0])
            colors.pop(0)
        
        ax.set_xlabel("iteration", loc="right")
        ax.set_ylim(top=top+0.1, bottom=bottom-0.1)
        ax.set_title(f"{mode} {criterion}")
        ax.legend()
        o+=1
plt.subplots_adjust(wspace=0.3, hspace=0.3)
plt.show()