# Naive Replay based Class Incremental Learning (CIL) with PyTorch

This notebook covers a simplified implementation of class incremental learning (CIL) using replay-based techniques. The goal of CIL is to learn a model that can incrementally learn new classes without forgetting the previously learned classes. In this notebook, we will use a replay-based technique to store and replay the data from the previous classes to prevent forgetting.

To do this, we will use the following steps:
Train base model on the first set of two classes e.g. 0 and 1 for 100 epochs utilising Stochastic Gradient Descent (SGD) with a learning rate of 0.01.
After training our base model, we will then store a subset of our data from the first two classes in a replay buffer, let $R_i$ be the replay buffer for class $i$ and $D_i$ be the training data for class $i$ up to that training step inclusive of all previous classes.

We denote: $R_i \subseteq D_i$

And: $D_i = \cup_{x = 0}^{i}d_x$

Where $d_x$ is the data for class $x$.

i.e. the replay buffer will only contain a variant subset of the data from the previous classes as opposed to all the data of that class which
we can change to determine the effect of the replay buffer size on the model's performance.

From this point onwards, we will incrementally train our model utilising Class-Incremental Learning (CIL) by training on the class and replaying the data from the previous classes.

This model will utilise a standard implementation of a ResNet18 CNN model with a single fully connected layer at the end to classify the images. It will also be trained on the CIFAR-10 dataset, which contains 60,000 32x32 colour images in 10 classes, with 6,000 images per class.

In [None]:
import torch
import numpy as np
import matplotlib.pyplot as plt
from PIL import Image
from torchvision import transforms
import time
from cifar import load_cifar10_data, split_into_classes, get_class_indexes 
from torch.utils.data import DataLoader
import random

# Path to the dataset
DATASET_PATH = 'cifar-10-batches-py' 
CLASSES = ['airplane', 'automobile', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck']


### Define custom ResNet18 model for CIFAR10

In this case we chance the convolution kernel to 3x3 and the stride to 1 for the first layer. We will also change the number of classes on the final softmax layer to 10 to match the CIFAR10 dataset.

In [2]:
class ResNet18CIFAR(torch.nn.Module):
    def __init__(self):
        super(ResNet18CIFAR, self).__init__()
        self.resnet = torch.hub.load('pytorch/vision:v0.10.0', 'resnet18', pretrained=False)
        # change the first layer to accept 32x32 images with 3 channels rather than 224x224 images
        # check the size of the input layer
        self.resnet.conv1 = torch.nn.Conv2d(3, 128, kernel_size=3, stride=1, padding=1, bias=False)
        
        self.resnet.bn1 = torch.nn.BatchNorm2d(128)
        # change number of blocks per layer
        self.resnet.layer1 = torch.nn.Sequential(
            torch.nn.Conv2d(128, 128, kernel_size=3, stride=1, padding=1, bias=False),
            torch.nn.BatchNorm2d(128),
            torch.nn.ReLU(),
            torch.nn.Conv2d(128, 128, kernel_size=3, stride=1, padding=1, bias=False),
            torch.nn.BatchNorm2d(128)
        )
        self.resnet.layer2 = torch.nn.Sequential(
            torch.nn.Conv2d(128, 256, kernel_size=3, stride=2, padding=1, bias=False),
            torch.nn.BatchNorm2d(256),
            torch.nn.ReLU(),
            torch.nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1, bias=False),
            torch.nn.BatchNorm2d(256)
        )
        self.resnet.layer3 = torch.nn.Sequential(
            torch.nn.Conv2d(256, 512, kernel_size=3, stride=2, padding=1, bias=False),
            torch.nn.BatchNorm2d(512),
            torch.nn.ReLU(),
            torch.nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1, bias=False),
            torch.nn.BatchNorm2d(512)
        )
        self.resnet.layer4 = torch.nn.Sequential(
            torch.nn.Conv2d(512, 1024, kernel_size=3, stride=2, padding=1, bias=False),
            torch.nn.BatchNorm2d(1024),
            torch.nn.ReLU(),
            torch.nn.Conv2d(1024, 1024, kernel_size=3, stride=1, padding=1, bias=False),
            torch.nn.BatchNorm2d(1024)
        )
        # change input layer to accept 32x32 images by utilising smaller convolutional kernel
        self.resnet.fc = torch.nn.Linear(1024, 10)
        # start with 5 classes and add more as needed
        self.resnet.maxpool = torch.nn.Identity()
        # maxpool worsens performance and is unnecessary for small image sizes
        

    def forward(self, x):
        return self.resnet(x)
    
