# Reducing and Filtering of Chord-Oddball data

In [65]:
import mne
import matplotlib
import matplotlib.pyplot as plt
import sys
sys.path.insert(1, "../")
import ccs_eeg_utils
import numpy as np
import mne.preprocessing as prep
import os
import sklearn 
from contextlib import contextmanager
from autoreject import AutoReject
import json

from mne_bids import (BIDSPath, read_raw_bids, write_raw_bids, inspect_dataset)

matplotlib.use('Qt5Agg')

%matplotlib qt

# path where dataset is stored
bids_root = "./data/ds003570/"
TASK = 'AuditoryOddballChords'
SUBJECT = '014'
SUPRESS_BIDS_OUTPUT = True
PROMPT_BADS = False
USE_ICA_JSON = True

In [66]:
# Context manager to suppress stdout and stderr
@contextmanager
def suppress_stdout_stderr():
    with open(os.devnull, 'w') as devnull:
        old_stdout = sys.stdout
        old_stderr = sys.stderr
        sys.stdout = devnull
        sys.stderr = devnull
        try:
            yield
        finally:
            sys.stdout = old_stdout
            sys.stderr = old_stderr

def read_raw_data(subject_id):
    bids_path = BIDSPath(subject=subject_id,
                         datatype='eeg', suffix='eeg', task=TASK,
                         root=bids_root)

    if SUPRESS_BIDS_OUTPUT:
        with suppress_stdout_stderr():
            raw = read_raw_bids(bids_path)
    else:
        raw = read_raw_bids(bids_path)

    # Inplace?
    ccs_eeg_utils.read_annotations_core(bids_path, raw)
    raw.load_data()
    
    return raw, bids_path

def preprocessing(raw):
    # TODO: bandpass first, downsample later? -> expensive!
    # 1. Downsampling to 128 Hz
    if raw.info['sfreq'] > 128:
        raw.resample(128)

    # Set channel types to EEG if not already set
    if not all(ch_type in ['eeg', 'stim'] for ch_type in raw.get_channel_types()):
        eeg_channel_names = raw.ch_names
        channel_types = {name: 'eeg' for name in eeg_channel_names}
        raw.set_channel_types(channel_types)

    # 2. Band-pass filter between 0.5 Hz and 30 Hz
    raw.filter(0.5, 30, fir_design='firwin')

    # 3. Re-referencing to the average activity of all electrodes
    #TODO: add apply_proj() here to apply arp?
    raw.set_eeg_reference('average', projection=True)

    """ events = prep.find_eog_events(raw)
    print(events) """

    # 5. Data Reduction (optional)
    # For instance, crop the first 60 seconds of the data

    return raw


def save_preprocessed_data(file_path, raw):
    """
    Saves the preprocessed EEG data to a file.

    Parameters:
    file_path (str): The path where the preprocessed data will be saved.
    raw (mne.io.Raw): The preprocessed MNE Raw object containing EEG data.
    """
    # Check if file_path ends with .fif extension
    if not file_path.endswith('.fif'):
        file_path += '.fif'

    # Save the data
    try:
        raw.save(file_path, overwrite=True)
        print(f"Data saved successfully to {file_path}")
    except Exception as e:
        print(f"Error saving data: {e}")

# see https://neuraldatascience.io/7-eeg/erp_artifacts.html
def get_ica(data, ica_bads, block_idx):
    data.set_montage('standard_1020', match_case=False)
    ica = mne.preprocessing.ICA(method="fastica", random_state=0)
  
    ica.fit(data, verbose=True)

    if USE_ICA_JSON:
        exclude_components = ica_bads[f"sub-{SUBJECT}"][block_idx]
    else:
        ica.plot_components()
        #ica.plot_properties(data)
        # components to be excluded
        # add python input to determine which components to exclude
        input_str = input("Enter index of the components to be separated by space: ")  
        # Converting input string to a list of integers  
        exclude_components = input_str.split()  
        exclude_components = [int(num) for num in exclude_components]  
        
        ica_bads[f"sub-{SUBJECT}"][block_idx] = exclude_components

    # Printing the list  
    print("List of components:", exclude_components) 
    ica.exclude = exclude_components
    reconst_raw = data.copy()
    # apply with excluded components
    reconst_raw = ica.apply(reconst_raw)
    
    return reconst_raw


def split_in_blocks(raw):
    events, event_id = mne.events_from_annotations(raw)

    # Identify indices of 'STATUS:boundary' events
    boundary_events = events[events[:, 2] == event_id['STATUS:boundary'], 0]

    # Split the data into blocks
    blocks = []
    start_idx = 0
    for end_idx in boundary_events:
        block = raw.copy().crop(tmin=raw.times[start_idx], tmax=raw.times[end_idx])
        blocks.append(block)
        start_idx = end_idx

    block = raw.copy().crop(tmin=raw.times[start_idx])
    blocks.append(block)    

    return blocks


