In [1]:
from math import floor
import numpy as np

from torchvision import datasets
from torchvision.transforms import ToTensor

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim

In [2]:
# training MNIST data
train_data = datasets.MNIST(root='data', train=True, download=True, transform=ToTensor())

# testing MNIST data
test_data = datasets.MNIST(root='data', train=False, download=True, transform=ToTensor())

In [35]:
# DEBUG CELLS
print(type(train_data.data))
print(train_data.data.shape)
print(train_data.targets.shape)

<class 'torch.Tensor'>
torch.Size([60000, 28, 28])
torch.Size([60000])


In [3]:
# define the network structure
class Network(nn.Module):
    def __init__(self, input_shape):
        super(Network, self).__init__()
        self.fc1 = nn.Linear(input_shape, 500)
        self.fc2 = nn.Linear(500, 300)
        self.output = nn.Linear(300, 2)

    def forward(self, x):
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.output(x)
        return x

In [4]:
input_shape = 784

# training hyperparameters
n_epoch = 2
learning_rate = 0.001
minibatch_sz = 64

In [20]:
# create the network, optimizer and define the loss function
network = Network(input_shape)
optimizer = optim.SGD(network.parameters(), lr=learning_rate)
criterion = nn.CrossEntropyLoss()

##### We will perfom task-wise training. A single task comprises of two classes from the MNIST dataset.

In [10]:
task1 = [0, 1]
task2 = [2, 3]
task3 = [4, 5]
task4 = [6, 7]
task5 = [8, 9]

##### Separate training and testing samples from each task. This is easier to work with.

In [11]:
task1_tr_samples = torch.where(torch.bitwise_or(train_data.targets == task1[0], train_data.targets == task1[1]) == 1)[0]

task2_tr_samples = torch.where(torch.bitwise_or(train_data.targets == task2[0], train_data.targets == task2[1]))[0]

task3_tr_samples = torch.where(torch.bitwise_or(train_data.targets == task3[0], train_data.targets == task3[1]) == 1)[0]

task4_tr_samples = torch.where(torch.bitwise_or(train_data.targets == task4[0], train_data.targets == task4[1]))[0]

task5_tr_samples = torch.where(torch.bitwise_or(train_data.targets == task5[0], train_data.targets == task5[1]))[0]

In [12]:
task1_ts_samples = torch.where(torch.bitwise_or(test_data.targets == task1[0], test_data.targets == task1[1]) == 1)[0]

task2_ts_samples = torch.where(torch.bitwise_or(test_data.targets == task2[0], test_data.targets == task2[1]))[0]

task3_ts_samples = torch.where(torch.bitwise_or(test_data.targets == task3[0], test_data.targets == task3[1]) == 1)[0]

task4_ts_samples = torch.where(torch.bitwise_or(test_data.targets == task4[0], test_data.targets == task4[1]))[0]

task5_ts_samples = torch.where(torch.bitwise_or(test_data.targets == task5[0], test_data.targets == task5[1]))[0]

### **Question 1**: The purpose of this question is to demonstrate the problem of catastrophic forgetting. For this purpose, we will train a single network on two different tasks in a sequence. After training evaluate the performance of the trained network on both tasks. What do you observe?

In [21]:
# train on task 1
for e in range(n_epoch):
    n_batch = floor(task1_tr_samples.shape[0] / minibatch_sz)
    
    for b in range(n_batch):
        x_batch = train_data.data[task1_tr_samples[(b*minibatch_sz):((b+1)*minibatch_sz)]]
        y_batch = train_data.targets[task1_tr_samples[(b*minibatch_sz):((b+1)*minibatch_sz)]]

        # flatten image before presenting to the network and normalize intensities to the range [0, 1]
        x_batch = torch.flatten(x_batch / 255, start_dim=1)

        # convert label to one hot
        y_batch = F.one_hot(y_batch).float()

        y_hat_batch = network(x_batch)
        loss = criterion(y_hat_batch, y_batch)
        loss.backward()
        optimizer.step()

        print(f'Epoch {e}: {loss.item()}')

