In [2]:
from avalanche.benchmarks import PermutedMNIST, RotatedMNIST
from avalanche.models import SimpleCNN
import avalanche
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.utils.data import DataLoader
from torchvision import transforms, models
from torch.utils.data import DataLoader
import numpy as np
from copy import deepcopy
from Approaches.MemoryApproach import ER,A_GEM
from Approaches.RegularizationApproach import MAS,EWC
from Approaches.Approach import Naive

In [None]:
torch.manual_seed(100)

<torch._C.Generator at 0x2d8600b9510>

In [4]:
class CNN(nn.Module):
    def __init__(self, num_classes):
        super().__init__()
        self.model = SimpleCNN(num_classes=num_classes)
        self.model.features[0] = nn.Conv2d(1, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    def forward(self, x):
        return self.model(x)

In [5]:
mas_model=CNN(num_classes=10)
ewc_model = CNN(num_classes=10)
naive_model = CNN(num_classes=10)
er_model = CNN(num_classes=10)
agem_model = CNN(num_classes=10)

permuted_mnist = PermutedMNIST(n_experiences=3,return_task_id=True)
rotated_mnist = RotatedMNIST(n_experiences=3,return_task_id=True,rotations_list=[0,90,180])


In [None]:
mas=MAS(mas_model,lambda_reg=1)
ewc=EWC(ewc_model,lambda_reg=10)
er = ER(er_model,sampling_freq=0.3,n_samples=100,max_size=150)
naive = Naive(naive_model)
agem = A_GEM(agem_model,n_samples=100,max_size=1000)

device=torch.device("cuda:0")
epochs = 5

In [None]:
def train_one_epoch(approach, dataloader, optimizer, device):
    approach.train()
    total_loss = 0


    for images, labels, _ in dataloader:
        images, labels = images.to(device), labels.to(device) 
        optimizer.zero_grad() 
        
        loss = approach(images,labels) 
    
        optimizer.step()  
    
        total_loss += loss.item()
    
        
    return total_loss / len(dataloader) 


In [8]:
def evaluate_stream(approach, R, row, test_stream):
    approach.eval()
    for test_exp in test_stream:
        correct = 0
        total = 0
        col = test_exp.current_experience
        test_loader = DataLoader(test_exp.dataset,batch_size=64,shuffle=False)
        for X, y, _ in test_loader:
            X, y = X.to(device), y.to(device)
            preds=approach(X)
            correct += torch.sum(preds.argmax(dim=1)==y)
        R[row][col] = correct/len(test_loader.dataset)
    return R


In [9]:
def train_stream(benchmark, approach, epochs):
    optimizer = optim.Adam(approach.model.parameters(),lr=0.001)
    n_experieneces=benchmark.n_experiences
    R=np.zeros((n_experieneces,n_experieneces))
    for train_exp,test_exp in zip(benchmark.train_stream,benchmark.test_stream):
        train_loader = DataLoader(train_exp.dataset,batch_size=64,shuffle=True)
        test_loader = DataLoader(test_exp.dataset,batch_size=64,shuffle=False)
        
        for epoch in range(epochs):
            train_loss = train_one_epoch(approach=approach, dataloader=train_loader, optimizer=optimizer, device=device)
            
            
        loaders = {"train":train_loader,"test":test_loader}
        approach.adapt(loaders)
        R = evaluate_stream(approach=approach, R=R, row=train_exp._current_experience, test_stream=benchmark.test_stream)
    return approach, R

In [12]:
er, R_er = train_stream(benchmark=permuted_mnist,approach=er,epochs=5)
agem, R_agem = train_stream(benchmark=permuted_mnist,approach=agem,epochs=5)
mas, R_mas = train_stream(benchmark=permuted_mnist,approach=mas,epochs=5)
ewc, R_ewc = train_stream(benchmark=permuted_mnist,approach=ewc,epochs=5)
naive, R_naive = train_stream(benchmark=permuted_mnist,approach=naive,epochs=5)



Importance estimation completed using 10000 samples
Importance estimation completed using 10000 samples
Importance estimation completed using 10000 samples
Importance estimation completed using 10000 samples
Importance estimation completed using 10000 samples
Importance estimation completed using 10000 samples


In [13]:
R_agem

array([[0.94959998, 0.63809997, 0.64749998],
       [0.82489997, 0.9479    , 0.4962    ],
       [0.69679999, 0.83489996, 0.95389998]])

In [14]:
R_naive

array([[0.93219995, 0.14219999, 0.1114    ],
       [0.29620001, 0.93430001, 0.1078    ],
       [0.1259    , 0.32909998, 0.94169998]])

In [15]:
R_er

array([[0.93409997, 0.11979999, 0.0896    ],
       [0.81569999, 0.9346    , 0.1275    ],
       [0.71869999, 0.79530001, 0.93669999]])

In [16]:
R_ewc

array([[0.93610001, 0.13169999, 0.10399999],
       [0.2669    , 0.93619996, 0.1183    ],
       [0.1344    , 0.33589998, 0.93539995]])

In [17]:
R_mas

array([[0.9436    , 0.16069999, 0.127     ],
       [0.84029996, 0.8854    , 0.13849999],
       [0.53670001, 0.61439997, 0.87      ]])

In [18]:
def get_BWT(R):
    n_experieneces=R.shape[0]
    BWT=0
    for i in range(n_experieneces):
        BWT+=R[n_experieneces-1][i]-R[i][i]
    BWT/=n_experieneces-1
    return BWT

In [None]:
def get_ACC(R):
    n_experieneces=R.shape[0]
    ACC=np.sum(R,axis=1)[-1]/n_experieneces
    return ACC

In [20]:
print("Naive BWT:",get_BWT(R_naive))
print("EWC BWT:",get_BWT(R_ewc))
print("MAS BWT:",get_BWT(R_mas))
print("ER BWT:",get_BWT(R_er))
print("A_GEM BWT:",get_BWT(R_agem))

Naive BWT: -0.7057499885559082
EWC BWT: -0.7009999975562096
MAS BWT: -0.3389500081539154
ER BWT: -0.1773499846458435
A_GEM BWT: -0.18290001153945923


In [23]:
print("Naive ACC:",get_ACC(R_naive))
print("EWC ACC:",get_ACC(R_ewc))
print("MAS ACC:",get_ACC(R_mas))
print("ER ACC:",get_ACC(R_er))
print("A_GEM ACC:",get_ACC(R_agem))

Naive ACC: 0.46556665500005084
EWC ACC: 0.4685666412115097
MAS ACC: 0.6736999948819479
ER ACC: 0.8168999950091044
A_GEM ACC: 0.8285333116849264
