# Experiment 1

## Overview

This notebook implements Experiment 1 to analyse the extent of supermasks overlaps when training our classic model on variants of the MNIST dataset.

The key components are:

* **Model Training**:
    * Implements modular network training by looping through tasks, training the MultitaskFC model and extracting supermasks.
    * Three dataset variants are used - Permuted MNIST, Rotated MNIST and Partitioned MNIST.

* **Supermask Analysis**:
    * Calculates supermask overlaps using the Jaccard similarity index.
    * Visually inspects supermasks patterns by plotting grids.
    * Statistically compares supermask similarity using t-tests.

* **Results Analysis**:
    * Supermask overlaps are visualised in 3D scatter plots.
    * Statistical comparison results are collected in DataFrames and sorted.
    * Key observations are made about supermask patterns for each dataset variant.

The notebook demonstrates our comprehensive experiment workflow for analysing and determining the extent of overlap between supermasks for different tasks - implementing training, extracting supermasks, evaluating similarities, and visualising and comparing results. The analysis provides insights into how supermasks overlap based on the nature of tasks.

##  Importing Required Libraries

In [None]:
import torch
import numpy as np
import pandas as pd
import torch.optim as optim
import matplotlib.pyplot as plt
from scipy.stats import ttest_ind
from mpl_toolkits.axes_grid1 import make_axes_locatable

from utilities.train_funcs import train
from utilities.eval_pred_funcs import evaluate
from utilities.models import MultitaskFC, MultitaskMaskLinear
from utilities.data import MNISTPerm, PartitionMNIST, RotatingMNIST
from utilities.similarity_funcs import jaccard_index, plot_supermask
from utilities.utils import cache_masks, set_model_task, set_num_tasks_learned

## Permuted MNIST

In [None]:
# Load permuted MNIST dataset
mnist = MNISTPerm()

# Set the number of tasks
num_tasks = 2

# Initialise the MultitaskFC model for the tasks
model = MultitaskFC(hidden_size=300, num_tasks=num_tasks)

# Initialise a list to store supermasks for each task
supermasks_perm = []

# Loop through each task for training and evaluation
for task_id in range(num_tasks):
    print(f"Training for task {task_id}")
    
    # Set the current task in the model
    set_model_task(model, task_id)
    
    # Update the task in the dataset
    mnist.update_task(task_id)
    
    # Initialise the optimiser (RMSprop) for model parameters that require gradient computation
    optimizer = optim.RMSprop([p for p in model.parameters() if p.requires_grad], lr=1e-4)
    
    # Loop over each epoch to train the model
    for e in range(1):
        # Use the training function for this model
        train(model, mnist.train_loader, optimizer, e, task_id)
        
        # Display validation information
        print("Validation")
        print("============")
        
        # Evaluate the model's performance on the validation dataset
        acc1 = evaluate(model, mnist.val_loader, e)

    # Retrieve and store the supermasks for this task
    masks_for_current_task = model.get_masks(layer_index=3)
    
    # Append the mask corresponding to the current task ID to the supermasks_perm list
    supermasks_perm.append(masks_for_current_task[task_id])
    
    # Cache the current state of the masks in the model
    cache_masks(model)
    print()
    
    # Update the number of learned tasks in the model
    set_num_tasks_learned(model, task_id + 1)
    print()

In [None]:
# Initialise a list to store data about overlap between supermasks
data = []

# Loop over each task
for i in range(num_tasks):
    # Loop over each task again to compare each pair of tasks
    for j in range(num_tasks):
        # Skip comparison of the same task with itself
        if i == j:
            continue
        
        # Convert supermasks from tensor format to numpy array format
        supermask1 = supermasks_perm[i].detach().numpy()
        supermask2 = supermasks_perm[j].detach().numpy()
        
        # Compute the Jaccard index between the two supermasks 
        # and round the result to two decimal places after multiplying by 100 to get a percentage
        overlap = round(jaccard_index(supermask1, supermask2) * 100, 2)

        # Append the task IDs and their overlap percentage to the data list
        data.append([i, j, overlap])
        
        # Print out the tasks being compared and their overlap percentage
        print(f"Task: {i}, Task: {j}, Overlap: {overlap}%")

# Convert the collected data into a DataFrame with appropriate column names
df = pd.DataFrame(data, columns=["supermask1", "supermask2", "overlap"])

# Initialise a 3D plotting figure
fig = plt.figure(figsize=(10, 7))
ax = fig.add_subplot(111, projection='3d')

