# Final extension classifier

In [None]:
import torch
import torch.nn as nn
import torch.nn.init as init
import torch.optim as optim
from torch.utils.data import Dataset, Subset, DataLoader, ConcatDataset

from math import floor

from torchvision import transforms
from torchvision.models import resnet34

from sklearn.model_selection import train_test_split
from sklearn.metrics import confusion_matrix

import numpy as np
import numpy.ma as ma

from PIL import Image
from copy import deepcopy

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader, ConcatDataset
from torch.backends import cudnn
from torch.autograd import Variable

import numpy as np
import numpy.ma as ma
from math import floor
from copy import deepcopy
import random

sigmoid = nn.Sigmoid() # Sigmoid function



class Exemplars(torch.utils.data.Dataset):
    def __init__(self, exemplars, transform=None):
        
        self.dataset = []
        self.targets = []

        for y, exemplar_y in enumerate(exemplars):
            self.dataset.extend(exemplar_y)  
            self.targets.extend([y] * len(exemplar_y))  # return: [y,y,y,y, ... ] until len(exemplar_y)

        self.transform = transform
    
    def __getitem__(self, index):
        image = self.dataset[index]
        target = self.targets[index]

        if self.transform is not None:
            image = self.transform(image)

        return image, target

    def __len__(self):
        return len(self.targets)

