**Import libraries**


In [None]:
import time
from copy import deepcopy

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Subset, DataLoader
from torch.backends import cudnn
from torch.nn.functional import one_hot


import torchvision
from torchvision import transforms
from torchvision.datasets import CIFAR100
from torchvision.models import resnet

import numpy as np
import matplotlib.pyplot as plt
from sklearn.metrics import confusion_matrix
import pandas as pd
import seaborn as sns

from PIL import Image
from tqdm import tqdm

# Load resnet_cifar.py
import os
if not os.path.exists("./resnet_cifar.py"):
    !wget --no-check-certificate 'https://docs.google.com/uc?export=download&id=14ugdr3UoIWHmRCRS9KrJiQmCRK9WvCVj' -O resnet_cifar.py


**Set Arguments**


In [None]:
DEVICE = 'cuda'

BATCH_SIZE = 128

K = 2000
NUM_EPOCHS = 70

LR = 2.0
MOMENTUM = 0.9
STEP_SIZE = [49,63]
GAMMA = 0.2
WEIGHT_DECAY = 1e-5
LOG_FREQUENCY = 10

# Random seeds
np.random.seed(653)
torch.manual_seed(653)
torch.cuda.manual_seed(653)

**Some utility functions**

In [None]:
def subset_indices(dataset, classes):
    """
    Returns the indices for the creation of a subset of a dataset containing only the images of the specified classes
    """
    indices = []
    for img_index, (_, img_label) in enumerate(dataset):
        if img_label in classes:
            indices.append(img_index)     # append the index of those images belonging to class c
    return indices


def show_heatmap_CM(labels, predictions):
    """
    Plot the confusion matrix as a heat map, given ground truth labels and the model predictions
    
    Params:
        labels: ground truth labels
        predictions: model predictions of the labels

    Return:
        Show the heatmap
        x axis: predicted class
        y axis: true class
    """
    fig, ax = plt.subplots(figsize=(9,9))

    # Build confusion matrix (as a 100x100 numpy array)
    cm = confusion_matrix(labels, predictions, labels=seen_classes)

    # Convert the confusion matrix to a 100x100 pandas dataframe (dimensions len(labels) x len(predictions) )
    df_cm = pd.DataFrame(cm, seen_classes, seen_classes)
    df_cm.columns = np.arange(100)+20    # for plotting reasons
    df_cm.index = np.arange(100)+20    # for plotting reasons

    # Visualize the confusion matrix as a heat map
    ax = sns.heatmap(df_cm, xticklabels=20, yticklabels=20, cbar=False, square=False, cmap='OrRd')
    sns.set(font_scale = 2)
    ax.set(xlabel='Predicted class', ylabel='True class')
    pos, textvals = plt.xticks()
    plt.xticks(np.array(pos)+20, textvals, va="center")
    pos, textvals = plt.yticks()
    plt.yticks(np.array(pos)+20, textvals, va="center")
    ax.tick_params(axis='x', pad=15)

    plt.show()

**Define Data Processing**

In [None]:
# Define transforms for training phase
train_transform = transforms.Compose([transforms.RandomCrop(32, padding=4),
                                      transforms.RandomHorizontalFlip(p=0.5),
                                      transforms.ToTensor(), # Turn PIL Image to torch.Tensor
                                      transforms.Normalize(mean=[0.5070751592371323, 0.48654887331495095, 0.4409178433670343],
                                                            std=[0.2673342858792401, 0.2564384629170883, 0.27615047132568404]) # Normalizes tensor with mean and standard deviation
])

# Define transforms for the evaluation phase
eval_transform = transforms.Compose([transforms.ToTensor(),
                                     transforms.Normalize(mean=[0.5088964127604166, 0.48739301317401956, 0.44194221124387256],
                                                          std=[0.2682515741720801, 0.2573637364478126, 0.2770957707973042])                                    
])


**Prepare Dataset**

In [None]:
# Load CIFAR100 training and test datasets 
cifar100_training = CIFAR100(root='./data', train=True, download=True, transform=train_transform)
cifar100_test = CIFAR100(root='./data', train=False, transform=eval_transform)

