From 39a165ec20a4de37488673ad52d296cbeac7bc56 Mon Sep 17 00:00:00 2001 From: Pierre Guetschel Date: Fri, 14 Jul 2023 11:32:11 +0200 Subject: [PATCH] Add FixedIntervalWindowsProcessing, Closes #424 --- moabb/datasets/preprocessing.py | 37 ++++++- moabb/datasets/utils.py | 10 +- moabb/paradigms/fixed_interval_windows.py | 113 ++++++++++++++++++++++ 3 files changed, 153 insertions(+), 7 deletions(-) create mode 100644 moabb/paradigms/fixed_interval_windows.py diff --git a/moabb/datasets/preprocessing.py b/moabb/datasets/preprocessing.py index 020b3fa4a..202cbb569 100644 --- a/moabb/datasets/preprocessing.py +++ b/moabb/datasets/preprocessing.py @@ -1,9 +1,10 @@ import logging from collections import OrderedDict from operator import methodcaller -from typing import Type +from typing import Union import mne +import numpy as np from sklearn.base import BaseEstimator, TransformerMixin from sklearn.pipeline import FunctionTransformer, Pipeline @@ -17,7 +18,7 @@ def _is_none_pipeline(pipeline): class ForkPipelines(TransformerMixin, BaseEstimator): - def __init__(self, transformers: list[tuple[str, Type[TransformerMixin]]]): + def __init__(self, transformers: list[tuple[str, Union[Pipeline, TransformerMixin]]]): for _, t in transformers: assert hasattr(t, "transform") self.transformers = transformers @@ -56,6 +57,38 @@ def transform(self, raw, y=None): return events +class RawToFixedIntervalEvents(FixedTransformer): + def __init__( + self, + length_samples, + stride_samples, + start_offset_samples, + stop_offset_samples, + marker=1, + ): + self.length_samples = length_samples + self.stride_samples = stride_samples + self.start_offset_samples = start_offset_samples + self.stop_offset_samples = stop_offset_samples + self.marker = marker + + def transform(self, raw: mne.io.BaseRaw, y=None): + stop_offset_samples = ( + raw.n_times if self.stop_offset_samples is None else self.stop_offset_samples + ) + stop_samples = stop_offset_samples - self.length_samples + raw.first_samp + onset = np.arange( + raw.first_samp + self.start_offset_samples, + stop_samples + 1, + self.window_stride_samples, + ) + events = np.empty((len(onset), 3), dtype=int) + events[:, 0] = onset + events[:, 1] = self.length_samples + events[:, 2] = self.marker + return events + + class EpochsToEvents(FixedTransformer): def transform(self, epochs, y=None): return epochs.events diff --git a/moabb/datasets/utils.py b/moabb/datasets/utils.py index 8d004642a..17af2c682 100644 --- a/moabb/datasets/utils.py +++ b/moabb/datasets/utils.py @@ -22,7 +22,7 @@ def _init_dataset_list(): def dataset_search( # noqa: C901 - paradigm, + paradigm=None, multi_session=False, events=None, has_all_events=False, @@ -35,8 +35,8 @@ def dataset_search( # noqa: C901 Parameters ---------- - paradigm: str - 'imagery', 'p300', 'ssvep' + paradigm: str | None + 'imagery', 'p300', 'ssvep', None multi_session: bool if True only returns datasets with more than one session per subject. @@ -67,7 +67,7 @@ def dataset_search( # noqa: C901 n_classes = len(events) else: n_classes = None - assert paradigm in ["imagery", "p300", "ssvep"] + assert paradigm in ["imagery", "p300", "ssvep", None] for type_d in dataset_list: d = type_d() @@ -78,7 +78,7 @@ def dataset_search( # noqa: C901 if len(d.subject_list) < min_subjects: continue - if paradigm != d.paradigm: + if paradigm is not None and paradigm != d.paradigm: continue if interval is not None and d.interval[1] - d.interval[0] < interval: diff --git a/moabb/paradigms/fixed_interval_windows.py b/moabb/paradigms/fixed_interval_windows.py new file mode 100644 index 000000000..7108f8673 --- /dev/null +++ b/moabb/paradigms/fixed_interval_windows.py @@ -0,0 +1,113 @@ +from moabb.datasets import utils +from moabb.datasets.preprocessing import RawToFixedIntervalEvents +from moabb.paradigms.base import BaseProcessing + + +class FixedIntervalWindowsProcessing(BaseProcessing): + """Paradigm for creating epochs at fixed interval, + ignoring the stim channel and events of the dataset. + + Parameters + ---------- + + filters: list of list (default [[7, 45]]) + bank of bandpass filter to apply. + + baseline: None | tuple of length 2 + The time interval to consider as “baseline” when applying baseline + correction. If None, do not apply baseline correction. + If a tuple (a, b), the interval is between a and b (in seconds), + including the endpoints. + Correction is applied by computing the mean of the baseline period + and subtracting it from the data (see mne.Epochs) + + channels: list of str | None (default None) + list of channel to select. If None, use all EEG channels available in + the dataset. + + resample: float | None (default None) + If not None, resample the eeg data with the sampling rate provided. + + length: float (default 5.0) + Length of the epochs in seconds. + + stride: float (default 10.0) + Stride between epochs in seconds. + + start_offset: float (default 0.0) + Start from the beginning of the raw recordings in seconds. + + stop_offset: float | None (default None) + Stop offset from beginning of raw recordings in seconds. + If None, set to be the end of the recording. + + marker: int (default 1) + Marker to use for the events created. + """ + + def __init__( + self, + filters=((7, 45)), + baseline=None, + channels=None, + resample=None, + length: float = 5.0, + stride: float = 10.0, + start_offset=0.0, + stop_offset=None, + marker=1, + ): + tmin = 0.0 + tmax = length + super().__init__( + filters=filters, + channels=channels, + baseline=baseline, + resample=resample, + tmin=tmin, + tmax=tmax, + ) + self.length = length + self.stride = stride + self.start_offset = start_offset + self.stop_offset = stop_offset + self.marker = marker + + def _to_samples(self, key, dataset=None): + value = getattr(self, key) + if dataset is None and self.resample is None: + raise ValueError(f"{key}_samples: dataset or resample must be specified") + return int(value * self.resample) + + def length_samples(self, dataset=None): + return self._to_samples("length", dataset) + + def stride_samples(self, dataset=None): + return self._to_samples("stride", dataset) + + def start_offset_samples(self, dataset=None): + return self._to_samples("start_offset", dataset) + + def stop_offset_samples(self, dataset=None): + if self.stop_offset is None: + return None + return self._to_samples("stop_offset", dataset) + + def used_events(self, dataset): + return {"Window": self.marker} + + def is_valid(self, dataset): + return True + + @property + def datasets(self): + return utils.dataset_search(paradigm=None) + + def _get_events_pipeline(self, dataset): + return RawToFixedIntervalEvents( + length_samples=self.length_samples(dataset), + stride_samples=self.stride_samples(dataset), + start_offset_samples=self.start_offset_samples(dataset), + stop_offset_samples=self.stop_offset_samples(dataset), + marker=self.marker, + )