Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix case where no epochs in run #491

Merged
merged 13 commits into from Sep 25, 2023
1 change: 1 addition & 0 deletions docs/source/whats_new.rst
Expand Up @@ -77,6 +77,7 @@ Bugs
- Fix ``dataset_list`` to include deprecated datasets (PR :gh:`464` by `Bruno Aristimunha`_)
- Fixed bug in :func:`moabb.analysis.results.get_string_rep` to handle addresses such as 0x__0A as well (PR :gh:`468` by `Anton Andreev`_)
- Removing joblib Parallel (:gh:`488` by `Igor Carrara`_)
- Fix case when events specified via ``raw.annotations`` but no events (:gh:`491` by `Pierre Guetschel`_)

API changes
~~~~~~~~~~~
Expand Down
6 changes: 3 additions & 3 deletions moabb/datasets/bids_interface.py
Expand Up @@ -174,7 +174,7 @@ def load(self, preload=False):
log.info("Attempting to retrieve cache of %s...", repr(self))
self.lock_file.mkdir(exist_ok=True)
if not self.lock_file.fpath.exists():
log.info("No cache found at %s.", {str(self.lock_file.directory)})
log.info("No cache found at %s.", str(self.lock_file.directory))
return None
paths = mne_bids.find_matching_paths(
root=self.root,
Expand All @@ -191,7 +191,7 @@ def load(self, preload=False):
session = sessions_data.setdefault(session_moabb, {})
run = self._load_file(path, preload=preload)
session[run_bids_to_moabb(path.run)] = run
log.info("Finished reading cache of %s", {repr(self)})
log.info("Finished reading cache of %s", repr(self))
return sessions_data

def save(self, sessions_data):
Expand All @@ -207,7 +207,7 @@ def save(self, sessions_data):

The type of the ``run`` object can vary (see the subclases).
"""
log.info("Starting caching %s", {repr(self)})
log.info("Starting caching %s", repr(self))
mne_bids.BIDSPath(root=self.root).mkdir(exist_ok=True)
mne_bids.make_dataset_description(
path=str(self.root),
Expand Down
103 changes: 80 additions & 23 deletions moabb/datasets/fake.py
Expand Up @@ -3,7 +3,7 @@
from pathlib import Path

import numpy as np
from mne import create_info, get_config, set_config
from mne import Annotations, annotations_from_events, create_info, get_config, set_config
from mne.channels import make_standard_montage
from mne.io import RawArray

Expand Down Expand Up @@ -31,6 +31,16 @@ class FakeDataset(BaseDataset):
Defines what sort of dataset this is
channels: list or tuple of str
List of channels to generate, default ("C3", "Cz", "C4")
duration: float or list of float
Duration of each run in seconds. If float, same duration for all
runs. If list, duration for each run.
n_events: int or list of int
Number of events per run. If int, same number of events
for all runs. If list, number of events for each run.
stim: bool
If True, pass events through stim channel.
annotations: bool
If True, pass events through Annotations.

.. versionadded:: 0.4.3
"""
Expand All @@ -46,16 +56,27 @@ def __init__(
channels=("C3", "Cz", "C4"),
seed=None,
sfreq=128,
duration=120,
n_events=60,
stim=True,
annotations=False,
):
self.n_runs = n_runs
self.n_events = n_events if isinstance(n_events, list) else [n_events] * n_runs
self.duration = duration if isinstance(duration, list) else [duration] * n_runs
assert len(self.n_events) == n_runs
assert len(self.duration) == n_runs
self.sfreq = sfreq
event_id = {ev: ii + 1 for ii, ev in enumerate(event_list)}
self.channels = channels
self.stim = stim
self.annotations = annotations
self.seed = seed
code = (
f"{code}-{paradigm.lower()}-{n_subjects}-{n_sessions}-{n_runs}-"
f"{''.join([re.sub('[^A-Za-z0-9]', '', e).lower() for e in event_list])}-"
f"{''.join([c.lower() for c in channels])}"
f"{code}-{paradigm.lower()}-{n_subjects}-{n_sessions}--"
f"{'-'.join([str(n) for n in self.n_events])}--"
f"{'-'.join([str(int(n)) for n in self.duration])}--"
f"{'-'.join([re.sub('[^A-Za-z0-9]', '', e).lower() for e in event_list])}--"
f"{'-'.join([c.lower() for c in channels])}"
)
super().__init__(
subjects=list(range(1, n_subjects + 1)),
Expand All @@ -77,29 +98,57 @@ def _get_single_subject_data(self, subject):
data = dict()
for session in range(self.n_sessions):
data[f"session_{session}"] = {
f"run_{ii}": self._generate_raw() for ii in range(self.n_runs)
f"run_{ii}": self._generate_raw(n, d)
for ii, (n, d) in enumerate(zip(self.n_events, self.duration))
}
return data

def _generate_raw(self):
montage = make_standard_montage("standard_1005")
sfreq = self.sfreq
duration = len(self.event_id) * 60
eeg_data = 2e-5 * np.random.randn(duration * sfreq, len(self.channels))
y = np.zeros((duration * sfreq))
def _generate_events(self, n_events, duration):
start = max(0, int(self.interval[0] * self.sfreq)) + 1
stop = (
min(
int((duration - self.interval[1]) * self.sfreq),
int(duration * self.sfreq),
)
- 1
)
onset = np.linspace(start, stop, n_events)
events = np.zeros((n_events, 3), dtype="int32")
events[:, 0] = onset
for ii, ev in enumerate(self.event_id):
start_idx = (1 + 5 * ii) * 128
jump = 5 * len(self.event_id) * 128
y[start_idx::jump] = self.event_id[ev]

ch_types = ["eeg"] * len(self.channels) + ["stim"]
ch_names = list(self.channels) + ["stim"]
events[ii :: len(self.event_id), 2] = self.event_id[ev]
return events

eeg_data = np.c_[eeg_data, y]
def _generate_raw(self, n_events, duration):
montage = make_standard_montage("standard_1005")
sfreq = self.sfreq
eeg_data = 2e-5 * np.random.randn(int(duration * sfreq), len(self.channels))
events = self._generate_events(n_events, duration)
ch_types = ["eeg"] * len(self.channels)
ch_names = list(self.channels)

if self.stim:
y = np.zeros(eeg_data.shape[0])
y[events[:, 0]] = events[:, 2]
ch_types += ["stim"]
ch_names += ["stim"]
eeg_data = np.c_[eeg_data, y]

info = create_info(ch_names=ch_names, ch_types=ch_types, sfreq=sfreq)
raw = RawArray(data=eeg_data.T, info=info, verbose=False)
raw.set_montage(montage)

if self.annotations:
event_desc = {v: k for k, v in self.event_id.items()}
if len(events) != 0:
annotations = annotations_from_events(
events, sfreq=sfreq, event_desc=event_desc
)
annotations.set_durations(self.interval[1] - self.interval[0])
else:
annotations = Annotations([], [], [])
raw.set_annotations(annotations)

return raw

def data_path(
Expand All @@ -117,6 +166,8 @@ class FakeVirtualRealityDataset(FakeDataset):
def __init__(self, seed=None):
self.n_blocks = 5
self.n_repetitions = 12
self.n_events_rep = [60] * self.n_repetitions
self.duration_rep = [120] * self.n_repetitions
super().__init__(
n_sessions=1,
n_runs=self.n_blocks * self.n_repetitions,
Expand All @@ -125,6 +176,10 @@ def __init__(self, seed=None):
event_list=dict(Target=2, NonTarget=1),
paradigm="p300",
seed=seed,
duration=self.duration_rep * self.n_blocks,
n_events=self.n_events_rep * self.n_blocks,
stim=True,
annotations=False,
)

def _get_single_subject_data(self, subject):
Expand All @@ -134,10 +189,12 @@ def _get_single_subject_data(self, subject):
for session in range(self.n_sessions):
data[f"{session}"] = {}
for block in range(self.n_blocks):
for repetition in range(self.n_repetitions):
data[f"{session}"][
block_rep(block, repetition)
] = self._generate_raw()
for repetition, (n, d) in enumerate(
zip(self.n_events_rep, self.duration_rep)
):
data[f"{session}"][block_rep(block, repetition)] = self._generate_raw(
n, d
)
return data

def get_block_repetition(self, paradigm, subjects, block_list, repetition_list):
Expand Down
107 changes: 61 additions & 46 deletions moabb/datasets/preprocessing.py
Expand Up @@ -21,6 +21,15 @@ def _is_none_pipeline(pipeline):
)


def _unsafe_pick_events(events, include):
try:
return mne.pick_events(events, include=include)
except RuntimeError as e:
if str(e) == "No events found":
return np.zeros((0, 3), dtype="int32")
raise e


class ForkPipelines(TransformerMixin, BaseEstimator):
def __init__(self, transformers: List[Tuple[str, Union[Pipeline, TransformerMixin]]]):
for _, t in transformers:
Expand All @@ -41,6 +50,10 @@ def fit(self, X, y=None):


class SetRawAnnotations(FixedTransformer):
"""
Always sets the annotations, even if the events list is empty
"""

def __init__(self, event_id, durations: Union[float, Dict[str, float]]):
assert isinstance(event_id, dict) # not None
self.event_id = event_id
Expand All @@ -54,70 +67,72 @@ def transform(self, raw, y=None):
return raw
stim_channels = mne.utils._get_stim_channel(None, raw.info, raise_error=False)
if len(stim_channels) == 0:
raise ValueError("Need either a stim channel or annotations")
log.warning(
"No stim channel nor annotations found, skipping setting annotations."
)
return raw
events = mne.find_events(raw, shortest_event=0, verbose=False)
# we don't catch the error if no event found:
events = mne.pick_events(events, include=list(self.event_id.values()))
annotations = mne.annotations_from_events(
events,
raw.info["sfreq"],
self.event_desc,
first_samp=raw.first_samp,
verbose=False,
)
annotations.set_durations(self.durations)
raw.set_annotations(annotations)
events = _unsafe_pick_events(events, include=list(self.event_id.values()))
if len(events) != 0:
annotations = mne.annotations_from_events(
events,
raw.info["sfreq"],
self.event_desc,
first_samp=raw.first_samp,
verbose=False,
)
annotations.set_durations(self.durations)
raw.set_annotations(annotations)
else:
log.warning("No events found, skipping setting annotations.")
return raw


class RawToEvents(FixedTransformer):
"""
Always returns an array for shape (n_events, 3), even if no events found
"""

def __init__(self, event_id):
assert isinstance(event_id, dict) # not None
self.event_id = event_id

def transform(self, raw, y=None):
def _find_events(self, raw):
stim_channels = mne.utils._get_stim_channel(None, raw.info, raise_error=False)
if len(stim_channels) > 0:
# returns empty array if none found
events = mne.find_events(raw, shortest_event=0, verbose=False)
else:
events, _ = mne.events_from_annotations(
raw, event_id=self.event_id, verbose=False
)
try:
events = mne.pick_events(events, include=list(self.event_id.values()))
except RuntimeError:
# skip raw if no event found
return
try:
events, _ = mne.events_from_annotations(
raw, event_id=self.event_id, verbose=False
)
except ValueError as e:
if str(e) == "Could not find any of the events you specified.":
return np.zeros((0, 3), dtype="int32")
raise e
return events

def transform(self, raw, y=None):
events = self._find_events(raw)
return _unsafe_pick_events(events, list(self.event_id.values()))

class RawToEventsP300(FixedTransformer):
def __init__(self, event_id):
assert isinstance(event_id, dict) # not None
self.event_id = event_id

class RawToEventsP300(RawToEvents):
def transform(self, raw, y=None):
events = self._find_events(raw)
event_id = self.event_id
stim_channels = mne.utils._get_stim_channel(None, raw.info, raise_error=False)
if len(stim_channels) > 0:
events = mne.find_events(raw, shortest_event=0, verbose=False)
else:
events, _ = mne.events_from_annotations(raw, event_id=event_id, verbose=False)
try:
if "Target" in event_id and "NonTarget" in event_id:
if (
type(event_id["Target"]) is list
and type(event_id["NonTarget"]) == list
):
event_id_new = dict(Target=1, NonTarget=0)
events = mne.merge_events(events, event_id["Target"], 1)
events = mne.merge_events(events, event_id["NonTarget"], 0)
event_id = event_id_new
events = mne.pick_events(events, include=list(event_id.values()))
except RuntimeError:
# skip raw if no event found
return
return events
if (
"Target" in event_id
and "NonTarget" in event_id
and type(event_id["Target"]) is list
and type(event_id["NonTarget"]) is list
):
event_id_new = dict(Target=1, NonTarget=0)
events = mne.merge_events(events, event_id["Target"], 1)
events = mne.merge_events(events, event_id["NonTarget"], 0)
event_id = event_id_new
return _unsafe_pick_events(events, list(event_id.values()))


class RawToFixedIntervalEvents(FixedTransformer):
Expand Down Expand Up @@ -195,7 +210,7 @@ def __init__(
def transform(self, X, y=None):
raw = X["raw"]
events = X["events"]
if events is None or len(events) == 0:
if len(events) == 0:
raise ValueError("No events found")
if not isinstance(raw, mne.io.BaseRaw):
raise ValueError("raw must be a mne.io.BaseRaw")
Expand Down
8 changes: 4 additions & 4 deletions moabb/tests/benchmark.py
Expand Up @@ -21,10 +21,10 @@ def test_benchmark_strdataset(self):
pipelines=str(self.pp_dir),
evaluations=["WithinSession"],
include_datasets=[
"FakeDataset-imagery-10-2-2-lefthandrighthand-c3czc4",
"FakeDataset-p300-10-2-2-targetnontarget-c3czc4",
"FakeDataset-ssvep-10-2-2-1315-c3czc4",
"FakeDataset-cvep-10-2-2-1000-c3czc4",
"FakeDataset-imagery-10-2--60-60--120-120--lefthand-righthand--c3-cz-c4",
"FakeDataset-p300-10-2--60-60--120-120--target-nontarget--c3-cz-c4",
"FakeDataset-ssvep-10-2--60-60--120-120--13-15--c3-cz-c4",
"FakeDataset-cvep-10-2--60-60--120-120--10-00--c3-cz-c4",
],
overwrite=True,
)
Expand Down