class iCaRL_ext:
    
    def __init__(self, device, net, lr, momentum, weight_decay, milestones, gamma, num_epochs, batch_size, train_transform, test_transform):
        self.device = device
        self.net = net

        # Set hyper-parameters
        self.LR = lr
        self.MOMENTUM = momentum
        self.WEIGHT_DECAY = weight_decay
        self.MILESTONES = milestones
        self.GAMMA = gamma
        self.NUM_EPOCHS = num_epochs
        self.BATCH_SIZE = batch_size
        
        # Set transformations
        self.train_transform = train_transform
        self.test_transform = test_transform

        # List of exemplar sets: Each set contains memory_size/(2*num_classes) exemplars
        # with num_classes the number of classes seen until now by the network.
        # we have one exemplar set for the class mean computation and another one
        # for the mean distance computation
        self.exemplars_mean = []
        self.exemplars_distance = []

        # Initialize the copy of the old network, used to compute outputs of the
        # previous network for the distillation loss, to None. This is useful to
        # correctly apply the first function when training the network for the
        # first time.
        self.old_net = None

        # Maximum number of exemplars
        self.memory_size = 2000
    
        # Loss function
        self.criterion = nn.BCEWithLogitsLoss()

        # If True, test on the best model found (e.g., minimize loss). If False,
        # test on the last model build (of the last epoch).
        self.VALIDATE = False

    def classify(self, batch, train_dataset=None):

        batch_features = self.extract_features(batch) # (batch size, 64) 
        for i in range(batch_features.size(0)):       
            batch_features[i] = batch_features[i]/batch_features[i].norm() # Normalize sample feature representation
        batch_features = batch_features.to(self.device)
        
       
        if self.cached_means is None:
            print("Computing mean of exemplars and class radius... ", end="")

            self.cached_means = []
           
            self.cached_radius = []
           


            # Number of known classes
            num_classes = len(self.exemplars_mean)

            # Compute the means of classes with all the data available,
            # including training data which contains samples belonging to
            # the latest 10 classes. This will remove noise from the mean
            # estimate, improving the results.  
            if train_dataset is not None:
                train_features_list = [[] for _ in range(10)]


                for train_sample, label in train_dataset: 
                    features = self.extract_features(train_sample, batch=False, transform=self.test_transform)
                    features = features/features.norm()
                    train_features_list[label % 10].append(features)


            for y in range(num_classes):
                if (train_dataset is not None) and (y in range(num_classes-10, num_classes)): 
                    features_list = train_features_list[y % 10]
                else:
                    features_list = []
                
           
                for exemplar in self.exemplars_mean[y]: 
                    features = self.extract_features(exemplar, batch=False, transform=self.test_transform)
                    features = features/features.norm() # Normalize the feature representation of the exemplar
                    features_list.append(features)
                
                features_list = torch.stack(features_list)
                class_means = features_list.mean(dim=0)
                class_means = class_means/class_means.norm() # Normalize the class means

                self.cached_means.append(class_means)



                if (train_dataset is not None) and (y in range(num_classes-10, num_classes)): 
                    features_list2 = train_features_list[y % 10]
                else:
                    features_list2 = []
                
              
                for exemplar_distance in self.exemplars_distance[y]: 
                    features_distance = self.extract_features(exemplar_distance, batch=False, transform=self.test_transform)
                    features_distance = features_distance/features_distance.norm() # Normalize the feature representation of the exemplar
                    features_list2.append(features_distance)
                
                features_list2 = torch.stack(features_list2)

                class_radiuses = torch.norm(features_list2 - class_means, dim=1)    #we compute the mean class radius
                class_mean_radius = torch.mean(class_radiuses)

                self.cached_radius.append(class_mean_radius)


            self.cached_radius = torch.stack(self.cached_radius).to(self.device) 
            self.cached_means = torch.stack(self.cached_means).to(self.device)
            print("done")
        
        #Classification
        preds = []
        for i in range(batch_features.size(0)): 
            f_arg = torch.norm(batch_features[i] - self.cached_means, dim=1)
            new_f_arg = f_arg / self.cached_radius
            preds.append(torch.argmin(new_f_arg)) 
        return torch.stack(preds)



    
    def extract_features(self, sample, batch=True, transform=None):

        assert not (batch is False and transform is None), "if a PIL image is passed to extract_features, a transform must be defined" 

        self.net.train(False)
        if self.best_net is not None: self.best_net.train(False)
        if self.old_net is not None: self.old_net.train(False)

        if batch is False: # Treat sample as single PIL image 
            sample = transform(sample)
            sample = sample.unsqueeze(0) # https://stackoverflow.com/a/59566009/6486336, (3, 32, 32) --> (1, 3, 32, 32)

        sample = sample.to(self.device)
       
        

        if self.VALIDATE:
            features = self.best_net.features(sample)
        else:
            features = self.net.features(sample)   

        if batch is False:
            features = features[0] 

        return features

    def incremental_train(self, split, train_dataset, val_dataset):

        if split is not 0:
            # Increment the number of output nodes for the new network by 10
            #starting from 1 (at run 0 we already have 10 output nodes)
            self.increment_classes(10)

        # Improve network parameters upon receiving new classes. Effectively
        # train a new network starting from the current network parameters.

       
        train_logs = self.update_representation(train_dataset, val_dataset) 
        
       
        num_classes = self.output_neurons_count()
        m = floor(self.memory_size / (2*num_classes))     # we have 2 exemplar set, so each one will have half size of exemplars

        print(f"Target number of exemplars per class: {m}")
        print(f"Target total number of exemplars for the mean: {m*num_classes}")
        print(f"Target total number of exemplars for the distance: {m*num_classes}")

        # Reduce pre-existing exemplar sets in order to fit new exemplars:
 
        for y in range(len(self.exemplars_mean)):
            self.exemplars_mean[y] = self.reduce_exemplar_set(self.exemplars_mean[y], m)

        for y in range(len(self.exemplars_distance)):
            self.exemplars_distance[y] = self.reduce_exemplar_set(self.exemplars_distance[y], m)

        # Construct exemplar set for new classes: 
      
        #exemplar set per class mean using herding
        new_exemplars = self.construct_exemplar_set_herding(train_dataset, m) 
        self.exemplars_mean.extend(new_exemplars)

        #exemplar set for the distances
        new_exemplars = self.construct_distance_exemplar_set(train_dataset, m) 
        self.exemplars_distance.extend(new_exemplars)

        return train_logs



    def update_representation(self, train_dataset, val_dataset):

        print(f"Length of herding exemplars set: {sum([len(self.exemplars_mean[y]) for y in range(len(self.exemplars_mean))])}")
        print(f"Length of mean distance exemplars set: {sum([len(self.exemplars_distance[y]) for y in range(len(self.exemplars_distance))])}")
        
        exemplars_mean_dataset = Exemplars(self.exemplars_mean, self.train_transform)
        exemplars_distance_dataset = Exemplars(self.exemplars_distance, self.train_transform)

        # we train just with herding exemplar 
        train_dataset_with_exemplars = ConcatDataset([exemplars_mean_dataset, exemplars_distance_dataset , train_dataset])
        
        # Train the network on combined dataset
        train_logs = self.train(train_dataset_with_exemplars, val_dataset) 

        # Keep a copy of the current network in order to compute its outputs for
        # the distillation loss while the new network is being trained.
        self.old_net = deepcopy(self.net)

        return train_logs

   

   

    def construct_exemplar_set_herding(self, dataset, m): 

        dataset.dataset.disable_transform()

        samples = [[] for _ in range(10)]
        for image, label in dataset:
            label = label % 10 # Map labels to 0-9 range
            samples[label].append(image)

        dataset.dataset.enable_transform()

        # Initialize exemplar sets
        exemplars = [[] for _ in range(10)]

        # Iterate over classes
        for y in range(10):
            print(f"Extracting exemplars from class {y} of current split... ", end="")

            # Transform samples to tensors and apply normalization
            transformed_samples = torch.zeros((len(samples[y]), 3, 32, 32)).to(self.device)
            for i in range(len(transformed_samples)): 
                transformed_samples[i] = self.test_transform(samples[y][i])

            # Extract features from samples
            samples_features = self.extract_features(transformed_samples).to(self.device)

            # Compute the feature mean of the current class
            features_mean = samples_features.mean(dim=0)

            # Initializes indices vector, containing the index of each exemplar chosen
            idx = []

            # See iCaRL algorithm 4
            for k in range(1, m+1): # k = 1, ..., m -- Choose m exemplars
                if k == 1: # No exemplars chosen yet, sum to 0 vector
                    f_sum = torch.zeros(64).to(self.device)
                else: # Sum of features of all exemplars chosen until now (j = 1, ..., k-1)
                    f_sum = samples_features[idx].sum(dim=0)

                # Compute argument of argmin function
                f_arg = torch.norm(features_mean - 1/k * (samples_features + f_sum), dim=1) 
                

                # Mask exemplars that were already taken, as we do not want to store the
                # same exemplar more than once
                mask = np.zeros(len(f_arg), int)
                mask[idx] = 1
                f_arg_masked = ma.masked_array(f_arg.cpu().detach().numpy(), mask=mask) 
                
                # Compute the nearest available exemplar
                exemplar_idx = np.argmin(f_arg_masked)

                idx.append(exemplar_idx)
            
            # Save exemplars to exemplar set
            for i in idx:
                exemplars[y].append(samples[y][i])
            
            print(f"Extracted {len(exemplars[y])} exemplars.")
            
        return exemplars


    def construct_distance_exemplar_set(self, dataset, m):

        dataset.dataset.disable_transform()

        samples = [[] for _ in range(10)]
        for image, label in dataset:
            label = label % 10 # Map labels to 0-9 range
            samples[label].append(image)

        dataset.dataset.enable_transform()

        # Initialize exemplar sets
        exemplars = [[] for _ in range(10)]

        # Iterate over classes
        for y in range(10):
            print(f"Extracting exemplars from class {y} of current split... ", end="")

            # Transform samples to tensors and apply normalization
            transformed_samples = torch.zeros((len(samples[y]), 3, 32, 32)).to(self.device)
            for i in range(len(transformed_samples)):  
                transformed_samples[i] = self.test_transform(samples[y][i])

            # Extract features from samples
            samples_features = self.extract_features(transformed_samples).to(self.device)

            # Compute the feature mean of the current class
            features_mean = samples_features.mean(dim=0)

            #we need here to compute the mean distance:
            distances_from_mean = torch.norm(features_mean - samples_features, dim=1)
            distance_mean = distances_from_mean.mean() 


            # Initializes indices vector, containing the index of each exemplar chosen
            idx = []

            # See iCaRL algorithm 4
            for k in range(1, m+1): 
                # Compute argument of argmin function
                f_arg = torch.abs(distance_mean - torch.norm(features_mean - samples_features, dim=1)) 

                # Mask exemplars that were already taken, as we do not want to store the
                # same exemplar more than once
                mask = np.zeros(len(f_arg), int)
                mask[idx] = 1
                f_arg_masked = ma.masked_array(f_arg.cpu().detach().numpy(), mask=mask)

                # Compute the nearest available exemplar
                exemplar_idx = np.argmin(f_arg_masked)

                idx.append(exemplar_idx)
            
            # Save exemplars to exemplar set
            for i in idx:
                exemplars[y].append(samples[y][i])
            
            print(f"Extracted for distance exemplar set {len(exemplars[y])} exemplars.")
            
        return exemplars


    def reduce_exemplar_set(self, exemplar_set, m):

        return exemplar_set[:m]
    
    #
    #train is the same of standard train routine
    #
    def train(self, train_dataset, val_dataset):
        # Define the optimization algorithm
        parameters_to_optimize = self.net.parameters()
        self.optimizer = optim.SGD(parameters_to_optimize, 
                                   lr=self.LR,
                                   momentum=self.MOMENTUM,
                                   weight_decay=self.WEIGHT_DECAY)
        
        # Define the learning rate decaying policy
        self.scheduler = optim.lr_scheduler.MultiStepLR(self.optimizer,
                                                        milestones=self.MILESTONES,
                                                        gamma=self.GAMMA)

        self.train_dataloader = DataLoader(train_dataset, batch_size=self.BATCH_SIZE, shuffle=True, num_workers=4, drop_last=True)
        self.val_dataloader = DataLoader(val_dataset, batch_size=self.BATCH_SIZE, shuffle=True, num_workers=4, drop_last=True)

        # Send networks to chosen device
        self.net = self.net.to(self.device)
        if self.old_net is not None: self.old_net = self.old_net.to(self.device)

        cudnn.benchmark  # Calling this optimizes runtime

        self.best_val_loss = float('inf')
        self.best_val_accuracy = 0
        self.best_train_loss = float('inf')
        self.best_train_accuracy = 0
        
        self.best_net = None
        self.best_epoch = -1

        for epoch in range(self.NUM_EPOCHS):
            # Run an epoch (start counting form 1)
            train_loss, train_accuracy = self.do_epoch(epoch+1)
        
            # Validate after each epoch 
            val_loss, val_accuracy = self.validate()    

            # Validation criterion: best net is the one that minimizes the loss
            # on the validation set.
            if self.VALIDATE and val_loss < self.best_val_loss:
                self.best_val_loss = val_loss
                self.best_val_accuracy = val_accuracy
                self.best_train_loss = train_loss
                self.best_train_accuracy = train_accuracy

                self.best_net = deepcopy(self.net)
                self.best_epoch = epoch
                print("Best model updated")

        if self.VALIDATE:
            val_loss = self.best_val_loss
            val_accuracy = self.best_val_accuracy
            train_loss = self.best_train_loss
            train_accuracy = self.best_train_accuracy

            print(f"Best model found at epoch {self.best_epoch+1}")

        return train_loss, train_accuracy, val_loss, val_accuracy
    
    def do_epoch(self, current_epoch):

        # Set the current network in training mode
        self.net.train()
        if self.old_net is not None: self.old_net.train(False)
        if self.best_net is not None: self.best_net.train(False)

        running_train_loss = 0
        running_corrects = 0
        total = 0
        batch_idx = 0

        print(f"Epoch: {current_epoch}, LR: {self.scheduler.get_last_lr()}")

        for images, labels in self.train_dataloader:
            loss, corrects = self.do_batch(images, labels)

            running_train_loss += loss.item()
            running_corrects += corrects
            total += labels.size(0)
            batch_idx += 1

        self.scheduler.step()

        # Calculate average scores
        train_loss = running_train_loss / batch_idx # Average over all batches
        train_accuracy = running_corrects / float(total) # Average over all samples

        print(f"Train loss: {train_loss}, Train accuracy: {train_accuracy}")

        return train_loss, train_accuracy

    def do_batch(self, batch, labels):

        batch = batch.to(self.device)
        labels = labels.to(self.device)

        # Zero-ing the gradients
        self.optimizer.zero_grad()
        
       

        num_classes = self.output_neurons_count() # Number of classes seen until now, including new classes
        one_hot_labels = self.to_onehot(labels)[:, num_classes-10:num_classes]

        if self.old_net is None:
            # Network is training for the first time, so we only apply the
            # classification loss.
            targets = one_hot_labels

        else:
           
            old_net_outputs = sigmoid(self.old_net(batch))[:, :num_classes-10] 
            targets = torch.cat((old_net_outputs, one_hot_labels), dim=1)

        # Forward pass
        outputs = self.net(batch)
        loss = self.criterion(outputs, targets)

        # Get predictions
        _, preds = torch.max(outputs.data, 1)

        # Accuracy over NEW IMAGES, not over all images
        running_corrects = torch.sum(preds == labels.data).data.item() 

        # Backward pass: computes gradients
        loss.backward()

        self.optimizer.step()

        return loss, running_corrects

    def validate(self):

        self.net.train(False)
        if self.old_net is not None: self.old_net.train(False)
        if self.best_net is not None: self.best_net.train(False)

        running_val_loss = 0
        running_corrects = 0
        total = 0
        batch_idx = 0

        for images, labels in self.val_dataloader:
            images = images.to(self.device)
            labels = labels.to(self.device)
            total += labels.size(0)

            # One hot encoding of new task labels 
            one_hot_labels = self.to_onehot(labels)

            # New net forward pass
            outputs = self.net(images)  
            loss = self.criterion(outputs, one_hot_labels) # loss type: BCE Loss with sigmoids over outputs

            running_val_loss += loss.item()

            # Get predictions
            _, preds = torch.max(outputs.data, 1)

            # Update the number of correctly classified validation samples
            running_corrects += torch.sum(preds == labels.data).data.item()

            batch_idx += 1

        # Calculate scores
        val_loss = running_val_loss / batch_idx
        val_accuracy = running_corrects / float(total)

        print(f"Validation loss: {val_loss}, Validation accuracy: {val_accuracy}")

        return val_loss, val_accuracy

    def test(self, test_dataset, train_dataset=None):

        self.net.train(False)
        if self.best_net is not None: self.best_net.train(False)  # Set Network to evaluation mode
        if self.old_net is not None: self.old_net.train(False)

        self.test_dataloader = DataLoader(test_dataset, batch_size=self.BATCH_SIZE, shuffle=True, num_workers=4)

        running_corrects = 0
        total = 0

        # To store all predictions
        all_preds = torch.tensor([])
        all_preds = all_preds.type(torch.LongTensor)
        all_targets = torch.tensor([])
        all_targets = all_targets.type(torch.LongTensor)

        # Clear mean of exemplars cache
        self.cached_means = None 
        
        # Disable transformations for train_dataset, if available, as we will
        # need original PIL images from which to extract features.
        if train_dataset is not None: train_dataset.dataset.disable_transform()

        for images, labels in self.test_dataloader:
            images = images.to(self.device)
            labels = labels.to(self.device)

            total += labels.size(0)
            
            with torch.no_grad():
                preds = self.classify(images, train_dataset)

            running_corrects += torch.sum(preds == labels.data).data.item()

            all_targets = torch.cat(
                (all_targets.to(self.device), labels.to(self.device)), dim=0
            )

            all_preds = torch.cat(
                (all_preds.to(self.device), preds.to(self.device)), dim=0
            )

        if train_dataset is not None: train_dataset.dataset.enable_transform()

        # Calculate accuracy
        accuracy = running_corrects / float(total)  

        print(f"Test accuracy (iCaRL): {accuracy} ", end="")

        if train_dataset is None:
            print("(only exemplars)")
        else:
            print("(exemplars and training data)")

        return accuracy, all_targets, all_preds

    
    
    def increment_classes(self, n=10):

        in_features = self.net.fc.in_features  # size of each input sample
        out_features = self.net.fc.out_features  # size of each output sample
        weight = self.net.fc.weight.data
        bias = self.net.fc.bias.data

        self.net.fc = nn.Linear(in_features, out_features+n)
        self.net.fc.weight.data[:out_features] = weight
        self.net.fc.bias.data[:out_features] = bias
    
    def output_neurons_count(self):

        return self.net.fc.out_features
    
    def feature_neurons_count(self):

        return self.net.fc.in_features
    
    def to_onehot(self, targets):
      
        num_classes = self.net.fc.out_features
        one_hot_targets = torch.eye(num_classes)[targets]

        return one_hot_targets.to(self.device)

    def network_params(self):
        weight = self.net.fc.weight.data
        bias = self.net.fc.bias.data

        return weight, bias

