In [1]:
from avalanche.benchmarks import PermutedMNIST, RotatedMNIST
from avalanche.models import SimpleCNN
import avalanche
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.utils.data import DataLoader
from torchvision import transforms, models
from torch.utils.data import DataLoader
from Approaches.DistillationApproach import LwF
import numpy as np
from copy import deepcopy
from Approaches.MemoryApproach import ER,A_GEM
from Approaches.RegularizationApproach import MAS,EWC
from Approaches.Approach import Naive

In [2]:
torch.manual_seed(100)

<torch._C.Generator at 0x27342e3e490>

In [3]:
class CNN(nn.Module):
    def __init__(self, num_classes):
        super().__init__()
        self.model = SimpleCNN(num_classes=num_classes)
        self.model.features[0] = nn.Conv2d(1, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    def forward(self, x):
        return self.model(x)

In [4]:
mas_model=CNN(num_classes=10)
ewc_model = CNN(num_classes=10)
naive_model = CNN(num_classes=10)
er_model = CNN(num_classes=10)
agem_model = CNN(num_classes=10)
lwf_model = CNN(num_classes=10)

permuted_mnist = PermutedMNIST(n_experiences=3,return_task_id=True)
rotated_mnist = RotatedMNIST(n_experiences=3,return_task_id=True,rotations_list=[0,90,180])


In [5]:
device=torch.device("cuda:0")
epochs = 5

mas=MAS(mas_model, lambda_reg=1, device=device)
ewc=EWC(ewc_model, lambda_reg=100, device=device)
er = ER(er_model, sampling_freq=0.3, n_samples=100, max_size=1000, device=device)
naive = Naive(naive_model, device=device)
agem = A_GEM(agem_model, n_samples=100, max_size=1000, device=device)
lwf = LwF(lwf_model, lambda_distill=2.5, device=device)

In [6]:
def train_one_epoch(approach, dataloader, optimizer, device):
    approach.train()
    total_loss = 0


    for images, labels, _ in dataloader:
        images, labels = images.to(device), labels.to(device) 
        optimizer.zero_grad() 
        
        loss = approach(images,labels) 
    
        optimizer.step()  
    
        total_loss += loss.item()
    
        
    return total_loss / len(dataloader) 


In [7]:
def evaluate_stream(approach, R, row, test_stream):
    approach.eval()
    for test_exp in test_stream:
        correct = 0
        total = 0
        col = test_exp.current_experience
        test_loader = DataLoader(test_exp.dataset,batch_size=64,shuffle=False)
        for X, y, _ in test_loader:
            X, y = X.to(device), y.to(device)
            preds=approach(X)
            correct += torch.sum(preds.argmax(dim=1)==y)
        R[row][col] = correct/len(test_loader.dataset)
    return R


In [8]:
def train_stream(benchmark, approach, epochs):
    optimizer = optim.Adam(approach.model.parameters(),lr=0.001)
    train_stream = benchmark.train_stream
    test_stream = benchmark.test_stream
    approach.set_stream(train_stream, test_stream)
    n_experieneces=benchmark.n_experiences
    R=np.zeros((n_experieneces,n_experieneces))
    for train_exp in train_stream:
        train_loader = DataLoader(train_exp.dataset,batch_size=64,shuffle=True)
        
        for epoch in range(epochs):
            train_loss = train_one_epoch(approach=approach, dataloader=train_loader, optimizer=optimizer, device=device)
            
        approach.adapt()
        R = evaluate_stream(approach=approach, R=R, row=train_exp._current_experience, test_stream=benchmark.test_stream)
    return approach, R

In [15]:
lwf, R_lwf = train_stream(benchmark=permuted_mnist,approach=lwf,epochs=1)



In [9]:
er, R_er = train_stream(benchmark=permuted_mnist,approach=er,epochs=epochs)
agem, R_agem = train_stream(benchmark=permuted_mnist,approach=agem,epochs=epochs)
mas, R_mas = train_stream(benchmark=permuted_mnist,approach=mas,epochs=epochs)
ewc, R_ewc = train_stream(benchmark=permuted_mnist,approach=ewc,epochs=epochs)
naive, R_naive = train_stream(benchmark=permuted_mnist,approach=naive,epochs=epochs)

  if grad.T @ ref_grad >= 0:


In [16]:
R_lwf

array([[0.91099995, 0.2969    , 0.76709998],
       [0.76379997, 0.91929996, 0.66219997],
       [0.62369996, 0.63999999, 0.94139999]])

In [10]:
R_agem

array([[0.93479997, 0.123     , 0.1001    ],
       [0.69349998, 0.93979996, 0.1085    ],
       [0.54210001, 0.6656    , 0.94389999]])

In [11]:
R_naive

array([[0.93269998, 0.1558    , 0.15789999],
       [0.2791    , 0.93769997, 0.1162    ],
       [0.11319999, 0.27779999, 0.93179995]])

In [12]:
R_er

array([[0.93869996, 0.1293    , 0.1108    ],
       [0.79329997, 0.93149996, 0.1133    ],
       [0.71669996, 0.8136    , 0.93759996]])

In [20]:
R_ewc

array([[0.9393    , 0.1223    , 0.3585    ],
       [0.40789998, 0.94220001, 0.2367    ],
       [0.2674    , 0.37689999, 0.94379997]])

In [14]:
R_mas

array([[0.9368    , 0.1064    , 0.1029    ],
       [0.81149995, 0.86799997, 0.11939999],
       [0.61849999, 0.60679996, 0.85689998]])

In [17]:
def get_BWT(R):
    n_experieneces=R.shape[0]
    BWT=0
    for i in range(n_experieneces):
        BWT+=R[n_experieneces-1][i]-R[i][i]
    BWT/=n_experieneces-1
    return BWT

In [18]:
def get_ACC(R):
    n_experieneces=R.shape[0]
    ACC=np.sum(R,axis=1)[-1]/n_experieneces
    return ACC

In [20]:
print("LwF BWT:",get_BWT(R_lwf))

LwF BWT: -0.28329998254776


In [21]:
print("LwF BWT:",get_ACC(R_lwf))

LwF BWT: 0.7350333134333292


In [None]:
print("Naive BWT:",get_BWT(R_naive))
print("EWC BWT:",get_BWT(R_ewc))
print("MAS BWT:",get_BWT(R_mas))
print("ER BWT:",get_BWT(R_er))
print("A_GEM BWT:",get_BWT(R_agem))
print("LwF BWT:",get_BWT(lwf))

Naive BWT: -0.5343999862670898
EWC BWT: -0.5350499898195267
MAS BWT: -0.2999999672174454
ER BWT: -0.13374999165534973
A_GEM BWT: -0.3622500151395798


In [23]:
print("Naive ACC:",get_ACC(R_naive))
print("EWC ACC:",get_ACC(R_ewc))
print("MAS ACC:",get_ACC(R_mas))
print("ER ACC:",get_ACC(R_er))
print("A_GEM ACC:",get_ACC(R_agem))

Naive ACC: 0.46556665500005084
EWC ACC: 0.4685666412115097
MAS ACC: 0.6736999948819479
ER ACC: 0.8168999950091044
A_GEM ACC: 0.8285333116849264
