In [49]:
from collections import OrderedDict
from torch import nn
from skorch import NeuralNet
from skorch.utils import to_numpy
from sklearn.base import TransformerMixin
from braindecode.models import EEGNetv4
from huggingface_hub import hf_hub_download
from sklearn.pipeline import Pipeline
from sklearn.linear_model import LogisticRegression
from sklearn.preprocessing import FunctionTransformer
from moabb.paradigms import MotorImagery
from moabb.datasets import Shin2017B
from moabb.evaluations import WithinSessionEvaluation, CrossSessionEvaluation

import torch
import pandas as pd

In [50]:
def remove_clf_layers(model: nn.Sequential):
    """
    Remove the classification layers from braindecode models.
    Tested on EEGNetv4, Deep4Net (i.e. DeepConvNet), and EEGResNet.
    """
    new_layers = []
    for name, layer in model.named_children():
        if 'classif' in name:
            continue
        if 'softmax' in name:
            continue
        new_layers.append((name, layer))
    return nn.Sequential(OrderedDict(new_layers))


def freeze_model(model):
    model.eval()
    for param in model.parameters():
        param.requires_grad = False
    return model


In [51]:
class FrozenNeuralNetTransformer(NeuralNet, TransformerMixin):
    def __init__(
            self,
            *args,
            criterion=nn.MSELoss,  # should be unused
            unique_name=None,  # needed for a unique digest in MOABB
            **kwargs
    ):
        super().__init__(
            *args,
            criterion=criterion,
            **kwargs
        )
        self.initialize()
        self.unique_name = unique_name

    def fit(self, X, y=None, **fit_params):
        return self  # do nothing

    def transform(self, X):
        X = self.infer(X)
        return to_numpy(X)

    def __repr__(self):
        return super().__repr__() + self.unique_name
    
def flatten_batched(X):
    return X.reshape(X.shape[0], -1)

In [52]:
import pickle

# download the model from the hub:
path_kwargs = hf_hub_download(
    repo_id='PierreGtch/EEGNetv4',
    filename='EEGNetv4_Lee2019_MI/kwargs.pkl',
)
path_params = hf_hub_download(
    repo_id='PierreGtch/EEGNetv4',
    filename='EEGNetv4_Lee2019_MI/model-params.pkl',
)
with open(path_kwargs, 'rb') as f:
    kwargs = pickle.load(f)
module_cls = kwargs['module_cls']
module_kwargs = kwargs['module_kwargs']

# load the model with pre-trained weights:
torch_module = module_cls(**module_kwargs)
torch_module.load_state_dict(torch.load(path_params, map_location='cpu'))
embedding = freeze_model(remove_clf_layers(torch_module)).double()

# Integrate the model in a Scikit-learn pipeline:
sklearn_pipeline = Pipeline([
    ('embedding', FrozenNeuralNetTransformer(embedding, unique_name='pretrained_Lee2019')),
    ('flatten', FunctionTransformer(flatten_batched)),
    ('classifier', LogisticRegression()),
])



In [54]:
from moabb.datasets import Cho2017

paradigm = MotorImagery(
    channels=['C3', 'Cz', 'C4'],  # Same as the ones used to pre-train the embedding
    events=['left_hand', 'right_hand'],
    n_classes=2,
    fmin=0.5,
    fmax=40,
    tmin=0,
    tmax=3,
    resample=128
)
datasets = [Cho2017()]
evaluation = WithinSessionEvaluation(
    paradigm=paradigm,
    datasets=datasets,
    overwrite=True,
    suffix='demo',
)

In [None]:
results = evaluation.process(pipelines=dict(demo_pipeline=sklearn_pipeline))

Cho2017-WithinSession:   0%|          | 0/52 [00:00<?, ?it/s]Trials demeaned and stacked with zero buffer to create continuous data -- edge effects present
 'left_hand': 100
 'right_hand': 100>
  warn(f"warnEpochs {epochs}")


No hdf5_path provided, models will not be saved.


Cho2017-WithinSession:   2%|▏         | 1/52 [00:03<03:07,  3.67s/it]Downloading data from 'https://s3.ap-northeast-1.wasabisys.com/gigadb-datasets/live/pub/10.5524/100001_101000/100295/mat_data/s02.mat' to file '/Users/andresalvarezolmo/mne_data/MNE-gigadb-data/gigadb-datasets/live/pub/10.5524/100001_101000/100295/mat_data/s02.mat'.




