In [3]:
import numpy as np
import matplotlib.pyplot as plt
from matplotlib import gridspec

import os

# mne library to analyse EEG
import mne
from mne import Epochs
from mne.channels import make_standard_montage
from mne.io import read_raw_edf
from mne.preprocessing import ICA
mne.set_log_level('error') # Avoid long log

# autorejct to reject bads epochs
from autoreject import AutoReject

# from mne import Epochs, pick_types
# from mne.decoding import CSP
# from mne.io import concatenate_raws, read_raw_edf

# from autoreject import AutoReject

In [4]:
# Function to load the data and set the montage
def set_montage(raw, montage_type: str = "standard_1020"):

    # f no other choise is made we try the "Standard_1020" montage
    montage = make_standard_montage(montage_type)

    # Dictionary with the structure old_name : correct_cases_name. To respect the upper and lower cases of the standard notation for the electrde's position.
    replacement = {
        'Fc': 'FC',
        'Cp': 'CP',
        'Af': 'AF',
        'Ft': 'FT',
        'Tp': 'TP',
        'Po': 'PO'   
    }

    
    # new_name is the dictionary to use to cange the name of the electrode's positions to respect the usual sandard notataions.
    # First get rid of the excessive "." 
    new_names = {
    name : name.replace(".", "") for name in raw.info['ch_names']
    }   

    # Change the lower and upper case of the electrode's names
    for key in new_names.keys():
        for old_string, new_string in replacement.items():
            new_names[key] = new_names[key].replace(old_string, new_string)

    # Rename channels and set the montage
   
    raw.rename_channels(new_names)
    raw.set_montage(montage)

    return raw

In [6]:

def filter_raw(raw, low_cut: float = 0.1, high_cut: float = 30, preload = False):

    if preload == True:
        raw_filt = raw.copy().filter(low_cut, high_cut)
    else:
        raw_filt = raw.copy().load_data().filter(low_cut, high_cut)
    
    return raw_filt

# def create_epochs(raws_filt, tmin: float, tmax: float, baseline = None, event_id = {'rest': 1, 'left_fist': 2, 'right_fist': 3}, preload: bool = True):
#     epochs = [
#         Epochs(raw_filt, mne.events_from_annotations(raw_filt)[0], event_id, tmin=tmin, tmax=tmax, baseline= baseline, preload= preload)
#         for raw_filt in raws_filt
#     ]
#     return epochs

In [22]:
def autoreject_ica_preprocessing(raw, ica_n_components: int = 25, ica_low_cut: float = 1., ica_high_cut: float = 30., time_step: float = 1., 
        random_state: int = None, max_ic: int = 2, z_thresh: float = 3.5, z_step: float = .05):

    # Creating a raw data to use to compute the ICA
    raw_ica = raw.copy().filter(ica_low_cut, ica_high_cut)

    # Creating the epochs for the artifact detection and the ICA
    # Coerent with the time-span of the most common artifact (heart, eye blink)
    events_ica = mne.make_fixed_length_events(raw_ica, duration=time_step)   

    epochs_ica = Epochs(raw_ica, events_ica,
                        tmin=0.0, tmax=time_step,
                        baseline=None,
                        preload=True)

    events, events_dict = mne.events_from_annotations(raw_filt)

    # Instancing the autoreject object
    ar = AutoReject(random_state=random_state,
                    picks=mne.pick_types(epochs_ica.info, 
                                        eeg=True,
                                        eog=False
                                        ),
                    n_jobs=-1, 
                    verbose=False
                    )

    ar.fit(epochs_ica)
    reject_log = ar.get_reject_log(epochs_ica)


    # Fit ICA
    ica = ICA(n_components=ica_n_components,
                                random_state=random_state,
                                )
    ica.fit(epochs_ica[~reject_log.bad_epochs], decim=3)

    ica.exclude = []
    num_excl = 0

    while num_excl < max_ic:
        eog_indices, eog_scores = ica.find_bads_eog(epochs_ica,
                                                    ch_name=['Fp1', 'Fp2', 'AF7', 'AF8', 'F7', 'F8', 'Fpz'], 
                                                    threshold=z_thresh
                                                    )
        num_excl = len(eog_indices)
        z_thresh -= z_step # won't impact things if num_excl is ≥ n_max_eog 


    # assign the bad EOG components to the ICA.exclude attribute so they can be removed later
    ica.exclude = eog_indices

    return ica

In [23]:
# Create a list with each tipe of experimental run
openeye_runs = [1]
closedeye_runs = [2]
fists_runs = [3, 7, 11]
imaginefists_runs = [4, 8, 12]
fistsfeet_runs = [5, 9, 13]
imaginefistsfeet_run = [6, 10, 14]

# Parameters to charge the files
path = 'files/'
participants = [2, 3]
run = 3
preload = True

# Parameter for the first filtering
low_cut = 0.1 
high_cut = 30

# Parameters to creates the epochs
tmin =  -1.  # start of each epoch (in sec)
tmax =  4.1  # end of each epoch (in sec)
baseline = (-1, 0) # for the baseline correction we choose the interval that reflect the resting state before the event
event_id = {'rest': 1, 'left_fist': 2, 'right_fist': 3}

for participant in participants:
    for run in fists_runs:

        # Load the file and the data (preload = True)
        raw = read_raw_edf(path + f'S{participant:03}/S{participant:03}R{run:02}.edf', preload = preload)

        # Set the montage
        raw = set_montage(raw)

        # Filter the data
        raw_filt = filter_raw(raw, low_cut=low_cut, high_cut=high_cut)

        # Apply autoreject and then create the ICA object, and reject two IC components (normally eye blinks)
        ica = autoreject_ica_preprocessing(raw)

        # Create the epochs we want to study
        events, events_dict = mne.events_from_annotations(raw_filt)
        epochs = mne.Epochs(raw_filt,
                            events, event_id,
                            tmin = tmin,
                            tmax = tmax,
                            preload=True,
                            baseline = baseline
                        ) 

        # Apply the ICA we computed previously
        epochs_postica = ica.apply(epochs.copy())

        # Clean the epochs with autoreject
        ar = AutoReject(n_interpolate=[1, 2, 4],
                        random_state=42,
                        picks=mne.pick_types(epochs_postica.info, 
                                            eeg=True,
                                            eog=False
                                            ),
                        n_jobs=-1, 
                        verbose=False
                        )

        epochs_clean, reject_log_clean = ar.fit_transform(epochs_postica, return_log=True)

        # Change the reference
        epochs_clean_masref = epochs_clean.copy().set_eeg_reference(ref_channels=['T9', 'T10'])
        
        # Savin the data

        folder_path = 'files/' + f'S{participant:03}/preprocessing_ica'
        # Create the folder if it doesn't exist
        os.makedirs(folder_path, exist_ok=True)
        file_name = f'S{participant:03}R{run:02}filt01_30_ICA_masref.fif'
        # Save to EDF file in the created folder
        output_file = os.path.join(folder_path, file_name)
        epochs_clean_masref.save(output_file, overwrite=True)

