# Models

## Overview

This notebook provides an implementation of subnet extraction and weight initialisation for our fully connected with masked linear layers models. 

The key contents are as follows:

* The **GetSubnet** class is designed to extract subnets from given network layers. It provides methods to determine if a score is greater than or equal to zero, making it a binary tensor. It also ensures that the gradients remain unchanged during the backward pass.

* The **mask_init** function initialises a scores tensor, which has the same shape as the weight tensor of a module. The tensor is populated with values from a uniform distribution.

* The **signed_constant** function modifies the weights of a given module based on its sign and a calculated constant.

* The **MultitaskMaskLinear** class is a customised masked linear layer which supports the classic model learning with unique masks for each task. The mask determines which weights are active during the forward pass.

* The **MultitaskFC** class is a fully connected multitask network that utilises the specialised masked linear layers, followed by batch normalisation and ReLU activation. It has methods to retrieve the batch normalisation means and subnet masks for each task.

* The **MultitaskMaskLinearV2** class is identical to MultitaskMaskLinear in functionality but developed to be part of the novel approach.

* The **MultitaskFCV2** class is an upgraded version of MultitaskFC that utilises the MultitaskMaskLinearV2 layers. It introduces methods for setting the alpha values for each masked linear layer.

Throughout the implementation, there is a consistent emphasis on maintaining separate masks (or subnets) for each task, ensuring that the networks can be trained for continual learning across multiple tasks without catastrophic forgetting.

##  Importing Required Libraries

In [None]:
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.autograd as autograd

## Implementing Subnet Extraction and Weight Initialisation for Network Layers

In [None]:
# Define a class to get the subnet of a network layer
class GetSubnet(autograd.Function):
    # Override the static method forward in the parent class.
    @staticmethod
    def forward(ctx, scores):
        # Return a binary tensor where each element is 1 if the corresponding 
        # score is greater than or equal to 0, otherwise 0
        return (scores >= 0).float()

    # Override the static method 'backward' in the parent class
    @staticmethod
    def backward(ctx, g):
        # Return the gradient g as it is, meaning this function does not alter 
        # the gradient
        return g

# Function to initialise the scores tensor, which has the same shape as the 
# weight tensor of a given module
def mask_init(module):
    # Create a tensor scores of the same size as the weight tensor in the module
    scores = torch.Tensor(module.weight.size())
    # Initialise the scores tensor with values drawn from a uniform distribution
    nn.init.kaiming_uniform_(scores, a=math.sqrt(5))
    return scores

# Function to adjust the weights of a module by the sign of the weights times a constant
def signed_constant(module):
    # Calculate the correct fan-in for the given module
    fan = nn.init._calculate_correct_fan(module.weight, "fan_in")
    # Calculate the gain for a ReLU activation function
    gain = nn.init.calculate_gain("relu")
    # Calculate the standard deviation using the gain and the fan-in
    std = gain / math.sqrt(fan)
    # Update the weights of the module by multiplying the sign of the weight by the 
    # calculated standard deviation
    module.weight.data = module.weight.data.sign() * std

## MultiTask Masked Linear Layer (Classic)

In [None]:
# Define a class to create a multitask linear layer with masks
class MultitaskMaskLinear(nn.Linear):
    def __init__(self, *args, num_tasks=1, **kwargs):
        # Initialise the parent class with the provided arguments and keyword arguments
        super().__init__(*args, **kwargs)
        # Store the number of tasks
        self.num_tasks = num_tasks
        # Create a list of parameters for the scores, initialised with the mask_init 
        # function for each task
        self.scores = nn.ParameterList(
            [nn.Parameter(mask_init(self)) for _ in range(num_tasks)]
        )
        
        # Disable gradients for the weight tensor
        self.weight.requires_grad = False
        # Adjust the weights using the signed_constant function
        signed_constant(self)

    # Define a method to cache the subnet masks with no gradient tracking
    @torch.no_grad()
    def cache_masks(self):
        # Register a buffer for the stacked masks
        self.register_buffer(
            "stacked",
            # Stack the subnets for all tasks
            torch.stack(
                [GetSubnet.apply(self.scores[j]) for j in range(self.num_tasks)]
            ),
        )

    # Override the forward pass method of the parent class
    def forward(self, x):
        # If task index is less than 0, perform a superimposed forward pass
        if self.task < 0:
            alpha_weights = self.alphas[: self.num_tasks_learned]
            # Create a binary index mask for tasks with non-zero alpha weights
            idxs = (alpha_weights > 0).squeeze().view(self.num_tasks_learned)
            if len(idxs.shape) == 0:
                idxs = idxs.view(1)
            # Compute the subnet as the sum of scaled stacked masks for tasks with 
            # non-zero alpha weights
            subnet = (
                alpha_weights[idxs] * self.stacked[: self.num_tasks_learned][idxs]
            ).sum(dim=0)
        else:
            # For single task, get subnet using GetSubnet class
            subnet = GetSubnet.apply(self.scores[self.task])
        # Multiply the weight by the subnet
        w = self.weight * subnet

        # Apply a linear transformation to the input tensor x using the computed weight 
        # and bias
        x = F.linear(x, w, self.bias)
        return x

    # Override the representation method to display custom class name and dimensions
    def __repr__(self):
        return f"MultitaskMaskLinear({self.in_dims}, {self.out_dims})"

