In [33]:
from huggingface_hub import hf_hub_download
import torch
from braindecode import EEGClassifier
from braindecode.models import EEGNetv4
import numpy as np
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score
from moabb.paradigms import MotorImagery
from moabb.datasets import Zhou2016
from braindecode import EEGClassifier
from braindecode.models import EEGNetv4

from collections import OrderedDict
from torch import nn
from skorch import NeuralNet
from skorch.utils import to_numpy
from sklearn.base import TransformerMixin
from braindecode.models import EEGNetv4
from huggingface_hub import hf_hub_download
from sklearn.pipeline import Pipeline
from sklearn.linear_model import LogisticRegression
from sklearn.preprocessing import FunctionTransformer
from moabb.paradigms import MotorImagery
from moabb.datasets import Zhou2016
from moabb.evaluations import WithinSessionEvaluation, CrossSessionEvaluation
from sklearn.metrics import make_scorer, accuracy_score, f1_score

import pickle

import torch
import pandas as pd


In [34]:
file_names = dict(
    torch='torch_params.pkl',
    f_params='skorch_params.pkl',
    f_optimizer='skorch_opt.pkl',
    f_history='skorch_history.json',
)
local_paths = {
    k: hf_hub_download(
        repo_id='PierreGtch/EEGNetv4',
        filename='toy/' + name,
    )
    for k, name in file_names.items()
}

# load the pure pytorch module:
torch_module = EEGNetv4(in_chans=3, n_classes=2, input_window_samples=200)
torch_module.load_state_dict(torch.load(local_paths['torch']))

# load the pure pytorch module:
skorch_module = EEGNetv4(in_chans=3, n_classes=2, input_window_samples=200)
skorch_classifier = EEGClassifier(skorch_module, max_epochs=5)
skorch_classifier.initialize()
skorch_classifier.load_params(
    f_params=local_paths['f_params'],
    f_optimizer=local_paths['f_optimizer'],
    f_history=local_paths['f_history'],
)

In [35]:
# _ = skorch_classifier.partial_fit(X_np, y_np)

In [None]:
paradigm = MotorImagery(
    channels=['C3', 'Cz', 'C4'],  # Same channels as used during pre-training.
    events=['left_hand', 'right_hand', 'feet'],
    n_classes=3,
    fmin=0.5,
    fmax=40,
    tmin=0,
    tmax=3,
    resample=128,
)
datasets = [Zhou2016()]

# --------------------------
# Load data from MOABB
# --------------------------
# Note: get_data returns a dictionary with a key per subject.
data = paradigm.get_data(datasets[0])

In [37]:
# For illustration, if you have a single subject's data returned as a tuple:
X, y, metadata = data
print(f"X shape = {X.shape}, y shape = {y.shape}")


X shape = (1800, 3, 385), y shape = (1800,)


In [38]:
# Change y shape to one hot encoding
one_hot_y = []
i = 0
for label in y:
    if(y[i] == "feet"):
        one_hot_y.append(0)
    if(y[i] == "left_hand"):
        one_hot_y.append(1)
    if(y[i] == "right_hand"):
        one_hot_y.append(2)
    i+=1

one_hot_y = np.array(one_hot_y, dtype=np.int64)

In [39]:
X_finetune, X_test, y_finetune, y_test = train_test_split(
    X, one_hot_y, train_size=0.8, random_state=42, stratify=y)


In [40]:
print(f"Fine-tuning set: {X_finetune.shape}, Test set: {X_test.shape}")

# --------------------------
# Assume you have a SKorch-based EEGClassifier
# that loads a model (e.g. EEGNetv4) and its pre-trained parameters.
# For demonstration, we assume skorch_classifier is already constructed,
# initialized, and loaded with pre-trained parameters.

# Here we build the SKorch model (adjust the parameters as needed).
# This example builds a new instance. In your actual code, you might load it as you did earlier.
skorch_module = EEGNetv4(in_chans=3, n_classes=3, input_window_samples=X.shape[-1])
skorch_classifier = EEGClassifier(skorch_module, max_epochs=20)

# Initialize the SKorch model (and optionally load pre-trained parameters)
skorch_classifier.initialize()
# If you have pre-trained parameters, load them here:
# skorch_classifier.load_params(f_params=..., f_optimizer=..., f_history=...)

# Optional: Ensure that the entire model is trainable (full fine tuning)
# for param in skorch_classifier.module_.parameters():
#     param.requires_grad = True

# --------------------------
# Fine Tune on 20% of the Data Using partial_fit
# --------------------------
# NEW: Fine tuning the entire model on only 20% of the data.
_ = skorch_classifier.partial_fit(X_finetune, y_finetune)

# --------------------------
# Evaluate on the Remaining 80%
# --------------------------
y_pred = skorch_classifier.predict(X_test)
accuracy = accuracy_score(y_test, y_pred)
print(f"Test Accuracy after fine tuning on 20% of the data: {accuracy:.3f}")


Fine-tuning set: (1440, 3, 385), Test set: (360, 3, 385)




  epoch    train_loss    valid_acc    valid_loss     dur
