Skip to content

Commit

Permalink
refactor API
Browse files Browse the repository at this point in the history
  • Loading branch information
alexandrebarachant committed Mar 7, 2018
1 parent 2b46678 commit ce5dbf5
Show file tree
Hide file tree
Showing 5 changed files with 146 additions and 104 deletions.
12 changes: 8 additions & 4 deletions examples/MotorImagery/two_class_motor_imagery.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@

from collections import OrderedDict
from moabb.datasets import utils
from moabb.viz import analyze
from moabb.analysis import analyze

import mne
mne.set_log_level(False)
Expand All @@ -18,13 +18,17 @@
has_all_events=True, min_subjects=2,
multi_session=False)

paradigm = LeftRightImagery()

context = WithinSessionEvaluation(paradigm=paradigm,
datasets=datasets,
random_state=42)

pipelines = OrderedDict()
pipelines['TS'] = make_pipeline(Covariances('oas'), TSclassifier())
pipelines['CSP+LDA'] = make_pipeline(Covariances('oas'), CSP(8), LDA())
pipelines['CSP+SVM'] = make_pipeline(Covariances('oas'), CSP(8), SVC()) #

context = LeftRightImagery(pipelines, WithinSessionEvaluation(), datasets)

results = context.process()
results = context.process(pipelines)

analyze(results, './')
94 changes: 86 additions & 8 deletions moabb/evaluations/base.py
Original file line number Diff line number Diff line change
@@ -1,33 +1,111 @@
from abc import ABC, abstractmethod
import numpy as np

from sklearn.base import BaseEstimator

from moabb.analysis import Results
from moabb.datasets.base import BaseDataset
from moabb.paradigms.base import BaseParadigm


class BaseEvaluation(ABC):
'''Base class that defines necessary operations for an evaluation.
Evaluations determine what the train and test sets are and can implement
additional data preprocessing steps for more complicated algorithms.
random_state: if not None, can guarantee same seed
n_jobs: 1; number of jobs for fitting of pipeline
Parameters
----------
paradigm : Paradigm instance
the paradigm to use.
datasets : List of Dataset Instance.
The list of dataset to run the evaluation. If none, the list of
compatible dataset will be retrieved from the paradigm instance.
random_state:
if not None, can guarantee same seed
n_jobs: 1;
number of jobs for fitting of pipeline
'''

def __init__(self, random_state=None, n_jobs=1):
def __init__(self, paradigm, datasets=None, random_state=None, n_jobs=1):
"""
Init.
"""
if random_state is None:
self.random_state = np.random.randint(0, 1000, 1)[0]

self.random_state = random_state
self.n_jobs = n_jobs

# check paradigm
if not isinstance(paradigm, BaseParadigm):
raise(ValueError("paradigm must be an Paradigm instance"))
self.paradigm = paradigm

# if no dataset provided, then we get the list from the paradigm
if datasets is None:
datasets = self.paradigm.datasets

if not isinstance(datasets, list):
if isinstance(datasets, BaseDataset):
datasets = [datasets]
else:
raise(ValueError("datasets must be a list or a dataset "
"instance"))

for dataset in datasets:
if not(isinstance(dataset, BaseDataset)):
raise(ValueError("datasets must only contains dataset "
"instance"))

for dataset in datasets:
self.paradigm.verify(dataset)

self.datasets = datasets

def process(self, pipelines, overwrite=False, suffix=''):
'''
Runs tasks on all given datasets.
'''

# check pipelines
if not isinstance(pipelines, dict):
raise(ValueError("pipelines must be a dict"))

for name, pipeline in pipelines.items():
if not(isinstance(pipeline, BaseEstimator)):
raise(ValueError("pipelines must only contains Pipelines "
"instance"))

results = Results(type(self),
type(self.paradigm),
overwrite=overwrite,
suffix=suffix)

for dataset in self.datasets:
print('\n\nProcessing dataset: {}'.format(dataset.code))
self.preprocess_data(dataset)

for subject in dataset.subject_list:
# check if we already have result for this subject/pipeline
run_pipes = results.not_yet_computed(pipelines,
dataset,
subject)
if len(run_pipes) > 0:
try:
res = self.evaluate(dataset, subject, run_pipes)
results.add(res)
except Exception as e:
print(e)
print('Skipping subject {}'.format(subject))
return results

@abstractmethod
def evaluate(self, dataset, subject, clf, paradigm):
def evaluate(self, dataset, subject, pipelines):
'''
Return results in a dict
'''
pass

