In [9]:
import numpy as np 
import joblib
import timeit
import os

from sklearn.metrics import accuracy_score

import mne
from mne import Epochs, pick_types, annotations_from_events, events_from_annotations
from mne.channels import make_standard_montage
from mne.io import concatenate_raws, read_raw_edf
from mne.datasets import eegbci

mne.set_log_level("CRITICAL")

In [10]:
tmin, tmax = -1.0, 4.0
drop_channels = False
subjects_count = 5

In [11]:
experiments = [
    {
        "runs": [3, 7, 11],
        "mapping": {0: "rest", 1: "left fist", 2: "right fist"},
        "event_id": {"left fist": 1, "right fist": 2},
    },
    {
        "runs": [4, 8, 12],
        "mapping": {0: "rest", 1: "imagine left fist", 2: "imagine right fist"},
        "event_id": {"imagine left fist":1, "imagine right fist":2},
    },
    {
        "runs": [5, 9, 13],
        "mapping": {0: "rest", 1: "both fists", 2: "both feets"},
        "event_id": {"both fists": 1, "both feets": 2},
    },
    {
        "runs": [6, 10, 14],
        "mapping": {0: "rest", 1: "imagine both fists", 2: "imagine both feets"},
        "event_id": {"imagine both fists": 1, "imagine both feets": 2},
    },
]

In [12]:
all_accuracies = []
for subject in range(1, subjects_count + 1):
    print(f"Subject #{subject}")

    for experiment_id, experiment in enumerate(experiments):
        model_file = f"models/model_{subject}_{experiment_id}.z"
        if not os.path.isfile(model_file):
            os.remove(model_file)
            print(f"Skipped experiment {experiment_id + 1}: Missing model")
            continue
        model = joblib.load(model_file)
        raw_fnames = [f"dataset/S{subject:03d}/S{subject:03d}R{run:02d}.edf" for run in experiment["runs"]]
        raw = concatenate_raws([read_raw_edf(f, preload=True) for f in raw_fnames])
        events, _ = events_from_annotations(raw, event_id=dict(T0=0, T1=1, T2=2))
        annot_from_events = annotations_from_events(
            events=events, event_desc=experiment["mapping"], sfreq=raw.info["sfreq"]
        )
        raw.set_annotations(annot_from_events)
        eegbci.standardize(raw)  # set channel names
        montage = make_standard_montage("biosemi64")
        raw.set_montage(montage, on_missing='ignore')

        # Select channels
        if drop_channels:
            channels = raw.info["ch_names"] 
            good_channels = [
                "FC3",
                "FC1",
                "FCz",
                "FC2",
                "FC4",
                "C3",
                "C1",
                "Cz",
                "C2",
                "C4",
                "CP3",
                "CP1",
                "CPz",
                "CP2",
                "CP4",
                "Fpz",
            ]
            bad_channels = [x for x in channels if x not in good_channels]
            raw.drop_channels(bad_channels)

        # Filter
        raw.notch_filter(60, method="iir")
        raw.filter(7.0, 32.0, fir_design="firwin", skip_by_annotation="edge")  

        # Read epochs
        events, _ = events_from_annotations(raw, event_id=experiment["event_id"]) 
        picks = pick_types(raw.info, meg=False, eeg=True, stim=False, eog=False, exclude="bads")
        epochs = Epochs(raw, events, experiment["event_id"], tmin, tmax, proj=True, picks=picks, baseline=None, preload=True)
        epochs_data = epochs.get_data()
        labels = epochs.events[:, -1]

        # Score
        start_time = timeit.default_timer()
        accuracy = model.score(epochs_data, labels)
        elapsed = timeit.default_timer() - start_time
        print(f"[Training] Accuracy: {accuracy:.2%} in {elapsed:.2}s")
        all_accuracies.append(accuracy)

Subject #1
[Training] Accuracy: 91.11% in 0.005s
[Training] Accuracy: 100.00% in 0.003s
[Training] Accuracy: 100.00% in 0.0043s
[Training] Accuracy: 100.00% in 0.006s
Subject #2
[Training] Accuracy: 88.89% in 0.0031s
[Training] Accuracy: 93.33% in 0.0063s
[Training] Accuracy: 100.00% in 0.004s
[Training] Accuracy: 84.44% in 0.0029s
Subject #3
[Training] Accuracy: 100.00% in 0.0031s
[Training] Accuracy: 100.00% in 0.003s
[Training] Accuracy: 97.78% in 0.0031s
[Training] Accuracy: 93.33% in 0.003s
Subject #4
[Training] Accuracy: 100.00% in 0.0031s
[Training] Accuracy: 97.78% in 0.0031s
[Training] Accuracy: 97.78% in 0.0033s
[Training] Accuracy: 88.89% in 0.003s
Subject #5
[Training] Accuracy: 100.00% in 0.003s
[Training] Accuracy: 97.78% in 0.0029s
[Training] Accuracy: 95.56% in 0.005s
[Training] Accuracy: 97.78% in 0.0032s


In [13]:
print(f"Global accuracy: {np.mean(all_accuracies):.2%}")

Global accuracy: 96.22%
