In [None]:
from collections import OrderedDict
from torch import nn
from skorch import NeuralNet
from skorch.utils import to_numpy
from sklearn.base import TransformerMixin
from braindecode.models import EEGNetv4
from huggingface_hub import hf_hub_download
from sklearn.model_selection import train_test_split
from sklearn.pipeline import Pipeline
from sklearn.linear_model import LogisticRegression
from sklearn.preprocessing import FunctionTransformer
from moabb.paradigms import MotorImagery
from moabb.datasets import Cho2017
from moabb.evaluations import WithinSessionEvaluation, CrossSessionEvaluation


import numpy as np
import torch
import pandas as pd

In [None]:
def remove_clf_layers(model: nn.Sequential):
    """
    Remove the classification layers from braindecode models.
    Tested on EEGNetv4, Deep4Net (i.e. DeepConvNet), and EEGResNet.
    """
    new_layers = []
    for name, layer in model.named_children():
        if 'classif' in name:
            continue
        if 'softmax' in name:
            continue
        new_layers.append((name, layer))
    return nn.Sequential(OrderedDict(new_layers))


def freeze_model(model):
    model.eval()
    for param in model.parameters():
        param.requires_grad = False
    return model


In [None]:
class FrozenNeuralNetTransformer(NeuralNet, TransformerMixin):
    def __init__(
            self,
            *args,
            criterion=nn.MSELoss,  # should be unused
            unique_name=None,  # needed for a unique digest in MOABB
            **kwargs
    ):
        super().__init__(
            *args,
            criterion=criterion,
            **kwargs
        )
        self.initialize()
        self.unique_name = unique_name

    def fit(self, X, y=None, **fit_params):
        return self  # do nothing

    def transform(self, X):
        X = self.infer(X)
        return to_numpy(X)

    def __repr__(self):
        return super().__repr__() + self.unique_name
    
def flatten_batched(X):
    return X.reshape(X.shape[0], -1)

In [None]:
import numpy as np
from sklearn.linear_model import LogisticRegression

class RandomLogisticRegression(LogisticRegression):
    def fit(self, X, y):
        self.classes_ = np.unique(y)
        n_classes = len(self.classes_)
        n_features = X.shape[1]
        
        self.coef_ = np.random.randn(
            1 if n_classes == 2 else n_classes,
            n_features
        )
        self.intercept_ = np.random.randn(
            1 if n_classes == 2 else n_classes
        )
        
        return self

In [None]:
import pickle

# download the model from the hub:
path_kwargs = hf_hub_download(
    repo_id='PierreGtch/EEGNetv4',
    filename='EEGNetv4_Cho2017/kwargs.pkl',
)
path_params = hf_hub_download(
    repo_id='PierreGtch/EEGNetv4',
    filename='EEGNetv4_Cho2017/model-params.pkl',
)
with open(path_kwargs, 'rb') as f:
    kwargs = pickle.load(f)
    module_cls = kwargs['module_cls']
    module_kwargs = kwargs['module_kwargs']

# load the model with pre-trained weights:
torch_module = module_cls(**module_kwargs)
torch_module.load_state_dict(torch.load(path_params, map_location='cpu'))
embedding = freeze_model(remove_clf_layers(torch_module)).double()
# embedding = remove_clf_layers(torch_module).double()

# Integrate the model in a Scikit-learn pipeline:
sklearn_pipeline = Pipeline([
    ('embedding', FrozenNeuralNetTransformer(embedding, unique_name='pretrained_Cho2017')),
    ('flatten', FunctionTransformer(flatten_batched)),
    ('classifier', RandomLogisticRegression()),
])

In [15]:
paradigm = MotorImagery(
    channels=['C3', 'Cz', 'C4'],  # Same as the ones used to pre-train the embedding
    events=['left_hand', 'right_hand'],
    n_classes=2,
    fmin=0.5,
    fmax=40,
    tmin=0,
    tmax=3,
    resample=128
)
datasets = [Cho2017()]
evaluation = WithinSessionEvaluation(
    paradigm=paradigm,
    datasets=datasets,
    overwrite=True,
    suffix='demo',
)

In [None]:
results = evaluation.process(pipelines=dict(demo_pipeline=sklearn_pipeline))
results

Cho2017-WithinSession:   0%|          | 0/52 [00:00<?, ?it/s]Trials demeaned and stacked with zero buffer to create continuous data -- edge effects present
 'left_hand': 100
 'right_hand': 100>
  warn(f"warnEpochs {epochs}")
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))


No hdf5_path provided, models will not be saved.


Cho2017-WithinSession:   2%|▏         | 1/52 [00:03<02:50,  3.35s/it]Trials demeaned and stacked with zero buffer to create continuous data -- edge effects present
 'left_hand': 100
 'right_hand': 100>
  warn(f"warnEpochs {epochs}")


