# Reducing and Filtering of Chord-Oddball data

In [None]:
import mne
import matplotlib
import matplotlib.pyplot as plt
import sys
sys.path.insert(1, "../")
import ccs_eeg_utils
import numpy as np
import mne.preprocessing as prep
import os
import sklearn 
from contextlib import contextmanager
from autoreject import AutoReject
import json

from mne_bids import (BIDSPath, read_raw_bids, write_raw_bids, inspect_dataset)

matplotlib.use('Qt5Agg')

%matplotlib qt

# path where dataset is stored
bids_root = "./data/ds003570/"
TASK = 'AuditoryOddballChords'
SUBJECT = '014'
SUPRESS_BIDS_OUTPUT = True
PROMPT_BADS = False
USE_ICA_JSON = True

In [None]:
# Context manager to suppress stdout and stderr
@contextmanager
def suppress_stdout_stderr():
    with open(os.devnull, 'w') as devnull:
        old_stdout = sys.stdout
        old_stderr = sys.stderr
        sys.stdout = devnull
        sys.stderr = devnull
        try:
            yield
        finally:
            sys.stdout = old_stdout
            sys.stderr = old_stderr

def read_raw_data(subject_id):
    bids_path = BIDSPath(subject=subject_id,
                         datatype='eeg', suffix='eeg', task=TASK,
                         root=bids_root)

    if SUPRESS_BIDS_OUTPUT:
        with suppress_stdout_stderr():
            raw = read_raw_bids(bids_path)
    else:
        raw = read_raw_bids(bids_path)

    # Inplace?
    ccs_eeg_utils.read_annotations_core(bids_path, raw)
    raw.load_data()
    
    return raw, bids_path

def preprocessing(raw):
    # TODO: bandpass first, downsample later? -> expensive!
    # 1. Downsampling to 128 Hz
    if raw.info['sfreq'] > 128:
        raw.resample(128)

    # Set channel types to EEG if not already set
    if not all(ch_type in ['eeg', 'stim'] for ch_type in raw.get_channel_types()):
        eeg_channel_names = raw.ch_names
        channel_types = {name: 'eeg' for name in eeg_channel_names}
        raw.set_channel_types(channel_types)

    # 2. Band-pass filter between 0.5 Hz and 30 Hz
    raw.filter(0.5, 30, fir_design='firwin')

    # 3. Re-referencing to the average activity of all electrodes
    #TODO: add apply_proj() here to apply arp?
    raw.set_eeg_reference('average', projection=True)

    """ events = prep.find_eog_events(raw)
    print(events) """

    # 5. Data Reduction (optional)
    # For instance, crop the first 60 seconds of the data

    return raw


def save_preprocessed_data(file_path, raw):
    """
    Saves the preprocessed EEG data to a file.

    Parameters:
    file_path (str): The path where the preprocessed data will be saved.
    raw (mne.io.Raw): The preprocessed MNE Raw object containing EEG data.
    """
    # Check if file_path ends with .fif extension
    if not file_path.endswith('.fif'):
        file_path += '.fif'

    # Save the data
    try:
        raw.save(file_path, overwrite=True)
        print(f"Data saved successfully to {file_path}")
    except Exception as e:
        print(f"Error saving data: {e}")

# see https://neuraldatascience.io/7-eeg/erp_artifacts.html
def get_ica(data, ica_bads, block_idx):
    data.set_montage('standard_1020', match_case=False)
    ica = mne.preprocessing.ICA(method="fastica", random_state=0)
  
    ica.fit(data, verbose=True)

    if USE_ICA_JSON:
        exclude_components = ica_bads[f"sub-{SUBJECT}"][block_idx]
    else:
        ica.plot_components()
        #ica.plot_properties(data)
        # components to be excluded
        # add python input to determine which components to exclude
        input_str = input("Enter index of the components to be separated by space: ")  
        # Converting input string to a list of integers  
        exclude_components = input_str.split()  
        exclude_components = [int(num) for num in exclude_components]  
        
        ica_bads[f"sub-{SUBJECT}"][block_idx] = exclude_components

    # Printing the list  
    print("List of components:", exclude_components) 
    ica.exclude = exclude_components
    reconst_raw = data.copy()
    # apply with excluded components
    reconst_raw = ica.apply(reconst_raw)
    
    return reconst_raw


