# MSWC FSCIL Benchmark Tutorial

This tutorial aims to provide an insight on how the NeuroBench framework is organized and how you can use it to benchmark your own models!

## About MSCW FSCIL (Keyword Few-Shot Class-Incremental Learning)

Learning new tasks from a small amount of experiences while retaining knowledge of prior tasks is a hallmark of biological intelligence and a long-standing goal of general AI. It is especially a key challenge to endow edge devices with the ability to adapt to their environments and users. This benchmark thus evaluates the capacity of a model to successively incorporate new keywords over multiple sessions (class-incremental), with only a handful of samples from the new classes to train with (few-shot). The FSCIL task is a recently established benchmark in the computer vision domain, but it has not yet been adapted to other data modalities. 

### Dataset:
Aligning with a neuromorphic interest in temporal data modalities, this benchmark introduces a FSCIL task with streaming audio data using the large Multilingual Spoken Word Corpus (MSWC) keyword classification dataset. The task is designed to be approached in two phases: pre-training and incremental learning. First, for pre-training, a set of 100 words spanning 5 base languages (English, German, Catalan, French, Kinyarwanda) with 500 training samples each are made available to train an initial model. Next, for incremental learning, the model undergoes 10 successive sessions to learn words from 10 new languages (Persian, Spanish, Russian, Welsh, Italian, Basque, Polish, Esparanto, Portuguese, Dutch) in a few-shot learning scenario. Each incremental session adds 10 words of the corresponding session language with only 5 training samples available per word. After each session, the model is tested in classification accuracy on all prior learned classes, including the 100 base pre-training classes and the few-shot-learned classes, therefore evaluating the FSCIL solution on its ability to learn new classes while retaining knowledge about the previously learned ones. Each session learns a new language, for a total knowledge base of 200 keywords by the end of the benchmark.


### Benchmark Task:



First, load your model. Note that is should not have a classification layer at the end, as this will be added by the benchmark.

In [None]:
model = ...

Then, convert it to a NeuroBench TorchModel:

In [None]:
from neurobench.models import TorchModel

model = TorchModel(model)

Redefine the following constants as per your liking:

In [None]:
ROOT = "./FSCIL_subset/"
NUM_WORKERS = 8
BATCH_SIZE = 256
NUM_REPEATS = 5 # How many times to repeat the experiment to get aggregate statistics
SPIKING = False
EVAL_SHOTS = 5 # How many shots to use for evaluation

Import the modules required for running the benchmark:

In [None]:
import copy

import numpy as np

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, ConcatDataset

from neurobench.benchmarks import Benchmark
from neurobench.datasets import MSWC
from neurobench.datasets.MSWC_IncrementalLoader import IncrementalFewShot

from mswc_fscil_proto import to_device, squeeze, examples_per_class, out2pred

Select the desired device:

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

if device == torch.device("cuda"):
    PIN_MEMORY = True
else:
    PIN_MEMORY = False

Define the MFCC pre-processing:

In [None]:
from neurobench.preprocessing import MFCCProcessor, S2SProcessor

n_fft = 512
win_length = None
hop_length = 240
n_mels = 20
n_mfcc = 20

if SPIKING:
    encode = S2SProcessor(device, transpose=False, soft_delta=args.soft_delta)
    config_change = {"sample_rate": 48000,
                     "hop_length": 240}
    encode.configure(threshold=1.0, **config_change)
else:
    encode = MFCCProcessor(
        sample_rate=48000,
        n_mfcc=n_mfcc,
        melkwargs={
            "n_fft": n_fft,
            "n_mels": n_mels,
            "hop_length": hop_length,
            "mel_scale": "htk",
            "f_min": 20,
            "f_max": 4000,
        },
        device = device
    )

Then, you can run the benchmark! Note however, that the code below supports both spiking and non-spiking networks (hence the flag `SPIKING` in an earlier cell). For more details on how the spiking network is constructed, please refer to [`mswc_fscil_proto.py`](./mswc_fscil_proto.py).

In [None]:
all_evals = []
all_query = []
all_act_sparsity = []
all_syn_ops_dense = []
all_syn_ops_macs = []
all_syn_ops_acs = []

