

# Motor imagery decoding from EEG data using the Common Spatial Pattern (CSP)

Decoding of motor imagery applied to EEG data decomposed using CSP. A
classifier is then applied to features extracted on CSP-filtered signals.

See https://en.wikipedia.org/wiki/Common_spatial_pattern and
:footcite:`Koles1991`. The EEGBCI dataset is documented in
:footcite:`SchalkEtAl2004` and is available at PhysioNet
:footcite:`GoldbergerEtAl2000`.


In [None]:
# Authors: Martin Billinger <martin.billinger@tugraz.at>
#
# License: BSD-3-Clause
# Copyright the MNE-Python contributors.

In [None]:

import matplotlib.pyplot as plt
import seaborn as sns
import numpy as np
from glob import glob

from sklearn.discriminant_analysis import LinearDiscriminantAnalysis
from sklearn.svm import SVC
from sklearn.model_selection import ShuffleSplit, cross_val_score
from sklearn.decomposition import PCA
from sklearn.pipeline import Pipeline
from sklearn.metrics import classification_report, accuracy_score, make_scorer, confusion_matrix

from mne import Epochs, pick_types, find_events
from mne.channels import make_standard_montage
from mne.datasets import eegbci
from mne.decoding import CSP, UnsupervisedSpatialFilter, Vectorizer, Scaler
from mne.io import concatenate_raws, read_raw_fif, read_raw_edf

print(__doc__)

In [None]:
# #############################################################################
# # Set parameters and read data

# avoid classification of evoked responses by using epochs that start 1s after
# cue onset.
tmin, tmax = -1.0, 4

subject = 1
runs = [6, 10, 14]  # motor imagery: hands vs feet
raw_fnames = eegbci.load_data(subject, runs)
raw = concatenate_raws([read_raw_edf(f, preload=True) for f in raw_fnames])
raw.annotations.rename(dict(T1="hands", T2="feet"))
eegbci.standardize(raw)  # set channel names

subject = "chris"
subject = "laurids"
raw_fnames = glob(f"data/scratch/{subject}/*_b*/*raw.fif")
raw = concatenate_raws([read_raw_fif(f, preload=True) for f in raw_fnames])

montage = make_standard_montage("standard_1005")
raw.set_montage(montage)
raw.set_eeg_reference(projection=True)
raw.apply_proj()

# Apply band-pass filter
raw.filter(7.0, 30.0, fir_design="firwin", skip_by_annotation="edge")

picks = pick_types(raw.info, meg=False, eeg=True, stim=False, eog=False, exclude="bads")
events = find_events(raw, stim_channel='trigger', verbose=True)
classes = [90, 180, 270]

# Read epochs (train will be done only between 1 and 2s)
# Testing will be done with a running classifier
epochs = Epochs(
    raw,
    events=events,
    event_id=classes,
    tmin=tmin,
    tmax=tmax,
    proj=True,
    picks=picks,
    baseline=None,
    preload=True,
)
epochs_train = epochs.copy().crop(tmin=2, tmax=2.5)
labels = epochs.events[:, -1]

epochs

Classification with linear discrimant analysis



In [None]:
pink = sns.light_palette('#69B45B', as_cmap=True)
savefigs = False
font = {
    'size': 12
}

def plot_cf(cf: np.ndarray[np.ndarray[int]], title : str, classes):
    label_convert = {90 : 'Right', 180 : 'Feet', 270 : 'Left'}
    group_names = [x for xs in 
                   [["True " + label_convert[c] if c == i else "False " + label_convert[c] for c in classes] for i in classes] 
                   for x in xs]
    #creating labels to add to the heatmap
    # group_names = ['True Pop','False Rap','False Rb',
    #                'False Pop', 'True Rap', 'False Rb', 
    #                'False Pop', 'False Rap', 'True Rb']
    group_counts = ['{0:0.0f}'.format(value) for value in cf.flatten()]
    labels = [f'{v1}\n{v2}' for v1, v2 in
            zip(group_names,group_counts)]
    labels = np.asarray(labels).reshape(len(classes),len(classes))
    print(labels)

    #creating the heatmap and adding titels
    f, ax = plt.subplots(figsize=(10,7))
    ticks = [label_convert[c] for c in classes]
    sns.heatmap(cf, annot=labels, fmt='', cmap= pink, xticklabels=ticks, yticklabels=ticks, ax=ax)

    ax.set_xlabel('Predicted', fontdict=font)
    ax.set_ylabel('Actual', fontdict=font)

    plt.suptitle(f'Song genre prediction {title}', fontsize=18, x=0.45)
    plt.title('Heatmap of the confusion matrix', fontsize=12)

    if savefigs:
        plt.savefig(f'figs/figure_{title}.png', bbox_inches='tight')

