# Continual Learning with PyTorch

This notebook is a homework assignment for the course [CS182/282A](https://inst.eecs.berkeley.edu/~cs182/fa22/). The goal of this assignment is to get familiar with the concept of continual learning and how to implement it with PyTorch. We will use the MNIST benchmark for this assignment. Many parts of this notebook are based on the [ContinualAI](https://github.com/ContinualAI)

---


**Requisites**

*   Python 3.x
*   Jupyter
*   PyTorch >= 1.8
*   NumPy
*   Matplotlib
---

In [None]:
!free -m
!df -h
!nvidia-smi

In [None]:
import torch
import torch.nn as nn
import torchvision.datasets as datasets
import torchvision.transforms as transforms
import torch.optim as optim
import torch.nn.functional as F
import numpy as np
import matplotlib.pyplot as plt

## Downloading the dataset

We will use the MNIST dataset for this assignment. The dataset is already available in PyTorch, so we just need to download it.

In [None]:
# download mnist
train_dataset = datasets.MNIST(root='./data', train=True, download=True, transform=transforms.ToTensor())
test_dataset = datasets.MNIST(root='./data', train=False, download=True, transform=transforms.ToTensor())

# preprocess mnist
train_dataset.data = train_dataset.data.float() / 255
train_dataset.data = train_dataset.data.reshape(-1, 1, 28, 28)
test_dataset.data = test_dataset.data.float() / 255
test_dataset.data = test_dataset.data.reshape(-1, 1, 28, 28)

print('Train dataset shape: ', train_dataset.data.shape)
print('Test dataset shape: ', test_dataset.data.shape)

In [None]:
use_cuda = torch.cuda.is_available()
device = torch.device("cuda" if use_cuda else "cpu");
torch.manual_seed(1)

### Define Network

We will use a simple 5-layer convolutional neural network for this assignment. The network is defined in the `Net` class below. The network is composed of 3 convolutional layers and 2 fully connected layers. 

In [None]:
class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.conv1 = nn.Conv2d(1, 10, kernel_size=5)
        self.conv2 = nn.Conv2d(10, 20, kernel_size=5)
        self.conv2_drop = nn.Dropout2d()
        self.fc1 = nn.Linear(320, 50)
        self.fc2 = nn.Linear(50, 10)

    def forward(self, x):
        x = F.relu(F.max_pool2d(self.conv1(x), 2))
        x = F.relu(F.max_pool2d(self.conv2_drop(self.conv2(x)), 2))
        x = x.view(-1, 320)
        x = F.relu(self.fc1(x))
        x = F.dropout(x, training=self.training)
        x = self.fc2(x)
        return x

### Training and Testing

We will use the `train` and `test` functions to train and test the network. The `train` function takes as input the network, the training data, the optimizer, the loss function, and the number of epochs. The `test` function takes as input the network and the test data. The `train` function returns the training loss and accuracy, and the `test` function returns the test accuracy.

Note that we are not using DataLoaders for simplicity in this assignment.

In [None]:
def train(model, device, x_train, t_train, optimizer, epoch):
    model.train()
    
    for start in range(0, len(t_train)-1, 256): # batch size = 256
      end = start + 256
      x, y = torch.from_numpy(x_train[start:end]), torch.from_numpy(t_train[start:end]).long()
      x, y = x.to(device), y.to(device)
      
      optimizer.zero_grad()

      output = model(x)
      loss = F.cross_entropy(output, y)
      loss.backward()
      optimizer.step()
    print('Train Epoch: {} \tLoss: {:.6f}'.format(epoch, loss.item()))

def test(model, device, x_test, t_test):
    model.eval()
    test_loss = 0
    correct = 0
    for start in range(0, len(t_test)-1, 256):
      end = start + 256
      with torch.no_grad():
        x, y = torch.from_numpy(x_test[start:end]), torch.from_numpy(t_test[start:end]).long()
        x, y = x.to(device), y.to(device)
        output = model(x)
        test_loss += F.cross_entropy(output, y).item() # sum up batch loss
        pred = output.max(1, keepdim=True)[1] # get the index of the max logit
        correct += pred.eq(y.view_as(pred)).sum().item()

    test_loss /= len(t_test)
    print('Test set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'.format(
        test_loss, correct, len(t_test),
        100. * correct / len(t_test)))
    return 100. * correct / len(t_test)

Let's instantiate the network, the optimizer, and then train and test the network.

In [None]:
model = Net().to(device)
optimizer = optim.SGD(model.parameters(), lr=0.01, momentum=0.9)

# train and test
for epoch in range(3): 
    train(model, device, train_dataset.data.numpy(), train_dataset.targets.numpy(), optimizer, epoch)
    test(model, device, test_dataset.data.numpy(), test_dataset.targets.numpy())

# Permuted MNIST

Permuted MNIST is one of basic benchmarks for continual learning. In this benchmark, the pixels of the MNIST images are permuted randomly. The goal of the network is to learn to classify the images despite the permutation of the pixels. This benchmark is the example of domain continual learning, where the input domain changes.

In [None]:
def permute_mnist(mnist, seed):
    """ Given the training set, permute pixels of each img the same way. """

    np.random.seed(seed)
    print("starting permutation...")
    h = w = 28
    perm_inds = list(range(h*w))
    np.random.shuffle(perm_inds)
    # print(perm_inds)
    perm_mnist = []
    for set in mnist:
        num_img = set.shape[0]
        flat_set = set.reshape(num_img, w * h)
        perm_mnist.append(flat_set[:, perm_inds].reshape(num_img, 1, w, h))
    print("done.")
    return perm_mnist

In [None]:
x_train2, x_test2 = permute_mnist([train_dataset.data.numpy(), test_dataset.data.numpy()], 0)

In [None]:
f, axarr = plt.subplots(1,2)
axarr[0].imshow(train_dataset.data.numpy()[1, 0], cmap="gray")
axarr[1].imshow(x_train2[2, 0], cmap="gray")
np.vectorize(lambda ax:ax.axis('off'))(axarr)

Let's test our pretrained model is still working on both the original and the permuted MNIST datasets.

In [None]:
print("Testing on the first task:")
test(model, device, test_dataset.data.numpy(), test_dataset.targets.numpy())

print("Testing on the second task:")
test(model, device, x_test2, test_dataset.targets.numpy())

The newtork is unable to classify the permuted MNIST images. This isn't unexpected, since we did not train the network to classify the permuted MNIST images. Now let's fine-tune the network on the permuted MNIST dataset.

In [None]:
for epoch in range(1, 3):
    train(model, device, x_train2, train_dataset.targets.numpy(), optimizer, epoch)
    test(model, device, x_test2, test_dataset.targets.numpy())

In [None]:
print("Testing on the first task:")
test(model, device, test_dataset.data.numpy(), test_dataset.targets.numpy())

print("Testing on the second task:")
test(model, device, x_test2, test_dataset.targets.numpy())

We observe that the network performs very well on the new task but poorly on the original MNIST task. Catastrophic forgetting occurs here: the network forgets the original MNIST task when it is trained on the permuted MNIST task. Now let's see how can we mitigate the effect of catastrophic forgetting.

## Continual Learning Strategies

Continual learning strategies are methods that allow a network to learn multiple tasks without forgetting the previous tasks. There are many different strategies, and we will implement 3 of them in this assignment. The strategies are: 

*   **Naive**: Naive fine tuning. Train the network on each task separately.
*   **EWC**: Elastic Weight Consolidation
*   **Rehearsal**: Store some examples from previous tasks and use them to train the network on the current task.

Let's implement the strategies. We will use the `train` and `test` functions defined above to train and test the network. 

In [None]:
# task 1
x_train = train_dataset.data.numpy()
t_train = train_dataset.targets.numpy()
x_test = test_dataset.data.numpy()
t_test = test_dataset.targets.numpy()

task_1 = [(x_train, t_train), (x_test, t_test)]

# task 2
x_train2, x_test2 = permute_mnist([x_train, x_test], 1)
task_2 = [(x_train2, t_train), (x_test2, t_test)]

# task 3
x_train3, x_test3 = permute_mnist([x_train, x_test], 2)
task_3 = [(x_train3, t_train), (x_test3, t_test)]

# task list
tasks = [task_1, task_2, task_3]

### Naive

The naive strategy is the simplest strategy. We just train the network on each task separately. Let's see how well the network performs on each task and how much it forgets from the previous tasks.

In [None]:
# Define the model and optimizer
model = Net().to(device)
optimizer = optim.SGD(model.parameters(), lr=0.01, momentum=0.9)

In [None]:
naive_accs = []
num_tasks = len(tasks)

for id, task in enumerate(tasks):
    avg_acc = 0 # average accuracy on task 1, 2, ..., 5
    (x_train, t_train), _ = task
    print("Training on task: ", id+1)

    for epoch in range(3): 
        train(model, device, x_train, t_train, optimizer, epoch)

    for id_test, task in enumerate(tasks):
        print('Test on task {}:'.format(id_test+1))
        _, (x_test, t_test) = task
        acc = test(model, device, x_test, t_test)
        avg_acc += acc
    
    naive_accs.append(avg_acc/num_tasks)
    print('Average accuracy on each task: ', avg_acc/num_tasks)
    print('-----------------------------------')

Qa1: What do you observe? How much does the network forget from the previous tasks? Why do you think this happens?

Qa2: (Open-ended question) We are using CNN. Does MLP perform better or worse than CNN? Try it out and report your results.

### EWC

Elastic Weights Consolidation (EWC) strategy is proposed in this paper: "[Overcoming catastrophic forgetting in neural networks](https://arxiv.org/abs/1612.00796)" This method is a regularization strategy that penalizes the network for changing the weights of the previous tasks. 

It is based on the computation of the importance of each weight (fisher information) and a squared regularization loss, penalizing changes in the most important wheights for the previous tasks.

$\mathcal{L}_{\text{EWC}}(\theta) = \mathcal{L}(\theta) + \lambda / 2 \sum_i F_i \left(\theta_i - \theta_i^{\text{old}}\right)^2$

where $\theta$ is the current network parameters, $\theta^{\text{old}}$ is the network parameters from the previous task, $F_i$ is the diagonal value of fisher information matrix , and $\lambda$ is a hyperparameter. Informally speaking, Fisher information is the approximation of the Hessian matrix of the loss function with respect to the weights. Therefore, the above equation is 2nd order Taylor expansion of the loss function around the previous task parameters. 

However, computing the fisher information matrix is not trivial. We will use the diagonal approximation of the fisher information matrix, which is the square of the gradient of the loss function with respect to the old weights.

In [None]:
fisher_dict = {}
optpar_dict = {}
ewc_lambda = 0.4

In [None]:
model = Net().to(device)
optimizer = optim.SGD(model.parameters(), lr=0.01, momentum=0.9)

Helper function to compute the fisher information matrix for each weight. This function is called after each task is trained.

In [None]:
def on_task_update(task_id, x_mem, t_mem):

    model.train()
    optimizer.zero_grad()

    # accumulating gradients
    for start in range(0, len(t_mem)-1, 256):
        end = start + 256
        x, y = torch.from_numpy(x_mem[start:end]), torch.from_numpy(t_mem[start:end]).long()
        x, y = x.to(device), y.to(device)
        output = model(x)
        loss = F.cross_entropy(output, y)
        loss.backward()

    fisher_dict[task_id] = {}
    optpar_dict[task_id] = {}

    # gradients accumulated can be used to calculate fisher
    for name, param in model.named_parameters():
        optpar_dict[task_id][name] = param.data.clone()
        fisher_dict[task_id][name] = param.grad.data.clone().pow(2)

We have to change the `train` function to compute the fisher information matrix for each weight. We will use the `on_task_update` function defined above to compute the fisher information matrix.

In [None]:
def train_ewc(model, device, task_id, x_train, t_train, optimizer, epoch):
    model.train()

    for start in range(0, len(t_train)-1, 256):
        end = start + 256
        x, y = torch.from_numpy(x_train[start:end]), torch.from_numpy(t_train[start:end]).long()
        x, y = x.to(device), y.to(device)
        
        optimizer.zero_grad()

        output = model(x)
        loss = F.cross_entropy(output, y)
        
        for task in range(task_id):
            for name, param in model.named_parameters():
                fisher = fisher_dict[task][name]
                optpar = optpar_dict[task][name]
                loss += (fisher * (optpar - param).pow(2)).sum() * ewc_lambda
        
        loss.backward()
        optimizer.step()
    print('Train Epoch: {} \tLoss: {:.6f}'.format(epoch, loss.item()))

In [None]:
ewc_accs = []
num_tasks = len(tasks)

for id, task in enumerate(tasks):
    avg_acc = 0 # average accuracy on task 1, 2, ..., 5
    (x_train, t_train), _ = task
    print("Training on task: ", id)

    for epoch in range(3): 
        train_ewc(model, device, id, x_train, t_train, optimizer, epoch)
    on_task_update(id, x_train, t_train)

    for id_test, task in enumerate(tasks):
        print('Test on task {}:'.format(id_test+1))
        _, (x_test, t_test) = task
        acc = test(model, device, x_test, t_test)
        avg_acc += acc
    
    ewc_accs.append(avg_acc/num_tasks)
    print('Average accuracy on each task: ', avg_acc/num_tasks)
    print('-----------------------------------')

Qb1. Hyperparameter is underexplored in this assignment. Try different values of $\lambda$ and report your results.

Qb2. What is the role of $\lambda$? What happens if $\lambda$ is too small or too large? Explain the results with plasticity and stability of the network.

### Rehearsal

Another strategy to mitigate catastrophic forgetting is to store some examples from previous tasks and use them to train the network on the current task. This strategy is called "rehearsal". Storing all the examples would perform best but is not feasible. Therefore, we will use a subset of the examples from the previous tasks. 

In [None]:
def shuffle_in_unison(dataset, seed, in_place=False):
    """ Shuffle two (or more) list in unison. """

    np.random.seed(seed)
    rng_state = np.random.get_state()
    new_dataset = []
    for x in dataset:
        if in_place:
            np.random.shuffle(x)
        else:
            new_dataset.append(np.random.permutation(x))
        np.random.set_state(rng_state)

    if not in_place:
        return new_dataset

In [None]:
model = Net().to(device)
optimizer = optim.SGD(model.parameters(), lr=0.01, momentum=0.9)

In [None]:
rehe_accs = []
num_tasks = len(tasks)

for id, task in enumerate(tasks):
    avg_acc = 0
    print("Training on task: ", id)

    (x_train, t_train), _ = task

    # for previous task
    for i in range(id):
        (past_x_train, past_t_train), _ = tasks[i]
        x_train = np.concatenate((x_train, past_x_train))
        t_train = np.concatenate((t_train, past_t_train))

    x_train, t_train = shuffle_in_unison([x_train, t_train], 0)

    for epoch in range(3):
        train(model, device, x_train, t_train, optimizer, epoch)

    for id_test, task in enumerate(tasks):
        print("Testing on task: ", id_test)
        _, (x_test, t_test) = task
        acc = test(model, device, x_test, t_test)
        avg_acc = avg_acc + acc

    print("Avg acc: ", avg_acc / num_tasks)
    rehe_accs.append(avg_acc / num_tasks)

Qc1. What would be the pros and cons of rehearsal?

## Conclusion

Let's compare the performance of the 3 strategies on the permuted MNIST dataset.

In [None]:
plt.plot([1, 2, 3], naive_accs, '-o', label="Naive")
plt.plot([1, 2, 3], rehe_accs, '-o', label="Rehearsal")
plt.plot([1, 2, 3], ewc_accs, '-o', label="EWC")
plt.xlabel('Tasks Encountered', fontsize=14)
plt.ylabel('Average Accuracy', fontsize=14)
plt.title('CL Strategies Comparison on MNIST', fontsize=14);
plt.xticks([1, 2, 3])
plt.legend(prop={'size': 16});