# Reducing and Filtering of Chord-Oddball data

In [4]:
import mne
import matplotlib
import matplotlib.pyplot as plt
import sys
sys.path.insert(1, "../")
import ccs_eeg_utils
import numpy as np
import mne.preprocessing as prep
import os
import sklearn 
from contextlib import contextmanager
from autoreject import AutoReject
import json

from mne_bids import (BIDSPath, read_raw_bids, write_raw_bids, inspect_dataset)

matplotlib.use('Qt5Agg')

%matplotlib qt

# path where dataset is stored
bids_root = "./data/ds003570/"
TASK = 'AuditoryOddballChords'
SUBJECT = '014'
SUPRESS_BIDS_OUTPUT = True

In [5]:
# Context manager to suppress stdout and stderr
@contextmanager
def suppress_stdout_stderr():
    with open(os.devnull, 'w') as devnull:
        old_stdout = sys.stdout
        old_stderr = sys.stderr
        sys.stdout = devnull
        sys.stderr = devnull
        try:
            yield
        finally:
            sys.stdout = old_stdout
            sys.stderr = old_stderr

def read_raw_data(subject_id):
    bids_path = BIDSPath(subject=subject_id,
                         datatype='eeg', suffix='eeg', task=TASK,
                         root=bids_root)

    if SUPRESS_BIDS_OUTPUT:
        with suppress_stdout_stderr():
            raw = read_raw_bids(bids_path)
    else:
        raw = read_raw_bids(bids_path)

    # Inplace?
    ccs_eeg_utils.read_annotations_core(bids_path, raw)
    raw.load_data()
    
    return raw, bids_path

def preprocessing(raw):
    # TODO: bandpass first, downsample later? -> expensive!
    # 1. Downsampling to 128 Hz
    if raw.info['sfreq'] > 128:
        raw.resample(128)

    # Set channel types to EEG if not already set
    if not all(ch_type in ['eeg', 'stim'] for ch_type in raw.get_channel_types()):
        eeg_channel_names = raw.ch_names
        channel_types = {name: 'eeg' for name in eeg_channel_names}
        raw.set_channel_types(channel_types)

    # 2. Band-pass filter between 0.5 Hz and 30 Hz
    raw.filter(0.5, 30, fir_design='firwin')

    # 3. Re-referencing to the average activity of all electrodes
    #TODO: add apply_proj() here to apply arp?
    raw.set_eeg_reference('average', projection=True)

    """ events = prep.find_eog_events(raw)
    print(events) """

    # 5. Data Reduction (optional)
    # For instance, crop the first 60 seconds of the data

    return raw


def save_preprocessed_data(file_path, raw):
    """
    Saves the preprocessed EEG data to a file.

    Parameters:
    file_path (str): The path where the preprocessed data will be saved.
    raw (mne.io.Raw): The preprocessed MNE Raw object containing EEG data.
    """
    # Check if file_path ends with .fif extension
    if not file_path.endswith('.fif'):
        file_path += '.fif'

    # Save the data
    try:
        raw.save(file_path, overwrite=True)
        print(f"Data saved successfully to {file_path}")
    except Exception as e:
        print(f"Error saving data: {e}")

# see https://neuraldatascience.io/7-eeg/erp_artifacts.html
def get_ica(data):
    data.set_montage('standard_1020', match_case=False)
    ica = mne.preprocessing.ICA(method="fastica")
    blocks = split_in_blocks(data.copy())
    for block in blocks:
        ica.fit(block, verbose=True)
        #ica.plot_components()
        ica.plot_properties(block)
    # components to be excluded
    # add python input to determine which components to exclude
    input_str = input("Enter index of the components to be separated by space: ")  
    # Converting input string to a list of integers  
    exclude_components = input_str.split()  
    exclude_components = [int(num) for num in exclude_components]  
  
    # Printing the list  
    print("List of components:", exclude_components) 
    ica.exclude = exclude_components
    reconst_raw = data.copy()
    # apply with excluded components
    ica.apply(reconst_raw)
    
    # TODO: add finf_bads_ of mne.ICA?
    return ica


