# MSWC FSCIL Neurobench Tutorial

This tutorial aims to provide an insight on the MSWC FSCIL NeuroBench task and present how you can use the corresponding Neurobench harness to benchmark your own models and solutions! In particular we give a tutorial to implement the prototypical network approach to both a convolutional and a recurrent spiking network.

## Introduction:

### About FSCIL (Few-Shot Class-Incremental Learning)

Learning new tasks from a small amount of experiences while retaining knowledge of prior tasks is a hallmark of biological intelligence and a long-standing goal of general AI. It is especially a key challenge to endow edge devices with the ability to adapt to their environments and users. This benchmark thus evaluates the capacity of a learning solution to successively incorporate new classes over multiple sessions (class-incremental), with only a handful of samples from the new classes to train with (few-shot). The FSCIL task is a recently established benchmark in the computer vision domain, but it has not yet been adapted to other data modalities. 

### The MSWC FSCIL Neurobench Task:
Aligning with a neuromorphic interest in temporal data modalities, this benchmark introduces a FSCIL task for streaming audio keyword classification using the large Multilingual Spoken Word Corpus (MSWC) dataset. The task is designed to be approached in two phases: pre-training and incremental learning:
* First, for pre-training, a set of 100 words spanning 5 base languages (English, German, Catalan, French, Kinyarwanda) with 500 training samples each are made available to train an initial model. We provide here 2 pre-trained models, a convolutional and a recurrent spiking one, both trained with gradient descent on the train samples of the 100 base keywords.

* Next, for incremental learning, the model undergoes 10 successive sessions to learn words from 10 new languages (Persian, Spanish, Russian, Welsh, Italian, Basque, Polish, Esparanto, Portuguese, Dutch) in a few-shot learning scenario. Each incremental session adds 10 words of the corresponding session language with only 5 training samples available per word. Here we give a tutorial for the prototypical network solution, as presented in the Neurobench paper.

## Importing the Dataset

The dataset is hosted on _Hugging Face_ at the following address: https://huggingface.co/datasets/NeuroBench/mswc_fscil_subset

The following code downloads the zipped file in a _data_ folder in the neurobench repo.

In [7]:
import requests
import tarfile
import os

relative_data_folder = '../../data'

# URL of the .tar.gz file
url = 'https://huggingface.co/datasets/NeuroBench/mswc_fscil_subset/resolve/main/mswc_fscil.tar.gz'

# Download the file: It is less than 1 Gb, should be quick.
response = requests.get(url)
filename = os.path.join(relative_data_folder, url.split('/')[-1])

with open(filename, 'wb') as file:
    file.write(response.content)

Then we extract the dataset:

In [None]:
from tqdm import tqdm

# Open the tar file
with tarfile.open(filename, 'r:gz') as tar:
    # Get the total number of files within the tar archive (for the progress bar)
    total_files = len(tar.getmembers())

    # Set up the tqdm progress bar
    with tqdm(total=total_files, unit='file', desc='Extracting files') as progress_bar:
        for member in tar.getmembers():
            # Extract each member
            tar.extract(member, path=relative_data_folder)  # Replace with your desired path
            # Update the progress bar
            progress_bar.update(1)

### Benchmark Task:

First, load your model that is pre-trained on the MSWC base training subset (in code: `MSWC(root=..., subset="base", procedure="training")`). Note that is should not have a classification layer at the end, as this will be added by the benchmark.

In [None]:
model = ...

Then, convert it to a NeuroBench TorchModel:

In [None]:
from neurobench.models import TorchModel

model = TorchModel(model)

Redefine the following constants as per your liking:

In [None]:
ROOT = "./FSCIL_subset/" # Where the MSWC dataset is stored
NUM_WORKERS = 8
BATCH_SIZE = 256
NUM_REPEATS = 5 # How many times to repeat the experiment to get aggregate statistics
SPIKING = False
EVAL_SHOTS = 5 # How many shots to use for evaluation
EVAL_WAYS = 10 # How many ways to use for evaluation

Import the modules required for running the benchmark:

In [None]:
import copy

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, ConcatDataset

from neurobench.benchmarks import Benchmark
from neurobench.datasets import MSWC
from neurobench.datasets.MSWC_IncrementalLoader import IncrementalFewShot

from mswc_fscil import to_device, squeeze, out2pred

