# Replication of Clayson et al., 2023 Preprocessing scripts

In [None]:
import logging
from dataclasses import dataclass

import numpy as np
import mne
from mne.utils import set_log_file
from sklearn.neighbors import NearestNeighbors
import re
from collections import OrderedDict
import matplotlib.pyplot as plt
import pandas as pd
import json
import os
import math

SST INFO: 

Triale z reakcj¹ wyprzedzaj¹c¹ stop (przedwczesn¹) wyró¿nione zosta³y numerem 1 (w kontracie do 0) na czwartej pozycji w triggerze i zosta³y wykluczone z 
"Segmentation Global".

TODO

- SST-005 - sygna³ zreferowany wy³¹cznie do A1, poniewa¿ A2 jest zbyt zaszumiona
- SST-044 - sygna³ zreferowany wy³¹cznie do A2, poniewa¿ A1 jest mocniej zaszumiona


Constants

In [None]:
random_state = 42

Loggers

In [None]:
######## PREPROCESSING ##############################################
# Create a custom logger for preprocessing INFO
logger_preprocessing_info = logging.getLogger('preprocessing_info')
logger_preprocessing_info.setLevel(logging.INFO)

######## ERRORS ##############################################
# Create a custom logger for errors
logger_errors_info = logging.getLogger('errors')
logger_errors_info.setLevel(logging.INFO)

Utils

In [None]:
def read_trigger_map(file_name):
    line_count = 0
    trigger_map = []
    with open(file_name, 'r') as file:
        # Read each line and increment the counter
        line = file.readline()
        try:
            match = re.search("(.*):(.*)(\\n)", line)
            trigger = (match.group(1), match.group(2), )
            trigger_map.append(trigger)
        except:
            pass
        while line:
            line_count += 1
            line = file.readline()
            try:
                match = re.search("(.*):(.*)(\\n)", line)
                trigger = (match.group(1), match.group(2), )
                trigger_map.append(trigger)
            except:
                pass

    assert len(trigger_map) == line_count, \
        f'The length of trigger file ({line_count}) not equals length of created trigger_map ({len(trigger_map)})'

    return trigger_map

def create_triggers_dict(trigger_map):
    triggers_codes = [item[1] for item in trigger_map]
    # Create an ordered dictionary to maintain order and remove duplicates
    unique_ordered_dict = OrderedDict.fromkeys(triggers_codes)
    numbered_dict = {key: 1000 + number for number, key in enumerate(unique_ordered_dict.keys())}
    reversed_numbered_dict = {1000 + number: key for number, key in enumerate(unique_ordered_dict.keys())}
    return numbered_dict, reversed_numbered_dict

def replace_trigger_names(raw, participant_id, trigger_map, new_response_event_dict=None, replace=False, search='RE'):
    # Replace event IDs in the Raw object
    events = mne.find_events(raw, stim_channel='Status')
    new_events_list = events.copy()
    
    # add trigger to corrupted bdf files - too short reaction 
    if paradigm == 'GNG' and (participant_id == 'B-GNG-199' or participant_id == 'B-GNG-208'):
        delta_time = 3
        for idx, event in enumerate(events):
            event_id = str(event[2])[-1]
            trigger_id = trigger_map[idx][0]
            trigger_new_code = trigger_map[idx][1]

            if event_id != trigger_id:
                new_events_list = np.concatenate([events[:idx, :], [[events[idx-1][0] + delta_time, 0, 65281]], events[idx:, :]]) 
                break
    # add missing RE triggers to bdf file - to short time between stop trigger and reaction trigger
    ids = ['SST-165', 'SST-211', 'SST-122', 'SST-088','SST-045','SST-012','SST-083','SST-136','SST-125']
    if paradigm == 'SST' and participant_id in ids:
        delta_time = 3
        for idx, event in enumerate(events):
            event_id = str(event[2])[-1]
            trigger_id = trigger_map[idx][0]
            trigger_new_code = trigger_map[idx][1]

            if event_id != trigger_id:
                new_events_list = np.concatenate([events[:idx, :], [[events[idx-1][0] + delta_time, 0, 65281]], events[idx:, :]]) 
                break
    
    # delete ghost event 65312
    if paradigm == 'SST' and (participant_id == 'SST-181'):
        for idx, event in enumerate(events):
            if event[2] == 65312:
                new_events_list = np.concatenate([events[:idx, :], events[idx+1:, :]]) 
                break  
    
    # delete ghost event 0: 130816
    if paradigm == 'SST' and (participant_id == 'SST-075'):
        for idx, event in enumerate(events):
            if event[2] == 130816:
                new_events_list = np.concatenate([events[:idx, :], events[idx+1:, :]]) 
                break   
    
    if paradigm == 'SST' and participant_id == 'SST-130':
        
        delta_time = 3
        for idx, event in enumerate(events):
            event_id = str(event[2])[-1]
            trigger_id = trigger_map[idx][0]
            trigger_new_code = trigger_map[idx][1]

            if event_id != trigger_id:
                new_events_list = np.concatenate([events[:idx, :], [[events[idx-1][0] + delta_time, 0, 65281]], events[idx:, :]]) 
                break

        for idx, event in enumerate(new_events_list):
            event_id = str(event[2])[-1]
            trigger_id = trigger_map[idx][0]
            trigger_new_code = trigger_map[idx][1]

            if event_id != trigger_id:
                new_events_list = np.concatenate([new_events_list[:idx, :], [[new_events_list[idx-1][0] + delta_time, 0, 65281]], new_events_list[idx:, :]]) 
                break
                           
    logger_preprocessing_info.info(f'EVENTS: {new_events_list}')

    assert len(new_events_list) == len(trigger_map), \
            f'The length of trigger map ({len(trigger_map)}) not equals length of events in eeg recording ({len(new_events_list)})'

    trigger_map_codes, mapping = create_triggers_dict(trigger_map)

    for idx, event in enumerate(new_events_list):
        event_id = str(event[2])[-1]
        trigger_id = trigger_map[idx][0]
        trigger_new_code = trigger_map[idx][1]
        
        if event_id != trigger_id:
            logger_errors_info.info(f'An event {idx} has different number than in provided file. {trigger_id} expected, {str(event[2])} found. Triggers may need to be checked.')

        trigger_new_code_int = trigger_map_codes[trigger_new_code]
        new_events_list[idx][2] = trigger_new_code_int

    annot_from_events = mne.annotations_from_events(
        events=new_events_list,
        event_desc=mapping,
        sfreq=raw.info["sfreq"],
        orig_time=raw.info["meas_date"],
    )
    raw_copy = raw.copy()
    raw_copy.set_annotations(annot_from_events)

    return raw_copy

def find_items_matching_regex(dictionary, regex_list):
    matching_items = {}
    for regex in regex_list:
        pattern = re.compile(regex)
        matching_items.update({key: value for key, value in dictionary.items() if pattern.match(key)})
    return matching_items

@dataclass
class ParticipantTriggerMappingContext:
    event_dict: dict
    events_mapping: dict
    new_event_dict: dict
    
    def __str__(self):
        return f"{self.event_dict}\n{self.events_mapping}\n{self.new_event_dict}"

def create_events_mappings(trigger_map, case='RE') -> ParticipantTriggerMappingContext:
    trigger_map_codes, mapping = create_triggers_dict(trigger_map)

    if case == 'RE':
        new_event_dict = {"correct_response": 0, "error_response": 1, "incorrect_go_response": 2}
        events_mapping = {
            'correct_response': [],
            'error_response': [],
            'incorrect_go_response' : []
        }
        
        # find response events from experimental blocks
        regex_pattern = [r'RE\*image\*.*\*0\*.*', r'RE\*image\*.*\*-\*.*']
        event_dict = find_items_matching_regex(trigger_map_codes, regex_pattern)
    
        for event_id in event_dict.keys():
            event_id_splitted = event_id.split('*')
    
            if (event_id_splitted[3] == '-') and (event_id_splitted[-1] == event_id_splitted[-2]):
                events_mapping['correct_response'].append(event_dict[event_id])
            elif (event_id_splitted[3] == '-') and (event_id_splitted[-1] != event_id_splitted[-2]):
                events_mapping['incorrect_go_response'].append(event_dict[event_id])
            elif (str(event_id_splitted[3]) == '0') and (event_id_splitted[-1] != '-'):
                events_mapping['error_response'].append(event_dict[event_id])
    
        
    elif case == 'STIM':
        new_event_dict = {"inhibited_stop": 0, "uninhibited_stop": 1}
        events_mapping = {
            'inhibited_stop': [],
            'uninhibited_stop': [],
        }
        # find all target stimuli events from experimental blocks
        regex_pattern = [r'ST.*']
        event_dict = find_items_matching_regex(trigger_map_codes, regex_pattern)

        for event_id in event_dict.keys():
            event_id_splitted = event_id.split('*')

            if (event_id_splitted[-1] == '-') and (event_id_splitted[3] == '0'):
                events_mapping['inhibited_stop'].append(event_dict[event_id])
            elif (event_id_splitted[-1] != '-') and (event_id_splitted[3] == '0'):
                events_mapping['uninhibited_stop'].append(event_dict[event_id])

    else:
        logger_errors_info('Not known case. Possible cases: \'RE\' for response, \'STIM\` for stimuli, and \`FBCK\` for feedback-locked events extraction.')
        # todo raise an Error
        event_dict = {}
        events_mapping = {}
        new_event_dict = {}


    return ParticipantTriggerMappingContext(event_dict=event_dict, 
                                            events_mapping=events_mapping,
                                            new_event_dict=new_event_dict)

def create_epochs(
        raw,
        context: ParticipantTriggerMappingContext,
        tmin=-.1,
        tmax=.6,
        reject=None,
        reject_by_annotation=False,
):
    # select specific events
    events, event_ids = mne.events_from_annotations(raw, event_id=context.event_dict)

    # Merge different events of one kind
    for mapping in context.events_mapping:
        events = mne.merge_events(
            events=events,
            ids=context.events_mapping[mapping],
            new_id=context.new_event_dict[mapping],
            replace_events=True,
        )

    # Read epochs
    epochs = mne.Epochs(
        raw=raw,
        events=events,
        event_id=context.new_event_dict,
        tmin=tmin,
        tmax=tmax,
        baseline=None,
        reject_by_annotation=reject_by_annotation,
        preload=True,
        reject=reject,
        picks=['eeg', 'eog'],
        on_missing = 'warn',
    )
    
    return epochs



