# Experiment 2

## Overview

This notebook implements Experiment 2 to evaluate classic and novel continual learning approaches on variants of the MNIST dataset. 

The key components are:

* **Classic Approach**:
    * Implements classic approach training by looping through tasks, training the model and evaluating performance.
    * Calculates task similarities using the batch normalisation layer means.
    * Plots task similarity heatmaps

* **Novel Approach**:
    * Implements novel training using our proposed method.
    * Loads pretrained classic models to initialise task-specific batch normalisation means.
    * Calculates task similarities and soft parameter sharing alphas.
    * Trains using modified parameter sharing

* **Post-Training Procedures**:
    * Results are collected into DataFrames.
    * Accuracy, training time and loss are plotted for the classic and novel approaches.

The notebook demonstrates our comprehensive experiment workflow - implementing baselines, proposing a novel method, training models, evaluating performance and comparing results.

##  Importing Required Libraries

In [None]:
import os
import torch
import numpy as np
import pandas as pd
import seaborn as sns
from time import time
import torch.optim as optim
import matplotlib.pyplot as plt

from utilities.eval_pred_funcs import evaluate
from utilities.train_funcs import train, trainV2
from utilities.models import MultitaskFC, MultitaskFCV2
from utilities.data import MNISTPerm, PartitionMNIST, RotatingMNIST
from utilities.utils import cache_masks, set_model_task, set_num_tasks_learned
from utilities.similarity_funcs import calculate_task_similarityE1, calculate_task_similarityE2, determine_alphas

## Classic Approach

