-
Notifications
You must be signed in to change notification settings - Fork 1.3k
/
ems.py
223 lines (180 loc) · 8.11 KB
/
ems.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
# Author: Denis Engemann <denis.engemann@gmail.com>
# Alexandre Gramfort <alexandre.gramfort@telecom-paristech.fr>
# Jean-Remi King <jeanremi.king@gmail.com>
#
# License: BSD (3-clause)
from collections import Counter
import numpy as np
from .mixin import TransformerMixin, EstimatorMixin
from .base import _set_cv
from ..utils import logger, verbose
from ..parallel import parallel_func
from .. import pick_types, pick_info
class EMS(TransformerMixin, EstimatorMixin):
"""Transformer to compute event-matched spatial filters.
This version of EMS [1]_ operates on the entire time course. No time
window needs to be specified. The result is a spatial filter at each
time point and a corresponding time course. Intuitively, the result
gives the similarity between the filter at each time point and the
data vector (sensors) at that time point.
.. note : EMS only works for binary classification.
Attributes
----------
``filters_`` : ndarray, shape (n_channels, n_times)
The set of spatial filters.
``classes_`` : ndarray, shape (n_classes,)
The target classes.
References
----------
.. [1] Aaron Schurger, Sebastien Marti, and Stanislas Dehaene, "Reducing
multi-sensor data to a single time course that reveals experimental
effects", BMC Neuroscience 2013, 14:122
"""
def __repr__(self): # noqa: D105
if hasattr(self, 'filters_'):
return '<EMS: fitted with %i filters on %i classes.>' % (
len(self.filters_), len(self.classes_))
else:
return '<EMS: not fitted.>'
def fit(self, X, y):
"""Fit the spatial filters.
.. note : EMS is fitted on data normalized by channel type before the
fitting of the spatial filters.
Parameters
----------
X : array, shape (n_epochs, n_channels, n_times)
The training data.
y : array of int, shape (n_epochs)
The target classes.
Returns
-------
self : returns and instance of self.
"""
classes = np.unique(y)
if len(classes) != 2:
raise ValueError('EMS only works for binary classification.')
self.classes_ = classes
filters = X[y == classes[0]].mean(0) - X[y == classes[1]].mean(0)
filters /= np.linalg.norm(filters, axis=0)[None, :]
self.filters_ = filters
return self
def transform(self, X):
"""Transform the data by the spatial filters.
Parameters
----------
X : array, shape (n_epochs, n_channels, n_times)
The input data.
Returns
-------
X : array, shape (n_epochs, n_times)
The input data transformed by the spatial filters.
"""
Xt = np.sum(X * self.filters_, axis=1)
return Xt
@verbose
def compute_ems(epochs, conditions=None, picks=None, n_jobs=1, verbose=None,
cv=None):
"""Compute event-matched spatial filter on epochs.
This version of EMS [1]_ operates on the entire time course. No time
window needs to be specified. The result is a spatial filter at each
time point and a corresponding time course. Intuitively, the result
gives the similarity between the filter at each time point and the
data vector (sensors) at that time point.
.. note : EMS only works for binary classification.
.. note : The present function applies a leave-one-out cross-validation,
following Schurger et al's paper. However, we recommend using
a stratified k-fold cross-validation. Indeed, leave-one-out tends
to overfit and cannot be used to estimate the variance of the
prediction within a given fold.
.. note : Because of the leave-one-out, this function needs an equal
number of epochs in each of the two conditions.
Parameters
----------
epochs : instance of mne.Epochs
The epochs.
conditions : list of str | None, defaults to None
If a list of strings, strings must match the epochs.event_id's key as
well as the number of conditions supported by the objective_function.
If None keys in epochs.event_id are used.
picks : array-like of int | None, defaults to None
Channels to be included. If None only good data channels are used.
n_jobs : int, defaults to 1
Number of jobs to run in parallel.
verbose : bool, str, int, or None, defaults to self.verbose
If not None, override default verbose level (see :func:`mne.verbose`
and :ref:`Logging documentation <tut_logging>` for more).
cv : cross-validation object | str | None, defaults to LeaveOneOut
The cross-validation scheme.
Returns
-------
surrogate_trials : ndarray, shape (n_trials // 2, n_times)
The trial surrogates.
mean_spatial_filter : ndarray, shape (n_channels, n_times)
The set of spatial filters.
conditions : ndarray, shape (n_classes,)
The conditions used. Values correspond to original event ids.
References
----------
.. [1] Aaron Schurger, Sebastien Marti, and Stanislas Dehaene, "Reducing
multi-sensor data to a single time course that reveals experimental
effects", BMC Neuroscience 2013, 14:122
"""
logger.info('...computing surrogate time series. This can take some time')
# Default to leave-one-out cv
cv = 'LeaveOneOut' if cv is None else cv
if picks is None:
picks = pick_types(epochs.info, meg=True, eeg=True)
if not len(set(Counter(epochs.events[:, 2]).values())) == 1:
raise ValueError('The same number of epochs is required by '
'this function. Please consider '
'`epochs.equalize_event_counts`')
if conditions is None:
conditions = epochs.event_id.keys()
epochs = epochs.copy()
else:
epochs = epochs[conditions]
epochs.drop_bad()
if len(conditions) != 2:
raise ValueError('Currently this function expects exactly 2 '
'conditions but you gave me %i' %
len(conditions))
ev = epochs.events[:, 2]
# Special care to avoid path dependent mappings and orders
conditions = list(sorted(conditions))
cond_idx = [np.where(ev == epochs.event_id[k])[0] for k in conditions]
info = pick_info(epochs.info, picks)
data = epochs.get_data()[:, picks]
# Scale (z-score) the data by channel type
# XXX the z-scoring is applied outside the CV, which is not standard.
for ch_type in ['mag', 'grad', 'eeg']:
if ch_type in epochs:
# FIXME should be applied to all sort of data channels
if ch_type == 'eeg':
this_picks = pick_types(info, meg=False, eeg=True)
else:
this_picks = pick_types(info, meg=ch_type, eeg=False)
data[:, this_picks] /= np.std(data[:, this_picks])
# Setup cross-validation. Need to use _set_cv to deal with sklearn
# deprecation of cv objects.
y = epochs.events[:, 2]
_, cv_splits = _set_cv(cv, 'classifier', X=y, y=y)
parallel, p_func, _ = parallel_func(_run_ems, n_jobs=n_jobs)
# FIXME this parallization should be removed.
# 1) it's numpy computation so it's already efficient,
# 2) it duplicates the data in RAM,
# 3) the computation is already super fast.
out = parallel(p_func(_ems_diff, data, cond_idx, train, test)
for train, test in cv_splits)
surrogate_trials, spatial_filter = zip(*out)
surrogate_trials = np.array(surrogate_trials)
spatial_filter = np.mean(spatial_filter, axis=0)
return surrogate_trials, spatial_filter, epochs.events[:, 2]
def _ems_diff(data0, data1):
"""Compute the default diff objective function."""
return np.mean(data0, axis=0) - np.mean(data1, axis=0)
def _run_ems(objective_function, data, cond_idx, train, test):
"""Run EMS."""
d = objective_function(*(data[np.intersect1d(c, train)] for c in cond_idx))
d /= np.sqrt(np.sum(d ** 2, axis=0))[None, :]
# compute surrogates
return np.sum(data[test[0]] * d, axis=0), d