In [None]:
import torch
import torch.nn as nn
from torch.backends import cudnn
from copy import deepcopy

class Manager():

    def __init__(self, device, net, criterion, optimizer, scheduler, train_dataloader, val_dataloader, test_dataloader):
        self.device = device

        self.net = net
        self.best_net = self.net       

        self.criterion = criterion
        self.optimizer = optimizer
        self.scheduler = scheduler

        self.train_dataloader = train_dataloader
        self.val_dataloader = val_dataloader
        self.test_dataloader = test_dataloader


    def increment_classes(self, n=10): 
        
        in_features = self.net.fc.in_features    # size of each input sample         
        out_features = self.net.fc.out_features  # size of each output sample
        weight = self.net.fc.weight.data

        self.net.fc = nn.Linear(in_features, out_features+n)   #increment output neurons
        self.net.fc.weight.data[:out_features] = weight       


    def to_onehot(self, targets):


      num_classes = self.net.fc.out_features
      one_hot_targets = torch.eye(num_classes)[targets] 

      return one_hot_targets.to(self.device)

    
    def train(self, num_epochs):


        self.net.to(self.device)
        cudnn.benchmark  # Calling this optimizes runtime 

        self.best_accuracy = 0 
        self.best_epoch = 0

        for epoch in range(num_epochs):
            # Run an epoch (start counting form 1: so we add +1)
            train_loss, train_accuracy = self.do_epoch(epoch+1)  
        
            # Validate after each epoch 
            val_loss, val_accuracy = self.validate()    

            # Best validation model
            if val_accuracy > self.best_accuracy:     # we deepcopy the network if validation scores
                self.best_accuracy = val_accuracy     # are the best until now
                self.best_net = deepcopy(self.net)
                self.best_epoch = epoch
                

            

        return (train_loss, train_accuracy,
                val_loss, val_accuracy)
    
    def do_epoch(self, current_epoch):


        self.net.train()  #Set network in training mode 

        running_train_loss = 0
        running_corrects = 0
        total = 0
        batch_idx = 0

        print(f"Epoch: {current_epoch}, LR: {self.scheduler.get_last_lr()}")   

        for images, labels in self.train_dataloader:    
            loss, corrects = self.do_batch(images, labels)  #'do_batch' trains model for one batch
            running_train_loss += loss.item()
            running_corrects += corrects  
            total += labels.size(0)       
            batch_idx += 1                

        self.scheduler.step()      #If you don’t call it, the learning rate won’t be changed over epochs and stays at the initial value.


        # Calculate average scores
        train_loss = running_train_loss / batch_idx       # Average loss over all batches seen in the epoch
        train_accuracy = running_corrects / float(total)  # Average accuracy over all samples (images) seen in the epoch

        print(f"Train loss: {train_loss}, Train accuracy: {train_accuracy}")

        return (train_loss, train_accuracy)

    def do_batch(self, batch, labels):  


        batch = batch.to(self.device)           #send to GPU label and batch
        labels = labels.to(self.device)

        # Zero-ing the gradients
        self.optimizer.zero_grad()    #Sets gradients of all model parameters to zero before backpropragation

        # One hot encoding of new task labels 
        one_hot_labels = self.to_onehot(labels) # Size = [64, 10*iteration]  

        # net forward pass
        outputs = self.net(batch)       
        
        loss = self.criterion(outputs, one_hot_labels) # BCE Loss with sigmoids over outputs 
                                                       

        # Get predictions
        _, preds = torch.max(outputs.data, 1) 

        # Compute the number of correctly classified images
        running_corrects = \
            torch.sum(preds == labels.data).data.item()  

        # Backward pass: computes gradients
        loss.backward()  

        # Update weights based on accumulated gradients
        self.optimizer.step()

        return (loss, running_corrects)

    def validate(self):    

        self.net.train(False)   

        running_val_loss = 0
        running_corrects = 0
        total = 0
        batch_idx = 0

        for images, labels in self.val_dataloader:    
            images = images.to(self.device)           
            labels = labels.to(self.device)
            total += labels.size(0)

            # One hot encoding of new task labels 
            one_hot_labels = self.to_onehot(labels) # Size = [batch_size, 10*iteration]

            # New net forward pass
            outputs = self.net(images)  
            loss = self.criterion(outputs, one_hot_labels) # BCE Loss with sigmoids over outputs

            running_val_loss += loss.item()

            # Get predictions
            _, preds = torch.max(outputs.data, 1)

            running_corrects += torch.sum(preds == labels.data).data.item() # Update the number of correctly classified validation samples


            batch_idx += 1

        
        val_loss = running_val_loss / batch_idx
        val_accuracy = running_corrects / float(total)

        print(f"Validation loss: {val_loss}, Validation accuracy: {val_accuracy}")

        return (val_loss, val_accuracy)

    def test(self):

        self.best_net.train(False)  # Set Network to evaluation mode
                                    # we take the best net encountered during ephocs
        running_corrects = 0
        total = 0

        all_preds = torch.tensor([]) # to store all predictions
        all_preds = all_preds.type(torch.LongTensor)
        all_targets = torch.tensor([])
        all_targets = all_targets.type(torch.LongTensor)
        
        for images, labels in self.test_dataloader:   
            images = images.to(self.device)
            labels = labels.to(self.device)
            total += labels.size(0)

            # Forward Pass
            outputs = self.best_net(images)

            # Get predictions
            _, preds = torch.max(outputs.data, 1)

            # Update Corrects
            running_corrects += torch.sum(preds == labels.data).data.item()

            # Append batch predictions and labels
            all_targets = torch.cat(
                (all_targets.to(self.device), labels.to(self.device)), dim=0
            )
            all_preds = torch.cat(
                (all_preds.to(self.device), preds.to(self.device)), dim=0
            )

        # Calculate accuracy
        accuracy = running_corrects / float(total)  

        print(f"Test accuracy: {accuracy}")

        return accuracy, all_targets, all_preds