Select the desired device:

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

if device == torch.device("cuda"):
    PIN_MEMORY = True
else:
    PIN_MEMORY = False

Define the MFCC pre-processing:

In [None]:
from neurobench.preprocessing import MFCCProcessor, S2SProcessor

n_fft = 512
win_length = None
hop_length = 240
n_mels = 20
n_mfcc = 20

if SPIKING:
    encode = S2SProcessor(device, transpose=False)
    config_change = {"sample_rate": 48000,
                     "hop_length": 240}
    encode.configure(threshold=1.0, **config_change)
else:
    encode = MFCCProcessor(
        sample_rate=48000,
        n_mfcc=n_mfcc,
        melkwargs={
            "n_fft": n_fft,
            "n_mels": n_mels,
            "hop_length": hop_length,
            "mel_scale": "htk",
            "f_min": 20,
            "f_max": 4000,
        },
        device = device
    )

Load the base training dataset to generate the prototypical representations for the base classes:

In [None]:
base_train_set = MSWC(root=ROOT, subset="base", procedure="training")
train_loader = DataLoader(base_train_set, batch_size=500, num_workers=NUM_WORKERS, pin_memory=PIN_MEMORY)

Generate the prototypical representations for the base classes and store them in a new linear layer:

In [None]:
if SPIKING:
    output = model.net.snn[-1].W
    proto_out = nn.Linear(output.weight.shape[1], 200, bias=True).to(device)
    proto_out.weight.data = output.weight.data
else:
    output = model.net.output
    proto_out = nn.Linear(512, 200, bias=True).to(device)
    proto_out.weight.data = output.weight.data

for data, target in train_loader:
    data, target = encode((data.to(device), target.to(device)))
    data = data.squeeze()
    class_id = target[0]

    if SPIKING:
        features = data
        for layer in model.net.snn[:-1]:
            features = layer(features)

        mean = torch.sum(features, dim=[0,1])/500
        proto_out.weight.data[class_id] = 2*mean
        proto_out.bias.data[class_id] = -torch.matmul(mean, mean.t())/features.shape[1]

    else:
        features = model.net(data, features_out=True)

        mean = torch.sum(features, dim=0)/500
        proto_out.weight.data[class_id] = 2*mean
        proto_out.bias.data[class_id] = -torch.matmul(mean, mean.t())

    del data
    del features
    del mean

if SPIKING:
    model.net.snn[-1].W = proto_out
else:
    model.net.output = proto_out

del base_train_set
del train_loader


Next, test the performance of the prototypical representations on the base test set using a Neurobench Benchmark:

In [None]:
### Evaluation phase ###

eval_model = copy.deepcopy(model)

eval_accs = []
query_accs = []
act_sparsity = []
syn_ops_dense = []
syn_ops_macs = []
syn_ops_acs = []

# Get base test set for evaluation
base_test_set = MSWC(root=ROOT, subset="base", procedure="testing")
test_loader = DataLoader(base_test_set, batch_size=256, num_workers=NUM_WORKERS, pin_memory=PIN_MEMORY)

# Define an arbitrary resampling as an example of pre-processor to feed to the Benchmark object
eval_model.net.eval()

# Metrics
static_metrics = ["model_size", "connection_sparsity"]
workload_metrics = ["classification_accuracy", "activation_sparsity", "synaptic_operations"]

# Define benchmark object
benchmark_all_test = Benchmark(eval_model, metric_list=[static_metrics, workload_metrics], dataloader=test_loader, 
                    preprocessors=[to_device, encode, squeeze], postprocessors=[])

benchmark_new_classes = Benchmark(eval_model, metric_list=[[],["classification_accuracy"]], dataloader=test_loader,
                    preprocessors=[to_device, encode, squeeze], postprocessors=[])

# Define specific post-processing with masking on the base classes
mask = torch.full((200,), float('inf')).to(device)
mask[torch.arange(0,100, dtype=int)] = 0
out_mask = lambda x: x - mask

# Run session 0 benchmark on base classes
print(f"Session: 0")

pre_train_results = benchmark_all_test.run(postprocessors=[out_mask, F.softmax, out2pred, torch.squeeze])

print("Base results:", pre_train_results)

