In [None]:
import logging
from dataclasses import dataclass

import mne
import re
from collections import OrderedDict
import matplotlib.pyplot as plt
import seaborn as sns
from mne.utils import set_log_file
from sklearn.decomposition import PCA
import numpy as np
from ssqueezepy import Wavelet, cwt, icwt
from lifelines import KaplanMeierFitter
from ssqueezepy.experimental import scale_to_freq
from sklearn.linear_model import LinearRegression
import pandas as pd
import os
import json

# matplotlib.use('Qt5Agg')
# plt.switch_backend('QtAgg')

Constants

In [None]:
random_state = 42

Loggers

In [None]:
######## PREPROCESSING ##############################################
# Create a custom logger for preprocessing INFO
logger_preprocessing_info = logging.getLogger('preprocessing_info')
logger_preprocessing_info.setLevel(logging.INFO)
logger_preprocessing_info.propagate = False


######## ERRORS ##############################################
# Create a custom logger for errors
logger_errors_info = logging.getLogger('errors')
logger_errors_info.setLevel(logging.INFO)
logger_preprocessing_info.propagate = False

In [None]:
def read_trigger_map(file_name):
    line_count = 0
    trigger_map = []
    with open(file_name, 'r') as file:
        # Read each line and increment the counter
        line = file.readline()
        try:
            match = re.search("(.*):(.*)(\\n)", line)
            trigger = (match.group(1), match.group(2), )
            trigger_map.append(trigger)
        except:
            pass
        while line:
            line_count += 1
            line = file.readline()
            try:
                match = re.search("(.*):(.*)(\\n)", line)
                trigger = (match.group(1), match.group(2), )
                trigger_map.append(trigger)
            except:
                pass

    assert len(trigger_map) == line_count, \
        f'The length of trigger file ({line_count}) not equals length of created trigger_map ({len(trigger_map)})'

    return trigger_map

def create_triggers_dict(trigger_map):
    triggers_codes = [item[1] for item in trigger_map]
    # Create an ordered dictionary to maintain order and remove duplicates
    unique_ordered_dict = OrderedDict.fromkeys(triggers_codes)
    numbered_dict = {key: 1000 + number for number, key in enumerate(unique_ordered_dict.keys())}
    reversed_numbered_dict = {1000 + number: key for number, key in enumerate(unique_ordered_dict.keys())}
    return numbered_dict, reversed_numbered_dict

def replace_trigger_names(raw, participant_id, trigger_map, new_response_event_dict=None, replace=False, search='RE'):
    # Replace event IDs in the Raw object
    events = mne.find_events(raw, stim_channel='Status')
    new_events_list = events.copy()
    
    # add trigger to corrupted bdf files - too short reaction 
    if paradigm == 'GNG' and (participant_id == 'B-GNG-199' or participant_id == 'B-GNG-208'):
        delta_time = 3
        for idx, event in enumerate(events):
            event_id = str(event[2])[-1]
            trigger_id = trigger_map[idx][0]
            trigger_new_code = trigger_map[idx][1]

            if event_id != trigger_id:
                new_events_list = np.concatenate([events[:idx, :], [[events[idx-1][0] + delta_time, 0, 65281]], events[idx:, :]]) 
                break
    # add missing RE triggers to bdf file - to short time between stop trigger and reaction trigger
    ids = ['SST-165', 'SST-211', 'SST-122', 'SST-088','SST-045','SST-012','SST-083','SST-136','SST-125']
    if paradigm == 'SST' and participant_id in ids:
        delta_time = 3
        for idx, event in enumerate(events):
            event_id = str(event[2])[-1]
            trigger_id = trigger_map[idx][0]
            trigger_new_code = trigger_map[idx][1]

            if event_id != trigger_id:
                new_events_list = np.concatenate([events[:idx, :], [[events[idx-1][0] + delta_time, 0, 65281]], events[idx:, :]]) 
                break
    
    # delete ghost event 65312
    if paradigm == 'SST' and (participant_id == 'SST-181'):
        for idx, event in enumerate(events):
            if event[2] == 65312:
                new_events_list = np.concatenate([events[:idx, :], events[idx+1:, :]]) 
                break  
    
    # delete ghost event 0: 130816
    if paradigm == 'SST' and (participant_id == 'SST-075'):
        for idx, event in enumerate(events):
            if event[2] == 130816:
                new_events_list = np.concatenate([events[:idx, :], events[idx+1:, :]]) 
                break   
    
    if paradigm == 'SST' and participant_id == 'SST-130':
        
        delta_time = 3
        for idx, event in enumerate(events):
            event_id = str(event[2])[-1]
            trigger_id = trigger_map[idx][0]
            trigger_new_code = trigger_map[idx][1]

            if event_id != trigger_id:
                new_events_list = np.concatenate([events[:idx, :], [[events[idx-1][0] + delta_time, 0, 65281]], events[idx:, :]]) 
                break

        for idx, event in enumerate(new_events_list):
            event_id = str(event[2])[-1]
            trigger_id = trigger_map[idx][0]
            trigger_new_code = trigger_map[idx][1]

            if event_id != trigger_id:
                new_events_list = np.concatenate([new_events_list[:idx, :], [[new_events_list[idx-1][0] + delta_time, 0, 65281]], new_events_list[idx:, :]]) 
                break
                           
    logger_preprocessing_info.info(f'EVENTS: {new_events_list}')

    assert len(new_events_list) == len(trigger_map), \
            f'The length of trigger map ({len(trigger_map)}) not equals length of events in eeg recording ({len(new_events_list)})'

    trigger_map_codes, mapping = create_triggers_dict(trigger_map)

    for idx, event in enumerate(new_events_list):
        event_id = str(event[2])[-1]
        trigger_id = trigger_map[idx][0]
        trigger_new_code = trigger_map[idx][1]
        
        if event_id != trigger_id:
            logger_errors_info.info(f'An event {idx} has different number than in provided file. {trigger_id} expected, {str(event[2])} found. Triggers may need to be checked.')

        trigger_new_code_int = trigger_map_codes[trigger_new_code]
        new_events_list[idx][2] = trigger_new_code_int

    annot_from_events = mne.annotations_from_events(
        events=new_events_list,
        event_desc=mapping,
        sfreq=raw.info["sfreq"],
        orig_time=raw.info["meas_date"],
    )
    raw_copy = raw.copy()
    raw_copy.set_annotations(annot_from_events)

    return raw_copy

def find_items_matching_regex(dictionary, regex_list):
    matching_items = {}
    for regex in regex_list:
        pattern = re.compile(regex)
        matching_items.update({key: value for key, value in dictionary.items() if pattern.match(key)})
    return matching_items

@dataclass
class ParticipantTriggerMappingContext:
    event_dict: dict
    events_mapping: dict
    new_event_dict: dict

    def __str__(self):
        return f"{self.event_dict}\n{self.events_mapping}\n{self.new_event_dict}"

def create_events_mappings(trigger_map, case='RE') -> ParticipantTriggerMappingContext:
    trigger_map_codes, mapping = create_triggers_dict(trigger_map)

    if case == 'RE':
        new_event_dict = {"correct_response": 0, "error_response": 1, "incorrect_go_response": 2}
        events_mapping = {
            'correct_response': [],
            'error_response': [],
            'incorrect_go_response' : []
        }
        
        # find response events from experimental blocks
        regex_pattern = [r'RE\*image\*.*\*0\*.*', r'RE\*image\*.*\*-\*.*']
        event_dict = find_items_matching_regex(trigger_map_codes, regex_pattern)
    
        for event_id in event_dict.keys():
            event_id_splitted = event_id.split('*')
    
            if (event_id_splitted[3] == '-') and (event_id_splitted[-1] == event_id_splitted[-2]):
                events_mapping['correct_response'].append(event_dict[event_id])
            elif (event_id_splitted[3] == '-') and (event_id_splitted[-1] != event_id_splitted[-2]):
                events_mapping['incorrect_go_response'].append(event_dict[event_id])
            elif (str(event_id_splitted[3]) == '0') and (event_id_splitted[-1] != '-'):
                events_mapping['error_response'].append(event_dict[event_id])
    
        
    elif case == 'STIM':
        new_event_dict = {"inhibited_stop": 0, "uninhibited_stop": 1}
        events_mapping = {
            'inhibited_stop': [],
            'uninhibited_stop': [],
        }
        # find all target stimuli events from experimental blocks
        regex_pattern = [r'ST.*']
        event_dict = find_items_matching_regex(trigger_map_codes, regex_pattern)

        for event_id in event_dict.keys():
            event_id_splitted = event_id.split('*')

            if (event_id_splitted[-1] == '-') and (event_id_splitted[3] == '0'):
                events_mapping['inhibited_stop'].append(event_dict[event_id])
            elif (event_id_splitted[-1] != '-') and (event_id_splitted[3] == '0'):
                events_mapping['uninhibited_stop'].append(event_dict[event_id])

    else:
        logger_errors_info('Not known case. Possible cases: \'RE\' for response, \'STIM\` for stimuli, and \`FBCK\` for feedback-locked events extraction.')
        # todo raise an Error
        event_dict = {}
        events_mapping = {}
        new_event_dict = {}
        
    return ParticipantTriggerMappingContext(event_dict=event_dict,
                                            events_mapping=events_mapping,
                                            new_event_dict=new_event_dict)

def create_epochs(
    raw,
    context: ParticipantTriggerMappingContext,
    tmin=-.1,
    tmax=.6,
    baseline=None,
    reject=None,
    reject_by_annotation=False,
    detrend=None
):
    # select specific events
    events, event_ids = mne.events_from_annotations(raw, event_id=context.event_dict)

    # Merge different events of one kind
    for mapping in context.events_mapping:
        events = mne.merge_events(
            events=events,
            ids=context.events_mapping[mapping],
            new_id=context.new_event_dict[mapping],
            replace_events=True,
        )
    
    # Read epochs
    epochs = mne.Epochs(
        raw=raw,
        events=events,
        event_id=context.new_event_dict,
        tmin=tmin,
        tmax=tmax,
        baseline=baseline,
        reject_by_annotation=reject_by_annotation,
        preload=True,
        reject=reject,
        picks=['eeg', 'eog'],
        detrend=detrend,
        on_missing = 'warn',
    )
    
    return epochs


def ocular_correction_gratton(epochs, subtract_evoked=False):

    if subtract_evoked:
        epochs_sub = epochs.copy().subtract_evoked()

        eog_model = mne.preprocessing.EOGRegression(
            picks="eeg",
            picks_artifact="eog"
        ).fit(epochs_sub)
    else:
        eog_model = mne.preprocessing.EOGRegression(
            picks="eeg",
            picks_artifact="eog"
        ).fit(epochs)

    epochs_clean_plain = eog_model.apply(epochs)

    return epochs_clean_plain

def find_bad_trials(epochs, picks=['FCz','Cz']):
    """
    """
    epochs_picked_channels = epochs.copy().pick(picks=picks)

    epochs_picked_channels.drop_bad()
    drop_log = epochs_picked_channels.drop_log

    for idx, _ in enumerate(epochs_picked_channels):
        epoch = epochs_picked_channels[idx]
        epoch_data = epoch.get_data(copy=True)
        
        for ch_name, ch_idx in zip(epochs_picked_channels.info['ch_names'], 
                                   np.arange(0, len(epochs_picked_channels.info['ch_names']))):
            channel_data = epoch_data[0,ch_idx,:]

            # EEG signal at the FCz or Cz site was greater than ± 150 μV were removed
            if(abs(channel_data) > 150e-6).any():
                logger_preprocessing_info.info(f'Channel {ch_name} exceeded +- 150 μV threshold at {idx} trail')
                new_drop_log_item = drop_log[idx] + (ch_name, ) if ch_name not in drop_log[idx] else drop_log[idx]
                drop_log = tuple(new_drop_log_item if i == idx else item for i, item in enumerate(drop_log))

    del epochs_picked_channels

    return drop_log


def reject_bad_trials(epochs, drop_log, picks=['FCz', 'Cz']):

    epochs_to_drop_indices = []
    clean_epochs = epochs.copy()

    assert len(clean_epochs) == len(drop_log), f'Length of epochs ({len(clean_epochs)}) not equals length of drop_log ({len(drop_log)}). Cannot mark trials as BAD.'
            
    for idx, item in enumerate(drop_log):
        if ('FCz' in item) or ('Cz' in item):
            logger_preprocessing_info.info(f'Rejecting trial {idx}. Artifacts at Fz or FCz')
            epochs_to_drop_indices.append(idx)

    clean_epochs = clean_epochs.drop(
        indices = epochs_to_drop_indices,
        reason = 'EXCEED 150uV',
    )

     # update drop_log
    for trial_idx in epochs_to_drop_indices:
        drop_log = tuple(('REJECTED',) if i == trial_idx else element for i, element in enumerate(drop_log))
    
    return clean_epochs, drop_log    