## Multitask Fully Connected Network (Classic)

In [None]:
# Define a class to create a fully connected multitask network
class MultitaskFC(nn.Module):
    def __init__(self, hidden_size, num_tasks):
        # Initialise the parent class
        super().__init__()
        # Define the model as a sequence of layers
        self.model = nn.Sequential(
            # Add a multitask masked linear layer with input size 784 and output size hidden_size
            # Each task has its own mask. Bias is disabled in this layer
            MultitaskMaskLinear(784, hidden_size, num_tasks=num_tasks, bias=False),
            # Normalise the output of the previous layer across the feature dimension
            nn.BatchNorm1d(hidden_size),
            # Apply the ReLU activation function
            nn.ReLU(),
            # Add a multitask masked linear layer with input size 784 and output size hidden_size
            # Each task has its own mask. Bias is disabled in this layer
            MultitaskMaskLinear(hidden_size, hidden_size, num_tasks=num_tasks, bias=False),
            # Normalise the output of the previous layer across the feature dimension
            nn.BatchNorm1d(hidden_size),
            # Apply the ReLU activation function
            nn.ReLU(),
            # Add a multitask masked linear layer with input size hidden_size and output size 100
            # Each task has its own mask. Bias is disabled in this layer
            MultitaskMaskLinear(hidden_size, 100, num_tasks=num_tasks, bias=False),
        )
        
        # Initialise means for each task's batch normalisation layer as a dictionary
        self.bn_means = {i: {} for i in range(num_tasks)}

    # Method to retrieve the running means of batch normalisation layers for a specific task
    def get_bn_means(self, task_id):
        # Iterate through the layers of the model
        for i, layer in enumerate(self.model):
            # Check if the layer is a batch normalisation layer
            if isinstance(layer, nn.BatchNorm1d):
                # Store the running mean of the batch normalisation layer for the specific task
                self.bn_means[task_id][i] = layer.running_mean.detach().clone()
        # Return the running means for the specified task
        return self.bn_means[task_id]
    
    def get_masks(self, layer_index):
        # Retrieve the layer by index from the model
        layer = self.model[layer_index]
        # Ensure the layer is an instance of MultitaskMaskLinear
        if isinstance(layer, MultitaskMaskLinear):
            # Return the masks for the specific layer
            return [GetSubnet.apply(score) for score in layer.scores]
        else:
            # Handle cases where the layer is not an instance of MultitaskMaskLinear
            raise ValueError(f"Layer at index {layer_index} is not an instance of MultitaskMaskLinear.")

    
    # Forward pass method to compute the model's output
    def forward(self, x):
        # Flatten the input x along dimension 1 and pass it through the sequential model
        return self.model(x.flatten(1))

## MultiTask Masked Linear Layer (Novel)