def get_epochs_from_events(data, event_str, min_reaction_s=None, max_reaction_s=None):
    evts, evts_dict = mne.events_from_annotations(data)

    # Identify deviant events
    deviant_keys = [e for e in evts_dict.keys() if e.endswith(event_str)]
    correct_keys = [e for e in evts_dict.keys() if "Correct -" in e]

    # Construct a dictionary of deviant events
    evts_dict_stim = {}
    for key in deviant_keys:
        evts_dict_stim[key] = evts_dict[key]

    data.info.normalize_proj()

    # Reject threshold
    reject = dict(eeg=0.0004)  # in V
    epochs = mne.Epochs(data, evts, evts_dict_stim, tmin=-0.4, tmax=1.6, baseline=(-0.4, 0), preload=True, reject=reject)

    if min_reaction_s and max_reaction_s:
        # Calculate reaction times
        reaction_times = []
        for key in deviant_keys:
            d_evts = evts[evts[:, 2] == evts_dict_stim[key]]
            for d_evt in d_evts:
                # Find the closest "Correct -" event
                c_evts = evts[(evts[:, 2] == evts_dict[correct_keys[0]]) & (evts[:, 0] > d_evt[0])]
                if len(c_evts) > 0:
                    # Calculate the reaction time
                    reaction_time = (c_evts[0][0] - d_evt[0]) / data.info['sfreq']  # convert to ms
                    reaction_times.append((reaction_time, key))

        # Filter epochs based on reaction time
        valid_epochs = [i for i, (rt, _) in enumerate(reaction_times) if min_reaction_s <= rt <= max_reaction_s]
        print(f"Filtered {len(valid_epochs)} epochs out of {len(epochs)} based on reaction time threshold")
        epochs = epochs[valid_epochs]

    return epochs


def interpolate_bads_and_merge(blocks):
    # Interpolate bad channels
    for block in blocks:
        block.interpolate_bads()

    # Merge the blocks
    raw = mne.concatenate_raws(blocks)

    return raw


def create_bad_json_structure():
    subjects = {}
    for s in range(1, 41):
        subject_key = f'sub-{s:03d}'
        subjects[subject_key] = {}
        for b in range(1, 9):
            block_key = f'{b}'
            subjects[subject_key][block_key] = []

    return subjects

def set_bad_channels_from_json(blocks, bad_json):
    for block in blocks:
        # Set bad channels
        block.info['bads'] = bad_json[f"sub-{SUBJECT}"][f"{blocks.index(block)+1}"]

    return blocks

In [67]:
# reduce bids eeg data
if not os.path.isfile(f"./data/processed_{SUBJECT}_raw.fif"):
    raw, bids_path = read_raw_data(SUBJECT)
    channel_types = {ch: 'eeg' for ch in raw.ch_names}
    raw.set_channel_types(channel_types)

    raw.set_montage('standard_1020', match_case=False)
    blocks = split_in_blocks(raw.copy())

    if os.path.isfile("./data/bad_channels.json"):
        bads = json.load(open("./data/bad_channels.json"))
        blocks = set_bad_channels_from_json(blocks, bads)
    else:
        bads = create_bad_json_structure()

    if os.path.isfile("./data/bad_ica_components.json"):
        ica_bads = json.load(open("./data/bad_ica_components.json"))
        blocks = set_bad_channels_from_json(blocks, bads)
    else:
        ica_bads = create_bad_json_structure()

    if PROMPT_BADS == True:
        for b in blocks:
            b.plot(n_channels=64)
            plt.show(block=True)
            bads[f"sub-{SUBJECT}"][f"{blocks.index(b)+1}"] = b.info['bads']
    
    with open("./data/bad_channels.json", "w") as f:
        json.dump(bads, f)

    ica_blocks = []

    for b in blocks:
        b.interpolate_bads()
        prep_block = preprocessing(b.copy())

        # ICA
        ica_block = get_ica(prep_block, ica_bads, f"{blocks.index(b)+1}")
        ica_blocks.append(ica_block)
    
    with open("./data/bad_ica_components.json", "w") as f:
        json.dump(ica_bads, f)

    prep_raw = mne.concatenate_raws(ica_blocks)
    save_preprocessed_data(f"./data/processed_{SUBJECT}_raw.fif", prep_raw)

else:
    prep_raw = mne.io.read_raw_fif(f"./data/processed_{SUBJECT}_raw.fif", preload=True)

prep_raw.info

Opening raw data file ./data/processed_014_raw.fif...
    Read a total of 1 projection items:
        Average EEG reference (1 x 64)  idle
    Range : 0 ... 301183 =      0.000 ...  2352.992 secs
Ready.


Reading 0 ... 301183  =      0.000 ...  2352.992 secs...


0,1
Measurement date,Unknown
Experimenter,Unknown
Digitized points,67 points
Good channels,64 EEG
Bad channels,
EOG channels,Not available
ECG channels,Not available
Sampling frequency,128.00 Hz
Highpass,0.50 Hz
Lowpass,30.00 Hz


## Epochs

In [68]:
# test with some electrodes

raw_subselect = prep_raw.copy()#.pick(["Cz", "T7", "T8", "P3", "P4"])
raw_subselect.annotations