model = ResNet18CIFAR()
model.eval

### Data Load and Preprocessing into separate class datasets

Utilise the CIFAR-10 dataset alongside a prebuilt dataloader to load the data into separate class datasets.
(Krishi's code)

In [3]:
train_data, train_labels, test_data, test_labels = load_cifar10_data(DATASET_PATH)
classes = ['airplane', 'automobile', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck']


# The Algorithm

Train initial model on the first two classes.

For each class $i$ in the dataset:

1. Train the model on the current class $i$ and replay the data from the previous classes.

2. Store a subset of the data from the current class $i$ in the replay buffer $R_i$.

3. Evaluate the model on the test set.

4. Increment the number of classes seen.

5. Repeat steps 1-4 for each class in the dataset.

Note: the size of the replay buffer is a hyperparameter that can be tuned to determine the effect of the replay buffer size on the model's performance.

The following hyperparameters will be used in this notebook:

- Epochs: 100 (per class)
- Learning rate: 0.00005
- Momentum: 0.9 (if using SGD)
- Batch size: 512
- Replay buffer size: Variant (e.g. 1000, 2000, 3000, 4000, 5000 per class)

### References

1. [Continual Learning with Deep Architectures: A Review](https://arxiv.org/abs/1907.04471)

2. [RECALL Replay Based Continual Learning in Semantic Segmentation](https://openaccess.thecvf.com/content/ICCV2021/papers/Maracani_RECALL_Replay-Based_Continual_Learning_in_Semantic_Segmentation_ICCV_2021_paper.pdf)

In [4]:
def calculate_accuracy(model, accuracies, input_data, input_labels, task):
    # calculate accuracy
    correct = 0
    total = len(input_data)
    batch = 100
    for j in range(0, len(input_data), batch):
        image = input_data[j:j+batch]
        labels = input_labels[j:j+batch]
        outputs = model(image)
        if task == 1:
            outputs[:, :5].data.fill_(-10e10)
        if task == 0:
            outputs[:, 5:].data.fill_(-10e10)
        # get the index of the highest value
        _, predicted = torch.max(outputs.data, 1)
        # add the number of correct predictions to the total
        correct += (predicted == labels).sum().item()
            
    accuracy_string = "Accuracy: "
    accuracies.append(correct/total * 100)
    accuracy_string += (f"{correct/total * 100:.2f}% ")

    return accuracy_string, accuracies

In [5]:
# free up gpu memory
torch.cuda.empty_cache()
memory = torch.cuda.memory_allocated()
print(f"Memory allocated: {memory/1e9} GB")

replay_memory = {cls: [] for cls in CLASSES}
accuracies = []
learning_rate = 0.001
momentum = 0.9
batch_size = 10
epochs = 100
loss_function = torch.nn.CrossEntropyLoss()
# output is softmax gaussian probability distribution, so cross entropy loss is appropriate
optimiser = torch.optim.SGD(model.parameters(), lr=learning_rate)
losses = []

# firstly, train on 5 classes utilising Adam
current_train_data, current_train_labels = split_into_classes(train_data, train_labels, ['airplane', 'automobile', 'bird', 'cat', 'deer'])

# convert to tensors
current_train_data = torch.tensor(current_train_data).float()
current_train_labels = torch.tensor(current_train_labels)

test_data_per_class_1, test_labels_per_class_1 = split_into_classes(test_data, test_labels, ['airplane', 'automobile', 'bird', 'cat', 'deer'])
test_data_per_class_2, test_labels_per_class_2 = split_into_classes(test_data, test_labels, ['dog', 'frog', 'horse', 'ship', 'truck'])

test_data_per_class_1 = torch.tensor(test_data_per_class_1).float()
test_labels_per_class_1 = torch.tensor(test_labels_per_class_1)

if torch.cuda.is_available():
    model = model.cuda()
    current_train_data = current_train_data.cuda()
    current_train_labels = current_train_labels.cuda()
    test_data_per_class_1 = test_data_per_class_1.cuda()
    test_labels_per_class_1 = test_labels_per_class_1.cuda()

# we have figured out CUDA for jupyter notebooks :)
for i in range(epochs):
    accuracy_string, accuracies = calculate_accuracy(model, accuracies, current_train_data, current_train_labels, 0)
    # calculate accuracy
    
    for j in range(0, len(current_train_data), batch_size):
        optimiser.zero_grad()
        images = current_train_data[j:j+batch_size]
        labels = current_train_labels[j:j+batch_size]
        
        outputs = model(images)
        loss = loss_function(outputs, labels)
        if j == 0:
            losses.append(loss.item())
        loss.backward()
        optimiser.step()
    print(f"Epoch {i+1}/{epochs}, " + accuracy_string + f"Loss: {losses[len(losses)-1]}")
    # accuracy greater than 80% we stop training
    if accuracies[len(accuracies)-1] > 85:
        break
    
accuracy_string_1, _ = calculate_accuracy(model, accuracies, test_data_per_class_1, test_labels_per_class_1, 0)
# remove last element from accuracies
accuracies.pop()

# print accuracy
print("Task 1 test accuracy: " + accuracy_string_1)

test_data_per_class_2 = torch.tensor(test_data_per_class_2).float()
test_labels_per_class_2 = torch.tensor(test_labels_per_class_2)

if torch.cuda.is_available():
    test_data_per_class_2 = test_data_per_class_2.cuda()
    test_labels_per_class_2 = test_labels_per_class_2.cuda()

accuracy_string_2, _ = calculate_accuracy(model, accuracies, test_data_per_class_2, test_labels_per_class_2, 1)
accuracies.pop()

print("Task 2 test accuracy: " + accuracy_string_2)


### Initial Results for five classes

In [6]:
# plot the loss and accuracy
# fix axes labels
plt.plot(losses)
plt.title('Loss')
plt.show()

# show airplane and automobile accuracies

plt.plot(accuracies, label='Train Set Accuracy')
plt.title('Train Set Accuracy')
plt.legend()
plt.show()

# extract numerical values from accuracy strings
accuracy_value_1 = float(accuracy_string_1.split(': ')[1].replace('%', ''))
accuracy_value_2 = float(accuracy_string_2.split(': ')[1].replace('%', ''))

# plot task 1 and task 2 test accuracies on a bar chart
plt.bar(['Task 1', 'Task 2'], [accuracy_value_1, accuracy_value_2])
plt.title('Test Set Accuracy for Task 1 and Task 2 (Post Task 1 Training)')
plt.ylabel('Accuracy (%)')
plt.show()

In [7]:
def calculate_accuracy_class(model, accuracies, data, labels):
        # calculate accuracy
        correct = [0] * 10
        total = len(data)/10
        batch_size = 1000
        with torch.no_grad():
            for j in range(0, len(data), batch_size):
                image = data[j:j+batch_size]
                label = labels[j:j+batch_size]
                outputs = model(image)
                _, predicted = torch.max(outputs.data, 1)
                for i in range(10):  
                    correct[i] += (predicted[label == i] == label[label == i]).sum().item()
        
        j = 0
        accuracy_string = "Accuracy: "
        for classes in accuracies:
             accuracies[classes].append(correct[j]/total * 100)
             accuracy_string += (f"" + classes + f": {correct[j]/total * 100:.2f}%, ")
             j += 1

        return accuracy_string, accuracies

# Train on next 5 classes

In [None]:
from cifar import split_into_classes

MEMORY_BUFFER = 250

replay_memory_data, replay_memory_labels = split_into_classes(train_data, train_labels, ['airplane', 'automobile', 'bird', 'cat', 'deer'])
replay_memory_data = replay_memory_data[:MEMORY_BUFFER]
replay_memory_labels = replay_memory_labels[:MEMORY_BUFFER]
old_weights = model.resnet.fc.weight.data.clone()
new_10_output_layer = torch.nn.Linear(1024, 10)
new_10_output_layer.weight.data[:10] = old_weights
model.resnet.fc = new_10_output_layer

# free up gpu memory
torch.cuda.empty_cache()
memory = torch.cuda.memory_allocated()

accuracies_class = {cls: [] for cls in CLASSES}
accuracies_class_test = {cls: [] for cls in CLASSES}

accuracies = []
learning_rate = 0.001
momentum = 0.9
batch_size = 10
epochs = 100
loss_function = torch.nn.CrossEntropyLoss()
# output is softmax gaussian probability distribution, so cross entropy loss is appropriate
optimiser = torch.optim.SGD(model.parameters(), lr=learning_rate)
losses = []

# firstly, train on 5 classes utilising SGD
current_train_data, current_train_labels = split_into_classes(train_data, train_labels, ['dog', 'frog', 'horse', 'ship', 'truck'])

current_train_data = list(replay_memory_data) + list(current_train_data)
current_train_labels = list(replay_memory_labels) + list(current_train_labels)

# shuffle the data
shuffled = list(zip(current_train_data, current_train_labels))
random.shuffle(shuffled)
current_train_data, current_train_labels = zip(*shuffled)

# convert to tensors
current_train_data = torch.tensor(current_train_data).float()
current_train_labels = torch.tensor(current_train_labels)

train_data = torch.tensor(train_data).float()
train_labels = torch.tensor(train_labels)

if torch.cuda.is_available():
    model = model.cuda()
    current_train_data = current_train_data.cuda()
    current_train_labels = current_train_labels.cuda()

# we have figured out CUDA for jupyter notebooks :)
for i in range(epochs):
    accuracy_string, accuracies = calculate_accuracy(model, accuracies, current_train_data, current_train_labels, -1)
    _, accuracies_class = calculate_accuracy_class(model, accuracies_class, current_train_data, current_train_labels)
    # calculate accuracy
    
    for j in range(0, len(current_train_data), batch_size):
        optimiser.zero_grad()
        images = current_train_data[j:j+batch_size]
        labels = current_train_labels[j:j+batch_size]
        
        outputs = model(images)
        loss = loss_function(outputs, labels)
        if j == 0:
            losses.append(loss.item())
        loss.backward()
        optimiser.step()
    print(f"Epoch {i+1}/{epochs}, " + accuracy_string + f"Loss: {losses[len(losses)-1]}")
    # accuracy greater than 80% we stop training
    if accuracies[len(accuracies)-1] > 85:
        break

accuracy_string_1, _ = calculate_accuracy(model, accuracies, test_data_per_class_1, test_labels_per_class_1, 0)
# remove last element from accuracies
accuracies.pop()

# print accuracy
print("Task 1 test accuracy: " + accuracy_string_1)

accuracy_string_2, _ = calculate_accuracy(model, accuracies, test_data_per_class_2, test_labels_per_class_2, 1)
accuracies.pop()

print("Task 2 test accuracy: " + accuracy_string_2)

current_test_data = torch.tensor(test_data).float()
current_test_labels = torch.tensor(test_labels)

if torch.cuda.is_available():
    current_test_data = current_test_data.cuda()
    current_test_labels = current_test_labels.cuda()
    
_, accuracies_class_test = calculate_accuracy_class(model, accuracies_class_test, current_test_data, current_test_labels)

# print all classes in accuracies_class_test
#for cls in CLASSES:
#    print(f"Test accuracy for {cls}: {accuracies_class_test[cls]}")


In [9]:
# plot the loss and accuracy
# fix axes labels
plt.plot(losses)
plt.title('Loss')
plt.show()

# for i in range(len(accuracies)):
#     print(f"Accuracy at {i} : {accuracies[i]}")
# 
# for cls in CLASSES:
#     plt.plot(accuracies_class[cls], label=f'{cls}')
# 
# plt.title('Train Set Accuracy over Epochs')
# plt.legend()
# plt.show()

# plot the test set accuracy as a bar chart
#plt.bar(CLASSES, [accuracies_class_test[cls][0] for cls in CLASSES])
#plt.title('Test Set Accuracy for Each Class')
#plt.xlabel('Classes')
#plt.ylabel('Accuracy (%)') 
#plt.show()
#for cls in CLASSES:
#    print(f"Test accuracy for {cls}: {accuracies_class_test[cls]}")
    
# extract numerical values from accuracy strings
accuracy_value_1 = float(accuracy_string_1.split(': ')[1].replace('%', ''))
accuracy_value_2 = float(accuracy_string_2.split(': ')[1].replace('%', ''))
    
# plot task 1 and task 2 test accuracies on a bar chart
plt.bar(['Task 1', 'Task 2'], [accuracy_value_1, accuracy_value_2])
plt.title('Test Set Accuracy for Task 1 and Task 2 (Post Task 2 Training)')
plt.ylabel('Accuracy (%)')
plt.show()