In [12]:
from collections import OrderedDict
from torch import nn
from skorch import NeuralNet
from skorch.utils import to_numpy
from sklearn.base import TransformerMixin
from braindecode.models import EEGNetv4
from huggingface_hub import hf_hub_download
from sklearn.pipeline import Pipeline
from sklearn.linear_model import LogisticRegression
from sklearn.preprocessing import FunctionTransformer
from moabb.paradigms import MotorImagery
from moabb.datasets import Cho2017
from moabb.evaluations import WithinSessionEvaluation, CrossSessionEvaluation

import torch
import pandas as pd

In [13]:
def remove_clf_layers(model: nn.Sequential):
    """
    Remove the classification layers from braindecode models.
    Tested on EEGNetv4, Deep4Net (i.e. DeepConvNet), and EEGResNet.
    """
    new_layers = []
    for name, layer in model.named_children():
        if 'classif' in name:
            continue
        if 'softmax' in name:
            continue
        new_layers.append((name, layer))
    return nn.Sequential(OrderedDict(new_layers))


def freeze_model(model):
    model.eval()
    for param in model.parameters():
        param.requires_grad = False
    return model


In [14]:
class FrozenNeuralNetTransformer(NeuralNet, TransformerMixin):
    def __init__(
            self,
            *args,
            criterion=nn.MSELoss,  # should be unused
            unique_name=None,  # needed for a unique digest in MOABB
            **kwargs
    ):
        super().__init__(
            *args,
            criterion=criterion,
            **kwargs
        )
        self.initialize()
        self.unique_name = unique_name

    def fit(self, X, y=None, **fit_params):
        return self  # do nothing

    def transform(self, X):
        X = self.infer(X)
        return to_numpy(X)

    def __repr__(self):
        return super().__repr__() + self.unique_name
    
def flatten_batched(X):
    return X.reshape(X.shape[0], -1)

In [15]:
import numpy as np
from sklearn.linear_model import LogisticRegression

class RandomLogisticRegression(LogisticRegression):
    def fit(self, X, y):
        self.classes_ = np.unique(y)
        n_classes = len(self.classes_)
        n_features = X.shape[1]
        
        self.coef_ = np.random.randn(
            1 if n_classes == 2 else n_classes,
            n_features
        )
        self.intercept_ = np.random.randn(
            1 if n_classes == 2 else n_classes
        )
        
        return self

In [16]:
import pickle

# download the model from the hub:
path_kwargs = hf_hub_download(
    repo_id='PierreGtch/EEGNetv4',
    filename='EEGNetv4_Lee2019_MI/kwargs.pkl',
)
path_params = hf_hub_download(
    repo_id='PierreGtch/EEGNetv4',
    filename='EEGNetv4_Lee2019_MI/model-params.pkl',
)
with open(path_kwargs, 'rb') as f:
    kwargs = pickle.load(f)
module_cls = kwargs['module_cls']
module_kwargs = kwargs['module_kwargs']

# load the model with pre-trained weights:
torch_module = module_cls(**module_kwargs)
torch_module.load_state_dict(torch.load(path_params, map_location='cpu'))
embedding = freeze_model(remove_clf_layers(torch_module)).double()

# Integrate the model in a Scikit-learn pipeline:
sklearn_pipeline = Pipeline([
    ('embedding', FrozenNeuralNetTransformer(embedding, unique_name='pretrained_Lee2019')),
    ('flatten', FunctionTransformer(flatten_batched)),
    ('classifier', RandomLogisticRegression()),
])



In [17]:
paradigm = MotorImagery(
    channels=['C3', 'Cz', 'C4'],  # Same as the ones used to pre-train the embedding
    events=['left_hand', 'right_hand'],
    n_classes=2,
    fmin=0.5,
    fmax=40,
    tmin=0,
    tmax=3,
    resample=128
)
datasets = [Cho2017()]
evaluation = WithinSessionEvaluation(
    paradigm=paradigm,
    datasets=datasets,
    overwrite=True,
    suffix='demo',
)

In [None]:
results = evaluation.process(pipelines=dict(demo_pipeline=sklearn_pipeline))

Cho2017-WithinSession:   0%|          | 0/52 [00:00<?, ?it/s]Trials demeaned and stacked with zero buffer to create continuous data -- edge effects present
 'left_hand': 100
 'right_hand': 100>
  warn(f"warnEpochs {epochs}")
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))


No hdf5_path provided, models will not be saved.


Cho2017-WithinSession:   2%|▏         | 1/52 [00:03<03:16,  3.86s/it]Trials demeaned and stacked with zero buffer to create continuous data -- edge effects present
 'left_hand': 100
 'right_hand': 100>
  warn(f"warnEpochs {epochs}")
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))


No hdf5_path provided, models will not be saved.


Cho2017-WithinSession:   4%|▍         | 2/52 [00:07<02:56,  3.54s/it]Trials demeaned and stacked with zero buffer to create continuous data -- edge effects present
 'left_hand': 100
 'right_hand': 100>
  warn(f"warnEpochs {epochs}")


No hdf5_path provided, models will not be saved.


Cho2017-WithinSession:   6%|▌         | 3/52 [00:10<02:52,  3.51s/it]Trials demeaned and stacked with zero buffer to create continuous data -- edge effects present
 'left_hand': 100
 'right_hand': 100>
  warn(f"warnEpochs {epochs}")


No hdf5_path provided, models will not be saved.


Cho2017-WithinSession:   8%|▊         | 4/52 [00:14<02:49,  3.54s/it]Trials demeaned and stacked with zero buffer to create continuous data -- edge effects present
 'left_hand': 100
 'right_hand': 100>
  warn(f"warnEpochs {epochs}")


No hdf5_path provided, models will not be saved.


Cho2017-WithinSession:  10%|▉         | 5/52 [00:17<02:43,  3.49s/it]Trials demeaned and stacked with zero buffer to create continuous data -- edge effects present
 'left_hand': 100
 'right_hand': 100>
  warn(f"warnEpochs {epochs}")


No hdf5_path provided, models will not be saved.


Cho2017-WithinSession:  12%|█▏        | 6/52 [00:21<02:39,  3.46s/it]Trials demeaned and stacked with zero buffer to create continuous data -- edge effects present
 'left_hand': 120
 'right_hand': 120>
  warn(f"warnEpochs {epochs}")


No hdf5_path provided, models will not be saved.


Cho2017-WithinSession:  13%|█▎        | 7/52 [00:25<02:55,  3.91s/it]Downloading data from 'https://s3.ap-northeast-1.wasabisys.com/gigadb-datasets/live/pub/10.5524/100001_101000/100295/mat_data/s08.mat' to file '/Users/andresalvarezolmo/mne_data/MNE-gigadb-data/gigadb-datasets/live/pub/10.5524/100001_101000/100295/mat_data/s08.mat'.

[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A

In [None]:
results

In [None]:

data = {
    "accuracy": [results['score'].mean()],
    "f1": [results["f1"].mean()],
    "recall": [results["recall"].mean()],
    "specificity": [results["specificity"].mean()],
    "precision": [results["precision"].mean()]     
    } 
df = pd.DataFrame(data)
print(df)

   accuracy        f1  recall  specificity  precision
0     0.515  0.497032   0.515        0.515   0.507688
