# MSWC FSCIL NeuroBench Tutorial

This tutorial aims to provide an insight on the MSWC FSCIL NeuroBench task and present how you can use the corresponding NeuroBench harness to benchmark your own models and solutions! In particular we give a tutorial to implement the prototypical network approach to both a convolutional and a recurrent spiking network.

## Introduction

### About FSCIL (Few-Shot Class-Incremental Learning)

Learning new tasks from a small amount of experiences while retaining knowledge of prior tasks is a hallmark of biological intelligence and a long-standing goal of general AI. It is especially a key challenge to endow edge devices with the ability to adapt to their environments and users. This benchmark thus evaluates the capacity of a learning solution to successively incorporate new classes over multiple sessions (class-incremental), with only a handful of samples from the new classes to train with (few-shot). The FSCIL task is a recently established benchmark in the computer vision domain (https://arxiv.org/abs/2004.10956), but it has not yet been adapted to other data modalities. 

### The MSWC FSCIL NeuroBench Task:
Aligning with a neuromorphic interest in temporal data modalities, this benchmark introduces a FSCIL task for streaming audio keyword classification using the large Multilingual Spoken Word Corpus (MSWC) dataset (https://mlcommons.org/datasets/multilingual-spoken-words/). The task is designed to be approached in two phases: pre-training and incremental learning:
* First, for pre-training, a set of 100 words spanning 5 base languages (English, German, Catalan, French, Kinyarwanda) with 500 training samples each are made available to train an initial model. We provide here 2 pre-trained models, a convolutional and a recurrent spiking one, both trained with gradient descent on the train samples of the 100 base keywords.

* Next, for incremental learning, the model undergoes 10 successive sessions to learn words from 10 new languages (Persian, Spanish, Russian, Welsh, Italian, Basque, Polish, Esparanto, Portuguese, Dutch) in a few-shot learning scenario. Each incremental session adds 10 words of the corresponding session language with only 5 training samples available per word. Here we give a tutorial for the prototypical network solution (https://arxiv.org/abs/1703.05175), as presented in the NeuroBench paper.

## Benchmark Task

Import the modules required for running the benchmark:

In [None]:
import copy
import os

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, ConcatDataset

from neurobench.benchmarks import Benchmark
from neurobench.datasets import MSWC
from neurobench.datasets.MSWC_IncrementalLoader import IncrementalFewShot
from tqdm import tqdm

We fix the default settings. Redefine them to your liking:

In [None]:
# data in repo root dir
ROOT = "../../../data/"

NUM_WORKERS = 8
BATCH_SIZE = 256
NUM_SHOTS = 5 # How many shots to use for evaluation

Select the desired device:

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

if device == torch.device("cuda"):
    PIN_MEMORY = True
else:
    PIN_MEMORY = False
device

First, decide if you want to go through the tutorial with the spiking neural network (SNN) or the convolutional one (CNN). Precise this by setting `SPIKING` to True (SNN) or False (CNN).

In [None]:
SPIKING = False

#### Pre-trained model loading

We don't cover the pre-training here as it follows a standard gradient descent and can take quite some time.

The pre-training step is nevertheless significant for the FSCIL performance. The models are pre-trained on the MSWC base training subset (in code: `MSWC(root=..., subset="base", procedure="training")`) which has 100 classes with 500 samples per class. The detailed pre-training procedure can be found in the _mswc_fscil.py_ code

In [None]:
from examples.mswc_fscil.M5 import M5

from examples.mswc_fscil.sparchSNNs import SNN
from examples.mswc_fscil.sparchSNNs import RadLIFLayer

MODEL_SAVE_DIR = "model_data/"  #Folder where pre-trained models are stored

We load the corresponding pre-trained model `mswc_rsnn_proto` (SNN) or `mswc_cnn_proto` (CNN) which are made available directly in the NeuroBench github repo under the `examples/mswc_fscil/model_data/` folder.

In [None]:
if SPIKING:
    model = SNN(
        input_shape=(256, 201, 20),
        neuron_type="RadLIF",
        layer_sizes=[1024, 1024, 200],
        normalization="batchnorm",
        dropout=0.1,
        bidirectional=False,
        use_readout_layer=True,
        ).to(device)
    
    state_dict = torch.load(os.path.join(MODEL_SAVE_DIR, "mswc_rsnn_proto"),
                        map_location=device)
    model.load_state_dict(state_dict)
else:
    model = M5(n_input=20, stride=2, n_channel=256, 
            n_output=200, input_kernel=4, pool_kernel=2, drop=True).to(device)

    state_dict = torch.load(os.path.join(MODEL_SAVE_DIR, "mswc_cnn_proto"),
                        map_location=device)
    model.load_state_dict(state_dict)

We display the model below. As you can see:

- The CNN model follows the multilayer M5 architecture defined in https://arxiv.org/abs/1610.00087 with a tuned kernel size to match the employed pre-processing.
- The SNN model consists of 2 recurrent spiking neuron layers and a linear readout layer, adapted from the sparchSNN library https://github.com/idiap/sparch. The spiking neurons are leaky integrate and fire neurons with an extra adaptive variable to mitigate the impact of average activity. All neuron parameters are trained heterogeneously during pre-training.

In [None]:
model

Then, we convert the model to a NeuroBench TorchModel to allow for computational metric benchmarking. This creates hooks to the model activity functions. The neural network itself is now stored in `model.net`.

In [None]:
from neurobench.models import TorchModel

model = TorchModel(model)

For manually defined activation modules, like the adapative LIF neuron used for the SNN model, we need to add this hook manually.

In [None]:
if SPIKING: 
    model.add_activation_module(RadLIFLayer)

### Pre-processing

For the proposed solution, we employ a state-of-the-art pre-processing, namely Mel Frequency Cepstral Coefficients (MFCC) to extract relevant frequency-based coefficients. We employ the torchaudio MFCC processor (https://pytorch.org/audio/main/generated/torchaudio.transforms.MFCC.html) and tune the hop length to fix the resolution to 200Hz and the number of mel coefficients to 20 for a reasonable number of input channels to the network.

For the _spiking_ solution, a delta-encoding is added on top of MFCC to convert the signals to spikes. This is done with the Speech2Spike pipeline (https://dl.acm.org/doi/abs/10.1145/3584954.3584995) that has directly been integrated in NeuroBench. We note that this adds a spiking threshold as an extra parameter. It was fixed to 1 following the Speech2Spikes initial observations here. 

In [None]:
from neurobench.preprocessing import MFCCPreProcessor, S2SPreProcessor

n_fft = 512
win_length = None
hop_length = 240
n_mels = 20
n_mfcc = 20

if SPIKING:
    encode = S2SPreProcessor(device, transpose=True)
    config_change = {"sample_rate": 48000,
                     "hop_length": 240}
    encode.configure(threshold=1.0, **config_change)
else:
    encode = MFCCPreProcessor(
        sample_rate=48000,
        n_mfcc=n_mfcc,
        melkwargs={
            "n_fft": n_fft,
            "n_mels": n_mels,
            "hop_length": hop_length,
            "mel_scale": "htk",
            "f_min": 20,
            "f_max": 4000,
        },
        device = device
    )

### Preparation for Prototypical Continual Learning

Before we can start using the prototypical network approach for learning incremental classes, we need to align the pre-trained model with this approach. The prototypical network approach (https://arxiv.org/abs/1703.05175) indeed relies on implementing a clustering protocol, based on the pre-trained feature extractor, as a linear readout layer; but this requires all parameters of this readout layer to be defined accordingly. Thus we first redefine the readout layer for the 100 base classes following the prototypical network approach (such that they will align with the incremental classes prototypical readout parameters).

Before continuing, we load the base training dataset that is the data available to generate the prototypical representations for the base classes. If the MSWC FSCIL dataset is not already available at `ROOT`, the entire dataset will first be downloaded from _Hugging Face_ at the following address: https://huggingface.co/datasets/NeuroBench/mswc_fscil_subset.

In [None]:
base_train_set = MSWC(root=ROOT, subset="base", procedure="training")

Then we create a dataloader **without shuffling** and with a batch_size of 500, which, following the definition of the dataset, will provide all samples of 1 class at each new batch.

In [None]:
train_loader = DataLoader(base_train_set, batch_size=500, num_workers=NUM_WORKERS, pin_memory=PIN_MEMORY)

The prototypical readout parameters are defined based on the mean extracted feature $c_k$ from all training sample of the corresponding class $k$, which we get by passing all input samples through the backbone of the pre-trained network (all layers except the readout one). The prototypical weights and biases for class $k$ then are: $W_k = 2c_k, \ \ b_k=c_kc_k^T$.

To do so, we first define a new readout layer supporting 200 classes (100 base classes + 100 incrementally learned classes) that will replace the pre-trained one:

In [None]:
# Set-up new proto readout layer
if SPIKING:
    output = model.net.snn[-1].W
    proto_out = nn.Linear(output.weight.shape[1], 200, bias=True).to(device)
    proto_out.weight.data = output.weight.data
else:
    output = model.net.output
    proto_out = nn.Linear(512, 200, bias=True).to(device)
    proto_out.weight.data = output.weight.data

Then we pass through each of the base training classes, get all of the 500 associated sample feature, average them and define the weights and biases accordingly.

Just note that for the _spiking_ solution, the features are summed over time and thus the bias is also divided by the number of total timesteps.

_Note_: This procedure can take a bit of time.

In [None]:
# Compute prototype weights for base classes

for data, target in tqdm(train_loader):
    data, target = encode((data.to(device), target.to(device)))
    data = data.squeeze()
    class_id = target[0]

    if SPIKING:
        features = data
        for layer in model.net.snn[:-1]:
            features = layer(features)

        mean = torch.sum(features, dim=[0,1])/500
        proto_out.weight.data[class_id] = 2*mean
        proto_out.bias.data[class_id] = -torch.matmul(mean, mean.t())/features.shape[1]

    else:
        features = model.net(data, features_out=True)

        mean = torch.sum(features, dim=0)/500
        proto_out.weight.data[class_id] = 2*mean
        proto_out.bias.data[class_id] = -torch.matmul(mean, mean.t())

    del data
    del features
    del mean

Finally we replace the pre-trained readout layer by the newly defined prototypical one:

In [None]:

# Replace pre-trained readout with prototypical layer
if SPIKING:
    model.net.snn[-1].W = proto_out
else:
    model.net.output = proto_out

del base_train_set
del train_loader

Next, we test the performance of the prototypical representations on the base test set using a NeuroBench Benchmark:

In [None]:
# Copy model for evaluation
eval_model = copy.deepcopy(model)

# Get base test set for evaluation
base_test_set = MSWC(root=ROOT, subset="base", procedure="testing")
test_loader = DataLoader(base_test_set, batch_size=BATCH_SIZE, num_workers=NUM_WORKERS, pin_memory=PIN_MEMORY)

# Put the model in evaluation mode
eval_model.net.eval()

As NeuroBench Benchmarks encapsulate the whole testing, it requires some pre and post-processors to manipulate data before and aftera network pass. We thus define the following utility functions:

In [None]:
squeeze = lambda x: (x[0].squeeze(), x[1])
out2pred = lambda x: torch.argmax(x, dim=-1)
to_device = lambda x: (x[0].to(device), x[1].to(device))

We also define a mask function for this evaluation as the network is directly defined with 200 output neurons but we are for now evaluating the performance solely on the 100 base classes:

In [None]:
# Define specific post-processing with masking on the base classes
mask = torch.full((200,), float('inf')).to(device)
mask[torch.arange(0,100, dtype=int)] = 0
out_mask = lambda x: x - mask

Now we can define the Benchmark object with the desired metrics:

In [None]:
# Metrics
static_metrics = ["footprint", "connection_sparsity"]
workload_metrics = ["classification_accuracy", "activation_sparsity", "synaptic_operations"]

# Define benchmark object
benchmark_all_test = Benchmark(eval_model, metric_list=[static_metrics, workload_metrics], 
                               dataloader=test_loader, 
                               preprocessors=[to_device, encode, squeeze], postprocessors=[])

We now run the Benchmark on the base test set:

In [None]:
pre_train_results = benchmark_all_test.run(postprocessors=[out_mask, F.softmax, out2pred, torch.squeeze])

print("Base results:", pre_train_results)

print(f"The base accuracy is {pre_train_results['classification_accuracy']*100}%")

This is the performance of session 0.

Note that the obtained accuracy, after conversion to prototypes, is below the original performance of the pre-trained model. This is a price to pay to allow for the prototypical network to work effectively in the incremental sessions. This could nevertheless still be improved upon, especially for the _spiking_ solution, where the conversion accuracy drop is significant (from 93% to 84%).

### Incremental Learning

We can now pursue with the few-shot incremental sessions. New sessions are learned following the prototypical network approach on the corresponding session classes and with the limited number of samples available.

We first initialize the FSCIL dataloader. It will generate 10 sessions from a random ordering of the 10 incremental languages. Each session consists of
- One `support` _list_ of `NUM_SHOTS` shots, each shot being a tuple of tensors `(X_shot, y_shot)` with one sample for each of the 10 session classes.  
- One `query` _dataset_ with all the current and prior incremental session classes and `query_shots` samples per class.
- One `query_classes` list that contains each unique incremental class index following their order of appearance.

Note that the `support_query_split` is here to define a pre-sampling split between samples available for support and for query in this order. In the proposed set-up, the few-shot dataloader thus fixes the 100 query samples per class from the start and samples 5 shots out of a 100 samples for each incremental class:

In [None]:
# IncrementalFewShot Dataloader used in incremental mode to generate class-incremental sessions
few_shot_dataloader = IncrementalFewShot(k_shot=NUM_SHOTS, 
                            root = ROOT,
                            query_shots=100,
                            support_query_split=(100,100))

We then run one incremental session learning as an example:

In [None]:
support, query, query_classes = next(iter(few_shot_dataloader))

The support data - which is generated in a shot-by-shot way for universality to different methods - is here concatenated to gather all training samples per class for the prototypical approach:

In [None]:
data = None

for X_shot, y_shot in support:
    if data is None:
        data = X_shot
        target = y_shot
    else:
        data = torch.cat((data,X_shot), 0)
        target = torch.cat((target,y_shot), 0)

data, target = encode((data.to(device), target.to(device)))
data = data.squeeze()

new_classes = y_shot.tolist()
Nways = len(y_shot) # Number of ways of one batch, should always be 10

We then apply the prototypical network approach on the corresponding incremental classes:

In [None]:
if SPIKING:
    features = eval_model.net.snn[0](data)
    features = eval_model.net.snn[1](features)

    for index, class_id in enumerate(new_classes):
        mean = torch.sum(features[[i*Nways+index for i in range(NUM_SHOTS)]], dim=[0,1])/NUM_SHOTS
        eval_model.net.snn[-1].W.weight.data[class_id] = 2*mean
        eval_model.net.snn[-1].W.bias.data[class_id] = -torch.matmul(mean, mean.t())/(features.shape[1])
else:
    features = eval_model.net(data, features_out=True)

    for index, class_id in enumerate(new_classes):
        mean = torch.sum(features[[i*Nways+index for i in range(NUM_SHOTS)]], dim=0)/NUM_SHOTS
        eval_model.net.output.weight.data[class_id] = 2*mean
        eval_model.net.output.bias.data[class_id] = -torch.matmul(mean, mean.t())

Then we evaluate the performance after one FSCIL session. The default FSCIL benchmarking evaluates accuracy on all classes seen so far, including the base classes used for pre-training. To this, we add an evaluation of the performance solely on the incremental few-shot classes, corresponding to only the `query` dataset.

Note that the dataloaders used for the benchmarking are actually redefined when running the Benchmark object. This is to be aligned with the general case of multiple sessions (see cell below) as the data to test on changes over sessions in a FSCIL task.

In [None]:
# Define benchmark object for incremental classes
benchmark_new_classes = Benchmark(eval_model, metric_list=[[],["classification_accuracy"]], 
                                  dataloader=None,
                                  preprocessors=[to_device, encode, squeeze], postprocessors=[])    

### Testing phase ###
eval_model.net.eval()

# Define session dataloaders for query and query + base_test samples
query_loader = DataLoader(query, batch_size=BATCH_SIZE, num_workers=NUM_WORKERS)

full_session_test_set = ConcatDataset([base_test_set, query])
full_session_test_loader = DataLoader(full_session_test_set, batch_size=BATCH_SIZE, num_workers=NUM_WORKERS)

# Create a mask function to only consider accuracy on classes presented so far
session_classes = torch.cat((torch.arange(0,100, dtype=int), torch.IntTensor(query_classes))) 
mask = torch.full((200,), float('inf')).to(device)
mask[session_classes] = 0
out_mask = lambda x: x - mask


# Run benchmark on query classes only
query_results = benchmark_new_classes.run(dataloader = query_loader, 
                                          postprocessors=[out_mask, F.softmax, out2pred, torch.squeeze])
print(f"Accuracy on new classes: {query_results['classification_accuracy']*100} %")

# Run benchmark to evaluate accuracy of this specific session
session_results = benchmark_all_test.run(dataloader = full_session_test_loader, 
                                         postprocessors=[out_mask, F.softmax, out2pred, torch.squeeze])
print(f"Session accuracy: {session_results['classification_accuracy']*100} %")

Finally, we can run the full FSCIL setup by looping over the code as presented above for all 10 sessions:

_Note_: This can take a bit of time as FSCIL requires for increasingly heavy datasets to be loaded in memory.

In [None]:
# Iteration over incremental sessions
for session, (support, query, query_classes) in enumerate(few_shot_dataloader):
    print(f"Session: {session+1}")

    ### Computing new Prototypical Weights ###
    data = None
    
    for X_shot, y_shot in support:
        if data is None:
            data = X_shot
            target = y_shot
        else:
            data = torch.cat((data,X_shot), 0)
            target = torch.cat((target,y_shot), 0)

    data, target = encode((data.to(device), target.to(device)))
    data = data.squeeze()

    new_classes = y_shot.tolist()
    Nways = len(y_shot) # Number of ways, should always be 10

    if SPIKING:
        features = eval_model.net.snn[0](data)
        features = eval_model.net.snn[1](features)

        for index, class_id in enumerate(new_classes):
            mean = torch.sum(features[[i*Nways+index for i in range(NUM_SHOTS)]], dim=[0,1])/NUM_SHOTS
            eval_model.net.snn[-1].W.weight.data[class_id] = 2*mean
            eval_model.net.snn[-1].W.bias.data[class_id] = -torch.matmul(mean, mean.t())/(features.shape[1])
    else:
        features = eval_model.net(data, features_out=True)

        for index, class_id in enumerate(new_classes):
            mean = torch.sum(features[[i*Nways+index for i in range(NUM_SHOTS)]], dim=0)/NUM_SHOTS
            eval_model.net.output.weight.data[class_id] = 2*mean
            eval_model.net.output.bias.data[class_id] = -torch.matmul(mean, mean.t())

    ### Testing phase ###
    eval_model.net.eval()

    # Define session dataloaders for query and query + base_test samples
    query_loader = DataLoader(query, batch_size=BATCH_SIZE, num_workers=NUM_WORKERS)
    
    full_session_test_set = ConcatDataset([base_test_set, query])
    full_session_test_loader = DataLoader(full_session_test_set, batch_size=BATCH_SIZE, num_workers=NUM_WORKERS)

    # Create a mask function to only consider accuracy on classes presented so far
    session_classes = torch.cat((torch.arange(0,100, dtype=int), torch.IntTensor(query_classes))) 
    mask = torch.full((200,), float('inf')).to(device)
    mask[session_classes] = 0
    out_mask = lambda x: x - mask

    # Run benchmark on query classes only
    query_results = benchmark_new_classes.run(dataloader = query_loader, postprocessors=[out_mask, F.softmax, out2pred, torch.squeeze])
    print(f"Accuracy on new classes: {query_results['classification_accuracy']*100} %")

    # Run benchmark to evaluate accuracy of this specific session
    session_results = benchmark_all_test.run(dataloader = full_session_test_loader, postprocessors=[out_mask, F.softmax, out2pred, torch.squeeze])
    print(f"Session accuracy: {session_results['classification_accuracy']*100} %")
    print("Session results:", session_results)


You should obtain a performance within the bounds presented in the results plot below. The shaded area represents $5^{th}$ and $95^{th}$ percentile on 100 runs.

![title](img/FSCIL_proto_results.png)