In [None]:
# Define a function for classic training over a specified number of epochs and tasks
# with the hidden_size set as 300
def training_classic_ev(epochs, num_tasks, hidden_size=300):
    # Initialise a dictionary to store results
    results = {}

    # Load permuted MNIST dataset
    mnist = MNISTPerm()

    # Initialise the MultitaskFC model for the tasks
    model = MultitaskFC(hidden_size=hidden_size, num_tasks=num_tasks)
    
    # Initialise a results dictionary for the permuted MNIST dataset
    results["perm"] = {}
    
    # Initialise lists to store accuracies, training times and losses for each task
    accs = []
    times = []
    losses = []
    
    # Dictionary to store the btach normalisation means for each task
    bn_means_dict = {}

    # Loop through each task for training and evaluation
    for task_id in range(num_tasks):
        # Initialise a dictionary for each task's results
        results["perm"][task_id] = {}
        
        # Log the training task number
        print(f"Training for task {task_id}")
        
        # Set the current task in the model
        set_model_task(model, task_id)  
        
        # Update the task in the dataset
        mnist.update_task(task_id)

        # Initialise the optimiser (RMSprop) for model parameters that require gradient computation
        optimizer = optim.RMSprop(
            [p for p in model.parameters() if p.requires_grad], lr=1e-4
        )

        # Record the start time for training
        start = time()

        # Loop over each epoch to train the model
        for e in range(epochs):
            # Use the training function for this model
            average_loss, bn_means = train(model, mnist.train_loader, optimizer, e, task_id)
            
            # Store the batch normalisation means in the dictionary for the task
            bn_means_dict[task_id] = bn_means

            # Display validation information
            print("Validation")
            print("============")
            
            # Evaluate the model's performance on the validation dataset
            acc1 = evaluate(model, mnist.val_loader, e)

        # Record the end time for training
        end = time()
        
        # Calculate the elapsed time for training
        time_taken = end - start
        
        # Log the time taken for training
        print(f"Time taken: {time_taken}")
        print()
        
        # Store the training time, accuracy and loss in the results dictionary
        results["perm"][task_id]["time"] = time_taken
        results["perm"][task_id]["acc"] = acc1
        results["perm"][task_id]["loss"] = average_loss
        
        # Append the accuracy, training time and loss to their respective lists
        accs.append(acc1)
        times.append(time_taken)
        losses.append(average_loss)

        # Cache the current state of the masks in the model
        cache_masks(model)
        print()

        # Update the number of learned tasks in the model
        set_num_tasks_learned(model, task_id + 1)
        print()
        
        # Save the model for the current task
        current_directory = os.getcwd()
        models_directory = os.path.join(current_directory, 'models')
        evaluation_directory = os.path.join(models_directory, 'evaluation')
        file_path = os.path.join(evaluation_directory, f'ev_permuted_model_task_{task_id}.pth')
        torch.save(model.state_dict(), file_path)
    
    # Compute the cosine similarity matrix for tasks based on the batch normalisation means
    similarities_matrix = calculate_task_similarityE1(bn_means_dict, num_tasks)
        
    # Plot the similarity matrix as a heatmap
    plt.figure(figsize=(10, 8))
    sns.heatmap(similarities_matrix, cmap='Blues', annot=True, linewidths=.5)
    plt.title('Task Similarity Heatmap - Permuted MNIST')
    plt.xlabel('Task ID')
    plt.ylabel('Task ID')
    plt.savefig('figures/perm_similarity_heatmap.png', dpi=300)
    plt.show()
    
    # Calculate and store the average accuracy, training time and loss for the permuted MNIST dataset
    results["perm"]["average_acc"] = np.average(np.array(accs))
    results["perm"]["average_time"] = np.average(np.array(times))
    results["perm"]["average_loss"] = np.average(np.array(losses))


    
    
    
    
    
    # Load rotated MNIST dataset
    mnist = RotatingMNIST()

    # Initialise the MultitaskFC model for the tasks
    model = MultitaskFC(hidden_size=hidden_size, num_tasks=num_tasks)
    
    # Initialise a results dictionary for the rotated MNIST dataset
    results["rotate"] = {}
    
    # Initialise lists to store accuracies, training times and losses for each task
    accs = []
    times = []
    losses = []
    
    # Dictionary to store the bacth normalisation means for each task
    bn_means_dict = {}

    # Loop through each task for training and evaluation
    for task_id in range(num_tasks):
        # Initialise a dictionary for each task's results
        results["rotate"][task_id] = {}
        
        # Log the training task number
        print(f"Training for task {task_id}")
        
        # Set the current task in the model
        set_model_task(model, task_id)  
        
        # Update the task in the dataset
        mnist.update_task(task_id)

        # Initialise the optimiser (RMSprop) for model parameters that require gradient computation
        optimizer = optim.RMSprop(
            [p for p in model.parameters() if p.requires_grad], lr=1e-4
        )

        # Record the start time for training
        start = time()

        # Loop over each epoch to train the model
        for e in range(epochs):
            # Use the training function for this model
            average_loss, bn_means = train(model, mnist.train_loader, optimizer, e, task_id)
            
            # Store the batch normalisation means in the dictionary for the task
            bn_means_dict[task_id] = bn_means

            # Display validation information
            print("Validation")
            print("============")
            
            # Evaluate the model's performance on the validation dataset
            acc1 = evaluate(model, mnist.val_loader, e)

        # Record the end time for training
        end = time()
        
        # Calculate the elapsed time for training
        time_taken = end - start
        
        # Log the time taken for training
        print(f"Time taken: {time_taken}")
        print()
        
        # Store the training time, accuracy and loss in the results dictionary
        results["rotate"][task_id]["time"] = time_taken
        results["rotate"][task_id]["acc"] = acc1
        results["rotate"][task_id]["loss"] = average_loss
        
        # Append the accuracy, training time and loss to their respective lists
        accs.append(acc1)
        times.append(time_taken)
        losses.append(average_loss)

        # Cache the current state of the masks in the model
        cache_masks(model)
        print()

        # Update the number of learned tasks in the model
        set_num_tasks_learned(model, task_id + 1)
        print()
        
        # Save the model for the current task
        current_directory = os.getcwd()
        models_directory = os.path.join(current_directory, 'models')
        evaluation_directory = os.path.join(models_directory, 'evaluation')
        file_path = os.path.join(evaluation_directory, f'ev_rotated_model_task_{task_id}.pth')
        torch.save(model.state_dict(), file_path)
    
    # Compute the cosine similarity matrix for tasks based on the batch normalisation means
    similarities_matrix = calculate_task_similarityE1(bn_means_dict, num_tasks)
        
    # Plot the similarity matrix as a heatmap
    plt.figure(figsize=(10, 8))
    sns.heatmap(similarities_matrix, cmap='Blues', annot=True, linewidths=.5)
    plt.title('Task Similarity Heatmap - Rotated MNIST')
    plt.xlabel('Task ID')
    plt.ylabel('Task ID')
    plt.savefig('figures/rotate_similarity_heatmap.png', dpi=300)
    plt.show()
    
    # Calculate and store the average accuracy, training time and loss for the rotated MNIST dataset
    results["rotate"]["average_acc"] = np.average(np.array(accs))
    results["rotate"]["average_time"] = np.average(np.array(times))
    results["rotate"]["average_loss"] = np.average(np.array(losses))

    
    
    
    

    
    # Load partitioned MNIST dataset
    mnist = PartitionMNIST()

    # Initialise the MultitaskFC model for the tasks
    model = MultitaskFC(hidden_size=hidden_size, num_tasks=num_tasks)
    
    # Initialise a results dictionary for the partitioned MNIST dataset
    results["part"] = {}
    
    # Initialise lists to store accuracies, training times and losses for each task
    accs = []
    times = []
    losses = []
    
    # Dictionary to store the bacth normalisation means for each task
    bn_means_dict = {}

    # Loop through each task for training and evaluation
    for task_id in range(num_tasks):
        # Initialise a dictionary for each task's results
        results["part"][task_id] = {}
        
        # Log the training task number
        print(f"Training for task {task_id}")
        
        # Set the current task in the model
        set_model_task(model, task_id)  
        
        # Update the task in the dataset
        mnist.update_task(task_id)

        # Initialise the optimiser (RMSprop) for model parameters that require gradient computation
        optimizer = optim.RMSprop(
            [p for p in model.parameters() if p.requires_grad], lr=1e-4
        )

        # Record the start time for training
        start = time()

        # Loop over each epoch to train the model
        for e in range(epochs):
            # Use the training function for this model
            average_loss, bn_means = train(model, mnist.train_loader, optimizer, e, task_id)
            
            # Store the batch normalisation means in the dictionary for the task
            bn_means_dict[task_id] = bn_means

            # Display validation information
            print("Validation")
            print("============")
            
            # Evaluate the model's performance on the validation dataset
            acc1 = evaluate(model, mnist.val_loader, e)

        # Record the end time for training
        end = time()
        
        # Calculate the elapsed time for training
        time_taken = end - start
        
        # Log the time taken for training
        print(f"Time taken: {time_taken}")
        print()
        
        # Store the training time, accuracy and loss in the results dictionary
        results["part"][task_id]["time"] = time_taken
        results["part"][task_id]["acc"] = acc1
        results["part"][task_id]["loss"] = average_loss
        
        # Append the accuracy, training time and loss to their respective lists
        accs.append(acc1)
        times.append(time_taken)
        losses.append(average_loss)

        # Cache the current state of the masks in the model
        cache_masks(model)
        print()

        # Update the number of learned tasks in the model
        set_num_tasks_learned(model, task_id + 1)
        print()
        
        # Save the model for the current task
        current_directory = os.getcwd()
        models_directory = os.path.join(current_directory, 'models')
        evaluation_directory = os.path.join(models_directory, 'evaluation')
        file_path = os.path.join(evaluation_directory, f'ev_partitioned_model_task_{task_id}.pth')
        torch.save(model.state_dict(), file_path)
     
    # Compute the cosine similarity matrix for tasks based on the batch normalisation means
    similarities_matrix = calculate_task_similarityE1(bn_means_dict, num_tasks)
        
    # Plot the similarity matrix as a heatmap
    plt.figure(figsize=(10, 8))
    sns.heatmap(similarities_matrix, cmap='Blues', annot=True, linewidths=.5)
    plt.title('Task Similarity Heatmap - Partitioned MNIST')
    plt.xlabel('Task ID')
    plt.ylabel('Task ID')
    plt.savefig('figures/part_similarity_heatmap.png', dpi=300)
    plt.show()
    
    # Calculate and store the average accuracy and time for the partitioned MNIST dataset
    results["part"]["average_acc"] = np.average(np.array(accs))
    results["part"]["average_time"] = np.average(np.array(times))
    results["part"]["average_loss"] = np.average(np.array(losses))
               
    
    # Return the results dictionary
    return results