base_train_set = MSWC(root=ROOT, subset="base", procedure="training")
train_loader = DataLoader(base_train_set, batch_size=500, num_workers=NUM_WORKERS, pin_memory=PIN_MEMORY)

if SPIKING:
    output = model.net.snn[-1].W
    proto_out = nn.Linear(output.weight.shape[1], 200, bias=True).to(device)
    proto_out.weight.data = output.weight.data
else:
    output = model.net.output
    proto_out = nn.Linear(512, 200, bias=True).to(device)
    proto_out.weight.data = output.weight.data

for data, target in train_loader:
    data, target = encode((data.to(device), target.to(device)))
    data = data.squeeze()
    class_id = target[0]

    if SPIKING:
        features = data
        for layer in model.net.snn[:-1]:
            features = layer(features)

        mean = torch.sum(features, dim=[0,1])/500
        proto_out.weight.data[class_id] = 2*mean
        proto_out.bias.data[class_id] = -torch.matmul(mean, mean.t())/features.shape[1]

    else:
        features = model.net(data, features_out=True)

        mean = torch.sum(features, dim=0)/500
        proto_out.weight.data[class_id] = 2*mean
        proto_out.bias.data[class_id] = -torch.matmul(mean, mean.t())

    del data
    del features
    del mean

if SPIKING:
    model.net.snn[-1].W = proto_out
else:
    model.net.output = proto_out

del base_train_set
del train_loader