# def ocular_correction_ica(raw, raw_unfiltered, veog=[], heog=[], info=False, from_template='auto', context=None):
#     filtered_raw_ica = raw_unfiltered.copy().drop_channels(['EXG7', 'EXG8']).filter(l_freq=1.0, h_freq=None)
    
#     ica_epochs = create_epochs(
#         filtered_raw_ica,
#         context,
#     )

#     ica = mne.preprocessing.ICA(
#         n_components=15,
#         method='infomax',
#         max_iter="auto",
#         random_state=random_state
#     )
#     ica.fit(filtered_raw_ica)
    
#     ica.exclude = []
    
#     # find which ICs match the VEOG pattern
#     veog_indices_eogs, veog_scores = ica.find_bads_eog(
#         filtered_raw_ica,
#         ch_name=veog,
#         threshold=0.9,
#         measure='correlation'
#     )

#     # find which ICs match the HEOG pattern
#     heog_indices_eogs, heog_scores = ica.find_bads_eog(
#         filtered_raw_ica,
#         ch_name=heog,
#         threshold=0.6,
#         measure='correlation'
#     )
    
#     logger_preprocessing_info.info(f'EOG based excluded ICA components:\nVEOG: {veog_indices_eogs}\nHEOG: {heog_indices_eogs}')

    
#     heog_indices = []
#     veog_indices = []
    
#     if info:
#         logger_preprocessing_info.info('ICA components')
#         fig = ica.plot_components()
  
#     if (len(veog_indices_eogs + heog_indices_eogs) == 0 and from_template == 'auto') or from_template == True:
#         logger_preprocessing_info.info('Using templates...')
#         templates_ica = pd.read_pickle('public_data/eog_templates.pkl')

#         template_veog_component = templates_ica['VEOG'].to_numpy()[0].flatten()
#         template_heog_component = templates_ica['HEOG'].to_numpy()[0].flatten()

#         # todo: make try - except and if faile - do another ICA with 20 components. Default amount of components: 15
        
#         try:
#             mne.preprocessing.corrmap(
#                 [ica], 
#                 template=template_veog_component, 
#                 threshold=0.9, 
#                 label="veog blink", 
#                 plot=False
#             )
#             mne.preprocessing.corrmap(
#                 [ica], 
#                 template=template_heog_component, 
#                 threshold=0.8, 
#                 label="heog blink", 
#                 plot=False
#             )

#             veog_indices = ica.labels_['veog blink']
#             heog_indices = ica.labels_['heog blink']
            
#             if (len(veog_indices_eogs) == 1) and (len(veog_indices) != 1) and (veog_indices_eogs[0] in veog_indices):
#                 veog_indices = veog_indices_eogs
#                 logger_errors_info.info(f'Too much VEOG patterns found. Comparing to VEOG channels')
            
#             if (len(heog_indices_eogs) == 1) and (len(heog_indices) != 1) and (heog_indices_eogs[0] in heog_indices):
#                 heog_indices = heog_indices_eogs
#                 logger_errors_info.info(f'Too much HEOG patterns found. Comparing to HEOG channels')
            
#         except:
#             logger_errors_info.info(f'No patterns found. Trying to use results from EOG channels')
#             if (len(veog_indices_eogs) == 1) and (len(veog_indices) == 0):
#                 veog_indices = veog_indices_eogs
#             if (len(heog_indices_eogs) == 1) and (len(heog_indices) == 0):
#                 heog_indices = heog_indices_eogs

#     logger_preprocessing_info.info(f'Excluded ICA components:\nVEOG: {veog_indices}\nHEOG: {heog_indices}')
#     fig = ica.plot_components(veog_indices + heog_indices)
    
#     assert len((veog_indices + heog_indices)) == 2, f'Number of ICA components to exclude ({len(veog_indices + heog_indices)}) is different than 2' 
    
#     ica.exclude = veog_indices + heog_indices
    
#     reconstructed_raw = raw.copy()
#     ica.apply(reconstructed_raw)
    
#     del filtered_raw_ica

#     return reconstructed_raw

def ica_epochs(filtered_raw_ica, participant_context):
    ica_epochs = create_epochs(
        filtered_raw_ica,
        participant_context,
    )

    reject_criteria = dict(eeg=200e-6)
    ica_epochs.drop_bad(reject=reject_criteria)


    ica = mne.preprocessing.ICA(
        n_components=.99,
        method='infomax',
        max_iter="auto",
        random_state=random_state
    )
    # ica.fit(filtered_raw_ica)
    ica.fit(ica_epochs)
    
    heog_indices = []
    veog_indices = []

    print('ICA components')
    fig = ica.plot_components()

    templates_ica = pd.read_pickle('public_data/eog_templates.pkl')

    template_veog_component = templates_ica['VEOG'].to_numpy()[0].flatten()
    template_heog_component = templates_ica['HEOG'].to_numpy()[0].flatten()

    mne.preprocessing.corrmap(
        [ica], 
        template=template_veog_component,
        threshold=0.9,
        label="veog blink", 
        plot=False
    )

    mne.preprocessing.corrmap(
        [ica], 
        template=template_heog_component, 
        label="heog blink", 
        threshold=0.8,
        plot=False
    )

    veog_indices = ica.labels_['veog blink']
    heog_indices = ica.labels_['heog blink']

    logger_preprocessing_info.info(f'Pattern based excluded epochs-based ICA components:\nVEOG: {veog_indices}\nHEOG: {heog_indices}')
    
    return veog_indices, heog_indices
    

def ocular_correction_ica(raw, raw_unfiltered, veog=[], heog=[], info=False, from_template='auto', context=None):
    filtered_raw_ica = raw_unfiltered.copy().drop_channels(['EXG7', 'EXG8']).filter(l_freq=1.0, h_freq=None)
    
    # ica_epochs = create_epochs(
    #     filtered_raw_ica,
    #     context,
    # )

    ica = mne.preprocessing.ICA(
        n_components=30,
        method='infomax',
        max_iter="auto",
        random_state=random_state
    )
    ica.fit(filtered_raw_ica)
    
    ica.exclude = []
    
    # find which ICs match the VEOG pattern
    veog_indices_eogs, veog_scores = ica.find_bads_eog(
        filtered_raw_ica,
        ch_name=veog,
        threshold=0.9,
        measure='correlation'
    )

    # find which ICs match the HEOG pattern
    heog_indices_eogs, heog_scores = ica.find_bads_eog(
        filtered_raw_ica,
        ch_name=heog,
        threshold=0.6,
        measure='correlation'
    )
    
    logger_preprocessing_info.info(f'EOG based excluded ICA components:\nVEOG: {veog_indices_eogs}\nHEOG: {heog_indices_eogs}')

    
    heog_indices = []
    veog_indices = []
    
    if info:
        logger_preprocessing_info.info('ICA components')
        fig = ica.plot_components()
  
    if (len(veog_indices_eogs + heog_indices_eogs) == 0 and from_template == 'auto') or from_template == True:
        logger_preprocessing_info.info('Using templates...')
        templates_ica = pd.read_pickle('public_data/eog_templates.pkl')

        template_veog_component = templates_ica['VEOG'].to_numpy()[0].flatten()
        template_heog_component = templates_ica['HEOG'].to_numpy()[0].flatten()

        # todo: make try - except and if faile - do another ICA with 20 components. Default amount of components: 15
        
        try:
            mne.preprocessing.corrmap(
                [ica], 
                template=template_veog_component, 
                threshold=0.9, 
                label="veog blink", 
                plot=False
            )
            veog_indices = ica.labels_['veog blink']
            
            if (len(veog_indices_eogs) == 1) and (len(veog_indices) != 1) and (veog_indices_eogs[0] in veog_indices):
                veog_indices = veog_indices_eogs
                logger_errors_info.info(f'Too much VEOG patterns found. Comparing to VEOG channels')
        except:
            logger_errors_info.info(f'No VEOG patterns found. Trying to use results from EOG channels')
            if (len(veog_indices_eogs) == 1) and (len(veog_indices) == 0):
                veog_indices = veog_indices_eogs
            
        try:  
            mne.preprocessing.corrmap(
                [ica], 
                template=template_heog_component, 
                threshold=0.8, 
                label="heog blink", 
                plot=False
            )

            heog_indices = ica.labels_['heog blink']

            
            if (len(heog_indices_eogs) == 1) and (len(heog_indices) != 1) and (heog_indices_eogs[0] in heog_indices):
                heog_indices = heog_indices_eogs
                logger_errors_info.info(f'Too much HEOG patterns found. Comparing to HEOG channels')
            
        except:
            logger_errors_info.info(f'No HEOG patterns found. Trying to use results from EOG channels')
            if (len(heog_indices_eogs) == 1) and (len(heog_indices) == 0):
                heog_indices = heog_indices_eogs

    logger_preprocessing_info.info(f'Excluded ICA components:\nVEOG: {veog_indices}\nHEOG: {heog_indices}')
    fig = ica.plot_components(veog_indices + heog_indices)
    
    if len((veog_indices + heog_indices)) != 2:
        logger_errors_info.info('Number of raw-based ICA components is different than 2. Trying epoch-based ICA...')
        veog_indices, heog_indices = ica_epochs(filtered_raw_ica, participant_context)
        
    assert len((veog_indices + heog_indices)) == 2, f'Number of ICA components to exclude ({len(veog_indices + heog_indices)}) is different than 2' 
    
    
    ica.exclude = veog_indices + heog_indices
    
    reconstructed_raw = raw.copy()
    ica.apply(reconstructed_raw)
    
    del filtered_raw_ica

    return reconstructed_raw

