Skip to content

Commit

Permalink
Add FixedIntervalWindowsProcessing, Closes NeuroTechX#424
Browse files Browse the repository at this point in the history
  • Loading branch information
PierreGtch committed Jul 14, 2023
1 parent 0514ce0 commit 39a165e
Show file tree
Hide file tree
Showing 3 changed files with 153 additions and 7 deletions.
37 changes: 35 additions & 2 deletions moabb/datasets/preprocessing.py
@@ -1,9 +1,10 @@
import logging
from collections import OrderedDict
from operator import methodcaller
from typing import Type
from typing import Union

import mne
import numpy as np
from sklearn.base import BaseEstimator, TransformerMixin
from sklearn.pipeline import FunctionTransformer, Pipeline

Expand All @@ -17,7 +18,7 @@ def _is_none_pipeline(pipeline):


class ForkPipelines(TransformerMixin, BaseEstimator):
def __init__(self, transformers: list[tuple[str, Type[TransformerMixin]]]):
def __init__(self, transformers: list[tuple[str, Union[Pipeline, TransformerMixin]]]):
for _, t in transformers:
assert hasattr(t, "transform")
self.transformers = transformers
Expand Down Expand Up @@ -56,6 +57,38 @@ def transform(self, raw, y=None):
return events


class RawToFixedIntervalEvents(FixedTransformer):
def __init__(
self,
length_samples,
stride_samples,
start_offset_samples,
stop_offset_samples,
marker=1,
):
self.length_samples = length_samples
self.stride_samples = stride_samples
self.start_offset_samples = start_offset_samples
self.stop_offset_samples = stop_offset_samples
self.marker = marker

def transform(self, raw: mne.io.BaseRaw, y=None):
stop_offset_samples = (
raw.n_times if self.stop_offset_samples is None else self.stop_offset_samples
)
stop_samples = stop_offset_samples - self.length_samples + raw.first_samp
onset = np.arange(
raw.first_samp + self.start_offset_samples,
stop_samples + 1,
self.window_stride_samples,
)
events = np.empty((len(onset), 3), dtype=int)
events[:, 0] = onset
events[:, 1] = self.length_samples
events[:, 2] = self.marker
return events


class EpochsToEvents(FixedTransformer):
def transform(self, epochs, y=None):
return epochs.events
Expand Down
10 changes: 5 additions & 5 deletions moabb/datasets/utils.py
Expand Up @@ -22,7 +22,7 @@ def _init_dataset_list():


def dataset_search( # noqa: C901
paradigm,
paradigm=None,
multi_session=False,
events=None,
has_all_events=False,
Expand All @@ -35,8 +35,8 @@ def dataset_search( # noqa: C901
Parameters
----------
paradigm: str
'imagery', 'p300', 'ssvep'
paradigm: str | None
'imagery', 'p300', 'ssvep', None
multi_session: bool
if True only returns datasets with more than one session per subject.
Expand Down Expand Up @@ -67,7 +67,7 @@ def dataset_search( # noqa: C901
n_classes = len(events)
else:
n_classes = None
assert paradigm in ["imagery", "p300", "ssvep"]
assert paradigm in ["imagery", "p300", "ssvep", None]

for type_d in dataset_list:
d = type_d()
Expand All @@ -78,7 +78,7 @@ def dataset_search( # noqa: C901
if len(d.subject_list) < min_subjects:
continue

if paradigm != d.paradigm:
if paradigm is not None and paradigm != d.paradigm:
continue

if interval is not None and d.interval[1] - d.interval[0] < interval:
Expand Down
113 changes: 113 additions & 0 deletions moabb/paradigms/fixed_interval_windows.py
@@ -0,0 +1,113 @@
from moabb.datasets import utils
from moabb.datasets.preprocessing import RawToFixedIntervalEvents
from moabb.paradigms.base import BaseProcessing


class FixedIntervalWindowsProcessing(BaseProcessing):
"""Paradigm for creating epochs at fixed interval,
ignoring the stim channel and events of the dataset.
Parameters
----------
filters: list of list (default [[7, 45]])
bank of bandpass filter to apply.
baseline: None | tuple of length 2
The time interval to consider as “baseline” when applying baseline
correction. If None, do not apply baseline correction.
If a tuple (a, b), the interval is between a and b (in seconds),
including the endpoints.
Correction is applied by computing the mean of the baseline period
and subtracting it from the data (see mne.Epochs)
channels: list of str | None (default None)
list of channel to select. If None, use all EEG channels available in
the dataset.
resample: float | None (default None)
If not None, resample the eeg data with the sampling rate provided.
length: float (default 5.0)
Length of the epochs in seconds.
stride: float (default 10.0)
Stride between epochs in seconds.
start_offset: float (default 0.0)
Start from the beginning of the raw recordings in seconds.
stop_offset: float | None (default None)
Stop offset from beginning of raw recordings in seconds.
If None, set to be the end of the recording.
marker: int (default 1)
Marker to use for the events created.
"""

def __init__(
self,
filters=((7, 45)),
baseline=None,
channels=None,
resample=None,
length: float = 5.0,
stride: float = 10.0,
start_offset=0.0,
stop_offset=None,
marker=1,
):
tmin = 0.0
tmax = length
super().__init__(
filters=filters,
channels=channels,
baseline=baseline,
resample=resample,
tmin=tmin,
tmax=tmax,
)
self.length = length
self.stride = stride
self.start_offset = start_offset
self.stop_offset = stop_offset
self.marker = marker

def _to_samples(self, key, dataset=None):
value = getattr(self, key)
if dataset is None and self.resample is None:
raise ValueError(f"{key}_samples: dataset or resample must be specified")
return int(value * self.resample)

def length_samples(self, dataset=None):
return self._to_samples("length", dataset)

def stride_samples(self, dataset=None):
return self._to_samples("stride", dataset)

def start_offset_samples(self, dataset=None):
return self._to_samples("start_offset", dataset)

def stop_offset_samples(self, dataset=None):
if self.stop_offset is None:
return None
return self._to_samples("stop_offset", dataset)

def used_events(self, dataset):
return {"Window": self.marker}

def is_valid(self, dataset):
return True

@property
def datasets(self):
return utils.dataset_search(paradigm=None)

def _get_events_pipeline(self, dataset):
return RawToFixedIntervalEvents(
length_samples=self.length_samples(dataset),
stride_samples=self.stride_samples(dataset),
start_offset_samples=self.start_offset_samples(dataset),
stop_offset_samples=self.stop_offset_samples(dataset),
marker=self.marker,
)

0 comments on commit 39a165e

Please sign in to comment.