for eval_iter in range(NUM_REPEATS):
    print(f"Evaluation Iteration: 0")
    ### Evaluation phase ###

    eval_model = copy.deepcopy(model)

    eval_accs = []
    query_accs = []
    act_sparsity = []
    syn_ops_dense = []
    syn_ops_macs = []
    syn_ops_acs = []

    # Get base test set for evaluation
    base_test_set = MSWC(root=ROOT, subset="base", procedure="testing")
    test_loader = DataLoader(base_test_set, batch_size=256, num_workers=NUM_WORKERS, pin_memory=PIN_MEMORY)

    # Define an arbitrary resampling as an example of pre-processor to feed to the Benchmark object
    eval_model.net.eval()

    # Metrics
    static_metrics = ["model_size", "connection_sparsity"]
    workload_metrics = ["classification_accuracy", "activation_sparsity", "synaptic_operations"]

    # Define benchmark object
    benchmark_all_test = Benchmark(eval_model, metric_list=[static_metrics, workload_metrics], dataloader=test_loader, 
                        preprocessors=[to_device, encode, squeeze], postprocessors=[])

    benchmark_new_classes = Benchmark(eval_model, metric_list=[[],["classification_accuracy"]], dataloader=test_loader,
                        preprocessors=[to_device, encode, squeeze], postprocessors=[])

    # Define specific post-processing with masking on the base classes
    mask = torch.full((200,), float('inf')).to(device)
    mask[torch.arange(0,100, dtype=int)] = 0
    out_mask = lambda x: x - mask

    # Run session 0 benchmark on base classes
    print(f"Session: 0")

    pre_train_results = benchmark_all_test.run(postprocessors=[out_mask, F.softmax, out2pred, torch.squeeze])
    
    print("Base results:", pre_train_results)
    
    eval_accs.append(pre_train_results['classification_accuracy'])
    act_sparsity.append(pre_train_results['activation_sparsity'])
    syn_ops_dense.append(pre_train_results['synaptic_operations']['Dense'])
    syn_ops_macs.append(pre_train_results['synaptic_operations']['Effective_MACs'])
    syn_ops_acs.append(pre_train_results['synaptic_operations']['Effective_ACs'])
    
    print(f"The base accuracy is {eval_accs[-1]*100}%")

    # IncrementalFewShot Dataloader used in incremental mode to generate class-incremental sessions
    few_shot_dataloader = IncrementalFewShot(n_way=10, k_shot=EVAL_SHOTS, 
                                root = ROOT,
                                query_shots=100,
                                support_query_split=(100,100),
                                samples_per_class=200)

    # Iteration over incremental sessions
    for session, (support, query, query_classes) in enumerate(few_shot_dataloader):
        print(f"Session: {session+1}")

        # Define benchmark object
        benchmark_all_test = Benchmark(eval_model, metric_list=[static_metrics, workload_metrics], dataloader=test_loader, 
                            preprocessors=[to_device, encode, squeeze], postprocessors=[])

        benchmark_new_classes = Benchmark(eval_model, metric_list=[[],["classification_accuracy"]], dataloader=test_loader,
                            preprocessors=[to_device, encode, squeeze], postprocessors=[])
        
        cur_class = support[0][1].tolist()
        eval_model.net.cur_j = examples_per_class(cur_class, 200, 5)

        ### Computing new Protypical Weights ###
        data = None
        
        for X_shot, y_shot in support:
            if data is None:
                data = X_shot
                target = y_shot
            else:
                data = torch.cat((data,X_shot), 0)
                target = torch.cat((target,y_shot), 0)

        data, target = encode((data.to(device), target.to(device)))
        data = data.squeeze()

        if SPIKING:
            features = eval_model.net.snn[0](data)
            features = eval_model.net.snn[1](features)

        else:
            features = eval_model.net(data, features_out=True)

        if SPIKING:
            for index, class_id  in enumerate(query_classes[-10:]):
                mean = torch.sum(features[[i*10+index for i in range(EVAL_SHOTS)]], dim=[0,1])/EVAL_SHOTS
                eval_model.net.snn[-1].W.weight.data[class_id] = 2*mean
                eval_model.net.snn[-1].W.bias.data[class_id] = -torch.matmul(mean, mean.t())/(features.shape[1])
        else:
            for index, class_id  in enumerate(query_classes[-10:]):
                mean = torch.sum(features[[i*10+index for i in range(EVAL_SHOTS)]], dim=0)/EVAL_SHOTS
                eval_model.net.output.weight.data[class_id] = 2*mean
                eval_model.net.output.bias.data[class_id] = -torch.matmul(mean, mean.t())

        ### Testing phase ###
        eval_model.net.eval()

        # Define session dataloaders for query and query + base_test samples
        query_loader = DataLoader(query, batch_size=BATCH_SIZE, num_workers=NUM_WORKERS)
        
        full_session_test_set = ConcatDataset([base_test_set, query])
        full_session_test_loader = DataLoader(full_session_test_set, batch_size=BATCH_SIZE, num_workers=NUM_WORKERS)

        # Create a mask function to only consider accuracy on classes presented so far
        session_classes = torch.cat((torch.arange(0,100, dtype=int), torch.IntTensor(query_classes))) 
        mask = torch.full((200,), float('inf')).to(device)
        mask[session_classes] = 0
        out_mask = lambda x: x - mask

        # Run benchmark to evaluate accuracy of this specific session
        session_results = benchmark_all_test.run(dataloader = full_session_test_loader, postprocessors=[out_mask, F.softmax, out2pred, torch.squeeze])
        print("Session results:", session_results)
        
        eval_accs.append(session_results['classification_accuracy'])
        act_sparsity.append(session_results['activation_sparsity'])
        syn_ops_dense.append(session_results['synaptic_operations']['Dense'])
        syn_ops_macs.append(session_results['synaptic_operations']['Effective_MACs'])
        syn_ops_acs.append(pre_train_results['synaptic_operations']['Effective_ACs'])
        print(f"Session accuracy: {session_results['classification_accuracy']*100} %")

        # Run benchmark on query classes only
        query_results = benchmark_new_classes.run(dataloader = query_loader, postprocessors=[out_mask, F.softmax, out2pred, torch.squeeze])
        print(f"Accuracy on new classes: {query_results['classification_accuracy']*100} %")
        query_accs.append(query_results['classification_accuracy'])

    all_evals.append(eval_accs)
    all_query.append(query_accs)
    all_act_sparsity.append(act_sparsity)
    all_syn_ops_dense.append(syn_ops_dense)
    all_syn_ops_macs.append(syn_ops_macs)
    all_syn_ops_acs.append(syn_ops_acs)

    mean_accuracy = np.mean(eval_accs)
    print(f"The total mean accuracy is {mean_accuracy*100}%")

    # Print all data
    print(f"Eval Accs: {eval_accs}")
    print(f"Query Accs: {query_accs}")
    print(f"Act Sparsity: {act_sparsity}")
    print(f"Syn Ops Dense: {syn_ops_dense}")
    print(f"Syn Ops MACs: {syn_ops_macs}")