In [None]:
def pre_process_eeg(input_fname, participant_id, context, trigger_fname=None, tmin=-0.1, tmax=0.9):
    # 0. read bdf
    raw = mne.io.read_raw_bdf(
        input_fname,
        eog=['EXG1', 'EXG2', 'EXG3', 'EXG4'],
        exclude=['EXG5', 'EXG6'],
        preload=True
    )

    try:
        raw = raw.set_montage('biosemi64')
    except ValueError as e:
        if '[\'EXG7\', \'EXG8\']' in e.args[0]:
            raw = raw.set_montage('biosemi64', on_missing='ignore')
            logger_preprocessing_info.info('On missing')
        else:
            logger_errors_info.info('Lacks important channels!')

    # 1. replace trigger names
    trigger_map = read_trigger_map(trigger_fname)
    raw_new_triggers = replace_trigger_names(raw, participant_id, trigger_map)

    # 2. re-reference: to mastoids
    if '005' in participant_id:
        raw_ref = raw_new_triggers.copy().set_eeg_reference(ref_channels=['EXG7'])
        logger_errors_info.info('Referencing to EX7')
    elif '044' in participant_id:
        raw_ref = raw_new_triggers.copy().set_eeg_reference(ref_channels=['EXG8'])
        logger_errors_info.info('Referencing to EX8')
    else:
        raw_ref = raw_new_triggers.copy().set_eeg_reference(ref_channels=['EXG7', 'EXG8'])

    # 3. Resampling
    raw_resampled = raw_ref.copy().resample(sfreq=500)

    # 4. Detrending, Segmentation, and first baseline correction
    epochs = create_epochs(
        # raw_resampled,
        raw_resampled,
        tmin = tmin,
        tmax = tmax,
        baseline = (-0.1, 0),
        detrend = 1,
        context=context,
        reject = None,
        reject_by_annotation = False,
    )

    # 5. ocular artifact correction with Gratton
    epochs_eog_corrected = ocular_correction_gratton(epochs)
    
    # 6. Second re-baseline
    epochs_eog_corrected.apply_baseline()
    
    # 7. Mark bad trials
    drop_log = find_bad_trials(epochs_eog_corrected, picks='eeg')

    # 8. Reject bad trials
    clean_epochs, _ = reject_bad_trials(epochs_eog_corrected, drop_log)
    if len(clean_epochs) < 6:
        logger_errors_info.info(f'Participant has only {len(clean_epochs)} artifact-free trials')
        
        
    return epochs_eog_corrected, drop_log

In [None]:
def save_epochs_with_drop_log(epochs, drop_log, participant_id):
    item = pd.DataFrame({
        'epochs': [epochs],
        'drop_log': [drop_log],
    })

    item.to_pickle(f'{preprocessed_data_dir_path}preprocessed_{participant_id}.pkl')

    return logger_preprocessing_info.info('Epochs saved to pickle.')

## Base preprocessing

Set globals

In [None]:
# GNG | SST | Flanker
paradigm = 'SST'
# RE | STIM | FBCK
case = 'RE'
# todo think whether move global vars as paradigm and case info some kind of data/case class

Set paths base on globals values

In [None]:
trigger_dir_path = f'data/{paradigm}/raw/triggers/'
bdf_dir_path = f'data/{paradigm}/raw/bdfs/'
behavioral_dir_path = f'data/{paradigm}/behavioral/'
preprocessed_data_dir_path = f'data/joint/{paradigm}/preprocessed/{case}/'
logger_dir_path = f'data/joint/{paradigm}/'

Set output files for loggers

In [None]:
######## PREPROCESSING ##############################################
# Create a file handler for preprocessing and set the level to INFO
file_handler_preprocessing = logging.FileHandler(f'data/joint/{paradigm}/{case}_preprocessing.txt')
file_handler_preprocessing.setLevel(logging.INFO)

# Create a formatter and add it to the file handler for preprocessing
formatter_preprocessing = logging.Formatter('%(asctime)s - %(levelname)s - %(message)s')
file_handler_preprocessing.setFormatter(formatter_preprocessing)

# Add the file handler for method A to the logger for preprocessing
logger_preprocessing_info.addHandler(file_handler_preprocessing)

######## ERRORS ##############################################
# Create a file handler for errors and set the level to INFO
file_handler_errors = logging.FileHandler(f'data/joint/{paradigm}/{case}_errors.txt')
file_handler_errors.setLevel(logging.INFO)

# Create a formatter and add it to the file handler for errors
formatter_errors = logging.Formatter('%(asctime)s - %(levelname)s - %(message)s')
file_handler_errors.setFormatter(formatter_errors)

# Add the file handler for method A to the logger for preprocessing
logger_errors_info.addHandler(file_handler_errors)

##### MNE ###################################################
# Create logger for MNE logs
logger_f_name = f'data/joint/{paradigm}/{case}_MNE-logs.txt'
set_log_file(fname=logger_f_name, output_format="%(asctime)s - %(message)s", overwrite=None)

In [None]:
# console_handler = logging.StreamHandler()
# console_handler.setLevel(logging.DEBUG)  # Set the desired logging level

# logger_errors_info.addHandler(console_handler)
# logger_preprocessing_info.addHandler(console_handler)

Read participant IDs

In [None]:
id_list = [item.split('.')[0] for item in os.listdir(bdf_dir_path)]

Perform base preprocessing

In [None]:
for participant_id in id_list:
    print(f'{participant_id}\n')
    bdf_fname = f'{bdf_dir_path}{participant_id}.bdf'
    trigger_fname = f'{trigger_dir_path}triggerMap_{participant_id}.txt'

    logger_preprocessing_info.info(f'#### PARTICIPANT ID: {participant_id} #########')
    logger_errors_info.info(f'#### PARTICIPANT ID: {participant_id} #########')

    try:
        trigger_map = read_trigger_map(trigger_fname)
        participant_context = create_events_mappings(trigger_map)
        logger_preprocessing_info.info(f'Context: {participant_context}')

        epochs_preprocessed, drop_log = pre_process_eeg(
            input_fname=bdf_fname,
            participant_id=participant_id,
            context=participant_context,
            trigger_fname=trigger_fname,
            tmin=-0.1, 
            tmax=0.9,
        )

        save_epochs_with_drop_log(epochs_preprocessed, drop_log, participant_id)

    except Exception as e:        
        logger_errors_info.info(f"{e}")
    
    logger_preprocessing_info.info(f'\n')
    logger_errors_info.info(f'\n')

print(f'##########\n DONE\n')   
# Restore MNE logging to std out     
set_log_file(fname=None)

## Wavelet filtering

In [None]:
def calculate_wavelet_filter(grand_average, scales, central_freq = 6, signal_freq=500, threshold_point=0.85):
    results_per_channel = []
    for channel_grand_average in grand_average:
        x = channel_grand_average.flatten()
        t = np.linspace(-0.1, 0.9, len(x))
        # construct wavelet function
        wavelet = Wavelet(('morlet', {'mu': central_freq}))
        Wx, _ = cwt(x, wavelet, fs=signal_freq, scales=scales, padtype='wrap', l1_norm=True, nv=None)
        
        # # baseline
        # baseline_mean = np.mean(Wx[:, :50], axis=1, keepdims=True)
        # Wx = Wx - baseline_mean
        
        freq = scale_to_freq(scales, wavelet, N=len(x), fs=signal_freq)
        
        # Compute and normalize the power spectrum from the CWT coefficients
        power_spectrum = np.abs(Wx)**2
        normalized_power_spectrum = power_spectrum / np.sum(power_spectrum)

        # Flatten the normalized power spectrum for CDF calculation
        flattened_spectrum = normalized_power_spectrum.flatten()

        # Use the Kaplan–Meier estimator
        kmf = KaplanMeierFitter()
        kmf.fit(durations=flattened_spectrum, event_observed=np.ones_like(flattened_spectrum))

        # Get the CDF values from the Kaplan–Meier estimator
        cdf_values = 1 - kmf.survival_function_.KM_estimate

        # Calculate the threshold
        threshold = threshold_point * (np.max(cdf_values) - np.min(cdf_values)) + np.min(cdf_values)

        # Plot the empirical CDF and the filtering model
        plt.step(kmf.survival_function_.index, cdf_values, where='post', label='Empirical CDF')
        plt.axhline(threshold, color='red', linestyle='--', label='Threshold')
        plt.title('Empirical CDF and Filtering Model')
        plt.xlabel('Wavelet Coefficient')
        plt.ylabel('Cumulative Probability')
        plt.legend()
        plt.show()

        # Find the value of wavelets coefficient that are above threshold
        cutoff_wavelet_index = np.where(cdf_values > threshold)[0][0]
        cutoff_wavelet_coef = kmf.survival_function_.index[cutoff_wavelet_index]
        print(f'Estimated threshold value for wavelet coefficients: {cutoff_wavelet_coef}')

        cwt_result_threshold_mask = np.where(normalized_power_spectrum >= cutoff_wavelet_coef, 1, 0)

        # Plot the CWT result
        plt.figure(figsize=(12, 16))

        plt.subplot(4, 1, 1)
        plt.imshow(np.abs(Wx), extent=[t[0], t[-1], freq[-1], freq[0]], aspect='auto', cmap='jet')
        plt.colorbar(label='Magnitude')
        plt.title('CWT Magnitude')

        # 
        plt.subplot(4, 1, 2)
        plt.imshow(normalized_power_spectrum, extent=[t[0], t[-1], freq[-1], freq[0]], aspect='auto', cmap='jet')
        plt.colorbar(label='Magnitude')
        plt.title('Normalized Power Spectrum')

        plt.subplot(4, 1, 3)
        plt.imshow(cwt_result_threshold_mask, extent=[t[0], t[-1], freq[-1], freq[0]], aspect='auto', cmap='jet')
        plt.colorbar(label='Magnitude')
        plt.title('Threshold Mask')

        plt.subplot(4, 1, 4)
        plt.imshow(cwt_result_threshold_mask*np.abs(Wx), extent=[t[0], t[-1], freq[-1], freq[0]], aspect='auto', cmap='jet')
        plt.colorbar(label='Magnitude')
        plt.title('Thresholded Grand Average - Examle')
        plt.show()
        
        results_per_channel.append((cwt_result_threshold_mask, wavelet, scales))

    return results_per_channel

In [None]:
def signal_cwt(signal, scales, central_freq = 6, signal_freq=500):
    x = signal.flatten()

    # construct wavelet function
    wavelet = Wavelet(('morlet', {'mu': central_freq}))
    Wx, scales = cwt(x, wavelet, fs=signal_freq, scales=scales, padtype='wrap', l1_norm=True, nv=None)
    
    # baseline
    # baseline_mean = np.mean(Wx[:, :50], axis=1, keepdims=True)
    # Wx_baselined = Wx - baseline_mean

    return Wx

def epochs_to_tfr(epochs, scales, picks=['FCz', 'Cz'], events=['error_response']):
    '''
    
    :param epochs: 
    :param picks: 
    :param events: 
    :return: ndarray of shape (n_events, n_channels, n_freqs, n_timepoints)
    '''
    if events == 'all':
        epochs_picked = epochs.copy().pick(picks)
    else:    
        epochs_picked = epochs.copy()[events].pick(picks)
    tfr_epochs = []
    for idx, _ in enumerate(epochs_picked):
        epoch = epochs_picked[idx]
        epoch_data = epoch.get_data(copy=True)

        tfr_channel_data = []
        for ch_name, ch_idx in zip(epochs_picked.info['ch_names'],
                                   np.arange(0, len(epochs_picked.info['ch_names']))):
            channel_data = epoch_data[0,ch_idx,:]
    
            channel_wavelet_data = signal_cwt(channel_data, scales)
            tfr_channel_data.append(channel_wavelet_data)
        tfr_epochs.append(tfr_channel_data)
    
    tfr_epochs = np.array(tfr_epochs)
    return tfr_epochs

def filter_signal(Wx, x, mask, wavelet, scales):
    time_domain_signal = icwt(mask * Wx, wavelet, scales, nv=None, padtype='wrap', l1_norm=True, x_mean = np.mean(x))

    return time_domain_signal

def tfr_filter_epochs(tfr, original_signal, per_channel_cwt_results):
    filtered_epochs = []
    for tfr_epochs, org_epoch in zip(tfr,original_signal) :
        filtered_channel_data = []
        for idx, channel_data in enumerate(tfr_epochs):
            mask = per_channel_cwt_results[idx][0]
            wavelet = per_channel_cwt_results[idx][1]
            scales = per_channel_cwt_results[idx][2]
            signal = channel_data
            reconstructed_signal = filter_signal(signal, org_epoch[idx], mask, wavelet, scales)
            filtered_channel_data.append(reconstructed_signal)
        filtered_epochs.append(filtered_channel_data)

    filtered_epochs = np.array(filtered_epochs)    
    return filtered_epochs

