Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Additional columns that should be evaluated #127

Merged
merged 6 commits into from
Feb 1, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 21 additions & 5 deletions moabb/analysis/results.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ class Results:
'''

def __init__(self, evaluation_class, paradigm_class, suffix='',
overwrite=False, hdf5_path=None):
overwrite=False, hdf5_path=None, additional_columns=None):
"""
class that will abstract result storage
"""
Expand All @@ -49,6 +49,12 @@ class that will abstract result storage
assert issubclass(evaluation_class, BaseEvaluation)
assert issubclass(paradigm_class, BaseParadigm)

if additional_columns is None:
self.additional_columns = []
else:
assert all([isinstance(ac, str) for ac in additional_columns])
self.additional_columns = additional_columns

if hdf5_path is None:
self.mod_dir = os.path.dirname(
os.path.abspath(inspect.getsourcefile(moabb)))
Expand Down Expand Up @@ -95,6 +101,7 @@ def to_list(res):
dlist = to_list(data_dict)
d1 = dlist[0] # FIXME: handle multiple session ?
dname = d1['dataset'].code
n_add_cols = len(self.additional_columns)
if dname not in ppline_grp.keys():
# create dataset subgroup if nonexistant
dset = ppline_grp.create_group(dname)
Expand All @@ -103,11 +110,12 @@ def to_list(res):
dt = h5py.special_dtype(vlen=str)
dset.create_dataset('id', (0, 2), dtype=dt,
maxshape=(None, 2))
dset.create_dataset('data', (0, 3),
maxshape=(None, 3))
dset.create_dataset('data', (0, 3 + n_add_cols),
maxshape=(None, 3 + n_add_cols))
dset.attrs['channels'] = d1['n_channels']
dset.attrs.create('columns',
['score', 'time', 'samples'],
['score', 'time', 'samples',
*self.additional_columns],
dtype=dt)
dset = ppline_grp[dname]
for d in dlist:
Expand All @@ -117,9 +125,17 @@ def to_list(res):
dset['data'].resize(length, 0)
dset['id'][-1, :] = np.asarray([str(d['subject']),
str(d['session'])])
try:
add_cols = [d[ac] for ac in self.additional_columns]
except KeyError:
raise ValueError(
f'Additional columns: {self.additional_columns} '
f'were specified in the evaluation, but results'
f' contain only these keys: {d.keys()}.')
dset['data'][-1, :] = np.asarray([d['score'],
d['time'],
d['n_samples']])
d['n_samples'],
*add_cols])

def to_dataframe(self, pipelines=None):
df_list = []
Expand Down
5 changes: 3 additions & 2 deletions moabb/evaluations/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ class BaseEvaluation(ABC):

def __init__(self, paradigm, datasets=None, random_state=None, n_jobs=1,
overwrite=False, error_score='raise', suffix='',
hdf5_path=None):
hdf5_path=None, additional_columns=None):
self.random_state = random_state
self.n_jobs = n_jobs
self.error_score = error_score
Expand Down Expand Up @@ -85,7 +85,8 @@ def __init__(self, paradigm, datasets=None, random_state=None, n_jobs=1,
type(self.paradigm),
overwrite=overwrite,
suffix=suffix,
hdf5_path=self.hdf5_path)
hdf5_path=self.hdf5_path,
additional_columns=additional_columns)

def process(self, pipelines):
'''Runs all pipelines on all datasets.
Expand Down
17 changes: 17 additions & 0 deletions moabb/tests/evaluations.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,23 @@ def test_eval_results(self):
self.assertEqual(len(results), 4)


class Test_AdditionalColumns(unittest.TestCase):

def setUp(self):
self.eval = ev.WithinSessionEvaluation(
paradigm=FakeImageryParadigm(), datasets=[dataset],
additional_columns=['one', 'two'])

def tearDown(self):
path = self.eval.results.filepath
if os.path.isfile(path):
os.remove(path)

def test_fails_if_nothing_returned(self):
self.assertRaises(ValueError, self.eval.process, pipelines)
# TODO Add custom evaluation that actually returns additional info


class Test_CrossSubj(Test_WithinSess):

def setUp(self):
Expand Down