Skip to content

Commit

Permalink
get events before RawToEpochs pipeline
Browse files Browse the repository at this point in the history
  • Loading branch information
PierreGtch committed Jul 14, 2023
1 parent dca3903 commit 0514ce0
Show file tree
Hide file tree
Showing 3 changed files with 33 additions and 15 deletions.
6 changes: 3 additions & 3 deletions moabb/datasets/preprocessing.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,9 @@ def __init__(
self.baseline = baseline
self.channels = channels

def transform(self, raw, y=None):
def transform(self, X, y=None):
raw = X["raw"]
events = X["events"]
if not isinstance(raw, mne.io.BaseRaw):
raise ValueError("raw must be a mne.io.BaseRaw")

Expand All @@ -98,8 +100,6 @@ def transform(self, raw, y=None):
raw.info["ch_names"], include=self.channels, ordered=True
)

events = RawToEvents(self.event_id).transform(raw)

epochs = mne.Epochs(
raw,
events,
Expand Down
21 changes: 15 additions & 6 deletions moabb/paradigms/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from moabb.datasets.preprocessing import (
EpochsToEvents,
EventsToLabels,
ForkPipelines,
RawToEpochs,
RawToEvents,
get_crop_pipeline,
Expand Down Expand Up @@ -307,12 +308,20 @@ def _get_epochs_pipeline(self, return_epochs, return_raws, dataset):
steps.append(
(
"epoching",
RawToEpochs(
event_id=self.used_events(dataset),
tmin=bmin,
tmax=bmax,
baseline=baseline,
channels=self.channels,
make_pipeline(
ForkPipelines(
[
("raw", make_pipeline(None)),
("events", self._get_events_pipeline(dataset)),
]
),
RawToEpochs(
event_id=self.used_events(dataset),
tmin=bmin,
tmax=bmax,
baseline=baseline,
channels=self.channels,
),
),
)
)
Expand Down
21 changes: 15 additions & 6 deletions moabb/tests/datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,13 @@
from operator import methodcaller

import mne
from sklearn.pipeline import make_pipeline
from sklearn.preprocessing import FunctionTransformer

from moabb.datasets import Shin2017A, Shin2017B, VirtualReality
from moabb.datasets.compound_dataset import CompoundDataset
from moabb.datasets.fake import FakeDataset, FakeVirtualRealityDataset
from moabb.datasets.preprocessing import RawToEpochs
from moabb.datasets.preprocessing import ForkPipelines, RawToEpochs, RawToEvents
from moabb.datasets.utils import block_rep
from moabb.paradigms import P300

Expand Down Expand Up @@ -74,11 +75,19 @@ def test_cache_dataset(self):
pipelines_list = [
dict(), # test BIDSInterfaceRawEDF
dict(
epochs_pipeline=RawToEpochs( # test BIDSInterfaceEpochs
event_id=dataset.event_id,
tmin=dataset.interval[0],
tmax=dataset.interval[1],
baseline=tuple(dataset.interval),
epochs_pipeline=make_pipeline(
ForkPipelines(
[
("raw", make_pipeline(None)),
("events", RawToEvents(dataset.event_id)),
]
),
RawToEpochs( # test BIDSInterfaceEpochs
event_id=dataset.event_id,
tmin=dataset.interval[0],
tmax=dataset.interval[1],
baseline=tuple(dataset.interval),
),
)
),
dict(
Expand Down

0 comments on commit 0514ce0

Please sign in to comment.