# Multivariate statistics (decoding / MVPA) on MEG/EEG

Author : Alexandre Gramfort

See more info on decoding on this page: https://martinos.org/mne/stable/auto_tutorials/plot_sensors_decoding.html

In [None]:
# add plot inline in the page
%matplotlib inline
import matplotlib.pyplot as plt

First, load the mne package:

In [None]:
import mne

We set the log-level to 'WARNING' so the output is less verbose

In [None]:
mne.set_log_level('WARNING')

## Access raw data

Now we import the sample dataset. If you don't already have it, it will be downloaded automatically (but be patient approx. 2GB)

In [None]:
from mne.datasets import sample
data_path = sample.data_path()

#data_path = '/Users/alex/mne_data/MNE-sample-data'

raw_fname = data_path + '/MEG/sample/sample_audvis_filt-0-40_raw.fif'
raw = mne.io.read_raw_fif(raw_fname, preload=True)
raw

High pass the data above 1Hz

In [None]:
raw.filter(1, None)

In [None]:
print(raw.info)

## Define and read epochs

First extract events:

In [None]:
events = mne.find_events(raw, stim_channel='STI 014', verbose=True)

Look at the design in a graphical way:

In [None]:
mne.viz.plot_events(events, raw.info['sfreq'], raw.first_samp);

## From raw to epochs

Define epochs parameters:

In [None]:
#event_id = dict(aud_l=1, aud_r=2)  # event trigger and conditions
event_id = {'aud_l': 1, 'aud_r': 2}  # event trigger and conditions
tmin = -0.1  # start of each epoch
tmax = 0.4  # end of each epoch
baseline = None  # no baseline as data were high passed

reject = dict(eeg=80e-6, eog=40e-6)

picks = mne.pick_types(raw.info, eeg=True, meg=True,
                       eog=True, stim=False, exclude='bads')

epochs = mne.Epochs(raw, events, event_id, tmin, tmax, proj=True,
                    picks=picks, baseline=baseline,
                    reject=reject, preload=True)  # with preload

print(epochs)

Look at the ERF and contrast between left and rigth response

In [None]:
evoked_left = epochs['aud_l'].average()
evoked_right = epochs['aud_r'].average()
evoked_contrast = mne.combine_evoked([evoked_left, evoked_right],
                                     [0.5, -0.5])

In [None]:
fig = evoked_left.plot()
fig = evoked_right.plot()
fig = evoked_contrast.plot()

Plot some topographies

In [None]:
vmin, vmax = -4, 4
fig = evoked_left.plot_topomap(ch_type='eeg', contours=0, vmin=vmin, vmax=vmax)
fig = evoked_right.plot_topomap(ch_type='eeg', contours=0, vmin=vmin, vmax=vmax)
fig = evoked_contrast.plot_topomap(ch_type='eeg', contours=0, vmin=None, vmax=None)

## Now let's see if we can classify single trials

To have a chance at 50% accuracy equalize epoch count in each condition

In [None]:
epochs.equalize_event_counts(event_id)
print(epochs)

A classifier takes as input an `x` and return `y` (0 or 1). Here x will be the data at one time point on all gradiometers (hence the term multivariate). We work with all sensors jointly and try to find a discriminative pattern between 2 conditions to predict the class.

For classification we will use the scikit-learn package (http://scikit-learn.org/) and MNE functions 

`
Reference:
Scikit-learn: Machine Learning in Python,
Pedregosa et al., JMLR 12, pp. 2825-2830, 2011.
`

In [None]:
import numpy as np
# make response vector
y = np.zeros(len(epochs.events), dtype=int)
y[epochs.events[:, 2] == 2] = 1

y.size

In [None]:
X = epochs.copy().pick_types(meg='grad').get_data()
X.shape

In [None]:
XX = X.reshape(108, -1)
XX.shape

In [None]:
from sklearn.model_selection import StratifiedKFold, cross_val_score
from sklearn.linear_model import LogisticRegression

logreg = LogisticRegression(C=1e6, solver='liblinear')
cv = StratifiedKFold(n_splits=5, random_state=42)
scores = cross_val_score(logreg, XX[2:], y[2:], cv=cv, scoring='roc_auc')
print(scores)
print('Accuracy = %0.3f (std %.3f)' % (np.mean(scores), np.std(scores)))

In [None]:
plt.hist(scores, bins=20)

Now we can do this more simply using the `mne.decoding` module

In [None]:
from sklearn.pipeline import make_pipeline
from mne.decoding import Scaler, Vectorizer, cross_val_multiscore

epochs_decoding = epochs.copy().pick_types(meg='grad')

clf = make_pipeline(Scaler(epochs_decoding.info),
                    Vectorizer(),
                    logreg)

X = epochs_decoding.get_data()
y = epochs_decoding.events[:, 2]

scores = cross_val_multiscore(clf, X, y, cv=5, n_jobs=1)

# Mean scores across cross-validation splits
score = np.mean(scores, axis=0)
print('Spatio-temporal: %0.1f%%' % (100 * score,))

## Decoding over time

In [None]:
from sklearn.preprocessing import StandardScaler
from mne.decoding import SlidingEstimator

clf = make_pipeline(StandardScaler(), logreg)

time_decod = SlidingEstimator(clf, n_jobs=1, scoring='roc_auc', verbose=True)
scores = cross_val_multiscore(time_decod, X, y, cv=5, n_jobs=1)

# Mean scores across cross-validation splits
scores = np.mean(scores, axis=0)

# Plot
fig, ax = plt.subplots()
ax.plot(epochs.times, scores, label='score')
ax.axhline(.5, color='k', linestyle='--', label='chance')
ax.set_xlabel('Times')
ax.set_ylabel('AUC')  # Area Under the Curve
ax.legend()
ax.axvline(.0, color='k', linestyle='-')
ax.set_title('Sensor space decoding')

For more details see: https://martinos.org/mne/stable/auto_tutorials/plot_sensors_decoding.html

and this book chapter:

Jean-Rémi King, Laura Gwilliams, Chris Holdgraf, Jona Sassenhagen, Alexandre Barachant, Denis Engemann, Eric Larson, Alexandre Gramfort. Encoding and Decoding Neuronal Dynamics: Methodological Framework to Uncover the Algorithms of Cognition. 2018. https://hal.archives-ouvertes.fr/hal-01848442/

<div class="alert alert-success">
    <b>EXERCISE</b>:
     <ul>
      <li>Do a time by time decoding o the spm face dataset to see if you can classify faces vs. scrambled faces.</li>
      <li>Do a generalization over time analysis as explained in the <a href="https://martinos.org/mne/dev/auto_tutorials/plot_sensors_decoding.html?highlight=generalizingestimator#temporal-generalization">documentation on decoding</a>.</li>
    </ul>
</div>

Example using the SPM face dataset: https://martinos.org/mne/dev/auto_examples/datasets/spm_faces_dataset.html