Epoch 0: 0.7139849662780762
Epoch 0: 0.7199063301086426
Epoch 0: 0.7150048613548279
Epoch 0: 0.7135699987411499
Epoch 0: 0.716407060623169
Epoch 0: 0.7134637236595154
Epoch 0: 0.7109566330909729
Epoch 0: 0.7090693116188049
Epoch 0: 0.7071647047996521
Epoch 0: 0.7027407288551331
Epoch 0: 0.7034420967102051
Epoch 0: 0.6973006129264832
Epoch 0: 0.700343668460846
Epoch 0: 0.6927923560142517
Epoch 0: 0.6857284903526306
Epoch 0: 0.6867830157279968
Epoch 0: 0.6825870871543884
Epoch 0: 0.6698495149612427
Epoch 0: 0.6691404581069946
Epoch 0: 0.6729548573493958
Epoch 0: 0.6643473505973816
Epoch 0: 0.6586947441101074
Epoch 0: 0.6530790328979492
Epoch 0: 0.6563106179237366
Epoch 0: 0.6547855734825134
Epoch 0: 0.6528515815734863
Epoch 0: 0.6237528920173645
Epoch 0: 0.6257097721099854
Epoch 0: 0.6341657042503357
Epoch 0: 0.6157877445220947
Epoch 0: 0.6068785786628723
Epoch 0: 0.6189902424812317
Epoch 0: 0.5914162993431091
Epoch 0: 0.5930476784706116
Epoch 0: 0.5774946212768555
Epoch 0: 0.57726716995

In [22]:
# test on Task 1
n_batch = floor(task1_ts_samples.shape[0] / minibatch_sz)

n_correct = 0    
for b in range(n_batch):
    x_batch = test_data.data[task1_ts_samples[(b*minibatch_sz):((b+1)*minibatch_sz)]]
    y_batch = test_data.targets[task1_ts_samples[(b*minibatch_sz):((b+1)*minibatch_sz)]]

    # flatten image before presenting to the network and normalize intensities to the range [0, 1]
    x_batch = torch.flatten(x_batch / 255, start_dim=1)

    y_hat_batch = network(x_batch)
    _, prediction = torch.max(y_hat_batch, 1)
    n_correct += (prediction == y_batch).sum().item()

print(f'Accuracy = {(n_correct * 100) / task1_ts_samples.shape[0]}')

Accuracy = 99.81087470449172


In [23]:
# train on task 2
for e in range(n_epoch):
    n_batch = floor(task2_tr_samples.shape[0] / minibatch_sz)
    
    for b in range(n_batch):
        x_batch = train_data.data[task2_tr_samples[(b*minibatch_sz):((b+1)*minibatch_sz)]]
        y_batch = train_data.targets[task2_tr_samples[(b*minibatch_sz):((b+1)*minibatch_sz)]]

        # flatten image before presenting to the network and normalize intensities to the range [0, 1]
        x_batch = torch.flatten(x_batch / 255, start_dim=1)

        # convert label to one hot
        y_batch = y_batch % 2
        y_batch = F.one_hot(y_batch).float()

        y_hat_batch = network(x_batch)
        loss = criterion(y_hat_batch, y_batch)
        loss.backward()
        optimizer.step()

        print(f'Epoch {e}: {loss.item()}')

Epoch 0: 69.4424057006836
Epoch 0: 77.93883514404297
Epoch 0: 95.78223419189453
Epoch 0: 77.31819152832031
Epoch 0: 82.10861206054688
Epoch 0: 97.94888305664062
Epoch 0: 47.53200912475586
Epoch 0: 41.89323425292969
Epoch 0: 37.094451904296875
Epoch 0: 24.8565616607666
Epoch 0: 10.43249797821045
Epoch 0: 17.543197631835938
Epoch 0: 30.247007369995117
Epoch 0: 20.887351989746094
Epoch 0: 14.361005783081055
Epoch 0: 5.889129638671875
Epoch 0: 3.509218692779541
Epoch 0: 1.243325114250183
Epoch 0: 3.5031096935272217
Epoch 0: 13.644874572753906
Epoch 0: 18.154708862304688
Epoch 0: 17.467178344726562
Epoch 0: 17.80025863647461
Epoch 0: 5.460891246795654
Epoch 0: 6.295725345611572
Epoch 0: 7.896490097045898
Epoch 0: 5.619497299194336
Epoch 0: 23.22791290283203
Epoch 0: 14.744108200073242
Epoch 0: 29.070405960083008
Epoch 0: 24.24897575378418
Epoch 0: 9.862651824951172
Epoch 0: 4.7876176834106445
Epoch 0: 1.9906792640686035
Epoch 0: 2.841184139251709
Epoch 0: 3.460263729095459
Epoch 0: 11.09870

In [24]:
# test on Task 2
n_batch = floor(task2_ts_samples.shape[0] / minibatch_sz)

n_correct = 0    
for b in range(n_batch):
    x_batch = test_data.data[task2_ts_samples[(b*minibatch_sz):((b+1)*minibatch_sz)]]
    y_batch = test_data.targets[task2_ts_samples[(b*minibatch_sz):((b+1)*minibatch_sz)]]

    # flatten image before presenting to the network and normalize intensities to the range [0, 1]
    x_batch = torch.flatten(x_batch / 255, start_dim=1)

    y_hat_batch = network(x_batch)
    _, prediction = torch.max(y_hat_batch, 1)
    n_correct += (prediction == (y_batch % 2)).sum().item()