def get_k_nearest_neighbors(target_ch_name, epochs, k=6):
    """
    Finds k nearest neighbors of given channel according to the 3D channels positions from the Epoch INFO
    :param target_ch_name: String
        Name of the target channel.
    :param epochs: mne Epochs
        Epochs with info attribute that consists of channels positions.
    :param k: int
        Number of neighbors to use by default for kneighbors queries. 
    :return: 
        indices: ndarray of shape (n_neighbors)
            Indices of the nearest channels.
        neighbor_ch_names: ndarray of shape (n_neighbors)
            Names of the nearest channels.   
    """
    epochs_copy_eeg_channels = epochs.copy().pick('eeg')
    info = epochs_copy_eeg_channels.info
    ch_names = epochs_copy_eeg_channels.info['ch_names']
    
    chs = [info["chs"][pick] for pick in np.arange(0,len(ch_names))]
    electrode_positions_3d =[]
    
    for ch in chs:
        electrode_positions_3d.append((ch['ch_name'], ch["loc"][:3]))
    
    neighbors_model = NearestNeighbors(n_neighbors=k+1, algorithm='auto')
    ch_coordinates = np.array([ch_name_coordinates[1] for ch_name_coordinates in electrode_positions_3d])

    neighbors_model.fit(ch_coordinates)

    target_ch_coordinates = np.array([ch_name_coordinates[1] for ch_name_coordinates in electrode_positions_3d if ch_name_coordinates[0] == target_ch_name])
    
    target_ch_coordinates = target_ch_coordinates.reshape(1,-1)

    distances, indices = neighbors_model.kneighbors(target_ch_coordinates)
    neighbor_ch_names = []
    
    # Log the nearest neighbors without the first (self) neighbor
    logger_preprocessing_info.debug(f"{k} Nearest Neighbors of {target_ch_name}:")
    for i, (distance, index) in enumerate(zip(distances.flatten(), indices.flatten())):
        if i == 0:
            pass
        else:
            neighbor_point = electrode_positions_3d[index]
            logger_preprocessing_info.debug(f"Neighbor {i + 1}: Index {index}, Distance {distance:.2f}, Coordinates {neighbor_point}")
            neighbor_ch_names.append(neighbor_point[0])
            
    return indices.flatten()[1:], np.array(neighbor_ch_names)

def find_bad_trails(epochs):
    """
    Channels that meet following conditions will be marked as bad for the trail:
        (1) Channels with a voltage difference of 100 μV through the duration of the epoch;
        (2) Channels that were flat;
        (3) Channels with more than a 30 μV difference with the nearest six neighbors; 
    :param mne Epochs 
        Epochs to find bad channels per trial. 
    :return: drop_log: tuple of n_trials length
        Tuple representing bad channels names per trial.
    """
    epochs_copy = epochs.copy()
    
    # channels with a voltage difference of 100 μV through the duration of the epoch
    reject_criteria = dict(eeg=100e-6)
    # flat channels (less than 1 µV of peak-to-peak difference)
    flat_criteria = dict(eeg=1e-6)
    
    epochs_copy.drop_bad(reject=reject_criteria, flat=flat_criteria)
    drop_log = epochs_copy.drop_log

    # channels with more than a 30 μV difference with the nearest six neighbors
    for idx, _ in enumerate(epochs_copy):
        epoch = epochs[idx]
        epoch_data = epoch.get_data(copy=True)
        for ch_name, ch_idx in zip(epochs_copy.info['ch_names'], np.arange(0, len(epochs_copy.info['ch_names']))):
            mean_channel_data = np.array(np.mean(epoch_data[0,ch_idx,:]))
    
            ch_neighbors_indices, ch_neighbors_names = get_k_nearest_neighbors(
                target_ch_name = ch_name,
                epochs = epoch,
                k=6
            )
    
            mean_neighbors_data = np.array([np.mean(epoch_data[0, ch_neighbor_index, :]) 
                                            for ch_neighbor_index in ch_neighbors_indices])
    
            # # if channels has more than a 30 μV difference with the nearest six neighbors
            if (abs(mean_neighbors_data - mean_channel_data) > 30e-6).all():
                logger_preprocessing_info.info(f'Channel {ch_name} has more than a 30 μV difference with the nearest six neighbors at {idx} trail.\n Mean channel data: {mean_channel_data}\nMean neighbors data: {mean_neighbors_data}')
    
                new_drop_log_item = drop_log[idx] + (ch_name, ) if ch_name not in drop_log[idx] else drop_log[idx]
                drop_log = tuple(new_drop_log_item if i == idx else item for i, item in enumerate(drop_log))

    del epochs_copy
    
    return drop_log

def calculate_percentage(tuple_of_tuples, element):
    total_tuples = len(tuple_of_tuples)
    # Avoid division by zero
    if total_tuples == 0:
        return 0 

    tuples_with_element = sum(1 for inner_tuple in tuple_of_tuples if element in inner_tuple)
    percentage = (tuples_with_element / total_tuples)
    return percentage

def find_global_bad_channels(epochs, drop_log):
    '''
    (1) Channels with an absolute correlation with the nearest six neighboring channels that fell below .4;
    (2) Channels that were marked as bad for more than 20% of epochs;
    :param epochs: 
    :param drop_log: 
    :return: 
    '''
    epochs_copy = epochs.copy()
    epochs_data = epochs_copy.get_data(copy=True)
    concatenated_epochs_data = np.concatenate(epochs_data, axis=1)
    
    global_bad_channels_drop_log = {}
    
    # (1) Channels with an absolute correlation with the nearest six neighboring channels that fell below .4
    for ch_name, ch_idx in zip(epochs_copy.info['ch_names'], np.arange(0, len(epochs_copy.info['ch_names']))):
        channel_data = concatenated_epochs_data[ch_idx]
    
        ch_neighbors_indices, ch_neighbors_names = get_k_nearest_neighbors(
            target_ch_name = ch_name,
            epochs = epochs_copy,
            k=6
        )
        channel_neighbors_data = np.array([concatenated_epochs_data[ch_neighbor_index]
                                           for ch_neighbor_index in ch_neighbors_indices])
        channels_corr = np.tril(np.corrcoef(channel_neighbors_data, channel_data), k=-1)
    
        if (abs(channels_corr[-1][:-1]) < .4).all():
            logger_preprocessing_info.info(f'Channel {ch_name} has < 0.4 abs corr with six nearest neighbors. Set as globally BAD. Channel corrs with neighbors: {channels_corr[-1][:-1]}')
    
            # mark channel as globally bad
            global_bad_channels_drop_log[ch_name] = ['LOW CORR NEIGH']
            # update drop_log
            drop_log = tuple(drop_log[i] + (ch_name,)
                             if ch_name not in drop_log[i] else drop_log[i] for i, item in enumerate(drop_log))
    
    # (2) Channels that were marked as bad for more than 20% of epochs   
    for ch_name, ch_idx in zip(epochs_copy.info['ch_names'], np.arange(0, len(epochs_copy.info['ch_names']))):
        percentage = calculate_percentage(drop_log, ch_name)
    
        if percentage > 0.2:
            logger_preprocessing_info.info(f'Channel {ch_name} is bad for {percentage} percent of epochs. Set as globally BAD.')
    
            if ch_name in global_bad_channels_drop_log:
                global_bad_channels_drop_log[ch_name].append('BAD FOR MORE THAN 20%')
            else:
                global_bad_channels_drop_log[ch_name] = ['BAD FOR MORE THAN 20%']
    
            # update drop_log
            drop_log = tuple(drop_log[i] + (ch_name,)
                             if ch_name not in drop_log[i] else drop_log[i] for i, item in enumerate(drop_log))
            
    del epochs_copy
            
    return drop_log, global_bad_channels_drop_log

def mark_bad_trials(epochs, drop_log, threshold=0.1):
    """
    If more than 10% of channels were marked bad for an epoch (trial), the entire epoch was rejected
    :param drop_log: 
    :return: trials_to_drop: ndarray
    '"""
    trials_to_drop_indices = []
    epochs_copy = epochs.copy()

    assert len(epochs_copy) == len(drop_log), f'Length of epochs ({len(epochs_copy)}) not equals length of drop_log ({len(drop_log)}). Cannot mark trials as BAD.'
    threshold_items = int(threshold * len(epochs_copy.info['ch_names']))
    for idx, item in enumerate(drop_log):
        if len(item) > threshold_items:
            logger_preprocessing_info.info(f'More than 10% of channels are bad for trial {idx}. Trial set as \'TO DROP\'')
            trials_to_drop_indices.append(idx)
    
    # update drop_log
    for trial_idx in trials_to_drop_indices:
        drop_log = tuple(drop_log[i] + ('TO DROP',) if i == trial_idx else item for i, item in enumerate(drop_log))
    
    return trials_to_drop_indices, drop_log


def reject_bad_trials(epochs, drop_log, trials_to_drop_indices=None):
    clean_epochs = epochs.copy()
    
    clean_epochs = clean_epochs.drop(
        indices = trials_to_drop_indices,
        reason = 'MORE THAN 10% CHANNELS MARKED AS BAD',
    )
    
    # update drop_log
    for trial_idx in trials_to_drop_indices:
        drop_log = tuple(('REJECTED',) if i == trial_idx else element for i, element in enumerate(drop_log))
    
    return clean_epochs, drop_log

def interpolate_bad_channels(epochs, drop_log, global_bad_channels_drop_log):
    """
    Bad channels were interpolated using spherical splines
    :param epochs: 
    :param drop_log: 
    :param global_bad_channels_drop_log: 
    :return: 
    """

    epochs_copy = epochs.copy()
    epochs_copy.info['bads'] = list(global_bad_channels_drop_log.keys())
    epochs_interpolated_bad_channels = epochs_copy.interpolate_bads(method='spline')

    # update drop log to remove interpolated channels
    updated_drop_log = drop_log
    for ch_name in list(global_bad_channels_drop_log.keys()):
        updated_drop_log = tuple(tuple(element for element in drop_log_item if element != ch_name) for drop_log_item in updated_drop_log)
    
    return epochs_interpolated_bad_channels, updated_drop_log

def create_erps_waves(epochs, drop_log, new_response_event_dict, type = 'error_response', tmin=0, tmax=0.1, picks=['FCz']):
    epochs_copy = epochs.copy()
    epochs_data = epochs_copy.get_data(copy=True, picks=picks, tmin=tmin, tmax=tmax)
    events = epochs_copy.events

    eeg_data = []

    if (len(epochs_data) == len(drop_log)) & (len(events) == len(drop_log)):
        for idx, item in enumerate(drop_log):
            if events[idx][-1] == new_response_event_dict[type]:
                if ('TO DROP' in item) or any(element in item for element in picks):
                    pass
                else:
                    eeg_data.append(epochs_data[idx])
            else:
                pass
        return np.array(eeg_data)
    else:
        logger_preprocessing_info.info(f'Epochs length is not equal to drop_log length:\nepochs: {len(epochs_data)}\ndrop_log: {len(drop_log)}')
        return None

