Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix/compound dataset target events #475

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
1 change: 1 addition & 0 deletions docs/source/whats_new.rst
Expand Up @@ -50,6 +50,7 @@ Enhancements
- Add option to plot scores vertically. (:gh:`417` by `Sara Sedlar`_)
- Increase the python version to 3.11 (:gh:`470` by `Bruno Aristimunha`_)
- Add match_all method in paradigm to support CompoundDataset evaluation with MNE epochs (:gh:`473` by `Gregoire Cattan`_)
- Automate setting of event_id in compound dataset and add `data_origin` information to the data (:gh:`475` by `Gregoire Cattan`_)

Bugs
~~~~
Expand Down
13 changes: 8 additions & 5 deletions moabb/datasets/base.py
Expand Up @@ -194,6 +194,13 @@ def __init__(
self.doi = doi
self.unit_factor = unit_factor

def _create_process_pipeline(self):
return Pipeline(
[
(StepType.RAW, SetRawAnnotations(self.event_id)),
]
)

def get_data(
self,
subjects=None,
Expand Down Expand Up @@ -256,11 +263,7 @@ def get_data(
cache_config = CacheConfig.make(cache_config)

if process_pipeline is None:
process_pipeline = Pipeline(
[
(StepType.RAW, SetRawAnnotations(self.event_id)),
]
)
process_pipeline = self._create_process_pipeline()

data = dict()
for subject in subjects:
Expand Down
50 changes: 42 additions & 8 deletions moabb/datasets/compound_dataset/base.py
@@ -1,5 +1,7 @@
"""Build a custom dataset using subjects from other datasets."""

from sklearn.pipeline import Pipeline

from ..base import BaseDataset


Expand Down Expand Up @@ -28,10 +30,6 @@ class CompoundDataset(BaseDataset):
sessions_per_subject: int
Number of sessions per subject (if varying, take minimum)

events: dict of strings
String codes for events matched with labels in the stim channel.
See `BaseDataset`.

code: string
Unique identifier for dataset, used in all plots

Expand All @@ -42,14 +40,13 @@ class CompoundDataset(BaseDataset):
Defines what sort of dataset this is
"""

def __init__(
self, subjects_list: list, events: dict, code: str, interval: list, paradigm: str
):
def __init__(self, subjects_list: list, code: str, interval: list, paradigm: str):
self._set_subjects_list(subjects_list)
dataset, _, _, _ = self.subjects_list[0]
super().__init__(
subjects=list(range(1, self.count + 1)),
sessions_per_subject=self._get_sessions_per_subject(),
events=events,
events=dataset.event_id,
code=code,
interval=interval,
paradigm=paradigm,
Expand Down Expand Up @@ -81,6 +78,43 @@ def _set_subjects_list(self, subjects_list: list):
for compoundDataset in subjects_list:
self.subjects_list.extend(compoundDataset.subjects_list)

def _with_data_origin(self, data: dict, shopped_subject):
data_origin = self.subjects_list[shopped_subject - 1]

class dict_with_hidden_key(dict):
def __getitem__(self, item):
# ensures data_origin is never accessed when iterating with dict.keys()
# that would make iterating over runs and sessions failing.
if item == "data_origin":
return data_origin
else:
return super().__getitem__(item)

return dict_with_hidden_key(data)

def _get_single_subject_data_using_cache(
self, shopped_subject, cache_config, process_pipeline
):
# change this compound dataset target event_id to match the one of the hidden dataset
# as event_id can varies between datasets
dataset, _, _, _ = self.subjects_list[shopped_subject - 1]
self.event_id = dataset.event_id

# regenerate the process_pipeline by overriding all `event_id`
steps = []
for step in process_pipeline.steps:
label, op = step
if hasattr(op, "event_id"):
op.event_id = self.event_id
steps.append((label, op))
process_pipeline = Pipeline(steps)

# don't forget to continue on preprocessing by calling super
data = super()._get_single_subject_data_using_cache(
shopped_subject, cache_config, process_pipeline
)
return self._with_data_origin(data, shopped_subject)

def _get_single_subject_data(self, shopped_subject):
"""Return data for a single subject."""
dataset, subject, sessions, runs = self.subjects_list[shopped_subject - 1]
Expand Down
1 change: 0 additions & 1 deletion moabb/datasets/compound_dataset/bi_illiteracy.py
Expand Up @@ -13,7 +13,6 @@ def __init__(self, subjects_list, dataset=None, code=None):
CompoundDataset.__init__(
self,
subjects_list=subjects_list,
events=dict(Target=2, NonTarget=1),
code=code,
interval=[0, 1.0],
paradigm="p300",
Expand Down
43 changes: 39 additions & 4 deletions moabb/tests/datasets.py
Expand Up @@ -319,14 +319,19 @@ def test_fake_dataset(self):
subjects_list = [(self.ds, 1, sessions, runs)]
compound_data = CompoundDataset(
subjects_list,
events=dict(Target=2, NonTarget=1),
code="CompoundDataset-test",
interval=[0, 1],
paradigm=self.paradigm,
)

data = compound_data.get_data()

# Check event_id is correctly set
self.assertEqual(compound_data.event_id, self.ds.event_id)

# Check data origin is correctly set
self.assertEqual(data[1]["data_origin"], subjects_list[0])

# Check data type
self.assertTrue(isinstance(data, dict))
self.assertIsInstance(data[1]["session_0"]["run_0"], mne.io.BaseRaw)
Expand All @@ -348,7 +353,6 @@ def test_compound_dataset_composition(self):
subjects_list = [(self.ds, 1, None, None)]
compound_dataset = CompoundDataset(
subjects_list,
events=dict(Target=2, NonTarget=1),
code="CompoundDataset-test",
interval=[0, 1],
paradigm=self.paradigm,
Expand All @@ -358,7 +362,6 @@ def test_compound_dataset_composition(self):
subjects_list = [compound_dataset, compound_dataset]
compound_data = CompoundDataset(
subjects_list,
events=dict(Target=2, NonTarget=1),
code="CompoundDataset-test",
interval=[0, 1],
paradigm=self.paradigm,
Expand All @@ -382,7 +385,6 @@ def test_get_sessions_per_subject(self):
subjects_list = [(self.ds, 1, None, None), (self.ds2, 1, None, None)]
compound_dataset = CompoundDataset(
subjects_list,
events=dict(Target=2, NonTarget=1),
code="CompoundDataset",
interval=[0, 1],
paradigm=self.paradigm,
Expand All @@ -391,6 +393,39 @@ def test_get_sessions_per_subject(self):
# Test private method _get_sessions_per_subject returns the minimum number of sessions per subjects
self.assertEqual(compound_dataset._get_sessions_per_subject(), self.n_sessions)

def test_event_id_correctly_updated(self):
# define a new fake dataset with different event_id
self.ds2 = FakeDataset(
n_sessions=self.n_sessions,
n_runs=self.n_runs,
n_subjects=self.n_subjects,
event_list=["Target2", "NonTarget2"],
paradigm=self.paradigm,
)

# Add the two datasets to a CompoundDataset
subjects_list = [(self.ds, 1, None, None), (self.ds2, 1, None, None)]

compound_dataset = CompoundDataset(
subjects_list,
code="CompoundDataset",
interval=[0, 1],
paradigm=self.paradigm,
)

# Check that the event_id of the compound_dataset is the same has the first dataset
self.assertEqual(compound_dataset.event_id, self.ds.event_id)

# Check event_id get correctly updated when taking a subject from dataset 2
data = compound_dataset.get_data(subjects=[2])
self.assertEqual(compound_dataset.event_id, self.ds2.event_id)
self.assertEqual(len(data.keys()), 1)

# Check event_id is correctly put back when taking a subject from the first dataset
data = compound_dataset.get_data(subjects=[1])
self.assertEqual(compound_dataset.event_id, self.ds.event_id)
self.assertEqual(len(data.keys()), 1)

def test_datasets_init(self):
codes = []
for ds in compound_dataset_list:
Expand Down
3 changes: 0 additions & 3 deletions tutorials/tutorial_5_build_a_custom_dataset.py
Expand Up @@ -63,7 +63,6 @@ def __init__(self):
CompoundDataset.__init__(
self,
subjects_list=subjects_list,
events=dict(Target=2, NonTarget=1),
code="D1",
interval=[0, 1.0],
paradigm="p300",
Expand All @@ -80,7 +79,6 @@ def __init__(self):
CompoundDataset.__init__(
self,
subjects_list=subjects_list,
events=dict(Target=2, NonTarget=1),
code="D2",
interval=[0, 1.0],
paradigm="p300",
Expand All @@ -103,7 +101,6 @@ def __init__(self):
CompoundDataset.__init__(
self,
subjects_list=subjects_list,
events=dict(Target=2, NonTarget=1),
code="D3",
interval=[0, 1.0],
paradigm="p300",
Expand Down