In [29]:
import numpy as np
import matplotlib.pyplot as plt
from matplotlib import gridspec

# mne library to analyse EEG
import mne
from mne import Epochs, pick_types
from mne.channels import make_standard_montage
from mne.datasets import eegbci
from mne.decoding import CSP
from mne.io import concatenate_raws, read_raw_edf

from autoreject import AutoReject
mne.set_log_level('error') # Avoid long log

In [None]:
def load_raws(participant: int, run: int, preload = True, path: str = 'files/'):
    raws = [read_raw_edf(path + f'S{participant:03}/S{participant:03}R{run:02}.edf', preload = preload) for participant in participants for run in runs]
    return raws

In [None]:


def load_raws(participants: list, runs: list, preload = True, path: str = 'files/'):
    raws = [read_raw_edf(path + f'S{participant:03}/S{participant:03}R{run:02}.edf', preload = preload) for participant in participants for run in runs]
    return raws

# Function to load the data and set the montage
def load_and_set_montage(raws: list, montage_type: str = "standard_1020"):

    # f no other choise is made we try the "Standard_1020" montage
    montage = make_standard_montage(montage_type)

    # Dictionary with the structure old_name : correct_cases_name. To respect the upper and lower cases of the standard notation for the electrde's position.
    replacement = {
        'Fc': 'FC',
        'Cp': 'CP',
        'Af': 'AF',
        'Ft': 'FT',
        'Tp': 'TP',
        'Po': 'PO'   
    }

    
    # new_name is the dictionary to use to cange the name of the electrode's positions to respect the usual sandard notataions.
    # First get rid of the excessive "." 
    new_names = {
    name : name.replace(".", "") for name in raws[0].info['ch_names']
    }   

    # Change the lower and upper case of the electrode's names
    for key in new_names.keys():
        for old_string, new_string in replacement.items():
            new_names[key] = new_names[key].replace(old_string, new_string)

    # Rename channels and set the montage
    for raw in raws:
        raw.rename_channels(new_names)
        raw.set_montage(montage)

    return raws


# Function to filter the data
def filter(raws : list, low_cut: float = 0.1, high_cut: float = 30, preload = False):

    if preload == True:
        raws_filt = [raw.copy().filter(low_cut, high_cut) for raw in raws]
    else:
        raws_filt = [raw.copy().load_data().filter(low_cut, high_cut) for raw in raws]
    
    return raws_filt


def create_epochs(raws_filt, tmin: float, tmax: float, baseline = None, event_id = {'rest': 1, 'left_fist': 2, 'right_fist': 3}, preload: bool = True):
    epochs = [
        Epochs(raw_filt, mne.events_from_annotations(raw_filt)[0], event_id, tmin=tmin, tmax=tmax, baseline= baseline, preload= preload)
        for raw_filt in raws_filt
    ]
    return epochs




In [37]:
def preprocessing(raw, ica_low_cut: float = 1., ica_high_cut: float = 30.):

    # Creating a raw data to use to compute the ICA
    raw_ica = raw.copy().filter(ica_low_cut, ica_high_cut)

    # Creating the epochs for the artifact detection and the ICA
    time_step = 1 # Coerent with the time-span of the most common artifact (heart, eye blink)
    events_ica = mne.make_fixed_length_events(raw_ica, duration=time_step)
    
    epochs_ica = mne.Epochs(raw_ica, events_ica,
                        tmin=0.0, tmax=time_step,
                        baseline=None,
                        preload=True)

    # Instancing the autoreject object
    ar = AutoReject(
                    random_state=42,
                    picks=mne.pick_types(epochs_ica.info, 
                                        eeg=True,
                                        eog=False
                                        ),
                    n_jobs=-1, 
                    verbose=False
                    )

    ar.fit(epochs_ica)
    reject_log = ar.get_reject_log(epochs_ica)

    # ICA parameters
    random_state = 0   # ensures ICA is reproducible each time it's run
    ica_n_components = 25    # Specify n_components as a decimal to set % explained variance

    # Fit ICA
    ica = mne.preprocessing.ICA(n_components=ica_n_components,
                                random_state=random_state,
                                )
    ica.fit(epochs_ica[~reject_log.bad_epochs], decim=3)

    ica.exclude = []
    num_excl = 0
    max_ic = 2 # Chosing three seems to eliminate the eye movements and blinks.
    z_thresh = 3.5
    z_step = .05

    while num_excl < max_ic:
        eog_indices, eog_scores = ica.find_bads_eog(epochs_ica,
                                                    ch_name=['Fp1', 'Fp2', 'AF7', 'AF8', 'F7', 'F8', 'Fpz'], 
                                                    threshold=z_thresh
                                                    )
        num_excl = len(eog_indices)
        z_thresh -= z_step # won't impact things if num_excl is ≥ n_max_eog 


    # assign the bad EOG components to the ICA.exclude attribute so they can be removed later
    ica.exclude = eog_indices

    return ica