def split_in_blocks(raw):
    events, event_id = mne.events_from_annotations(raw)

    # Identify indices of 'STATUS:boundary' events
    boundary_events = events[events[:, 2] == event_id['STATUS:boundary'], 0]

    # Split the data into blocks
    blocks = []
    start_idx = 0
    for end_idx in boundary_events:
        block = raw.copy().crop(tmin=raw.times[start_idx], tmax=raw.times[end_idx])
        blocks.append(block)
        start_idx = end_idx

    block = raw.copy().crop(tmin=raw.times[start_idx])
    blocks.append(block)    

    return blocks


def get_epochs_from_events(data, event_str, reaction_time_threshold=None):
    evts,evts_dict = mne.events_from_annotations(data)

    # not sure what annotations to use here
    # wanted_keys = [e for e in evts_dict.keys() if 'STATUS:two' in e]

    # deviant events (get all deviant events)
    deviant_keys = [e for e in evts_dict.keys() if event_str in e]

    evts_dict_stim=dict((k, evts_dict[k]) for k in deviant_keys if k in evts_dict)

    data.info.normalize_proj()

    # EITHER use total threshold -250uV - 250uV 
    reject = dict(eeg=0.0004) # in V
    epochs = mne.Epochs(data, evts, evts_dict_stim, tmin=-0.4, tmax=1.6, baseline=(-0.4,0), preload=True, reject=reject)


    if reaction_time_threshold:
        correct_keys = [e for e in evts_dict.keys() if 'Correct -' in e]
        correct_evt_ids = [evts_dict[key] for key in correct_keys]

        # get all event times
        event_times = epochs.events[:, 0] / epochs.info['sfreq']

        # if a correct event exists less than 600ms in the epoch (epochs starts 400ms before deviant + 200ms reactiontime), remove deviant event
        epochs_to_remove = []
        for i in range(len(event_times) - 1):
            current_time = event_times[i]
            next_time = event_times[i + 1]

            # Check if the next event is a correct one and within 600ms
            if (epochs.events[i + 1, 2] in correct_evt_ids and 
                (next_time - current_time) < 0.6):
                epochs_to_remove.append(i)

        # Remove the epochs with unrealistic (<200ms) reaction time
        epochs_to_keep = [i for i in range(len(epochs)) if i not in epochs_to_remove]
        epochs = epochs[epochs_to_keep]

    return epochs


def interpolate_bads_and_merge(blocks):
    # Interpolate bad channels
    for block in blocks:
        block.interpolate_bads()

    # Merge the blocks
    raw = mne.concatenate_raws(blocks)

    return raw


def create_bad_json_structure():
    subjects = {}
    for s in range(1, 41):
        subject_key = f'sub-{s:03d}'
        subjects[subject_key] = {}
        for b in range(1, 9):
            block_key = f'{b}'
            subjects[subject_key][block_key] = []

    return subjects

def set_bad_channels_from_json(blocks, bad_json):
    for block in blocks:
        # Set bad channels
        block.info['bads'] = bad_json[f"sub-{SUBJECT}"][f"{blocks.index(block)+1}"]

    return blocks