def split_in_blocks(raw):
    events, event_id = mne.events_from_annotations(raw)

    # Identify indices of 'STATUS:boundary' events
    boundary_events = events[events[:, 2] == event_id['STATUS:boundary'], 0]

    # Split the data into blocks
    blocks = []
    start_idx = 0
    for end_idx in boundary_events:
        block = raw.copy().crop(tmin=raw.times[start_idx], tmax=raw.times[end_idx])
        blocks.append(block)
        start_idx = end_idx

    return blocks


def get_epochs_from_events(data, event_str, reaction_time_threshold=None):
    evts,evts_dict = mne.events_from_annotations(data)

    # not sure what annotations to use here
    # wanted_keys = [e for e in evts_dict.keys() if 'STATUS:two' in e]

    # deviant events (get all deviant events)
    deviant_keys = [e for e in evts_dict.keys() if event_str in e]

    evts_dict_stim=dict((k, evts_dict[k]) for k in deviant_keys if k in evts_dict)

    data.info.normalize_proj()

    # EITHER use total threshold -250uV - 250uV 
    reject = dict(eeg=0.0005) # in V
    epochs = mne.Epochs(data, evts, evts_dict_stim, tmin=-0.4, tmax=1.6, baseline=(0,0.4), preload=True, reject=reject)

    # remove deviant epochs with reaction time under 200ms
    correct_keys = [e for e in evts_dict.keys() if 'Correct -' in e]
    correct_evt_ids = [evts_dict[key] for key in correct_keys]

    if reaction_time_threshold:
        # get all event times
        event_times = epochs.events[:, 0] / epochs.info['sfreq']

        # if a correct event exists less than 600ms in the epoch (epochs starts 400ms before deviant + 200ms reactiontime), remove deviant event
        epochs_to_remove = []
        for i in range(len(event_times) - 1):
            current_time = event_times[i]
            next_time = event_times[i + 1]

            # Check if the next event is a correct one and within 600ms
            if (epochs.events[i + 1, 2] in correct_evt_ids and 
                (next_time - current_time) < 0.6):
                epochs_to_remove.append(i)

        # Remove the epochs with unrealistic (<200ms) reaction time
        epochs_to_keep = [i for i in range(len(epochs)) if i not in epochs_to_remove]
        epochs = epochs[epochs_to_keep]

    return epochs


def interpolate_bads_and_merge(blocks):
    # Interpolate bad channels
    for block in blocks:
        block.interpolate_bads()

    # Merge the blocks
    raw = mne.concatenate_raws(blocks)

    return raw


def create_bad_json_structure():
    subjects = {}
    for s in range(1, 41):
        subject_key = f'sub-{s:03d}'
        subjects[subject_key] = {}
        for b in range(1, 9):
            block_key = f'{b}'
            subjects[subject_key][block_key] = []

    return subjects

In [6]:
# reduce bids eeg data
if not os.path.isfile(f"./data/processed_{SUBJECT}_raw.fif"):
    raw, bids_path = read_raw_data(SUBJECT)
    channel_types = {ch: 'eeg' for ch in raw.ch_names}
    raw.set_channel_types(channel_types)

    raw.set_montage('standard_1020', match_case=False)
    blocks = split_in_blocks(raw.copy())

    if os.path.isfile("./data/bad_channels.json"):
        bads = json.load(open("./data/bad_channels.json"))
    else:
        bads = create_bad_json_structure()

    for b in blocks:
        b.plot()
        plt.show(block=True)
        bads[f"sub-{SUBJECT}"][f"{blocks.index(b)+1}"] = b.info['bads']
    
    with open("./data/bad_channels.json", "w") as f:
        json.dump(bads, f)

    interpol_raw = interpolate_bads_and_merge(blocks)
    prep_raw = preprocessing(interpol_raw.copy())
    save_preprocessed_data(f"./data/processed_{SUBJECT}_raw.fif", prep_raw)

else:
    prep_raw = mne.io.read_raw_fif(f"./data/processed_{SUBJECT}_raw.fif", preload=True)

prep_raw.info

Reading 0 ... 4818943  =      0.000 ...  2353.000 secs...


  raw.set_channel_types(channel_types)