# Plot the data points in 3D plot
ax.scatter(df['supermask1'], df['supermask2'], df['overlap'])
ax.set_xlabel('Supermask (Task ID)')
ax.set_ylabel('Supermask (Task ID)')
ax.set_zlabel('Overlap (%)')

# Set the title for the plot
plt.title('Jaccard Index Overlap - MNIST Permuted')

# Save the plot
plt.savefig('figures/perm_supermasks_overlap_jac_index.png', bbox_inches='tight', dpi=300)

# Display the plot
plt.show()

In [None]:
# Get the total number of tasks from the length of the supermasks_perm list
num_tasks = len(supermasks_perm)

# Create a grid of subplots with 5 rows and 2 columns, and define the figure size
fig, axes = plt.subplots(nrows=5, ncols=2, figsize=(8, 20))

# Loop through each supermask in the supermasks_perm list
for i, supermask in enumerate(supermasks_perm):
    # Determine the row number for plotting based on the current index
    row = i % 5
    # Determine the column number for plotting based on the current index
    col = i // 5
    # Plot the supermask on the respective subplot
    plot_supermask(axes[row][col], supermask.detach().numpy(), i)

# Adjust the layout of the plots for better visualisation and to prevent overlap
plt.tight_layout()

# Display the entire grid of plots
plt.show()

In [None]:
# Convert each supermask in the supermasks_perm list from tensor to numpy format
all_supermasks = [mask.detach().numpy() for mask in supermasks_perm]

# Convert the list of supermasks into a numpy array for indexing and manipulation
all_supermasks = np.array(all_supermasks)

# Define a function to statistically compare two tasks based on their supermasks
def compare_tasks_statistically(task1, task2):
    # Flatten the supermasks for the two tasks to make them 1D arrays
    data1 = all_supermasks[task1].flatten()
    data2 = all_supermasks[task2].flatten()
    
    # Compute the t-statistic and p-value using a two-sample t-test to compare the two tasks
    t_stat, p_value = ttest_ind(data1, data2)
    
    # Return the results in a dictionary format
    return {"Task Pair": (task1, task2), "t-statistic": t_stat, "p-value": p_value}

# Initialise a list to store comparison results between tasks
results = []

# Loop through each unique pair of tasks
for i in range(len(all_supermasks)):
    for j in range(i+1, len(all_supermasks)):
        # Compare the tasks statistically and store the results
        result = compare_tasks_statistically(i, j)
        
        # Append the results to the results list
        results.append(result)

# Convert the list of comparison results into a DataFrame
df_results = pd.DataFrame(results)

# Print the entire DataFrame without row indices
print(df_results.to_string(index=False))

# Sort the DataFrame by p-values in descending order (higher p-values indicate greater similarity) and reset its index
df_sorted_similarity = df_results.sort_values(by='p-value', ascending=False).reset_index(drop=True)

# Print the sorted DataFrame without row indices
print(df_sorted_similarity.to_string(index=False))

## Rotated MNIST

In [None]:
# Load rotated MNIST dataset
mnist = RotatingMNIST()

# Set the number of tasks
num_tasks = 10

# Initialise the MultitaskFC model for the tasks
model = MultitaskFC(hidden_size=300, num_tasks=num_tasks)

# Initialise a list to store supermasks for each task
supermasks_rotate = []

# Loop through each task for training and evaluation
for task_id in range(num_tasks):
    print(f"Training for task {task_id}")
    
    # Set the current task in the model
    set_model_task(model, task_id)
    
    # Update the task in the dataset
    mnist.update_task(task_id)
    
    # Initialise the optimiser (RMSprop) for model parameters that require gradient computation
    optimizer = optim.RMSprop([p for p in model.parameters() if p.requires_grad], lr=1e-4)
    
    # Loop over each epoch to train the model
    for e in range(1):
        # Use the training function for this model
        train(model, mnist.train_loader, optimizer, e, task_id)
        
        # Display validation information
        print("Validation")
        print("============")
        
        # Evaluate the model's performance on the validation dataset
        acc1 = evaluate(model, mnist.val_loader, e)

    # Retrieve and store the supermasks for this task
    masks_for_current_task = model.get_masks(layer_index=3)
    
    # Append the mask corresponding to the current task ID to the supermasks_rotate list
    supermasks_rotate.append(masks_for_current_task[task_id])
    
    # Cache the current state of the masks in the model
    cache_masks(model)
    print()
    
    # Update the number of learned tasks in the model
    set_num_tasks_learned(model, task_id + 1)
    print()