In [None]:
def get_grand_average(path_to_dir, picks, event):
    id_list = [item.split('.')[0] for item in os.listdir(path_to_dir)]
    all_evokeds = []
    
    for id_ in id_list:
        preprocessed_epochs = pd.read_pickle(f'{path_to_dir}{id_}.pkl')
        clean_epochs, _ = reject_bad_trials(preprocessed_epochs['epochs'].to_numpy().flatten()[0], preprocessed_epochs['drop_log'].to_numpy().flatten()[0])
        if len(clean_epochs) < 6:
            logger_errors_info.info(f'Participant has only {len(clean_epochs)} artifact-free trials')
        else:
            all_evokeds.append(clean_epochs[event].average().get_data(picks=picks))
    
    all_evokeds = np.array(all_evokeds)    
    grand_average = np.mean(all_evokeds, axis=0)
    
    return grand_average

In [None]:
def create_wavelet_filter(path_to_dir, scales, picks=['FCz', 'Cz'], event='error_response', threshold=0.85, central_freq=6):
    grand_average = get_grand_average(path_to_dir, picks=picks, event=event)
    filter_per_channel = calculate_wavelet_filter(grand_average, scales=scales, central_freq=central_freq, threshold_point=threshold)
    
    return filter_per_channel

### Test wavelet filtering

#### 1. Test quality of wavelet invers transform

See: https://dsp.stackexchange.com/questions/87097/why-is-inverse-cwt-inexact-inaccurate/

Perform wavelet deconstruction and inverse transform

In [None]:
paradigm = 'SST'
case = 'RE'

In [None]:
preprocessed_data_dir_path = f'data/joint/{paradigm}/preprocessed/{case}/'
id_list = [item.split('.')[0] for item in os.listdir(preprocessed_data_dir_path)]

all_epochs_reconstructed = []
all_epochs_original = []
diffs = []
nv=None

# create scales
# preprocessed_epochs = pd.read_pickle(f'{preprocessed_data_dir_path}{id_list[0]}.pkl')
# epochs = preprocessed_epochs['epochs'].to_numpy().flatten()[0]
# drop_log = preprocessed_epochs['drop_log'].to_numpy().flatten()[0]

# x = epochs[event].get_data(picks=picks)[0].flatten()
# wavelet = Wavelet(('morlet', {'mu': 6}))

# Wx, scales = cwt(x, wavelet, fs=500, scales='log-piecewise', padtype='wrap', l1_norm=True, nv=nv)
# new_scales = scales[34:]
new_scales = np.geomspace(16,500,200)

for id_ in id_list:
    # read data
    preprocessed_epochs = pd.read_pickle(f'{preprocessed_data_dir_path}{id_}.pkl')
    epochs = preprocessed_epochs['epochs'].to_numpy().flatten()[0]
         
    drop_log = preprocessed_epochs['drop_log'].to_numpy().flatten()[0]
    
    epochs_data = epochs.copy()['error_response'].pick(['FCz'])
    
    # transform data into TFR space
    wavelet = Wavelet(('morlet', {'mu': 6}))
    
    participant_epochs_reconstructed = []
    participant_epochs_original = []
    participant_diffs = []
    for epoch in epochs_data:
        
        Wx, scales = cwt(epoch.flatten(), wavelet, fs=500, scales=new_scales, padtype='wrap', l1_norm=True, nv=None)
        time_domain_signal = icwt(Wx, wavelet, scales=new_scales, nv=None, padtype='wrap', l1_norm=True, x_mean=np.mean(epoch.flatten()))

        diff = abs(np.mean(epoch.flatten()) - np.mean(time_domain_signal))
        
        participant_epochs_reconstructed.append(time_domain_signal)
        participant_epochs_original.append(epoch.flatten())
        participant_diffs.append(diff) 
        
    all_epochs_reconstructed.append(participant_epochs_reconstructed)
    all_epochs_original.append(participant_epochs_original)
    diffs.append(participant_diffs)

Plot per participant grand average similarities

In [None]:
# 1- 30 Hz adjusted with mean od the original signal 
x = np.linspace(-0.25, 0.9, 501)
for i in range(0, len(all_epochs_reconstructed)):
    plt.figure()

    plt.plot(x, np.mean(all_epochs_original[i], axis=0).flatten())
    plt.plot(x, np.mean(all_epochs_reconstructed[i], axis=0).flatten() )

Calculate differences between original and reconstructed signals

In [None]:
epsilon = 1e-7 # 0.1 uV

for idx, paricipant in enumerate(diffs):
    print(f'IDX: {idx}\n{paricipant}\n')
    exceed = np.array([True if x > epsilon else False for x in paricipant])
    print(exceed)

Test wavelet inverse transform quality per-participant

In [None]:
x = np.linspace(-0.1, 0.9, np.array(all_epochs_reconstructed[0]).shape[-1])

idx = 100
participant_reconstructed = all_epochs_reconstructed[idx]
participant_original = all_epochs_original[idx]

for i in range(0, len(participant_reconstructed)):
    plt.figure()
    print(i)
    plt.plot(x, participant_original[i])
    plt.plot(x, participant_reconstructed[i])
    
    plt.show()

#### 2. Test thresholds and their impact into amplitude reduction

In [None]:
paradigm = 'SST'
case = 'RE'

In [None]:
preprocessed_data_dir_path = f'data/joint/{paradigm}/preprocessed/{case}/'
thresholds = np.arange(0.0, 1.0, 0.05)
picks = ['FCz']
event = 'error_response'

id_list = [item.split('.')[0] for item in os.listdir(preprocessed_data_dir_path)]
tfr_epochs_participants = []
epochs_participants = []

# create scales
# preprocessed_epochs = pd.read_pickle(f'{preprocessed_data_dir_path}{id_list[0]}.pkl')
# epochs = preprocessed_epochs['epochs'].to_numpy().flatten()[0]
# drop_log = preprocessed_epochs['drop_log'].to_numpy().flatten()[0]

# x = epochs[event].get_data(picks=picks)[0].flatten()
# wavelet = Wavelet(('morlet', {'mu': 6}))
new_scales = np.geomspace(16,500,200) # from 1 to 30 Hz

for id_ in id_list:
    # read data
    preprocessed_epochs = pd.read_pickle(f'{preprocessed_data_dir_path}{id_}.pkl')
    epochs = preprocessed_epochs['epochs'].to_numpy().flatten()[0]
    drop_log = preprocessed_epochs['drop_log'].to_numpy().flatten()[0]
    
    # save unfiltered data
    epochs_participants.append(epochs[event].get_data(picks=picks))
    
    # transform data into TFR space
    tfr_epochs = epochs_to_tfr(epochs, scales=new_scales, picks=picks, events=event)
    # save tfr data
    tfr_epochs_participants.append(tfr_epochs)

In [None]:
grand_average = get_grand_average(preprocessed_data_dir_path, picks=picks, event=event)

In [None]:
reconstructed_epochs_per_threshold = []
filtered_ths = []
for threshold in thresholds:
    print(f'Threshold: {threshold}')
    filter_per_channel = calculate_wavelet_filter(grand_average, scales=new_scales, signal_freq=500, central_freq=6, threshold_point=threshold)
    filtered_ths.append(filter_per_channel)
    
    all_epochs_reconstructed = []
    for idx, tfr_epochs in enumerate(tfr_epochs_participants):
        reconstructed_epochs = tfr_filter_epochs(
            tfr_epochs,
            epochs_participants[idx],
            per_channel_cwt_results = filter_per_channel
        )
        
    
        all_epochs_reconstructed.append(reconstructed_epochs)
    reconstructed_epochs_per_threshold.append(all_epochs_reconstructed)

Plot results

In [None]:
grand_average_per_threshold = np.array([np.mean([np.mean(participant, axis=0) for participant in threshold_], axis=0) for threshold_ in reconstructed_epochs_per_threshold])

grand_average_peak_amplitude = np.min(grand_average[0][50:150])
filtered_grand_averages_amplitudes = [np.min(item[50:150]) for item in grand_average_per_threshold[:,0,:]]

diffs = [item/grand_average_peak_amplitude for item in filtered_grand_averages_amplitudes]

fig, ax = plt.subplots()
plt.plot(thresholds, diffs)
ax.set_xticks(np.arange(0.1, 1.0, 0.1))
plt.axhline(y=0.99, c='r', linestyle='--')
plt.axvline(x=0.45, c='orange', linestyle='--')

# plt.axhline(y=1.01, c='r', linestyle='--')
# plt.axvline(x=0.4, c='orange', linestyle='--')

plt.xlabel("Threshold")
plt.ylabel("Amplitude reduction")

plt.show()

In [None]:
plt.figure(figsize=(12,10))
ax = plt.subplot(111)

x = np.linspace(-0.1, 0.9, np.array(all_epochs_reconstructed[0]).shape[-1])

plt.plot(x, grand_average.flatten(), linestyle='--', label='original signal')

for i in range(0, len(thresholds)):
    plt.plot(x, grand_average_per_threshold[i,0,:], label=str(round(thresholds[i], 2)))

plt.legend()
# ax.legend(bbox_to_anchor=(0.7, 1.0))

plt.xlabel("Time (s)")
plt.ylabel("Amplitude (V)")

plt.show()

### Perform wavelet filtering

In [None]:
######## PREPROCESSING ##############################################
# Create a custom logger for preprocessing INFO
logger_preprocessing_info = logging.getLogger('preprocessing_info')
logger_preprocessing_info.setLevel(logging.INFO)
logger_preprocessing_info.propagate = False

######## ERRORS ##############################################
# Create a custom logger for errors
logger_errors_info = logging.getLogger('errors')
logger_errors_info.setLevel(logging.INFO)

Set globals

In [None]:
# GNG | SST | Flanker
paradigm = 'SST'
case = 'RE'

Set paths base on globals values

In [None]:
preprocessed_data_dir_path = f'data/joint/{paradigm}/preprocessed/{case}/'
logger_dir_path = f'data/joint/{paradigm}/'

Set output files for loggers

In [None]:
######## PREPROCESSING ##############################################
# Create a file handler for preprocessing and set the level to INFO
file_handler_preprocessing = logging.FileHandler(f'data/joint/{paradigm}/{case}_wavelets_info.txt')
file_handler_preprocessing.setLevel(logging.INFO)

# Create a formatter and add it to the file handler for preprocessing
formatter_preprocessing = logging.Formatter('%(asctime)s - %(levelname)s - %(message)s')
file_handler_preprocessing.setFormatter(formatter_preprocessing)

# Add the file handler for method A to the logger for preprocessing
logger_preprocessing_info.addHandler(file_handler_preprocessing)

######## ERRORS ##############################################
# Create a file handler for errors and set the level to INFO
file_handler_errors = logging.FileHandler(f'data/joint/{paradigm}/{case}_wavelets_errors.txt')
file_handler_errors.setLevel(logging.INFO)

# Create a formatter and add it to the file handler for errors
formatter_errors = logging.Formatter('%(asctime)s - %(levelname)s - %(message)s')
file_handler_errors.setFormatter(formatter_errors)

# Add the file handler for method A to the logger for preprocessing
logger_errors_info.addHandler(file_handler_errors)

##### MNE ###################################################
# Create logger for MNE logs
logger_f_name = f'data/joint/{paradigm}/{case}_wavelets_MNE-logs.txt'
set_log_file(fname=logger_f_name, output_format="%(asctime)s - %(message)s", overwrite=None)

#### 1. Read all participants data and create grand averages per event

In [None]:
def create_clean_data_dict(path_to_dir):
    events_data_dict = dict()
    id_list = [item.split('.')[0] for item in os.listdir(path_to_dir)]

    for participant_id in id_list:
        preprocessed_epochs = pd.read_pickle(f'{path_to_dir}{participant_id}.pkl')

        epochs = preprocessed_epochs['epochs'].to_numpy().flatten()[0]
        drop_log = preprocessed_epochs['drop_log'].to_numpy().flatten()[0]

        assert len(drop_log) == len(epochs)

        epochs_copy = epochs.copy().pick('eeg')
        ch_names = epochs_copy.ch_names

        for trial_idx, _ in enumerate(epochs_copy):
            epoch = epochs_copy[trial_idx]
            epoch_type = list(epoch.event_id.keys())
            drop_log_item = drop_log[trial_idx]

            trial_data = []
            for idx, channel in enumerate(epoch.get_data(copy=True)[0]):
                ch_name = ch_names[idx]
                if ch_name in drop_log_item:
                    trial_data.append([np.nan] * len(channel))
                else:
                    trial_data.append(channel)

            trial_data = np.array(trial_data).reshape(len(ch_names), -1)
            events_data_dict.setdefault(epoch_type[0], []).append(trial_data)
    
    return events_data_dict

In [None]:
def create_grand_averages_dict(events_data_dict):
    grand_averages_dict = {}

    for key, array in events_data_dict.items():
        grand_average = np.nanmean(array, axis=0)
        grand_averages_dict[key] = grand_average

    return grand_averages_dict

Create dict of clean data per event to create grand averages

In [None]:
events_data_dict = create_clean_data_dict(path_to_dir = preprocessed_data_dir_path)
grand_averages_dict = create_grand_averages_dict(events_data_dict)
print(grand_averages_dict.keys())

In [None]:
# save grand averages dict
grand_averages_dict_ = {key: [value] for key, value in grand_averages_dict.items()}
grand_averages_dict_df = pd.DataFrame(grand_averages_dict_)

