In [None]:
################################################################################
# the following code was cleaned from comments that refers to iCaRL methods
# in order to increase cleanliness of the code. 
# For further details on the implementation of iCaRL methods,
# check the commented version (iCaRL.ipynb)
#
################################################################################
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader, ConcatDataset
from torch.backends import cudnn
from torch.autograd import Variable

import numpy as np
import numpy.ma as ma
from math import floor
from copy import deepcopy
import random

from sklearn.svm import SVC


sigmoid = nn.Sigmoid() # Sigmoid function
softmax = nn.Softmax(dim=None)
logsoftmax = nn.LogSoftmax()

class Exemplars(torch.utils.data.Dataset):
    def __init__(self, exemplars, transform=None):
    
        self.dataset = []
        self.targets = []

        for y, exemplar_y in enumerate(exemplars):
            self.dataset.extend(exemplar_y)  
            self.targets.extend([y] * len(exemplar_y))  # return: [y,y,y,y, ... ] until len(exemplar_y)

        self.transform = transform
    
    def __getitem__(self, index):
        image = self.dataset[index]
        target = self.targets[index]

        if self.transform is not None:
            image = self.transform(image)

        return image, target

    def __len__(self):
        return len(self.targets)


class iCaRL_svm:

    def __init__(self, device, net, lr, momentum, weight_decay, milestones, gamma, num_epochs, batch_size, train_transform, test_transform, criterion, class_params, all_data):
        self.device = device
        self.net = net

        # Set hyper-parameters
        self.LR = lr
        self.MOMENTUM = momentum
        self.WEIGHT_DECAY = weight_decay
        self.MILESTONES = milestones
        self.GAMMA = gamma
        self.NUM_EPOCHS = num_epochs
        self.BATCH_SIZE = batch_size
        
        # Set transformations
        self.train_transform = train_transform
        self.test_transform = test_transform

        # List of exemplar sets. Each set contains memory_size/num_classes exemplars
        # with num_classes the number of classes seen until now by the network.
        self.exemplars = []

        # Initialize the copy of the old network, used to compute outputs of the
        # previous network for the distillation loss, to None. This is useful to
        # correctly apply the first function when training the network for the
        # first time.
        self.old_net = None

        # Maximum number of exemplars
        self.memory_size = 2000
    
        # Loss function
        self.criterion = criterion

        # If True, test on the best model found (e.g., minimize loss). If False,
        # test on the last model build (of the last epoch).
        self.VALIDATE = False

        # classifier with parameters 
        self.clf = SVC()
        self.parameters = class_params 
        self.all_data = all_data



    def classify(self, batch, train_dataset=None):
        
        batch_features = self.extract_features(batch) # (batch size, 64) 
        for i in range(batch_features.size(0)):       
            batch_features[i] = batch_features[i]/batch_features[i].norm() # Normalize sample feature representation
        batch_features = batch_features.to(self.device)
        
        ########################################

        # train classifier part:
        # this part of code is executed just once since we need to train the classifier
        # just one time and not at each invokation of classify method

        if self.X is None: #check if we already computed the training dataset for the classifier
             # start building the dataset
            self.X = []
            num_classes = len(self.exemplars)
            if train_dataset is not None and self.all_data == True:
                # initialize:
                # - X will contain current trainset + exemplars features
                # - y will contain the labels 
                self.X = []
                self.y = []
                #####retrieve data from train_set (if we have it):#####
                for train_sample, label in train_dataset:                                     
                    features = self.extract_features(train_sample, batch=False, transform=self.test_transform) 
                    features = features/features.norm()
                    self.X.append(features.to("cpu").numpy())
                    self.y.append(label)

            #####retrieve exemplars and concatenate them ######       
            for y in range(num_classes):  
                for exemplar in self.exemplars[y]: 
                    features = self.extract_features(exemplar, batch=False, transform=self.test_transform)
                    features = features/features.norm() # Normalize the feature representation of the exemplar
                    
                    self.X.append(features.to("cpu").numpy())
                    self.y.append(y)


            if self.all_data == True:
              print('Training SVM classifier with all data...')
            else: 
              print('Training SVM classifier with only exemplars...')

            # Initialize classifier
            self.clf.set_params(**self.parameters)
            self.X = np.vstack(self.X)
            self.y = np.array(self.y)
            self.clf.fit(self.X,self.y)
            print("Classifier training ended. ")
  
        #########################################################
          
        preds = self.clf.predict(batch_features.to("cpu").numpy())
        
        return torch.from_numpy(preds).to("cuda")
    
    def extract_features(self, sample, batch=True, transform=None):
 
        assert not (batch is False and transform is None), "if a PIL image is passed to extract_features, a transform must be defined"

        self.net.train(False)
        if self.best_net is not None: self.best_net.train(False)
        if self.old_net is not None: self.old_net.train(False)

        if batch is False: # Treat sample as single PIL image 
            sample = transform(sample)
            sample = sample.unsqueeze(0) # https://stackoverflow.com/a/59566009/6486336, (3, 32, 32) --> (1, 3, 32, 32)

        sample = sample.to(self.device)


        if self.VALIDATE: 
            features = self.best_net.features(sample)
        else:
            features = self.net.features(sample)   

        if batch is False:
            features = features[0] 

        return features

    def incremental_train(self, split, train_dataset, val_dataset):
    
        if split is not 0:

            self.increment_classes(10)

  
        train_logs = self.update_representation(train_dataset, val_dataset)
   
        
  
        num_classes = self.output_neurons_count()
        m = floor(self.memory_size / num_classes)

        print(f"Target number of exemplars per class: {m}")
        print(f"Target total number of exemplars: {m*num_classes}")

        for y in range(len(self.exemplars)):
            self.exemplars[y] = self.reduce_exemplar_set(self.exemplars[y], m)

  
        new_exemplars = self.construct_exemplar_set_rand(train_dataset, m) 
        self.exemplars.extend(new_exemplars)

        return train_logs

    def update_representation(self, train_dataset, val_dataset): 

        print(f"Length of exemplars set: {sum([len(self.exemplars[y]) for y in range(len(self.exemplars))])}")

        exemplars_dataset = Exemplars(self.exemplars, self.train_transform)
        train_dataset_with_exemplars = ConcatDataset([exemplars_dataset, train_dataset])

        # Train the network on combined dataset
        train_logs = self.train(train_dataset_with_exemplars, val_dataset) 

        
        self.old_net = deepcopy(self.net)

        return train_logs

    def construct_exemplar_set_rand(self, dataset, m):
     
        dataset.dataset.disable_transform()
        
        #storing images of a split
        samples = [[] for _ in range(10)]
        for image, label in dataset:
            label = label % 10 # Map labels to 0-9 range: example: 21 -> 1, 22 -> 2 etc.
            samples[label].append(image)

        # we can now turn back the transformation on train_dataset
        dataset.dataset.enable_transform()

        exemplars = [[] for _ in range(10)]     
        #random sampling
        for y in range(10):
            print(f"Randomly extracting exemplars from class {y} of current split... ", end="")
            # Randomly choose m samples from samples[y] without replacement
            exemplars[y] = random.sample(samples[y], m)
            print(f"Extracted {len(exemplars[y])} exemplars.")

        return exemplars

    def construct_exemplar_set_herding(self, dataset, m): 

        dataset.dataset.disable_transform()

        samples = [[] for _ in range(10)]
        for image, label in dataset:
            label = label % 10 # Map labels to 0-9 range
            samples[label].append(image)

        dataset.dataset.enable_transform()

        # Initialize exemplar sets
        exemplars = [[] for _ in range(10)]

        # Iterate over classes
        for y in range(10):
            print(f"Extracting exemplars from class {y} of current split... ", end="")

            # Transform samples to tensors and apply normalization
            transformed_samples = torch.zeros((len(samples[y]), 3, 32, 32)).to(self.device)
            for i in range(len(transformed_samples)):  
                transformed_samples[i] = self.test_transform(samples[y][i])

            # Extract features from samples
            samples_features = self.extract_features(transformed_samples).to(self.device)

            # Compute the feature mean of the current class
            features_mean = samples_features.mean(dim=0)

            # Initializes indices vector, containing the index of each exemplar chosen
            idx = []

            # See iCaRL algorithm 4
            for k in range(1, m+1): # k = 1, ..., m -- Choose m exemplars
                if k == 1: # No exemplars chosen yet, sum to 0 vector
                    f_sum = torch.zeros(64).to(self.device)
                else: # Sum of features of all exemplars chosen until now (j = 1, ..., k-1)
                    f_sum = samples_features[idx].sum(dim=0)

                # Compute argument of argmin function
                f_arg = torch.norm(features_mean - 1/k *(samples_features + f_sum), dim=1) 
                
        
                mask = np.zeros(len(f_arg), int)
                mask[idx] = 1
                f_arg_masked = ma.masked_array(f_arg.cpu().detach().numpy(), mask=mask) 

                # Compute the nearest available exemplar
                exemplar_idx = np.argmin(f_arg_masked)

                idx.append(exemplar_idx)
            
            # Save exemplars to exemplar set
            for i in idx:
                exemplars[y].append(samples[y][i])
            
            print(f"Extracted {len(exemplars[y])} exemplars.")
            
        return exemplars

    def reduce_exemplar_set(self, exemplar_set, m):
   
        return exemplar_set[:m]
    
  
    def train(self, train_dataset, val_dataset):
     
        parameters_to_optimize = self.net.parameters()
        self.optimizer = optim.SGD(parameters_to_optimize, 
                                   lr=self.LR,
                                   momentum=self.MOMENTUM,
                                   weight_decay=self.WEIGHT_DECAY)
        
        # Define the learning rate decaying policy
        self.scheduler = optim.lr_scheduler.MultiStepLR(self.optimizer,
                                                        milestones=self.MILESTONES,
                                                        gamma=self.GAMMA)

        self.train_dataloader = DataLoader(train_dataset, batch_size=self.BATCH_SIZE, shuffle=True, num_workers=4, drop_last=True)
        self.val_dataloader = DataLoader(val_dataset, batch_size=self.BATCH_SIZE, shuffle=True, num_workers=4, drop_last=True)

        self.net = self.net.to(self.device)
        if self.old_net is not None: self.old_net = self.old_net.to(self.device)

        cudnn.benchmark  
        self.best_val_loss = float('inf')
        self.best_val_accuracy = 0
        self.best_train_loss = float('inf')
        self.best_train_accuracy = 0
        
        self.best_net = None
        self.best_epoch = -1

        for epoch in range(self.NUM_EPOCHS):
           
            train_loss, train_accuracy = self.do_epoch(epoch+1)
        
            # Validate after each epoch 
            val_loss, val_accuracy = self.validate()    

            # Validation criterion: best net is the one that minimizes the loss
            # on the validation set.
            if self.VALIDATE and val_loss < self.best_val_loss:
                self.best_val_loss = val_loss
                self.best_val_accuracy = val_accuracy
                self.best_train_loss = train_loss
                self.best_train_accuracy = train_accuracy

                self.best_net = deepcopy(self.net)
                self.best_epoch = epoch
                print("Best model updated")

        if self.VALIDATE:
            val_loss = self.best_val_loss
            val_accuracy = self.best_val_accuracy
            train_loss = self.best_train_loss
            train_accuracy = self.best_train_accuracy

            print(f"Best model found at epoch {self.best_epoch+1}")

        return train_loss, train_accuracy, val_loss, val_accuracy
    
    def do_epoch(self, current_epoch):
      

        # Set the current network in training mode
        self.net.train()
        if self.old_net is not None: self.old_net.train(False)
        if self.best_net is not None: self.best_net.train(False)

        running_train_loss = 0
        running_corrects = 0
        total = 0
        batch_idx = 0

        print(f"Epoch: {current_epoch}, LR: {self.scheduler.get_last_lr()}")

        for images, labels in self.train_dataloader:
            loss, corrects = self.do_batch(images, labels)

            running_train_loss += loss.item()
            running_corrects += corrects
            total += labels.size(0)
            batch_idx += 1

        self.scheduler.step()

        # Calculate average scores
        train_loss = running_train_loss / batch_idx # Average over all batches
        train_accuracy = running_corrects / float(total) # Average over all samples

        print(f"Train loss: {train_loss}, Train accuracy: {train_accuracy}")

        return train_loss, train_accuracy

    def do_batch(self, batch, labels):

        batch = batch.to(self.device)
        labels = labels.to(self.device)

        # Zero-ing the gradients
        self.optimizer.zero_grad()
        
       

        num_classes = self.output_neurons_count() # Number of classes seen until now, including new classes
        one_hot_labels = self.to_onehot(labels)[:, num_classes-10:num_classes]

        if self.old_net is None:
            # Network is training for the first time, so we only apply the
            # classification loss.
            targets = one_hot_labels

        else:
       
            old_net_outputs = sigmoid(self.old_net(batch))[:, :num_classes-10] 

            
            targets = torch.cat((old_net_outputs, one_hot_labels), dim=1)

        # Forward pass
        outputs = self.net(batch)
        loss = self.criterion(outputs, targets)

        # Get predictions
        _, preds = torch.max(outputs.data, 1)

        # Accuracy over NEW IMAGES, not over all images
        running_corrects = torch.sum(preds == labels.data).data.item() 

        # Backward pass: computes gradients
        loss.backward()

        self.optimizer.step()

        return loss, running_corrects

    def validate(self): 

        self.net.train(False)
        if self.old_net is not None: self.old_net.train(False)
        if self.best_net is not None: self.best_net.train(False)

        running_val_loss = 0
        running_corrects = 0
        total = 0
        batch_idx = 0

        for images, labels in self.val_dataloader:
            images = images.to(self.device)
            labels = labels.to(self.device)
            total += labels.size(0)

            # One hot encoding of new task labels 
            one_hot_labels = self.to_onehot(labels)

            # New net forward pass
            outputs = self.net(images)  
            loss = self.criterion(outputs, one_hot_labels) # loss type: BCE Loss with sigmoids over outputs

            running_val_loss += loss.item()

            # Get predictions
            _, preds = torch.max(outputs.data, 1)

            # Update the number of correctly classified validation samples
            running_corrects += torch.sum(preds == labels.data).data.item()

            batch_idx += 1

        # Calculate scores
        val_loss = running_val_loss / batch_idx
        val_accuracy = running_corrects / float(total)

        print(f"Validation loss: {val_loss}, Validation accuracy: {val_accuracy}")

        return val_loss, val_accuracy

    def test(self, test_dataset, train_dataset=None):
      

        self.net.train(False)
        if self.best_net is not None: self.best_net.train(False)  # Set Network to evaluation mode
        if self.old_net is not None: self.old_net.train(False)

        self.test_dataloader = DataLoader(test_dataset, batch_size=self.BATCH_SIZE, shuffle=True, num_workers=4)

        running_corrects = 0
        total = 0

        # To store all predictions
        all_preds = torch.tensor([])
        all_preds = all_preds.type(torch.LongTensor)
        all_targets = torch.tensor([])
        all_targets = all_targets.type(torch.LongTensor)

        # List on which train SVM for current split
        self.X = None 
        
        # Disable transformations for train_dataset, if available, as we will
        # need original PIL images from which to extract features.
        if train_dataset is not None: train_dataset.dataset.disable_transform()

        for images, labels in self.test_dataloader:
            images = images.to(self.device)
            labels = labels.to(self.device)

            total += labels.size(0)
            
            with torch.no_grad():
                preds = self.classify(images, train_dataset)

            running_corrects += torch.sum(preds == labels.data).data.item()

            all_targets = torch.cat(
                (all_targets.to(self.device), labels.to(self.device)), dim=0
            )

            all_preds = torch.cat(
                (all_preds.to(self.device), preds.to(self.device)), dim=0
            )

        if train_dataset is not None: train_dataset.dataset.enable_transform()

        # Calculate accuracy
        accuracy = running_corrects / float(total)  

        print(f"Test accuracy (iCaRL): {accuracy} ", end="")

        if train_dataset is None:
            print("(only exemplars)")
        else:
            print("(exemplars and training data)")

        return accuracy, all_targets, all_preds

    
    
    def increment_classes(self, n=10):
   

        in_features = self.net.fc.in_features  # size of each input sample
        out_features = self.net.fc.out_features  # size of each output sample
        weight = self.net.fc.weight.data
        bias = self.net.fc.bias.data

        self.net.fc = nn.Linear(in_features, out_features+n)
        self.net.fc.weight.data[:out_features] = weight
        self.net.fc.bias.data[:out_features] = bias
    
    def output_neurons_count(self):

        return self.net.fc.out_features
    

    
    def to_onehot(self, targets):

        num_classes = self.net.fc.out_features
        one_hot_targets = torch.eye(num_classes)[targets]

        return one_hot_targets.to(self.device)