In [None]:
# Define the number of training rounds for the models
n_training_rounds = 10

# Start main program execution
if __name__ == "__main__":
    # Set up lists to hold accuracy, training time, and loss results for the classic approach
    classic_acc_arr = []
    classic_time_arr = []
    classic_loss_arr = []

    # Define column names for the resulting DataFrame that will hold accuracy, time, and loss data
    columns_acc = ["permutation_acc", "rotation_acc", "partition_acc"]
    columns_time = ["permutation_time", "rotation_time", "partition_time"]
    columns_loss = ["permutation_loss", "rotation_loss", "partition_loss"]

    # Iterate through each training round
    for i in range(n_training_rounds):
        # Initialise lists for the current training round
        classic_acc_arr.append([])
        classic_time_arr.append([])
        classic_loss_arr.append([])

        # Get results from the training function for classic methods
        results_classic = training_classic_ev(1, 10, 300)

        # Loop through the classic approach training results and store in corresponding lists
        for key, value in results_classic.items():
            classic_acc_arr[i].append(value["average_acc"])
            classic_time_arr[i].append(value["average_time"])
            classic_loss_arr[i].append(value["average_loss"])

    # Create DataFrames from classic approach results
    df_classic_acc = pd.DataFrame(classic_acc_arr, columns=columns_acc)
    df_classic_time = pd.DataFrame(classic_time_arr, columns=columns_time)
    df_classic_loss = pd.DataFrame(classic_loss_arr, columns=columns_loss)

    # Join accuracy, training time, and loss DataFrames
    df_classic = df_classic_acc.join(df_classic_time).join(df_classic_loss)

    # Save resulting DataFrame to a CSV file
    df_classic.to_csv("outputs/evaluation/classic_results.csv")

