In [1]:
import numpy as np
import matplotlib.pyplot as plt
from matplotlib import gridspec

# mne library to analyse EEG
import mne
from mne import Epochs, pick_types
from mne.channels import make_standard_montage
from mne.datasets import eegbci
from mne.decoding import CSP
from mne.io import concatenate_raws, read_raw_edf
from mne.decoding import Vectorizer
mne.set_log_level('error') # Avoid long log


from sklearn.preprocessing import StandardScaler
from sklearn.pipeline import make_pipeline
from sklearn.model_selection import cross_val_score, train_test_split, GridSearchCV, StratifiedKFold, cross_val_predict
from sklearn.metrics import precision_recall_fscore_support, accuracy_score

from scipy import stats

# Models
from sklearn import svm
from sklearn.discriminant_analysis import LinearDiscriminantAnalysis
from sklearn.linear_model import LogisticRegression

In [2]:
# Create a list with each tipe of experimental run
openeye_runs = [1]
closedeye_runs = [2]
fists_runs = [3, 7, 11]
imaginefists_runs = [4, 8, 12]
fistsfeet_runs = [5, 9, 13]
imaginefistsfeet_run = [6, 10, 14]

# List with the ID of each participant
participants = [_ for _ in range(1,110)]


# Defining the EEG standard EEG bands. These are indicative situation in which we see appearing this type of waves on healty subjects.
delta_waves = {
    'freq_min': 0.5,
    'f_max': 4
} # normally occur during deep sleep

theta_waves = {
    'freq_min': 4,
    'freq_max': 8
} # transiently during sleep

alpha_waves = {
    'freq_min': 8, 
    'freq_max': 13
} # relaxed but awake state, resting with the eyes closed

beta_waves = {
    'freq_min': 13, 
    'freq_max': 30   
} # attention to tasks or stimuli,logical thinking

gamma_waves = {
    'freq_min': 30, 
    'freq_max': 70  
} # large-scale brain network activity and cognitive phenomena such as working memory, attention



In [3]:
# Charging the data
participant = 14

# Get the path to the data
def file_path(participant, run):
    return f'files/S{participant:03}/S{participant:03}R{run:02}.edf'

# Load the data
# Preload = True charges also the data, not just the headers
# raw = concatenate_raws([read_raw_edf(file_path(participant, run), preload = True) for run in fists_runs])

In [4]:
raws = [read_raw_edf(file_path(participant, 3)) for participant in range(1,10)]

In [5]:
# We choose the "Standard_1020" montage

montage = mne.channels.make_standard_montage("standard_1020")


# Here we change the names of the electrode to match the standard notation and set the choosen montage on the raw data charged.

# Dictionary with the structure old_name : correct_cases_name. To respect the upper and lower cases of the standard notation for the electrde's position.
replacement = {
    'Fc': 'FC',
    'Cp': 'CP',
    'Af': 'AF',
    'Ft': 'FT',
    'Tp': 'TP',
    'Po': 'PO'   
}

# new_name is the dictionary to use to cange the name of the electrode's positions to respect the usual sandard notataions.
# First get rid of the excessive "." 
new_names = {
    name : name.replace(".", "") for name in raws[0].info['ch_names']
}   

# Change the lower and upper case of the electrode's names
for key in new_names.keys():
    for old_string, new_string in replacement.items():
        new_names[key] = new_names[key].replace(old_string, new_string)

# Choose the montage and set it for the uploaded data
# montage = 'standard_1020'
for raw in raws:
    raw.rename_channels(new_names)
    raw.set_montage(montage)



In [6]:
low_cut = 0.1 # We filter the low frequency to remove slow drift
high_cut = 30 # We filter the high frequency to eliminate noise, and because the motor signals appears mostly as alpha and beta waves

# copy the raw data and apply the filter
raws_filt = [raw.load_data().copy().filter(low_cut, high_cut) for raw in raws]

In [7]:
tmin =  -1.  # start of each epoch (in sec)
tmax =  4.1  # end of each epoch (in sec)
baseline = (-1, 0) # for the baseline correction we choose the interval that reflect the resting state before the event

# Making it easyer to read the events
event_mapping = {
    1: 'rest',
    2: 'left_fist',
    3: 'right_fist'
}
event_id = {v:k for k,v in event_mapping.items()}

epochs = [
    Epochs(raw_filt, mne.events_from_annotations(raw_filt)[0], event_id, tmin=tmin, tmax=tmax, baseline= baseline)
    for raw_filt in raws_filt
]

rememeber to eliminate subjects 88, 92, 100 because the experiments have been done using differents timings

In [8]:

labels = [epoch.events[:,-1] for epoch in epochs]
raws_epochs = [epoch.get_data() for epoch in epochs]

In [None]:
for i in range(len(raws_filt)):
    if raws_filt[i].shape[2] != 817:
        print(raws_filt[i].shape[] , 'subject in position', i )

654 subject in position 87
654 subject in position 91
654 subject in position 99


In [9]:
data = [epoch.get_data() for epoch in epochs]
labels = [epoch.events[:,-1] for epoch in epochs]

In [10]:
for i in data:
    print(i.shape)