# Check dataset sizes
print(f'Training set size: {len(cifar100_training)}')
print(f'Test set size: {len(cifar100_test)}')


# Create an array with a random permutation of the 100 classes, with 10 rows of 10 classes each
classes = np.random.permutation(np.arange(100)).reshape((10,10))

**Prepare Network**

In [None]:
from resnet_cifar import resnet32
net = resnet32(num_classes = 100)   # Loading ResNet32 model

**Define loss function**

In [None]:
# Classification loss + distillation loss
criterion = nn.BCEWithLogitsLoss(reduction='mean')   # but: reduction='sum' is as seen in iCaRL paper (mean is not applied anywhere in the loss)

# All the loss are BCEwithLogitsLoss
def compute_loss(outputs, one_hot_labels, classes, task_step_counter, criterion = criterion, old_outputs = None):
    """
    Computes the loss presented in the iCaRL paper (classification loss + distillation loss),
    which is basically a binary cross-entropy loss, in which the "true distribution" vector y in the formula also includes the values of some pre-updated network outputs.
    Params:
        outputs: (BATCH_SIZE, NUM_CLASSES) matrix of logits of the current batch of images
        one_hot_labels: (BATCH_SIZE, NUM_CLASSES) each row is the one-hot-encoded ground-truth label of the images of the current batch.
        classes: the array with the random permutation of the 100 classes (10 rows with 10 classes each)
        old_outputs: (BATCH_SIZE, NUM_CLASSES) matrix of the pre-updated network outputs of all the nodes (not only those associated with the previously seen t-10 classes), AKA q_i values. It is None when the first set of classes is encountered.
        criterion: BCEWithLogitsLoss
    Returns:
        The fraction of the total loss associated with the current batch of images
    """

    if old_outputs is not None:
        sig = nn.Sigmoid()
        columns_idx_of_pre_updated_outputs = seen_classes[:-10]     # indices of the columns associated with the q_i values of the previously seen classes
        one_hot_labels[:, columns_idx_of_pre_updated_outputs] = sig(old_outputs[:,columns_idx_of_pre_updated_outputs])  # each row of this matrix is now the target "true" distribution in the cross-entropy loss to be applied
    
    loss = criterion(outputs, one_hot_labels)

    return loss

**Training and testing**

In [None]:
# By default, everything is loaded to cpu
net = net.to(DEVICE) # this will bring the network to GPU if DEVICE is cuda

cudnn.benchmark # Calling this optimizes runtime

test_acc_history = []  # this list shall contain 10 values (one for each seen class)

global_step_counter = 0
task_step_counter = 1 # Incremented by one every time one set of 10 classes is seen

seen_classes_counter = 0
seen_classes = []
old_outputs = None

