**Import libraries**


In [None]:
import time
from copy import deepcopy

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Subset, DataLoader
from torch.backends import cudnn
from torch.nn.functional import one_hot


import torchvision
from torchvision import transforms
from torchvision.datasets import CIFAR100
from torchvision.models import resnet

import numpy as np
from numpy.linalg import norm
import matplotlib.pyplot as plt
from sklearn.metrics import confusion_matrix
import pandas as pd
import seaborn as sns

from PIL import Image
from tqdm import tqdm

# Load resnet_cifar.py
import os
if not os.path.exists("./resnet_cifar.py"):
    !wget --no-check-certificate 'https://docs.google.com/uc?export=download&id=14ugdr3UoIWHmRCRS9KrJiQmCRK9WvCVj' -O resnet_cifar.py

--2020-12-14 22:18:18--  https://docs.google.com/uc?export=download&id=14ugdr3UoIWHmRCRS9KrJiQmCRK9WvCVj
Resolving docs.google.com (docs.google.com)... 172.217.164.142, 2607:f8b0:4004:814::200e
Connecting to docs.google.com (docs.google.com)|172.217.164.142|:443... connected.
HTTP request sent, awaiting response... 302 Moved Temporarily
Location: https://doc-10-1g-docs.googleusercontent.com/docs/securesc/ha0ro937gcuc7l7deffksulhg5h7mbp1/7i93udod75itm3il2gbejdslrvcnu2t8/1607984250000/05588580840073678935/*/14ugdr3UoIWHmRCRS9KrJiQmCRK9WvCVj?e=download [following]
--2020-12-14 22:18:19--  https://doc-10-1g-docs.googleusercontent.com/docs/securesc/ha0ro937gcuc7l7deffksulhg5h7mbp1/7i93udod75itm3il2gbejdslrvcnu2t8/1607984250000/05588580840073678935/*/14ugdr3UoIWHmRCRS9KrJiQmCRK9WvCVj?e=download
Resolving doc-10-1g-docs.googleusercontent.com (doc-10-1g-docs.googleusercontent.com)... 172.217.9.193, 2607:f8b0:4004:806::2001
Connecting to doc-10-1g-docs.googleusercontent.com (doc-10-1g-docs.goog

**Set Arguments**


In [None]:
DEVICE = 'cuda'

BATCH_SIZE = 128

K = 2000
NUM_EPOCHS = 70  # 70

LR = 2.0
MOMENTUM = 0.9
STEP_SIZE = [49,63]
GAMMA = 0.2
WEIGHT_DECAY = 1e-5
LOG_FREQUENCY = 10

# Random seeds
np.random.seed(653)
torch.manual_seed(653)
torch.cuda.manual_seed(653)

**Some utility functions**

In [None]:
def compute_means_of_exemplars(net, dataset, exemplars_idx, n):
    """
    Compute the L2-normalized mean-of-exemplars (AKA average feature vectors) for the first n classes seen
    Returns:
        a (seen_classes_counter, 64) numpy array whose rows are the mean-of-exemplars of the seen_classes_counter classes seen so far, computed using the current exemplars
    """
    if n == 0:    # if seen_classes_counter = 0 then return an empty array
        return np.array([]).reshape((0,64))
    
    class_means = []    # this will contain seen_classes_counter lists; i-th list will contain the L2-normalized mean-of-exemplars for the i-th seen class

    for i in range(n):   # for each i-th seen class

        # Retrieve the exemplars of the i-th seen class
        list_exemplars = [dataset[exemplar_idx][0] for exemplar_idx in exemplars_idx[i]]    # a list of tensors, containing the m exemplars of the i-th seen class
        exemplars_images = torch.stack(list_exemplars).cuda()     # a tensor containing the m exemplars of the i-th seen class    

        # Compute the average feature vector of the i-th seen class
        features = net.get_features(exemplars_images).cpu().data.numpy()    # feature vectors of the m exemplars of the i-th seen class (m, 64)
        features = features / np.array([np.linalg.norm(features, axis=1)]).T    # L2-normalization of features
        class_mean = features.mean(axis=0)  # average feature vector for the i-th seen class (64)
        class_mean = class_mean / np.linalg.norm(class_mean)    # L2 normalization
        class_means.append(class_mean)

    return np.array(class_means)    # (seen_classes_counter, 64)


def show_heatmap_CM(labels, predictions, order_of_labels):
    """
    Plot the confusion matrix as a heat map, given ground truth labels and the model predictions
    
    Params:
        labels: ground truth labels
        predictions: model predictions of the labels

    Return:
        Show the heatmap
        x axis: predicted class
        y axis: true class
    """
    fig, ax = plt.subplots(figsize=(9,9))

    # Build confusion matrix (as a 100x100 numpy array)
    cm = confusion_matrix(labels, predictions, labels=seen_classes) # classes are ordered on the x and y axis as in seen_classes array

    # Convert the confusion matrix to a 100x100 pandas dataframe (dimensions len(labels) x len(predictions) )
    df_cm = pd.DataFrame(cm, seen_classes, seen_classes)

    # Visualize the confusion matrix as a heat map
    ax = sns.heatmap(df_cm, cbar=False, square=False)

    plt.show()


def CIFAR100_with_indices(cls):
    """
    Modifies the CIFAR100 class to return a tuple (data, target, index)
    instead of just (data, target). index is a relative index.
    Ref: https://discuss.pytorch.org/t/how-to-retrieve-the-sample-indices-of-a-mini-batch/7948/19
    """
    def __getitem__(self, index):
        img, target = self.data[index], self.targets[index]
        img = Image.fromarray(img)
        if self.transform is not None:
            img = self.transform(img)
        if self.target_transform is not None:
            target = self.target_transform(target)
        return img, target, index

    return type(cls.__name__, (cls,), {'__getitem__': __getitem__,})


flatten = lambda l: [item for sublist in l for item in sublist]     # from list of lists to flat list


def subset_indices(dataset, classes, exemplars_idx=[], no_exemplars=False):
    """
    Returns the indices for the creation of a subset of a dataset containing only the images of the specified classes AND the exemplars of the previous classes
    """

    indices = []

    # Current classes
    for _, img_label, img_index in dataset:
        if img_label in classes:
            indices.append(img_index)     # append the index of those images belonging to class c

    # Exemplars of previous classes
    flat_list_exemplars_idx = flatten(exemplars_idx)
    indices.extend(flat_list_exemplars_idx)

    return indices


**Define Data Processing**

In [None]:
# Define transforms for training phase
train_transform = transforms.Compose([transforms.RandomCrop(32, padding=4),
                                      transforms.RandomHorizontalFlip(p=0.5),
                                      transforms.ToTensor(), # Turn PIL Image to torch.Tensor
                                      transforms.Normalize(mean=[0.5070751592371323, 0.48654887331495095, 0.4409178433670343],
                                                            std=[0.2673342858792401, 0.2564384629170883, 0.27615047132568404]) # Normalizes tensor with mean and standard deviation
])

# Define transforms for the evaluation phase
eval_transform = transforms.Compose([transforms.Pad(4, fill=0, padding_mode='constant'),
                                     transforms.TenCrop(32),     # TenCrop crops the input image at the four corners and at the center with 32x32 cropping sections
                                     transforms.Lambda(lambda crops: torch.stack([transforms.ToTensor()(crop) for crop in crops])),
                                     transforms.Lambda(lambda crops: torch.stack([transforms.Normalize(mean=[0.5070751592371323, 0.48654887331495095, 0.4409178433670343],
                                                                                                       std=[0.2673342858792401, 0.2564384629170883, 0.27615047132568404])(crop) for crop in crops]))])


**Prepare Dataset**

In [None]:
# This new version of the CIFAR100 class now returns (img, target, index) when you loop on it with dataloaders
CIFAR100 = CIFAR100_with_indices(CIFAR100)

# Load CIFAR100 training and test datasets
cifar100_training = CIFAR100(root='./data', train=True, download=True, transform=train_transform)
#cifar100_training_exemplars = CIFAR100(root='./data', train=True, transform=eval_transform)
cifar100_test = CIFAR100(root='./data', train=False, transform=eval_transform)

# Check dataset sizes
print(f'Training set size: {len(cifar100_training)}')
print(f'Test set size: {len(cifar100_test)}')


# Create an array with a random permutation of the 100 classes (0,1,...,99), with 10 rows of 10 classes each
classes = np.random.permutation(np.arange(100)).reshape((10,10))

Downloading https://www.cs.toronto.edu/~kriz/cifar-100-python.tar.gz to ./data/cifar-100-python.tar.gz


HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))