grand_averages_dict_df.to_pickle(f'data/joint/{paradigm}/grand_averages_dict.pkl')

#### 2. Create wavelet filter based on erroneous trials

In [None]:
event_of_interest = 'error_response'

In [None]:
scales = np.geomspace(16,500,200) # from 1 to 30 Hz
central_freq = 6
signal_freq = 500
threshold_point = 0.5 # chosen based on the threshold tests

In [None]:
# read grand averages dict
grand_averages_dict_df = pd.read_pickle(f'data/joint/{paradigm}/grand_averages_dict.pkl')
grand_averages_dict_df

In [None]:
wavelet_filters_per_channel = calculate_wavelet_filter(
    grand_average = grand_averages_dict_df[event_of_interest].to_numpy()[0], 
    scales = scales,
    central_freq = central_freq, 
    signal_freq=signal_freq, 
    threshold_point=threshold_point
)

In [None]:
print(f'Filter created for: {len(wavelet_filters_per_channel)} channels')
print(f'Number of scales: {len(wavelet_filters_per_channel[0][2])}')
print(f'Shape of Wx: {wavelet_filters_per_channel[0][0].shape}')

In [None]:
# save filter || TODO
# pd.DataFrame({'wavelet_filter_FLA': [wavelet_filters_per_channel]}).to_pickle(f'data/joint/{paradigm}/wavelet_filters_per_channel.pkl')

#### 3. Transform all epochs into tfrs, apply filter, and save

In [None]:
def read_behavioral_file(participant_id):
    
    behavioral_data_df = pd.read_csv(f'{behavioral_dir_path}beh_{participant_id}.csv')

    trial_numerator = 1
    trial_numbers = []
    for i in range(0, len(behavioral_data_df)):
        # if behavioral_data_df.iloc[i]['block_type'] != 'experiment':
        #     trial_numbers.append(0)
        # else:
        trial_numbers.append(trial_numerator)
        trial_numerator+=1
    
    behavioral_data_df['trial number'] = trial_numbers
    return behavioral_data_df

In [None]:
def save_epochs_with_behavioral_data_long(epochs, drop_log, participant_id, case='RE'):
    
    if paradigm == 'GNG':
    
        # read behavioral file
        behavioral_data_df = read_behavioral_file(participant_id)

        beh_data_uninhibited_nogo_responses_df = behavioral_data_df[
            (behavioral_data_df['block type'] == 'experiment') &
            (behavioral_data_df['trial type'] != 'go') &
            (behavioral_data_df['reaction'] == False)
            ]
        logger_preprocessing_info.info(f'Number of uninhibited NOGO trials: {len(beh_data_uninhibited_nogo_responses_df)}')

        beh_data_inhibited_nogo_responses_df = behavioral_data_df[
            (behavioral_data_df['block type'] == 'experiment') &
            (behavioral_data_df['trial type'] != 'go') &
            (behavioral_data_df['reaction'] == True)
            ]
        logger_preprocessing_info.info(f'Number of inhibited NOGO trials: {len(beh_data_inhibited_nogo_responses_df)}')

        beh_data_correct_go_responses_df = behavioral_data_df[
            (behavioral_data_df['block type'] == 'experiment') &
            (behavioral_data_df['trial type'] == 'go') &
            (behavioral_data_df['response'] == 'num_separator')
            ]
        logger_preprocessing_info.info(f'Number correct GO trials: {len(beh_data_correct_go_responses_df)}')

        results_df = pd.DataFrame()
        epochs_df = pd.DataFrame()
        behavioral_df = pd.DataFrame()

        if case == 'RE':
            behavioral_df = pd.concat([beh_data_uninhibited_nogo_responses_df, beh_data_correct_go_responses_df]).sort_values(by='trial number')

            logger_preprocessing_info.info(f'Len drop log: {len(drop_log)}')
            logger_preprocessing_info.info(f'Len behavioral df: {len(behavioral_df)}')
            assert len(behavioral_df) == len(drop_log), f'Number of events read from behavioral file ({len(behavioral_df)}) not equals number of events from drop_log ({len(drop_log)})'

            for idx, _ in enumerate(epochs):
                epoch = epochs[idx]
                epoch_type = list(epoch.event_id.keys())
                assert len(epoch_type) == 1, \
                    f'Single trial is not single. Length of epoch: {len(epoch_type)}. Error during trial-wise saving.'
                drop_log_item = drop_log[idx]

                this_df = pd.DataFrame({
                    'epoch': [epoch],
                    'event': epoch_type,
                    'drop_log': [drop_log_item],
                })

                epochs_df = pd.concat([epochs_df, this_df], ignore_index=True)

            # Set the indexes of epochs to match reactions
            indexes = behavioral_df.index
            epochs_df.set_index(indexes, inplace=True)
            results_df = pd.concat([behavioral_df, epochs_df], axis=1)

        elif case == 'STIM':
            behavioral_df = pd.concat([beh_data_uninhibited_nogo_responses_df, beh_data_correct_go_responses_df, beh_data_inhibited_nogo_responses_df]).sort_values(by='trial number')

            logger_preprocessing_info.info(f'Len drop log: {len(drop_log)}')
            logger_preprocessing_info.info(f'Len behavioral df: {len(behavioral_df)}')
            assert len(behavioral_df) == len(drop_log), f'Number of events read from behavioral file ({len(behavioral_df)}) not equals number of events from drop_log ({len(drop_log)})'

            for idx, _ in enumerate(epochs):
                epoch = epochs[idx]
                epoch_type = list(epoch.event_id.keys())
                assert len(epoch_type) == 1, \
                    f'Single trial is not single. Length of epoch: {len(epoch_type)}. Error during trial-wise saving.'
                drop_log_item = drop_log[idx]

                this_df = pd.DataFrame({
                    'epoch': [epoch],
                    'event': epoch_type,
                    'drop_log': [drop_log_item],
                })

                epochs_df = pd.concat([epochs_df, this_df], ignore_index=True)

            # Set the indexes of epochs to match reactions
            indexes = behavioral_df.index
            epochs_df.set_index(indexes, inplace=True)
            results_df = pd.concat([behavioral_df, epochs_df], axis=1)

        else:
            logger_preprocessing_info.info('Not implemented')

        assert len(results_df) == len(behavioral_df) == len(epochs_df), f'Length of trial-wise dataframe ({len(results_df)}) not equals number of events from behavioral file ({len(behavioral_df)}) and number of epochs ({len(epochs_df)})'

        results_df.to_pickle(f'{preprocessed_data_dir_path}preprocessed-beh_{participant_id}.pkl')
        logger_preprocessing_info.info('Epochs and behavioral data in long format saved to pickle.')
    
    if paradigm == 'FLA':
        # read behavioral file
        print('in saving')
        behavioral_data_df = read_behavioral_file(participant_id)

        beh_data_incorrect_incongruent_responses_df = behavioral_data_df[
            (behavioral_data_df['block_type'] == 'experiment') &
            (behavioral_data_df['trial_type'] == 'incongruent') &
            (behavioral_data_df['reaction'] == 'incorrect') &
            ((behavioral_data_df['response'] == 'l') | (behavioral_data_df['response'] == 'r'))
            ]
        logger_preprocessing_info.info(f'Number of incorrect incongruent trials: {len(beh_data_incorrect_incongruent_responses_df)}')

        beh_data_correct_incongruent_responses_df = behavioral_data_df[
            (behavioral_data_df['block_type'] == 'experiment') &
            (behavioral_data_df['trial_type'] == 'incongruent') &
            (behavioral_data_df['reaction'] == 'correct') &
            ((behavioral_data_df['response'] == 'l') | (behavioral_data_df['response'] == 'r'))
            ]
        logger_preprocessing_info.info(f'Number of correct incongruent trials: {len(beh_data_correct_incongruent_responses_df)}')
        
        beh_data_incorrect_congruent_responses_df = behavioral_data_df[
            (behavioral_data_df['block_type'] == 'experiment') &
            (behavioral_data_df['trial_type'] == 'congruent') &
            (behavioral_data_df['reaction'] == 'incorrect') &
            ((behavioral_data_df['response'] == 'l') | (behavioral_data_df['response'] == 'r'))
            ]
        logger_preprocessing_info.info(f'Number incorrect congruent trials: {len(beh_data_incorrect_congruent_responses_df)}')

        beh_data_correct_congruent_responses_df = behavioral_data_df[
            (behavioral_data_df['block_type'] == 'experiment') &
            (behavioral_data_df['trial_type'] == 'congruent') &
            (behavioral_data_df['reaction'] == 'correct') &
            ((behavioral_data_df['response'] == 'l') | (behavioral_data_df['response'] == 'r'))
            ]
        logger_preprocessing_info.info(f'Number correct congruent trials: {len(beh_data_correct_congruent_responses_df)}')

        results_df = pd.DataFrame()
        epochs_df = pd.DataFrame()
        behavioral_df = pd.DataFrame()
        
        if case == 'RE':
            behavioral_df = pd.concat([beh_data_incorrect_incongruent_responses_df, beh_data_correct_incongruent_responses_df, beh_data_incorrect_congruent_responses_df, beh_data_correct_congruent_responses_df]).sort_values(by='trial number')

            logger_preprocessing_info.info(f'Len drop log: {len(drop_log)}')
            logger_preprocessing_info.info(f'Len behavioral df: {len(behavioral_df)}')
            assert len(behavioral_df) == len(drop_log), f'Number of events read from behavioral file ({len(behavioral_df)}) not equals number of events from drop_log ({len(drop_log)})'

            for idx, _ in enumerate(epochs):
                epoch = epochs[idx]
                epoch_type = list(epoch.event_id.keys())
                assert len(epoch_type) == 1, \
                    f'Single trial is not single. Length of epoch: {len(epoch_type)}. Error during trial-wise saving.'
                drop_log_item = drop_log[idx]

                this_df = pd.DataFrame({
                    'epoch': [epoch],
                    'event': epoch_type,
                    'drop_log': [drop_log_item],
                })

                epochs_df = pd.concat([epochs_df, this_df], ignore_index=True)

            # Set the indexes of epochs to match reactions
            indexes = behavioral_df.index
            epochs_df.set_index(indexes, inplace=True)
            results_df = pd.concat([behavioral_df, epochs_df], axis=1)

        elif case == 'STIM':
            behavioral_df = pd.concat([beh_data_incorrect_incongruent_responses_df, beh_data_correct_incongruent_responses_df, beh_data_incorrect_congruent_responses_df, beh_data_correct_congruent_responses_df]).sort_values(by='trial number')

            logger_preprocessing_info.info(f'Len drop log: {len(drop_log)}')
            logger_preprocessing_info.info(f'Len behavioral df: {len(behavioral_df)}')
            assert len(behavioral_df) == len(drop_log), f'Number of events read from behavioral file ({len(behavioral_df)}) not equals number of events from drop_log ({len(drop_log)})'

            for idx, _ in enumerate(epochs):
                epoch = epochs[idx]
                epoch_type = list(epoch.event_id.keys())
                assert len(epoch_type) == 1, \
                    f'Single trial is not single. Length of epoch: {len(epoch_type)}. Error during trial-wise saving.'
                drop_log_item = drop_log[idx]

                this_df = pd.DataFrame({
                    'epoch': [epoch],
                    'event': epoch_type,
                    'drop_log': [drop_log_item],
                })

                epochs_df = pd.concat([epochs_df, this_df], ignore_index=True)

            # Set the indexes of epochs to match reactions
            indexes = behavioral_df.index
            epochs_df.set_index(indexes, inplace=True)
            results_df = pd.concat([behavioral_df, epochs_df], axis=1)

        else:
            logger_preprocessing_info.info('Not implemented')

        assert len(results_df) == len(behavioral_df) == len(epochs_df), f'Length of trial-wise dataframe ({len(results_df)}) not equals number of events from behavioral file ({len(behavioral_df)}) and number of epochs ({len(epochs_df)})'

        results_df.to_pickle(f'{preprocessed_data_dir_path}preprocessed-beh_{participant_id}.pkl')
        logger_preprocessing_info.info('Epochs and behavioral data in long format saved to pickle.')
     
    if paradigm == 'SST':
        # read behavioral file
        behavioral_data_df = read_behavioral_file(participant_id)

        beh_data_inhibited_stop_df = behavioral_data_df.iloc[30:][
            (behavioral_data_df['STOP_TYPE'] == 0) &
            (behavioral_data_df['RE_time'].isna())
            ]
        logger_preprocessing_info.info(f'Number of correctly inhibited STOP trials: {len(beh_data_inhibited_stop_df)}')

        beh_data_uninhibited_stop_df = behavioral_data_df.iloc[30:][
            (behavioral_data_df['STOP_TYPE'] == 0) &
            (behavioral_data_df['RE_time'].notna())
            ]
        logger_preprocessing_info.info(f'Number of incorrectly uninhibited STOP trials: {len(beh_data_uninhibited_stop_df)}')

        beh_data_correct_go_responses_df = behavioral_data_df.iloc[30:][
            (behavioral_data_df['STOP_TYPE'].isna()) &
            # (behavioral_data_df['STOP_TYPE'] != 1) &
            (behavioral_data_df['RE_key'] == behavioral_data_df['RE_true'])
            ]
        logger_preprocessing_info.info(f'Number correct GO trials: {len(beh_data_correct_go_responses_df)}')

        beh_data_incorrect_go_responses_df = behavioral_data_df.iloc[30:][
            (behavioral_data_df['STOP_TYPE'] != 0) &
            (behavioral_data_df['STOP_TYPE'] != 1) &
            (behavioral_data_df['RE_key'] != behavioral_data_df['RE_true']) &
            (behavioral_data_df['RE_key'].notna())
            ]
        logger_preprocessing_info.info(f'Number incorrect GO trials: {len(beh_data_incorrect_go_responses_df)}')

        results_df = pd.DataFrame()
        epochs_df = pd.DataFrame()
        behavioral_df = pd.DataFrame()
        
        if case == 'RE':
            behavioral_df = pd.concat([beh_data_uninhibited_stop_df, beh_data_incorrect_go_responses_df, beh_data_correct_go_responses_df]).sort_values(by='trial number')

            logger_preprocessing_info.info(f'Len drop log: {len(drop_log)}')
            logger_preprocessing_info.info(f'Len behavioral df: {len(behavioral_df)}')
            assert len(behavioral_df) == len(drop_log), f'Number of events read from behavioral file ({len(behavioral_df)}) not equals number of events from drop_log ({len(drop_log)})'

            for idx, _ in enumerate(epochs):
                epoch = epochs[idx]
                epoch_type = list(epoch.event_id.keys())
                assert len(epoch_type) == 1, \
                    f'Single trial is not single. Length of epoch: {len(epoch_type)}. Error during trial-wise saving.'
                drop_log_item = drop_log[idx]

                this_df = pd.DataFrame({
                    'epoch': [epoch],
                    'event': epoch_type,
                    'drop_log': [drop_log_item],
                })

                epochs_df = pd.concat([epochs_df, this_df], ignore_index=True)

            # Set the indexes of epochs to match reactions
            indexes = behavioral_df.index
            epochs_df.set_index(indexes, inplace=True)
            results_df = pd.concat([behavioral_df, epochs_df], axis=1)

        elif case == 'STIM':
            behavioral_df = pd.concat([beh_data_inhibited_stop_df, beh_data_uninhibited_stop_df]).sort_values(by='trial number')

            logger_preprocessing_info.info(f'Len drop log: {len(drop_log)}')
            logger_preprocessing_info.info(f'Len behavioral df: {len(behavioral_df)}')
            assert len(behavioral_df) == len(drop_log), f'Number of events read from behavioral file ({len(behavioral_df)}) not equals number of events from drop_log ({len(drop_log)})'

            for idx, _ in enumerate(epochs):
                epoch = epochs[idx]
                epoch_type = list(epoch.event_id.keys())
                assert len(epoch_type) == 1, \
                    f'Single trial is not single. Length of epoch: {len(epoch_type)}. Error during trial-wise saving.'
                drop_log_item = drop_log[idx]

                this_df = pd.DataFrame({
                    'epoch': [epoch],
                    'event': epoch_type,
                    'drop_log': [drop_log_item],
                })

                epochs_df = pd.concat([epochs_df, this_df], ignore_index=True)

            # Set the indexes of epochs to match reactions
            indexes = behavioral_df.index
            epochs_df.set_index(indexes, inplace=True)
            results_df = pd.concat([behavioral_df, epochs_df], axis=1)

        else:
            logger_preprocessing_info.info('Not implemented')

        assert len(results_df) == len(behavioral_df) == len(epochs_df), f'Length of trial-wise dataframe ({len(results_df)}) not equals number of events from behavioral file ({len(behavioral_df)}) and number of epochs ({len(epochs_df)})'

        results_df.to_pickle(f'{preprocessed_data_dir_path}wavelets/preprocessed-beh_{participant_id}.pkl')
        logger_preprocessing_info.info('Epochs and behavioral data in long format saved to pickle.')
        
    
    return results_df