def preprocess_data(self, dataset, paradigm):
@abstractmethod
def preprocess_data(self, dataset):
'''
Optional paramter if any sort of dataset-wide computation is needed
per subject
Expand Down
27 changes: 17 additions & 10 deletions moabb/evaluations/evaluations.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ class CrossSubjectEvaluation(TrainTestEvaluation):
"""

def evaluate(self, dataset, subject, pipelines, paradigm):
def evaluate(self, dataset, subject, pipelines):
# requires that subject be an int
s = subject-1
self.ind_cache[s] = self.ind_cache[s]*0
Expand All @@ -60,7 +60,7 @@ def evaluate(self, dataset, subject, pipelines, paradigm):
out = {}
for name, clf in pipelines.items():
t_start = time()
score = self.score(clf, allX, ally, groups, paradigm.scoring)
score = self.score(clf, allX, ally, groups, self.paradigm.scoring)
duration = time() - t_start
out[name] = {'time': duration,
'dataset': dataset,
Expand All @@ -70,7 +70,7 @@ def evaluate(self, dataset, subject, pipelines, paradigm):
'n_channels': allX.shape[1]}
return out

def preprocess_data(self, dataset, paradigm):
def preprocess_data(self, dataset):
assert len(dataset.subject_list) > 1, "Dataset {} has only one subject".format(
dataset.code)
self.X_cache = []
Expand All @@ -82,7 +82,7 @@ def preprocess_data(self, dataset, paradigm):
for s in dataset.subject_list:
sub = dataset.get_data([s], stack_sessions=True)[0]
# get all epochs for individual files in given subject
epochs = paradigm._epochs(sub, event_id, dataset.interval)
epochs = self.paradigm._epochs(sub, event_id, dataset.interval)
# equalize events from different classes
X, y = self.extract_data_from_cont(epochs, event_id)
self.X_cache.append(X)
Expand All @@ -103,7 +103,7 @@ class WithinSessionEvaluation(TrainTestEvaluation):
"""

def evaluate(self, dataset, subject, pipelines, paradigm):
def evaluate(self, dataset, subject, pipelines):
"""Prepare data for classification."""
event_id = dataset.selected_events
if not event_id:
Expand All @@ -116,14 +116,14 @@ def evaluate(self, dataset, subject, pipelines, paradigm):
# sess_id = '{:03d}_{:d}'.format(subject, ind)

# get all epochs for individual files in given session
epochs = paradigm._epochs(session, event_id, dataset.interval)
epochs = self.paradigm._epochs(session, event_id, dataset.interval)
X, y = self.extract_data_from_cont(epochs, event_id)
if len(np.unique(y)) > 1:
counts = np.unique(y,return_counts=True)[1]
print('score imbalance: {}'.format(counts))
for name, clf in pipelines.items():
t_start = time()
score = self.score(clf, X, y, paradigm.scoring)
score = self.score(clf, X, y, self.paradigm.scoring)
duration = time() - t_start
out[name].append({'time': duration,
'dataset': dataset,
Expand All @@ -142,6 +142,13 @@ def score(self, clf, X, y, scoring):
scoring=scoring, n_jobs=self.n_jobs)
return acc.mean()

def preprocess_data(self, dataset):
'''
Optional paramter if any sort of dataset-wide computation is needed
per subject
'''
pass


class CrossSessionEvaluation(TrainTestEvaluation):
"""Cross session Context.
Expand All @@ -151,7 +158,7 @@ class CrossSessionEvaluation(TrainTestEvaluation):
"""

def evaluate(self, dataset, subject, pipelines, paradigm):
def evaluate(self, dataset, subject, pipelines):
event_id = dataset.selected_events
if not event_id:
raise(ValueError("Dataset had no selected events"))
Expand All @@ -161,7 +168,7 @@ def evaluate(self, dataset, subject, pipelines, paradigm):
listX, listy = ([], [])
for ind, session in enumerate(sub):
# get list epochs for individual files in given session
epochs = paradigm._epochs(session, event_id, dataset.interval)
epochs = self.paradigm._epochs(session, event_id, dataset.interval)
# equalize events from different classes
X, y = self.extract_data_from_cont(epochs, event_id)
listX.append(X)
Expand All @@ -175,7 +182,7 @@ def evaluate(self, dataset, subject, pipelines, paradigm):
out = {}
for name, clf in pipelines.items():
t_start = time()
score = self.score(clf, allX, ally, groupvec, paradigm.scoring)
score = self.score(clf, allX, ally, groupvec, self.paradigm.scoring)
duration = time() - t_start
out[name] = {'time': duration,
'dataset': dataset,
Expand Down
62 changes: 16 additions & 46 deletions moabb/paradigms/base.py
Original file line number Diff line number Diff line change
@@ -1,56 +1,13 @@
from abc import ABC, abstractmethod, abstractproperty
import numpy as np

from sklearn.base import BaseEstimator

from moabb.datasets.base import BaseDataset
from moabb.datasets import utils
from abc import ABC, abstractproperty, abstractmethod


class BaseParadigm(ABC):
"""Base Context.
Parameters
----------
datasets : List of Dataset instances, or None
List of dataset instances on which the pipelines will be evaluated.
If None, uses all datasets (and should break...)
pipelines : Dict of pipelines instances.
Dictionary of pipelines. Keys identifies pipeline names, and values
are scikit-learn pipelines instances.
evaluator: Evaluator instance
Instance that defines evaluation scheme
"""

def __init__(self, pipelines, evaluator, datasets=None):
def __init__(self):
"""init"""
self.evaluator = evaluator
if datasets is None:
datasets = utils.dataset_list
# check dataset
if not isinstance(datasets, list):
if isinstance(datasets, BaseDataset):
datasets = [datasets]
else:
raise(ValueError("datasets must be a list or a dataset "
"instance"))

for dataset in datasets:
if not(isinstance(dataset, BaseDataset)):
raise(ValueError("datasets must only contains dataset "
"instance"))

self.datasets = datasets

# check pipelines
if not isinstance(pipelines, dict):
raise(ValueError("pipelines must be a dict"))

for name, pipeline in pipelines.items():
if not(isinstance(pipeline, BaseEstimator)):
raise(ValueError("pipelines must only contains Pipelines "
"instance"))
self.pipelines = pipelines
pass

@abstractproperty
def scoring(self):
Expand All @@ -60,3 +17,16 @@ def scoring(self):
'''
pass

@abstractproperty
def datasets(self):
'''Property that define the list of compatible datasets
'''
pass

@abstractmethod
def verify(self, dataset):
'''
Method that verifies dataset is correct for given parameters
'''

0 comments on commit ce5dbf5

Please sign in to comment.