standard_epochs = get_epochs_from_events(raw_subselect, '_S')
exemplar_epochs = get_epochs_from_events(raw_subselect, '_deviantEcorrect_E', min_reaction_s=0.2, max_reaction_s=1.9)
function_epochs = get_epochs_from_events(raw_subselect, '_deviantFcorrect_F', min_reaction_s=0.2, max_reaction_s=1.9)

Used Annotations descriptions: ['STATUS:16128', 'STATUS:Correct - Exemplar!', 'STATUS:Correct - Function!', 'STATUS:Incorrect - Standard!', 'STATUS:boundary', 'STATUS:five6_S', 'STATUS:five6_deviantE', 'STATUS:five6_deviantEcorrect_E', 'STATUS:five6_deviantF', 'STATUS:five6_deviantFcorrect_F', 'STATUS:fiveRoot_S', 'STATUS:fiveRoot_deviantE', 'STATUS:fiveRoot_deviantEcorrect_E', 'STATUS:fiveRoot_deviantF', 'STATUS:fiveRoot_deviantFcorrect_F', 'STATUS:four6_S', 'STATUS:four6_Sincorrect', 'STATUS:four6_deviantE', 'STATUS:four6_deviantEcorrect_E', 'STATUS:four6_deviantF', 'STATUS:four6_deviantFcorrect_F', 'STATUS:fourRoot_S', 'STATUS:fourRoot_Sincorrect', 'STATUS:fourRoot_deviantE', 'STATUS:fourRoot_deviantEcorrect_E', 'STATUS:fourRoot_deviantF', 'STATUS:fourRoot_deviantFcorrect_F', 'STATUS:one', 'STATUS:two']


Not setting metadata
1218 matching events found
Applying baseline correction (mode: mean)
Created an SSP operator (subspace dimension = 1)
1 projection items activated
Using data from preloaded Raw for 1218 events and 257 original time points ...
    Rejecting  epoch based on EEG : ['Fpz']
    Rejecting  epoch based on EEG : ['Fpz']
    Rejecting  epoch based on EEG : ['AF7']
    Rejecting  epoch based on EEG : ['AF4']
    Rejecting  epoch based on EEG : ['Fpz']
    Rejecting  epoch based on EEG : ['Afz']
    Rejecting  epoch based on EEG : ['C2']
7 bad epochs dropped
Used Annotations descriptions: ['STATUS:16128', 'STATUS:Correct - Exemplar!', 'STATUS:Correct - Function!', 'STATUS:Incorrect - Standard!', 'STATUS:boundary', 'STATUS:five6_S', 'STATUS:five6_deviantE', 'STATUS:five6_deviantEcorrect_E', 'STATUS:five6_deviantF', 'STATUS:five6_deviantFcorrect_F', 'STATUS:fiveRoot_S', 'STATUS:fiveRoot_deviantE', 'STATUS:fiveRoot_deviantEcorrect_E', 'STATUS:fiveRoot_deviantF', 'STATUS:fiveRoot

IndexError: index 84 is out of bounds for axis 0 with size 84

In [None]:
def epoch_rejection(epochs, shape):
    epochs_concat = epochs.copy()
    concat = []
    epochs_concat_removed = np.zeros(shape=shape)
    print("len prior: ", len(epochs_concat))

    for epoch in epochs_concat:
        for channel in epoch:
            concat = np.concatenate((concat, channel))
        std_concat = np.std(concat)
        for channel in epoch:
            channel_max = np.max(abs(channel))
            std_channel = np.std(channel)

            if(channel_max < (5 * std_concat)  and channel_max < (250 * 1e-6)  and channel_max < (5 * std_channel)):
                np.append(epochs_concat_removed, channel)

    print("len after: ", len(epochs_concat_removed))

    return epochs_concat_removed



In [None]:
# exemplar_epochs = epoch_rejection(exemplar_epochs.get_data(), np.array(exemplar_epochs.get_data()).shape )
# # exemplar_epochs = epoch_rejection_prob(exemplar_epochs.get_data())

# function_epochs = epoch_rejection(function_epochs.get_data(), np.array(function_epochs.get_data()).shape)
# # function_epochs = epoch_rejection_prob(function_epochs.get_data())

# standard_epochs = epoch_rejection(standard_epochs.get_data(), np.array(standard_epochs.get_data()).shape)
# # standard_epochs = epoch_rejection_prob(standard_epochs.get_data())

In [None]:
from ipynb.fs.defs.analysis import generate_AUC_ROC
step = 1

roc_exemplar = generate_AUC_ROC(standard_epochs, exemplar_epochs, 7, step)
#print(roc_exemplar)
roc_function = generate_AUC_ROC(standard_epochs, function_epochs, 7, step)
#print(roc_function)
time = [i*step/128 -0.4 for i in range(len(roc_exemplar))]

plt.plot(time, roc_exemplar, label="exemplar")
plt.plot(time, roc_function, label="function")
plt.legend()
plt.show()


ModuleNotFoundError: No module named 'ipynb'

In [None]:
print(np.array(exemplar_epochs.get_data()).shape)
print(np.array(function_epochs.get_data()).shape)
print(np.array(standard_epochs.get_data()).shape)