No hdf5_path provided, models will not be saved.


Cho2017-WithinSession:   4%|▍         | 2/52 [00:06<02:44,  3.29s/it]Trials demeaned and stacked with zero buffer to create continuous data -- edge effects present
 'left_hand': 100
 'right_hand': 100>
  warn(f"warnEpochs {epochs}")
Cho2017-WithinSession:   6%|▌         | 3/52 [00:09<02:37,  3.21s/it]

No hdf5_path provided, models will not be saved.


Trials demeaned and stacked with zero buffer to create continuous data -- edge effects present
 'left_hand': 100
 'right_hand': 100>
  warn(f"warnEpochs {epochs}")


No hdf5_path provided, models will not be saved.


Cho2017-WithinSession:   8%|▊         | 4/52 [00:12<02:33,  3.19s/it]Trials demeaned and stacked with zero buffer to create continuous data -- edge effects present
 'left_hand': 100
 'right_hand': 100>
  warn(f"warnEpochs {epochs}")


No hdf5_path provided, models will not be saved.


Cho2017-WithinSession:  10%|▉         | 5/52 [00:16<02:29,  3.18s/it]Trials demeaned and stacked with zero buffer to create continuous data -- edge effects present
 'left_hand': 100
 'right_hand': 100>
  warn(f"warnEpochs {epochs}")


No hdf5_path provided, models will not be saved.


Cho2017-WithinSession:  12%|█▏        | 6/52 [00:19<02:25,  3.17s/it]Trials demeaned and stacked with zero buffer to create continuous data -- edge effects present
 'left_hand': 120
 'right_hand': 120>
  warn(f"warnEpochs {epochs}")


No hdf5_path provided, models will not be saved.


Cho2017-WithinSession:  13%|█▎        | 7/52 [00:23<02:33,  3.42s/it]Trials demeaned and stacked with zero buffer to create continuous data -- edge effects present
 'left_hand': 100
 'right_hand': 100>
  warn(f"warnEpochs {epochs}")
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))


No hdf5_path provided, models will not be saved.


Cho2017-WithinSession:  15%|█▌        | 8/52 [00:26<02:28,  3.37s/it]Trials demeaned and stacked with zero buffer to create continuous data -- edge effects present
 'left_hand': 120
 'right_hand': 120>
  warn(f"warnEpochs {epochs}")


No hdf5_path provided, models will not be saved.


Cho2017-WithinSession:  17%|█▋        | 9/52 [00:30<02:29,  3.48s/it]Trials demeaned and stacked with zero buffer to create continuous data -- edge effects present
 'left_hand': 100
 'right_hand': 100>
  warn(f"warnEpochs {epochs}")
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))


No hdf5_path provided, models will not be saved.


Cho2017-WithinSession:  19%|█▉        | 10/52 [00:33<02:22,  3.38s/it]Trials demeaned and stacked with zero buffer to create continuous data -- edge effects present
 'left_hand': 100
 'right_hand': 100>
  warn(f"warnEpochs {epochs}")
Cho2017-WithinSession:  21%|██        | 11/52 [00:36<02:17,  3.35s/it]

No hdf5_path provided, models will not be saved.


Trials demeaned and stacked with zero buffer to create continuous data -- edge effects present
 'left_hand': 100
 'right_hand': 100>
  warn(f"warnEpochs {epochs}")


No hdf5_path provided, models will not be saved.


Cho2017-WithinSession:  23%|██▎       | 12/52 [00:39<02:12,  3.31s/it]Trials demeaned and stacked with zero buffer to create continuous data -- edge effects present
 'left_hand': 100
 'right_hand': 100>
  warn(f"warnEpochs {epochs}")


No hdf5_path provided, models will not be saved.


Cho2017-WithinSession:  25%|██▌       | 13/52 [00:43<02:09,  3.33s/it]Trials demeaned and stacked with zero buffer to create continuous data -- edge effects present
 'left_hand': 100
 'right_hand': 100>
  warn(f"warnEpochs {epochs}")


No hdf5_path provided, models will not be saved.


Cho2017-WithinSession:  27%|██▋       | 14/52 [00:46<02:06,  3.33s/it]Trials demeaned and stacked with zero buffer to create continuous data -- edge effects present
 'left_hand': 100
 'right_hand': 100>
  warn(f"warnEpochs {epochs}")
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))


No hdf5_path provided, models will not be saved.


Cho2017-WithinSession:  29%|██▉       | 15/52 [00:49<02:04,  3.37s/it]Trials demeaned and stacked with zero buffer to create continuous data -- edge effects present
 'left_hand': 100
 'right_hand': 100>
  warn(f"warnEpochs {epochs}")
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))


No hdf5_path provided, models will not be saved.


