# Supermask Overlap and Task Similarity Functions

## Overview

This notebook implements several functions to analyse the similarity between neural network activations for different tasks and determine importance scores.

The key functions are:

* **jaccard_index**: Calculates the Jaccard similarity index between two supermasks. This measures the overlap between the supermasks.
* **plot_supermask**: Visualises the supermasks' neural activity on a plot.
* **calculate_task_similarityE1**: Computes the cosine similarity between the batch normalisation layer means for two tasks in the classic approach. This gives a measure of supermask overlap or task similarity.
* **calculate_task_similarityE2**: Compute similarities for multiple layers and return a dictionary with similarity matrices for each layer in the novel approach.
* **determine_alphas**: Uses the task similarity matrices to determine alpha values per layer for a given task in the novel approach. The alpha values indicate how much the most similar prior task should influence the training of the current task.

The similarity analysis enables us to quantify how related different tasks are based on the neural activations. Plotting the supermasks also provides a visual depiction of the neural representations. The alpha values are then used and set for training the novel approach. 

This notebook implements several valuable functions for analysing and visualising neural network dynamics for continual learning across multiple tasks. The similarity analysis and alpha calculation, in particular, provide critical insights into mask overlap for different tasks and inform the training procedure of the novel approach.

##  Importing Required Libraries

In [None]:
import torch
import numpy as np
import matplotlib.pyplot as plt
from sklearn.metrics.pairwise import cosine_similarity
from mpl_toolkits.axes_grid1 import make_axes_locatable

## Classic Approach

In [None]:
def jaccard_index(supermask1, supermask2):
    # Calculate the intersection (count of positions where both supermasks have a value of 1)
    intersection = np.sum(np.logical_and(supermask1, supermask2))
    
    # Calculate the union (count of positions where at least one supermask has a value of 1)
    union = np.sum(np.logical_or(supermask1, supermask2))
    
    # Compute the Jaccard index
    jaccard = intersection / union if union != 0 else 0
    
    return jaccard




def plot_supermask(ax, supermask, task_id):
    # Display the supermask on the given axis using a reverse grayscale colormap
    im = ax.imshow(supermask, cmap="gray_r")
    
    # Set title and axis labels
    ax.set_title(f"Neural Activity for Task {task_id}")
    ax.set_xlabel("Neuron Index")
    ax.set_ylabel("Neuron Index")
    
    # Create space for the colourbar on the right of the plot
    divider = make_axes_locatable(ax)
    cax = divider.append_axes("right", size="5%", pad=0.05)
    
    # Attach the colourbar
    plt.colorbar(im, cax=cax)
    
    plt.savefig('figures/perm_neural_acitivity.png', bbox_inches='tight', dpi=300)
    
    


def calculate_task_similarityE1(bn_means_dict, num_tasks):
    
    # Initialise a similarity matrix with zeros
    similarities_matrix = np.zeros((num_tasks, num_tasks))
    
    # Loop through each task
    for task_id in range(num_tasks):
        
        # Retrieve the mean dictionary for the current task
        current_task_mean_dict = bn_means_dict[task_id]
        
        # Skip the current task if the mean dictionary is None or does not contain index 1
        if current_task_mean_dict is None or 1 not in current_task_mean_dict:
            
            continue
            
        # Get the mean from the first batch normalisation layer
        current_task_mean = current_task_mean_dict[1]
        
        # Convert mean to numpy array
        current_task_mean_numpy = current_task_mean.cpu().numpy().reshape(1, -1)
        
        # Iterate through the mean dictionaries for all tasks
        for i, mean_dict in enumerate(bn_means_dict.values()):
            
            # If mean dictionary exists and contains index 1
            if mean_dict is not None and 1 in mean_dict:
                
                # Retrieve the mean for the first batch normalisation layer
                mean = mean_dict[1]
                
                # Convert mean to numpy array
                mean_numpy = mean.cpu().numpy().reshape(1, -1)
                
                # Calculate cosine similarity
                similarity = cosine_similarity(current_task_mean_numpy, mean_numpy)
                
                # Get absolute value
                similarity_value = abs(similarity[0][0]) if isinstance(similarity, np.ndarray) else abs(similarity)
                
                # Assign the similarity value to the matrix
                similarities_matrix[task_id, i] = similarity_value
                
    # Return the similarities matrix
    return similarities_matrix

## Novel Approach

In [None]:
def calculate_task_similarityE2(bn_means_dict, num_tasks):
    
    # Dictionary to store similarities per layer
    layer_similarities = {}
    
    # Check if the mean dictionary for task 0 is not None before accessing keys
    if bn_means_dict[0] is not None:
        
        # Iterate through each layer index in the mean dictionary
        for layer_index in bn_means_dict[0].keys():
            
            # Initialise a similarity matrix with zeros
            similarities_matrix = np.zeros((num_tasks, num_tasks))
            
            # Loop through each task
            for task_id in range(num_tasks):
                
                # Retrieve the mean dictionary for the current task
                current_task_mean = bn_means_dict[task_id]
                
                # If the mean dictionary exists
                if current_task_mean is not None:
                    
                    # Retrieve and convert the mean for the current layer to numpy array
                    current_task_mean_numpy = current_task_mean[layer_index].cpu().numpy().reshape(1, -1)
                    
                    # Iterate through the means for all tasks
                    for i, mean in enumerate(bn_means_dict.values()):
                        
                        # If mean dictionary and the mean for the current layer exist
                        if mean is not None and mean[layer_index] is not None:
                            
                            # Convert the mean for the current layer to numpy array
                            mean_numpy = mean[layer_index].cpu().numpy().reshape(1, -1)
                            
                            # Calculate cosine similarity
                            similarity = cosine_similarity(current_task_mean_numpy, mean_numpy)
                            
                            # Get absolute value of similarity
                            similarity_value = abs(similarity[0][0]) if isinstance(similarity, np.ndarray) else abs(similarity)
                            
                            # Assign the similarity value to the matrix
                            similarities_matrix[task_id, i] = similarity_value
                            
            # Store the similarities matrix for the current layer
            layer_similarities[layer_index] = similarities_matrix
            
    # Return the dictionary containing similarities per layer
    return layer_similarities






def determine_alphas(similarities, current_task):
    
    # Dictionary to store alpha values per layer
    alphas_per_layer = {}
    
    # Iterate through each layer index and corresponding similarity matrix
    for layer_index, layer_similarity in similarities.items():
        
        # Retrieve similarities for the current task
        task_similarities = layer_similarity[current_task]
        
        # Create zeros array of the same shape as similarities
        alphas = np.zeros_like(task_similarities)
        
        # Find the index of the most similar task
        most_similar_task = np.argmax(task_similarities)
        
        # Set alpha for the most similar task to 1
        alphas[most_similar_task] = 1
        
        # Convert alphas to a tensor
        alphas_per_layer[layer_index] = torch.tensor(alphas, dtype=torch.float)
        
    # Return the dictionary containing alphas per layer
    return alphas_per_layer