In [None]:
%matplotlib inline

# Motor imagery decoding from EEG data using the Common Spatial Pattern (CSP)

Authors: Martin Billinger

Decoding of motor imagery applied to EEG data decomposed using CSP.
Here the classifier is applied to features extracted on CSP filtered signals.

See http://en.wikipedia.org/wiki/Common_spatial_pattern and [1]

The EEGBCI dataset is documented in [2]
The data set is available at PhysioNet [3]

[1] Zoltan J. Koles. The quantitative extraction and topographic mapping
    of the abnormal components in the clinical EEG. Electroencephalography
    and Clinical Neurophysiology, 79(6):440--447, December 1991.

[2] Schalk, G., McFarland, D.J., Hinterberger, T., Birbaumer, N.,
    Wolpaw, J.R. (2004) BCI2000: A General-Purpose Brain-Computer Interface
    (BCI) System. IEEE TBME 51(6):1034-1043

[3] Goldberger AL, Amaral LAN, Glass L, Hausdorff JM, Ivanov PCh, Mark RG,
    Mietus JE, Moody GB, Peng C-K, Stanley HE. (2000) PhysioBank,
    PhysioToolkit, and PhysioNet: Components of a New Research Resource for
    Complex Physiologic Signals. Circulation 101(23):e215-e220



In [None]:
import numpy as np
import matplotlib.pyplot as plt

import mne
from mne.channels import read_layout
from mne.io import concatenate_raws, read_raw_edf
from mne.datasets import eegbci
from mne.decoding import CSP

In [None]:
# Set parameters and read data

# avoid classification of evoked responses by using epochs that
# start 1s after cue onset.
tmin, tmax = -1., 4.
subject = 1
runs = [6, 10, 14]  # motor imagery: hands vs feet

raw_fnames = eegbci.load_data(subject, runs)
raw_fnames

In [None]:
raw_files = [read_raw_edf(f, preload=True) for f in raw_fnames]
raw = concatenate_raws(raw_files)

In [None]:
# Set channel locations to plot topographies
from mne.channels import make_standard_montage
raw.rename_channels(lambda x: x.strip('.'))
montage = make_standard_montage('standard_1005')
raw.set_montage(montage, False)

In [None]:
raw

In [None]:
raw.info

In [None]:
raw.annotations

In [None]:
raw.annotations.description

In [None]:
%matplotlib qt
raw.plot()

In [None]:
%matplotlib inline
# strip channel names of "." characters
raw.rename_channels(lambda x: x.strip('.'))

In [None]:
raw.ch_names

In [None]:
# Apply band-pass filter
raw.filter(7., 30., method='iir')

# events = find_events(raw, shortest_event=0, stim_channel='STI 014', verbose=True)
events, _ = mne.events_from_annotations(raw)

In [None]:
events[:10]

In [None]:
event_id = dict(hands=2, feet=3)

picks = mne.pick_types(raw.info, meg=False, eeg=True, stim=False, eog=False,
                       exclude='bads')

# Read epochs (train will be done only between 1 and 2s)
# Testing will be done with a running classifier
epochs = mne.Epochs(raw, events, event_id, tmin, tmax, proj=True, picks=picks,
                    baseline=None, preload=True)
epochs_train = epochs.copy().crop(tmin=1., tmax=2.)
labels = epochs.events[:, -1] - 2

In [None]:
epochs['feet'].average().plot();

Classification with linear discrimant analysis



In [None]:
from sklearn.discriminant_analysis import LinearDiscriminantAnalysis
from sklearn.model_selection import ShuffleSplit

# Assemble a classifier
lda = LinearDiscriminantAnalysis()
csp = CSP(n_components=4, reg=None, log=True)

In [None]:
epochs.get_data().shape

In [None]:
# Define a monte-carlo cross-validation generator (reduce variance):
cv = ShuffleSplit(10, test_size=0.2, random_state=42)
scores = []
X = epochs.get_data()
X_train = epochs_train.get_data()

In [None]:
# Use scikit-learn Pipeline with cross_val_score function
from sklearn.pipeline import Pipeline
from sklearn.model_selection import cross_val_score
clf = Pipeline([('CSP', csp), ('LDA', lda)])

In [None]:
clf

In [None]:
scores = cross_val_score(clf, X_train, labels, cv=cv, n_jobs=1)
print(scores)

In [None]:
# Printing the results
class_balance = np.mean(labels == labels[0])
class_balance = max(class_balance, 1. - class_balance)
print("Classification accuracy: %f / Chance level: %f" % (np.mean(scores),
                                                          class_balance))

# plot CSP patterns estimated on full data for visualization
csp.fit_transform(X, labels)

csp.plot_patterns(epochs.info, ch_type='eeg',
                  units='Patterns (AU)', size=1.5);

## Look at performance over time

In [None]:
sfreq = raw.info['sfreq']
w_length = int(sfreq * 0.5)   # running classifier: window length
w_step = int(sfreq * 0.1)  # running classifier: window step size
w_start = np.arange(0, X.shape[2] - w_length, w_step)

scores_windows = []

for train_idx, test_idx in cv.split(X, labels):
    y_train, y_test = labels[train_idx], labels[test_idx]

    XX_train = csp.fit_transform(X_train[train_idx], y_train)
    XX_test = csp.transform(X_train[test_idx])

    # fit classifier
    lda.fit(XX_train, y_train)

    # running classifier: test classifier on sliding window
    score_this_window = []
    for n in w_start:
        XX_test = csp.transform(X[test_idx][:, :, n:(n + w_length)])
        score_this_window.append(lda.score(XX_test, y_test))
    scores_windows.append(score_this_window)

# Plot scores over time
w_times = (w_start + w_length / 2.) / sfreq + epochs.tmin

plt.figure()
plt.plot(w_times, np.mean(scores_windows, 0), label='Score')
plt.axvline(0, linestyle='--', color='k', label='Onset')
plt.axhline(0.5, linestyle='-', color='k', label='Chance')
plt.xlabel('time (s)')
plt.ylabel('classification accuracy')
plt.title('Classification score over time')
plt.legend(loc='lower right')
plt.show()

<div class="alert alert-success">
    <b>EXERCISES:</b>
     <ul>
      <li>How is the performance of the CSP affected by the choice of its parameter such as n_components, reg or log? You will plot the performance of the model as a function of n_components.</li>
      <li>How you would tune these parameters without the risk of overfitting?</li>
    </ul>
</div>

See documentation: https://mne.tools/stable/generated/mne.decoding.CSP.html#mne.decoding.CSP

In [None]:
raw

In [None]:
raw.annotations.description

In [None]:
raw.set_annotations?

In [None]:
onset = [2., 5.]
duration = [1., 3.]
description = ['alpha', 'alpha']
annotations = mne.Annotations(onset, duration, description)
annotations

In [None]:
raw.set_annotations(annotations)

In [None]:
%matplotlib qt
raw.plot()

In [None]:
raw.annotations