In [2]:
from collections import OrderedDict
from torch import nn
from skorch import NeuralNet
from skorch.utils import to_numpy
from sklearn.base import TransformerMixin
from braindecode.models import EEGNetv4
from huggingface_hub import hf_hub_download
from sklearn.pipeline import Pipeline
from sklearn.linear_model import LogisticRegression
from sklearn.preprocessing import FunctionTransformer
from moabb.paradigms import MotorImagery
from moabb.datasets import Zhou2016
from moabb.evaluations import WithinSessionEvaluation, CrossSessionEvaluation

import torch
import pandas as pd

  from .autonotebook import tqdm as notebook_tqdm


In [3]:
def remove_clf_layers(model: nn.Sequential):
    """
    Remove the classification layers from braindecode models.
    Tested on EEGNetv4, Deep4Net (i.e. DeepConvNet), and EEGResNet.
    """
    new_layers = []
    for name, layer in model.named_children():
        if 'classif' in name:
            continue
        if 'softmax' in name:
            continue
        new_layers.append((name, layer))
    return nn.Sequential(OrderedDict(new_layers))


def freeze_model(model):
    model.eval()
    for param in model.parameters():
        param.requires_grad = False
    return model


In [4]:
class FrozenNeuralNetTransformer(NeuralNet, TransformerMixin):
    def __init__(
            self,
            *args,
            criterion=nn.MSELoss,  # should be unused
            unique_name=None,  # needed for a unique digest in MOABB
            **kwargs
    ):
        super().__init__(
            *args,
            criterion=criterion,
            **kwargs
        )
        self.initialize()
        self.unique_name = unique_name

    def fit(self, X, y=None, **fit_params):
        return self  # do nothing

    def transform(self, X):
        X = self.infer(X)
        return to_numpy(X)

    def __repr__(self):
        return super().__repr__() + self.unique_name
    
def flatten_batched(X):
    return X.reshape(X.shape[0], -1)

In [None]:
import pickle

# download the model from the hub:
path_kwargs = hf_hub_download(
    repo_id='PierreGtch/EEGNetv4',
    filename='EEGNetv4_Lee2019_MI/kwargs.pkl',
)
path_params = hf_hub_download(
    repo_id='PierreGtch/EEGNetv4',
    filename='EEGNetv4_Lee2019_MI/model-params.pkl',
)
with open(path_kwargs, 'rb') as f:
    kwargs = pickle.load(f)
module_cls = kwargs['module_cls']
module_kwargs = kwargs['module_kwargs']

# load the model with pre-trained weights:
torch_module = module_cls(**module_kwargs)
torch_module.load_state_dict(torch.load(path_params, map_location='cpu'))
embedding = freeze_model(remove_clf_layers(torch_module)).double()

# Integrate the model in a Scikit-learn pipeline:
sklearn_pipeline = Pipeline([
    ('embedding', FrozenNeuralNetTransformer(embedding, unique_name='pretrained_Lee2019')),
    ('flatten', FunctionTransformer(flatten_batched)),
    ('classifier', LogisticRegression()),
])



In [6]:
paradigm = MotorImagery(
    channels=['C3', 'Cz', 'C4'],  # Same as the ones used to pre-train the embedding
    events=['left_hand', 'right_hand', 'feet'],
    n_classes=3,
    fmin=0.5,
    fmax=40,
    tmin=0,
    tmax=3,
    resample=128
)
datasets = [Zhou2016()]
evaluation = WithinSessionEvaluation(
    paradigm=paradigm,
    datasets=datasets,
    overwrite=True,
    suffix='demo',
)

In [7]:
results = evaluation.process(pipelines=dict(demo_pipeline=sklearn_pipeline))

Zhou2016-WithinSession:   0%|          | 0/4 [00:00<?, ?it/s]

Reading 0 ... 305029  =      0.000 ...  1220.116 secs...


  raw = read_raw_cnt(fname, preload=True, eog=["VEOU", "VEOL"])
  raw = read_raw_cnt(fname, preload=True, eog=["VEOU", "VEOL"])


Used Annotations descriptions: ['feet', 'left_hand', 'right_hand']
Reading 0 ... 430479  =      0.000 ...  1721.916 secs...


  raw = read_raw_cnt(fname, preload=True, eog=["VEOU", "VEOL"])
  raw = read_raw_cnt(fname, preload=True, eog=["VEOU", "VEOL"])


