In [1]:
import os
import math 
import numpy as np
import matplotlib.pyplot as plt
import joblib

from sklearn.pipeline import Pipeline, make_pipeline
from sklearn.discriminant_analysis import LinearDiscriminantAnalysis
from sklearn.model_selection import ShuffleSplit, cross_val_score, train_test_split
from sklearn.ensemble import RandomForestClassifier
from sklearn.metrics import accuracy_score
from sklearn.base import BaseEstimator, TransformerMixin
from sklearn.utils.validation import check_array, check_is_fitted
from sklearn.svm import SVC 
from sklearn.decomposition import PCA
from sklearn.linear_model import LogisticRegression

import mne
from mne import Epochs, pick_types, annotations_from_events, events_from_annotations
from mne.channels import make_standard_montage
from mne.io import concatenate_raws, read_raw_edf
from mne.datasets import eegbci 
from src.CSP import CSP
from mne.viz import plot_events, plot_montage
from mne.preprocessing import ICA, create_eog_epochs, create_ecg_epochs, corrmap, Xdawn

mne.set_log_level("CRITICAL")

In [2]:
if not os.path.exists("models"):
    os.makedirs("models")

In [3]:
tmin, tmax = -1.0, 2.0
subjects_count = 109 
drop_channels = False
crop_train = True

In [4]:
experiments = [
    {
        "runs": [3, 7, 11],
        "mapping": {0: "rest", 1: "left/fist", 2: "right/fist"},
        "event_id": {"left/fist": 1, "right/fist": 2},
    },
    {
        "runs": [4, 8, 12],
        "mapping": {0: "rest", 1: "left/imaginefist", 2: "right/imaginefist"},
        "event_id": {"left/imaginefist":1, "right/imaginefist":2},
    },
    {
        "runs": [5, 9, 13],
        "mapping": {0: "rest", 1: "top/fists", 2: "bottom/feets"},
        "event_id": {"top/fists": 1, "bottom/feets": 2},
    },
    {
        "runs": [6, 10, 14],
        "mapping": {0: "rest", 1: "top/imaginefists", 2: "top/imaginefeets"},
        "event_id": {"top/imaginefists": 1, "top/imaginefeets": 2},
    },
    # ----------------------------------------------------------------------------
    # {
    #     "runs": [3, 7, 11, 4, 8, 12],
    #     "mapping": {0: "rest", 1: "left fist", 2: "right fist"},
    # },
    # {
    #     "runs": [5, 9, 13, 6, 10, 14],
    #     "mapping": {0: "rest", 1: "both fists", 2: "both feets"},
    # },
]

In [5]:
def create_model():
    # Decomposer
    csp = CSP(n_components=8)

    # Classifier 
    logr = LogisticRegression(penalty='l1', solver='liblinear', multi_class='auto') 

    # Pipeline
    clf = Pipeline([("CSP", csp), ("LogisticRegression", logr)]) 

    return clf

In [6]:
def cleanup_raw(raw):
    eegbci.standardize(raw)  # set channel names
    montage = make_standard_montage("biosemi64")
    raw.set_montage(montage, on_missing='ignore')
    
    # Select channels
    if drop_channels:
        channels = raw.info["ch_names"] 
        good_channels = [
            "FC3",
            "FC1",
            "FCz",
            "FC2",
            "FC4",
            "C3",
            "C1",
            "Cz",
            "C2",
            "C4",
            "CP3",
            "CP1",
            "CPz",
            "CP2",
            "CP4",
            "Fpz",
        ]
        bad_channels = [x for x in channels if x not in good_channels]
        raw.drop_channels(bad_channels)

    # Filter
    raw.notch_filter(60, method="iir")
    raw.filter(7.0, 32.0, fir_design="firwin", skip_by_annotation="edge") 

In [7]:
all_accuracies = []
all_cross_accuracies = []
for subject in range(1, subjects_count + 1): 
    print(f"Subject #{subject}")

    for experiment_id, experiment in enumerate(experiments):
        raw_fnames = [f"dataset/S{subject:03d}/S{subject:03d}R{run:02d}.edf" for run in experiment["runs"]]
        raw = concatenate_raws([read_raw_edf(f, preload=True) for f in raw_fnames])
        events, _ = events_from_annotations(raw, event_id=dict(T1=1, T2=2))
        annot_from_events = annotations_from_events(
            events=events, event_desc=experiment["mapping"], sfreq=raw.info["sfreq"]
        )
        raw.set_annotations(annot_from_events)
        cleanup_raw(raw)

        # Read epochs 
        events, _ = events_from_annotations(raw, event_id=experiment["event_id"])
        picks = pick_types(raw.info, meg=False, eeg=True, stim=False, eog=False, exclude="bads")
        epochs = Epochs(raw, events, experiment["event_id"], tmin, tmax, proj=True, picks=picks, baseline=None, preload=True)
        epochs_data = epochs.get_data()
        labels = epochs.events[:, -1]

        # monte-carlo cross-validation generator:
        cv = ShuffleSplit(10, test_size=0.2, random_state=42)

        # Display accuracy
        model = create_model()
        epochs_train = epochs.copy()
        if crop_train:
            epochs_train = epochs_train.crop(1, 2)
        score = cross_val_score(model, epochs_train.get_data(), labels, cv=cv, n_jobs=-1, verbose=False)
        model.fit(epochs_data, labels)
        accuracy = model.score(epochs_data, labels)
        print(f"[Training] Accuracy: {accuracy:.2%} (score: {score.mean():.2} ~{score.std():.2})")
        all_accuracies.append(accuracy)
        all_cross_accuracies.append(score) 
        file =  f"models/model_{subject}_{experiment_id}.z"
        if os.path.isfile(file):
            os.remove(file)
        joblib.dump(model, file)

Subject #1
[Training] Accuracy: 100.00% (score: 0.53 ~0.16)
[Training] Accuracy: 100.00% (score: 0.66 ~0.18)
[Training] Accuracy: 100.00% (score: 0.96 ~0.1)
[Training] Accuracy: 100.00% (score: 0.92 ~0.087)
Subject #2
[Training] Accuracy: 73.33% (score: 0.47 ~0.083)
[Training] Accuracy: 71.11% (score: 0.68 ~0.14)
[Training] Accuracy: 97.78% (score: 0.97 ~0.051)
[Training] Accuracy: 75.56% (score: 0.72 ~0.19)
Subject #3
[Training] Accuracy: 97.78% (score: 0.46 ~0.12)
[Training] Accuracy: 100.00% (score: 0.43 ~0.15)
[Training] Accuracy: 100.00% (score: 0.53 ~0.16)
[Training] Accuracy: 100.00% (score: 0.54 ~0.18)
Subject #4
[Training] Accuracy: 100.00% (score: 0.56 ~0.15)
[Training] Accuracy: 100.00% (score: 0.3 ~0.087)
[Training] Accuracy: 93.33% (score: 0.89 ~0.13)
[Training] Accuracy: 95.56% (score: 0.68 ~0.13)
Subject #5
[Training] Accuracy: 97.78% (score: 0.44 ~0.16)
[Training] Accuracy: 100.00% (score: 0.51 ~0.13)
[Training] Accuracy: 100.00% (score: 0.77 ~0.14)
[Training] Accuracy:

In [8]:
print(f"Training accuracy: {np.mean(all_accuracies):.2%} (cross validation score: {np.mean(all_cross_accuracies):.2%})")

Training accuracy: 93.19% (cross validation score: 60.72%)