-------  ------------  -----------  ------------  ------
      1        [36m1.2196[0m       [32m0.2986[0m        [35m1.1155[0m  1.9757
      2        [36m1.1597[0m       [32m0.3194[0m        [35m1.1078[0m  2.3092
      3        [36m1.1445[0m       [32m0.3333[0m        [35m1.1035[0m  1.9662
      4        [36m1.1191[0m       [32m0.3368[0m        [35m1.1006[0m  1.8968
      5        [36m1.1038[0m       [32m0.3507[0m        [35m1.0990[0m  1.9140
      6        [36m1.0931[0m       0.3507        [35m1.0978[0m  1.9681
      7        [36m1.0798[0m       [32m0.3646[0m        [35m1.0965[0m  1.8291
      8        [36m1.0685[0m       [32m0.3715[0m        [35m1.0956[0m  2.0896
      9        [36m1.0643[0m       [32m0.3785[0m        [35m1.0948[0m  2.0557
     10        [36m1.0517[0m       [32m0.3854[0m        [35m1.0935[0m  1.8978
     11        [36m1.0342[0m       [32m0.3924[0m   

In [41]:
def remove_clf_layers(model: nn.Sequential):
    """
    Remove the classification layers from braindecode models.
    Tested on EEGNetv4, Deep4Net (i.e. DeepConvNet), and EEGResNet.
    """
    new_layers = []
    for name, layer in model.named_children():
        if 'classif' in name:
            continue
        if 'softmax' in name:
            continue
        new_layers.append((name, layer))
    return nn.Sequential(OrderedDict(new_layers))


def freeze_model(model):
    model.eval()
    for param in model.parameters():
        param.requires_grad = False
    return model

class FrozenNeuralNetTransformer(NeuralNet, TransformerMixin):
    def __init__(
            self,
            *args,
            criterion=nn.MSELoss,  # should be unused
            unique_name=None,  # needed for a unique digest in MOABB
            **kwargs
    ):
        super().__init__(
            *args,
            criterion=criterion,
            **kwargs
        )
        self.initialize()
        self.unique_name = unique_name

    def fit(self, X, y=None, **fit_params):
        return self  # do nothing

    def transform(self, X):
        X = self.infer(X)
        return to_numpy(X)

    def __repr__(self):
        return super().__repr__() + self.unique_name
    
def flatten_batched(X):
    return X.reshape(X.shape[0], -1)

In [42]:
# First, create the embedding from the pre-trained model
path_kwargs = hf_hub_download(
    repo_id='PierreGtch/EEGNetv4',
    filename='EEGNetv4_Lee2019_MI/kwargs.pkl',
)
path_params = hf_hub_download(
    repo_id='PierreGtch/EEGNetv4',
    filename='EEGNetv4_Lee2019_MI/model-params.pkl',
)

with open(path_kwargs, 'rb') as f:
    kwargs = pickle.load(f)
module_cls = kwargs['module_cls']
module_kwargs = kwargs['module_kwargs']

# Load the pre-trained model
torch_module = module_cls(**module_kwargs)
torch_module.load_state_dict(torch.load(path_params, map_location='cpu'))
embedding = freeze_model(remove_clf_layers(torch_module)).double()

# Create a wrapper for your trained classifier to make it compatible with the pipeline
class TrainedClassifierWrapper:
    def __init__(self, trained_classifier):
        self.trained_classifier = trained_classifier
        
    def fit(self, X, y):
        # The classifier is already trained, so we just return self
        return self
        
    def predict(self, X):
        return self.trained_classifier.predict(X)
    
    def predict_proba(self, X):
        return self.trained_classifier.predict_proba(X)

# Create the pipeline using your trained classifier
sklearn_pipeline = Pipeline([
    ('embedding', FrozenNeuralNetTransformer(embedding, unique_name='pretrained_Lee2019')),
    ('flatten', FunctionTransformer(flatten_batched)),
    ('classifier', TrainedClassifierWrapper(skorch_classifier)),
])

# Optional: Add data validation
def validate_data(X, y):
    print(f"Input shape before embedding: {X.shape}")
    # You can add more validation steps here
    return X, y

# Use the pipeline
evaluation = WithinSessionEvaluation(
    paradigm=paradigm,
    hdf5_path='results.h5',
    n_jobs=1
)
s
results = evaluation.process(
    pipelines=dict(demo_pipeline=sklearn_pipeline),
    param_grid=None,
    postprocess_pipeline=validate_data
)

<moabb.datasets.alex_mi.AlexMI object at 0x166e26140> not compatible with paradigm. Removing this dataset from the list.
<moabb.datasets.bnci.BNCI2014_002 object at 0x166e27be0> not compatible with paradigm. Removing this dataset from the list.
<moabb.datasets.bnci.BNCI2014_004 object at 0x166e26d70> not compatible with paradigm. Removing this dataset from the list.
<moabb.datasets.bnci.BNCI2015_001 object at 0x166e26500> not compatible with paradigm. Removing this dataset from the list.
<moabb.datasets.bnci.BNCI2015_004 object at 0x166e27880> not compatible with paradigm. Removing this dataset from the list.
<moabb.datasets.gigadb.Cho2017 object at 0x166e269e0> not compatible with paradigm. Removing this dataset from the list.
<moabb.datasets.mpi_mi.GrosseWentrup2009 object at 0x166e27b50> not compatible with paradigm. Removing this dataset from the list.
<moabb.datasets.Lee2019.Lee2019_MI object at 0x166e25c90> not compatible with paradigm. Removing this dataset from the list.
<moabb

AssertionError: 