Extracting ./data/cifar-100-python.tar.gz to ./data
Training set size: 50000
Test set size: 10000


**Prepare Network**

In [None]:
from resnet_cifar import resnet32
net = resnet32(num_classes = 100)   # Loading ResNet32 model

**Define loss function**

In [None]:
# Classification loss + distillation loss
criterion = nn.BCEWithLogitsLoss(reduction='mean')   # but: reduction='sum' is as seen in iCaRL paper (mean is not applied anywhere in the loss)

# All the loss are BCEwithLogitsLoss
def compute_loss(outputs, one_hot_labels, classes, task_step_counter, criterion = criterion, old_outputs = None):
    """
    Computes the loss presented in the iCaRL paper (classification loss + distillation loss),
    which is basically a binary cross-entropy loss, in which the "true distribution" vector y in the formula also includes the values of some pre-updated network outputs.
    Params:
        outputs: (BATCH_SIZE, NUM_CLASSES) matrix of logits of the current batch of images
        one_hot_labels: (BATCH_SIZE, NUM_CLASSES) each row is the one-hot-encoded ground-truth label of the images of the current batch.
        classes: the array with the random permutation of the 100 classes (10 rows with 10 classes each)
        old_outputs: (BATCH_SIZE, NUM_CLASSES) matrix of the pre-updated network outputs of all the nodes (not only those associated with the previously seen t-10 classes), AKA q_i values. It is None when the first set of classes is encountered.
        criterion: BCEWithLogitsLoss
    Returns:
        The fraction of the total loss associated with the current batch of images
    """

    if old_outputs is not None:
        sig = nn.Sigmoid()
        columns_idx_of_pre_updated_outputs = seen_classes[:-10]     # indices of the columns associated with the q_i values of the previously seen classes
        one_hot_labels[:, columns_idx_of_pre_updated_outputs] = sig(old_outputs[:,columns_idx_of_pre_updated_outputs])  # each row of this matrix is now the target "true" distribution in the cross-entropy loss to be applied
    
    loss = criterion(outputs, one_hot_labels)

    return loss