## Novel Approach

In [None]:
# Define a function for novel training over a specified number of epochs and tasks
# with the hidden_size set as 300
def training_novel_ev(epochs, num_tasks, hidden_size=300):
    
    # Initialise an empty dictionary to store results
    results = {}

    # Load the permuted MNIST dataset
    mnist = MNISTPerm()

    # Initialise the MultitaskFCV2 model for the tasks
    model = MultitaskFCV2(hidden_size=hidden_size, num_tasks=num_tasks)
    

    # Initialise a results dictionary for the permuted MNIST dataset
    results["perm"] = {}

    # Initialise lists to store accuracies, training times and losses for each task
    accs = []
    times = []
    losses = []
    
    
    # Loop through all tasks to load the saved models and extract batch normalisation means values
    for task_id in range(num_tasks):
        current_directory = os.getcwd()
        models_directory = os.path.join(current_directory, 'models')
        evaluation_directory = os.path.join(models_directory, 'evaluation')
        file_path = os.path.join(evaluation_directory, f'ev_permuted_model_task_{task_id}.pth')
        classic_model = MultitaskFC(hidden_size=hidden_size, num_tasks=num_tasks)
        classic_model.load_state_dict(torch.load(file_path), strict=False)
        bn_mean = classic_model.get_bn_means(task_id)
        model.bn_means[task_id] = bn_mean
        print(f"Task {task_id}: bn_mean =", bn_mean)



    # Loop over the range of tasks for training and evaluation
    for task in range(num_tasks):
  
        print(f"Current task: {task}")
        print(f"Updated bn_means (length: {len(model.bn_means)}): {model.bn_means}")
        
        # Calculate the similarities matrix based on the current state of batch normalisation means
        similarities_matrix = calculate_task_similarityE2(model.bn_means, num_tasks)

        # Calculate the alphas based on the task similarities
        alphas_per_layer = determine_alphas(similarities_matrix, task)

        # Set the alphas for the multitask masked linear layer
        model.set_alphas(alphas_per_layer)
       
        # Initialise a dictionary for each task's results
        results["perm"][task] = {}

        # Log the training task number
        print(f"Training for task {task}")
        
        # Set the current task in the model
        set_model_task(model, task)

        # Update the task in the dataset
        mnist.update_task(task)

        # Initialise the optimiser (RMSprop) for model parameters that require gradient computation
        optimizer = optim.RMSprop(
            [p for p in model.parameters() if p.requires_grad], lr=1e-4
        )

        # Record the start time for training
        start = time()

        # Loop over each epoch to train the model
        for e in range(epochs):
            # Use the training function for this model
            average_loss = trainV2(model, mnist.train_loader, optimizer, e, bn_means=model.bn_means)

            # Display validation information
            print("Validation")
            print("============")

            # Evaluate the model's performance on the validation dataset
            acc1 = evaluate(model, mnist.val_loader, e)

        # Record the end time for training
        end = time()

        # Calculate the elapsed time for training
        time_taken = end - start

        # Log the time taken for training
        print(f"Time taken: {time_taken}")
        print()

        # Store the training time, accuracy and loss in the results dictionary
        results["perm"][task]["time"] = time_taken
        results["perm"][task]["acc"] = acc1
        results["perm"][task]["loss"] = average_loss

        # Append the accuracy and time taken to their respective lists
        accs.append(acc1)
        times.append(time_taken)
        losses.append(average_loss)

        # Cache the mask states of the model
        cache_masks(model)
        print()

        # Update the number of learned tasks in the model
        set_num_tasks_learned(model, task + 1)
        print()
    

    # Calculate and store the average accuracy, training time and loss for the permuted MNIST dataset
    results["perm"]["average_acc"] = np.average(np.array(accs))
    results["perm"]["average_time"] = np.average(np.array(times))
    results["perm"]["average_loss"] = np.average(np.array(losses))


    

    
    # Load the rotated MNIST dataset
    mnist = RotatingMNIST()

    # Initialise the MultitaskFCV2 model for the tasks
    model = MultitaskFCV2(hidden_size=hidden_size, num_tasks=num_tasks)

    # Initialise a results dictionary for the rotated MNIST dataset
    results["rotate"] = {}

    # Initialise lists to store accuracies, training times and losses for each task
    accs = []
    times = []
    losses = []
    
    
    # Loop through all tasks to load the saved models and extract batch normalisation means values
    for task_id in range(num_tasks):
        current_directory = os.getcwd()
        models_directory = os.path.join(current_directory, 'models')
        evaluation_directory = os.path.join(models_directory, 'evaluation')
        file_path = os.path.join(evaluation_directory, f'ev_rotated_model_task_{task_id}.pth')
        classic_model = MultitaskFC(hidden_size=hidden_size, num_tasks=num_tasks)
        classic_model.load_state_dict(torch.load(file_path), strict=False)
        bn_mean = classic_model.get_bn_means(task_id)
        model.bn_means[task_id] = bn_mean
        print(f"Task {task_id}: bn_mean =", bn_mean)

    # Loop over the range of tasks for training and evaluation
    for task in range(num_tasks):
  
        print(f"Current task: {task}")
        print(f"Updated bn_means (length: {len(model.bn_means)}): {model.bn_means}")
        
        # Calculate the similarities matrix based on the current state of batch normalisation means
        similarities_matrix = calculate_task_similarityE2(model.bn_means, num_tasks)

        # Calculate the alphas based on the task similarities
        alphas = determine_alphas(similarities_matrix, task)

        # Set the alphas for the multitask masked linear layer
        model.set_alphas(alphas)
       
        # Initialise a dictionary for each task's results
        results["rotate"][task] = {}

        # Log the training task number
        print(f"Training for task {task}")
        
        # Set the current task in the model
        set_model_task(model, task)

        # Update the task in the dataset
        mnist.update_task(task)

        # Initialise the optimiser (RMSprop) for model parameters that require gradient computation
        optimizer = optim.RMSprop(
            [p for p in model.parameters() if p.requires_grad], lr=1e-4
        )

        # Record the start time for training
        start = time()

        # Loop over each epoch to train the model
        for e in range(epochs):
            # Use the updated training function for this model
            average_loss = trainV2(model, mnist.train_loader, optimizer, e, bn_means=model.bn_means)

            # Display validation information
            print("Validation")
            print("============")

            # Evaluate the model's performance on the validation dataset
            acc1 = evaluate(model, mnist.val_loader, e)

        # Record the end time for training
        end = time()

        # Calculate the elapsed time for training
        time_taken = end - start

        # Log the time taken for training
        print(f"Time taken: {time_taken}")
        print()

        # Store the time taken, accuracy and loss in the results dictionary
        results["rotate"][task]["time"] = time_taken
        results["rotate"][task]["acc"] = acc1
        results["rotate"][task]["loss"] = average_loss

        # Append the accuracy, training time and loss to their respective lists
        accs.append(acc1)
        times.append(time_taken)
        losses.append(average_loss)

        # Cache the mask states of the model
        cache_masks(model)
        print()

        # Update the number of learned tasks in the model
        set_num_tasks_learned(model, task + 1)
        print()
        
    
    # Calculate and store the average accuracy, training time and loss for the rotated MNIST dataset
    results["rotate"]["average_acc"] = np.average(np.array(accs))
    results["rotate"]["average_time"] = np.average(np.array(times))
    results["rotate"]["average_loss"] = np.average(np.array(losses))
    
    
    

    
    
    # Load the partitioned MNIST dataset
    mnist = PartitionMNIST()

    # Initialise the MultitaskFCV2 model for the tasks
    model = MultitaskFCV2(hidden_size=hidden_size, num_tasks=num_tasks)

    # Initialise the results dictionary for the partitioned MNIST dataset
    results["part"] = {}

    # Initialise lists to store accuracies, training times and losses for each task
    accs = []
    times = []
    losses = []
    
    
    # Loop through all tasks to load the saved models and extract the batch normaloisation means values
    for task_id in range(num_tasks):
        current_directory = os.getcwd()
        models_directory = os.path.join(current_directory, 'models')
        evaluation_directory = os.path.join(models_directory, 'evaluation')
        file_path = os.path.join(evaluation_directory, f'ev_partitioned_model_task_{task_id}.pth')
        classic_model = MultitaskFC(hidden_size=hidden_size, num_tasks=num_tasks)
        classic_model.load_state_dict(torch.load(file_path), strict=False)
        bn_mean = classic_model.get_bn_means(task_id)
        model.bn_means[task_id] = bn_mean
        print(f"Task {task_id}: bn_mean =", bn_mean)

    # Loop over the range of tasks for training and evaluation
    for task in range(num_tasks):
  
        print(f"Current task: {task}")
        print(f"Updated bn_means (length: {len(model.bn_means)}): {model.bn_means}")
        
        # Calculate the similarities matrix based on the current state of batch normalisation means
        similarities_matrix = calculate_task_similarityE2(model.bn_means, num_tasks)

        # Calculate the alphas based on the task similarities
        alphas = determine_alphas(similarities_matrix, task)

        # Set the alphas for the multitask masked linear layer
        model.set_alphas(alphas)
       
        # Initialise a dictionary for each task's results
        results["part"][task] = {}

        # Log the training task number
        print(f"Training for task {task}")
        
        # Set the current task in the model
        set_model_task(model, task)

        # Update the task in the dataset
        mnist.update_task(task)

        # Initialise the optimiser (RMSprop) for model parameters that require gradient computation
        optimizer = optim.RMSprop(
            [p for p in model.parameters() if p.requires_grad], lr=1e-4
        )

        # Record the start time for training
        start = time()

        # Loop over each epoch to train the model
        for e in range(epochs):
            # Use the training function for this model
            average_loss = trainV2(model, mnist.train_loader, optimizer, e, bn_means=model.bn_means)

            # Display validation information
            print("Validation")
            print("============")

            # Evaluate the model's performance on the validation dataset
            acc1 = evaluate(model, mnist.val_loader, e)

        # Record the end time for training
        end = time()

        # Calculate the elapsed time for training
        time_taken = end - start

        # Log the time taken for training
        print(f"Time taken: {time_taken}")
        print()

        # Store the training time, accuracy and loss in the results dictionary
        results["part"][task]["time"] = time_taken
        results["part"][task]["acc"] = acc1
        results["part"][task]["loss"] = average_loss

        # Append the accuracy, training time and loss to their respective lists
        accs.append(acc1)
        times.append(time_taken)
        losses.append(average_loss)

        # Cache the mask states of the model
        cache_masks(model)
        print()

        # Update the number of learned tasks in the model
        set_num_tasks_learned(model, task + 1)
        print()
        
    
    # Calculate and store the average accuracy, training time and loss for the partitioned MNIST dataset
    results["part"]["average_acc"] = np.average(np.array(accs))
    results["part"]["average_time"] = np.average(np.array(times))
    results["part"]["average_loss"] = np.average(np.array(losses))
    
    # Return the results dictionary
    return results