def create_erps(epochs, drop_log, new_response_event_dict, type = 'error_response', tmin=0, tmax=0.1, picks=['FCz']):
    epochs_copy = epochs.copy()
    epochs_data = epochs_copy.get_data(copy=True, picks=picks, tmin=tmin, tmax=tmax)
    events = epochs_copy.events

    erps_data = []

    if (len(epochs_data) == len(drop_log)) & (len(events) == len(drop_log)):
        for idx, item in enumerate(drop_log):
            if events[idx][-1] == new_response_event_dict[type]:
                if ('TO DROP' in item) or any(element in item for element in picks):
                    erps_data.append([None])
                else:
                    erps_data.append(np.mean(epochs_data[idx], axis=-1))
            else:
                pass
        return np.array(erps_data)
    else:
        logger_preprocessing_info.info(f'Epochs length is not equal to drop_log length:\nepochs: {len(epochs_data)}\ndrop_log: {len(drop_log)}')
        return None

In [None]:
def pre_process_eeg(input_fname, participant_id, context, case='RE', trigger_fname=None):
    # check if exists ICA cache
    try:
        raw_corrected_eogs = mne.io.read_raw_fif(f'{ica_cache_dir_path}{participant_id}_raw.fif.gz', preload=True)
    except:
        # 0. read bdf
        raw = mne.io.read_raw_bdf(
            input_fname,
            eog=['EXG1', 'EXG2', 'EXG3', 'EXG4'],
            exclude=['EXG5', 'EXG6'],
            preload=True
        )

        try:
            raw = raw.set_montage('biosemi64')
        except ValueError as e:
            if '[\'EXG7\', \'EXG8\']' in e.args[0]:
                raw = raw.set_montage('biosemi64', on_missing='ignore')
                logger_preprocessing_info.info('On missing')
            else:
                logger_preprocessing_info.info('Lacks important channels!')

        # 1. replace trigger names
        trigger_map = read_trigger_map(trigger_fname)
        raw_new_triggers = replace_trigger_names(raw, participant_id, trigger_map)

        # 2. re-reference: to mastoids
        raw_ref = raw_new_triggers.copy().set_eeg_reference(ref_channels=['EXG7', 'EXG8'])

        # 3. 4-th order Butterworth filters
        raw_filtered = raw_ref.copy().filter(
            l_freq=.1,
            h_freq=30.0,
            n_jobs=10,
            method='iir',
            iir_params=None,
        )

        # 4. Notch filter at 50 Hz
        raw_filtered = raw_filtered.notch_filter(
            freqs=np.arange(50, (raw_filtered.info['sfreq'] / 2), 50),
            n_jobs=10,
        )

        # 5. ocular correction with ICA
        raw_corrected_eogs = ocular_correction_ica(
            raw_filtered, 
            raw_ref, 
            heog=['EXG3', 'EXG4'], 
            veog=['EXG1', 'EXG2'], 
            from_template=True, 
            info=True,
            context=context,
        )

        raw_corrected_eogs.save(f'{ica_cache_dir_path}{participant_id}_raw.fif.gz', overwrite=True)

    # 6. segmentation -400 to 800 ms around the response
    raw_corrected_eogs_drop_ref = raw_corrected_eogs.copy().drop_channels(['EXG7', 'EXG8']).pick('eeg')

    epochs = create_epochs(
        raw_corrected_eogs_drop_ref,
        context=context,
        tmin=-.4,
        tmax=.8,
        reject=None,
        reject_by_annotation=False,
    )

    # 7. Trial-wise Bad Channels Identification
    drop_log = find_bad_trails(epochs)

    # 8. Global Bad Channel Identification - <.4 corr with 6 neigh. and channels marked as bad for more than 20% trials
    drop_log, global_bad_channels_drop_log = find_global_bad_channels(epochs, drop_log)

    # 9. calculate trails to remove
    trials_to_drop_indices, drop_log = mark_bad_trials(epochs, drop_log, threshold=0.1)

    # 10. Interpolate bad channels (and thus update drop log)
    interpolated_epochs, drop_log = interpolate_bad_channels(epochs, drop_log, global_bad_channels_drop_log)

    # 11. Remove participants that have less then 6 trials
    clean_epochs, _ = reject_bad_trials(interpolated_epochs.copy(), drop_log, trials_to_drop_indices)
    if len(clean_epochs) < 6:
        logger_errors_info.info(f"Participant ID: {participant_id} has not enough clean trials")

    # 13. Baseline correction
    if case == 'RE':
        interpolated_epochs.apply_baseline(baseline=(-0.4, -0.2),)
    elif case == 'STIM':
        interpolated_epochs.apply_baseline(baseline=(-0.2, 0),)
    else:
        logger_errors_info.info('Not know case. Setting baseline from -0.2 to 0')
        interpolated_epochs.apply_baseline(baseline=(-0.2, 0),)

    return interpolated_epochs, drop_log

In [None]:
def save_epochs_with_drop_log_separately(epochs, drop_log, participant_id):
    # save drop_log
    with open(f'{preprocessed_data_dir_path}drop_log_{participant_id}.json', 'w') as fjson:
        json.dump(drop_log, fjson)

    # save Epoch object
    epochs.save(f'{preprocessed_data_dir_path}preprocessed_{participant_id}-epo.fif', overwrite=True)

    return logger_preprocessing_info.info('Epochs saved to fif. Drop log saved to json.')

In [None]:
def save_epochs_with_drop_log(epochs, drop_log, participant_id):
    item = pd.DataFrame({
        'epochs': [epochs],
        'drop_log': [drop_log],
    })
    
    item.to_pickle(f'{preprocessed_data_dir_path}preprocessed_{participant_id}.pkl')
    
    return logger_preprocessing_info.info('Epochs saved to pickle.')

In [None]:
def read_behavioral_file(participant_id):
    
    behavioral_data_df = pd.read_csv(f'{behavioral_dir_path}beh_{participant_id}.csv')

    trial_numerator = 1
    trial_numbers = []
    for i in range(0, len(behavioral_data_df)):
        # if behavioral_data_df.iloc[i]['block_type'] != 'experiment':
        #     trial_numbers.append(0)
        # else:
        trial_numbers.append(trial_numerator)
        trial_numerator+=1
    
    behavioral_data_df['trial number'] = trial_numbers
    return behavioral_data_df