**Training and testing**

In [None]:
net = net.to(DEVICE)        # this will bring the network to GPU if DEVICE is cuda

cudnn.benchmark             # Calling this optimizes runtime

test_acc_history = []       # this list shall contain 10 values (one for each seen class)
test_acc_history_CNN = []   # this list shall contain 10 values (one for each seen class)

global_step_counter = 0     # Incremented by one every time a training iteration inside an epoch ends
task_step_counter = 1       # Incremented by one every time one set of 10 classes is seen

seen_classes_counter = 0    # Incremented by 10 every time one set of 10 classes is seen
seen_classes = []           # Is appended with a list of the 10 classes currently seen, each time they are seen

old_outputs = None

exemplars_idx = [ [] for _ in range(100) ]     # a list of 100 lists which will contain m int values each; the i-th inner list will contain cifar100_training indices of exemplar images for the i-th class seen

for current_classes in classes:     # 10 cycles (over the 10 sets of classes)

    # Create a dataloader for the new 10 classes to train on
    augm_training_subset_idx = subset_indices(cifar100_training, current_classes, exemplars_idx)   # augm stands for augmented with exemplars of previous classes
    augm_training_subset = Subset(cifar100_training, augm_training_subset_idx)
    augm_training_dataloader = DataLoader(augm_training_subset, shuffle=True, num_workers=4, batch_size=BATCH_SIZE, drop_last=True)
    
    # Initialize the optimizer and the scheduler
    parameters_to_optimize = net.parameters()
    optimizer = optim.SGD(parameters_to_optimize, lr=LR, momentum=MOMENTUM, weight_decay=WEIGHT_DECAY)
    scheduler = optim.lr_scheduler.MultiStepLR(optimizer, STEP_SIZE, gamma=GAMMA)



    #################### UPDATE REPRESENTATION ####################
    # train on the next set of 10 unseen classes


    
    # 10 previously unseen classes are now being seen
    seen_classes_counter += 10
    seen_classes += current_classes.tolist()

    epoch_step_counter = 0      # Incremented by one each time an epochs end

    for epoch in range(NUM_EPOCHS):
        print(f'Starting epoch {epoch+1}/{NUM_EPOCHS}, LR = {scheduler.get_last_lr()}')
        t = time.time()

        net.train(True)
        
        total_training_corrects = 0
        running_loss = 0

        for images, labels, _ in augm_training_dataloader:

            # Bring data over the device of choice
            images = images.to(DEVICE)
            labels = labels.to(DEVICE)

            optimizer.zero_grad() # Zero-ing the gradients

            # If this is not the first time classes are seen, compute the output logits of the network obtained previously
            if task_step_counter >= 2:
                old_net = torch.load('resNet_task' + str(task_step_counter - 1) + '.pt').train(False)
                old_outputs = old_net(images)
            
            # Forward pass to the network
            outputs = net(images)

            # Compute loss
            one_hot_labels = torch.eye(100)[labels].to(DEVICE)
            loss = compute_loss(outputs, one_hot_labels,classes,task_step_counter, criterion, old_outputs)
            
            # Get predictions
            _, preds = torch.max(outputs.data, 1)

            # Update the running amount of correct predictions
            total_training_corrects += torch.sum(preds == labels.data).data.item()

            # Log loss
            if epoch_step_counter % LOG_FREQUENCY == 0:
                print(f'Step: {epoch_step_counter}, training loss: {loss.item()}')
 
            # Compute gradients for each layer and update weights
            loss.backward()  # backward pass: computes gradients
            optimizer.step() # update weights based on accumulated gradients

            # Update the local and global step counter
            epoch_step_counter += 1
            global_step_counter += 1

        # Compute the accuracy on the current 10 classes (only for reporting purposes)
        current_classes_training_accuracy = total_training_corrects / (float(len(augm_training_dataloader)) * BATCH_SIZE)
        print(f'------ Epoch {epoch+1}/{NUM_EPOCHS} of the training on the {task_step_counter}° set of classes has ended')
        print(f'------ Training accuracy (only on the current 10 classes): {current_classes_training_accuracy}\n------ Elapsed time for this epoch: {time.time() - t}')

        # Step the scheduler
        scheduler.step()

    # Save the
    torch.save(net, f'resNet_task{task_step_counter}.pt')



    #################### EXEMPLARS MANAGEMENT ####################
    # exemplar management part



    ###### Reduce exemplar sets of previously seen classes

    m = int(np.ceil(K / float(seen_classes_counter)))   # NB: this can result in a number of exemplars slightly larger than K, but that's how it's done in the original iCaRL code (see main_resnet_tf.py)
    for i in range(seen_classes_counter - 10):
        exemplars_idx[i] = exemplars_idx[i][:m]
    
    ###### Construct exemplar set for each of the current 10 classes seen

    training_subset_idx = subset_indices(cifar100_training, current_classes)   # list of 5000 indices
    training_subset = Subset(cifar100_training, training_subset_idx)    # subset of all training images which label is in current_classes
    
    current_classes_means = []  # at the end of the next for loop, this will be (10,64)

    for i, cls in enumerate(current_classes): # for each current class

        # Retrieve the 500 images of class cls
        class_subset_idx = subset_indices(training_subset, [cls])   # absolute indices of the 500 images of class cls
        class_subset_images = torch.stack([cifar100_training[j][0] for j in class_subset_idx]).cuda()  # tensor of the 500 images of class cls

        # Compute the (exact) class mean (AKA mean feature vector) μ of class cls
        features = net.get_features(class_subset_images).cpu().data.numpy()     # (len(class_subset), net.fc.in_features) = (500, 64)-d matrix
        features = features / np.array([np.linalg.norm(features, axis=1)]).T    # normalize features (the paper mentions an L2-normalization of features)
        class_mean = features.mean(axis=0)  # current class mean of features (64)
        class_mean = class_mean / np.linalg.norm(class_mean)    # L2 normalization x / ||x||
        current_classes_means.append(class_mean)

        # Construct exemplar set for class cls
        local_exemplars_idx = []    # relative indices of the 500 images of class cls
        for k in range(500):  # m; putting 500 here is a TRICK: all the 500 currently seen images will be used to compute the currently seen classes means, not the m exemplars, so they will be the exact class means
            temp = features[local_exemplars_idx]    # features of the k exemplars already found
            temp = np.sum(temp, axis=0)             #
            temp = features + temp                  # broadcasting! (500, 64)
            temp = temp / float(k+1)                # a tiny broadcasting, (500, 64)
            temp = class_mean - temp                # broadcasting! (500, 64)
            temp = np.linalg.norm(temp, axis=1)     # final norm, (500)
            argsort = np.argsort(temp)

            for j in argsort:
                if j not in local_exemplars_idx:
                    exemplars_idx[(seen_classes_counter-10) + i].append(class_subset_idx[j])    # append the absolute index of the exemplar p_k to the exemplar set P_i of the i-th class
                    local_exemplars_idx.append(j)
                    break



    #################### TEST PHASE ####################
    # NMoE classifier (see paper for mathematical details on the algorithm)

    # Create a dataloader for the test images of all (and only) the classes seen so far
    test_subset_idx = subset_indices(cifar100_test, seen_classes)
    test_subset = Subset(cifar100_test, test_subset_idx)
    test_dataloader = DataLoader(test_subset, shuffle=False, num_workers=4, batch_size=BATCH_SIZE)

    net.train(False)    # Set Network to evaluation mode

    # Compute the L2-normalized means-of-exemplars for the previously seen (seen_classes_counter-10) classes, and the L2-normalized (exact) class means for the currently seen 10 classes (TRICK)
    class_means = compute_means_of_exemplars(net, cifar100_training, exemplars_idx, seen_classes_counter)    # (seen_classes_counter, 64); USE ONLY WITH TRICK: means of exemplars of the previously seen classes and (exact) class means of the currently seen classes
    
    
    all_test_labels = torch.LongTensor([]).to(DEVICE)
    all_test_preds = torch.LongTensor([]).to(DEVICE)
    all_test_labels_CNN = torch.LongTensor([]).to(DEVICE)
    all_test_preds_CNN = torch.LongTensor([]).to(DEVICE)

    total_test_corrects = 0
    total_test_corrects_CNN = 0

    with torch.no_grad():
        for images, labels, _ in tqdm(test_dataloader):    # images is a (128,10,3,32,32) tensor, labels is a (128) tensor
            images = images.to(DEVICE)
            labels = labels.to(DEVICE)
            bs, ncrops, c, h, w = images.size()



            ##### CLASSIFICATION WITH NME
            # Compute the L2-normalized feature vectors of the images of the current batch of test images
            features = net.get_features(images.view(-1, c, h, w)).cpu().data.numpy()    # (1280, 64)-d matrix; each row is a ϕ(x)
            features = features / np.array([np.linalg.norm(features, axis=1)]).T    # L2 normalization of the ϕ(x)'s
            #features = features.view(bs, ncrops, -1).mean(axis=1)   # (128,10,64)
            
            # Get predictions with the NMoE classifier
            preds = []
            for feature in features:
                pred_idx = np.argmin( np.linalg.norm(feature - class_means, axis=1) )    # relative index of the closest mean of exemplars; this indices a list in the list of lists exemplars_idx
                pred = cifar100_training[ exemplars_idx[pred_idx][0] ][1]   # retrieve the label associated with pred_idx; note that all the exemplars which indices are in the same list of exemplars_idx have the same label by definition
                preds.append(pred)
            preds = torch.LongTensor(preds).view(bs, ncrops)    # from (1280) to (128,10)
            # Hard major voting
            class_counts = torch.stack([torch.bincount(row, minlength=100) for row in preds])   # (128,100)
            preds = torch.argmax(class_counts, dim=1)   # (128)

            preds = preds.to(DEVICE)     # to tensor (needed for the next line of code)
            current_batch_corrects = torch.sum(preds == labels.data).data.item()


            # Update the running statistics (for plotting purposes)
            total_test_corrects += current_batch_corrects

            # Store the current batch labels and preds; needed later for plotting the heatmap
            all_test_labels = torch.cat((all_test_labels, labels), 0)
            all_test_preds = torch.cat((all_test_preds, preds), 0)



            ##### CLASSIFICATION WITH CNN
            # Forward pass to the network
            outputs = net(images.view(-1, c, h, w))    # fuse batch size and ncrops (in input: (1280,3,32,32), in output: (1280,100))
            outputs = outputs.view(bs, ncrops, -1).mean(axis=1)      # avg over crops (128,10,100)----after mean---->(128,100)
            
            # Get predictions
            _, preds_CNN = torch.max(outputs.data, 1)
            current_batch_corrects_CNN = torch.sum(preds_CNN == labels.data).data.item()

            # Update the running amount of correct predictions
            total_test_corrects_CNN += current_batch_corrects_CNN

            # Store the current batch labels and preds; needed later for plotting the heatmap
            all_test_preds_CNN = torch.cat((all_test_preds, preds), 0)



        test_accuracy = total_test_corrects / float(len(test_subset))
        test_acc_history.append(test_accuracy)
        test_accuracy_CNN = total_test_corrects_CNN / float(len(test_subset))
        test_acc_history_CNN.append(test_accuracy_CNN)
    




    ### End of training on the 10 currently seen classes


    # 10 new classes are to be seen next: update the counter
    task_step_counter += 1

    print(f'--------------- Training ended on the {int(seen_classes_counter / 10)}° set of classes')
    print(f'--------------- Test accuracy (NME): {test_accuracy}')
    print(f'--------------- Test accuracy (CNN): {test_accuracy_CNN}')

    ###### Reduce exemplar sets of currently seen classes (this is because of the trick)
    for i in range( seen_classes_counter-10, seen_classes_counter ):
        exemplars_idx[i] = exemplars_idx[i][:m]