(29, 64, 817)
(28, 64, 817)
(29, 64, 817)
(28, 64, 817)
(28, 64, 817)
(28, 64, 817)
(29, 64, 817)
(28, 64, 817)
(28, 64, 817)


In [11]:
X = np.concatenate(data, axis = 0)
y = np.concatenate(labels)

In [14]:
print('X shape', X.shape)
print('y shape', y.shape)

X shape (255, 64, 817)
y shape (255,)


In [105]:
for i in data_vec[1][-100:]:
    print(i)

-6.20620336508404e-05
-8.182442426522105e-05
-7.161551098853623e-05
-4.0339538447180534e-05
-9.205146280631682e-06
6.814036612374151e-06
6.607751441151636e-06
-3.1586807543384517e-06
-1.533727737955922e-05
-2.3085456083235042e-05
-1.999675102612292e-05
-5.471825705648514e-06
1.1460499395136669e-05
1.9352248258763882e-05
1.6757987439230225e-05
1.3230631126082594e-05
1.5682016277715548e-05
1.8410362257705626e-05
1.2185315534966168e-05
9.168305492517619e-07
1.0484260626577472e-06
2.091867423711383e-05
4.669716626697539e-05
5.57870789041648e-05
4.242051903250098e-05
2.2827576617740745e-05
1.3708132754532219e-05
1.2712316807836866e-05
6.0064379224555546e-06
-9.000922242370547e-06
-1.7399568809462985e-05
-5.34590023675676e-06
2.136499868547174e-05
4.2613582476301424e-05
4.560590961146256e-05
3.5742285664130755e-05
2.564350511542425e-05
1.891763229274564e-05
9.009874748903339e-06
-7.6019828641584594e-06
-2.2163686333879186e-05
-2.0669986875295122e-05
2.577566290030404e-07
2.6519089982511098e-

In [63]:
data_vec.shape

(29, 52288)

In [108]:
data_train, data_test, labels_train, labels_test = train_test_split(data, labels, test_size=0.3, random_state=0)

TypeError: Singleton array array(<generator object <genexpr> at 0x739df5dc5d80>, dtype=object) cannot be considered a valid collection.

In [44]:
clf_lda_pip = make_pipeline(Vectorizer(), StandardScaler(), LinearDiscriminantAnalysis(solver='svd'))

In [52]:
clf_lda_pip = make_pipeline(Vectorizer(), StandardScaler(), LinearDiscriminantAnalysis(solver='svd'))

In [None]:
clf_lda_pip

In [23]:
def applyCrossValidation(models, model_names, data, labels, kfold):
    results = []
    if np.all(np.isfinite(data)) == True and np.any(np.isnan(data)) == False:
        for i in range(len(models)):
            #print(model_names[i])
            cv_accuracy = cross_val_score(models[i], data, labels, cv=kfold)
            results.append(cv_accuracy)
            #print('CV accuracy of model ' + model_names[i] + ': ' + str(cv_accuracy))
            
    else:
        print('Data has infinite or NaN value!')
    
    return results


In [46]:
results_perParticipant_UP = []
model_names = [ 'LR', 'LDA'] 
kfold = StratifiedKFold(n_splits=3, random_state=42, shuffle=True)
for i in range(len(data_train)):
    # Linear Discriminant Analysis
    clf_lda_pip = make_pipeline(Vectorizer(), StandardScaler(), LinearDiscriminantAnalysis(solver='svd'))
    #Logistic Regression
    clf_lr_pip = make_pipeline(Vectorizer(), StandardScaler(), LogisticRegression(penalty='l1', solver='liblinear', random_state=42))
    
    models = [ clf_lr_pip, clf_lda_pip]
    scores = applyCrossValidation(models, model_names, data_train[i], labels_train[i], kfold)
    results_perParticipant_UP.append(scores)


In [51]:
results_perParticipant_UP

[[array([0.3       , 0.5       , 0.44444444]),
  array([0.4       , 0.6       , 0.55555556])],
 [array([0.6       , 0.55555556, 0.55555556]),
  array([0.4       , 0.55555556, 0.66666667])],
 [array([0.6       , 0.55555556, 0.33333333]),
  array([0.5       , 0.33333333, 0.44444444])],
 [array([0.4       , 0.44444444, 0.77777778]),
  array([0.3       , 0.55555556, 0.66666667])],
 [array([0.5       , 0.66666667, 0.66666667]),
  array([0.4       , 0.55555556, 0.55555556])],
 [array([0.6       , 0.77777778, 0.77777778]),
  array([0.5       , 0.77777778, 0.77777778])],
 [array([0.2       , 0.66666667, 0.33333333]),
  array([0.4       , 0.44444444, 0.44444444])],
 [array([0.4       , 0.66666667, 0.66666667]),
  array([0.4       , 0.77777778, 0.33333333])],
 [array([0.3       , 0.55555556, 0.55555556]),
  array([0.4       , 0.44444444, 0.22222222])],
 [array([0.3       , 0.44444444, 0.22222222]),
  array([0.4       , 0.44444444, 0.44444444])],
 [array([0.5       , 0.44444444, 0.44444444]),
  a