In [None]:
def save_epochs_with_behavioral_data_long(epochs, drop_log, participant_id, case='RE'):
    
    if paradigm == 'GNG':
    
        # read behavioral file
        behavioral_data_df = read_behavioral_file(participant_id)

        beh_data_uninhibited_nogo_responses_df = behavioral_data_df[
            (behavioral_data_df['block type'] == 'experiment') &
            (behavioral_data_df['trial type'] != 'go') &
            (behavioral_data_df['reaction'] == False)
            ]
        logger_preprocessing_info.info(f'Number of uninhibited NOGO trials: {len(beh_data_uninhibited_nogo_responses_df)}')

        beh_data_inhibited_nogo_responses_df = behavioral_data_df[
            (behavioral_data_df['block type'] == 'experiment') &
            (behavioral_data_df['trial type'] != 'go') &
            (behavioral_data_df['reaction'] == True)
            ]
        logger_preprocessing_info.info(f'Number of inhibited NOGO trials: {len(beh_data_inhibited_nogo_responses_df)}')

        beh_data_correct_go_responses_df = behavioral_data_df[
            (behavioral_data_df['block type'] == 'experiment') &
            (behavioral_data_df['trial type'] == 'go') &
            (behavioral_data_df['response'] == 'num_separator')
            ]
        logger_preprocessing_info.info(f'Number correct GO trials: {len(beh_data_correct_go_responses_df)}')

        results_df = pd.DataFrame()
        epochs_df = pd.DataFrame()
        behavioral_df = pd.DataFrame()

        if case == 'RE':
            behavioral_df = pd.concat([beh_data_uninhibited_nogo_responses_df, beh_data_correct_go_responses_df]).sort_values(by='trial number')

            logger_preprocessing_info.info(f'Len drop log: {len(drop_log)}')
            logger_preprocessing_info.info(f'Len behavioral df: {len(behavioral_df)}')
            assert len(behavioral_df) == len(drop_log), f'Number of events read from behavioral file ({len(behavioral_df)}) not equals number of events from drop_log ({len(drop_log)})'

            for idx, _ in enumerate(epochs):
                epoch = epochs[idx]
                epoch_type = list(epoch.event_id.keys())
                assert len(epoch_type) == 1, \
                    f'Single trial is not single. Length of epoch: {len(epoch_type)}. Error during trial-wise saving.'
                drop_log_item = drop_log[idx]

                this_df = pd.DataFrame({
                    'epoch': [epoch],
                    'event': epoch_type,
                    'drop_log': [drop_log_item],
                })

                epochs_df = pd.concat([epochs_df, this_df], ignore_index=True)

            # Set the indexes of epochs to match reactions
            indexes = behavioral_df.index
            epochs_df.set_index(indexes, inplace=True)
            results_df = pd.concat([behavioral_df, epochs_df], axis=1)

        elif case == 'STIM':
            behavioral_df = pd.concat([beh_data_uninhibited_nogo_responses_df, beh_data_correct_go_responses_df, beh_data_inhibited_nogo_responses_df]).sort_values(by='trial number')

            logger_preprocessing_info.info(f'Len drop log: {len(drop_log)}')
            logger_preprocessing_info.info(f'Len behavioral df: {len(behavioral_df)}')
            assert len(behavioral_df) == len(drop_log), f'Number of events read from behavioral file ({len(behavioral_df)}) not equals number of events from drop_log ({len(drop_log)})'

            for idx, _ in enumerate(epochs):
                epoch = epochs[idx]
                epoch_type = list(epoch.event_id.keys())
                assert len(epoch_type) == 1, \
                    f'Single trial is not single. Length of epoch: {len(epoch_type)}. Error during trial-wise saving.'
                drop_log_item = drop_log[idx]

                this_df = pd.DataFrame({
                    'epoch': [epoch],
                    'event': epoch_type,
                    'drop_log': [drop_log_item],
                })

                epochs_df = pd.concat([epochs_df, this_df], ignore_index=True)

            # Set the indexes of epochs to match reactions
            indexes = behavioral_df.index
            epochs_df.set_index(indexes, inplace=True)
            results_df = pd.concat([behavioral_df, epochs_df], axis=1)

        else:
            logger_preprocessing_info.info('Not implemented')

        assert len(results_df) == len(behavioral_df) == len(epochs_df), f'Length of trial-wise dataframe ({len(results_df)}) not equals number of events from behavioral file ({len(behavioral_df)}) and number of epochs ({len(epochs_df)})'

        results_df.to_pickle(f'{preprocessed_data_dir_path}preprocessed-beh_{participant_id}.pkl')
        logger_preprocessing_info.info('Epochs and behavioral data in long format saved to pickle.')
    
    if paradigm == 'FLA':
        # read behavioral file
        print('in saving')
        behavioral_data_df = read_behavioral_file(participant_id)

        beh_data_incorrect_incongruent_responses_df = behavioral_data_df[
            (behavioral_data_df['block_type'] == 'experiment') &
            (behavioral_data_df['trial_type'] == 'incongruent') &
            (behavioral_data_df['reaction'] == 'incorrect') &
            ((behavioral_data_df['response'] == 'l') | (behavioral_data_df['response'] == 'r'))
            ]
        logger_preprocessing_info.info(f'Number of incorrect incongruent trials: {len(beh_data_incorrect_incongruent_responses_df)}')

        beh_data_correct_incongruent_responses_df = behavioral_data_df[
            (behavioral_data_df['block_type'] == 'experiment') &
            (behavioral_data_df['trial_type'] == 'incongruent') &
            (behavioral_data_df['reaction'] == 'correct') &
            ((behavioral_data_df['response'] == 'l') | (behavioral_data_df['response'] == 'r'))
            ]
        logger_preprocessing_info.info(f'Number of correct incongruent trials: {len(beh_data_correct_incongruent_responses_df)}')
        
        beh_data_incorrect_congruent_responses_df = behavioral_data_df[
            (behavioral_data_df['block_type'] == 'experiment') &
            (behavioral_data_df['trial_type'] == 'congruent') &
            (behavioral_data_df['reaction'] == 'incorrect') &
            ((behavioral_data_df['response'] == 'l') | (behavioral_data_df['response'] == 'r'))
            ]
        logger_preprocessing_info.info(f'Number incorrect congruent trials: {len(beh_data_incorrect_congruent_responses_df)}')

        beh_data_correct_congruent_responses_df = behavioral_data_df[
            (behavioral_data_df['block_type'] == 'experiment') &
            (behavioral_data_df['trial_type'] == 'congruent') &
            (behavioral_data_df['reaction'] == 'correct') &
            ((behavioral_data_df['response'] == 'l') | (behavioral_data_df['response'] == 'r'))
            ]
        logger_preprocessing_info.info(f'Number correct congruent trials: {len(beh_data_correct_congruent_responses_df)}')

        results_df = pd.DataFrame()
        epochs_df = pd.DataFrame()
        behavioral_df = pd.DataFrame()
        
        if case == 'RE':
            behavioral_df = pd.concat([beh_data_incorrect_incongruent_responses_df, beh_data_correct_incongruent_responses_df, beh_data_incorrect_congruent_responses_df, beh_data_correct_congruent_responses_df]).sort_values(by='trial number')

            logger_preprocessing_info.info(f'Len drop log: {len(drop_log)}')
            logger_preprocessing_info.info(f'Len behavioral df: {len(behavioral_df)}')
            assert len(behavioral_df) == len(drop_log), f'Number of events read from behavioral file ({len(behavioral_df)}) not equals number of events from drop_log ({len(drop_log)})'

            for idx, _ in enumerate(epochs):
                epoch = epochs[idx]
                epoch_type = list(epoch.event_id.keys())
                assert len(epoch_type) == 1, \
                    f'Single trial is not single. Length of epoch: {len(epoch_type)}. Error during trial-wise saving.'
                drop_log_item = drop_log[idx]

                this_df = pd.DataFrame({
                    'epoch': [epoch],
                    'event': epoch_type,
                    'drop_log': [drop_log_item],
                })

                epochs_df = pd.concat([epochs_df, this_df], ignore_index=True)

            # Set the indexes of epochs to match reactions
            indexes = behavioral_df.index
            epochs_df.set_index(indexes, inplace=True)
            results_df = pd.concat([behavioral_df, epochs_df], axis=1)

        elif case == 'STIM':
            behavioral_df = pd.concat([beh_data_incorrect_incongruent_responses_df, beh_data_correct_incongruent_responses_df, beh_data_incorrect_congruent_responses_df, beh_data_correct_congruent_responses_df]).sort_values(by='trial number')

            logger_preprocessing_info.info(f'Len drop log: {len(drop_log)}')
            logger_preprocessing_info.info(f'Len behavioral df: {len(behavioral_df)}')
            assert len(behavioral_df) == len(drop_log), f'Number of events read from behavioral file ({len(behavioral_df)}) not equals number of events from drop_log ({len(drop_log)})'

            for idx, _ in enumerate(epochs):
                epoch = epochs[idx]
                epoch_type = list(epoch.event_id.keys())
                assert len(epoch_type) == 1, \
                    f'Single trial is not single. Length of epoch: {len(epoch_type)}. Error during trial-wise saving.'
                drop_log_item = drop_log[idx]

                this_df = pd.DataFrame({
                    'epoch': [epoch],
                    'event': epoch_type,
                    'drop_log': [drop_log_item],
                })

                epochs_df = pd.concat([epochs_df, this_df], ignore_index=True)

            # Set the indexes of epochs to match reactions
            indexes = behavioral_df.index
            epochs_df.set_index(indexes, inplace=True)
            results_df = pd.concat([behavioral_df, epochs_df], axis=1)

        else:
            logger_preprocessing_info.info('Not implemented')

        assert len(results_df) == len(behavioral_df) == len(epochs_df), f'Length of trial-wise dataframe ({len(results_df)}) not equals number of events from behavioral file ({len(behavioral_df)}) and number of epochs ({len(epochs_df)})'

        results_df.to_pickle(f'{preprocessed_data_dir_path}preprocessed-beh_{participant_id}.pkl')
        logger_preprocessing_info.info('Epochs and behavioral data in long format saved to pickle.')
     
    if paradigm == 'SST':
        # read behavioral file
        behavioral_data_df = read_behavioral_file(participant_id)

        beh_data_inhibited_stop_df = behavioral_data_df.iloc[30:][
            (behavioral_data_df['STOP_TYPE'] == 0) &
            (behavioral_data_df['RE_time'].isna())
            ]
        logger_preprocessing_info.info(f'Number of correctly inhibited STOP trials: {len(beh_data_inhibited_stop_df)}')

        beh_data_uninhibited_stop_df = behavioral_data_df.iloc[30:][
            (behavioral_data_df['STOP_TYPE'] == 0) &
            (behavioral_data_df['RE_time'].notna())
            ]
        logger_preprocessing_info.info(f'Number of incorrectly uninhibited STOP trials: {len(beh_data_uninhibited_stop_df)}')

        beh_data_correct_go_responses_df = behavioral_data_df.iloc[30:][
            (behavioral_data_df['STOP_TYPE'].isna()) &
            # (behavioral_data_df['STOP_TYPE'] != 1) &
            (behavioral_data_df['RE_key'] == behavioral_data_df['RE_true'])
            ]
        logger_preprocessing_info.info(f'Number correct GO trials: {len(beh_data_correct_go_responses_df)}')

        beh_data_incorrect_go_responses_df = behavioral_data_df.iloc[30:][
            (behavioral_data_df['STOP_TYPE'] != 0) &
            (behavioral_data_df['STOP_TYPE'] != 1) &
            (behavioral_data_df['RE_key'] != behavioral_data_df['RE_true']) &
            (behavioral_data_df['RE_key'].notna())
            ]
        logger_preprocessing_info.info(f'Number incorrect GO trials: {len(beh_data_incorrect_go_responses_df)}')

        results_df = pd.DataFrame()
        epochs_df = pd.DataFrame()
        behavioral_df = pd.DataFrame()
        
        if case == 'RE':
            behavioral_df = pd.concat([beh_data_uninhibited_stop_df, beh_data_incorrect_go_responses_df, beh_data_correct_go_responses_df]).sort_values(by='trial number')

            logger_preprocessing_info.info(f'Len drop log: {len(drop_log)}')
            logger_preprocessing_info.info(f'Len behavioral df: {len(behavioral_df)}')
            assert len(behavioral_df) == len(drop_log), f'Number of events read from behavioral file ({len(behavioral_df)}) not equals number of events from drop_log ({len(drop_log)})'

            for idx, _ in enumerate(epochs):
                epoch = epochs[idx]
                epoch_type = list(epoch.event_id.keys())
                assert len(epoch_type) == 1, \
                    f'Single trial is not single. Length of epoch: {len(epoch_type)}. Error during trial-wise saving.'
                drop_log_item = drop_log[idx]

                this_df = pd.DataFrame({
                    'epoch': [epoch],
                    'event': epoch_type,
                    'drop_log': [drop_log_item],
                })

                epochs_df = pd.concat([epochs_df, this_df], ignore_index=True)

            # Set the indexes of epochs to match reactions
            indexes = behavioral_df.index
            epochs_df.set_index(indexes, inplace=True)
            results_df = pd.concat([behavioral_df, epochs_df], axis=1)

        elif case == 'STIM':
            behavioral_df = pd.concat([beh_data_inhibited_stop_df, beh_data_uninhibited_stop_df]).sort_values(by='trial number')

            logger_preprocessing_info.info(f'Len drop log: {len(drop_log)}')
            logger_preprocessing_info.info(f'Len behavioral df: {len(behavioral_df)}')
            assert len(behavioral_df) == len(drop_log), f'Number of events read from behavioral file ({len(behavioral_df)}) not equals number of events from drop_log ({len(drop_log)})'

            for idx, _ in enumerate(epochs):
                epoch = epochs[idx]
                epoch_type = list(epoch.event_id.keys())
                assert len(epoch_type) == 1, \
                    f'Single trial is not single. Length of epoch: {len(epoch_type)}. Error during trial-wise saving.'
                drop_log_item = drop_log[idx]

                this_df = pd.DataFrame({
                    'epoch': [epoch],
                    'event': epoch_type,
                    'drop_log': [drop_log_item],
                })

                epochs_df = pd.concat([epochs_df, this_df], ignore_index=True)

            # Set the indexes of epochs to match reactions
            indexes = behavioral_df.index
            epochs_df.set_index(indexes, inplace=True)
            results_df = pd.concat([behavioral_df, epochs_df], axis=1)

        else:
            logger_preprocessing_info.info('Not implemented')

        assert len(results_df) == len(behavioral_df) == len(epochs_df), f'Length of trial-wise dataframe ({len(results_df)}) not equals number of events from behavioral file ({len(behavioral_df)}) and number of epochs ({len(epochs_df)})'

        results_df.to_pickle(f'{preprocessed_data_dir_path}preprocessed-beh_{participant_id}.pkl')
        logger_preprocessing_info.info('Epochs and behavioral data in long format saved to pickle.')
        
    
    return results_df