torch.save(net, 'final_net')  # save the final net on Colab's file system


Starting epoch 1/70, LR = [2.0]
Step: 0, training loss: 0.7913661003112793
Step: 10, training loss: 0.04946229234337807
Step: 20, training loss: 0.035522427409887314
Step: 30, training loss: 0.030085602775216103
------ Epoch 1/70 of the training on the 1° set of classes has ended
------ Training accuracy (only on the current 10 classes): 0.17447916666666666
------ Elapsed time for this epoch: 2.9646830558776855
Starting epoch 2/70, LR = [2.0]
Step: 40, training loss: 0.029591649770736694
Step: 50, training loss: 0.02491151914000511
Step: 60, training loss: 0.026038672775030136
Step: 70, training loss: 0.024393809959292412
------ Epoch 2/70 of the training on the 1° set of classes has ended
------ Training accuracy (only on the current 10 classes): 0.40965544871794873
------ Elapsed time for this epoch: 2.7888123989105225
Starting epoch 3/70, LR = [2.0]
Step: 80, training loss: 0.022532057017087936
Step: 90, training loss: 0.022642306983470917
Step: 100, training loss: 0.021246695891022

100%|██████████| 8/8 [00:08<00:00,  1.04s/it]


--------------- Training ended on the 1° set of classes
--------------- Test accuracy (NME): 0.862
--------------- Test accuracy (CNN): 0.87
Starting epoch 1/70, LR = [2.0]
Step: 0, training loss: 0.0891089215874672
Step: 10, training loss: 0.056843239814043045
Step: 20, training loss: 0.04061124101281166
Step: 30, training loss: 0.03632091358304024
Step: 40, training loss: 0.03380727395415306
Step: 50, training loss: 0.034197792410850525
------ Epoch 1/70 of the training on the 2° set of classes has ended
------ Training accuracy (only on the current 10 classes): 0.2760416666666667
------ Elapsed time for this epoch: 7.841755390167236
Starting epoch 2/70, LR = [2.0]
Step: 60, training loss: 0.032131828367710114
Step: 70, training loss: 0.0278193186968565
Step: 80, training loss: 0.028527608141303062
Step: 90, training loss: 0.027653750032186508
Step: 100, training loss: 0.02910902164876461
------ Epoch 2/70 of the training on the 2° set of classes has ended
------ Training accuracy (o

KeyboardInterrupt: ignored

In [None]:
# Print the ten test accuracies obtained
print(test_acc_history)
print(test_acc_history_CNN)

**Heatmap Confusion Matrix**

In [None]:
# Show the confusion matrix for the test set as a heatmap
show_heatmap_CM(all_test_labels.cpu(), all_test_preds.cpu(), order_of_labels=seen_classes)