In [None]:
import random
import torch
from torch import nn
from torch.nn import functional as F
import pandas as pd
import numpy as np

from torch.utils.data import Dataset, DataLoader, TensorDataset
from torchvision import datasets, transforms
from torch import optim

from tqdm import tqdm
import sys

%matplotlib inline
import matplotlib.pyplot as plt

plt.style.use("seaborn-white")

In [None]:
def split_data(data, labels, split, bs):
    from collections import Counter
    distinct_labels = list(Counter(labels).keys())
    n_labels = len(distinct_labels)
    n_split = int(n_labels/split)
    trans = data.T
    trans.columns = labels
    dataloader = {}
    datasets = {}
    for i in range(n_split):
        cond1 = trans.columns.values >= i*split
        cond2 = trans.columns.values < (i+1)*split
        out = trans.iloc[:,cond1&cond2]
        out_label = out.columns.values
        out, out_label = map(torch.tensor, (out.T.to_numpy(), out_label))
        datasets[i] = TensorDataset(out.float(), out_label)
        dataloader[i] = DataLoader(datasets[i], batch_size=bs, shuffle=True)
    
    return (datasets, dataloader)

In [None]:
class LinearLayer(nn.Module):
    def __init__(self, input_dim, output_dim, act='relu', use_bn=False):
        super(LinearLayer, self).__init__()
        self.use_bn = use_bn
        self.lin = nn.Linear(input_dim, output_dim)
        self.act = nn.ReLU() if act == 'relu' else act
        if use_bn:
            self.bn = nn.BatchNorm1d(output_dim)
    def forward(self, x):
        if self.use_bn:
            return self.bn(self.act(self.lin(x)))
        return self.act(self.lin(x))

class Flatten(nn.Module):

    def forward(self, x):
        return x.view(x.shape[0], -1)

class BaseModel(nn.Module):    
    def __init__(self, num_inputs, num_hidden, num_outputs):
        super(BaseModel, self).__init__()
        #self.f1 = Flatten()
        self.lin1 = LinearLayer(num_inputs, num_hidden, use_bn=False)
        self.lin2 = nn.Linear(num_hidden, num_outputs)
        
    def forward(self, x):
        #fc1 = self.f1(x)
        h1 = self.lin1(x)
        out = self.lin2(h1)
        return out

In [None]:
def accu(model, dataloader):
    model = model.eval()
    acc = 0
    count = 0
    for input, target in dataloader:
        o = model(input)
        acc += (o.argmax(dim=1).long() == target).float().sum()
        count += len(target)
    return acc/count

In [None]:
train_data = pd.read_csv('./processed_data/cifar100_train_resnet18.csv',header=None)
train_label = pd.read_csv('./processed_data/cifar100_train_label_resnet18.csv',header=None)
train_label = [train_label.values[i][0] for i in range(train_label.shape[0])]
test = pd.read_csv('./processed_data/cifar100_test_resnet18.csv',header=None)
test_label = pd.read_csv('./processed_data/cifar100_test_label_resnet18.csv',header=None)
test_label = [test_label.values[i][0] for i in range(test_label.shape[0])]

In [None]:
split_train, split_train_loader = split_data(train_data, train_label, 4, 64)
split_test, split_test_loader = split_data(test, test_label, 4, 64)

In [None]:
seeds = [1,2,3,5,8]
lr = 1e-3
n_task = 25

In [None]:
for seed in seeds:

    criterion = nn.CrossEntropyLoss()
    torch.backends.cudnn.enabled=False
    torch.backends.cudnn.deterministic=True
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    np.random.seed(seed)
    random.seed(seed)
    model = BaseModel(512,20000,100)
    optimizer = optim.SGD(model.parameters(), lr=lr, momentum=0.9)
    
    accuracy = np.zeros([n_task, n_task])

    for i in range(n_task):
        model.train()
        for _ in range(1):
            for input, target in tqdm(split_train_loader[i]):
                optimizer.zero_grad()
                outputs = model(input)
                loss = criterion(outputs, target)
                loss.backward()
                optimizer.step()

        model.eval()
        for j in range(i+1):
            accuracy[i,j] = accu(model,split_test_loader[j]).item()
            
    accuracy = pd.DataFrame(accuracy)
    accuracy.to_csv('./accuracy/Vanilla_cifar100_resnet18_all_result_lr1e-3_randseed_'+str(seed)+'.csv',index=False,header=False)