In [None]:
picks = 'eeg'
scales = np.geomspace(16,500,200) # from 1 to 30 Hz
central_freq = 6
signal_freq = 500
# threshold_point = 0.45 # chosen based on the threshold tests

In [None]:
behavioral_dir_path = f'data/{paradigm}/behavioral/'

In [None]:
id_list = [item.split('.')[0] for item in os.listdir(preprocessed_data_dir_path)]
# id_list = id_list[49:]

epochs_participants_reconstructed = []
epochs_participants = []
# tfr_epochs_participants = []

for participant_id in id_list:
    # read data
    print(participant_id)
    logger_preprocessing_info.info(f'#### PARTICIPANT ID: {participant_id} #########')
    logger_errors_info.info(f'#### PARTICIPANT ID: {participant_id} #########')
    
    try:
        preprocessed_epochs = pd.read_pickle(f'{preprocessed_data_dir_path}{participant_id}.pkl')
        epochs = preprocessed_epochs['epochs'].to_numpy().flatten()[0]
        drop_log = preprocessed_epochs['drop_log'].to_numpy().flatten()[0]

        # save unfiltered data
        # epochs_participants.append(epochs.get_data(picks=picks))

        # transform data into TFR space
        tfr_epochs = epochs_to_tfr(
            epochs, 
            scales=scales, 
            picks=picks,
            events='all',
        )

        # tfr_epochs_participants.append(tfr_epochs)

        # filter data with created wavelet filter
        reconstructed_epochs = tfr_filter_epochs(
            tfr_epochs,
            epochs.get_data(picks=picks),
            per_channel_cwt_results = wavelet_filters_per_channel
        )
        # epochs_participants_reconstructed.append(reconstructed_epochs)
    
        assert epochs.get_data(picks=picks).shape == reconstructed_epochs.shape
    
        epochs_copy = epochs.copy().pick('eeg')

        e_arr = mne.EpochsArray(
                data = reconstructed_epochs, 
                info = epochs_copy.info,
                events = epochs_copy.events,
                tmin=epochs_copy.tmin,
            )

        _ = save_epochs_with_behavioral_data_long(
            e_arr,
            drop_log,
            participant_id.split('_')[1],
            case=case,
        )
    except Exception as e:        
        logger_errors_info.info(f"{e}")
    
    logger_preprocessing_info.info(f'\n')
    logger_errors_info.info(f'\n')

print(f'##########\n DONE\n')       

Check similarity between grand average of original and filtered signal per person

In [None]:
channel_idx = epochs.info.ch_names.index('FCz')
channel_idx

In [None]:
x = np.linspace(-0.1, 0.9, epochs_participants[0].shape[-1])
for i in range(0, len(epochs_participants)):
    plt.figure()
    
    plt.plot(x, np.mean(epochs_participants[i], axis=0)[channel_idx].flatten())
    plt.plot(x, np.mean(epochs_participants_reconstructed[i], axis=0)[channel_idx].flatten())
    
    plt.xlabel("Time (s)")
    plt.ylabel("Amplitude (V)")

---

In [None]:
def signal_cwt(signal, central_freq = 6, signal_freq=500):
    x = signal.flatten()

    # construct wavelet function
    wavelet = Wavelet(('morlet', {'mu': central_freq}))
    Wx, scales = cwt(x, wavelet, fs=signal_freq, scales='log-piecewise', padtype='wrap', l1_norm=True, nv=10)
    # Wx, scales = cwt(x, wavelet, fs=signal_freq, )

    return Wx

def epochs_to_tfr(epochs, picks=['FCz', 'Cz'], events=['error_response']):
    '''
    
    :param epochs: 
    :param picks: 
    :param events: 
    :return: ndarray of shape (n_events, n_channels, n_freqs, n_timepoints)
    '''
    epochs_picked = epochs.copy()[events].pick(picks)
    tfr_epochs = []
    for idx, _ in enumerate(epochs_picked):
        epoch = epochs_picked[idx]
        epoch_data = epoch.get_data(copy=True)

        tfr_channel_data = []
        for ch_name, ch_idx in zip(epochs_picked.info['ch_names'],
                                   np.arange(0, len(epochs_picked.info['ch_names']))):
            channel_data = epoch_data[0,ch_idx,:]
    
            channel_wavelet_data = signal_cwt(channel_data)
            tfr_channel_data.append(channel_wavelet_data)
        tfr_epochs.append(tfr_channel_data)
    
    tfr_epochs = np.array(tfr_epochs)
    return tfr_epochs

def filter_signal(Wx, x, mask, wavelet, scales):
    time_domain_signal = icwt(mask * Wx, wavelet, scales, nv=10, padtype='wrap', l1_norm=True, x_mean = np.mean(x))

    return time_domain_signal

# def tfr_filter_epochs(tfr, mask, wavelet, scales):
#     filtered_epochs = []
#     for epochs in tfr:
#         filtered_channel_data = []
#         for channel_data in epochs:
#             signal = channel_data
#             reconstructed_signal = filter_signal(signal, mask, wavelet, scales)
#             filtered_channel_data.append(reconstructed_signal)
#         filtered_epochs.append(filtered_channel_data)

#     filtered_epochs = np.array(filtered_epochs)    
#     return filtered_epochs

def tfr_filter_epochs(tfr, original_signal, per_channel_cwt_results):
    filtered_epochs = []
    for epochs, org_epoch in zip(tfr,original_signal) :
        filtered_channel_data = []
        for idx, channel_data in enumerate(epochs):
            mask = per_channel_cwt_results[idx][0]
            wavelet = per_channel_cwt_results[idx][1]
            scales = per_channel_cwt_results[idx][2]
            signal = channel_data
            reconstructed_signal = filter_signal(signal, org_epoch[idx], mask, wavelet, scales)
            filtered_channel_data.append(reconstructed_signal)
        filtered_epochs.append(filtered_channel_data)

    filtered_epochs = np.array(filtered_epochs)    
    return filtered_epochs

In [None]:
def calculate_wavelet_filter(grand_average, central_freq = 6, signal_freq=500, threshold_point=0.85):
    results_per_channel = []
    for channel_grand_average in grand_average:
        x = channel_grand_average.flatten()
        t = np.linspace(-0.1, 0.9, len(x))
        # construct wavelet function
        wavelet = Wavelet(('morlet', {'mu': central_freq}))
        Wx, scales = cwt(x, wavelet, fs=signal_freq, scales='log-piecewise', padtype='wrap', l1_norm=True, nv=10)

        freq = scale_to_freq(scales, wavelet, N=len(x), fs=signal_freq)
        # Compute and normalize the power spectrum from the CWT coefficients
        power_spectrum = np.abs(Wx)**2
        normalized_power_spectrum = power_spectrum / np.sum(power_spectrum)

        # Flatten the normalized power spectrum for CDF calculation
        flattened_spectrum = normalized_power_spectrum.flatten()

        # Use the Kaplan–Meier estimator
        kmf = KaplanMeierFitter()
        kmf.fit(durations=flattened_spectrum, event_observed=np.ones_like(flattened_spectrum))

        # Get the CDF values from the Kaplan–Meier estimator
        cdf_values = 1 - kmf.survival_function_.KM_estimate

        # Calculate the threshold
        threshold = threshold_point * (np.max(cdf_values) - np.min(cdf_values)) + np.min(cdf_values)

        # Plot the empirical CDF and the filtering model
        plt.step(kmf.survival_function_.index, cdf_values, where='post', label='Empirical CDF')
        plt.axhline(threshold, color='red', linestyle='--', label='Threshold')
        plt.title('Empirical CDF and Filtering Model')
        plt.xlabel('Wavelet Coefficient')
        plt.ylabel('Cumulative Probability')
        plt.legend()
        plt.show()

        # Find the value of wavelets coefficient that are above threshold
        cutoff_wavelet_index = np.where(cdf_values > threshold)[0][0]
        cutoff_wavelet_coef = kmf.survival_function_.index[cutoff_wavelet_index]
        print(f'Estimated threshold value for wavelet coefficients: {cutoff_wavelet_coef}')

        cwt_result_threshold_mask = np.where(normalized_power_spectrum >= cutoff_wavelet_coef, 1, 0)

        # Plot the CWT result
        plt.figure(figsize=(12, 16))

        plt.subplot(4, 1, 1)
        plt.imshow(np.abs(Wx), extent=[t[0], t[-1], freq[-1], freq[0]], aspect='auto', cmap='jet')
        plt.colorbar(label='Magnitude')
        plt.title('CWT Magnitude')

        # 
        plt.subplot(4, 1, 2)
        plt.imshow(normalized_power_spectrum, extent=[t[0], t[-1], freq[-1], freq[0]], aspect='auto', cmap='jet')
        plt.colorbar(label='Magnitude')
        plt.title('Normalized Power Spectrum')

        plt.subplot(4, 1, 3)
        plt.imshow(cwt_result_threshold_mask, extent=[t[0], t[-1], freq[-1], freq[0]], aspect='auto', cmap='jet')
        plt.colorbar(label='Magnitude')
        plt.title('Threshold Mask')

        plt.subplot(4, 1, 4)
        plt.imshow(cwt_result_threshold_mask*np.abs(Wx), extent=[t[0], t[-1], freq[-1], freq[0]], aspect='auto', cmap='jet')
        plt.colorbar(label='Magnitude')
        plt.title('Thresholded Grand Average - Examle')
        plt.show()
        
        results_per_channel.append((cwt_result_threshold_mask, wavelet, scales))

    return results_per_channel

