# 1. Load the data

In [51]:
from mne.datasets import eegbci
from mne.io import concatenate_raws, read_raw_edf
from mne import Epochs, pick_types
from mne.channels import make_standard_montage
import os
import pywt

def apply_wavelet_transform(raw, wavelet='db4', level=4):
    """
    This function use the pywt library to apply wavelet transform by using 'ondelette' signal.
    The wavelet transform is more precise than fourier transform because the output frequence are localised with time

    Parameters
    ----------
    raw: mne.io.Raw
        An MNE Raw object containing the EEG data with annotations.
    wavelet: Wavelet object or name string, optional
        'db4' 'ondelette' family naming 'Daubechies' wuth 4 coefficient's filter. This 'ondelette' is often used in EEG analysis.
    level: int, optional
        level of decomposition (more is high, more is precise)
    Returns
    -------
    raw: mne.io.Raw
        An MNE Raw object containing the EEG data with annotations filtered by wavelet transform
    """
    data, times = raw.get_data(return_times=True)
    cleaned_data = np.zeros_like(data)
    for i in range(data.shape[0]):
        coeffs = pywt.wavedec(data[i, :], wavelet, level=level)

        threshold = 0.04
        coeffs[1:] = [pywt.threshold(c, threshold, mode='soft') for c in coeffs[1:]]

        cleaned_data[i, :] = pywt.waverec(coeffs, wavelet)

    raw._data = cleaned_data
    return raw

def load_all_eeg_from_documentation(data_dir,
                                    subjects,
                                    experiments):
    def edf_file(subject, experiment):
        return os.path.join(data_dir, f'S{subject:03d}', f'S{subject:03d}R{experiment:02d}.edf')
    raw = concatenate_raws([
        read_raw_edf(edf_file(subject, experiment), preload=True)
        for subject in subjects
        for experiment in experiments
    ])
    eegbci.standardize(raw)
    montage = make_standard_montage("standard_1005")
    raw.set_montage(montage)
    raw.set_eeg_reference(projection=True)
    raw.filter(16.0, 30.0, fir_design="firwin", skip_by_annotation="edge")
    raw = apply_wavelet_transform(raw)
    event_id = {
        'T1': 1,
        'T2': 2
    }
    picks = pick_types(raw.info, meg=False, eeg=True, stim=False, eog=False, exclude="bads")
    epochs = Epochs(
        raw,
        event_id=event_id,
        tmin=0,
        tmax=2,
        proj=True,
        picks=picks,
        baseline=None,
        preload=True,
    )
    epochs_train = epochs.copy().crop(tmin=0., tmax=2.0)
    labels = epochs.events[:, -1] - 2
    return epochs, epochs_train, labels

subjects =  list(range(1, 50))
experiments = [3, 4, 7, 8, 11, 12]
epochs, epochs_train, labels = load_all_eeg_from_documentation("../data/raw/", subjects, experiments)
epochs

Extracting EDF parameters from /Users/alexis/Documents/42/total-perspective-vortex/data/raw/S001/S001R03.edf...
EDF file detected
Setting channel info structure...
Creating raw.info structure...
Reading 0 ... 19999  =      0.000 ...   124.994 secs...
Extracting EDF parameters from /Users/alexis/Documents/42/total-perspective-vortex/data/raw/S001/S001R04.edf...
EDF file detected
Setting channel info structure...
Creating raw.info structure...
Reading 0 ... 19999  =      0.000 ...   124.994 secs...
Extracting EDF parameters from /Users/alexis/Documents/42/total-perspective-vortex/data/raw/S001/S001R07.edf...
EDF file detected
Setting channel info structure...
Creating raw.info structure...
Reading 0 ... 19999  =      0.000 ...   124.994 secs...
Extracting EDF parameters from /Users/alexis/Documents/42/total-perspective-vortex/data/raw/S001/S001R08.edf...
EDF file detected
Setting channel info structure...
Creating raw.info structure...
Reading 0 ... 19999  =      0.000 ...   124.994 secs