## Preprocessing

Set globals

In [None]:
# GNG | SST | Flanker
paradigm = 'SST'
# RE | STIM | FBCK
case = 'STIM'
# todo think whether move global vars as paradigm and case info some kind of data/case class

Set paths base on globals values

In [None]:
trigger_dir_path = f'data/{paradigm}/raw/triggers/'
bdf_dir_path = f'data/{paradigm}/raw/bdfs/'
behavioral_dir_path = f'data/{paradigm}/behavioral/'
preprocessed_data_dir_path = f'data/{paradigm}/preprocessed/{case}/'
logger_dir_path = f'data/{paradigm}/'
ica_cache_dir_path = f'data/{paradigm}/ica_cache/'

Set output files for loggers

In [None]:
######## PREPROCESSING ##############################################
# Create a file handler for preprocessing and set the level to INFO
file_handler_preprocessing = logging.FileHandler(f'data/{paradigm}/{case}_preprocessing.txt')
file_handler_preprocessing.setLevel(logging.INFO)

# Create a formatter and add it to the file handler for preprocessing
formatter_preprocessing = logging.Formatter('%(asctime)s - %(levelname)s - %(message)s')
file_handler_preprocessing.setFormatter(formatter_preprocessing)

# Add the file handler for method A to the logger for preprocessing
logger_preprocessing_info.addHandler(file_handler_preprocessing)

######## ERRORS ##############################################
# Create a file handler for errors and set the level to INFO
file_handler_errors = logging.FileHandler(f'data/{paradigm}/{case}_errors.txt')
file_handler_errors.setLevel(logging.INFO)

# Create a formatter and add it to the file handler for errors
formatter_errors = logging.Formatter('%(asctime)s - %(levelname)s - %(message)s')
file_handler_errors.setFormatter(formatter_errors)

# Add the file handler for method A to the logger for preprocessing
logger_errors_info.addHandler(file_handler_errors)

##### MNE ###################################################
# Create logger for MNE logs
logger_f_name = f'data/{paradigm}/{case}_MNE-logs.txt'
set_log_file(fname=logger_f_name, output_format="%(asctime)s - %(message)s", overwrite=None)

Read participant IDs

In [None]:
id_list = [item.split('.')[0] for item in os.listdir(bdf_dir_path)]

In [None]:
for participant_id in id_list:
    print(f'{participant_id}\n')
    bdf_fname = f'{bdf_dir_path}{participant_id}.bdf'
    trigger_fname = f'{trigger_dir_path}triggerMap_{participant_id}.txt'

    logger_preprocessing_info.info(f'#### PARTICIPANT ID: {participant_id} #########')
    logger_errors_info.info(f'#### PARTICIPANT ID: {participant_id} #########')

    try:
        trigger_map = read_trigger_map(trigger_fname)
        participant_context = create_events_mappings(trigger_map, case=case)
        logger_preprocessing_info.info(f'Context: {participant_context}')
        
        epochs_preprocessed, drop_log = pre_process_eeg(
            input_fname=bdf_fname,
            participant_id = participant_id,
            context=participant_context,
            trigger_fname=trigger_fname,
            case=case,
        )
    
        _ = save_epochs_with_behavioral_data_long(
            epochs_preprocessed,
            drop_log,
            participant_id,
            case=case,
        )
    except Exception as e:        
        logger_errors_info.info(f"{e}")
    
    logger_preprocessing_info.info(f'\n')
    logger_errors_info.info(f'\n')

print(f'##########\n DONE\n')   
# Restore MNE logging to std out     
set_log_file(fname=None)

In [None]:
behavioral_data_df = read_behavioral_file(participant_id)

beh_data_inhibited_stop_df = behavioral_data_df.iloc[30:][
    (behavioral_data_df['STOP_TYPE'] == 0) &
    (behavioral_data_df['RE_time'].isna())
    ]
logger_preprocessing_info.info(f'Number of correctly inhibited STOP trials: {len(beh_data_inhibited_stop_df)}')

beh_data_uninhibited_stop_df = behavioral_data_df.iloc[30:][
    (behavioral_data_df['STOP_TYPE'] == 0) &
    (behavioral_data_df['RE_time'].notna())
    ]
logger_preprocessing_info.info(f'Number of incorrectly uninhibited STOP trials: {len(beh_data_uninhibited_stop_df)}')

beh_data_correct_go_responses_df = behavioral_data_df.iloc[30:][
    (behavioral_data_df['STOP_TYPE'].isna()) &
    # (behavioral_data_df['STOP_TYPE'] != 1) &
    (behavioral_data_df['RE_key'] == behavioral_data_df['RE_true'])
    ]
logger_preprocessing_info.info(f'Number correct GO trials: {len(beh_data_correct_go_responses_df)}')

beh_data_incorrect_go_responses_df = behavioral_data_df.iloc[30:][
    (behavioral_data_df['STOP_TYPE'] != 0) &
    (behavioral_data_df['STOP_TYPE'] != 1) &
    (behavioral_data_df['RE_key'] != behavioral_data_df['RE_true']) &
    (behavioral_data_df['RE_key'].notna())
    ]
logger_preprocessing_info.info(f'Number incorrect GO trials: {len(beh_data_incorrect_go_responses_df)}')

In [None]:
len(beh_data_incorrect_go_responses_df)

In [None]:
epochs=epochs_preprocessed

In [None]:
results_df = pd.DataFrame()
epochs_df = pd.DataFrame()
behavioral_df = pd.DataFrame()

if case == 'RE':
    behavioral_df = pd.concat([beh_data_uninhibited_stop_df, beh_data_incorrect_go_responses_df, beh_data_correct_go_responses_df]).sort_values(by='trial number')

    logger_preprocessing_info.info(f'Len drop log: {len(drop_log)}')
    logger_preprocessing_info.info(f'Len behavioral df: {len(behavioral_df)}')
    # assert len(behavioral_df) == len(drop_log), f'Number of events read from behavioral file ({len(behavioral_df)}) not equals number of events from drop_log ({len(drop_log)})'

    for idx, _ in enumerate(epochs):
        epoch = epochs[idx]
        epoch_type = list(epoch.event_id.keys())
        assert len(epoch_type) == 1, \
            f'Single trial is not single. Length of epoch: {len(epoch_type)}. Error during trial-wise saving.'
        drop_log_item = drop_log[idx]

        this_df = pd.DataFrame({
            'epoch': [epoch],
            'event': epoch_type,
            'drop_log': [drop_log_item],
        })

        epochs_df = pd.concat([epochs_df, this_df], ignore_index=True)

    # Set the indexes of epochs to match reactions
    indexes = behavioral_df.index
    epochs_df.set_index(indexes, inplace=True)
    results_df = pd.concat([behavioral_df, epochs_df], axis=1)

In [None]:
results_df

In [None]:
behavioral_df

---
## Manual checking

In [None]:
def replace_trigger_names2(raw, participant_id, trigger_map, new_response_event_dict=None, replace=False, search='RE'):
    # Replace event IDs in the Raw object
    events = mne.find_events(raw, stim_channel='Status')
    new_events_list = events.copy()
    
    # add trigger to corrupted bdf files - too short reaction 
    if paradigm == 'GNG' and (participant_id == 'B-GNG-199' or participant_id == 'B-GNG-208'):
        delta_time = 3
        for idx, event in enumerate(events):
            event_id = str(event[2])[-1]
            trigger_id = trigger_map[idx][0]
            trigger_new_code = trigger_map[idx][1]

            if event_id != trigger_id:
                new_events_list = np.concatenate([events[:idx, :], [[events[idx-1][0] + delta_time, 0, 65281]], events[idx:, :]]) 
                break
    
    ids = ['SST-165', 'SST-211', 'SST-122', 'SST-088','SST-045','SST-012','SST-083','SST-136','SST-125']
    if paradigm == 'SST' and participant_id in ids:
        delta_time = 3
        for idx, event in enumerate(events):
            event_id = str(event[2])[-1]
            trigger_id = trigger_map[idx][0]
            trigger_new_code = trigger_map[idx][1]

            if event_id != trigger_id:
                new_events_list = np.concatenate([events[:idx, :], [[events[idx-1][0] + delta_time, 0, 65281]], events[idx:, :]]) 
                break
                
    if paradigm == 'SST' and participant_id == 'SST-130':
        delta_time = 3
        for idx, event in enumerate(events):
            event_id = str(event[2])[-1]
            trigger_id = trigger_map[idx][0]
            trigger_new_code = trigger_map[idx][1]

            if event_id != trigger_id:
                new_events_list = np.concatenate([events[:idx, :], [[events[idx-1][0] + delta_time, 0, 65281]], events[idx:, :]]) 
                break
        
        for idx, event in enumerate(new_events_list):
            event_id = str(event[2])[-1]
            trigger_id = trigger_map[idx][0]
            trigger_new_code = trigger_map[idx][1]

            if event_id != trigger_id:
                new_events_list = np.concatenate([new_events_list[:idx, :], [[new_events_list[idx-1][0] + delta_time, 0, 65281]], new_events_list[idx:, :]]) 
                break
                
    
    # delete ghost event 65312
    if paradigm == 'SST' and (participant_id == 'SST-181'):
        for idx, event in enumerate(events):
            if event[2] == 65312:
                new_events_list = np.concatenate([events[:idx, :], events[idx+1:, :]]) 
                break  
    
    # delete ghost event 0: 130816
    if paradigm == 'SST' and (participant_id == 'SST-075'):
        for idx, event in enumerate(events):
            if event[2] == 130816:
                new_events_list = np.concatenate([events[:idx, :], events[idx+1:, :]]) 
                break             
                
        
    logger_preprocessing_info.info(f'EVENTS: {new_events_list}')

    # assert len(new_events_list) == len(trigger_map), \
    #         f'The length of trigger map ({len(trigger_map)}) not equals length of events in eeg recording ({len(new_events_list)})'

    trigger_map_codes, mapping = create_triggers_dict(trigger_map)

    for idx, event in enumerate(new_events_list):
        event_id = str(event[2])[-1]
        trigger_id = trigger_map[idx][0]
        trigger_new_code = trigger_map[idx][1]
        
        print(f'{event[2]} {trigger_id}')
        
        if event_id != trigger_id:
            logger_errors_info.info(f'An event {idx} has different number than in provided file. {trigger_id} expected, {str(event[2])} found. Triggers may need to be checked.')

        trigger_new_code_int = trigger_map_codes[trigger_new_code]
        new_events_list[idx][2] = trigger_new_code_int

    annot_from_events = mne.annotations_from_events(
        events=new_events_list,
        event_desc=mapping,
        sfreq=raw.info["sfreq"],
        orig_time=raw.info["meas_date"],
    )
    raw_copy = raw.copy()
    raw_copy.set_annotations(annot_from_events)

    return raw_copy