In [None]:
def get_grand_average(path_to_dir, picks, event):
    id_list = [item.split('.')[0] for item in os.listdir(path_to_dir)]
    all_evokeds = []
    
    for id_ in id_list:
        preprocessed_epochs = pd.read_pickle(f'{path_to_dir}{id_}.pkl')
        clean_epochs, _ = reject_bad_trials(preprocessed_epochs['epochs'].to_numpy().flatten()[0], preprocessed_epochs['drop_log'].to_numpy().flatten()[0])
        if len(clean_epochs) < 6:
            logger_errors_info.info(f'Participant has only {len(clean_epochs)} artifact-free trials')
        else:
            all_evokeds.append(clean_epochs[event].average().get_data(picks=picks))
    
    all_evokeds = np.array(all_evokeds)    
    grand_average = np.mean(all_evokeds, axis=0)
    
    return grand_average

In [None]:
def create_wavelet_filter(path_to_dir, picks=['FCz', 'Cz'], event='error_response', threshold=0.85, central_freq=6):
    grand_average = get_grand_average(path_to_dir, picks=picks, event=event)
    filter_per_channel = calculate_wavelet_filter(grand_average, central_freq=central_freq, threshold_point=threshold)
    
    return filter_per_channel

## Wavelet filter

Set globals

In [None]:
# GNG | SST | Flanker
paradigm = 'SST'
case = 'RE'

Set paths base on globals values

In [None]:
preprocessed_data_dir_path = f'data/joint/{paradigm}/preprocessed/{case}/'
logger_dir_path = f'data/joint/{paradigm}/'

Set output files for loggers

In [None]:
######## PREPROCESSING ##############################################
# Create a file handler for preprocessing and set the level to INFO
file_handler_preprocessing = logging.FileHandler(f'data/joint/{paradigm}/{case}_wavelets_info.txt')
file_handler_preprocessing.setLevel(logging.INFO)

# Create a formatter and add it to the file handler for preprocessing
formatter_preprocessing = logging.Formatter('%(asctime)s - %(levelname)s - %(message)s')
file_handler_preprocessing.setFormatter(formatter_preprocessing)

# Add the file handler for method A to the logger for preprocessing
logger_preprocessing_info.addHandler(file_handler_preprocessing)

######## ERRORS ##############################################
# Create a file handler for errors and set the level to INFO
file_handler_errors = logging.FileHandler(f'data/joint/{paradigm}/{case}_wavelets_errors.txt')
file_handler_errors.setLevel(logging.INFO)

# Create a formatter and add it to the file handler for errors
formatter_errors = logging.Formatter('%(asctime)s - %(levelname)s - %(message)s')
file_handler_errors.setFormatter(formatter_errors)

# Add the file handler for method A to the logger for preprocessing
logger_errors_info.addHandler(file_handler_errors)

##### MNE ###################################################
# Create logger for MNE logs
logger_f_name = f'data/joint/{paradigm}/{case}_wavelets_MNE-logs.txt'
set_log_file(fname=logger_f_name, output_format="%(asctime)s - %(message)s", overwrite=None)

### Test wavelet filtering

#### Test quality of wavelet invers transform

See: https://dsp.stackexchange.com/questions/87097/why-is-inverse-cwt-inexact-inaccurate/

Perform wavelet deconstruction and inverse transform

In [None]:
preprocessed_data_dir_path = f'data/joint/{paradigm}/preprocessed/{case}/'
id_list = [item.split('.')[0] for item in os.listdir(preprocessed_data_dir_path)]

all_epochs_reconstructed = []
all_epochs_original = []
diffs = []
nv=None

# create scales
# preprocessed_epochs = pd.read_pickle(f'{preprocessed_data_dir_path}{id_list[0]}.pkl')
# epochs = preprocessed_epochs['epochs'].to_numpy().flatten()[0]
# drop_log = preprocessed_epochs['drop_log'].to_numpy().flatten()[0]

# x = epochs[event].get_data(picks=picks)[0].flatten()
# wavelet = Wavelet(('morlet', {'mu': 6}))

# Wx, scales = cwt(x, wavelet, fs=500, scales='log-piecewise', padtype='wrap', l1_norm=True, nv=nv)
# new_scales = scales[34:]
new_scales = np.geomspace(16,500,200)

for id_ in id_list:
    # read data
    preprocessed_epochs = pd.read_pickle(f'{preprocessed_data_dir_path}{id_}.pkl')
    epochs = preprocessed_epochs['epochs'].to_numpy().flatten()[0]
         
    drop_log = preprocessed_epochs['drop_log'].to_numpy().flatten()[0]
    
    epochs_data = epochs.copy()['error_response'].pick(['FCz'])
    
    # transform data into TFR space
    wavelet = Wavelet(('morlet', {'mu': 6}))
    
    participant_epochs_reconstructed = []
    participant_epochs_original = []
    participant_diffs = []
    for epoch in epochs_data:
        
        Wx, scales = cwt(epoch.flatten(), wavelet, fs=500, scales=new_scales, padtype='wrap', l1_norm=True, nv=None)
        time_domain_signal = icwt(Wx, wavelet, scales=new_scales, nv=None, padtype='wrap', l1_norm=True, x_mean=np.mean(epoch.flatten()))

        diff = abs(np.mean(epoch.flatten()) - np.mean(time_domain_signal))
        
        participant_epochs_reconstructed.append(time_domain_signal)
        participant_epochs_original.append(epoch.flatten())
        participant_diffs.append(diff) 
        
    all_epochs_reconstructed.append(participant_epochs_reconstructed)
    all_epochs_original.append(participant_epochs_original)
    diffs.append(participant_diffs)

Plot per participant grand average similarities

In [None]:
x = np.linspace(-0.1, 0.9, np.array(all_epochs_reconstructed[0]).shape[-1])
for i in range(0, len(all_epochs_reconstructed)):
    plt.figure()
    
    plt.plot(x, np.mean(all_epochs_original[i], axis=0).flatten())
    plt.plot(x, np.mean(all_epochs_reconstructed[i], axis=0).flatten())

Calculate differences between original and reconstructed signals

In [None]:
epsilon = 1e-7 # 0.1 uV

for idx, paricipant in enumerate(diffs):
    print(f'IDX: {idx}\n{paricipant}\n')
    exceed = np.array([True if x > epsilon else False for x in paricipant])
    print(exceed)

Test wavelet inverse transform quality per-participant

In [None]:
x = np.linspace(-0.1, 0.9, np.array(all_epochs_reconstructed[0]).shape[-1])

idx = 100
participant_reconstructed = all_epochs_reconstructed[idx]
participant_original = all_epochs_original[idx]

for i in range(0, len(participant_reconstructed)):
    plt.figure()
    print(i)
    plt.plot(x, participant_original[i])
    plt.plot(x, participant_reconstructed[i])
    
    plt.show()

#### Test thresholds and their impact into amplitude reduction

In [None]:
preprocessed_data_dir_path = f'data/joint/{paradigm}/preprocessed/{case}/'
thresholds = np.arange(0.0, 1.0, 0.05)
picks = ['FCz']
event = 'error_response'

id_list = [item.split('.')[0] for item in os.listdir(preprocessed_data_dir_path)]
tfr_epochs_participants = []
epochs_participants = []

# create scales
# preprocessed_epochs = pd.read_pickle(f'{preprocessed_data_dir_path}{id_list[0]}.pkl')
# epochs = preprocessed_epochs['epochs'].to_numpy().flatten()[0]
# drop_log = preprocessed_epochs['drop_log'].to_numpy().flatten()[0]

# x = epochs[event].get_data(picks=picks)[0].flatten()
# wavelet = Wavelet(('morlet', {'mu': 6}))
new_scales = np.geomspace(16,500,200) # from 1 to 30 Hz

for id_ in id_list:
    # read data
    preprocessed_epochs = pd.read_pickle(f'{preprocessed_data_dir_path}{id_}.pkl')
    epochs = preprocessed_epochs['epochs'].to_numpy().flatten()[0]
    drop_log = preprocessed_epochs['drop_log'].to_numpy().flatten()[0]
    
    # save unfiltered data
    epochs_participants.append(epochs[event].get_data(picks=picks))
    
    # transform data into TFR space
    tfr_epochs = epochs_to_tfr(epochs, scales=new_scales, picks=picks, events=event)
    # save tfr data
    tfr_epochs_participants.append(tfr_epochs)

Plot results

In [None]:
grand_average_per_threshold = np.array([np.mean([np.mean(participant, axis=0) for participant in threshold_], axis=0) for threshold_ in reconstructed_epochs_per_threshold])

grand_average_peak_amplitude = np.min(grand_average[0][50:150])
filtered_grand_averages_amplitudes = [np.min(item[50:150]) for item in grand_average_per_threshold[:,0,:]]

diffs = [item/grand_average_peak_amplitude for item in filtered_grand_averages_amplitudes]

fig, ax = plt.subplots()
plt.plot(thresholds[:-1], diffs)
ax.set_xticks(np.arange(0.1, 1.0, 0.1))
plt.axhline(y=0.85, c='r', linestyle='--')
plt.axvline(x=0.7, c='orange', linestyle='--')

plt.xlabel("Threshold")
plt.ylabel("Amplitude reduction")

plt.show()

In [None]:
plt.figure(figsize=(12,10))
ax = plt.subplot(111)

x = np.linspace(-0.1, 0.9, np.array(all_epochs_reconstructed[0]).shape[-1])

plt.plot(x, grand_average.flatten(), linestyle='--', label='original signal')

for i in range(0, len(thresholds[:-1])):
    plt.plot(x, grand_average_per_threshold[i,0,:], label=str(round(thresholds[i], 2)))

plt.legend()
# ax.legend(bbox_to_anchor=(0.7, 1.0))

plt.xlabel("Time (s)")
plt.ylabel("Amplitude (V)")

plt.show()

### Perform wavelet filtering

In [None]:
picks = ['FCz', 'Cz']

In [None]:
filter_per_channel = create_wavelet_filter(
    preprocessed_data_dir_path, 
    picks=picks, 
    event='error_response', 
    threshold=0.7,
    central_freq=6
)

Transform all epochs into tfrs and apply filter

Check similarity between grand average of original and filtered signal per person

In [None]:
x = np.linspace(-0.1, 0.9, all_epochs[0].shape[-1])
for i in range(0, len(all_epochs)):
    plt.figure()
    
    plt.plot(x, np.mean(all_epochs[i], axis=0)[0].flatten())
    plt.plot(x, np.mean(all_epochs_reconstructed2[i], axis=0)[0].flatten())
    
    plt.xlabel("Time (s)")
    plt.ylabel("Amplitude (V)")

---

## Generate PCA components:

1. Average filtered with wavelets single trials into grand average.
2. Create variability matrix base on the grand average: variability in peak latency and variability in waveform compression.
3. Perform PCA decomposition on set of modified ERPs. 

In [None]:
def ms_to_tp(value_in_ms, freq=500):
    """
    Only for relative conversion of the lengths
    :param value_in_ms: 
    :param freq: 
    :return: 
    """
    ms_unit = freq/1000
    value_in_tp = int(value_in_ms*ms_unit)
    return value_in_tp

def stretch(xs, coef, centre):
    """Scale a list by a coefficient around a point in the list.

    Parameters
    ----------
    xs : list
        Input values.
    coef : float
        Coefficient to scale by.
    centre : int
        Position in the list to use as a centre point.

    Returns
    -------
    list

    """
    grain = 100

    stretched_array = np.repeat(xs, grain * coef)
    stretched_array = np.array(stretched_array)
    result = [chunk.mean() for chunk in chunks(stretched_array, grain)]

    pivot_point = int(centre * coef)
    first = pivot_point - centre
    last = pivot_point + len(xs) - centre
    result = result[first:last]

    assert len(result) == len(xs), "Length should be preserved"
    return result


def chunks(iterable, n):
    """
    Yield successive n-sized chunks from iterable.
    Source: http://stackoverflow.com/questions/312443/how-do-you-split-a-list-into-evenly-sized-chunks-in-python#answer-312464

    """
    for i in range(0, len(iterable), n):
        yield iterable[i:i + n]