[Parallel(n_jobs=1)]: Done  17 tasks      | elapsed:    0.0s


Used Annotations descriptions: [np.str_('T0'), np.str_('T1'), np.str_('T2')]
Ignoring annotation durations and creating fixed-duration epochs around annotation onsets.
Not setting metadata
6629 matching events found
No baseline correction applied
Created an SSP operator (subspace dimension = 1)
1 projection items activated
Using data from preloaded Raw for 6629 events and 321 original time points ...
0 bad epochs dropped


Unnamed: 0,General,General.1
,MNE object type,Epochs
,Measurement date,2009-08-12 at 16:15:00 UTC
,Participant,X
,Experimenter,Unknown
,Acquisition,Acquisition
,Total number of events,6629
,Events counts,T1: 4410  T2: 2219
,Time range,0.000 – 2.000 s
,Baseline,off
,Sampling frequency,160.00 Hz


# 2. Train

## 2.1. Train version MNE

In [45]:
from sklearn.model_selection import ShuffleSplit, cross_val_score
from sklearn.pipeline import Pipeline
from mne.decoding import CSP
from sklearn.discriminant_analysis import LinearDiscriminantAnalysis
import numpy as np

# Monte-carlo cross-validation generator (reduce variance):
scores = []
epochs_data = epochs.get_data(copy=False)
epochs_data_train = epochs_train.get_data(copy=False)
cv = ShuffleSplit(10, test_size=0.2, random_state=42)
cv_split = cv.split(epochs_data)

# Classifier
lda = LinearDiscriminantAnalysis()
csp = CSP(n_components=4, reg=None, log=True, norm_trace=False)

clf = Pipeline([
    ('CSP', csp),
    ('LDA', lda)
])

scores = cross_val_score(clf, epochs_data_train, labels, cv=cv, n_jobs=1)
print(f"Scores de validation croisée : {scores}")
print(f"Score moyen : {scores.mean():.3f} (+/- {scores.std() * 2:.3f})")
# class_balance = np.mean(labels == labels[0])
# class_balance = max(class_balance, 1.0 - class_balance)
# print(f"Classification accuracy: {np.mean(scores)} / Chance level: {class_balance}")
# csp.fit_transform(epochs_data, labels)
# csp.plot_patterns(epochs.info, ch_type="eeg", units="Patterns (AU)", size=1.5)

Computing rank from data with rank=None
    Using tolerance 1.9e-05 (2.2e-16 eps * 64 dim * 1.3e+09  max singular value)
    Estimated rank (data): 63
    data: rank 63 computed from 64 data channels with 0 projectors
    Setting small data eigenvalues to zero (without PCA)
Reducing data rank from 64 -> 63
Estimating class=-1 covariance using EMPIRICAL
Done.
Estimating class=0 covariance using EMPIRICAL
Done.
Computing rank from data with rank=None
    Using tolerance 1.9e-05 (2.2e-16 eps * 64 dim * 1.3e+09  max singular value)
    Estimated rank (data): 63
    data: rank 63 computed from 64 data channels with 0 projectors
    Setting small data eigenvalues to zero (without PCA)
Reducing data rank from 64 -> 63
Estimating class=-1 covariance using EMPIRICAL
Done.
Estimating class=0 covariance using EMPIRICAL
Done.
Computing rank from data with rank=None
    Using tolerance 1.9e-05 (2.2e-16 eps * 64 dim * 1.3e+09  max singular value)
    Estimated rank (data): 63
    data: rank 63 compu

## 2.2. Train version mine

In [52]:
from sklearn.model_selection import train_test_split
# Recherche des meilleurs hyperparamètres
from sklearn.model_selection import GridSearchCV
param_grid = {
    'CSP__n_components': [2, 4, 6, 8],
    'LDA__solver': ['svd', 'lsqr']
}