Used Annotations descriptions: ['feet', 'left_hand', 'right_hand']
Reading 0 ... 252599  =      0.000 ...  1010.396 secs...
Used Annotations descriptions: ['feet', 'left_hand', 'right_hand']
Reading 0 ... 296649  =      0.000 ...  1186.596 secs...


  raw = read_raw_cnt(fname, preload=True, eog=["VEOU", "VEOL"])
  raw = read_raw_cnt(fname, preload=True, eog=["VEOU", "VEOL"])
  raw = read_raw_cnt(fname, preload=True, eog=["VEOU", "VEOL"])
  raw = read_raw_cnt(fname, preload=True, eog=["VEOU", "VEOL"])


Used Annotations descriptions: ['feet', 'left_hand', 'right_hand']
Reading 0 ... 233249  =      0.000 ...   932.996 secs...
Used Annotations descriptions: ['feet', 'left_hand', 'right_hand']
Reading 0 ... 226219  =      0.000 ...   904.876 secs...


  raw = read_raw_cnt(fname, preload=True, eog=["VEOU", "VEOL"])
  raw = read_raw_cnt(fname, preload=True, eog=["VEOU", "VEOL"])
  raw = read_raw_cnt(fname, preload=True, eog=["VEOU", "VEOL"])
  raw = read_raw_cnt(fname, preload=True, eog=["VEOU", "VEOL"])


Used Annotations descriptions: ['feet', 'left_hand', 'right_hand']


 'left_hand': 30
 'right_hand': 30
 'feet': 30>
  warn(f"warnEpochs {epochs}")
 'left_hand': 30
 'right_hand': 29
 'feet': 30>
  warn(f"warnEpochs {epochs}")
 'left_hand': 25
 'right_hand': 25
 'feet': 25>
  warn(f"warnEpochs {epochs}")
 'left_hand': 25
 'right_hand': 25
 'feet': 25>
  warn(f"warnEpochs {epochs}")
 'left_hand': 25
 'right_hand': 25
 'feet': 25>
  warn(f"warnEpochs {epochs}")
 'left_hand': 25
 'right_hand': 25
 'feet': 25>
  warn(f"warnEpochs {epochs}")


No hdf5_path provided, models will not be saved.
No hdf5_path provided, models will not be saved.
No hdf5_path provided, models will not be saved.


Zhou2016-WithinSession:  25%|██▌       | 1/4 [00:03<00:09,  3.08s/it]

Reading 0 ... 227539  =      0.000 ...   910.156 secs...


  raw = read_raw_cnt(fname, preload=True, eog=["VEOU", "VEOL"])


Used Annotations descriptions: ['feet', 'left_hand', 'right_hand']
Reading 0 ... 216079  =      0.000 ...   864.316 secs...


  raw = read_raw_cnt(fname, preload=True, eog=["VEOU", "VEOL"])


Used Annotations descriptions: ['feet', 'left_hand', 'right_hand']
Reading 0 ... 213939  =      0.000 ...   855.756 secs...


  raw = read_raw_cnt(fname, preload=True, eog=["VEOU", "VEOL"])
  raw = read_raw_cnt(fname, preload=True, eog=["VEOU", "VEOL"])


Used Annotations descriptions: ['feet', 'left_hand', 'right_hand']
Reading 0 ... 175269  =      0.000 ...   701.076 secs...


  raw = read_raw_cnt(fname, preload=True, eog=["VEOU", "VEOL"])
  raw = read_raw_cnt(fname, preload=True, eog=["VEOU", "VEOL"])


Used Annotations descriptions: ['feet', 'left_hand', 'right_hand']
Reading 0 ... 213209  =      0.000 ...   852.836 secs...


  raw = read_raw_cnt(fname, preload=True, eog=["VEOU", "VEOL"])
  raw = read_raw_cnt(fname, preload=True, eog=["VEOU", "VEOL"])


Used Annotations descriptions: ['feet', 'left_hand', 'right_hand']
Reading 0 ... 217659  =      0.000 ...   870.636 secs...


  raw = read_raw_cnt(fname, preload=True, eog=["VEOU", "VEOL"])
  raw = read_raw_cnt(fname, preload=True, eog=["VEOU", "VEOL"])