# todo: implement compressing
def generate_variability_matrix(X):
    '''
    
    :param X: ndarray of shape (n_timepoints,)
        Grand Average on given channel 
    :return: 
    '''
    # Find the peak latency in the grand average signal /in tp
    peak_latency_tp = np.argmin(X)
    print(f'Peak latency in tp: {peak_latency_tp}')

    # Parameters
    latency_shifts = np.arange(-ms_to_tp(60), ms_to_tp(60), ms_to_tp(5))  # From -50 to 50 ms in steps of 5 ms
    width_changes = np.arange(1, 1.5, 0.02)  # From 1 to 2 in steps of 0.05

    # Initialize a list to store modified ERP responses
    modified_responses = []

    # Enumerate through latency shifts and width changes
    for width_change in width_changes:
        for latency_shift in latency_shifts:
        # Apply latency shift
            evoked_shifted = np.roll(X.flatten(), int(latency_shift))

            # Calculate the stretched array
            evoked_stretched = np.array(stretch(evoked_shifted, coef=width_change, centre=peak_latency_tp))
            modified_responses.append(evoked_stretched)

    # Convert the list of modified responses to a numpy array
    modified_responses = np.array(modified_responses)

    return modified_responses

def create_variability_PCA_components(variability_matrix, n_components=3):
    pca = PCA(n_components=n_components)
    X = variability_matrix.T
    X_transformed = pca.fit_transform(X)
    
    return X_transformed

In [None]:
# 1. Create grand average of filtered signal
wavelets_path = f'data/joint/{paradigm}/preprocessed/wavelets/'
id_list = [item[:-4] for item in os.listdir(wavelets_path)]
all_epochs = []

for id_ in id_list:
    preprocessed_filtered_epochs_df = pd.read_pickle(f'{wavelets_path}{id_}.pkl')
    preprocessed_filtered_epochs = preprocessed_filtered_epochs_df['epochs'].to_numpy()[0]
    # clean_epochs, _ = reject_bad_trials(preprocessed_epochs['epochs'].to_numpy().flatten()[0], preprocessed_epochs['drop_log'].to_numpy().flatten()[0])
    # if len(clean_epochs) < 6:
    #     logger_errors_info.info(f'Participant has only {len(clean_epochs)} artifact-free trials')
    # else:
    #     all_evokeds.append(clean_epochs[event].average().get_data(picks=picks))

    all_epochs.append(preprocessed_filtered_epochs)
grand_average = np.mean(np.array([np.mean(item, axis=0) for item in all_epochs]), axis=0)
grand_average.shape

In [None]:
plt.plot(np.linspace(-0.1, 0.9, grand_average.shape[-1]), grand_average[0])

In [None]:
# 2. Generate variability matrix
variability_matrices = []
for channel_grand_average in grand_average:
    variability_matrix = generate_variability_matrix(channel_grand_average)
    variability_matrices.append(variability_matrix)
    
    plt.figure()
    
    sns.heatmap(
        variability_matrix,
        center=0,
        cmap='Spectral'  
    )

    plt.show()

In [None]:
# 3. get PCA components
pca_per_channel = []
for variability_matrix in variability_matrices:
    pca = create_variability_PCA_components(variability_matrices[0])
    pca_per_channel.append(pca)

In [None]:
PCA_comp = pca_per_channel[0].T
x = np.linspace(-0.1, 0.9, grand_average.shape[-1])
plt.plot(x, PCA_comp[0])
plt.plot(x, PCA_comp[1])
plt.plot(x, PCA_comp[2])

plt.show()

### Regress signal on PCA components

In [None]:
def regress_signal_on_PCA(epochs, PCA_list):
    filtered_epochs = []
    
    for epoch in epochs:
        filtered_channel_data = []
        for idx, channel_data in enumerate(epoch):
            pca = PCA_list[idx]
            lm = LinearRegression()
            lm.fit(X=pca, y=channel_data.flatten())
            epoch_pred = lm.predict(pca)
            filtered_channel_data.append(epoch_pred)
        filtered_epochs.append(filtered_channel_data)
    filtered_epochs = np.array(filtered_epochs)
    return filtered_epochs

In [None]:
all_epochs_pca_filtered = []
for participant_data in all_epochs:
    filtered_pca_epochs = regress_signal_on_PCA(participant_data, pca_per_channel)
    all_epochs_pca_filtered.append(filtered_pca_epochs)        

In [None]:
# filtered_pca_epochs = regress_signal_on_PCA(reconstructed_epochs, pca)

In [None]:
x = np.linspace(-0.1, 0.9, grand_average.shape[-1])

for i in range(0, len(all_epochs_pca_filtered)):
    plt.figure()

    # plt.plot(x, epochs_preprocessed['error_response'].average().get_data(picks=['FCz']).flatten(), label='original signal')
    plt.plot(x, np.mean(all_epochs[i], axis=0)[0], label = 'wavelet filtered signal')
    plt.plot(x, np.mean(all_epochs_pca_filtered[i], axis=0)[0], label='PCA filtered signal')

    plt.legend()

In [None]:
x = np.linspace(-0.1, 0.9, grand_average.shape[-1])

for i in range(0, len(all_epochs_pca_filtered)):
    plt.figure()

    # plt.plot(x, epochs_preprocessed['error_response'].average().get_data(picks=['FCz']).flatten(), label='original signal')
    plt.plot(x, np.mean(all_epochs[i], axis=0)[0], label = 'wavelet filtered signal')
    plt.plot(x, np.mean(all_epochs_pca_filtered[i], axis=0)[0], label='PCA filtered signal')

    plt.legend()

In [None]:
x = np.linspace(-0.1, 0.9, grand_average.shape[-1])
idx = 222

for i in range(0, len(all_epochs[idx])):
    plt.figure()

    # plt.plot(x, epochs_preprocessed['error_response'].average().get_data(picks=['FCz']).flatten(), label='original signal')
    plt.plot(x, all_epochs[idx][i][0], label = 'wavelet filtered signal')
    plt.plot(x, all_epochs_pca_filtered[idx][i][0], label='PCA filtered signal')

    plt.legend()

---
## For testing

In [None]:
input_fname = 'data/raw/A-GNG-000.bdf'
raw = mne.io.read_raw_bdf(
    input_fname,
    eog=['EXG1', 'EXG2', 'EXG3', 'EXG4'],
    exclude=['EXG5', 'EXG6'],
    preload=True
)

try:
    raw = raw.set_montage('biosemi64')
except ValueError as e:
    if '[\'EXG7\', \'EXG8\']' in e.args[0]:
        raw = raw.set_montage('biosemi64', on_missing='ignore')
        print('On missing')
    else:
        print('Lacks important channels!')


file_path = 'data/raw/triggerMap_A-GNG-000.txt'
trigger_map = read_trigger_map(file_path)
raw_new_triggers = replace_trigger_names(raw, trigger_map)

In [None]:
# 1. re-reference: to mastoids
raw_ref = raw_new_triggers.copy().set_eeg_reference(ref_channels=['EXG7', 'EXG8'])

In [None]:
# 2. Resampling
raw_resampled = raw_ref.copy().resample(sfreq=500)

In [None]:
# # (Filter)
# # 2. 4-th order Butterworth filters
# raw_filtered = raw_resampled.copy().filter(
#     l_freq=.1,
#     h_freq=30.0,
#     n_jobs=10,
#     method='iir',
#     iir_params=None,
#     picks=['eeg', 'eog']
# )

In [None]:
# 3. Detrending, Segmentation, and first baseline correction

epochs = create_epochs(
    # raw_resampled,
    raw_resampled,
    tmin = -.1,
    tmax = .9,
    baseline = (-0.1, 0),
    detrend = 1,
    events_to_select = response_event_dict,  # response_event_dict
    new_events_dict = new_response_event_dict,  # new_response_event_dict
    events_mapping = events_mapping,  # events_mapping
    reject = None,
    reject_by_annotation = False,
)

In [None]:
# 4. ocular artifact correction with Gratton
epochs_eog_corrected = ocular_correction_gratton(epochs)

In [None]:
# 5. Second re-baseline
epochs_eog_corrected.apply_baseline()

In [None]:
# 6. Mark bad trials
drop_log = find_bad_trials(epochs_eog_corrected, picks=['FCz','Cz'])
drop_log

In [None]:
# 7. Reject bad trials
clean_epochs, drop_log = reject_bad_trials(epochs_eog_corrected, drop_log)
print(clean_epochs)

In [None]:
fig = clean_epochs.copy().pick(['FCz']).average().plot()

## Wavelets transform

In [None]:
pick = ['FCz']
grand_average = clean_epochs['error_response'].average().get_data(picks=pick)

In [None]:
def calculate_wavelet_filter(grand_average, central_freq = 6, signal_freq=500):
    x = grand_average.flatten()
    
    # construct wavelet function
    wavelet = Wavelet(('morlet', {'mu': central_freq}))
    Wx, scales = cwt(x, wavelet, fs=500)
    
    freq = scale_to_freq(scales, wavelet, N=len(x), fs=signal_freq)
    print(freq)
    # Compute and normalize the power spectrum from the CWT coefficients
    power_spectrum = np.abs(Wx)**2
    normalized_power_spectrum = power_spectrum / np.sum(power_spectrum)
    
    # Flatten the normalized power spectrum for CDF calculation
    flattened_spectrum = normalized_power_spectrum.flatten()
    
    # Use the Kaplan–Meier estimator from the lifelines library
    kmf = KaplanMeierFitter()
    kmf.fit(durations=flattened_spectrum, event_observed=np.ones_like(flattened_spectrum))
    
    # Get the CDF values from the Kaplan–Meier estimator
    cdf_values = 1 - kmf.survival_function_.KM_estimate
    
    # Calculate the threshold
    threshold = 0.85 * (np.max(cdf_values) - np.min(cdf_values)) + np.min(cdf_values)
    
    # Plot the empirical CDF and the filtering model
    plt.step(kmf.survival_function_.index, cdf_values, where='post', label='Empirical CDF')
    plt.axhline(threshold, color='red', linestyle='--', label='Threshold')
    plt.title('Empirical CDF and Filtering Model')
    plt.xlabel('Wavelet Coefficient')
    plt.ylabel('Cumulative Probability')
    plt.legend()
    plt.show()
    
    # find the value of wavelets coefficient that are above threshold
    cutoff_wavelet_index = np.where(cdf_values > threshold)[0][0]
    cutoff_wavelet_coef = kmf.survival_function_.index[cutoff_wavelet_index]
    print(f'Estimated threshold value for wavelet coefficients: {cutoff_wavelet_coef}')
    
    cwt_result_threshold_mask = np.where(normalized_power_spectrum >= cutoff_wavelet_coef, 1, 0)
    
    # Plot the CWT result
    plt.figure(figsize=(12, 16))  
    
    plt.subplot(4, 1, 1)
    plt.imshow(abs_cwt, extent=[t[0], t[-1], freq[-1], freq[0]], aspect='auto', cmap='jet')
    plt.colorbar(label='Magnitude')
    plt.title('CWT Magnitude')
    
    # 
    plt.subplot(4, 1, 2)
    plt.imshow(normalized_power_spectrum, extent=[t[0], t[-1], freq[-1], freq[0]], aspect='auto', cmap='jet')
    plt.colorbar(label='Magnitude')
    plt.title('Normalized Power Spectrum')
    
    plt.subplot(4, 1, 3)
    plt.imshow(cwt_result_threshold_mask, extent=[t[0], t[-1], freq[-1], freq[0]], aspect='auto', cmap='jet')
    plt.colorbar(label='Magnitude')
    plt.title('Threshold Mask')
    
    plt.subplot(4, 1, 4)
    plt.imshow(cwt_result_threshold_mask*np.abs(Wx), extent=[t[0], t[-1], freq[-1], freq[0]], aspect='auto', cmap='jet')
    plt.colorbar(label='Magnitude')
    plt.title('Thresholded Grand Average - Examle')
    plt.show()
    
    return cwt_result_threshold_mask, wavelet, scales

def filter_signal(Wx, mask, wavelet, scales):
    time_domain_signal = icwt(mask * Wx, wavelet, scales)
    
    return time_domain_signal

In [None]:
cwt_result_threshold_mask, wavelet, scales = calculate_wavelet_filter(grand_average, central_freq=6)

In [None]:
x = grand_average.flatten()

# construct wavelet function
wavelet_this = Wavelet(('morlet', {'mu': 6}))
Wx, scales_this = cwt(x, wavelet, fs=500)

In [None]:
filtered_signal = filter_signal(Wx, cwt_result_threshold_mask, wavelet, scales)

In [None]:
plt.plot(x)
plt.plot(filtered_signal)

In [None]:
freq = scale_to_freq(scales, wavelet, N=len(x), fs=500)

In [None]:
import pandas as pd
data = pd.DataFrame(abs(Wx), index=freq, columns=np.linspace(-0.1, 0.9, len(x)))

In [None]:
data

In [None]:
sns.heatmap(data, cmap='jet', )

## Regress single trial ERP on PCA components

In [None]:
lm = LinearRegression()
epochs_regressed=[]

for idx, _ in enumerate(clean_epochs['error_response']):
    epoch_data = clean_epochs['error_response'][idx].get_data(picks='FCz', tmin=-0.1, tmax=0.5).flatten()
    lm.fit(X=X_transformed, y=epoch_data)
    epoch_pred = lm.predict(X_transformed)
    epochs_regressed.append(epoch_pred)
    
