Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adding Braindecode pipeline #328

Merged
merged 54 commits into from
Apr 4, 2023
Merged
Show file tree
Hide file tree
Changes from 15 commits
Commits
Show all changes
54 commits
Select commit Hold shift + click to select a range
4c59cfd
Adding braindecode object as a pipeline.
bruAristimunha Feb 7, 2023
4d0b169
Adding in init file and changing the name.
bruAristimunha Feb 7, 2023
2037230
Adding new pipeline.
bruAristimunha Feb 7, 2023
3756a76
Adding new dependence, braindecode
bruAristimunha Feb 7, 2023
2eb5a4c
Adding new dependence, torch
bruAristimunha Feb 7, 2023
e5a0534
Merge branch 'develop' into braindecode
bruAristimunha Feb 7, 2023
42f8189
Merge branch 'develop' into braindecode
bruAristimunha Feb 8, 2023
b42657c
Updating the dependencies, set as optional.
bruAristimunha Feb 9, 2023
187954a
Setting the Valid Split to the new pipeline.
bruAristimunha Feb 9, 2023
0706b48
Merge remote-tracking branch 'origin/braindecode' into braindecode
bruAristimunha Feb 9, 2023
3334dfd
restoring the file
bruAristimunha Feb 12, 2023
6fb5b45
Merge branch 'develop' into braindecode
bruAristimunha Feb 13, 2023
eb3b460
Merge branch 'develop' into braindecode
bruAristimunha Feb 14, 2023
4d22b40
Merge branch 'develop' into braindecode
bruAristimunha Mar 3, 2023
1eaf054
Merge branch 'develop' into braindecode
bruAristimunha Mar 3, 2023
0697459
Updating __init__
bruAristimunha Mar 11, 2023
bea1d6b
Removing the BraindecodeClassifierModel
bruAristimunha Mar 11, 2023
555bf0c
Updating EEGClassifier to use the max_epochs
bruAristimunha Mar 11, 2023
76a0752
Adding braindecode as depedencies
bruAristimunha Mar 11, 2023
014162c
Moving the file to other nome
bruAristimunha Mar 11, 2023
df77fda
Merge branch 'develop' into braindecode
bruAristimunha Mar 11, 2023
039c73d
Adding support ot braindecode classifier
bruAristimunha Mar 11, 2023
a801d67
Adding y as None value
bruAristimunha Mar 11, 2023
bcbeed1
first iteration to use the ShallowNet from braindecode with yaml file
bruAristimunha Mar 11, 2023
ff470cb
Adding as example
bruAristimunha Mar 11, 2023
c9ff708
To discuss
bruAristimunha Mar 11, 2023
d70e1fe
Merge branch 'develop' into braindecode
bruAristimunha Mar 19, 2023
c557d98
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Mar 19, 2023
eafc342
Removing braindecode file
bruAristimunha Mar 19, 2023
2fe6cd2
Adding yaml files
bruAristimunha Mar 19, 2023
47bd4d5
Removing ShallowFBCSPNet, duplicate
bruAristimunha Mar 19, 2023
846de60
Updating Braindecode_ShallowFBCSPNET.yml to add the inputShapeSetterEEG
bruAristimunha Mar 19, 2023
37ba988
working on the parser
bruAristimunha Mar 21, 2023
32d89b5
removing braindecode keyword
bruAristimunha Mar 21, 2023
df2f933
working more on the parser
bruAristimunha Mar 22, 2023
afcdc10
Merge branch 'develop' into braindecode
bruAristimunha Mar 28, 2023
adb2914
Adding one more test
bruAristimunha Mar 28, 2023
f0a3479
Adding check module
bruAristimunha Mar 28, 2023
51c11a0
Merge remote-tracking branch 'origin/braindecode' into braindecode
bruAristimunha Mar 28, 2023
6851075
Improving the BraindecodeDatasetLoader
bruAristimunha Mar 28, 2023
934f407
Removing the yaml file for braindecode object
bruAristimunha Mar 28, 2023
2353869
Improving the examples
bruAristimunha Mar 28, 2023
7eeae68
Naming the variable
bruAristimunha Mar 28, 2023
21fa441
Returning the old parser
bruAristimunha Mar 28, 2023
a72555e
Merge branch 'develop' into braindecode
bruAristimunha Mar 28, 2023
4e1b691
adding test folder
bruAristimunha Mar 28, 2023
094fe42
Merge remote-tracking branch 'origin/braindecode' into braindecode
bruAristimunha Mar 28, 2023
87a5c28
fix: correct error when multiple pipelines
Mar 29, 2023
e42fa17
fix: correct doc error
Mar 29, 2023
c77eefb
revert: leave braindecode order
Mar 29, 2023
ee7b3a7
fix: benchmark unit test passed
Mar 29, 2023
13fd28a
fix: doc building error
Mar 29, 2023
4d5cab5
Update moabb/benchmark.py
bruAristimunha Apr 4, 2023
f4e3b36
Merge branch 'develop' into braindecode
sylvchev Apr 4, 2023
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
2 changes: 2 additions & 0 deletions moabb/pipelines/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@
Pipelines are typically a chain of sklearn compatible transformers and end
with a sklearn compatible estimator.
"""
from .braindecode import BraindecodeClassifierModel, CreateBraindecodeDataset

bruAristimunha marked this conversation as resolved.
Show resolved Hide resolved
# flake8: noqa
from .classification import SSVEP_CCA, SSVEP_TRCA
from .features import FM, ExtendedSSVEPSignal, LogVariance
Expand Down
76 changes: 76 additions & 0 deletions moabb/pipelines/braindecode.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
from braindecode.datasets import WindowsDataset, create_from_X_y
from mne.epochs import BaseEpochs
from numpy import array, unique
from sklearn.base import BaseEstimator, ClassifierMixin, TransformerMixin


class CreateBraindecodeDataset(BaseEstimator, TransformerMixin):
"""
Wrapper to create a Braindecode Dataset from a mne Epoched
bruAristimunha marked this conversation as resolved.
Show resolved Hide resolved
object.