clf = Pipeline([
    ('CSP', CSP()),
    ('LDA', LinearDiscriminantAnalysis())
])

X_train, X_test, y_train, y_test = train_test_split(
    epochs_data_train,
    labels,
    test_size=0.2,
    random_state=42
)
grid = GridSearchCV(clf, param_grid, cv=5)
grid.fit(X_train, y_train)
print(f"Meilleurs paramètres : {grid.best_params_}")
best_model = grid.best_estimator_
print(f"Score de validation croisée : {grid.best_score_}")

Computing rank from data with rank=None
    Using tolerance 1.7e-05 (2.2e-16 eps * 64 dim * 1.2e+09  max singular value)
    Estimated rank (data): 63
    data: rank 63 computed from 64 data channels with 0 projectors
    Setting small data eigenvalues to zero (without PCA)
Reducing data rank from 64 -> 63
Estimating class=-1 covariance using EMPIRICAL
Done.
Estimating class=0 covariance using EMPIRICAL
Done.
Computing rank from data with rank=None
    Using tolerance 1.7e-05 (2.2e-16 eps * 64 dim * 1.2e+09  max singular value)
    Estimated rank (data): 63
    data: rank 63 computed from 64 data channels with 0 projectors
    Setting small data eigenvalues to zero (without PCA)
Reducing data rank from 64 -> 63
Estimating class=-1 covariance using EMPIRICAL
Done.
Estimating class=0 covariance using EMPIRICAL
Done.
Computing rank from data with rank=None
    Using tolerance 1.7e-05 (2.2e-16 eps * 64 dim * 1.2e+09  max singular value)
    Estimated rank (data): 63
    data: rank 63 compu

# Prédiction

In [54]:
from sklearn.metrics import accuracy_score

predictions = []
true_labels = []
predict_subjects = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10]
# predict_experiments = [5, 6, 9, 10, 13, 14]
predict_experiments = [3, 4, 7, 8, 11, 12]
epochs_predict, _, labels_predict = load_all_eeg_from_documentation("../data/raw/", predict_subjects, predict_experiments)

Extracting EDF parameters from /Users/alexis/Documents/42/total-perspective-vortex/data/raw/S001/S001R03.edf...
EDF file detected
Setting channel info structure...
Creating raw.info structure...
Reading 0 ... 19999  =      0.000 ...   124.994 secs...
Extracting EDF parameters from /Users/alexis/Documents/42/total-perspective-vortex/data/raw/S001/S001R04.edf...
EDF file detected
Setting channel info structure...
Creating raw.info structure...
Reading 0 ... 19999  =      0.000 ...   124.994 secs...
Extracting EDF parameters from /Users/alexis/Documents/42/total-perspective-vortex/data/raw/S001/S001R07.edf...
EDF file detected
Setting channel info structure...
Creating raw.info structure...
Reading 0 ... 19999  =      0.000 ...   124.994 secs...
Extracting EDF parameters from /Users/alexis/Documents/42/total-perspective-vortex/data/raw/S001/S001R08.edf...
EDF file detected
Setting channel info structure...
Creating raw.info structure...
Reading 0 ... 19999  =      0.000 ...   124.994 secs

[Parallel(n_jobs=1)]: Done  17 tasks      | elapsed:    0.0s


Used Annotations descriptions: [np.str_('T0'), np.str_('T1'), np.str_('T2')]
Ignoring annotation durations and creating fixed-duration epochs around annotation onsets.
Not setting metadata
1357 matching events found
No baseline correction applied
Created an SSP operator (subspace dimension = 1)
1 projection items activated
Using data from preloaded Raw for 1357 events and 321 original time points ...
0 bad epochs dropped


In [74]:
predictions = best_model.predict(epochs_predict.get_data())
print(f"Accuracy: {accuracy_score(labels_predict, predictions):.3f}")

Accuracy: 0.664