eval_accs.append(pre_train_results['classification_accuracy'])
act_sparsity.append(pre_train_results['activation_sparsity'])
syn_ops_dense.append(pre_train_results['synaptic_operations']['Dense'])
syn_ops_macs.append(pre_train_results['synaptic_operations']['Effective_MACs'])
syn_ops_acs.append(pre_train_results['synaptic_operations']['Effective_ACs'])

print(f"The base accuracy is {eval_accs[-1]*100}%")

The accuracy on the base classes using prototypical representations should be a bit lower than the accuracy of the original model. This is due to the conversion from the original backpropagation-trained readout classifier to the prototype readout classifier.

Next, the dataloader for the few-shot sessions is initialized:

In [None]:
# IncrementalFewShot Dataloader used in incremental mode to generate class-incremental sessions
few_shot_dataloader = IncrementalFewShot(n_way=EVAL_WAYS, k_shot=EVAL_SHOTS, 
                            root = ROOT,
                            query_shots=100,
                            support_query_split=(100,100),
                            samples_per_class=200)

Finally, it is possible can run the incremental learning benchmark! Note however, that the code below supports both spiking and non-spiking networks (hence the flag `SPIKING` in an earlier cell). For more details on how the spiking network is constructed, please refer to [`mswc_fscil_proto.py`](./mswc_fscil_proto.py). Note that in the code below, only 'one' repeat is performed whereas in the referenced Python file, `n` repeats are executed.

In [None]:
# Iteration over incremental sessions
for session, (support, query, query_classes) in enumerate(few_shot_dataloader):
    print(f"Session: {session+1}")

    # Define benchmark object
    benchmark_all_test = Benchmark(eval_model, metric_list=[static_metrics, workload_metrics], dataloader=test_loader, 
                        preprocessors=[to_device, encode, squeeze], postprocessors=[])

    benchmark_new_classes = Benchmark(eval_model, metric_list=[[],["classification_accuracy"]], dataloader=test_loader,
                        preprocessors=[to_device, encode, squeeze], postprocessors=[])

    ### Computing new Prototypical Weights ###
    data = None
    
    for X_shot, y_shot in support:
        if data is None:
            data = X_shot
            target = y_shot
        else:
            data = torch.cat((data,X_shot), 0)
            target = torch.cat((target,y_shot), 0)

    data, target = encode((data.to(device), target.to(device)))
    data = data.squeeze()

    if SPIKING:
        features = eval_model.net.snn[0](data)
        features = eval_model.net.snn[1](features)

        for index, class_id in enumerate(query_classes[-EVAL_WAYS:]):
            mean = torch.sum(features[[i*EVAL_WAYS+index for i in range(EVAL_SHOTS)]], dim=[0,1])/EVAL_SHOTS
            eval_model.net.snn[-1].W.weight.data[class_id] = 2*mean
            eval_model.net.snn[-1].W.bias.data[class_id] = -torch.matmul(mean, mean.t())/(features.shape[1])
    else:
        features = eval_model.net(data, features_out=True)

        for index, class_id in enumerate(query_classes[-EVAL_WAYS:]):
            mean = torch.sum(features[[i*EVAL_WAYS+index for i in range(EVAL_SHOTS)]], dim=0)/EVAL_SHOTS
            eval_model.net.output.weight.data[class_id] = 2*mean
            eval_model.net.output.bias.data[class_id] = -torch.matmul(mean, mean.t())

    ### Testing phase ###
    eval_model.net.eval()

    # Define session dataloaders for query and query + base_test samples
    query_loader = DataLoader(query, batch_size=BATCH_SIZE, num_workers=NUM_WORKERS)
    
    full_session_test_set = ConcatDataset([base_test_set, query])
    full_session_test_loader = DataLoader(full_session_test_set, batch_size=BATCH_SIZE, num_workers=NUM_WORKERS)

    # Create a mask function to only consider accuracy on classes presented so far
    session_classes = torch.cat((torch.arange(0,100, dtype=int), torch.IntTensor(query_classes))) 
    mask = torch.full((200,), float('inf')).to(device)
    mask[session_classes] = 0
    out_mask = lambda x: x - mask

    # Run benchmark to evaluate accuracy of this specific session
    session_results = benchmark_all_test.run(dataloader = full_session_test_loader, postprocessors=[out_mask, F.softmax, out2pred, torch.squeeze])
    
    print("Session results:", session_results)
