In [1]:
from collections import OrderedDict
from torch import nn
from skorch import NeuralNet
from skorch.utils import to_numpy
from sklearn.base import TransformerMixin
from braindecode.models import EEGNetv4
from huggingface_hub import hf_hub_download
from sklearn.pipeline import Pipeline
from sklearn.linear_model import LogisticRegression
from sklearn.preprocessing import FunctionTransformer
from moabb.paradigms import MotorImagery
from moabb.datasets import Zhou2016
from moabb.evaluations import WithinSessionEvaluation, CrossSessionEvaluation

import torch
import pandas as pd

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
def remove_clf_layers(model: nn.Sequential):
    """
    Remove the classification layers from braindecode models.
    Tested on EEGNetv4, Deep4Net (i.e. DeepConvNet), and EEGResNet.
    """
    new_layers = []
    for name, layer in model.named_children():
        if 'classif' in name:
            continue
        if 'softmax' in name:
            continue
        new_layers.append((name, layer))
    return nn.Sequential(OrderedDict(new_layers))


def freeze_model(model):
    model.eval()
    for param in model.parameters():
        param.requires_grad = False
    return model


In [3]:
class FrozenNeuralNetTransformer(NeuralNet, TransformerMixin):
    def __init__(
            self,
            *args,
            criterion=nn.MSELoss,  # should be unused
            unique_name=None,  # needed for a unique digest in MOABB
            **kwargs
    ):
        super().__init__(
            *args,
            criterion=criterion,
            **kwargs
        )
        self.initialize()
        self.unique_name = unique_name

    def fit(self, X, y=None, **fit_params):
        return self  # do nothing

    def transform(self, X):
        X = self.infer(X)
        return to_numpy(X)

    def __repr__(self):
        return super().__repr__() + self.unique_name
    
def flatten_batched(X):
    return X.reshape(X.shape[0], -1)

In [None]:
import numpy as np
from sklearn.linear_model import LogisticRegression

class RandomLogisticRegression(LogisticRegression):
    def fit(self, X, y):
        self.classes_ = np.unique(y)
        n_classes = len(self.classes_)
        n_features = X.shape[1]
        
        self.coef_ = np.random.randn(
            1 if n_classes == 2 else n_classes,
            n_features
        )
        self.intercept_ = np.random.randn(
            1 if n_classes == 2 else n_classes
        )
        
        return self

In [12]:
import pickle

# download the model from the hub:
path_kwargs = hf_hub_download(
    repo_id='PierreGtch/EEGNetv4',
    filename='EEGNetv4_Lee2019_MI/kwargs.pkl',
)
path_params = hf_hub_download(
    repo_id='PierreGtch/EEGNetv4',
    filename='EEGNetv4_Lee2019_MI/model-params.pkl',
)
with open(path_kwargs, 'rb') as f:
    kwargs = pickle.load(f)
module_cls = kwargs['module_cls']
module_kwargs = kwargs['module_kwargs']

# load the model with pre-trained weights:
torch_module = module_cls(**module_kwargs)
torch_module.load_state_dict(torch.load(path_params, map_location='cpu'))
embedding = freeze_model(remove_clf_layers(torch_module)).double()

# Integrate the model in a Scikit-learn pipeline:
sklearn_pipeline = Pipeline([
    ('embedding', FrozenNeuralNetTransformer(embedding, unique_name='pretrained_Lee2019')),
    ('flatten', FunctionTransformer(flatten_batched)),
    ('classifier', RandomLogisticRegression()),
])

In [13]:
paradigm = MotorImagery(
    channels=['C3', 'Cz', 'C4'],  # Same as the ones used to pre-train the embedding
    events=['left_hand', 'right_hand', 'feet'],
    n_classes=3,
    fmin=0.5,
    fmax=40,
    tmin=0,
    tmax=3,
    resample=128
)
datasets = [Zhou2016()]
evaluation = WithinSessionEvaluation(
    paradigm=paradigm,
    datasets=datasets,
    overwrite=True,
    suffix='demo',
)

In [None]:
results = evaluation.process(pipelines=dict(demo_pipeline=sklearn_pipeline))

In [15]:
results

Unnamed: 0,score,f1,recall,specificity,precision,time,samples,subject,session,channels,n_sessions,dataset,pipeline
0,0.295873,0.215698,0.29596,0.646881,0.181648,0.033342,179.0,1,0,3,3,Zhou2016,demo_pipeline
1,0.32,0.280766,0.32,0.66,0.282301,0.028085,150.0,1,1,3,3,Zhou2016,demo_pipeline
2,0.373333,0.323245,0.373333,0.686667,0.373028,0.022446,150.0,1,2,3,3,Zhou2016,demo_pipeline
3,0.38,0.301229,0.38,0.69,0.276819,0.023878,150.0,2,0,3,3,Zhou2016,demo_pipeline
4,0.237037,0.173103,0.237037,0.618519,0.146648,0.023591,135.0,2,1,3,3,Zhou2016,demo_pipeline
5,0.34,0.287576,0.34,0.67,0.364279,0.026319,150.0,2,2,3,3,Zhou2016,demo_pipeline
6,0.233333,0.162743,0.233333,0.616667,0.202112,0.029106,150.0,3,0,3,3,Zhou2016,demo_pipeline
7,0.252688,0.201869,0.252727,0.626836,0.177861,0.023882,151.0,3,1,3,3,Zhou2016,demo_pipeline
8,0.3,0.234055,0.3,0.65,0.304353,0.031908,150.0,3,2,3,3,Zhou2016,demo_pipeline
9,0.288889,0.200733,0.288889,0.644444,0.18556,0.024413,135.0,4,0,3,3,Zhou2016,demo_pipeline


In [16]:

data = {
    "accuracy": [results['score'].mean()],
    "f1": [results["f1"].mean()],
    "recall": [results["recall"].mean()],
    "specificity": [results["specificity"].mean()],
    "precision": [results["precision"].mean()]     
    } 
df = pd.DataFrame(data)
print(df)

   accuracy        f1    recall  specificity  precision
0  0.303985  0.237529  0.303996     0.651946   0.249544