print(f'Accuracy = {(n_correct * 100) / task1_ts_samples.shape[0]}')

Accuracy = 89.17257683215131


In [25]:
# test on Task 1
n_batch = floor(task1_ts_samples.shape[0] / minibatch_sz)

n_correct = 0    
for b in range(n_batch):
    x_batch = test_data.data[task1_ts_samples[(b*minibatch_sz):((b+1)*minibatch_sz)]]
    y_batch = test_data.targets[task1_ts_samples[(b*minibatch_sz):((b+1)*minibatch_sz)]]

    # flatten image before presenting to the network and normalize intensities to the range [0, 1]
    x_batch = torch.flatten(x_batch / 255, start_dim=1)

    y_hat_batch = network(x_batch)
    _, prediction = torch.max(y_hat_batch, 1)
    n_correct += (prediction == y_batch).sum().item()

print(f'Accuracy = {(n_correct * 100) / task1_ts_samples.shape[0]}')

Accuracy = 57.4468085106383


### **Question 2**: The purpose of this question is to study the effect of replay on catatophic forgetting. In this question also, we will train the network on two tasks in a sequence? When we train the network on the second task, we will also use some samples from the first task for replay. TO keep things simple, select a random proportaion (say 50%) of samples from the first task for replay. After training evaluate the performance of the trained network on both tasks. What do you observe?

In [5]:
# Save some samples from previous tasks for replay
prop_saved = 0.5 # proportion of samples saved from a task for replay

In [6]:
# create a new network
network = Network(input_shape)

In [15]:
# train on task 1
for e in range(n_epoch):
    n_batch = floor(task1_tr_samples.shape[0] / minibatch_sz)
    
    for b in range(n_batch):
        x_batch = train_data.data[task1_tr_samples[(b*minibatch_sz):((b+1)*minibatch_sz)]]
        y_batch = train_data.targets[task1_tr_samples[(b*minibatch_sz):((b+1)*minibatch_sz)]]

        # flatten image before presenting to the network and normalize intensities to the range [0, 1]
        x_batch = torch.flatten(x_batch / 255, start_dim=1)

        # convert label to one hot
        y_batch = F.one_hot(y_batch).float()

        y_hat_batch = network(x_batch)
        loss = criterion(y_hat_batch, y_batch)
        loss.backward()
        optimizer.step()

        print(f'Epoch {e}: {loss.item()}')

Epoch 0: 0.6861187219619751
Epoch 0: 0.6880649328231812
Epoch 0: 0.6798537969589233
Epoch 0: 0.6838418841362
Epoch 0: 0.6823480129241943
Epoch 0: 0.678693413734436
Epoch 0: 0.6775755286216736
Epoch 0: 0.678235650062561
Epoch 0: 0.677001416683197
Epoch 0: 0.6789224147796631
Epoch 0: 0.6728288531303406
Epoch 0: 0.667654275894165
Epoch 0: 0.6719087362289429
Epoch 0: 0.6661622524261475
Epoch 0: 0.6608288288116455
Epoch 0: 0.6627450585365295
Epoch 0: 0.6563937067985535
Epoch 0: 0.6439087390899658
Epoch 0: 0.6464380025863647
Epoch 0: 0.6473026275634766
Epoch 0: 0.6406174898147583
Epoch 0: 0.6351507902145386
Epoch 0: 0.6325798034667969
Epoch 0: 0.6348806619644165
Epoch 0: 0.6348358392715454
Epoch 0: 0.628645122051239
Epoch 0: 0.5977924466133118
Epoch 0: 0.601351261138916
Epoch 0: 0.6065499782562256
Epoch 0: 0.5875546336174011
Epoch 0: 0.5756452083587646
Epoch 0: 0.5907731056213379
Epoch 0: 0.5582550168037415
Epoch 0: 0.5601009726524353
Epoch 0: 0.5380309820175171
Epoch 0: 0.5446566939353943
E

In [16]:
task1_replay = np.random.choice(task1_tr_samples.numpy(), int(prop_saved * task1_tr_samples.shape[0]))
task1_replay_samples = torch.Tensor(task1_replay).int()

In [17]:
# train on task 2 with replay
tr_samples = torch.concatenate([task2_tr_samples, task1_replay_samples], dim=0) # concatenate samples from task 2 and replay samples from task 1

# randomize the array to mix samples from task 2 and replay
np.random.shuffle(tr_samples.numpy())