# Semantic drift compensation

In [None]:
import torch
import torch.nn as nn
import torch.nn.init as init
import torch.optim as optim
from torch.utils.data import Dataset, Subset, DataLoader, ConcatDataset

from math import floor

from torchvision import transforms
from torchvision.models import resnet34

from sklearn.model_selection import train_test_split
from sklearn.metrics import confusion_matrix

import numpy as np
import numpy.ma as ma

from PIL import Image
from copy import deepcopy

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader, ConcatDataset
from torch.backends import cudnn
from torch.autograd import Variable

import numpy as np
import numpy.ma as ma
from math import floor
from copy import deepcopy
import random

sigmoid = nn.Sigmoid() 

class Exemplars(torch.utils.data.Dataset):
    def __init__(self, exemplars, transform=None):


        self.dataset = []
        self.targets = []

        for y, exemplar_y in enumerate(exemplars):
            self.dataset.extend(exemplar_y)
            self.targets.extend([y] * len(exemplar_y))  

        self.transform = transform
    
    def __getitem__(self, index):
        image = self.dataset[index]
        target = self.targets[index]

        if self.transform is not None:
            image = self.transform(image)

        return image, target

    def __len__(self):
        return len(self.targets)

class iCaRL_SDC:
    

    def __init__(self, device, net, lr, momentum, weight_decay, milestones, gamma, num_epochs, batch_size, train_transform, test_transform):
        self.device = device
        self.net = net
        self.cached_means = None

        # Set hyper-parameters
        self.LR = lr
        self.MOMENTUM = momentum
        self.WEIGHT_DECAY = weight_decay
        self.MILESTONES = milestones
        self.GAMMA = gamma
        self.NUM_EPOCHS = num_epochs
        self.BATCH_SIZE = batch_size
        
        # Set transformations
        self.train_transform = train_transform
        self.test_transform = test_transform


        self.exemplars = []

        self.old_net = None
        self.previous_net = None

        # Maximum number of exemplars
        self.memory_size = 2000
    
        # Loss function
        self.criterion = nn.BCEWithLogitsLoss()

 
        self.VALIDATE = False


    def classify(self, batch, train_dataset=None):
        

        batch_features = self.extract_features(batch) 
        for i in range(batch_features.size(0)):      
            batch_features[i] = batch_features[i]/batch_features[i].norm()
        batch_features = batch_features.to(self.device)
        
        #Ottenimento prototipes
        if self.test_batch == 0: 
            print("Computing mean of exemplars... ", end="")

            # Class means for the last ten classes
            self.new_cached_means = [] 
            
            #Parameter used for Cifar 100 in "Semantic Drift Compensation for Class-Incremental Learning"
            sigma = 0.2 

            # Number of known classes
            num_classes = len(self.exemplars)     


            if train_dataset is not None:
                train_features_list = [[] for _ in range(10)]
                drift_vectors = [] # list collecting difference in feature representations of new and old network
                all_weights = torch.tensor([]).to(self.device) 


                # Computing prototypes
                if self.cached_means is None: 
                  
                  for train_sample, label in train_dataset: 
                    features = self.extract_features(train_sample, batch=False, transform=self.test_transform) 
                    features = features/features.norm()
                    train_features_list[label % 10].append(features)
    
                else:
                  
                  for train_sample, label in train_dataset:
                    #current net representation
                    features = self.extract_features(train_sample, batch=False, transform=self.test_transform) 
                    features = features/features.norm()
                    train_features_list[label % 10].append(features)

                    #old net representation
                    old_features = self.extract_features(train_sample, batch=False, transform=self.test_transform, new = False)
                    old_features = old_features/features.norm()
                    
                    
                    sample_weights = (old_features - self.cached_means).unsqueeze(0) #size: (num cached means, 64)
                    all_weights = torch.cat((all_weights, sample_weights), dim = 0) 
                    
                    # difference in the representation space for current sample
                    single_drift = features - old_features

                    drift_vectors.append(single_drift)

                  drift_vectors = torch.stack(drift_vectors) #size = (num samples, 64)
                  

                  # Eq. 10 and 11 in "Semantic Drift Compensation for Class-Incremental Learning" (Par 4.1)
                  for j in range(len(self.cached_means)):
                   
                    allSamples_CurrentMean_difference = all_weights[:, j] #size: (num samples, 64); 
                    
                    #Eq 11
                    class_weights = torch.exp(-1/(2*sigma**2)*torch.norm(allSamples_CurrentMean_difference, dim = 1)**2) #size = num samples
                    # Eq 10
                    semantic_drift = torch.sum(class_weights.reshape(-1, 1)*drift_vectors, dim = 0)/torch.sum(class_weights)

                    # Updating position of old means, having estimated the shift in feature space in this split
                    self.cached_means[j] = self.cached_means[j] + semantic_drift

            # Computing new means    
            for y in range(num_classes-10, num_classes):
                if (train_dataset is not None):
                    features_list = train_features_list[y % 10]
                    features_list = torch.stack(features_list)

                    class_means = features_list.mean(dim=0) 
                    class_means = class_means/class_means.norm() 
                    self.new_cached_means.append(class_means)

            
            self.new_cached_means = torch.stack(self.new_cached_means).to(self.device)

            if self.cached_means is None: 
              self.cached_means = self.new_cached_means

            else:
              # Inserting new means
              self.cached_means = torch.cat((self.cached_means, self.new_cached_means), dim = 0)
            print("done")
        
        #Classification
        preds = []
        for i in range(batch_features.size(0)): 
            f_arg = torch.norm(batch_features[i] - self.cached_means, dim=1) 
            preds.append(torch.argmin(f_arg)) 
        
        return torch.stack(preds)
    
    def extract_features(self, sample, batch=True, transform=None, new = True):


        assert not (batch is False and transform is None), "if a PIL image is passed to extract_features, a transform must be defined" 

        self.net.train(False)
        if self.best_net is not None: self.best_net.train(False)
        if self.old_net is not None: self.old_net.train(False)
        if self.previous_net is not None: self.previous_net.train(False)



        if batch is False: 
            sample = transform(sample)
            sample = sample.unsqueeze(0) # https://stackoverflow.com/a/59566009/6486336, (3, 32, 32) --> (1, 3, 32, 32)

        sample = sample.to(self.device)
       
        
        if new: # compute feature representations of the current network
          if self.VALIDATE:
            features = self.best_net.features(sample)
          else:
            features = self.net.features(sample)

        else: # compute feature representations of the previous network
          features = self.previous_net.features(sample)


        if batch is False:
            features = features[0] 
        return features

    def incremental_train(self, split, train_dataset, val_dataset):


        if split is not 0:

            self.increment_classes(10)


        train_logs = self.update_representation(train_dataset, val_dataset) 
        
        num_classes = self.output_neurons_count()
        m = floor(self.memory_size / num_classes)

        print(f"Target number of exemplars per class: {m}")
        print(f"Target total number of exemplars: {m*num_classes}")

        # Reduce pre-existing exemplar sets in order to fit new exemplars:
        for y in range(len(self.exemplars)):
            self.exemplars[y] = self.reduce_exemplar_set(self.exemplars[y], m)

        # Construct exemplar set for new classes: 
        new_exemplars = self.construct_exemplar_set_rand(train_dataset, m) #dovremmo usare l'herding
        self.exemplars.extend(new_exemplars)

        return train_logs

    def update_representation(self, train_dataset, val_dataset):

        print(f"Length of exemplars set: {sum([len(self.exemplars[y]) for y in range(len(self.exemplars))])}")

        exemplars_dataset = Exemplars(self.exemplars, self.train_transform)
        train_dataset_with_exemplars = ConcatDataset([exemplars_dataset, train_dataset])

        # Train the network on combined dataset
        train_logs = self.train(train_dataset_with_exemplars, val_dataset) 
        
        if self.old_net is not None: self.previous_net = deepcopy(self.old_net)

        self.old_net = deepcopy(self.net)
        

        return train_logs

    def construct_exemplar_set_rand(self, dataset, m):

        dataset.dataset.disable_transform()
        
        #storing images of a split
        samples = [[] for _ in range(10)]
        for image, label in dataset:
            label = label % 10 # Map labels to 0-9 range: example: 21 -> 1, 22 -> 2 etc.
            samples[label].append(image)


        dataset.dataset.enable_transform()

        exemplars = [[] for _ in range(10)]     
        #random sampling
        for y in range(10):
            print(f"Randomly extracting exemplars from class {y} of current split... ", end="")
            # Randomly choose m samples from samples[y] without replacement
            exemplars[y] = random.sample(samples[y], m)
            print(f"Extracted {len(exemplars[y])} exemplars.")

        return exemplars


    def construct_exemplar_set_herding(self, dataset, m):

        dataset.dataset.disable_transform()

        samples = [[] for _ in range(10)]
        for image, label in dataset:
            label = label % 10 # Map labels to 0-9 range
            samples[label].append(image)

        dataset.dataset.enable_transform()

        # Initialize exemplar sets
        exemplars = [[] for _ in range(10)]

        # Iterate over classes
        for y in range(10):
            print(f"Extracting exemplars from class {y} of current split... ", end="")

            # Transform samples to tensors and apply normalization
            transformed_samples = torch.zeros((len(samples[y]), 3, 32, 32)).to(self.device)
            for i in range(len(transformed_samples)):  
                transformed_samples[i] = self.test_transform(samples[y][i])

            # Extract features from samples
            samples_features = self.extract_features(transformed_samples).to(self.device)

            # Compute the feature mean of the current class
            features_mean = samples_features.mean(dim=0)

     
            idx = []

            # See iCaRL algorithm 4
            for k in range(1, m+1): # k = 1, ..., m -- Choose m exemplars
                if k == 1: # No exemplars chosen yet, sum to 0 vector
                    f_sum = torch.zeros(64).to(self.device)
                else: # Sum of features of all exemplars chosen until now (j = 1, ..., k-1)
                    f_sum = samples_features[idx].sum(dim=0)

                # Compute argument of argmin function
                f_arg = torch.norm(features_mean - 1/k * (samples_features + f_sum), dim=1)
                mask = np.zeros(len(f_arg), int)
                mask[idx] = 1
                f_arg_masked = ma.masked_array(f_arg.cpu().detach().numpy(), mask=mask) 

                # Compute the nearest available exemplar
                exemplar_idx = np.argmin(f_arg_masked)

                idx.append(exemplar_idx)
            
            # Save exemplars to exemplar set
            for i in idx:
                exemplars[y].append(samples[y][i])
            
            print(f"Extracted {len(exemplars[y])} exemplars.")
            
        return exemplars

    def reduce_exemplar_set(self, exemplar_set, m):

        return exemplar_set[:m]
    

    def train(self, train_dataset, val_dataset):


        # Define the optimization algorithm
        parameters_to_optimize = self.net.parameters()
        self.optimizer = optim.SGD(parameters_to_optimize, 
                                   lr=self.LR,
                                   momentum=self.MOMENTUM,
                                   weight_decay=self.WEIGHT_DECAY)
        
        # Define the learning rate decaying policy
        self.scheduler = optim.lr_scheduler.MultiStepLR(self.optimizer,
                                                        milestones=self.MILESTONES,
                                                        gamma=self.GAMMA)

        # Create DataLoaders for training and validation
        self.train_dataloader = DataLoader(train_dataset, batch_size=self.BATCH_SIZE, shuffle=True, num_workers=4, drop_last=True)
        self.val_dataloader = DataLoader(val_dataset, batch_size=self.BATCH_SIZE, shuffle=True, num_workers=4, drop_last=True)

        # Send networks to chosen device
        self.net = self.net.to(self.device)
        if self.old_net is not None: self.old_net = self.old_net.to(self.device)
        if self.previous_net is not None: self.previous_net = self.previous_net.to(self.device)


        cudnn.benchmark  

        self.best_val_loss = float('inf')
        self.best_val_accuracy = 0
        self.best_train_loss = float('inf')
        self.best_train_accuracy = 0
        
        self.best_net = None
        self.best_epoch = -1

        for epoch in range(self.NUM_EPOCHS):
            # Run an epoch (start counting form 1)
            train_loss, train_accuracy = self.do_epoch(epoch+1)
        
            # Validate after each epoch 
            val_loss, val_accuracy = self.validate()    


            if self.VALIDATE and val_loss < self.best_val_loss:
                self.best_val_loss = val_loss
                self.best_val_accuracy = val_accuracy
                self.best_train_loss = train_loss
                self.best_train_accuracy = train_accuracy

                self.best_net = deepcopy(self.net)
                self.best_epoch = epoch
                print("Best model updated")

        if self.VALIDATE:
            val_loss = self.best_val_loss
            val_accuracy = self.best_val_accuracy
            train_loss = self.best_train_loss
            train_accuracy = self.best_train_accuracy

            print(f"Best model found at epoch {self.best_epoch+1}")

        return train_loss, train_accuracy, val_loss, val_accuracy
    
    def do_epoch(self, current_epoch):

        # Set the current network in training mode
        self.net.train()
        if self.old_net is not None: self.old_net.train(False)
        if self.best_net is not None: self.best_net.train(False)
        if self.previous_net is not None: self.previous_net.train(False)

        running_train_loss = 0
        running_corrects = 0
        total = 0
        batch_idx = 0

        print(f"Epoch: {current_epoch}, LR: {self.scheduler.get_last_lr()}")

        for images, labels in self.train_dataloader:
            loss, corrects = self.do_batch(images, labels)

            running_train_loss += loss.item()
            running_corrects += corrects
            total += labels.size(0)
            batch_idx += 1

        self.scheduler.step()

        # Calculate average scores
        train_loss = running_train_loss / batch_idx # Average over all batches
        train_accuracy = running_corrects / float(total) # Average over all samples

        print(f"Train loss: {train_loss}, Train accuracy: {train_accuracy}")

        return train_loss, train_accuracy

    def do_batch(self, batch, labels):

        batch = batch.to(self.device)
        labels = labels.to(self.device)

        # Zero-ing the gradients
        self.optimizer.zero_grad()
        
        num_classes = self.output_neurons_count() 
        one_hot_labels = self.to_onehot(labels)[:, num_classes-10:num_classes]

        if self.old_net is None:
            # Network is training for the first time, so we only apply the
            # classification loss.
            targets = one_hot_labels

        else:
            
            old_net_outputs = sigmoid(self.old_net(batch))[:, :num_classes-10] 
            targets = torch.cat((old_net_outputs, one_hot_labels), dim=1)

        # Forward pass
        outputs = self.net(batch)
        loss = self.criterion(outputs, targets)

        # Get predictions
        _, preds = torch.max(outputs.data, 1)

       
        running_corrects = torch.sum(preds == labels.data).data.item() 

        # Backward pass: computes gradients
        loss.backward()

        self.optimizer.step()

        return loss, running_corrects

    def validate(self): 

        self.net.train(False)
        if self.old_net is not None: self.old_net.train(False)
        if self.best_net is not None: self.best_net.train(False)

        running_val_loss = 0
        running_corrects = 0
        total = 0
        batch_idx = 0

        for images, labels in self.val_dataloader:
            images = images.to(self.device)
            labels = labels.to(self.device)
            total += labels.size(0)

            # One hot encoding of new task labels 
            one_hot_labels = self.to_onehot(labels)

            # New net forward pass
            outputs = self.net(images)  
            loss = self.criterion(outputs, one_hot_labels) # loss type: BCE Loss with sigmoids over outputs

            running_val_loss += loss.item()

            # Get predictions
            _, preds = torch.max(outputs.data, 1)

            # Update the number of correctly classified validation samples
            running_corrects += torch.sum(preds == labels.data).data.item()

            batch_idx += 1

        # Calculate scores
        val_loss = running_val_loss / batch_idx
        val_accuracy = running_corrects / float(total)

        print(f"Validation loss: {val_loss}, Validation accuracy: {val_accuracy}")

        return val_loss, val_accuracy

    def test(self, test_dataset, train_dataset=None):

        self.net.train(False)
        if self.best_net is not None: self.best_net.train(False)  
        if self.old_net is not None: self.old_net.train(False)
        if self.previous_net is not None: self.previous_net.train(False)


        self.test_dataloader = DataLoader(test_dataset, batch_size=self.BATCH_SIZE, shuffle=True, num_workers=4)

        running_corrects = 0
        total = 0

        # To store all predictions
        all_preds = torch.tensor([])
        all_preds = all_preds.type(torch.LongTensor)
        all_targets = torch.tensor([])
        all_targets = all_targets.type(torch.LongTensor)

        # A new test begins, so we need to compute new class means and update the old ones
        self.test_batch = 0 
        

        if train_dataset is not None: train_dataset.dataset.disable_transform()

        for images, labels in self.test_dataloader:
            images = images.to(self.device)
            labels = labels.to(self.device)

            total += labels.size(0)
            
            with torch.no_grad():
                preds = self.classify(images, train_dataset)
                
            self.test_batch += 1
            running_corrects += torch.sum(preds == labels.data).data.item()

            all_targets = torch.cat(
                (all_targets.to(self.device), labels.to(self.device)), dim=0
            )

            all_preds = torch.cat(
                (all_preds.to(self.device), preds.to(self.device)), dim=0
            )

        if train_dataset is not None: train_dataset.dataset.enable_transform()

        # Calculate accuracy
        accuracy = running_corrects / float(total)  

        print(f"Test accuracy (iCaRL): {accuracy} ", end="")

        if train_dataset is None:
            print("(only exemplars)")
        else:
            print("(exemplars and training data)")

        return accuracy, all_targets, all_preds

    
    def increment_classes(self, n=10):


        in_features = self.net.fc.in_features  # size of each input sample
        out_features = self.net.fc.out_features  # size of each output sample
        weight = self.net.fc.weight.data
        bias = self.net.fc.bias.data

        self.net.fc = nn.Linear(in_features, out_features+n)
        self.net.fc.weight.data[:out_features] = weight
        self.net.fc.bias.data[:out_features] = bias
    
    def output_neurons_count(self):
 

        return self.net.fc.out_features
    
 
    
    def to_onehot(self, targets):

        num_classes = self.net.fc.out_features
        one_hot_targets = torch.eye(num_classes)[targets]

        return one_hot_targets.to(self.device)