In [None]:
# Define a class to create a multitask linear layer with masks
class MultitaskMaskLinearV2(nn.Linear):
    def __init__(self, *args, num_tasks=1, **kwargs):
        # Initialise the parent class with the provided arguments and keyword arguments
        super().__init__(*args, **kwargs)
        # Store the number of tasks
        self.num_tasks = num_tasks
        # Create a list of parameters for the scores, initialised with the mask_init 
        # function for each task
        self.scores = nn.ParameterList(
            [nn.Parameter(mask_init(self)) for _ in range(num_tasks)]
        )
        
        # Disable gradients for the weight tensor to keep the weights untrained
        self.weight.requires_grad = False
        # Adjust the weights using the signed_constant function
        signed_constant(self)

    # Define a method to cache the subnet masks with no gradient tracking
    @torch.no_grad()
    def cache_masks(self):
        # Register a buffer for the stacked masks
        self.register_buffer(
            "stacked",
            # Stack the subnets for all tasks
            torch.stack(
                [GetSubnet.apply(self.scores[j]) for j in range(self.num_tasks)]
            ),
        )

    # Override the forward pass method of the parent class
    def forward(self, x):
        # If task index is less than 0, perform a superimposed forward pass
        if self.task < 0:
            alpha_weights = self.alphas[: self.num_tasks_learned]
            # Create a binary index mask for tasks with non-zero alpha_weights
            idxs = (alpha_weights > 0).squeeze().view(self.num_tasks_learned)
            if len(idxs.shape) == 0:
                idxs = idxs.view(1)
            # Compute the subnet as the sum of scaled stacked masks for tasks with 
            # non-zero alpha_weights
            subnet = (
                alpha_weights[idxs] * self.stacked[: self.num_tasks_learned][idxs]
            ).sum(dim=0)
        else:
            # For single task, get subnet using GetSubnet class
            subnet = GetSubnet.apply(self.scores[self.task])
        # Multiply the weight by the subnet
        w = self.weight * subnet

        # Apply a linear transformation to the input tensor x using the computed weight 
        # and bias
        x = F.linear(x, w, self.bias)
        return x

    # Override the representation method to display custom class name and dimensions
    def __repr__(self):
        return f"MultitaskMaskLinearV2({self.in_dims}, {self.out_dims})"

## Multitask Fully Connected Network (Novel)

In [None]:
# Define a subclass to create a fully connected multitask network
class MultitaskFCV2(nn.Module):
    def __init__(self, hidden_size, num_tasks):
        # Initialise the parent class
        super().__init__()
        # Define the model as a sequence of layers
        self.model = nn.Sequential(
            # Add a multitask masked linear layer with input size 784 and output size 'hidden_size'
            # Each task has its own mask. Bias is disabled in this layer
            MultitaskMaskLinearV2(784, hidden_size, num_tasks=num_tasks, bias=False),
            # Normalise the output of the previous layer across the feature dimension (dimension 1)
            nn.BatchNorm1d(hidden_size),
            # Apply the ReLU activation function
            nn.ReLU(),
            # Add another multitask masked linear layer with input size 'hidden_size' and output size 'hidden_size'
            # Each task has its own mask. Bias is disabled in this layer
            MultitaskMaskLinearV2(hidden_size, hidden_size, num_tasks=num_tasks, bias=False),
            # Normalise the output of the previous layer across the feature dimension (dimension 1)
            nn.BatchNorm1d(hidden_size),
            # Apply the ReLU activation function
            nn.ReLU(),
            # Add a multitask masked linear layer with input size 'hidden_size' and output size 100
            # Each task has its own mask. Bias is disabled in this layer
            MultitaskMaskLinearV2(hidden_size, 100, num_tasks=num_tasks, bias=False),
        )
        
        # Initialise means for each task's batch normalisation layer as a dictionary
        self.bn_means = {i: None for i in range(num_tasks)}
        # Initialise current_task attribute, to identify the current task being processed
        self.current_task = None
            
            
    def set_alphas(self, alphas_per_layer, verbose=True):
        # Iterate over the dictionary containing the alphas for each layer
        for layer_index, alphas in alphas_per_layer.items():
            # Iterate over named modules in the model to find MultitaskMaskLinearV2 instances
            for n, m in self.named_modules():
                # Check if the module is of type MultitaskMaskLinearV2 and if the current layer index matches
                if isinstance(m, MultitaskMaskLinearV2) and int(n.split('.')[1]) in [0, 3]:
                    # If verbose is True, print information about the alpha setting
                    if verbose:
                        print(f"=> Setting alphas for {n}")
                    # Set the alphas value for the current module
                    m.alphas = alphas


    # Override the forward pass method of the parent class
    def forward(self, x):
        # Flatten the input tensor x starting from the first dimension
        x = x.flatten(1)
        # Loop through each layer of the model
        for i, layer in enumerate(self.model):
            # Apply the layer on the input
            x = layer(x)
        return x # Return the processed tensor

-------------------------------------------------------------------------------------------------------------------------------

#### Code adapted from:

* https://github.com/pytorch
* https://github.com/RAIVNLab/supsup
* https://github.com/allenai/hidden-networks