In [15]:
import numpy as np
import matplotlib.pyplot as plt
from matplotlib import gridspec

# mne library to analyse EEG
import mne
from mne import Epochs
from mne.channels import make_standard_montage
from mne.io import read_raw_edf
from mne.preprocessing import ICA
mne.set_log_level('error') # Avoid long log

# autorejct to reject bads epochs
from autoreject import AutoReject

# from mne import Epochs, pick_types
# from mne.decoding import CSP
# from mne.io import concatenate_raws, read_raw_edf

# from autoreject import AutoReject

In [4]:
# Function to load the data and set the montage
def set_montage(raw, montage_type: str = "standard_1020"):

    # f no other choise is made we try the "Standard_1020" montage
    montage = make_standard_montage(montage_type)

    # Dictionary with the structure old_name : correct_cases_name. To respect the upper and lower cases of the standard notation for the electrde's position.
    replacement = {
        'Fc': 'FC',
        'Cp': 'CP',
        'Af': 'AF',
        'Ft': 'FT',
        'Tp': 'TP',
        'Po': 'PO'   
    }

    
    # new_name is the dictionary to use to cange the name of the electrode's positions to respect the usual sandard notataions.
    # First get rid of the excessive "." 
    new_names = {
    name : name.replace(".", "") for name in raw.info['ch_names']
    }   

    # Change the lower and upper case of the electrode's names
    for key in new_names.keys():
        for old_string, new_string in replacement.items():
            new_names[key] = new_names[key].replace(old_string, new_string)

    # Rename channels and set the montage
   
    raw.rename_channels(new_names)
    raw.set_montage(montage)

    return raw

In [None]:
# Function to filter the data
def filter_raw(raw, low_cut: float = 0.1, high_cut: float = 30, preload = False):

    if preload == True:
        raw_filt = raw.copy().filter(low_cut, high_cut)
    else:
        raw_filt = raw.copy().load_data().filter(low_cut, high_cut)
    
    return raw_filt

# def create_epochs(raws_filt, tmin: float, tmax: float, baseline = None, event_id = {'rest': 1, 'left_fist': 2, 'right_fist': 3}, preload: bool = True):
#     epochs = [
#         Epochs(raw_filt, mne.events_from_annotations(raw_filt)[0], event_id, tmin=tmin, tmax=tmax, baseline= baseline, preload= preload)
#         for raw_filt in raws_filt
#     ]
#     return epochs

In [18]:
def autoreject_ica_preprocessing(raw, ica_low_cut: float = 1., ica_high_cut: float = 30.):

    # Creating a raw data to use to compute the ICA
    raw_ica = raw.copy().filter(ica_low_cut, ica_high_cut)

    # Creating the epochs for the artifact detection and the ICA
    time_step = 1 # Coerent with the time-span of the most common artifact (heart, eye blink)
    events_ica = mne.make_fixed_length_events(raw_ica, duration=time_step)
    
    epochs_ica = Epochs(raw_ica, events_ica,
                        tmin=0.0, tmax=time_step,
                        baseline=None,
                        preload=True)

    # Instancing the autoreject object
    ar = AutoReject(
                    random_state=42,
                    picks=mne.pick_types(epochs_ica.info, 
                                        eeg=True,
                                        eog=False
                                        ),
                    n_jobs=-1, 
                    verbose=False
                    )

    ar.fit(epochs_ica)
    reject_log = ar.get_reject_log(epochs_ica)

    # ICA parameters
    random_state = 0   # ensures ICA is reproducible each time it's run
    ica_n_components = 25    # Specify n_components as a decimal to set % explained variance

    # Fit ICA
    ica = ICA(n_components=ica_n_components,
                                random_state=random_state,
                                )
    ica.fit(epochs_ica[~reject_log.bad_epochs], decim=3)

    ica.exclude = []
    num_excl = 0
    max_ic = 2 # Chosing three seems to eliminate the eye movements and blinks.
    z_thresh = 3.5
    z_step = .05

    while num_excl < max_ic:
        eog_indices, eog_scores = ica.find_bads_eog(epochs_ica,
                                                    ch_name=['Fp1', 'Fp2', 'AF7', 'AF8', 'F7', 'F8', 'Fpz'], 
                                                    threshold=z_thresh
                                                    )
        num_excl = len(eog_indices)
        z_thresh -= z_step # won't impact things if num_excl is ≥ n_max_eog 


    # assign the bad EOG components to the ICA.exclude attribute so they can be removed later
    ica.exclude = eog_indices

    return ica

In [19]:
path = 'files/'
participant = 1
run = 3
preload = True

raw = read_raw_edf(path + f'S{participant:03}/S{participant:03}R{run:02}.edf', preload = preload)

raw = set_montage(raw)

raw_filt = filter_raw(raw)

ica = autoreject_ica_preprocessing(raw)
