In [None]:
import random
import torch
from torch import nn
from torch.nn import functional as F
import pandas as pd
import numpy as np

from torch.utils.data import Dataset, DataLoader, TensorDataset
from torchvision import datasets, transforms
from torch import optim

from tqdm import tqdm
import sys

sys.path.append('./codes/')
from GEM import GEM

%matplotlib inline
import matplotlib.pyplot as plt

plt.style.use("seaborn-white")

In [None]:
def split_data(data, labels, split, bs):
    from collections import Counter
    distinct_labels = list(Counter(labels).keys())
    n_labels = len(distinct_labels)
    n_split = int(n_labels/split)
    trans = data.T
    trans.columns = labels
    dataloader = {}
    datasets = {}
    for i in range(n_split):
        cond1 = trans.columns.values >= i*split
        cond2 = trans.columns.values < (i+1)*split
        out = trans.iloc[:,cond1&cond2]
        out_label = out.columns.values
        out, out_label = map(torch.tensor, (out.T.to_numpy(), out_label))
        datasets[i] = TensorDataset(out.float(), out_label)
        dataloader[i] = DataLoader(datasets[i], batch_size=bs, shuffle=True)
    
    return (datasets, dataloader)

In [None]:
def accu(model, dataloader):
    model = model.eval()
    acc = 0
    count = 0
    for input, target in dataloader:
        o = model(input)
        acc += (o.argmax(dim=1).long() == target).float().sum()
        count += len(target)
    return acc/count

In [None]:
train_data = pd.read_csv('./processed_data/cifar100_train_resnet18.csv',header=None)
train_label = pd.read_csv('./processed_data/cifar100_train_label_resnet18.csv',header=None)
train_label = [train_label.values[i][0] for i in range(train_label.shape[0])]
test_data = pd.read_csv('./processed_data/cifar100_test_resnet18.csv',header=None)
test_label = pd.read_csv('./processed_data/cifar100_test_label_resnet18.csv',header=None)
test_label = [test_label.values[i][0] for i in range(test_label.shape[0])]

In [None]:
train_datasets, trainloader = split_data(train_data, train_label, 4, 64)
test_datasets, testloader = split_data(test_data, test_label, 4, 64)

In [None]:
n_task = 25
seeds = [1,2,3,5,8]

In [None]:
for seed in seeds:
    args = {
        'n_layers':1,
        'n_hiddens':20000,
        'memory_strength':0.5,
        'lr':1e-3,
        'n_memories':256,
        'cuda':False,
        'seed':seed
    }
    gem = GEM(512,100,25,args)
    accuracy = np.zeros([n_task, n_task])

    for i in range(n_task):
        for _ in range(1):
            for input, target in tqdm(trainloader[i]):
                gem.observe(input, i, target)

        for j in range(i+1):
            accuracy[i,j] = accu(gem.net,testloader[j]).item()
            
    accuracy = pd.DataFrame(accuracy)
    accuracy.to_csv('./accuracy/GEM_cifar100_resnet18_all_result_lr1e-3_randseed_'+str(seed)+'.csv',index=False,header=False)