In [1]:
from collections import OrderedDict
from torch import nn
from skorch import NeuralNet
from skorch.utils import to_numpy
from sklearn.base import TransformerMixin
from braindecode.models import EEGNetv4
from huggingface_hub import hf_hub_download
from sklearn.pipeline import Pipeline
from sklearn.linear_model import LogisticRegression
from sklearn.preprocessing import FunctionTransformer
from moabb.paradigms import MotorImagery
from moabb.datasets import Zhou2016
from moabb.evaluations import WithinSessionEvaluation, CrossSessionEvaluation
from sklearn.metrics import make_scorer, accuracy_score, f1_score

import torch
import pandas as pd

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
def remove_clf_layers(model: nn.Sequential):
    """
    Remove the classification layers from braindecode models.
    Tested on EEGNetv4, Deep4Net (i.e. DeepConvNet), and EEGResNet.
    """
    new_layers = []
    for name, layer in model.named_children():
        if 'classif' in name:
            continue
        if 'softmax' in name:
            continue
        new_layers.append((name, layer))
    return nn.Sequential(OrderedDict(new_layers))


def freeze_model(model):
    model.eval()
    for param in model.parameters():
        param.requires_grad = False
    return model


In [3]:
class FrozenNeuralNetTransformer(NeuralNet, TransformerMixin):
    def __init__(
            self,
            *args,
            criterion=nn.MSELoss,  # should be unused
            unique_name=None,  # needed for a unique digest in MOABB
            **kwargs
    ):
        super().__init__(
            *args,
            criterion=criterion,
            **kwargs
        )
        self.initialize()
        self.unique_name = unique_name

    def fit(self, X, y=None, **fit_params):
        return self  # do nothing

    def transform(self, X):
        X = self.infer(X)
        return to_numpy(X)

    def __repr__(self):
        return super().__repr__() + self.unique_name
    
def flatten_batched(X):
    return X.reshape(X.shape[0], -1)

In [4]:
import pickle

# download the model from the hub:
path_kwargs = hf_hub_download(
    repo_id='PierreGtch/EEGNetv4',
    filename='EEGNetv4_Lee2019_MI/kwargs.pkl',
)
path_params = hf_hub_download(
    repo_id='PierreGtch/EEGNetv4',
    filename='EEGNetv4_Lee2019_MI/model-params.pkl',
)
with open(path_kwargs, 'rb') as f:
    kwargs = pickle.load(f)
module_cls = kwargs['module_cls']
module_kwargs = kwargs['module_kwargs']

# load the model with pre-trained weights:
torch_module = module_cls(**module_kwargs)
torch_module.load_state_dict(torch.load(path_params, map_location='cpu'))
embedding = freeze_model(remove_clf_layers(torch_module)).double()

# Integrate the model in a Scikit-learn pipeline:
sklearn_pipeline = Pipeline([
    ('embedding', FrozenNeuralNetTransformer(embedding, unique_name='pretrained_Lee2019')),
    ('flatten', FunctionTransformer(flatten_batched)),
    ('classifier', LogisticRegression()),
])



In [5]:
scoring = {
    'accuracy': make_scorer(accuracy_score),
    'f1': make_scorer(f1_score, average='weighted')
}

In [6]:
paradigm = MotorImagery(
    channels=['C3', 'Cz', 'C4'],  # Same as the ones used to pre-train the embedding
    events=['left_hand', 'right_hand', 'feet'],
    n_classes=3,
    fmin=0.5,
    fmax=40,
    tmin=0,
    tmax=3,
    resample=128
)
datasets = [Zhou2016()]
evaluation = WithinSessionEvaluation(
    paradigm=paradigm,
    datasets=datasets,
    overwrite=True,
    suffix='demo',
)

In [7]:
datasets = [Zhou2016()]
evaluation = WithinSessionEvaluation(
    paradigm=paradigm,
    datasets=datasets,
    overwrite=True,
    suffix='demo',
)

In [None]:
results = evaluation.process(pipelines=dict(demo_pipeline=sklearn_pipeline))

In [9]:
results

Unnamed: 0,score,f1,recall,specificity,precision,time,samples,subject,session,channels,n_sessions,dataset,pipeline
0,0.535714,0.530302,0.534849,0.767184,0.540674,0.036486,179.0,1,0,3,3,Zhou2016,demo_pipeline
1,0.58,0.573749,0.58,0.79,0.591464,0.022084,150.0,1,1,3,3,Zhou2016,demo_pipeline
2,0.653333,0.63651,0.653333,0.826667,0.66032,0.022127,150.0,1,2,3,3,Zhou2016,demo_pipeline
3,0.566667,0.552825,0.566667,0.783333,0.592555,0.020669,150.0,2,0,3,3,Zhou2016,demo_pipeline
4,0.674074,0.675479,0.674074,0.837037,0.69596,0.023178,135.0,2,1,3,3,Zhou2016,demo_pipeline
5,0.626667,0.61296,0.626667,0.813333,0.639764,0.021616,150.0,2,2,3,3,Zhou2016,demo_pipeline
6,0.666667,0.653769,0.666667,0.833333,0.67864,0.020088,150.0,3,0,3,3,Zhou2016,demo_pipeline
7,0.596989,0.595606,0.598182,0.797296,0.61272,0.025167,151.0,3,1,3,3,Zhou2016,demo_pipeline
8,0.64,0.635542,0.64,0.82,0.635907,0.021002,150.0,3,2,3,3,Zhou2016,demo_pipeline
9,0.777778,0.77085,0.777778,0.888889,0.801849,0.021314,135.0,4,0,3,3,Zhou2016,demo_pipeline


In [10]:

data = {
    "accuracy": [results['score'].mean()],
    "f1": [results["f1"].mean()],
    "recall": [results["recall"].mean()],
    "specificity": [results["specificity"].mean()],
    "precision": [results["precision"].mean()]     
    } 
df = pd.DataFrame(data)
print(df)

   accuracy        f1    recall  specificity  precision
0  0.652046  0.644337  0.652073     0.825867   0.664542