In [None]:
# SST-165
# SST-211
# SST-122
# SST-088
# SST-045
# SST-012
# SST-024 (The length of trigger map (816) not equals length of events in eeg recording (776))
# SST-083
# SST-136
# SST-130 (The length of trigger map (848) not equals length of events in eeg recording (846))
# SST-125

In [None]:
id_ = 'SST-130'
paradigm= 'SST'
case = 'RE'

In [None]:
trigger_dir_path = f'data/{paradigm}/raw/triggers/'
bdf_dir_path = f'data/{paradigm}/raw/bdfs/'
behavioral_dir_path = f'data/{paradigm}/behavioral/'
# preprocessed_data_dir_path = f'data/{paradigm}/preprocessed/{case}/'
logger_dir_path = f'data/{paradigm}/'
# ica_cache_dir_path = f'data/{paradigm}/ica_cache/'

In [None]:
console_handler = logging.StreamHandler()
console_handler.setLevel(logging.DEBUG)  # Set the desired logging level

logger_errors_info.addHandler(console_handler)
logger_preprocessing_info.addHandler(console_handler)

In [None]:
bdf_fname = f'{bdf_dir_path}{id_}.bdf' 
trigger_fname = f'{trigger_dir_path}triggerMap_{id_}.txt'
logger_f_name = f'{logger_dir_path}{case}_logs.txt'

# 0. read bdf
raw = mne.io.read_raw_bdf(
    bdf_fname,
    eog=['EXG1', 'EXG2', 'EXG3', 'EXG4'],
    exclude=['EXG5', 'EXG6'],
    preload=True
)

try:
    raw = raw.set_montage('biosemi64')
except ValueError as e:
    if '[\'EXG7\', \'EXG8\']' in e.args[0]:
        raw = raw.set_montage('biosemi64', on_missing='ignore')
        logger_preprocessing_info.info('On missing')
    else:
        logger_preprocessing_info.info('Lacks important channels!')

# 1. replace trigger names
trigger_map = read_trigger_map(trigger_fname)
raw_new_triggers = replace_trigger_names2(raw, id_, trigger_map)
participant_context = create_events_mappings(trigger_map, case=case)

# 2. re-reference: to mastoids
# raw_ref = raw_new_triggers.copy().set_eeg_reference(ref_channels=['EXG7', 'EXG8'])

# # 3. 4-th order Butterworth filters
# raw_filtered = raw_ref.copy().filter(
#     l_freq=.1,
#     h_freq=30.0,
#     n_jobs=10,
#     method='iir',
#     iir_params=None,
# )

# # 4. Notch filter at 50 Hz
# raw_filtered = raw_filtered.notch_filter(
#     freqs=np.arange(50, (raw_filtered.info['sfreq'] / 2), 50),
#     n_jobs=10,
# )

In [None]:
filtered_raw_ica = raw_ref.copy().drop_channels(['EXG7', 'EXG8']).filter(l_freq=1.0, h_freq=None)
heog=['EXG3', 'EXG4']
veog=['EXG1', 'EXG2'] 

ica_epochs = create_epochs(
    filtered_raw_ica,
    participant_context,
)

ica = mne.preprocessing.ICA(
    n_components=15,
    method='infomax',
    max_iter="auto",
    random_state=random_state
)
# ica.fit(filtered_raw_ica)
ica.fit(filtered_raw_ica)

ica.exclude = []

# find which ICs match the VEOG pattern
veog_indices_eogs, veog_scores = ica.find_bads_eog(
    filtered_raw_ica,
    ch_name=veog,
    threshold=0.9,
    measure='correlation'
)

# find which ICs match the HEOG pattern
heog_indices_eogs, heog_scores = ica.find_bads_eog(
    filtered_raw_ica,
    ch_name=heog,
    threshold=0.6,
    measure='correlation'
)

print(f'EOG based excluded ICA components:\nVEOG: {veog_indices_eogs}\nHEOG: {heog_indices_eogs}')


heog_indices = []
veog_indices = []

print('ICA components')
fig = ica.plot_components()

templates_ica = pd.read_pickle('public_data/eog_templates.pkl')

template_veog_component = templates_ica['VEOG'].to_numpy()[0].flatten()
template_heog_component = templates_ica['HEOG'].to_numpy()[0].flatten()

mne.preprocessing.corrmap(
    [ica], 
    template=template_veog_component,
    threshold=0.9,
    label="veog blink", 
    plot=False
)

mne.preprocessing.corrmap(
    [ica], 
    template=template_heog_component, 
    label="heog blink", 
    threshold=0.8,
    plot=False
)

veog_indices = ica.labels_['veog blink']
heog_indices = ica.labels_['heog blink']

print(veog_indices)
print(heog_indices)


# if (len(veog_indices_eogs) == 1) and (len(veog_indices) != 1) and (veog_indices_eogs[0] in veog_indices):
#     veog_indices = veog_indices_eogs
# if (len(heog_indices_eogs) == 1) and (len(heog_indices) != 1) and (heog_indices_eogs[0] in heog_indices):
#     heog_indices = heog_indices_eogs
# logger_errors_info.info(f'Too much patterns found. Comparing to EOG channels')


# logger_preprocessing_info.info(f'Excluded ICA components:\nVEOG: {veog_indices}\nHEOG: {heog_indices}')
# fig = ica.plot_components(veog_indices + heog_indices)

# assert len((veog_indices + heog_indices)) == 2, f'Number of ICA components to exclude ({len(veog_indices + heog_indices)}) is different than 2' 

# ica.exclude = veog_indices + heog_indices

# reconstructed_raw = raw.copy()
# ica.apply(reconstructed_raw)

# del filtered_raw_ica


In [None]:
# raw_corrected_eogs.save(f'{ica_cache_dir_path}{id}_raw.fif.gz', overwrite=True)

In [None]:
# 6. segmentation -400 to 800 ms around the response
raw_corrected_eogs_drop_ref = raw_corrected_eogs.copy().drop_channels(['EXG7', 'EXG8']).pick('eeg')

epochs = create_epochs(
    raw_corrected_eogs_drop_ref,
    context=context,
    tmin=-.4,
    tmax=.8,
    reject=None,
    reject_by_annotation=False,
)

# 7. Trial-wise Bad Channels Identification
drop_log = find_bad_trails(epochs)

# 8. Global Bad Channel Identification - <.4 corr with 6 neigh. and channels marked as bad for more than 20% trials
drop_log, global_bad_channels_drop_log = find_global_bad_channels(epochs, drop_log)

# 9. calculate trails to remove
trials_to_drop_indices, drop_log = mark_bad_trials(epochs, drop_log, threshold=0.1)

# 10. Interpolate bad channels (and thus update drop log)
interpolated_epochs, drop_log = interpolate_bad_channels(epochs, drop_log, global_bad_channels_drop_log)

# 11. Remove participants that have less then 6 trials
clean_epochs, _ = reject_bad_trials(interpolated_epochs.copy(), drop_log, trials_to_drop_indices)
if len(clean_epochs) < 6:
    logger_errors_info.info(f"Participant ID: {id} has not enough clean trials")

# 13. Baseline correction
if case == 'RE':
    interpolated_epochs.apply_baseline(baseline=(-0.4, -0.2),)
elif case == 'STIM':
    interpolated_epochs.apply_baseline(baseline=(-0.2, 0),)
else:
    logger_errors_info.info('Not know case. Setting baseline from -0.2 to 0')
    interpolated_epochs.apply_baseline(baseline=(-0.2, 0),)

    

---
## Tests

In [None]:
id = 'B-GNG-102'

In [None]:
epochs_df = pd.read_pickle(f'{preprocessed_data_dir_path}preprocessed_{id}.pkl')
epochs = epochs_df['epochs'].to_numpy().flatten()[0]
drop_log = epochs_df['drop_log'].to_numpy().flatten()[0]

In [None]:
error_wave = create_erps_waves(
    epochs, 
    drop_log, 
    type='uninhibited_response', 
    tmin=-0.1, 
    tmax=0.6, 
    picks=['FCz'],
    new_response_event_dict = {"inhibited_response": 0, "uninhibited_response": 1}
)
correct_wave = create_erps_waves(
    epochs, 
    drop_log, 
    type='inhibited_response', 
    tmin=-0.1, 
    tmax=0.6, 
    picks=['FCz'],
    new_response_event_dict = {"inhibited_response": 0, "uninhibited_response": 1}
)

plt.plot(np.linspace(-0.1, 0.6, len(np.mean(error_wave, axis=0).flatten())), np.mean(error_wave, axis=0).flatten())
plt.plot(np.linspace(-0.1, 0.6, len(np.mean(correct_wave, axis=0).flatten())), np.mean(correct_wave, axis=0).flatten())

plt.show()

In [None]:
error_wave.shape

In [None]:
# ERPs scoring
ern_single_trials = create_erps(epochs_preprocessed, drop_log, type='error_response', tmin=0, tmax=0.1, picks=['FCz'])
pe_single_trials = create_erps(epochs_preprocessed, drop_log, type='error_response', tmin=0.2, tmax=0.4, picks=['Pz'])

---
## For testing

In [None]:
input_fname = 'data/GNG/raw/bdfs/B-GNG-102.bdf'
raw = mne.io.read_raw_bdf(
    input_fname, 
    eog=['EXG1', 'EXG2', 'EXG3', 'EXG4'], 
    exclude=['EXG5', 'EXG6'], 
    preload=True
)

try:
    raw = raw.set_montage('biosemi64')
except ValueError as e:
    if '[\'EXG7\', \'EXG8\']' in e.args[0]:
        raw = raw.set_montage('biosemi64', on_missing='ignore')
        print('On missing')
    else:
        print('Lacks important channels!')


file_path = 'data/GNG/raw/triggers/triggerMap_B-GNG-102.txt'
trigger_map = read_trigger_map(file_path)
# raw_new_triggers = replace_trigger_names(raw, trigger_map)

In [None]:
trigger_map

In [None]:
'TG' in trigger_map[1][1]