epochs_regressed = np.array(epochs_regressed)
print(epochs_regressed.shape)

## Generate PCA components:

1. Average filtered with wavelets single trials into grand average.
2. Create variability matrix base on the grand average: variability in peak latency and variability in waveform compression.
3. Perform PCA decomposition on set of modified ERPs. 

In [None]:
# 1. Create grand average

In [None]:
pick = ['FCz']
grand_average = clean_epochs['error_response'].average().get_data(picks=pick, tmin=-0.1, tmax=0.5)

In [None]:
def ms_to_tp(value_in_ms, freq=500):
    """
    Only for relative conversion of the lengths
    :param value_in_ms: 
    :param freq: 
    :return: 
    """
    ms_unit = freq/1000
    value_in_tp = int(value_in_ms*ms_unit)
    return value_in_tp

def stretch(xs, coef, centre):
    """Scale a list by a coefficient around a point in the list.

    Parameters
    ----------
    xs : list
        Input values.
    coef : float
        Coefficient to scale by.
    centre : int
        Position in the list to use as a centre point.

    Returns
    -------
    list

    """
    grain = 100

    stretched_array = np.repeat(xs, grain * coef)
    stretched_array = np.array(stretched_array)
    result = [chunk.mean() for chunk in chunks(stretched_array, grain)]

    pivot_point = int(centre * coef)
    first = pivot_point - centre
    last = pivot_point + len(xs) - centre
    result = result[first:last]

    assert len(result) == len(xs), "Length should be preserved"
    return result


def chunks(iterable, n):
    """
    Yield successive n-sized chunks from iterable.
    Source: http://stackoverflow.com/questions/312443/how-do-you-split-a-list-into-evenly-sized-chunks-in-python#answer-312464

    """
    for i in range(0, len(iterable), n):
        yield iterable[i:i + n]


def generate_variability_matrix(X):
    '''
    
    :param X: ndarray of shape (n_timepoints,)
        Grand Average on given channel 
    :return: 
    '''
    # Find the peak latency in the grand average signal /in tp
    peak_latency_tp = np.argmin(X)
    print(f'Peak latency in tp: {peak_latency_tp}')

    # Parameters
    latency_shifts = np.arange(-ms_to_tp(50), ms_to_tp(50), ms_to_tp(2))  # From -50 to 50 ms in steps of 5 ms
    width_changes = np.arange(1, 1.5, 0.02)  # From 1 to 2 in steps of 0.05

    # Initialize a list to store modified ERP responses
    modified_responses = []

    # Enumerate through latency shifts and width changes
    for width_change in width_changes:
        for latency_shift in latency_shifts:
        # Apply latency shift
            evoked_shifted = np.roll(X.flatten(), int(latency_shift))

            # Calculate the stretched array
            evoked_stretched = np.array(stretch(evoked_shifted, coef=width_change, centre=peak_latency_tp))
            modified_responses.append(evoked_stretched)

    # Convert the list of modified responses to a numpy array
    modified_responses = np.array(modified_responses)

    return modified_responses

In [None]:
# 2. Generate variability matrix
variability_matrix = generate_variability_matrix(grand_average)

sns.heatmap(
    variability_matrix,
    center=0,
    cmap='Spectral'  
)

plt.show()

In [None]:
# 3. fit PCA
pca = PCA(n_components=3)
X = variability_matrix.T
X_transformed = pca.fit_transform(X)

In [None]:
PCA_comp = X_transformed.T
x = np.linspace(0, 0.5, len(variability_matrix[0]))
plt.plot(x, PCA_comp[0])
plt.plot(x, PCA_comp[1])
plt.plot(x, PCA_comp[2])

plt.show()

In [None]:
# def pre_process_eeg(input_fname, trigger_map=None, parameters=None):
#     raw = mne.io.read_raw_bdf(input_fname, eog=['EX7', 'EX8'])
# 
#     # 1. re-reference: to mastoids
#     raw.set_eeg_reference(ref_channels=['M1', 'M2'])
# 
#     # 2. segmentation -100 to 900 ms around the response
#     epochs = create_epochs(raw_filtered, tmin=-.1, tmax=.9)
# 
#     # 3. ocular artifact correction with ICA
#     refined_epochs = ocular_correction_gratton(epochs)
# 
#     # 6. Second re-baseline
#     refined_epochs.apply_baseline()
#
#     # 7. Find bad trials: trials in which the EEG signal at the FCz or Cz site was greater than ± 150 μV are marked
#     drop_log = find_bad_trials(refined_epochs, picks=['FCz','Cz'])
# 
#     # 9. Wavelet filter (1 to 30 Hz in steps of 0.3 Hz)
#     # todo

#     # 10. Slicing wavelets: -100 - 500 around response
#     # todo
#
#     # 11. PCA on grand average of inverted wavelets (after wavelets -> invert to get signal, average, do PCA)
#     # todo -> this on cleaned_epochs = reject_bad_trials(refined_epochs, drop_log))
#
#     # 12. Regression: Y (invert single-trial wavelets) = PCA_3 .fit(); y_hat = .predict()
#     # todo
#
#     # 13. peak amplitude of y_hat (single trail denoised signal)
#     # todo
#


In [None]:
response_event_dict = {
    'Stimulus/RE*ex*1_n*1_c_1*R*FB': 10003,
    'Stimulus/RE*ex*1_n*1_c_1*R*FG': 10004,
    'Stimulus/RE*ex*1_n*1_c_2*R': 10005,
    'Stimulus/RE*ex*1_n*2_c_1*R': 10006,
    'Stimulus/RE*ex*2_n*1_c_1*R': 10007,
    'Stimulus/RE*ex*2_n*2_c_1*R*FB': 10008,
    'Stimulus/RE*ex*2_n*2_c_1*R*FG': 10009,
    'Stimulus/RE*ex*2_n*2_c_2*R': 10010,
}

new_response_event_dict = {"correct_response": 0, "error_response": 1}

events_mapping = {
    'correct_response': [10003, 10004, 10008, 10009],
    'error_response': [10005, 10006, 10007, 10010],
}

In [None]:
raw = mne.io.read_raw_brainvision(
    vhdr_fname = 'data/GNG_AA0303-64 el.vhdr', preload=True
)

In [None]:
raw_resampled = raw.copy().resample(sfreq=500)

In [None]:
# 2. 4-th order Butterworth filters
raw_filtered = raw_resampled.copy().filter(
        l_freq=.1,
        h_freq=30.0,
        n_jobs=10,
        method='iir',
        iir_params=None,
)

In [None]:
epochs = create_epochs(
    raw_filtered,
    tmin=-.1,
    tmax=.5,
    baseline=(-0.1, 0),
    detrend=1,
    events_to_select=response_event_dict,  # response_event_dict
    new_events_dict=new_response_event_dict,  # new_response_event_dict
    events_mapping=events_mapping,  # events_mapping
    reject=None,
    reject_by_annotation=False,
)

In [None]:
epochs_copy = epochs.copy()
epochs_picked_channels = epochs_copy.pick(picks=['FCz', 'Cz'])

epochs_picked_channels.drop_bad()
drop_log = epochs_picked_channels.drop_log
print(drop_log)

# channels with more than a 30 μV difference with the nearest six neighbors
for idx, _ in enumerate(epochs_picked_channels):
    epoch = epochs[idx]
    epoch_data = epoch.get_data(copy=True)
    # epoch_data[0] = np.random.normal(0,5,epoch_data[0].shape)
    for ch_name, ch_idx in zip(epochs_picked_channels.info['ch_names'], np.arange(0, len(epochs_picked_channels.info['ch_names']))):
        channel_data = epoch_data[0,ch_idx,:]

        # EEG signal at the FCz or Cz site was greater than ± 150 μV were removed
        if(abs(channel_data) > 150e-6).any():
            print(f'BAD------ trail index {idx}, channel: {ch_name}')
            new_drop_log_item = drop_log[idx] + (ch_name, ) if ch_name not in drop_log[idx] else drop_log[idx]
            drop_log = tuple(new_drop_log_item if i == idx else item for i, item in enumerate(drop_log))


In [None]:
drop_log

In [None]:

epochs_copy = epochs.copy()
cleaned_epochs = reject_bad_trials(epochs_copy, drop_log)
evokes = []
picks=['FCz', 'Cz']

for ch_name in picks:
    evoked = cleaned_epochs.copy().pick(picks=ch_name)['error_response'].average()
    print(evoked)
    evokes.append(evoked)
    # print(X.shape)

#     pca = PCA(n_components=3)
#     X_transformed = pca.fit_transform(X)
#     transformed_evokes[ch_name] = X_transformed

In [None]:
# if len(epochs) == len(drop_log):
#     epochs_to_drop_indices = []
#     for idx, item in enumerate(drop_log):
#         if ('FCz' in item) or ('Cz' in item):
#             print(f'In item: {idx}')
#             epochs_to_drop_indices.append(idx)
# 
#     clean_epochs = epochs.copy().drop(
#         indices = epochs_to_drop_indices,
#         reason = 'EXCEED 150uV', 
#     )
#     
# else:
#     print(f'Epochs length is not equal drop_log length:\nepochs: {len(epochs)}\ndrop_log{len(drop_log)}')
#     

In [None]:
# _, value = mne.preprocessing.peak_finder(evokes[0].copy().crop(0.02, 0.1).get_data().flatten(), extrema=-1)
# idx, _ = np.where(evokes[0] == value)
_, lat, amp = evokes[1].get_peak(tmin=0.02, tmax=0.1, return_amplitude=True, mode='abs')
print(lat)

In [None]:
fig = evokes[0].plot()

In [None]:
evokes[0].get_data().shape

In [None]:
plt.plot(evokes[1])
plt.axvline(x = idx, color = 'b')
plt.plot(np.roll(evokes[1],12))

In [None]:
idx

In [None]:
grand_average = clean_epochs['error_response'].average().get_data(picks='FCz', tmin=-0.1, tmax=0.5)
grand_average = grand_average.flatten()

In [None]:
# x = np.linspace(-0.1, 0.5, len(grand_average))
# plt.plot(x, grand_average)
# 
# plt.show()

In [None]:
vm = generate_variability_matrix(grand_average)

In [None]:
# sns.heatmap(
#     vm,
#     center=0,
#     cmap='Spectral',
#     # xticklabels=np.arange(0, len(grand_average))
# )
# plt.show()

In [None]:
from sklearn.decomposition import PCA
pca = PCA(n_components=3)
X = vm.T
X_transformed = pca.fit_transform(X)

In [None]:
print(X_transformed.shape)

PCA_comp = X_transformed.T

x = np.linspace(0, 0.5, len(vm[0]))
plt.plot(x, PCA_comp[0])
plt.plot(x, PCA_comp[1])
plt.plot(x, PCA_comp[2])

plt.show()

In [None]:
from sklearn.linear_model import LinearRegression

results=[]

lm = LinearRegression()

for idx, _ in enumerate(clean_epochs['error_response']):
    epoch_data = clean_epochs['error_response'][idx].get_data(picks='FCz', tmin=-0.1, tmax=0.5).flatten()
    lm.fit(X=X_transformed, y=epoch_data)
    epoch_pred = lm.predict(X_transformed)
    results.append(epoch_pred)

In [None]:
len(results)

In [None]:
x = np.linspace(-0.1, 0.5, len(grand_average))

for i in range(0, len(results)):
    plt.figure()
    plt.plot(x, clean_epochs['error_response'][i].get_data(picks='FCz', tmin=-0.1, tmax=0.5).flatten())
    plt.plot(x, results[i].flatten())
    
plt.show()

In [None]:
latency_shifts = np.arange(-ms_to_tp(50), ms_to_tp(50), ms_to_tp(5))  # From -50 to 50 ms in steps of 5 ms
len(latency_shifts)

In [None]:
plt.plot(np.linspace(-0.1, 0.5, len(vm[0])), vm[0])
plt.plot(np.linspace(-0.1, 0.5, len(vm[0])), vm[1])
plt.plot(np.linspace(-0.1, 0.5, len(vm[0])), vm[2])
plt.plot(np.linspace(-0.1, 0.5, len(vm[0])), vm[3])
plt.plot(np.linspace(-0.1, 0.5, len(vm[0])), vm[4])
plt.plot(np.linspace(-0.1, 0.5, len(vm[0])), vm[5])
plt.plot(np.linspace(-0.1, 0.5, len(vm[0])), vm[6])
plt.plot(np.linspace(-0.1, 0.5, len(vm[0])), vm[7])
plt.plot(np.linspace(-0.1, 0.5, len(vm[0])), vm[8])
plt.plot(np.linspace(-0.1, 0.5, len(vm[0])), vm[9])


plt.show()

In [None]:
a = this_evoked
b = a
for i in range(0,1):
    b = np.array(stretch(b, 2, centre=28))

In [None]:
print(len(a))
print(len(b))

In [None]:
plt.plot(a)
plt.axvline(x=175)

In [None]:
plt.plot(b)