Cho2017-WithinSession:  31%|███       | 16/52 [00:53<02:01,  3.38s/it]Trials demeaned and stacked with zero buffer to create continuous data -- edge effects present
 'left_hand': 100
 'right_hand': 100>
  warn(f"warnEpochs {epochs}")
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))


No hdf5_path provided, models will not be saved.


Cho2017-WithinSession:  33%|███▎      | 17/52 [00:56<01:59,  3.40s/it]Trials demeaned and stacked with zero buffer to create continuous data -- edge effects present
 'left_hand': 100
 'right_hand': 100>
  warn(f"warnEpochs {epochs}")
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))


No hdf5_path provided, models will not be saved.


Cho2017-WithinSession:  35%|███▍      | 18/52 [01:00<01:56,  3.42s/it]Trials demeaned and stacked with zero buffer to create continuous data -- edge effects present
 'left_hand': 100
 'right_hand': 100>
  warn(f"warnEpochs {epochs}")
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))


No hdf5_path provided, models will not be saved.


Cho2017-WithinSession:  37%|███▋      | 19/52 [01:03<01:51,  3.38s/it]Trials demeaned and stacked with zero buffer to create continuous data -- edge effects present
 'left_hand': 100
 'right_hand': 100>
  warn(f"warnEpochs {epochs}")


No hdf5_path provided, models will not be saved.


Cho2017-WithinSession:  38%|███▊      | 20/52 [01:06<01:46,  3.33s/it]Trials demeaned and stacked with zero buffer to create continuous data -- edge effects present
 'left_hand': 100
 'right_hand': 100>
  warn(f"warnEpochs {epochs}")


No hdf5_path provided, models will not be saved.


Cho2017-WithinSession:  40%|████      | 21/52 [01:09<01:41,  3.28s/it]Trials demeaned and stacked with zero buffer to create continuous data -- edge effects present
 'left_hand': 100
 'right_hand': 100>
  warn(f"warnEpochs {epochs}")
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))


No hdf5_path provided, models will not be saved.


Cho2017-WithinSession:  42%|████▏     | 22/52 [01:13<01:38,  3.29s/it]Downloading data from 'https://s3.ap-northeast-1.wasabisys.com/gigadb-datasets/live/pub/10.5524/100001_101000/100295/mat_data/s23.mat' to file '/Users/andresalvarezolmo/mne_data/MNE-gigadb-data/gigadb-datasets/live/pub/10.5524/100001_101000/100295/mat_data/s23.mat'.
100%|███████████████████████████████████████| 205M/205M [00:00<00:00, 87.9GB/s]
SHA256 hash of downloaded file: c4a01839bc60d44eff43675492c3056c1338d0202f2cb6c72d37ecaaa95c81b0
Use this value as the 'known_hash' argument of 'pooch.retrieve' to ensure that the file hasn't changed if it is downloaded again in the future.
Trials demeaned and stacked with zero buffer to create continuous data -- edge effects present
 'left_hand': 100
 'right_hand': 100>
  warn(f"warnEpochs {epochs}")


No hdf5_path provided, models will not be saved.


  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
Cho2017-WithinSession:  44%|████▍     | 23/52 [02:13<09:48, 20.30s/it]Downloading data from 'https://s3.ap-northeast-1.wasabisys.com/gigadb-datasets/live/pub/10.5524/100001_101000/100295/mat_data/s24.mat' to file '/Users/andresalvarezolmo/mne_data/MNE-gigadb-data/gigadb-datasets/live/pub/10.5524/100001_101000/100295/mat_data/s24.mat'.
100%|████████████████████████████████████████| 203M/203M [00:00<00:00, 171GB/s]
SHA256 hash of downloaded file: b28db34626242be4975e19bea742e486d8fd0c4f9eaefd3878077745816f4d35
Use this value as the 'known_hash' argument of 'pooch.retrieve' to ensure that the file hasn't changed if it is downloaded again in the future.
Trials demeaned and stacked with zero buffer to create continuous data -- edge effects present
 'left_hand': 100
 'right_hand': 100>
  warn(f"warnEpochs {epochs}")
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f

No hdf5_path provided, models will not be saved.


Cho2017-WithinSession:  46%|████▌     | 24/52 [04:01<21:44, 46.60s/it]Downloading data from 'https://s3.ap-northeast-1.wasabisys.com/gigadb-datasets/live/pub/10.5524/100001_101000/100295/mat_data/s25.mat' to file '/Users/andresalvarezolmo/mne_data/MNE-gigadb-data/gigadb-datasets/live/pub/10.5524/100001_101000/100295/mat_data/s25.mat'.


In [None]:
data = {
    "accuracy": [results['score'].mean()],
    "f1": [results["f1"].mean()],
    "recall": [results["recall"].mean()],
    "specificity": [results["specificity"].mean()],
    "precision": [results["precision"].mean()]     
    } 
df = pd.DataFrame(data)
print(df)