In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader, ConcatDataset
from torch.backends import cudnn

import numpy as np
import numpy.ma as ma
from math import floor
from copy import deepcopy
import random

softmax = nn.Softmax(dim=None)

sigmoid = nn.Sigmoid() 

class Exemplars(torch.utils.data.Dataset):
    def __init__(self, exemplars, transform=None):
       
        self.dataset = []
        self.targets = []

        for y, exemplar_y in enumerate(exemplars):
            self.dataset.extend(exemplar_y)  
            self.targets.extend([y] * len(exemplar_y)) 

        self.transform = transform
    
    def __getitem__(self, index):
        image = self.dataset[index]
        target = self.targets[index]

        if self.transform is not None:
            image = self.transform(image)

        return image, target

    def __len__(self):
        return len(self.targets)

class iCaRLWithRejectFC:

    def __init__(self, device, net, lr, momentum, weight_decay, milestones, gamma, num_epochs, batch_size, train_transform, test_transform):
        self.device = device
        self.net = net

        # Set hyper-parameters
        self.LR = lr
        self.MOMENTUM = momentum
        self.WEIGHT_DECAY = weight_decay
        self.MILESTONES = milestones
        self.GAMMA = gamma
        self.NUM_EPOCHS = num_epochs
        self.BATCH_SIZE = batch_size
        
        # Set transformations
        self.train_transform = train_transform
        self.test_transform = test_transform

        # List of exemplar sets. Each set contains memory_size/num_classes exemplars
        # with num_classes the number of classes seen until now by the network.
        self.exemplars = []

        # Initialize the copy of the old network, used to compute outputs of the
        # previous network for the distillation loss, to None. This is useful to
        # correctly apply the first function when training the network for the
        # first time.
        self.old_net = None

        # Maximum number of exemplars
        self.memory_size = 2000
    
        # Loss function
        self.criterion = nn.BCEWithLogitsLoss()

        # If True, test on the best model found (e.g., minimize loss). If False,
        # test on the last model build (of the last epoch).
        self.VALIDATE = False
        self.activations_var = []


    def test_without_classifier(self, test_dataset, threshold):

        self.net.train(False)
        if self.best_net is not None: self.best_net.train(False) # Set Network to evaluation mode
        if self.old_net is not None: self.old_net.train(False)

        self.test_dataloader = DataLoader(test_dataset, batch_size=self.BATCH_SIZE, shuffle=True, num_workers=4)

        self.threshold = threshold

        running_corrects = 0
        total = 0
        running_unknown = 0 


        all_preds = torch.tensor([]) 
        all_preds = all_preds.type(torch.LongTensor)
        all_targets = torch.tensor([])
        all_targets = all_targets.type(torch.LongTensor)
        
        for images, labels in self.test_dataloader:
            images = images.to(self.device)
            labels = labels.to(self.device)
            total += labels.size(0)

            # Forward Pass
            with torch.no_grad():
                if self.VALIDATE:
                    outputs = self.best_net(images)
                else:
                    outputs = self.net(images)

            probabilities = softmax(outputs)
        

            # Get predictions
            max_probabilities, preds = torch.max(probabilities.data, 1)


            # OWR: increment number of unknown images seen so far and label them as 101 (unknown)
            for i in range(len(preds)):
                if max_probabilities[i].data.item() < self.threshold:
                    running_unknown += 1
                    preds[i] = 101


            # Update Corrects
            running_corrects += torch.sum(preds == labels.data).data.item()

            all_targets = torch.cat(
                (all_targets.to(self.device), labels.to(self.device)), dim=0
            )

            # Append batch predictions
            all_preds = torch.cat(
                (all_preds.to(self.device), preds.to(self.device)), dim=0
            )


        # Calculate accuracy and unknown ratio
        accuracy = running_corrects / float(total) 
        unknown_ratio = running_unknown / float(total) 
 

        print(f"Test accuracy (hybrid1-FC): {accuracy}")
        print(f"Unknown Ratio: {unknown_ratio} ", end="")


        return accuracy, unknown_ratio, all_targets, all_preds
    
    
    def extract_features(self, sample, batch=True, transform=None):

        assert not (batch is False and transform is None)

        self.net.train(False)
        if self.best_net is not None: self.best_net.train(False)
        if self.old_net is not None: self.old_net.train(False)

        if batch is False: 
            sample = transform(sample)
            sample = sample.unsqueeze(0) 

        sample = sample.to(self.device)

        if self.VALIDATE: 
            features = self.best_net.features(sample)
        else:
            features = self.net.features(sample)

        if batch is False:
            features = features[0] 

        return features

    def incremental_train(self, split, train_dataset, val_dataset):

        if split is not 0:
            # Increment the number of output nodes for the new network by 10
            self.increment_classes(10)

        # Improve network parameters upon receiving new classes. Effectively
        # train a new network starting from the current network parameters.
        train_logs = self.update_representation(train_dataset, val_dataset) 

        # Compute the number of exemplars per class
        num_classes = self.output_neurons_count()
        m = floor(self.memory_size / num_classes)

        print(f"Target number of exemplars per class: {m}")
        print(f"Target total number of exemplars: {m*num_classes}")

        # Reduce pre-existing exemplar sets in order to fit new exemplars
        for y in range(len(self.exemplars)):
            self.exemplars[y] = self.reduce_exemplar_set(self.exemplars[y], m)

        # Construct exemplar set for new classes
        new_exemplars = self.construct_exemplar_set_rand(train_dataset, m) 
        self.exemplars.extend(new_exemplars) 
        

        return train_logs

    def update_representation(self, train_dataset, val_dataset): 

        # Combine the new training data with existing exemplars.

        print(f"Length of exemplars set: {sum([len(self.exemplars[y]) for y in range(len(self.exemplars))])}")

        exemplars_dataset = Exemplars(self.exemplars, self.train_transform)
        train_dataset_with_exemplars = ConcatDataset([exemplars_dataset, train_dataset])

        # Train the network on combined dataset
        train_logs = self.train(train_dataset_with_exemplars, val_dataset) 

        # Keep a copy of the current network in order to compute its outputs for
        # the distillation loss while the new network is being trained.
        self.old_net = deepcopy(self.net)

        return train_logs


    def construct_exemplar_set_rand(self, dataset, m):

        dataset.dataset.disable_transform()
        
        #storing images of a split
        samples = [[] for _ in range(10)]
        for image, label in dataset:
            label = label % 10 # Map labels to 0-9 range
            samples[label].append(image)

        dataset.dataset.enable_transform()

        exemplars = [[] for _ in range(10)]
        
        #sampling
        for y in range(10):
            print(f"Randomly extracting exemplars from class {y} of current split... ", end="")

            # Randomly choose m samples from samples[y] without replacement
            exemplars[y] = random.sample(samples[y], m)

            print(f"Extracted {len(exemplars[y])} exemplars.")

        return exemplars

    def construct_exemplar_set_herding(self, dataset, m): 

        dataset.dataset.disable_transform()

        samples = [[] for _ in range(10)]
        for image, label in dataset:
            label = label % 10 # Map labels to 0-9 range
            samples[label].append(image)

        dataset.dataset.enable_transform()

        # Initialize exemplar sets
        exemplars = [[] for _ in range(10)]

        # Iterate over classes
        for y in range(10):
            print(f"Extracting exemplars from class {y} of current split... ", end="")

            # Transform samples to tensors and apply normalization
            transformed_samples = torch.zeros((len(samples[y]), 3, 32, 32)).to(self.device)
            for i in range(len(transformed_samples)):  
                transformed_samples[i] = self.test_transform(samples[y][i])

            # Extract features from samples
            samples_features = self.extract_features(transformed_samples).to(self.device)

            # Compute the feature mean of the current class
            features_mean = samples_features.mean(dim=0)

            # Initializes indices vector, containing the index of each exemplar chosen
            idx = []

            # See iCaRL algorithm 4
            for k in range(1, m+1): 
                if k == 1: 
                    f_sum = torch.zeros(64).to(self.device)
                else: 
                    f_sum = samples_features[idx].sum(dim=0)

                # Compute argument of argmin function
                f_arg = torch.norm(features_mean - 1/k * (samples_features + f_sum), dim=1) 
                

                # Mask exemplars that were already taken, as we do not want to store the
                # same exemplar more than once
                mask = np.zeros(len(f_arg), int)
                mask[idx] = 1
                f_arg_masked = ma.masked_array(f_arg.cpu().detach().numpy(), mask=mask) 

                # Compute the nearest available exemplar
                exemplar_idx = np.argmin(f_arg_masked)

                idx.append(exemplar_idx)
            
            # Save exemplars to exemplar set
            for i in idx:
                exemplars[y].append(samples[y][i])
            
            print(f"Extracted {len(exemplars[y])} exemplars.")
            
        return exemplars

    def reduce_exemplar_set(self, exemplar_set, m):

        return exemplar_set[:m]

    def train(self, train_dataset, val_dataset):

        # Define the optimization algorithm
        parameters_to_optimize = self.net.parameters()
        self.optimizer = optim.SGD(parameters_to_optimize, 
                                   lr=self.LR,
                                   momentum=self.MOMENTUM,
                                   weight_decay=self.WEIGHT_DECAY)
        
        # Define the learning rate decaying policy
        self.scheduler = optim.lr_scheduler.MultiStepLR(self.optimizer,
                                                        milestones=self.MILESTONES,
                                                        gamma=self.GAMMA)

        # Create DataLoaders for training and validation
        self.train_dataloader = DataLoader(train_dataset, batch_size=self.BATCH_SIZE, shuffle=True, num_workers=4, drop_last=True)
        self.val_dataloader = DataLoader(val_dataset, batch_size=self.BATCH_SIZE, shuffle=True, num_workers=4, drop_last=True)

        # Send networks to chosen device
        self.net = self.net.to(self.device)
        if self.old_net is not None: self.old_net = self.old_net.to(self.device)

        cudnn.benchmark  # Calling this optimizes runtime

        self.best_val_loss = float('inf')
        self.best_val_accuracy = 0
        self.best_train_loss = float('inf')
        self.best_train_accuracy = 0
        
        self.best_net = None
        self.best_epoch = -1

        for epoch in range(self.NUM_EPOCHS):
            # Run an epoch (start counting form 1)
            train_loss, train_accuracy = self.do_epoch(epoch+1)
        
            # Validate after each epoch 
            val_loss, val_accuracy = self.validate()    

            # Validation criterion: best net is the one that minimizes the loss
            # on the validation set.
            if self.VALIDATE and val_loss < self.best_val_loss:
                self.best_val_loss = val_loss
                self.best_val_accuracy = val_accuracy
                self.best_train_loss = train_loss
                self.best_train_accuracy = train_accuracy

                self.best_net = deepcopy(self.net)
                self.best_epoch = epoch
                print("Best model updated")

        if self.VALIDATE:
            val_loss = self.best_val_loss
            val_accuracy = self.best_val_accuracy
            train_loss = self.best_train_loss
            train_accuracy = self.best_train_accuracy

            print(f"Best model found at epoch {self.best_epoch+1}")

        return train_loss, train_accuracy, val_loss, val_accuracy
    
    def do_epoch(self, current_epoch):

        # Set the current network in training mode
        self.net.train()
        if self.old_net is not None: self.old_net.train(False)
        if self.best_net is not None: self.best_net.train(False)

        running_train_loss = 0
        running_corrects = 0
        total = 0
        batch_idx = 0

        print(f"Epoch: {current_epoch}, LR: {self.scheduler.get_last_lr()}")

        for images, labels in self.train_dataloader:
            loss, corrects = self.do_batch(images, labels)

            running_train_loss += loss.item()
            running_corrects += corrects
            total += labels.size(0)
            batch_idx += 1

        self.scheduler.step()

        # Calculate average scores
        train_loss = running_train_loss / batch_idx 
        train_accuracy = running_corrects / float(total) 

        print(f"Train loss: {train_loss}, Train accuracy: {train_accuracy}")

        return train_loss, train_accuracy

    def do_batch(self, batch, labels):

        batch = batch.to(self.device)
        labels = labels.to(self.device)

        # Zero-ing the gradients
        self.optimizer.zero_grad()

        num_classes = self.output_neurons_count() 
        one_hot_labels = self.to_onehot(labels)[:, num_classes-10:num_classes]

        if self.old_net is None:

            targets = one_hot_labels

        else:
            
            old_net_outputs = sigmoid(self.old_net(batch))[:, :num_classes-10] 

            targets = torch.cat((old_net_outputs, one_hot_labels), dim=1)     

        # Forward pass
        outputs = self.net(batch)
        loss = self.criterion(outputs, targets)

        # Get predictions
        _, preds = torch.max(outputs.data, 1)

        # Accuracy over new images
        running_corrects = torch.sum(preds == labels.data).data.item() 

        # Backward pass: computes gradients
        loss.backward()

        self.optimizer.step()

        return loss, running_corrects

    def validate(self): 

        self.net.train(False)
        if self.old_net is not None: self.old_net.train(False)
        if self.best_net is not None: self.best_net.train(False)

        running_val_loss = 0
        running_corrects = 0
        total = 0
        batch_idx = 0

        for images, labels in self.val_dataloader:
            images = images.to(self.device)
            labels = labels.to(self.device)
            total += labels.size(0)

            # One hot encoding of new task labels 
            one_hot_labels = self.to_onehot(labels)

            # New net forward pass
            outputs = self.net(images)  
            loss = self.criterion(outputs, one_hot_labels) # BCE Loss with sigmoids over outputs

            running_val_loss += loss.item()

            # Get predictions
            _, preds = torch.max(outputs.data, 1)

            # Update the number of correctly classified validation samples
            running_corrects += torch.sum(preds == labels.data).data.item()

            batch_idx += 1

        # Calculate scores
        val_loss = running_val_loss / batch_idx
        val_accuracy = running_corrects / float(total)

        print(f"Validation loss: {val_loss}, Validation accuracy: {val_accuracy}")

        return val_loss, val_accuracy

    
    def increment_classes(self, n=10):

        in_features = self.net.fc.in_features  
        out_features = self.net.fc.out_features  
        weight = self.net.fc.weight.data
        bias = self.net.fc.bias.data

        self.net.fc = nn.Linear(in_features, out_features+n)
        self.net.fc.weight.data[:out_features] = weight
        self.net.fc.bias.data[:out_features] = bias
    

    def output_neurons_count(self):

        return self.net.fc.out_features
    

    def feature_neurons_count(self):

        return self.net.fc.in_features
    

    def to_onehot(self, targets):
 
        num_classes = self.net.fc.out_features
        one_hot_targets = torch.eye(num_classes)[targets]

        return one_hot_targets.to(self.device)