Used Annotations descriptions: ['STATUS:16128', 'STATUS:Correct - Exemplar!', 'STATUS:Correct - Function!', 'STATUS:Incorrect - Standard!', 'STATUS:boundary', 'STATUS:five6_S', 'STATUS:five6_deviantE', 'STATUS:five6_deviantEcorrect_E', 'STATUS:five6_deviantF', 'STATUS:five6_deviantFcorrect_F', 'STATUS:fiveRoot_S', 'STATUS:fiveRoot_deviantE', 'STATUS:fiveRoot_deviantEcorrect_E', 'STATUS:fiveRoot_deviantF', 'STATUS:fiveRoot_deviantFcorrect_F', 'STATUS:four6_S', 'STATUS:four6_Sincorrect', 'STATUS:four6_deviantE', 'STATUS:four6_deviantEcorrect_E', 'STATUS:four6_deviantF', 'STATUS:four6_deviantFcorrect_F', 'STATUS:fourRoot_S', 'STATUS:fourRoot_Sincorrect', 'STATUS:fourRoot_deviantE', 'STATUS:fourRoot_deviantEcorrect_E', 'STATUS:fourRoot_deviantF', 'STATUS:fourRoot_deviantFcorrect_F', 'STATUS:one', 'STATUS:two']
Channels marked as bad:
none


  block.interpolate_bads()
  block.interpolate_bads()
  block.interpolate_bads()
  block.interpolate_bads()
  block.interpolate_bads()
  block.interpolate_bads()
  block.interpolate_bads()


Filtering raw data in 7 contiguous segments
Setting up band-pass filter from 0.5 - 30 Hz

FIR filter parameters
---------------------
Designing a one-pass, zero-phase, non-causal bandpass filter:
- Windowed time-domain design (firwin) method
- Hamming window with 0.0194 passband ripple and 53 dB stopband attenuation
- Lower passband edge: 0.50
- Lower transition bandwidth: 0.50 Hz (-6 dB cutoff frequency: 0.25 Hz)
- Upper passband edge: 30.00 Hz
- Upper transition bandwidth: 7.50 Hz (-6 dB cutoff frequency: 33.75 Hz)
- Filter length: 845 samples (6.602 s)



[Parallel(n_jobs=1)]: Done  17 tasks      | elapsed:    0.0s


EEG channel type selected for re-referencing
Adding average EEG reference projection.
1 projection items deactivated
Average reference projection was added, but has not been applied yet. Use the apply_proj method to apply it.
Writing /home/linus/git/uni/eeg-chord-oddball/data/processed_014_raw.fif
Closing /home/linus/git/uni/eeg-chord-oddball/data/processed_014_raw.fif
[done]
Data saved successfully to ./data/processed_014_raw.fif


0,1
Measurement date,Unknown
Experimenter,Unknown
Digitized points,67 points
Good channels,64 EEG
Bad channels,
EOG channels,Not available
ECG channels,Not available
Sampling frequency,128.00 Hz
Highpass,0.50 Hz
Lowpass,30.00 Hz


In [7]:
series = prep_raw[0,:]
#print(series[0][0])
plt.plot(series[0][0])
plt.show()
plt.plot(series[0][0])
plt.show()

## Epochs

In [8]:
# test with some electrodes

raw_subselect = prep_raw.copy().pick(["Cz", "T7", "T8", "P3", "P4"])
raw_subselect.annotations

standard_epochs = get_epochs_from_events(raw_subselect, '_S')
exemplar_epochs = get_epochs_from_events(raw_subselect, '_deviantEcorrect_E', reaction_time_threshold=0.2)
function_epochs = get_epochs_from_events(raw_subselect, '_deviantFcorrect_F', reaction_time_threshold=0.2)

Used Annotations descriptions: ['STATUS:16128', 'STATUS:Correct - Exemplar!', 'STATUS:Correct - Function!', 'STATUS:Incorrect - Standard!', 'STATUS:boundary', 'STATUS:five6_S', 'STATUS:five6_deviantE', 'STATUS:five6_deviantEcorrect_E', 'STATUS:five6_deviantF', 'STATUS:five6_deviantFcorrect_F', 'STATUS:fiveRoot_S', 'STATUS:fiveRoot_deviantE', 'STATUS:fiveRoot_deviantEcorrect_E', 'STATUS:fiveRoot_deviantF', 'STATUS:fiveRoot_deviantFcorrect_F', 'STATUS:four6_S', 'STATUS:four6_Sincorrect', 'STATUS:four6_deviantE', 'STATUS:four6_deviantEcorrect_E', 'STATUS:four6_deviantF', 'STATUS:four6_deviantFcorrect_F', 'STATUS:fourRoot_S', 'STATUS:fourRoot_Sincorrect', 'STATUS:fourRoot_deviantE', 'STATUS:fourRoot_deviantEcorrect_E', 'STATUS:fourRoot_deviantF', 'STATUS:fourRoot_deviantFcorrect_F', 'STATUS:one', 'STATUS:two']
Not setting metadata
1067 matching events found
Applying baseline correction (mode: mean)
Created an SSP operator (subspace dimension = 1)
1 projection items activated
Using data fro

