In [None]:
# Uncomment if needed to install dependencies in requirements.txt
# %pip install -r ../requirements.txt

In [1]:
import sys
from pathlib import Path
import mne
import numpy as np
mne.set_log_level("ERROR")

In [2]:
PROJECT_ROOT = Path("..").resolve()
sys.path.append(str(PROJECT_ROOT))

from src.data import get_subject_data
from src.model import csp_lda_classifier
from src.evaluate import evaluate

In [3]:
# Get all subject folders from Physionet motor movement/imagery dataset
subjects = sorted([child for child in (PROJECT_ROOT/"data"/"files").iterdir() if child.is_dir()])
num_subjects = len(subjects)

accuracies = [] #Stores accuracies for each subject
bci_literate_subjects = [] # Stores subjects with accuracy above 70%

# Get results for every subject in physionet dataset
for subject_id, subject in enumerate(subjects):
    subject_id+=1
    print((f"Processing subject {subject_id}/{num_subjects}"))

    runs = sorted([run for run in subject.glob("*.edf")])
    train_runs = [runs[3], runs[7]]  # Select runs 4, 8 for motor imagery training data
    valid_runs = [runs[11]] # Select run 12 for testing motor imagery model

    # Retrieve EEG data for subject.
    # 0.5 second offset from event onset to account for subject reaction time to event.
    # 3.5 second duration because events are 4 seconds long.
    # T1 and T2 as classes to retrieve data for. T1=left hand. T2=right hand.
    train_data = get_subject_data(train_runs, ("T1","T2"))
    
    clf = csp_lda_classifier() # Create classifier object

    # Evaluate performance
    # Within run classification if only one run in used in data
    # Cross-run classficiation if multiple runs included in data
    mean, scores = evaluate(clf, train_data)

    clf.fit(train_data["x"], train_data["y"]) # Fit classifier with data
 
    valid_data = get_subject_data(valid_runs, ("T1","T2")) # Get validation data

    #Get accuracy of model on validation data
    preds = clf.predict(valid_data["x"])
    corrects = preds == valid_data["y"]
    accuracy = corrects.mean()
    print(f"Validation accuracy: {accuracy:.3f}", end="\n\n")

    # Store accuracy and check if subject is BCI literate
    accuracies.append(accuracy)
    if accuracy >= 0.7:
        bci_literate_subjects.append(subject_id)

print(f"Average accuracy across all subjects: {np.mean(accuracies):.3f}")
print(f"Number of BCI literate subjects: {len(bci_literate_subjects)}/{num_subjects}")
print(f"BCI literate subjects: {bci_literate_subjects}")



Processing subject 1/109
Evaluation accuracy: 0.567
Validation accuracy: 0.733

Processing subject 2/109
Evaluation accuracy: 0.933
Validation accuracy: 0.733

Processing subject 3/109
Evaluation accuracy: 0.500
Validation accuracy: 0.667

Processing subject 4/109
Evaluation accuracy: 0.567
Validation accuracy: 0.400

Processing subject 5/109
Evaluation accuracy: 0.533
Validation accuracy: 0.533

Processing subject 6/109
Evaluation accuracy: 0.533
Validation accuracy: 0.533

Processing subject 7/109
Evaluation accuracy: 0.800
Validation accuracy: 1.000

Processing subject 8/109
Evaluation accuracy: 0.433
Validation accuracy: 0.400

Processing subject 9/109
Evaluation accuracy: 0.333
Validation accuracy: 0.467

Processing subject 10/109
Evaluation accuracy: 0.567
Validation accuracy: 0.267

Processing subject 11/109
Evaluation accuracy: 0.533
Validation accuracy: 0.733

Processing subject 12/109
Evaluation accuracy: 0.500
Validation accuracy: 0.800

Processing subject 13/109
Evaluation 