# CE-40959: Advanced Machine Learning
## HW5 - Continual Learning (90 points)

In this notebook, you are going to see the `catastrophic forgetting` phenomenon in continual learning scenarios and then alleviate this problem by implementing [Gradient Episodic Memory(GEM)](https://arxiv.org/abs/1706.08840) on the `MNIST` dataset.


Please write your code in specified sections and do not change anything else. If you have a question regarding this homework, please ask it on the Quera.

Also, it is recommended to use Google Colab to do this homework. You can connect to your drive using the code below:

In [1]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


## Import Required libraries

In [2]:
!pip install quadprog

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Collecting quadprog
  Downloading quadprog-0.1.11.tar.gz (121 kB)
[K     |████████████████████████████████| 121 kB 5.1 MB/s 
[?25h  Installing build dependencies ... [?25l[?25hdone
  Getting requirements to build wheel ... [?25l[?25hdone
    Preparing wheel metadata ... [?25l[?25hdone
Building wheels for collected packages: quadprog
  Building wheel for quadprog (PEP 517) ... [?25l[?25hdone
  Created wheel for quadprog: filename=quadprog-0.1.11-cp37-cp37m-linux_x86_64.whl size=290752 sha256=803bc49036b98da72aa385611c273727be59cb748a5fc68677d16f5b52754a29
  Stored in directory: /root/.cache/pip/wheels/4a/4e/d7/41034ea11aeef1266df3cae546116cb6094e955c41ae3e2589
Successfully built quadprog
Installing collected packages: quadprog
Successfully installed quadprog-0.1.11


In [3]:
import numpy as np
import os
import matplotlib.pyplot as plt
import torch
import torchvision
import random
import torch.nn as nn
import math
import quadprog
from tqdm import tqdm
import pickle
import torch.nn.functional as F
import torchvision.transforms as transforms
import torch.optim as optim
import torch.utils.data as data

## Learning parameters

In [4]:
num_classes = 10
class_per_task = 2
number_of_data_per_class = 3000
num_tasks = int(num_classes // class_per_task)
batch_size = 10
memory_size_per_task = 10

## Prepare dataset (5 points)

To compare different benchmarks fairly, define all of your dataloaders for each task and save them in an array.

In [5]:
device = "cuda" if torch.cuda.is_available() else "cpu"
kwargs = {'num_workers': 1, 'pin_memory': True} if device=='cuda' else {}

In [6]:
def get_all_dataloaders(num_classes, class_per_task, number_of_data_per_class, num_tasks):
    #################################################################################
    #                  COMPLETE THE FOLLOWING SECTION (2.5 points)                   #
    #################################################################################
    # complete the function to get all dataloaders for all tasks
    train_loader = {}

    classes= np.arange(num_classes)
    np.random.shuffle(classes)
    
    for i in range(num_tasks):
        dataset = torchvision.datasets.MNIST('./dataset/', train=True, download=True, transform = transforms.ToTensor())
        idx=[dataset.targets == a for a in classes[class_per_task*i:class_per_task*(i+1)]]
        indexes = dataset.targets == 1000000 #want all to be False

        for k in range(len(idx)):
          indexes = torch.logical_or(idx[k],indexes)

        dataset.targets = dataset.targets[indexes][:class_per_task*number_of_data_per_class]
        dataset.data = dataset.data[indexes][:class_per_task*number_of_data_per_class]

        train_loader[i] = torch.utils.data.DataLoader(
                dataset,
                batch_size=batch_size)
        
    return train_loader,classes

    #################################################################################


def get_testloader(classes,num_task,class_per_task,number_of_data_per_class):
    #################################################################################
    #                  COMPLETE THE FOLLOWING SECTION (2.5 points)                   #
    #################################################################################
    # complete the function to get MNIST test dataloader

    test_loader = {}

    for i in range(num_tasks):
        dataset = torchvision.datasets.MNIST('./dataset/', train=False, download=True, transform = transforms.ToTensor())
        idx=[dataset.targets == a for a in classes[class_per_task*i:class_per_task*(i+1)]]
        indexes = dataset.targets == 1000000

        for k in range(len(idx)):
          indexes = torch.logical_or(idx[k],indexes)

        dataset.targets = dataset.targets[indexes]
        dataset.data = dataset.data[indexes]

        test_loader[i] = torch.utils.data.DataLoader(
                dataset,
                batch_size=batch_size)
        
    return test_loader
    #################################################################################

In [7]:
train_loader,classes = get_all_dataloaders(num_classes,class_per_task,number_of_data_per_class,num_tasks)

Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz to ./dataset/MNIST/raw/train-images-idx3-ubyte.gz


  0%|          | 0/9912422 [00:00<?, ?it/s]

Extracting ./dataset/MNIST/raw/train-images-idx3-ubyte.gz to ./dataset/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz to ./dataset/MNIST/raw/train-labels-idx1-ubyte.gz


  0%|          | 0/28881 [00:00<?, ?it/s]

Extracting ./dataset/MNIST/raw/train-labels-idx1-ubyte.gz to ./dataset/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz to ./dataset/MNIST/raw/t10k-images-idx3-ubyte.gz


  0%|          | 0/1648877 [00:00<?, ?it/s]

Extracting ./dataset/MNIST/raw/t10k-images-idx3-ubyte.gz to ./dataset/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz to ./dataset/MNIST/raw/t10k-labels-idx1-ubyte.gz


  0%|          | 0/4542 [00:00<?, ?it/s]

Extracting ./dataset/MNIST/raw/t10k-labels-idx1-ubyte.gz to ./dataset/MNIST/raw



In [8]:
test_loader = get_testloader(classes,num_tasks,class_per_task,number_of_data_per_class)

## Network (5 points)

In [9]:
# define a 3 layer fc with relu activation functions between layers
# your fc layers dimensions are as follows:
# 784, 150, 150, 10


#################################################################################
#                  COMPLETE THE FOLLOWING SECTION (5 points)                   #
#################################################################################
# define above mentioned model and needed variables

class FC(torch.nn.Module):
    def __init__(self):
        super(FC, self).__init__()   
        self.fc1 = torch.nn.Linear(784, 150)
        self.fc2 = torch.nn.Linear(150, 150)
        self.fc3 = torch.nn.Linear(150, 10)

    def forward(self,inp):
        x = torch.flatten(inp,1,-1)
        x = self.fc1(x)
        x = nn.functional.relu(x)
        x = self.fc2(x)
        x = nn.functional.relu(x)
        out = self.fc3(x)
        return out

#################################################################################

## Naive Learning (20 points)

In this section, you will learn a network in its natural state, without considering any strategy for learning it continually. You will see that learning data in a such fashion causes a phenomenon called catastrophic forgetting.

As `GEM` is a task-incremental method, like the paper, evaluate your trained model performance for each task and then report the average accuracy of tasks. In addition to Accuracy metric, report the `backward transfer` and `forward transfer` metric based on the defination in the paper.

In [10]:
def train(model, trainloader, optimizer, criterion,task_id):
    model.train()
    print('Training on task',task_id)
    train_running_loss = 0.0
    train_running_correct = 0
    counter = 0
    for i, data in tqdm(enumerate(trainloader[task_id]), total=len(trainloader[task_id])):
        counter += 1
        image, labels = data
        image = image.to(device)
        labels = labels.to(device)
        optimizer.zero_grad()
        outputs = model(image)
        loss = criterion(outputs, labels)
        train_running_loss += loss.item()
        _, preds = torch.max(outputs.data, 1)
        train_running_correct += (preds == labels).sum().item()
        loss.backward()
        optimizer.step()
    
    epoch_loss = train_running_loss / counter
    epoch_acc = 100. * (train_running_correct / len(trainloader[task_id].dataset))
    return epoch_loss, epoch_acc

def validate(model, testloaders, criterion,task_id,R):
    model.eval()
    print('Validation')

    for t,(k,testloader) in enumerate(testloaders.items()):
        valid_running_loss = 0.0
        valid_running_correct = 0
        counter = 0
        if t<=task_id:
          with torch.no_grad():
              for i, data in enumerate(testloader):
                  counter += 1
                  image, labels = data
                  image = image.to(device)
                  labels = labels.to(device)
                  outputs = model(image)
                  loss = criterion(outputs, labels)
                  valid_running_loss += loss.item()
                  _, preds = torch.max(outputs.data, 1)
                  valid_running_correct += (preds == labels).sum().item()
              if task_id == t:
                  epoch_loss = valid_running_loss / counter
              R[task_id][t] = 100. * (valid_running_correct / len(testloader.dataset))
    return epoch_loss, R



In [11]:
def save_model(epoch, model, optimizer, criterion,checkpoint_path,stage):
    """
    Function to save the trained model 
    """
    print(f"Saving model...")
    torch.save({
                'epoch': epoch,
                'model_state_dict': model.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
                'loss': criterion,
                }, checkpoint_path+'model_stage'+str(stage)+'.pth')
    
def compute_metrics(R,task_id,baseline):
    diag = R.diagonal()
    acc = np.mean(R[task_id])
    BWT=0
    FWT=0
    if task_id != num_tasks-1:
        BWT = np.mean(R[task_id-1]-diag[task_id-1])
    if task_id != 0:
        FWT = np.mean([R[i-1][i] for i in range(task_id-1)])
    return acc,BWT,FWT

def base_line(model,testloaders):
    baseline=[]
    valid_running_loss = 0.0
    valid_running_correct = 0
    for t,(k,testloader) in enumerate(testloaders.items()):
          model.eval()
          with torch.no_grad():
              for i, data in enumerate(testloader):
                  image, labels = data
                  image = image.to(device)
                  labels = labels.to(device)
                  outputs = model(image)
                  loss = criterion(outputs, labels)
                  valid_running_loss += loss.item()
                  _, preds = torch.max(outputs.data, 1)
                  valid_running_correct += (preds == labels).sum().item()
              baseline.append(100. * (valid_running_correct / len(testloader.dataset)))

    return baseline

In [None]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

model = FC()
criterion = torch.nn.CrossEntropyLoss()
lr = 1e-3
optimizer  = torch.optim.SGD(model.parameters(), lr=lr)
epochs = 1
checkpoint_path = "/content/drive/MyDrive/MSC1400_1/AML/HW6/Practical/Q6/checkpoints"

#################################################################################
#                  COMPLETE THE FOLLOWING SECTION (20 points)                   #
#################################################################################
# complete code for sequentially training and then
# evaluate your model with test data
R = np.empty((num_tasks,num_tasks))

for i in range(num_tasks):

    train_loss, valid_loss = np.empty((5,3),dtype=np.float64),np.empty((5,3),dtype=np.float64)
    train_acc = np.empty((5,3),dtype=np.float64)

    for epoch in range(epochs):
        print(f"[INFO]: Epoch {epoch+1} of {epochs} --- Task {i}")

        train_epoch_loss, train_epoch_acc = train(model, train_loader, 
                                                optimizer, criterion,i)
        valid_epoch_loss, R = validate(model, test_loader,  
                                                    criterion,i,R)
        train_loss[i][0] = train_epoch_loss
        valid_loss[i][1] = valid_epoch_loss
        train_acc[i][2] = train_epoch_acc
        baseline = base_line(model,test_loader)
        print(R)
        acc,bwt,fwt = compute_metrics(R,i,baseline)
        print(f"Training loss: {train_epoch_loss:.3f}, training acc: {train_epoch_acc:.3f}")

        print(f"Acc: {acc:.3f}, BWT: {bwt:.3f}, FWT: {fwt: .3f}")


        if epoch == 2:

            save_model(epoch, model, optimizer, criterion,checkpoint_path,i)
            with open('{}train_losses.pickle'.format(checkpoint_path),'wb') as f:
                    pickle.dump(train_loss, f)
            with open('{}valid_losses.pickle'.format(checkpoint_path),'wb') as f:
                    pickle.dump(valid_loss, f)
            with open('{}train_accuracies.pickle'.format(checkpoint_path),'wb') as f:
                    pickle.dump(train_acc, f)
            with open('{}R.pickle'.format(checkpoint_path),'wb') as f:
                    pickle.dump(R, f)

        print('-'*50)
    print('TRAINING COMPLETE')

#################################################################################

[INFO]: Epoch 1 of 1 --- Task 0
Training on task 0


100%|██████████| 600/600 [00:01<00:00, 316.89it/s]


Validation
[[93.64583333  0.          0.          0.          0.        ]
 [ 0.         98.66939611  0.          0.          0.        ]
 [ 0.          0.         98.69477912  0.          0.        ]
 [ 0.          0.          0.         98.98322318  0.        ]
 [ 0.          0.          0.          0.         98.01568989]]
Training loss: 2.026, training acc: 58.167
Acc: 18.729, BWT: -78.413, FWT:  0.000
--------------------------------------------------
TRAINING COMPLETE
[INFO]: Epoch 1 of 1 --- Task 1
Training on task 1


100%|██████████| 600/600 [00:01<00:00, 368.75it/s]


Validation


  out=out, **kwargs)
  ret = ret.dtype.type(ret / rcount)


[[93.64583333  0.          0.          0.          0.        ]
 [ 0.         96.67349028  0.          0.          0.        ]
 [ 0.          0.         98.69477912  0.          0.        ]
 [ 0.          0.          0.         98.98322318  0.        ]
 [ 0.          0.          0.          0.         98.01568989]]
Training loss: 1.578, training acc: 52.733
Acc: 19.335, BWT: -74.917, FWT:  nan
--------------------------------------------------
TRAINING COMPLETE
[INFO]: Epoch 1 of 1 --- Task 2
Training on task 2


100%|██████████| 600/600 [00:01<00:00, 376.89it/s]


Validation
[[93.64583333  0.          0.          0.          0.        ]
 [ 0.         96.67349028  0.          0.          0.        ]
 [ 0.          0.         93.47389558  0.          0.        ]
 [ 0.          0.          0.         98.98322318  0.        ]
 [ 0.          0.          0.          0.         98.01568989]]
Training loss: 1.614, training acc: 37.783
Acc: 18.695, BWT: -77.339, FWT:  0.000
--------------------------------------------------
TRAINING COMPLETE
[INFO]: Epoch 1 of 1 --- Task 3
Training on task 3


100%|██████████| 600/600 [00:01<00:00, 398.49it/s]


Validation
[[93.64583333  0.          0.          0.          0.        ]
 [ 0.         96.67349028  0.          0.          0.        ]
 [ 0.          0.         93.47389558  0.          0.        ]
 [ 0.          0.          0.         98.11896289  0.        ]
 [ 0.          0.          0.          0.         98.01568989]]
Training loss: 1.771, training acc: 55.117
Acc: 19.624, BWT: -74.779, FWT:  0.000
--------------------------------------------------
TRAINING COMPLETE
[INFO]: Epoch 1 of 1 --- Task 4
Training on task 4


100%|██████████| 600/600 [00:01<00:00, 390.03it/s]


Validation
[[93.64583333  0.          0.          0.          0.        ]
 [ 0.         96.67349028  0.          0.          0.        ]
 [ 0.          0.         93.47389558  0.          0.        ]
 [ 0.          0.          0.         98.11896289  0.        ]
 [ 0.          0.          0.          0.         84.26395939]]
Training loss: 1.894, training acc: 30.767
Acc: 16.853, BWT: 0.000, FWT:  0.000
--------------------------------------------------
TRAINING COMPLETE


## Continually Learning using GEM (50 points)

In this section, you will complete the codes for the GEM method using the beforementioned parameters. Read the procedure explained in the paper. We pre-defined some functions for you. Complete them and use them in training.

In [12]:
# define your main class for continually learning with GEM
# define all needed variables and functions, all inside this class

class GEM(torch.nn.Module):
    def __init__(self,  n_inputs,n_outputs,n_tasks):
        super(GEM, self).__init__()   
        #################################################################################
        #                  COMPLETE THE FOLLOWING SECTION (5 points)                   #
        #################################################################################
        # define above mentioned model and needed variables
        self.margin = 0
        self.net = FC()

        self.ce = nn.CrossEntropyLoss()
        self.n_outputs = n_outputs

        self.opt = optim.SGD(self.parameters(), 1e-3)

        self.n_memories = 5
        self.gpu = 'yes'

        # allocate episodic memory
        self.memory_data = torch.FloatTensor(
            n_tasks, self.n_memories, n_inputs)
        self.memory_labs = torch.LongTensor(n_tasks, self.n_memories)
        
        self.memory_data = self.memory_data.cuda()
        self.memory_labs = self.memory_labs.cuda()

        # allocate temporary synaptic memory
        self.grad_dims = []
        for param in self.parameters():
            self.grad_dims.append(param.data.numel())
        self.grads = torch.Tensor(sum(self.grad_dims), n_tasks)
        self.grads = self.grads.cuda()

        # allocate counters
        self.observed_tasks = []
        self.old_task = -1
        self.mem_cnt = 0
        self.nc_per_task = n_outputs
        #################################################################################


    def calculate_past_classes_gradients(self,x,t,y):
        #################################################################################
        #                  COMPLETE THE FOLLOWING SECTION (10 points)                   #
        #################################################################################
         if len(self.observed_tasks) > 1:
            for tt in range(len(self.observed_tasks) - 1):
                self.zero_grad()
                # fwd/bwd on the examples in the memory
                past_task = self.observed_tasks[tt]
                offset1, offset2 = 0,self.nc_per_task
                ptloss = self.ce(
                    self.forward(
                        self.memory_data[past_task],
                        past_task)[:, offset1: offset2],
                    self.memory_labs[past_task] - offset1)
                ptloss.backward()
                self.grads[:, past_task].fill_(0.0)
                cnt = 0
                for param in self.parameters():
                  if param.grad is not None:
                      beg = 0 if cnt == 0 else sum(self.grad_dims[:cnt])
                      en = sum(self.grad_dims[:cnt + 1])
                      self.grads[beg: en, past_task].copy_(param.grad.data.view(-1))
                  cnt += 1


        #################################################################################

    def calculate_current_task_gradients(self,x,t,y):
        #################################################################################
        #                  COMPLETE THE FOLLOWING SECTION (5 points)                    #
        #################################################################################
        self.zero_grad()

        offset1, offset2 = 0, self.nc_per_task
        loss = self.ce(self.forward(x, t)[:, offset1: offset2], y - offset1)
        loss.backward()

        #################################################################################

    def project_past_Classes_gradients(self,x,t,y):
        #################################################################################
        #                  COMPLETE THE FOLLOWING SECTION (15 points)                   #
        #################################################################################
        
        if len(self.observed_tasks) > 1:
            # copy gradient
            store_grad(self.parameters, self.grads, self.grad_dims, t)
            indx = torch.cuda.LongTensor(self.observed_tasks[:-1]) if self.gpu \
                else torch.LongTensor(self.observed_tasks[:-1])
            dotp = torch.mm(self.grads[:, t].unsqueeze(0),
                            self.grads.index_select(1, indx))
            if (dotp < 0).sum() != 0:
                project2cone2(self.grads[:, t].unsqueeze(1),
                              self.grads.index_select(1, indx), self.margin)
                # copy gradients back
                overwrite_grad(self.parameters, self.grads[:, t],
                               self.grad_dims)
        self.opt.step()
        #################################################################################


    def update_memory(self,x,t,y):
        #################################################################################
        #                  COMPLETE THE FOLLOWING SECTION (5 points)                    #
        #################################################################################
        if t != self.old_task:
            self.observed_tasks.append(t)
            self.old_task = t

        # Update ring buffer storing examples from current task
        bsz = y.data.size(0)
        endcnt = min(self.mem_cnt + bsz, self.n_memories)
        effbsz = endcnt - self.mem_cnt
        self.memory_data[t, self.mem_cnt: endcnt].copy_(
            x.data[: effbsz])
        if bsz == 1:
            self.memory_labs[t, self.mem_cnt] = y.data[0]
        else:
            self.memory_labs[t, self.mem_cnt: endcnt].copy_(
                y.data[: effbsz])
        self.mem_cnt += effbsz
        if self.mem_cnt == self.n_memories:
            self.mem_cnt = 0
        #################################################################################

    def forward(self, x):
        output = self.net(x)
        return output

In [16]:

def store_grad(pp, grads, grad_dims, tid):
    """
        This stores parameter gradients of past tasks.
        pp: parameters
        grads: gradients
        grad_dims: list with number of parameters per layers
        tid: task id
    """
    # store the gradients
    grads[:, tid].fill_(0.0)
    cnt = 0
    for param in pp():
        if param.grad is not None:
            beg = 0 if cnt == 0 else sum(grad_dims[:cnt])
            en = sum(grad_dims[:cnt + 1])
            grads[beg: en, tid].copy_(param.grad.data.view(-1))
        cnt += 1


def overwrite_grad(pp, newgrad, grad_dims):
    cnt = 0
    for param in pp():
        if param.grad is not None:
            beg = 0 if cnt == 0 else sum(grad_dims[:cnt])
            en = sum(grad_dims[:cnt + 1])
            this_grad = newgrad[beg: en].contiguous().view(
                param.grad.data.size())
            param.grad.data.copy_(this_grad)
        cnt += 1


def project2cone2(gradient, memories, margin=0.5, eps=1e-3):
    memories_np = memories.cpu().t().double().numpy()
    gradient_np = gradient.cpu().contiguous().view(-1).double().numpy()
    t = memories_np.shape[0]
    P = np.dot(memories_np, memories_np.transpose())
    P = 0.5 * (P + P.transpose()) + np.eye(t) * eps
    q = np.dot(memories_np, gradient_np) * -1
    G = np.eye(t)
    h = np.zeros(t) + margin
    v = quadprog.solve_qp(P, q, G, h)[0]
    x = np.dot(v, memories_np) + gradient_np
    gradient.copy_(torch.Tensor(x).view(-1, 1))