n_batch = floor(tr_samples.shape[0] / minibatch_sz)

for e in range(n_epoch):    
    for b in range(n_batch):
        x_batch = train_data.data[tr_samples[(b*minibatch_sz):((b+1)*minibatch_sz)]]
        y_batch = train_data.targets[tr_samples[(b*minibatch_sz):((b+1)*minibatch_sz)]]

        # flatten image before presenting to the network and normalize intensities to the range [0, 1]
        x_batch = torch.flatten(x_batch / 255, start_dim=1)

        # convert label to one hot
        y_batch = y_batch % 2
        y_batch = F.one_hot(y_batch).float()

        y_hat_batch = network(x_batch)
        loss = criterion(y_hat_batch, y_batch)
        loss.backward()
        optimizer.step()

        print(f'Epoch {e}: {loss.item()}')

Epoch 0: 74.47444152832031
Epoch 0: 85.4444351196289
Epoch 0: 52.0087890625
Epoch 0: 56.47416305541992
Epoch 0: 58.91046142578125
Epoch 0: 20.745325088500977
Epoch 0: 61.7367057800293
Epoch 0: 59.45780944824219
Epoch 0: 35.815460205078125
Epoch 0: 31.981496810913086
Epoch 0: 10.560304641723633
Epoch 0: 18.08081817626953
Epoch 0: 23.73309326171875
Epoch 0: 9.656204223632812
Epoch 0: 10.612948417663574
Epoch 0: 13.979422569274902
Epoch 0: 17.6220703125
Epoch 0: 33.5146484375
Epoch 0: 30.247905731201172
Epoch 0: 19.095428466796875
Epoch 0: 9.698869705200195
Epoch 0: 7.056517601013184
Epoch 0: 1.79326331615448
Epoch 0: 8.665701866149902
Epoch 0: 11.909594535827637
Epoch 0: 29.386981964111328
Epoch 0: 34.13988494873047
Epoch 0: 35.49656677246094
Epoch 0: 23.07880401611328
Epoch 0: 18.847185134887695
Epoch 0: 6.588086128234863
Epoch 0: 6.474761962890625
Epoch 0: 4.687563419342041
Epoch 0: 8.245673179626465
Epoch 0: 8.287185668945312
Epoch 0: 15.62210750579834
Epoch 0: 16.157024383544922
Epoc

In [18]:
# test on Task 1
n_batch = floor(task1_ts_samples.shape[0] / minibatch_sz)

n_correct = 0    
for b in range(n_batch):
    x_batch = test_data.data[task1_ts_samples[(b*minibatch_sz):((b+1)*minibatch_sz)]]
    y_batch = test_data.targets[task1_ts_samples[(b*minibatch_sz):((b+1)*minibatch_sz)]]

    # flatten image before presenting to the network and normalize intensities to the range [0, 1]
    x_batch = torch.flatten(x_batch / 255, start_dim=1)

    y_hat_batch = network(x_batch)
    _, prediction = torch.max(y_hat_batch, 1)
    n_correct += (prediction == y_batch).sum().item()

print(f'Accuracy = {(n_correct * 100) / task1_ts_samples.shape[0]}')

Accuracy = 98.10874704491725


In [19]:
# test on Task 2
n_batch = floor(task2_ts_samples.shape[0] / minibatch_sz)

n_correct = 0    
for b in range(n_batch):
    x_batch = test_data.data[task2_ts_samples[(b*minibatch_sz):((b+1)*minibatch_sz)]]
    y_batch = test_data.targets[task2_ts_samples[(b*minibatch_sz):((b+1)*minibatch_sz)]]

    # flatten image before presenting to the network and normalize intensities to the range [0, 1]
    x_batch = torch.flatten(x_batch / 255, start_dim=1)

    y_hat_batch = network(x_batch)
    _, prediction = torch.max(y_hat_batch, 1)
    n_correct += (prediction == (y_batch % 2)).sum().item()

print(f'Accuracy = {(n_correct * 100) / task1_ts_samples.shape[0]}')

Accuracy = 90.26004728132388


# **Directions for further exploration**
We will not share solutions for these questions.

**Q1**: How does the proportion of samples saved for replay affect the model's performance?

**Q2**: Use replay to train the nentwork on more than two tasks. What is the impact of replay on the memory used by your models? Note that replay-based approach requires that you save the replay samples from previous task forever. This implies that the memory required to store samples contributes to your models memory footprint.

**Q3**: Can we chose replay samples more smartly so that we generate maximal impact while using minimal memory? For instance, can you use the network's prediction on a given task to identify samples stored for replay?