In [None]:
true_classes, pred_classes = [], []
def classification_report_with_accuracy_score(y_true, y_pred):
    # print(classification_report(y_true, y_pred)) # print classification report
    true_classes.extend(y_true)
    pred_classes.extend(y_pred)
    return accuracy_score(y_true, y_pred) # return accuracy score
    


# Define a monte-carlo cross-validation generator (reduce variance):
scores = []
# epochs_data = epochs.get_data(copy=False)
epochs_data_train = epochs_train.get_data(copy=False)
cv = ShuffleSplit(10, test_size=0.2, random_state=42)
# cv_split = cv.split(epochs_data_train)

# Preprocessing
scaler = Scaler(epochs.info)
csp = CSP(n_components=8, reg=None, norm_trace=False, log=True)
pca = UnsupervisedSpatialFilter(PCA(), average=True)
vec = Vectorizer()

# Assemble a classifier
lda = LinearDiscriminantAnalysis()
svc = SVC(kernel = "linear")


# Use scikit-learn Pipeline with cross_val_score function
# clf = Pipeline([("PCA", pca), ("CSP", csp), ("LDA", lda)])
# clf = Pipeline([("Scaler", scaler), ("PCA", pca), ("Vectorizer", vec), ("SVM", svc)])
clf = Pipeline([("Scaler", scaler), ("PCA", pca), ("Vectorizer", vec), ("LDA", lda)])

scores = cross_val_score(clf, epochs_data_train, labels, cv=cv, n_jobs=None, scoring=make_scorer(classification_report_with_accuracy_score))
# scores = cross_val_score(clf, epochs_data_train, labels, cv=cv, n_jobs=None)

# Printing the results
# class_balance = np.mean(labels == labels[0])
# class_balance = max(class_balance, 1.0 - class_balance)
# print(f"Classification accuracy: {np.mean(scores)} / Chance level: {class_balance}")

# plot CSP patterns estimated on full data for visualization
# csp.fit_transform(epochs_data, labels)

# csp.plot_patterns(epochs.info, ch_type="eeg", units="Patterns (AU)", size=1.5)
# Printing the results


class_balance = np.mean(labels == labels[0])
class_balance = min(class_balance, 1.0 - class_balance)
print(classification_report(pred_classes, true_classes)) # print classification report
print("="*20)
print(f"Classification accuracy: {np.mean(scores)} / Chance level: {class_balance}")
plot_cf(confusion_matrix(true_classes, pred_classes), "test", classes)



Look at performance over time



In [None]:
assert len(classes) < 3, "Can't do CSP on more than 2 classes"
cv_split = cv.split(epochs_data_train)
sfreq = raw.info["sfreq"]
w_length = int(sfreq * 0.5)  # running classifier: window length
w_step = int(sfreq * 0.1)  # running classifier: window step size
w_start = np.arange(0, epochs_data.shape[2] - w_length, w_step)

scores_windows = []

minipca = Pipeline([("PCA", pca), ("Vectorizer", vec)])
minicsp = Pipeline([("PCA", pca), ("CSP", csp)])

for train_idx, test_idx in cv_split:
    y_train, y_test = labels[train_idx], labels[test_idx]

    X_train = minicsp.fit_transform(epochs_data_train[train_idx], y_train)
    X_test = minicsp.transform(epochs_data_train[test_idx])
    # fit classifier
    svc.fit(X_train, y_train)

    # running classifier: test classifier on sliding window
    score_this_window = []
    for n in w_start:
        X_test = minicsp.transform(epochs_data[test_idx][:, :, n : (n + w_length)])
        score_this_window.append(svc.score(X_test, y_test))

    scores_windows.append(score_this_window)

# Plot scores over time
w_times = (w_start + w_length / 2.0) / sfreq + epochs.tmin
plt.figure()
plt.plot(w_times, np.mean(scores_windows, 0), label="Score")
plt.axvline(0, linestyle="--", color="k", label="Onset")
plt.axhline(0.5, linestyle="-", color="k", label="Chance")
plt.ylim(bottom=0, top=1)
plt.xlabel("time (s)")
plt.ylabel("classification accuracy")
plt.title("Classification score over time")
plt.legend(loc="lower right")
plt.show()

## References
.. footbibliography::