In [None]:
# Define the number of training rounds for the models
n_training_rounds = 10

# Start main program execution
if __name__ == "__main__":
    # Set up lists to hold accuracy, training time, and loss results for novel approach
    novel_acc_arr = []
    novel_time_arr = []
    novel_loss_arr = []

    # Define column names for the resulting DataFrame that will hold accuracy, training time, and loss data
    columns_acc = ["permutation_acc", "rotation_acc", "partition_acc"]
    columns_time = ["permutation_time", "rotation_time", "partition_time"]
    columns_loss = ["permutation_loss", "rotation_loss", "partition_loss"]

    # Iterate through each training round
    for i in range(n_training_rounds):
        # Initialise lists for the current training round
        novel_acc_arr.append([])
        novel_time_arr.append([])
        novel_loss_arr.append([])

        # Get results from the training functions for novel approach
        results_novel = training_novel_ev(1, 10, 300)

        # Loop through the novel approach training results and store in corresponding lists
        for key, value in results_novel.items():
            novel_acc_arr[i].append(value["average_acc"])
            novel_time_arr[i].append(value["average_time"])
            novel_loss_arr[i].append(value["average_loss"])

    # Create DataFrames from novel approach results
    df_novel_acc = pd.DataFrame(novel_acc_arr, columns=columns_acc)
    df_novel_time = pd.DataFrame(novel_time_arr, columns=columns_time)
    df_novel_loss = pd.DataFrame(novel_loss_arr, columns=columns_loss)

    # Join accuracy, training time, and loss DataFrames
    df_novel = df_novel_acc.join(df_novel_time).join(df_novel_loss)

    # Save resulting DataFrame to a CSV file
    df_novel.to_csv("outputs/evaluation/novel_results.csv")