In [9]:
# TODO: Copied from exercise 1, has to be adapted. Change selection of electrodes to electrodes near the ears? --> Cz is apparently connected to N2c
evts,evts_dict = mne.events_from_annotations(raw_subselect)

# not sure what annotations to use here
# wanted_keys = [e for e in evts_dict.keys() if 'STATUS:two' in e]

# deviant events (get all deviant events)
deviant_keys = [e for e in evts_dict.keys() if 'deviant' in e]

evts_dict_stim=dict((k, evts_dict[k]) for k in deviant_keys if k in evts_dict)

print(evts_dict_stim)



Used Annotations descriptions: ['STATUS:16128', 'STATUS:Correct - Exemplar!', 'STATUS:Correct - Function!', 'STATUS:Incorrect - Standard!', 'STATUS:boundary', 'STATUS:five6_S', 'STATUS:five6_deviantE', 'STATUS:five6_deviantEcorrect_E', 'STATUS:five6_deviantF', 'STATUS:five6_deviantFcorrect_F', 'STATUS:fiveRoot_S', 'STATUS:fiveRoot_deviantE', 'STATUS:fiveRoot_deviantEcorrect_E', 'STATUS:fiveRoot_deviantF', 'STATUS:fiveRoot_deviantFcorrect_F', 'STATUS:four6_S', 'STATUS:four6_Sincorrect', 'STATUS:four6_deviantE', 'STATUS:four6_deviantEcorrect_E', 'STATUS:four6_deviantF', 'STATUS:four6_deviantFcorrect_F', 'STATUS:fourRoot_S', 'STATUS:fourRoot_Sincorrect', 'STATUS:fourRoot_deviantE', 'STATUS:fourRoot_deviantEcorrect_E', 'STATUS:fourRoot_deviantF', 'STATUS:fourRoot_deviantFcorrect_F', 'STATUS:one', 'STATUS:two']
{'STATUS:five6_deviantE': 7, 'STATUS:five6_deviantEcorrect_E': 8, 'STATUS:five6_deviantF': 9, 'STATUS:five6_deviantFcorrect_F': 10, 'STATUS:fiveRoot_deviantE': 12, 'STATUS:fiveRoot_d

In [10]:

# 400 ms before stimulus onset, 1600 ms after stimulus onset - Stimulus 2 --> since each chord lasts 400ms --> 0s - 2000
# maybe add baseline=(0.3,0.4) according to paper --> then no ica though? -> why?
raw_subselect.info.normalize_proj()

# EITHER use total threshold -250uV - 250uV 
reject = dict(eeg=0.0005) # in V
epochs = mne.Epochs(raw_subselect, evts, evts_dict_stim, tmin=-0.4, tmax=1.6, baseline=(0,0.4), preload=True, reject=reject)

# OR use annotations
""" annotations, bads = prep.annotate_amplitude(raw_subselect, flat=dict(eeg=-250 * 1e-6), peak=dict(eeg=250 * 1e-6), picks="eeg")
raw_subselect.set_annotations(annotations)
epochs = mne.Epochs(raw_subselect, evts, evts_dict_stim, tmin=-0.4, tmax=1.6, baseline=(0.3,0.4), preload=True, reject_by_annotation=True) """


# remove deviant epochs with reaction time under 200ms
correct_keys = [e for e in evts_dict.keys() if 'Correct -' in e]
correct_evt_ids = [evts_dict[key] for key in correct_keys]

# get all event times
event_times = epochs.events[:, 0] / epochs.info['sfreq']

# if a correct event exists less than 600ms in the epoch (epochs starts 400ms before deviant + 200ms reactiontime), remove deviant event
epochs_to_remove = []
for i in range(len(event_times) - 1):
    current_time = event_times[i]
    next_time = event_times[i + 1]

    # Check if the next event is a correct one and within 600ms
    if (epochs.events[i + 1, 2] in correct_evt_ids and 
        (next_time - current_time) < 0.6):
        epochs_to_remove.append(i)


# Remove the epochs with unrealistic (<200ms) reaction time
epochs_to_keep = [i for i in range(len(epochs)) if i not in epochs_to_remove]
epochs = epochs[epochs_to_keep]

# TODO: ICA vor oder nach Epoching?
ica = get_ica(raw_subselect)
#ica = get_ica(epochs)

# extract data from epochs object
data = epochs.get_data()
times = epochs.times

n_trials = 3 # use data.shape[0] for all
fig, axs = plt.subplots(n_trials, 1, figsize=(10, 3*n_trials), sharex=True, sharey=True)

# plot selected trials
for i in range(n_trials):
    for ch in range(data.shape[1]):
        axs[i].plot(times, data[i, ch, :], label=f'Channel {ch}')
    axs[i].set_title(f'Trial {i}')
    axs[i].legend(loc='upper right')

# label the x-axis
plt.xlabel('Time (s)')

# display the plot
plt.show()


Not setting metadata
193 matching events found
Applying baseline correction (mode: mean)
Created an SSP operator (subspace dimension = 1)
1 projection items activated
Using data from preloaded Raw for 193 events and 257 original time points ...


0 bad epochs dropped
Used Annotations descriptions: ['STATUS:16128', 'STATUS:Correct - Exemplar!', 'STATUS:Correct - Function!', 'STATUS:Incorrect - Standard!', 'STATUS:boundary', 'STATUS:five6_S', 'STATUS:five6_deviantE', 'STATUS:five6_deviantEcorrect_E', 'STATUS:five6_deviantF', 'STATUS:five6_deviantFcorrect_F', 'STATUS:fiveRoot_S', 'STATUS:fiveRoot_deviantE', 'STATUS:fiveRoot_deviantEcorrect_E', 'STATUS:fiveRoot_deviantF', 'STATUS:fiveRoot_deviantFcorrect_F', 'STATUS:four6_S', 'STATUS:four6_Sincorrect', 'STATUS:four6_deviantE', 'STATUS:four6_deviantEcorrect_E', 'STATUS:four6_deviantF', 'STATUS:four6_deviantFcorrect_F', 'STATUS:fourRoot_S', 'STATUS:fourRoot_Sincorrect', 'STATUS:fourRoot_deviantE', 'STATUS:fourRoot_deviantEcorrect_E', 'STATUS:fourRoot_deviantF', 'STATUS:fourRoot_deviantFcorrect_F', 'STATUS:one', 'STATUS:two']


IndexError: index 263552 is out of bounds for axis 0 with size 263552

In [None]:
epochs_concat = np.array(epochs)
concat = []
epochs_concat_removed = []
removed = 0

for index, epoch in enumerate(epochs_concat[:]):
    for channel in epoch:
        concat = np.concatenate((concat, channel))
    mean_concat = np.mean(concat)
    std_concat = np.std(concat)
    for signal in concat:
        if(signal > mean_concat - (5 * std_concat) and signal < mean_concat + (5 * std_concat) or signal < (-250 * 1e-6) or signal > (250 * 1e-6)):
            epochs_concat_removed.append(signal)

print("len after: ", len(epochs_concat_removed))

len after:  24024915


In [None]:
epochs_channel = np.array(epochs)
print("len prior: ", len(epochs_channel))

for index, epoch in enumerate(epochs_channel):
    for channel in epoch:
        mean_channel = np.mean(channel)
        std_channel = np.std(channel)
        for signal in channel:
            if(signal < mean_channel - (5 * std_channel) or signal > mean_channel + (5 * std_channel) or signal < (-250 * 1e-6) or signal > (250 * 1e-6)):
                print("remove ", signal)
                epochs_channel = np.delete(epochs_channel, index, 0)

print("len after: ", len(epochs_channel))

len prior:  193
remove  -5.9523672294814104e-05
remove  -5.781054439996942e-05
len after:  191