In [None]:
def add_response_info(trigger_map):
    new_trigger_map = trigger_map.copy()
    for idx, trigger in enumerate(new_trigger_map):
        if 'TG' in trigger[1]:
            print('in target')
            if 'RE' in new_trigger_map[idx+1][1]:
                print('adding RE')
                new_trigger = (trigger[0], trigger[1][:-1]+'R')
                new_trigger_map[idx] = new_trigger
                
    return new_trigger_map   

In [None]:
ne_tg_map = add_response_info(trigger_map)
# ne_tg_map

In [None]:
ne_tg_map

In [None]:
# Replace event IDs in the Raw object
events = mne.find_events(raw, stim_channel='Status')
new_events_list = events.copy()
print(f'EVENTS: {new_events_list}')

assert len(events) == len(trigger_map), \
    f'The length of trigger map ({len(trigger_map)}) not equals length of events in eeg recording ({len(events)})'

trigger_map_codes, mapping = create_triggers_dict(trigger_map)

for idx, event in enumerate(events):
    event_id = str(event[2])[-1]
    trigger_id = trigger_map[idx][0]
    trigger_new_code = trigger_map[idx][1]
    
    print(event)
    print(event_id)
    print(trigger_id)
# 
    if event_id != trigger_id:
        print(f'An event {idx} has different number than in provided file. {trigger_id} expected, {str(event[2])} found. Triggers may need to be checked.')

    trigger_new_code_int = trigger_map_codes[trigger_new_code]
    new_events_list[idx][2] = trigger_new_code_int

annot_from_events = mne.annotations_from_events(
    events=new_events_list,
    event_desc=mapping,
    sfreq=raw.info["sfreq"],
    orig_time=raw.info["meas_date"],
)
raw_copy = raw.copy()
raw_copy.set_annotations(annot_from_events)

In [None]:
len(new_events_list)

In [None]:
import matplotlib
import matplotlib.pyplot as plt


In [None]:
# matplotlib.use('Qt5Agg')
# plt.switch_backend('QtAgg')

In [None]:
fig = raw_copy.plot(start=1404, duration=15)
# plt.pause(0.0001)
# %matplotlib
# plt.show()

In [None]:
mne.viz.plot_raw(raw_copy)

In [None]:
# 1. re-reference: to mastoids
raw_ref = raw_new_triggers.copy().set_eeg_reference(ref_channels=['EXG7', 'EXG8'])
# fig = raw_ref.plot(start=60, duration=1)

In [None]:
# 2. 4-th order Butterworth filters
raw_filtered = raw_ref.copy().filter(
    l_freq=.1,
    h_freq=30.0,
    n_jobs=10,
    method='iir',
    iir_params=None,
)

In [None]:
# 3. Notch filter at 50 Hz
raw_filtered = raw_filtered.notch_filter(
    freqs=np.arange(50, 251, 50),
    n_jobs=10,
    # method='iir',
    # iir_params=None,
)

In [None]:
# 5. ocular artifact correction with ICA
# raw_corrected_eogs = ocular_correction_ica(raw_filtered, raw_ref, heog=['EXG3', 'EXG4'], veog=['EXG1', 'EXG2'])
heog=['EXG3', 'EXG4']
veog=['EXG1', 'EXG2']
filtered_raw_ica = raw_ref.copy().drop_channels(['EXG7', 'EXG8']).filter(l_freq=1.0, h_freq=None)

ica = mne.preprocessing.ICA(
    n_components=12,
    method='infomax',
    max_iter="auto",
    random_state=random_state
)
ica.fit(filtered_raw_ica)

ica.exclude = []

# find which ICs match the VEOG pattern
veog_indices, veog_scores = ica.find_bads_eog(
    filtered_raw_ica,
    ch_name=veog,
    threshold=0.9,
    measure='correlation'
)

# find which ICs match the HEOG pattern
heog_indices, heog_scores = ica.find_bads_eog(
    filtered_raw_ica,
    ch_name=heog,
    threshold=0.7,
    measure='correlation'
)

In [None]:
print('ICA components')
fig = ica.plot_components()

print(f"VEOG indices: {veog_indices}\nVEOG scores: {veog_scores}\n")
print(f"HEOG indices: {heog_indices}\nHEOG scores: {heog_scores}\n")


print(f'Excluded ICA components:\nVEOG: {veog_indices}\nHEOG: {heog_indices}')
fig = ica.plot_components(veog_indices + heog_indices)

In [None]:
# if len(veog_indices + heog_indices) == 0 and from_template == 'auto':
# print('No ICA component correlates with EOG channels. Using templates...')
# ica_template = pd.read_pickle('data/eog_templates.pkl')
# veog_template = ica_template[['VEOG']].to_numpy()
# heog_template = ica_template[['HEOG']].to_numpy()
# 
# mne.preprocessing.corrmap([ica], template=veog_template, threshold=0.9)
# mne.preprocessing.corrmap([ica], template=heog_template, threshold=0.8)


# elif from_template is True:
#     print('Using templates...')

# ica.exclude = veog_indices + heog_indices
# # 
# reconstructed_raw = raw.copy()
# ica.apply(reconstructed_raw)

template_veog_component = ica.get_components()[:, veog_indices[0]]
template_heog_component = ica.get_components()[:, heog_indices[0]]

templates_ica = pd.DataFrame({
    'VEOG': [template_veog_component],
    'HEOG': [template_heog_component],
})
templates_ica.to_pickle('data/eog_templates.pkl')
#########
templates_ica = pd.read_pickle('data/eog_templates.pkl')

template_veog_component = templates_ica['VEOG'].to_numpy()[0].flatten()
template_heog_component = templates_ica['HEOG'].to_numpy()[0].flatten()


mne.preprocessing.corrmap([ica], template=template_veog_component, threshold=0.9, label="veog blink", plot=False)
mne.preprocessing.corrmap([ica], template=template_heog_component, threshold=0.8, label="heog blink", plot=False)

to_exclude = ica.labels_['veog blink'] + ica.labels_['heog blink']
print(to_exclude)



# del filtered_raw_ica


In [None]:
ica.labels_

In [None]:
raw_corrected_eogs_drop_ref = raw_corrected_eogs.copy().drop_channels(['EXG7', 'EXG8']).pick('eeg')

epochs = create_epochs(
    raw_corrected_eogs_drop_ref, 
    context=participant_context
    tmin=-.4, 
    tmax=.8,
    reject=None,
    reject_by_annotation=False,
)

fig = epochs.plot(n_epochs=20, n_channels=68)

In [None]:
# 6. Trial-wise Bad Channels Identification
drop_log = find_bad_trails(epochs) 
drop_log

In [None]:
# 7. Global Bad Channel Identification - <.4 corr with 6 neigh. and channels marked as bad for more than 20% trials
drop_log, global_bad_channels_drop_log = find_global_bad_channels(epochs, drop_log)

In [None]:
# 8. calculate trails to remove
trials_to_drop_indices, drop_log = calculate_bad_trials(drop_log, threshold=0.1)

In [None]:
# 10. Interpolate bad channels (and thus update drop log)
interpolated_epochs, drop_log = interpolate_bad_channels(epochs, drop_log, global_bad_channels_drop_log)

In [None]:
# 12. Remove participants that have less then 6 trials
clean_epochs, _ = reject_bad_trials(interpolated_epochs.copy(), drop_log, trials_to_drop_indices)
if len(clean_epochs) < 6:
      print(f"Participant ID: {id} has not enough clean trials")

In [None]:
# 13. Baseline correction
interpolated_epochs.apply_baseline(baseline=(-0.4, -0.2),)

In [None]:
drop_log

In [None]:
# 13. ERPs scoring
ern_single_trials = create_erps(interpolated_epochs, drop_log, type='error_response', tmin=0, tmax=0.1, picks=['FCz'])
pe_single_trials = create_erps(interpolated_epochs, drop_log, type='error_response', tmin=0.2, tmax=0.4, picks=['Pz'])

In [None]:
error_wave = create_erps_waves(interpolated_epochs, drop_log, type='error_response', tmin=-0.1, tmax=0.6, picks=['FCz'])
plt.plot(np.linspace(-0.1, 0.6, len(np.mean(error_wave, axis=0).flatten())), np.mean(error_wave, axis=0).flatten(), )

In [None]:
plt.plot(np.linspace(-0.1, 0.6, len(np.mean(error_wave, axis=0).flatten())), np.mean(error_wave, axis=0).flatten(), )

In [None]:
# filtered_raw_ica = raw_ref.copy().drop_channels(['EXG7', 'EXG8']).filter(l_freq=1.0, h_freq=None)
# 
# ica = mne.preprocessing.ICA(
#     n_components=12,
#     method='infomax',
#     max_iter="auto",
#     random_state=random_state
# )
# ica.fit(filtered_raw_ica)
# 
# ica.exclude = []
# heog=['EXG3', 'EXG4']
# veog=['EXG1', 'EXG2']
# 
# # find which ICs match the VEOG pattern
# veog_indices, veog_scores = ica.find_bads_eog(
#     filtered_raw_ica,
#     ch_name=veog,
#     threshold=0.9,
#     measure='correlation'
# )
# # 
# print(f'{veog_indices}')
# 
# heog_indices, heog_scores = ica.find_bads_eog(
#     filtered_raw_ica,
#     ch_name=heog,
#     threshold=0.7,
#     measure='correlation'
# )
# print(f'{heog_indices}')
# 
# 
# print(f'Excluded ICA components: {veog_indices + heog_indices}')
# fig = ica.plot_components(veog_indices + heog_indices)
# 
# ica.exclude = veog_indices + heog_indices
# # 
# reconstructed_raw = raw_filtered.copy()
# ica.apply(reconstructed_raw)

In [None]:
# ica.exclude = []
# heog=['EXG3', 'EXG4']
# veog=['EXG1', 'EXG2']
# 
# # find which ICs match the VEOG pattern
# veog_indices, veog_scores = ica.find_bads_eog(
#     raw_ref,
#     ch_name=veog,
#     threshold=0.9,
#     measure='correlation'
# )
# # 
# print(f'{veog_indices}')
# 
# heog_indices, heog_scores = ica.find_bads_eog(
#     raw_ref,
#     ch_name=heog,
#     threshold=0.7,
#     measure='correlation'
# )
# print(f'{heog_indices}')
# 
#  
# print(f'Excluded ICA components: {veog_indices + heog_indices}')
# fig = ica.plot_components(veog_indices + heog_indices)
# 
# ica.exclude = veog_indices + heog_indices
# # 
# reconstructed_raw = raw_filtered.copy()
# ica.apply(reconstructed_raw)

In [None]:
print(heog_scores)
print(veog_scores)