This is a transformer function that allow used to use the
dataset as a sklearn pipeline.
"""

def __init__(self, kw_args: dict = None):
"""

Parameters
----------
kw_args: dict

"""
self.kw_args = kw_args

def fit(self, X: BaseEpochs, y=None):
self.y = y
return self

def transform(self, X: BaseEpochs, y=None) -> WindowsDataset:
"""

Parameters
----------
X: BaseEpochs object from mne
y: list|array of labels

Returns
-------
WindowsDataset: Braindecode Dataset
"""
dataset = create_from_X_y(
X.get_data(),
y=self.y,
window_size_samples=X.get_data().shape[2],
window_stride_samples=X.get_data().shape[2],
drop_last_window=False,
sfreq=X.info["sfreq"],
)

return dataset

def __sklearn_is_fitted__(self) -> bool:
"""
Return True since CreateBraindecodeDataset is stateless.
"""
return True


class BraindecodeClassifierModel(BaseEstimator, ClassifierMixin):
def __init__(self, clf: BaseEstimator, kw_args: dict = None):
self.clf = clf
self.classes_ = None
self.kw_args = kw_args

def fit(self, X: WindowsDataset, y=None) -> BaseEstimator:
self.clf.fit(X, y=y, **self.kw_args)
self.classes_ = unique(y)

return self.clf

def predict(self, X: WindowsDataset) -> array:
return self.clf.predict(X)

def predict_proba(self, X: WindowsDataset) -> array:
return self.clf.predict_proba(X)
62 changes: 62 additions & 0 deletions pipelines/braindecode.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
from braindecode.classifier import EEGClassifier
from braindecode.models import ShallowFBCSPNet
from sklearn.pipeline import Pipeline
from skorch.callbacks import LRScheduler
from skorch.dataset import ValidSplit
from torch.cuda import is_available
from torch.nn import NLLLoss
from torch.optim import AdamW

from moabb.pipelines.braindecode import (
BraindecodeClassifierModel,
CreateBraindecodeDataset,
)


# hard-coded for now
n_classes = 2
n_chans = 22
input_window_samples = 1001
# These values we found good for shallow network:
lr = 0.0625 * 0.01
weight_decay = 0
batch_size = 64
n_epochs = 4

model = ShallowFBCSPNet(
n_chans,
n_classes,
input_window_samples=input_window_samples,
final_conv_length="auto",
)
device = "cuda" if is_available() else "cpu"

clf = EEGClassifier(
bruAristimunha marked this conversation as resolved.
Show resolved Hide resolved
model,
criterion=NLLLoss,
optimizer=AdamW,
train_split=ValidSplit(0.2), # using valid_set for validation
optimizer__lr=lr,
optimizer__weight_decay=weight_decay,
batch_size=batch_size,
callbacks=[
"accuracy",
("lr_scheduler", LRScheduler("CosineAnnealingLR", T_max=n_epochs - 1)),
],
device=device,
)

create_dataset = CreateBraindecodeDataset()
fit_params = {"epochs": 10}

clf_braindecode = BraindecodeClassifierModel(clf, fit_params)

pipe = Pipeline([("Braindecode_dataset", create_dataset), ("Net", clf_braindecode)])

pipes = {"ShallowFBCSPNet": pipe}

PIPELINE = {
bruAristimunha marked this conversation as resolved.
Show resolved Hide resolved
"name": "ShallowFBCSPNet",
"paradigms": ["LeftRightImagery"],
"pipeline": pipe,
}
2 changes: 2 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,8 @@ pre-commit = "^2.21.0"
m2r2 = "^0.3.3"
tdlda = {git = "https://github.com/jsosulski/tdlda.git", rev = "0.1.0"}

bruAristimunha marked this conversation as resolved.
Show resolved Hide resolved


bruAristimunha marked this conversation as resolved.
Show resolved Hide resolved
[build-system]
requires = ["poetry-core>=1.0.0"]
build-backend = "poetry.core.masonry.api"
Expand Down