In [38]:
# Create a list with each tipe of experimental run
openeye_runs = [1]
closedeye_runs = [2]
fists_runs = [3, 7, 11]
imaginefists_runs = [4, 8, 12]
fistsfeet_runs = [5, 9, 13]
imaginefistsfeet_run = [6, 10, 14]

event_id = {'rest': 1, 'left_fist': 2, 'right_fist': 3}

participants = [i for  i in range(1,5)]
preload = True
path = 'files/'

raws = load_raws(participants, fists_runs, preload, path)

raws = load_and_set_montage(raws)

raws_filt = filter(raws)

raws_ica = [preprocessing(raw) for raw in raws]







In [47]:
tmin =  -1.  # start of each epoch (in sec)
tmax =  4.1  # end of each epoch (in sec)
baseline = (-1, 0) # for the baseline correction we choose the interval that reflect the resting state before the event

event_id = {'rest': 1, 'left_fist': 2, 'right_fist': 3}
# Create epochs
epochs = [mne.Epochs(raw_filt,
                    mne.events_from_annotations(raw_filt)[0], event_id,
                    tmin = tmin,
                    tmax = tmax,
                    preload=True,
                    baseline = baseline
                   ) for raw_filt in raws_filt ]


list_epochs_postica = []
    
for i in range(len(epochs)):
    epochs_postica = raws_ica[i].apply(epochs[i])
    list_epochs_postica.append(epochs_postica)

list_epochs_postica_cleaned = []

for epochs_postica in list_epochs_postica:
    ar = AutoReject(n_interpolate=[1, 2, 4],
                random_state=42,
                picks=mne.pick_types(epochs_postica.info, 
                                     eeg=True,
                                     eog=False
                                    ),
                n_jobs=-1, 
                verbose=False
                )

    list_epochs_postica_cleaned.append(ar.fit_transform(epochs_postica, return_log=True)[0])


In [48]:
list_epochs_masref = [epochs_clean.copy().set_eeg_reference(ref_channels=['T9', 'T10']) for epochs_clean in list_epochs_postica_cleaned]

array([[[-1.89408167e-05, -1.08178429e-05, -6.57928802e-06, ...,
         -3.49665191e-05, -4.78734092e-05, -4.59465927e-05],
        [-1.01840391e-05, -8.97076136e-06, -4.38843805e-06, ...,
         -3.30464477e-05, -4.49483345e-05, -3.80733551e-05],
        [ 1.10088418e-05,  8.49344778e-06,  1.01400541e-05, ...,
         -1.58297086e-05, -2.18850870e-05, -1.32407123e-05],
        ...,
        [-1.31769438e-06, -2.13841539e-06, -3.83006830e-06, ...,
         -4.61249303e-06,  2.20133040e-06,  8.66263195e-06],
        [ 6.17635354e-06, -5.33437351e-06, -1.27538837e-05, ...,
         -1.00869600e-05,  5.17565516e-06,  2.34081457e-05],
        [-9.61382372e-06, -2.92840163e-06,  5.46646421e-06, ...,
         -3.12512324e-05, -1.61799884e-05,  1.24885155e-06]],

       [[-8.71225536e-06, -1.10283943e-05, -1.43448283e-05, ...,
          4.60573344e-05,  3.52922275e-05,  2.11748454e-05],
        [-8.35975191e-06, -1.18054294e-05, -1.38618801e-05, ...,
          4.59548490e-05,  3.61762547e