Used Annotations descriptions: ['feet', 'left_hand', 'right_hand']


 'left_hand': 25
 'right_hand': 25
 'feet': 25>
  warn(f"warnEpochs {epochs}")
 'left_hand': 25
 'right_hand': 25
 'feet': 25>
  warn(f"warnEpochs {epochs}")
 'left_hand': 25
 'right_hand': 25
 'feet': 25>
  warn(f"warnEpochs {epochs}")
 'left_hand': 20
 'right_hand': 20
 'feet': 20>
  warn(f"warnEpochs {epochs}")
 'left_hand': 25
 'right_hand': 25
 'feet': 25>
  warn(f"warnEpochs {epochs}")
 'left_hand': 25
 'right_hand': 25
 'feet': 25>
  warn(f"warnEpochs {epochs}")


No hdf5_path provided, models will not be saved.
No hdf5_path provided, models will not be saved.
No hdf5_path provided, models will not be saved.


Zhou2016-WithinSession:  50%|█████     | 2/4 [00:05<00:05,  2.62s/it]

Reading 0 ... 219849  =      0.000 ...   879.396 secs...
Used Annotations descriptions: ['feet', 'left_hand', 'right_hand']


  raw = read_raw_cnt(fname, preload=True, eog=["VEOU", "VEOL"])


Reading 0 ... 216709  =      0.000 ...   866.836 secs...


  raw = read_raw_cnt(fname, preload=True, eog=["VEOU", "VEOL"])


Used Annotations descriptions: ['feet', 'left_hand', 'right_hand']
Reading 0 ... 226609  =      0.000 ...   906.436 secs...


  raw = read_raw_cnt(fname, preload=True, eog=["VEOU", "VEOL"])
  raw = read_raw_cnt(fname, preload=True, eog=["VEOU", "VEOL"])


Used Annotations descriptions: ['feet', 'left_hand', 'right_hand']
Reading 0 ... 266929  =      0.000 ...  1067.716 secs...


  raw = read_raw_cnt(fname, preload=True, eog=["VEOU", "VEOL"])
  raw = read_raw_cnt(fname, preload=True, eog=["VEOU", "VEOL"])


Used Annotations descriptions: ['feet', 'left_hand', 'right_hand']
Reading 0 ... 227989  =      0.000 ...   911.956 secs...


  raw = read_raw_cnt(fname, preload=True, eog=["VEOU", "VEOL"])


Used Annotations descriptions: ['feet', 'left_hand', 'right_hand']
Reading 0 ... 222459  =      0.000 ...   889.836 secs...


  raw = read_raw_cnt(fname, preload=True, eog=["VEOU", "VEOL"])


Used Annotations descriptions: ['feet', 'left_hand', 'right_hand']


 'left_hand': 25
 'right_hand': 25
 'feet': 25>
  warn(f"warnEpochs {epochs}")
 'left_hand': 25
 'right_hand': 25
 'feet': 25>
  warn(f"warnEpochs {epochs}")
 'left_hand': 25
 'right_hand': 25
 'feet': 25>
  warn(f"warnEpochs {epochs}")
 'left_hand': 25
 'right_hand': 25
 'feet': 26>
  warn(f"warnEpochs {epochs}")
 'left_hand': 25
 'right_hand': 25
 'feet': 25>
  warn(f"warnEpochs {epochs}")
 'left_hand': 25
 'right_hand': 25
 'feet': 25>
  warn(f"warnEpochs {epochs}")


No hdf5_path provided, models will not be saved.
No hdf5_path provided, models will not be saved.
No hdf5_path provided, models will not be saved.


Zhou2016-WithinSession:  75%|███████▌  | 3/4 [00:08<00:02,  2.67s/it]

Reading 0 ... 181339  =      0.000 ...   725.356 secs...


  raw = read_raw_cnt(fname, preload=True, eog=["VEOU", "VEOL"])


Used Annotations descriptions: ['feet', 'left_hand', 'right_hand']
Reading 0 ... 217139  =      0.000 ...   868.556 secs...


  raw = read_raw_cnt(fname, preload=True, eog=["VEOU", "VEOL"])


