In [1]:
from avalanche.benchmarks import PermutedMNIST, RotatedMNIST
from avalanche.models import SimpleCNN
import avalanche
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.utils.data import DataLoader
from torchvision import transforms, models
from torch.utils.data import DataLoader
import numpy as np
from copy import deepcopy
from Approaches.MemoryApproach import ER
from Approaches.RegularizationApproach import MAS,EWC
from Approaches.Approach import Naive

In [2]:
torch.manual_seed(100)

<torch._C.Generator at 0x191fe8ba490>

In [3]:
class CNN(nn.Module):
    def __init__(self, num_classes):
        super().__init__()
        self.model = SimpleCNN(num_classes=num_classes)
        self.model.features[0] = nn.Conv2d(1, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    def forward(self, x):
        return self.model(x)

In [4]:
mas_model=CNN(num_classes=10)
ewc_model = CNN(num_classes=10)
naive_model = CNN(num_classes=10)
er_model = CNN(num_classes=10)

permuted_mnist = PermutedMNIST(n_experiences=3,return_task_id=True)
rotated_mnist = RotatedMNIST(n_experiences=3,return_task_id=True,rotations_list=[0,90,180])


In [5]:
mas=MAS(mas_model,lambda_reg=1)
ewc=EWC(ewc_model,lambda_reg=10)
er = ER(er_model,sampling_freq=0.3,n_samples=100,max_size=4000)
naive = Naive(naive_model)
mas_optimizer = optim.Adam(mas.model.parameters(),lr=0.001)
ewc_optimizer = optim.Adam(ewc.model.parameters(),lr=0.001)
er_optimizer = optim.Adam(er.model.parameters(),lr=0.001)
naive_optimizer = optim.Adam(naive.model.parameters(),lr=0.001)


device=torch.device("cuda:0")
epochs = 5

In [6]:
import matplotlib.pyplot as plt
def train_one_epoch(approach, dataloader, optimizer, device):
    approach.train()
    total_loss = 0
    losses=[]
    penalties=[]

    for images, labels, _ in dataloader:
        images, labels = images.to(device), labels.to(device) 

        
        loss = approach(images,labels) 
        

        optimizer.zero_grad()  # Reset gradients
        loss.backward()  # Backpropagation (compute gradients)
        optimizer.step()  # Update model parameters
        # losses.append(loss.item())
        # penalties.append(approach.penalty().item())

        

        # Collect the loss
        total_loss += loss.item()
    # plt.plot(losses)
    # plt.show()
    # plt.plot(penalties)
    # plt.show()
        
    return total_loss / len(dataloader)  # Return average loss


In [7]:
def evaluate_stream(approach, R, row, test_stream):
    approach.eval()
    for test_exp in test_stream:
        correct = 0
        total = 0
        col = test_exp.current_experience
        test_loader = DataLoader(test_exp.dataset,batch_size=64,shuffle=False)
        for X, y, _ in test_loader:
            X, y = X.to(device), y.to(device)
            preds=approach(X)
            correct += torch.sum(preds.argmax(dim=1)==y)
        R[row][col] = correct/len(test_loader.dataset)
    return R


In [8]:
def train_stream(benchmark, approach, optimizer, epochs):
    n_experieneces=benchmark.n_experiences
    R=np.zeros((n_experieneces,n_experieneces))
    for train_exp,test_exp in zip(benchmark.train_stream,benchmark.test_stream):
        train_loader = DataLoader(train_exp.dataset,batch_size=64,shuffle=True)
        test_loader = DataLoader(test_exp.dataset,batch_size=64,shuffle=False)
        
        for epoch in range(epochs):
            train_loss = train_one_epoch(approach=approach, dataloader=train_loader, optimizer=optimizer, device=device)
            
            
        loaders = {"train":train_loader,"test":test_loader}
        approach.adapt(loaders)
        R = evaluate_stream(approach=approach, R=R, row=train_exp._current_experience, test_stream=benchmark.test_stream)
    return approach, R

In [9]:
er, R_er=train_stream(benchmark=permuted_mnist,approach=er,optimizer=er_optimizer,epochs=5)
mas, R_mas = train_stream(benchmark=permuted_mnist,approach=mas,optimizer=mas_optimizer,epochs=5)
ewc, R_ewc = train_stream(benchmark=permuted_mnist,approach=ewc,optimizer=ewc_optimizer,epochs=5)
naive, R_naive = train_stream(benchmark=permuted_mnist,approach=naive,optimizer=naive_optimizer,epochs=5)



Importance estimation completed using 10000 samples
Importance estimation completed using 10000 samples
Importance estimation completed using 10000 samples
Importance estimation completed using 10000 samples
Importance estimation completed using 10000 samples
Importance estimation completed using 10000 samples


In [10]:
R_naive

array([[0.93469995, 0.121     , 0.1006    ],
       [0.27129999, 0.93359995, 0.1087    ],
       [0.13159999, 0.24699999, 0.93289995]])

In [11]:
R_er

array([[0.93759996, 0.121     , 0.0869    ],
       [0.77429998, 0.93349999, 0.1131    ],
       [0.7058    , 0.81439996, 0.93659997]])

In [12]:
R_ewc

array([[0.93529999, 0.13159999, 0.1069    ],
       [0.28979999, 0.93689996, 0.1043    ],
       [0.15789999, 0.31709999, 0.9368    ]])

In [13]:
R_mas

array([[0.93399996, 0.1162    , 0.08899999],
       [0.67429996, 0.87459999, 0.114     ],
       [0.46739998, 0.60939997, 0.86849999]])

In [15]:
def get_BWT(R):
    n_experieneces=R.shape[0]
    BWT=0
    for i in range(n_experieneces):
        BWT+=R[n_experieneces-1][i]-R[i][i]
    BWT/=n_experieneces-1
    return BWT

In [16]:
def  get_ACC(R):
    n_experieneces=R.shape[0]
    ACC=np.sum(R,axis=1)[-1]/n_experieneces
    return ACC

In [20]:
print("Naive BWT:",get_BWT(R_naive))
print("EWC BWT:",get_BWT(R_ewc))
print("MAS BWT:",get_BWT(R_mas))
print("ER BWT:",get_BWT(R_er))

Naive BWT: -0.7448499575257301
EWC BWT: -0.6985999867320061
MAS BWT: -0.365899994969368
ER BWT: -0.17544999718666077


In [21]:
print("Naive ACC:",get_ACC(R_naive))
print("EWC ACC:",get_ACC(R_ewc))
print("MAS ACC:",get_ACC(R_mas))
print("ER ACC:",get_ACC(R_er))

Naive ACC: 0.43716664612293243
EWC ACC: 0.4705999940633774
MAS ACC: 0.6484333177407583
ER ACC: 0.8189333081245422


In [18]:
batch_X,batch_y,t=next(iter(train_loader))
batch_X=batch_X.to(device)
batch_y=batch_y.to(device)
yhat=mas.model(batch_X)
criterion(yhat,batch_y)

NameError: name 'train_loader' is not defined