In [None]:
# Initialise a list to store data about overlap between supermasks
data = []

# Loop over each task
for i in range(num_tasks):
    # Loop over each task again to compare each pair of tasks
    for j in range(num_tasks):
        # Skip comparison of the same task with itself
        if i == j:
            continue
            
        # Convert supermasks from tensor format to numpy array format
        supermask1 = supermasks_rotate[i].detach().numpy()
        supermask2 = supermasks_rotate[j].detach().numpy()
        
        # Compute the Jaccard index between the two supermasks 
        # and round the result to two decimal places after multiplying by 100 to get a percentage
        overlap = round(jaccard_index(supermask1, supermask2) * 100, 2)
        
        # Append the task IDs and their overlap percentage to the data list
        data.append([i, j, overlap])
        
        # Print out the tasks being compared and their overlap percentage
        print(f"Task: {i}, Task: {j}, Overlap: {overlap}%")
        
# Convert the collected data into a DataFrame with appropriate column names
df = pd.DataFrame(data, columns=["supermask1", "supermask2", "overlap"])

# Initialise a 3D plotting figure
fig = plt.figure(figsize=(10, 7))
ax = fig.add_subplot(111, projection='3d')

# Plot the data points in 3D plot
ax.scatter(df['supermask1'], df['supermask2'], df['overlap'])
ax.set_xlabel('Supermask (Task ID)')
ax.set_ylabel('Supermask (Task ID)')
ax.set_zlabel('Overlap (%)')

# Set the title for the plot
plt.title('Jaccard Index Overlap - MNIST Rotated')

# Save the plot
plt.savefig('figures/rotate_supermasks_overlap_jac_index.png', bbox_inches='tight', dpi=300)

# Display the plot
plt.show()

In [None]:
# Get the total number of tasks from the length of the supermasks_rotate list
num_tasks = len(supermasks_rotate)

# Create a grid of subplots with 5 rows and 2 columns, and define the figure size
fig, axes = plt.subplots(nrows=5, ncols=2, figsize=(8, 20))

# Loop through each supermask in the supermasks_rotate list
for i, supermask in enumerate(supermasks_rotate):
    # Determine the row number for plotting based on the current index
    row = i % 5
    # Determine the column number for plotting based on the current index
    col = i // 5
    # Plot the supermask on the respective subplot
    plot_supermask(axes[row][col], supermask.detach().numpy(), i)
    
# Adjust the layout of the plots for better visualisation and to prevent overlap
plt.tight_layout()

# Display the entire grid of plots
plt.show()

In [None]:
# Convert each supermask in the supermasks_rotate list from tensor to numpy format
all_supermasks = [mask.detach().numpy() for mask in supermasks_rotate]

# Convert the list of supermasks into a numpy array for indexing and manipulation
all_supermasks = np.array(all_supermasks)

# Initialise a list to store comparison results between tasks
results = []

# Loop through each unique pair of tasks
for i in range(len(all_supermasks)):
    for j in range(i+1, len(all_supermasks)):
        # Compare the tasks statistically and store the results
        result = compare_tasks_statistically(i, j)
        
        # Append the results to the results list
        results.append(result)
        
# Convert the list of comparison results into a DataFrame
df_results = pd.DataFrame(results)

# Print the entire DataFrame without row indices
print(df_results.to_string(index=False))

# Sort the DataFrame by p-values in descending order (higher p-values indicate greater similarity) and reset its index
df_sorted_similarity = df_results.sort_values(by='p-value', ascending=False).reset_index(drop=True)

# Print the sorted DataFrame without row indices
print(df_sorted_similarity.to_string(index=False))

## Partitioned MNIST

In [None]:
# Load partitioned MNIST dataset
mnist = PartitionMNIST()

# Set the number of tasks
num_tasks = 10

# Initialise the MultitaskFC model for the tasks
model = MultitaskFC(hidden_size=300, num_tasks=num_tasks)

# Initialise a list to store supermasks for each task
supermasks_part = []