for current_classes in classes:     # 10 cycles (over the 10 sets of classes)

    # Create a dataloader for the new 10 classes to train on
    training_subset_idx = subset_indices(cifar100_training, current_classes)
    training_subset = Subset(cifar100_training, training_subset_idx)
    training_dataloader = DataLoader(training_subset, shuffle=True, num_workers=4, batch_size=BATCH_SIZE, drop_last=True)
    
    # Initialize the optimizer and the scheduler
    parameters_to_optimize = net.parameters()
    optimizer = optim.SGD(parameters_to_optimize, lr=LR, momentum=MOMENTUM, weight_decay=WEIGHT_DECAY)
    scheduler = optim.lr_scheduler.MultiStepLR(optimizer, STEP_SIZE, gamma=GAMMA)



    # TRAINING PHASE
    # train on the next set of 10 unseen classes


    
    # 10 previously unseen classes are now being seen
    seen_classes_counter += 10
    seen_classes += current_classes.tolist()

    epoch_step_counter = 0
    pre_update_outputs_already_computed = False

    for epoch in range(NUM_EPOCHS):
        print(f'Starting epoch {epoch+1}/{NUM_EPOCHS}, LR = {scheduler.get_last_lr()}')
        t = time.time()

        net.train(True)
        
        total_training_corrects = 0
        running_loss = 0
    
        for images, labels in training_dataloader:

            # Bring data over the device of choice
            images = images.to(DEVICE)
            labels = labels.to(DEVICE)

            optimizer.zero_grad() # Zero-ing the gradients

            if(task_step_counter >= 2):
                old_net = torch.load('resNet_task' + str(task_step_counter - 1) + '.pt').train(False)
                old_outputs = old_net(images)
            
            # Forward pass to the network
            outputs = net(images)
    
            one_hot_labels = torch.eye(100)[labels].to(DEVICE)

            loss = compute_loss(outputs, one_hot_labels,classes,task_step_counter, criterion, old_outputs)
            
            # Get predictions
            _, preds = torch.max(outputs.data, 1)

            # Update Corrects
            total_training_corrects += torch.sum(preds == labels.data).data.item()

            # Compute loss based on output and ground truth
            one_hot_labels = torch.eye(100)[labels].to(DEVICE)
            loss = compute_loss(outputs, one_hot_labels, classes, task_step_counter, criterion, old_outputs)

            # Log loss
            if epoch_step_counter % LOG_FREQUENCY == 0:
                print(f'Step: {epoch_step_counter}, training loss: {loss.item()}')
 
            # Compute gradients for each layer and update weights
            loss.backward()  # backward pass: computes gradients
            optimizer.step() # update weights based on accumulated gradients

            # Update counters
            epoch_step_counter += 1
            global_step_counter += 1


        current_classes_training_accuracy = total_training_corrects / (float(len(training_dataloader)) * BATCH_SIZE)
        print(f'------ Epoch {epoch+1}/{NUM_EPOCHS} of the training on the {task_step_counter}° set of classes has ended.\n------ Training accuracy (only on the current 10 classes): {current_classes_training_accuracy}\n------ Elapsed time for this epoch: {time.time() - t}')

        # Step the scheduler
        scheduler.step()

    
    torch.save(net, 'resNet_task{0}.pt'.format(task_step_counter))

    task_step_counter += 1



    # TEST PHASE
    # evaluate the network on (all) the test images of the classes seen so far



    # Create a dataloader for the test images of all the classes seen so far
    test_subset_idx = subset_indices(cifar100_test, seen_classes)
    test_subset = Subset(cifar100_test, test_subset_idx)
    test_dataloader = DataLoader(test_subset, shuffle=False, num_workers=4, batch_size=BATCH_SIZE)

    net.train(False)    # Set Network to evaluation mode

    all_test_labels = torch.LongTensor([]).to(DEVICE)
    all_test_preds = torch.LongTensor([]).to(DEVICE)

    total_test_corrects = 0

    with torch.no_grad():
        for images, labels in test_dataloader:
            images = images.to(DEVICE)
            labels = labels.to(DEVICE)

            # Forward Pass
            outputs = net(images)

            # Get predictions (for reporting purposes)
            _, preds = torch.max(outputs.data, 1)
            current_batch_corrects = torch.sum(preds == labels.data).data.item()

            # Update the running amount of correct predictions (for plotting purposes)
            total_test_corrects += current_batch_corrects

            # Store the current batch labels and preds; needed later for plotting the heatmap
            all_test_labels = torch.cat((all_test_labels, labels), 0)
            all_test_preds = torch.cat((all_test_preds, preds), 0)
    
        test_accuracy = total_test_corrects / float(len(test_subset))
        test_acc_history.append(test_accuracy)

        print(f'---------- Training ended on the {int(seen_classes_counter / 10)}° set of classes\n---------- Test accuracy: {test_accuracy}')


torch.save(net, './SavedNet') # SavedNet on drive folder


In [None]:
print(test_acc_history)

**Heatmap Confusion Matrix**

In [None]:
# Show the confusion matrix for the test set as a heatmap
show_heatmap_CM(all_test_labels.cpu(), all_test_preds.cpu())