[A[A[A[A



[A[A[A[A



[A[A[A[A



[A[A[A[A



[A[A[A[A



[A[A[A[A



[A[A[A[A



[A[A[A[A



[A[A[A[A



[A[A[A[A



[A[A[A[A



[A[A[A[A



[A[A[A[A



[A[A[A[A



[A[A[A[A



[A[A[A[A



[A[A[A[A



[A[A[A[A



[A[A[A[A



[A[A[A[A



[A[A[A[A



[A[A[A[A



[A[A[A[A



[A[A[A[A



[A[A[A[A



[A[A[A[A



[A[A[A[A



[A[A[A[A



[A[A[A[A



[A[A[A[A



[A[A[A[A



[A[A[A[A



[A[A[A[A



[A[A[A[A



[A[A[A[A



[A[A[A[A



[A[A[A[A



[A[A[A[A



[A[A[A[A



[A[A[A[A



[A[A[A[A



[A

No hdf5_path provided, models will not be saved.


Cho2017-WithinSession:   4%|▍         | 2/52 [00:41<19:37, 23.56s/it]Downloading data from 'https://s3.ap-northeast-1.wasabisys.com/gigadb-datasets/live/pub/10.5524/100001_101000/100295/mat_data/s03.mat' to file '/Users/andresalvarezolmo/mne_data/MNE-gigadb-data/gigadb-datasets/live/pub/10.5524/100001_101000/100295/mat_data/s03.mat'.




[A[A[A[A



[A[A[A[A



[A[A[A[A



[A[A[A[A



[A[A[A[A



[A[A[A[A



[A[A[A[A



[A[A[A[A



[A[A[A[A



[A[A[A[A



[A[A[A[A



[A[A[A[A



[A[A[A[A



[A[A[A[A



[A[A[A[A



[A[A[A[A



[A[A[A[A



[A[A[A[A



[A[A[A[A



[A[A[A[A



[A[A[A[A



[A[A[A[A



[A[A[A[A



[A[A[A[A



[A[A[A[A



[A[A[A[A



[A[A[A[A



[A[A[A[A



[A[A[A[A



[A[A[A[A



[A[A[A[A



[A[A[A[A



[A[A[A[A



[A[A[A[A



[A[A[A[A



[A[A[A[A



[A[A[A[A



[A[A[A[A



[A[A[A[A



[A[A[A[A



[A[A[A[A



[A

No hdf5_path provided, models will not be saved.


Cho2017-WithinSession:   6%|▌         | 3/52 [01:30<28:42, 35.15s/it]Downloading data from 'https://s3.ap-northeast-1.wasabisys.com/gigadb-datasets/live/pub/10.5524/100001_101000/100295/mat_data/s04.mat' to file '/Users/andresalvarezolmo/mne_data/MNE-gigadb-data/gigadb-datasets/live/pub/10.5524/100001_101000/100295/mat_data/s04.mat'.




[A[A[A[A



[A[A[A[A



[A[A[A[A



[A[A[A[A



[A[A[A[A



[A[A[A[A



[A[A[A[A



[A[A[A[A



[A[A[A[A



[A[A[A[A



[A[A[A[A



[A[A[A[A



[A[A[A[A



[A[A[A[A



[A[A[A[A



[A[A[A[A



[A[A[A[A



[A[A[A[A



[A[A[A[A



[A[A[A[A



[A[A[A[A



[A[A[A[A



[A[A[A[A



[A[A[A[A



[A[A[A[A



[A[A[A[A



[A[A[A[A



[A[A[A[A



[A[A[A[A



[A[A[A[A



[A[A[A[A



[A[A[A[A



[A[A[A[A



[A[A[A[A



[A[A[A[A



[A[A[A[A



[A[A[A[A



[A[A[A[A



[A[A[A[A



[A[A[A[A



[A[A[A[A



[A

No hdf5_path provided, models will not be saved.


Cho2017-WithinSession:   8%|▊         | 4/52 [02:01<26:47, 33.49s/it]Downloading data from 'https://s3.ap-northeast-1.wasabisys.com/gigadb-datasets/live/pub/10.5524/100001_101000/100295/mat_data/s05.mat' to file '/Users/andresalvarezolmo/mne_data/MNE-gigadb-data/gigadb-datasets/live/pub/10.5524/100001_101000/100295/mat_data/s05.mat'.




[A[A[A[A



[A[A[A[A



[A[A[A[A



[A[A[A[A



[A[A[A[A



[A[A[A[A



[A[A[A[A



[A[A[A[A



[A[A[A[A



[A[A[A[A



[A[A[A[A



[A[A[A[A



[A[A[A[A



[A[A[A[A



[A[A[A[A



[A[A[A[A



[A[A[A[A



[A[A[A[A



[A[A[A[A



[A[A[A[A



[A[A[A[A



[A[A[A[A



[A[A[A[A



[A[A[A[A



[A[A[A[A



[A[A[A[A



[A[A[A[A



[A[A[A[A



[A[A[A[A



[A[A[A[A



[A[A[A[A



[A[A[A[A



[A[A[A[A



[A[A[A[A



[A[A[A[A



[A[A[A[A



[A[A[A[A



[A[A[A[A



[A[A[A[A



[A[A[A[A



[A[A[A[A



[A

No hdf5_path provided, models will not be saved.


Cho2017-WithinSession:  10%|▉         | 5/52 [02:28<24:38, 31.45s/it]Downloading data from 'https://s3.ap-northeast-1.wasabisys.com/gigadb-datasets/live/pub/10.5524/100001_101000/100295/mat_data/s06.mat' to file '/Users/andresalvarezolmo/mne_data/MNE-gigadb-data/gigadb-datasets/live/pub/10.5524/100001_101000/100295/mat_data/s06.mat'.




[A[A[A[A



[A[A[A[A



[A[A[A[A



[A[A[A[A



[A[A[A[A



[A[A[A[A



[A[A[A[A



[A[A[A[A



[A[A[A[A



[A[A[A[A



[A[A[A[A



[A[A[A[A



[A[A[A[A



[A[A[A[A



[A[A[A[A



[A[A[A[A



[A[A[A[A



[A[A[A[A



[A[A[A[A



[A[A[A[A



[A[A[A[A



[A[A[A[A



[A[A[A[A



[A[A[A[A



[A[A[A[A



[A[A[A[A



[A[A[A[A



[A[A[A[A



[A[A[A[A



[A[A[A[A



[A[A[A[A



[A[A[A[A



[A[A[A[A



[A[A[A[A



[A[A[A[A



[A[A[A[A



[A[A[A[A



[A[A[A[A



[A[A[A[A



[A[A[A[A



[A[A[A[A



[A

No hdf5_path provided, models will not be saved.


Cho2017-WithinSession:  12%|█▏        | 6/52 [03:00<24:04, 31.40s/it]Downloading data from 'https://s3.ap-northeast-1.wasabisys.com/gigadb-datasets/live/pub/10.5524/100001_101000/100295/mat_data/s07.mat' to file '/Users/andresalvarezolmo/mne_data/MNE-gigadb-data/gigadb-datasets/live/pub/10.5524/100001_101000/100295/mat_data/s07.mat'.




[A[A[A[A



[A[A[A[A



[A[A[A[A



[A[A[A[A



[A[A[A[A



[A[A[A[A



[A[A[A[A



[A[A[A[A



[A[A[A[A



[A[A[A[A



[A[A[A[A



[A[A[A[A



[A[A[A[A



[A[A[A[A



[A[A[A[A



[A[A[A[A



[A[A[A[A



[A[A[A[A



[A[A[A[A



[A[A[A[A



[A[A[A[A



[A[A[A[A



[A[A[A[A



[A[A[A[A



[A[A[A[A



[A[A[A[A



[A[A[A[A



[A[A[A[A



[A[A[A[A



[A[A[A[A



[A[A[A[A



[A[A[A[A



[A[A[A[A



[A[A[A[A



[A[A[A[A



[A[A[A[A



[A[A[A[A



[A[A[A[A



[A[A[A[A



[A[A[A[A



[A[A[A[A



[A

No hdf5_path provided, models will not be saved.


Cho2017-WithinSession:  13%|█▎        | 7/52 [03:36<24:40, 32.90s/it]Downloading data from 'https://s3.ap-northeast-1.wasabisys.com/gigadb-datasets/live/pub/10.5524/100001_101000/100295/mat_data/s08.mat' to file '/Users/andresalvarezolmo/mne_data/MNE-gigadb-data/gigadb-datasets/live/pub/10.5524/100001_101000/100295/mat_data/s08.mat'.




[A[A[A[A



[A[A[A[A



[A[A[A[A



[A[A[A[A



[A[A[A[A



[A[A[A[A



[A[A[A[A



[A[A[A[A



[A[A[A[A



[A[A[A[A



[A[A[A[A



[A[A[A[A



[A[A[A[A



[A[A[A[A



[A[A[A[A



[A[A[A[A



[A[A[A[A



[A[A[A[A



[A[A[A[A



[A[A[A[A



[A[A[A[A



[A[A[A[A



[A[A[A[A



[A[A[A[A



[A[A[A[A



[A[A[A[A



[A[A[A[A



[A[A[A[A



[A[A[A[A



[A[A[A[A



[A[A[A[A



[A[A[A[A



[A[A[A[A



[A[A[A[A



[A[A[A[A



[A[A[A[A



[A[A[A[A



[A[A[A[A



[A[A[A[A



[A[A[A[A



[A[A[A[A



[A

No hdf5_path provided, models will not be saved.


Cho2017-WithinSession:  15%|█▌        | 8/52 [04:08<24:05, 32.85s/it]Downloading data from 'https://s3.ap-northeast-1.wasabisys.com/gigadb-datasets/live/pub/10.5524/100001_101000/100295/mat_data/s09.mat' to file '/Users/andresalvarezolmo/mne_data/MNE-gigadb-data/gigadb-datasets/live/pub/10.5524/100001_101000/100295/mat_data/s09.mat'.




[A[A[A[A



[A[A[A[A



[A[A[A[A



[A[A[A[A



[A[A[A[A



[A[A[A[A



[A[A[A[A



[A[A[A[A



[A[A[A[A



[A[A[A[A



[A[A[A[A



[A[A[A[A



[A[A[A[A



[A[A[A[A



[A[A[A[A



[A[A[A[A



[A[A[A[A



[A[A[A[A



[A[A[A[A



[A[A[A[A



[A[A[A[A



[A[A[A[A



[A[A[A[A



[A[A[A[A



[A[A[A[A



[A[A[A[A



[A[A[A[A



[A[A[A[A



[A[A[A[A



[A[A[A[A



[A[A[A[A



[A[A[A[A



[A[A[A[A



[A[A[A[A



[A[A[A[A



[A[A[A[A



[A[A[A[A



[A[A[A[A



[A[A[A[A



[A[A[A[A



[A[A[A[A



[A

No hdf5_path provided, models will not be saved.


Cho2017-WithinSession:  17%|█▋        | 9/52 [04:45<24:23, 34.03s/it]Downloading data from 'https://s3.ap-northeast-1.wasabisys.com/gigadb-datasets/live/pub/10.5524/100001_101000/100295/mat_data/s10.mat' to file '/Users/andresalvarezolmo/mne_data/MNE-gigadb-data/gigadb-datasets/live/pub/10.5524/100001_101000/100295/mat_data/s10.mat'.


In [None]:
results

Unnamed: 0,score,f1,recall,specificity,precision,time,samples,subject,session,channels,n_sessions,dataset,pipeline
0,0.524762,0.521941,0.524242,0.761108,0.528578,0.051888,179.0,1,0,3,3,Zhou2016,demo_pipeline
1,0.62,0.619226,0.62,0.81,0.632209,0.022646,150.0,1,1,3,3,Zhou2016,demo_pipeline
2,0.633333,0.614599,0.633333,0.816667,0.648614,0.02706,150.0,1,2,3,3,Zhou2016,demo_pipeline
3,0.573333,0.561358,0.573333,0.786667,0.564913,0.036932,150.0,2,0,3,3,Zhou2016,demo_pipeline
4,0.666667,0.665119,0.666667,0.833333,0.689021,0.065634,135.0,2,1,3,3,Zhou2016,demo_pipeline
5,0.626667,0.615576,0.626667,0.813333,0.640647,0.034364,150.0,2,2,3,3,Zhou2016,demo_pipeline
6,0.68,0.669079,0.68,0.84,0.689935,0.030962,150.0,3,0,3,3,Zhou2016,demo_pipeline
7,0.596129,0.581496,0.598182,0.79768,0.587382,0.021299,151.0,3,1,3,3,Zhou2016,demo_pipeline
8,0.653333,0.653312,0.653333,0.826667,0.657391,0.020853,150.0,3,2,3,3,Zhou2016,demo_pipeline
9,0.785185,0.780627,0.785185,0.892593,0.784161,0.020561,135.0,4,0,3,3,Zhou2016,demo_pipeline


In [None]:

data = {
    "accuracy": [results['score'].mean()],
    "f1": [results["f1"].mean()],
    "recall": [results["recall"].mean()],
    "specificity": [results["specificity"].mean()],
    "precision": [results["precision"].mean()]     
    } 
df = pd.DataFrame(data)
print(df)

   accuracy        f1    recall  specificity  precision
0  0.657173  0.650345  0.657301     0.828448   0.663703