In [None]:
# reduce bids eeg data
if not os.path.isfile(f"./data/processed_{SUBJECT}_raw.fif"):
    raw, bids_path = read_raw_data(SUBJECT)
    channel_types = {ch: 'eeg' for ch in raw.ch_names}
    raw.set_channel_types(channel_types)

    raw.set_montage('standard_1020', match_case=False)
    blocks = split_in_blocks(raw.copy())

    if os.path.isfile("./data/bad_channels.json"):
        bads = json.load(open("./data/bad_channels.json"))
        blocks = set_bad_channels_from_json(blocks, bads)
    else:
        bads = create_bad_json_structure()

    if os.path.isfile("./data/bad_ica_components.json"):
        ica_bads = json.load(open("./data/bad_ica_components.json"))
        blocks = set_bad_channels_from_json(blocks, bads)
    else:
        ica_bads = create_bad_json_structure()

    if PROMPT_BADS == True:
        for b in blocks:
            b.plot(n_channels=64)
            plt.show(block=True)
            bads[f"sub-{SUBJECT}"][f"{blocks.index(b)+1}"] = b.info['bads']
    
    with open("./data/bad_channels.json", "w") as f:
        json.dump(bads, f)

    ica_blocks = []

    for b in blocks:
        b.interpolate_bads()
        prep_block = preprocessing(b.copy())

        # ICA
        ica_block = get_ica(prep_block, ica_bads, f"{blocks.index(b)+1}")
        ica_blocks.append(ica_block)
    
    with open("./data/bad_ica_components.json", "w") as f:
        json.dump(ica_bads, f)

    prep_raw = mne.concatenate_raws(ica_blocks)
    save_preprocessed_data(f"./data/processed_{SUBJECT}_raw.fif", prep_raw)

else:
    prep_raw = mne.io.read_raw_fif(f"./data/processed_{SUBJECT}_raw.fif", preload=True)

prep_raw.info

In [None]:
series = prep_raw[0,:]
#print(series[0][0])
plt.plot(series[0][0])
plt.show()
plt.plot(series[0][0])
plt.show()

## Epochs

In [None]:
# test with some electrodes

raw_subselect = prep_raw.copy()#.pick(["Cz", "T7", "T8", "P3", "P4"])
raw_subselect.annotations

standard_epochs = get_epochs_from_events(raw_subselect, '_S')
exemplar_epochs = get_epochs_from_events(raw_subselect, '_deviantEcorrect_E', reaction_time_threshold=0.2)
function_epochs = get_epochs_from_events(raw_subselect, '_deviantFcorrect_F', reaction_time_threshold=0.2)

In [None]:
def epoch_rejection(epochs, shape):
    epochs_concat = epochs.copy()
    concat = []
    epochs_concat_removed = np.zeros(shape=shape)
    print("len prior: ", len(epochs_concat))

    for epoch in epochs_concat:
        for channel in epoch:
            concat = np.concatenate((concat, channel))
        std_concat = np.std(concat)
        for channel in epoch:
            channel_max = np.max(abs(channel))
            std_channel = np.std(channel)

            if(channel_max < (5 * std_concat)  and channel_max < (250 * 1e-6)  and channel_max < (5 * std_channel)):
                np.append(epochs_concat_removed, channel)

    print("len after: ", len(epochs_concat_removed))

    return epochs_concat_removed



In [None]:
exemplar_epochs = epoch_rejection(exemplar_epochs.get_data(), np.array(exemplar_epochs.get_data()).shape )
# exemplar_epochs = epoch_rejection_prob(exemplar_epochs.get_data())

function_epochs = epoch_rejection(function_epochs.get_data(), np.array(function_epochs.get_data()).shape)
# function_epochs = epoch_rejection_prob(function_epochs.get_data())

standard_epochs = epoch_rejection(standard_epochs.get_data(), np.array(standard_epochs.get_data()).shape)
# standard_epochs = epoch_rejection_prob(standard_epochs.get_data())

In [None]:
from ipynb.fs.defs.analysis import generate_AUC_ROC
step = 1

roc_exemplar = generate_AUC_ROC(standard_epochs, exemplar_epochs, 7, step)
#print(roc_exemplar)
roc_function = generate_AUC_ROC(standard_epochs, function_epochs, 7, step)
#print(roc_function)
time = [i*step/128 -0.4 for i in range(len(roc_exemplar))]

plt.plot(time, roc_exemplar, label="exemplar")
plt.plot(time, roc_function, label="function")
plt.legend()
plt.show()


In [None]:
print(np.array(exemplar_epochs.get_data()).shape)
print(np.array(function_epochs.get_data()).shape)
print(np.array(standard_epochs.get_data()).shape)