# Loop through each task for training and evaluation
for task_id in range(num_tasks):
    print(f"Training for task {task_id}")
    
    # Set the current task in the model
    set_model_task(model, task_id)
    
    # Update the task in the dataset
    mnist.update_task(task_id)
    
    # Initialise the optimiser (RMSprop) for model parameters that require gradient computation
    optimizer = optim.RMSprop([p for p in model.parameters() if p.requires_grad], lr=1e-4)
    
    # Loop over each epoch to train the model
    for e in range(1):
        # Use the training function for this model
        train(model, mnist.train_loader, optimizer, e, task_id)
        
         # Display validation information
        print("Validation")
        print("============")
        
        # Evaluate the model's performance on the validation dataset
        acc1 = evaluate(model, mnist.val_loader, e)
    
    # Retrieve and store the supermasks for this task
    masks_for_current_task = model.get_masks(layer_index=3)
    
    # Append the mask corresponding to the current task ID to the supermasks_part list
    supermasks_part.append(masks_for_current_task[task_id])
    
    # Cache the current state of the masks in the model
    cache_masks(model)
    print()
    
    # Update the number of learned tasks in the model
    set_num_tasks_learned(model, task_id + 1)
    print()

In [None]:
# Initialise a list to store data about overlap between supermasks
data = []

# Loop over each task
for i in range(num_tasks):
    # Loop over each task again to compare each pair of tasks
    for j in range(num_tasks):
        # Skip comparison of the same task with itself
        if i == j:
            continue
            
        # Convert supermasks from tensor format to numpy array format
        supermask1 = supermasks_part[i].detach().numpy()
        supermask2 = supermasks_part[j].detach().numpy()
        
        # Compute the Jaccard index between the two supermasks 
        # and round the result to two decimal places after multiplying by 100 to get a percentage
        overlap = round(jaccard_index(supermask1, supermask2) * 100, 2)
        
        # Append the task IDs and their overlap percentage to the data list
        data.append([i, j, overlap])
        
        # Print out the tasks being compared and their overlap percentage
        print(f"Task: {i}, Task: {j}, Overlap: {overlap}%")

# Convert the collected data into a DataFrame with appropriate column names
df = pd.DataFrame(data, columns=["supermask1", "supermask2", "overlap"])

# Initialise a 3D plotting figure
fig = plt.figure(figsize=(10, 7))
ax = fig.add_subplot(111, projection='3d')

# Plot the data points in 3D plot
ax.scatter(df['supermask1'], df['supermask2'], df['overlap'])
ax.set_xlabel('Supermask (Task ID)')
ax.set_ylabel('Supermask (Task ID)')
ax.set_zlabel('Overlap (%)')

# Set the title for the plot
plt.title('Jaccard Index Overlap - MNIST Partitioned')

# Save the plot
plt.savefig('figures/part_supermasks_overlap_jac_index.png', bbox_inches='tight', dpi=300)

# Display the plot
plt.show()

In [None]:
# Get the total number of tasks from the length of the supermasks_part list
num_tasks = len(supermasks_part)

# Create a grid of subplots with 5 rows and 2 columns, and define the figure size
fig, axes = plt.subplots(nrows=5, ncols=2, figsize=(8, 20))

# Loop through each supermask in the supermasks_part list
for i, supermask in enumerate(supermasks_part):
    # Determine the row number for plotting based on the current index
    row = i % 5
    # Determine the column number for plotting based on the current index
    col = i // 5
    # Plot the supermask on the respective subplot
    plot_supermask(axes[row][col], supermask.detach().numpy(), i)

# Adjust the layout of the plots for better visualisation and to prevent overlap
plt.tight_layout()

# Display the entire grid of plots
plt.show()

In [None]:
# Convert each supermask in the supermasks_part list from tensor to numpy format
all_supermasks = [mask.detach().numpy() for mask in supermasks_part]

# Convert the list of supermasks into a numpy array for indexing and manipulation
all_supermasks = np.array(all_supermasks)

# Initialise a list to store comparison results between tasks
results = []

# Loop through each unique pair of tasks
for i in range(len(all_supermasks)):
    for j in range(i+1, len(all_supermasks)):
        # Compare the tasks statistically and store the results
        result = compare_tasks_statistically(i, j)
        
        # Append the results to the results list
        results.append(result)

# Convert the list of comparison results into a DataFrame
df_results = pd.DataFrame(results)

# Print the entire DataFrame without row indices
print(df_results.to_string(index=False))

# Sort the DataFrame by p-values in descending order (higher p-values indicate greater similarity) and reset its index
df_sorted_similarity = df_results.sort_values(by='p-value', ascending=False).reset_index(drop=True)

# Print the sorted DataFrame without row indices
print(df_sorted_similarity.to_string(index=False))

-------------------------------------------------------------------------------------------------------------------------------

#### Code adapted from:

* https://github.com/pytorch
* https://github.com/RAIVNLab/supsup