# Replay based Class Incremental Learning (CIL) with PyTorch

This notebook covers a simplified implementation of class incremental learning (CIL) using replay-based techniques. The goal of CIL is to learn a model that can incrementally learn new classes without forgetting the previously learned classes. In this notebook, we will use a replay-based technique to store and replay the data from the previous classes to prevent forgetting.

To do this, we will use the following steps:
Train base model on the first set of two classes e.g. 0 and 1 for 100 epochs utilising Stochastic Gradient Descent (SGD) with a learning rate of 0.01 and momentum of 0.9.
After training our base model, we will then store a subset of our data from the first two classes in a replay buffer, let $R_i$ be the replay buffer for class $i$ and $D_i$ be the training data for class $i$ up to that training step inclusive of all previous classes.

We denote: $R_i \subseteq D_i$

And: $D_i = \cup_{x = 0}^{i}d_x$

Where $d_x$ is the data for class $x$.

i.e. the replay buffer will only contain a variant subset of the data from the previous classes as opposed to all the data of that class which
we can change to determine the effect of the replay buffer size on the model's performance.

From this point onwards, we will incrementally train our model utilising Class-Incremental Learning (CIL) by training on the class and replaying the data from the previous classes.

This model will utilise a standard implementation of a ResNet18 CNN model with a single fully connected layer at the end to classify the images. It will also be trained on the CIFAR-10 dataset, which contains 60,000 32x32 colour images in 10 classes, with 6,000 images per class.

In [133]:
import torch
import numpy as np
import matplotlib.pyplot as plt
from PIL import Image
from torchvision import transforms
import time
from cifar import load_cifar10_data, split_into_classes, get_class_indexes 
from torch.utils.data import DataLoader

# Path to the dataset
DATASET_PATH = 'cifar-10-batches-py' 
CLASSES = ['airplane', 'automobile', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck']


### Define custom ResNet18 model for CIFAR10

In [134]:
class ResNet18CIFAR(torch.nn.Module):
    def __init__(self):
        super(ResNet18CIFAR, self).__init__()
        self.resnet = torch.hub.load('pytorch/vision:v0.10.0', 'resnet18', pretrained=False)
        # change the first layer to accept 32x32 images with 3 channels rather than 224x224 images
        # check the size of the input layer
        print("|| conv1 weight size: ", self.resnet.conv1.weight.size())
        print("|| fc weight size: ", self.resnet.fc.weight.size())
        self.resnet.conv1 = torch.nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1, bias=False)
        # change input layer to accept 32x32 images by utilising smaller convolutional kernel
        self.resnet.fc = torch.nn.Linear(512, 10)
        # the architecture is already built to support 10 classes
        self.resnet.maxpool = torch.nn.Identity()
        # maxpool worsens performance and is unnecessary for small image sizes
        

    def forward(self, x):
        return self.resnet(x)
    
# make a resnet model
model = ResNet18CIFAR()
model.eval()

|| conv1 weight size:  torch.Size([64, 3, 7, 7])
|| fc weight size:  torch.Size([1000, 512])


Using cache found in /dcs/21/u2145461/.cache/torch/hub/pytorch_vision_v0.10.0


ResNet18CIFAR(
  (resnet): ResNet(
    (conv1): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
    (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (relu): ReLU(inplace=True)
    (maxpool): Identity()
    (layer1): Sequential(
      (0): BasicBlock(
        (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (relu): ReLU(inplace=True)
        (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      )
      (1): BasicBlock(
        (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (relu): ReLU(inplace=True)
       

### Data Load and Preprocessing into separate class datasets

Utilise the CIFAR-10 dataset alongside a prebuilt dataloader to load the data into separate class datasets.
(Krishi's code)

In [135]:
train_data, train_labels, test_data, test_labels = load_cifar10_data(DATASET_PATH)
classes = ['airplane', 'automobile', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck']

# split the data into 10 classes by doing sort by key where in the keys are the labels and the values are the data
train_split = {cls: [] for cls in CLASSES}
for img, label in zip(train_data, train_labels):
    train_split[CLASSES[label]].append(img)
    
# this makes more sense to me, effectively indexes 0-5000 are all airplanes, 5000-10000 are all automobiles etc
test_split = {cls: [] for cls in CLASSES}
for img, label in zip(test_data, test_labels):
    test_split[CLASSES[label]].append(img)
    
model = ResNet18CIFAR()

# load the model
model = torch.load('resnet18_cifar77ACC.pth',  map_location=torch.device('cpu'))
model.eval()

# test that the model is working by identifying the class of the first 1000 images in the test set
correct = 0
total = 10000
i = 0
# test the model on the training data
with torch.no_grad():
    for objects in classes:
        # get the input and output
        img = test_split[objects]
        
        # turn the image into a tensor
        img = torch.tensor(img).float()
        

        # normalise the image
        #img = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])(img)

        # get the prediction
        outputs = model(img)
        _, predicted = torch.max(outputs, 1)

        # get the number of correct predictions
        correct += (predicted == i).sum().item()
        i += 1

print('Accuracy of the network on the 1000 test images: %d %%' % (100 * correct / total))





Using cache found in /dcs/21/u2145461/.cache/torch/hub/pytorch_vision_v0.10.0


|| conv1 weight size:  torch.Size([64, 3, 7, 7])
|| fc weight size:  torch.Size([1000, 512])


  model = torch.load('resnet18_cifar77ACC.pth',  map_location=torch.device('cpu'))


Accuracy of the network on the 1000 test images: 77 %