Used Annotations descriptions: ['feet', 'left_hand', 'right_hand']
Reading 0 ... 215399  =      0.000 ...   861.596 secs...
Used Annotations descriptions: ['feet', 'left_hand', 'right_hand']
Reading 0 ... 212209  =      0.000 ...   848.836 secs...


  raw = read_raw_cnt(fname, preload=True, eog=["VEOU", "VEOL"])
  raw = read_raw_cnt(fname, preload=True, eog=["VEOU", "VEOL"])
  raw = read_raw_cnt(fname, preload=True, eog=["VEOU", "VEOL"])
  raw = read_raw_cnt(fname, preload=True, eog=["VEOU", "VEOL"])


Used Annotations descriptions: ['feet', 'left_hand', 'right_hand']
Reading 0 ... 209799  =      0.000 ...   839.196 secs...
Used Annotations descriptions: ['feet', 'left_hand', 'right_hand']
Reading 0 ... 217109  =      0.000 ...   868.436 secs...


  raw = read_raw_cnt(fname, preload=True, eog=["VEOU", "VEOL"])
  raw = read_raw_cnt(fname, preload=True, eog=["VEOU", "VEOL"])


Used Annotations descriptions: ['feet', 'left_hand', 'right_hand']


 'left_hand': 20
 'right_hand': 20
 'feet': 20>
  warn(f"warnEpochs {epochs}")
 'left_hand': 25
 'right_hand': 25
 'feet': 25>
  warn(f"warnEpochs {epochs}")
 'left_hand': 25
 'right_hand': 25
 'feet': 25>
  warn(f"warnEpochs {epochs}")
 'left_hand': 25
 'right_hand': 25
 'feet': 25>
  warn(f"warnEpochs {epochs}")
 'left_hand': 25
 'right_hand': 25
 'feet': 25>
  warn(f"warnEpochs {epochs}")
 'left_hand': 25
 'right_hand': 25
 'feet': 25>
  warn(f"warnEpochs {epochs}")


No hdf5_path provided, models will not be saved.
No hdf5_path provided, models will not be saved.
No hdf5_path provided, models will not be saved.


Zhou2016-WithinSession: 100%|██████████| 4/4 [00:12<00:00,  3.08s/it]


In [8]:
results

Unnamed: 0,score,f1,recall,specificity,precision,time,samples,subject,session,channels,n_sessions,dataset,pipeline
0,0.530635,0.527746,0.530303,0.764655,0.535995,0.040278,179.0,1,0,3,3,Zhou2016,demo_pipeline
1,0.606667,0.606178,0.606667,0.803333,0.623576,0.021214,150.0,1,1,3,3,Zhou2016,demo_pipeline
2,0.646667,0.634515,0.646667,0.823333,0.662476,0.021983,150.0,1,2,3,3,Zhou2016,demo_pipeline
3,0.566667,0.554476,0.566667,0.783333,0.574671,0.02204,150.0,2,0,3,3,Zhou2016,demo_pipeline
4,0.674074,0.673069,0.674074,0.837037,0.686258,0.02024,135.0,2,1,3,3,Zhou2016,demo_pipeline
5,0.66,0.658994,0.66,0.83,0.668599,0.027064,150.0,2,2,3,3,Zhou2016,demo_pipeline
6,0.686667,0.662662,0.686667,0.843333,0.698312,0.022528,150.0,3,0,3,3,Zhou2016,demo_pipeline
7,0.603226,0.585731,0.603636,0.799631,0.605606,0.059109,151.0,3,1,3,3,Zhou2016,demo_pipeline
8,0.653333,0.649564,0.653333,0.826667,0.653056,0.055564,150.0,3,2,3,3,Zhou2016,demo_pipeline
9,0.777778,0.775674,0.777778,0.888889,0.791163,0.215994,135.0,4,0,3,3,Zhou2016,demo_pipeline


In [9]:

data = {
    "accuracy": [results['score'].mean()],
    "f1": [results["f1"].mean()],
    "recall": [results["recall"].mean()],
    "specificity": [results["specificity"].mean()],
    "precision": [results["precision"].mean()]     
    } 
df = pd.DataFrame(data)
print(df)

   accuracy        f1    recall  specificity  precision
0  0.658809  0.651609  0.658816     0.829184   0.667478