## Post-Training Results Visualisation

In [None]:
# Read classic and novel approaches' results from CSV files
classic_df = pd.read_csv('outputs/evaluation/classic_results.csv')
novel_df = pd.read_csv('outputs/evaluation/novel_results.csv')

# Prefix column names with classic and novel
classic_df.columns = ['classic_' + col if col != 'n_training_rounds' else col for col in classic_df.columns]
novel_df.columns = ['novel_' + col if col != 'n_training_rounds' else col for col in novel_df.columns]

# Merge the DataFrames
merged_df = classic_df.merge(novel_df, left_on='n_training_rounds', right_on='n_training_rounds')

# Save merged DataFrame to CSV
merged_df.to_csv('outputs/evaluation/merged_results.csv', index=False)

# Set the default seaborn theme
sns.set_theme(style="whitegrid")

# Set the context for plotting
sns.set_context("paper", font_scale=1.2)

# Define a color palette
palette = sns.color_palette("colorblind")

# Define line styles for the classic and novel approaches' results
line_styles_classic = ['-']
line_styles_novel = ['--']

# Define markers for each dataset
markers = ['o', 'x', '^']

# Function to create a plot for each dataset
def create_plot(metric, y_label, file_prefix):
    fig, axes = plt.subplots(1, 3, figsize=(18, 5))
    fig.suptitle(y_label, fontsize=16, fontweight='bold')

    for ax, dataset, title, marker in zip(axes, ['permutation', 'rotation', 'partition'], ['Permutation', 'Rotation', 'Partition'], markers):
        classic_column = f'classic_{dataset}_{metric}'
        novel_column = f'novel_{dataset}_{metric}'

        # Use distinct line styles and markers for each dataset
        ax.plot(merged_df['n_training_rounds'], merged_df[classic_column], label='Classic ' + title, linestyle=line_styles_classic[0], marker=marker, color=palette[0], linewidth=2, markersize=6)
        ax.plot(merged_df['n_training_rounds'], merged_df[novel_column], label='Novel ' + title, linestyle=line_styles_novel[0], marker=marker, color=palette[1], linewidth=2, markersize=6)

        ax.set_title(title, fontsize=14)
        ax.set_xlabel('Training Round', fontsize=12)
        ax.set_xticks(merged_df['n_training_rounds'])
        ax.grid(True, linestyle='--')

        # Set y-axis limits and label
        min_val = min(merged_df[classic_column].min(), merged_df[novel_column].min())
        max_val = max(merged_df[classic_column].max(), merged_df[novel_column].max())
        ax.set_ylim(min_val - (max_val - min_val) * 0.05, max_val + (max_val - min_val) * 0.05)
        if dataset == 'permutation':
            ax.set_ylabel(y_label if metric != 'time' else 'Time (seconds)', fontsize=12)

    # Custom legend outside the plot
    fig.legend(loc='upper center', bbox_to_anchor=(0.5, -0.15), ncol=3, fontsize=10)
    plt.tight_layout(rect=[0, 0.1, 1, 0.95])
    
    # Save the generated figures
    plt.savefig(f'figures/{file_prefix}_{metric}.png', bbox_inches='tight', dpi=300)
    
    plt.show()

# Plot accuracy, training time and loss
for metric, y_label, file_prefix in zip(['acc', 'loss', 'time'], ['Accuracy', 'Loss', 'Training Time'], ['accuracy', 'loss', 'time']):
    create_plot(metric, y_label, file_prefix)

-------------------------------------------------------------------------------------------------------------------------------

#### Code adapted from:

* https://github.com/pytorch
* https://github.com/RAIVNLab/supsup