In [1]:
import os
import numpy as np
import h5py
import pandas as pd
from scipy.signal import resample
from sklearn.utils import shuffle
import matplotlib.pyplot as plt
import warnings
warnings.filterwarnings("ignore")

In [2]:
import mne
from mne.preprocessing import ICA
try:
    from mne_icalabel import label_components
except Exception:
    label_components = None

In [3]:
SAMPLE_RATE = 200  # fs
# SAMPLE_LEN = 1.0   # sample seconds
# OVERLAPPING = 0.8  # overlapping seconds
sub_folder_path = str(SAMPLE_RATE) + 'Hz'
sub_folder_path

'200Hz'

## Load participants.tsv file

In [4]:
# root dir
root = 'PD-SIM/'
participants_path = os.path.join(root, 'participants.tsv')
participants = pd.read_csv(participants_path, sep='\t')
participants

Unnamed: 0,participant_id,GROUP,ID,EEG,BEHAVIOR,AGE,GENDER,MOCA,UPDRS,TYPE
0,sub-001,PD,1003,PD1003,Subject_1001_10-Oct-2017_10.26h,80,M,19,28.0,1
1,sub-002,PD,1013,PD1013,Subject_1002_17-Oct-2017_13.5h,81,M,17,25.0,1
2,sub-003,PD,1023,PD1023,Subject_1003_27-Oct-2017_10.47h,68,F,26,10.0,1
3,sub-004,PD,1033,PD1033,Subject_1004_31-Oct-2017_15.3h,80,M,22,10.0,1
4,sub-005,PD,1043,PD1043,Subject_1005_07-Nov-2017_12.3h,56,M,21,13.0,1
...,...,...,...,...,...,...,...,...,...,...
142,sub-143,Control,1453,Control1453,Subject_2046_03-Feb-2021_11.46h,64,F,27,,0
143,sub-144,Control,1463,Control1463,Subject_2047_03-Mar-2021_14.33h,71,M,30,,0
144,sub-145,Control,1473,Control1473,Subject_2048_04-Mar-2021_10.43h,78,M,27,,0
145,sub-146,Control,1483,Control1483,Subject_2049_08-Mar-2021_14.59h,68,F,27,,0


## Find if there are bad channels, check sampling frequency and data shape to avoid inconsistency

In [5]:
# Test for bad channels, sampling freq and shape
bad_channel_list, sampling_freq_list, data_shape_list = [], [], []
for sub in os.listdir(root):
    if 'sub-' in sub:
        sub_path = os.path.join(root, sub, 'eeg/')
        # print(sub_path)
        for file in os.listdir(sub_path):
            if '.set' in file:
                file_path = os.path.join(sub_path, file)
                raw = mne.io.read_raw_eeglab(file_path, preload=True)
                # get bad channels
                bad_channel = raw.info['bads']
                bad_channel_list.append(bad_channel)
                # get sampling frequency
                sampling_freq = raw.info['sfreq']
                sampling_freq_list.append(sampling_freq)
                # get eeg data
                data = raw.get_data()
                data_shape = data.shape
                data_shape_list.append(data_shape)

Reading D:\PycharmPorjects\DataProcessingLocalDisk\ERP-Benchmark\PD-SIM\PD-SIM\sub-001\eeg\sub-001_task-Simon_eeg.fdt
Reading 0 ... 476079  =      0.000 ...   952.158 secs...
Reading D:\PycharmPorjects\DataProcessingLocalDisk\ERP-Benchmark\PD-SIM\PD-SIM\sub-002\eeg\sub-002_task-Simon_eeg.fdt
Reading 0 ... 577699  =      0.000 ...  1155.398 secs...
Reading D:\PycharmPorjects\DataProcessingLocalDisk\ERP-Benchmark\PD-SIM\PD-SIM\sub-003\eeg\sub-003_task-Simon_eeg.fdt
Reading 0 ... 489439  =      0.000 ...   978.878 secs...
Reading D:\PycharmPorjects\DataProcessingLocalDisk\ERP-Benchmark\PD-SIM\PD-SIM\sub-004\eeg\sub-004_task-Simon_eeg.fdt
Reading 0 ... 460929  =      0.000 ...   921.858 secs...
Reading D:\PycharmPorjects\DataProcessingLocalDisk\ERP-Benchmark\PD-SIM\PD-SIM\sub-005\eeg\sub-005_task-Simon_eeg.fdt
Reading 0 ... 426349  =      0.000 ...   852.698 secs...
Reading D:\PycharmPorjects\DataProcessingLocalDisk\ERP-Benchmark\PD-SIM\PD-SIM\sub-006\eeg\sub-006_task-Simon_eeg.fdt
Reading

In [6]:
from collections import Counter

print(bad_channel_list)
print(data_shape_list[0])
print("Channel number counter:", Counter(i[0] for i in data_shape_list))
print("Sampling rate counter:", Counter(sampling_freq_list))

[[], [], [], [], [], [], [], [], [], [], [], [], [], [], [], [], [], [], [], [], [], [], [], [], [], [], [], [], [], [], [], [], [], [], [], [], [], [], [], [], [], [], [], [], [], [], [], [], [], [], [], [], [], [], [], [], [], [], [], [], [], [], [], [], [], [], [], [], [], [], [], [], [], [], [], [], [], [], [], [], [], [], [], [], [], [], [], [], [], [], [], [], [], [], [], [], [], [], [], [], [], [], [], [], [], [], [], [], [], [], [], [], [], [], [], [], [], [], [], [], [], [], [], [], [], [], [], [], [], [], [], [], [], [], [], [], [], [], [], [], [], [], [], [], [], [], []]
(63, 476080)
Channel number counter: Counter({63: 118, 64: 28, 66: 1})
Sampling rate counter: Counter({500.0: 147})


## Pick common channels

In [7]:
# channel number not consistent, take the common channels
common_channels = []
for sub in os.listdir(root):
    if 'sub-' in sub:
        sub_path = os.path.join(root, sub, 'eeg/')
        # print(sub_path)
        for file in os.listdir(sub_path):
            if '.set' in file:
                file_path = os.path.join(sub_path, file)
                raw = mne.io.read_raw_eeglab(file_path, preload=True)
                current_channels = set(raw.info['ch_names'])
                if not common_channels:
                    common_channels = current_channels
                else:
                    common_channels &= current_channels
common_channels = list(common_channels)
print(common_channels)
print("Common channels number: ", len(common_channels))

Reading D:\PycharmPorjects\DataProcessingLocalDisk\ERP-Benchmark\PD-SIM\PD-SIM\sub-001\eeg\sub-001_task-Simon_eeg.fdt
Reading 0 ... 476079  =      0.000 ...   952.158 secs...
Reading D:\PycharmPorjects\DataProcessingLocalDisk\ERP-Benchmark\PD-SIM\PD-SIM\sub-002\eeg\sub-002_task-Simon_eeg.fdt
Reading 0 ... 577699  =      0.000 ...  1155.398 secs...
Reading D:\PycharmPorjects\DataProcessingLocalDisk\ERP-Benchmark\PD-SIM\PD-SIM\sub-003\eeg\sub-003_task-Simon_eeg.fdt
Reading 0 ... 489439  =      0.000 ...   978.878 secs...
Reading D:\PycharmPorjects\DataProcessingLocalDisk\ERP-Benchmark\PD-SIM\PD-SIM\sub-004\eeg\sub-004_task-Simon_eeg.fdt
Reading 0 ... 460929  =      0.000 ...   921.858 secs...
Reading D:\PycharmPorjects\DataProcessingLocalDisk\ERP-Benchmark\PD-SIM\PD-SIM\sub-005\eeg\sub-005_task-Simon_eeg.fdt
Reading 0 ... 426349  =      0.000 ...   852.698 secs...
Reading D:\PycharmPorjects\DataProcessingLocalDisk\ERP-Benchmark\PD-SIM\PD-SIM\sub-006\eeg\sub-006_task-Simon_eeg.fdt
Reading

In [8]:
common_channels = ['FT7', 'Fz', 'P7', 'P8', 'POz', 'F7', 'FT8', 'AF8', 'TP7', 'Cz', 'CP1', 'FC4', 'FC1', 'T7', 'Oz', 'AF7', 'CP6', 'C4', 'C2', 'C3', 'P4', 'F4', 'F6', 'P1', 'TP10', 'P3', 'O2', 'T8', 'CP5', 'Fp1', 'AF3', 'P5', 'P2', 'F2', 'Fp2', 'CPz', 'O1', 'CP3', 'F5', 'FCz', 'PO7', 'TP9', 'PO8', 'P6', 'C1', 'TP8', 'CP4', 'FT10', 'F3', 'F8', 'FC2', 'C6', 'C5', 'CP2', 'AFz', 'FC5', 'AF4', 'FC6', 'F1', 'FC3']

## Data preprocessing and segmentation

In [9]:
feature_path = 'Processed/' + sub_folder_path + '/PD-SIM/Feature'
if not os.path.exists(feature_path):
    os.makedirs(feature_path)
    
label_path = 'Processed/' + sub_folder_path + '/PD-SIM/Label'
if not os.path.exists(label_path):
    os.makedirs(label_path)

In [10]:
def data_preprocessing(
    raw: mne.io.Raw,
    common_channels: list,
    sample_rate: int = 250,
    notch_freq: float = 60.0,
    l_freq: float = 0.5,
    h_freq: float = 40.0,
    do_bad_interp: bool = True,
    verbose: bool = True,
):
    """
    Preprocessing steps ：
      1) choose common channels and reorder
      2) Set Montage 
      3) 60 Hz Notch（before band pass）
      4) bandpass filter（default 0.5–40 Hz）
      5) interpolate bad channels（if do_bad_interp is True）
      6) re-reference to average
      7) ICA（在 1 Hz 高通的副本上拟合，自动剔除眼动/肌电等分量，需 mne-icalabel）
      8) downsample to 250 Hz
    """
    
    # 1. select common channels and reorder to given order
    keep = [ch for ch in common_channels if ch in raw.ch_names]
    raw.pick_channels(keep)
    raw.reorder_channels(keep)
    if verbose:
        print(f"✔ Step 2: Picked common channels ({len(keep)}): {keep}")
        
    # 2. Set Montage
    raw.set_montage(mne.channels.make_standard_montage('standard_1020'))
    if verbose:
        print("✔ Step 1, Montage set: 'standard_1020'.")
        
    # 3. Notch（工频）
    if notch_freq is not None:
        raw.notch_filter(freqs=[notch_freq], picks="eeg", verbose=False)
        if verbose:
            print(f"✔ Step 3: Notch @ {notch_freq} Hz")
        
    # 4. Bandpass Filter (0.5–40 Hz)
    raw.filter(l_freq=l_freq, h_freq=h_freq, picks="eeg", verbose=False)
    if verbose:
        print(f"✔ Step 4: Band-pass {l_freq}–{h_freq} Hz")
        
    # 5. Interpolate bad channels
    if do_bad_interp and raw.info.get("bads"):
        raw.interpolate_bads(reset_bads=True, verbose=False)
        if verbose:
            print(f"✔ Step 5: Interpolated bads: {raw.info.get('bads', [])}")
    else:
        if verbose:
            print("ℹ Step 5: No bads to interpolate (set raw.info['bads'] first if needed)")
            
    # 6) Re-reference to average
    raw.set_eeg_reference("average", verbose=False)
    if verbose:
        print("✔ Step 6: Average reference")
    
    print()
    # 7) ICA (fit ICLabel on 1 Hz high-pass filtered copy, then apply to original)
    raw_for_ica = raw.copy().filter(l_freq=1.0, h_freq=None, picks="eeg", verbose=False)
    ica = ICA(n_components=0.999, method="fastica", random_state=97, max_iter="auto")
    ica.fit(raw_for_ica)

    excluded = []
    if label_components is not None:
        try:
            ic_labels = label_components(raw_for_ica, ica, method="iclabel")
            labels = ic_labels["labels"]
            probs = ic_labels["y_pred_proba"]  # (n_comp, n_classes)
            thresholds = {
                "eye blink": 0.7,
                "muscle artifact": 0.6,
                "heart beat": 0.5,
                "line noise": 0.8,
                "channel noise": 0.9,
            }
            for i, lab in enumerate(labels):
                if lab in thresholds:
                    if probs is not None:
                        p = probs[i].max()
                    else:
                        p = 1.0
                    if p >= thresholds[lab]:
                        excluded.append(i)
        except Exception as e:
            if verbose:
                print(f"⚠ ICLabel failed ({e}). Skipping auto exclusion.")
    else:
        if verbose:
            print("ℹ ICLabel not available; fitted ICA but no auto component exclusion.")

    if excluded:
        ica.exclude = sorted(set(excluded))
        raw = ica.apply(raw.copy())
        if verbose:
            print(f"✔ Step 7: ICA applied. Excluded comps: {ica.exclude}")
    else:
        if verbose:
            print("ℹ Step 7: No ICA components excluded.")

    # 8) downsample to 250 Hz
    if raw.info["sfreq"] != sample_rate:
        raw.resample(sample_rate, npad="auto", verbose=False)
    if verbose:
        print(f"✔ Step 8: Resampled to {sample_rate} Hz")
        
    return raw

In [11]:
def epoch_and_make_xy(
    raw: mne.io.Raw,
    events_tsv_path: str,
    beh_tsv_path: str,
    tmin: float = -0.5,
    tmax: float = 1.5,
    baseline=(-0.5, 0),               # e.g, (-0.5, 0)   baseline correction from -0.5 to 0 s
    task_id: int = 1,            # assigned task ID, here are all Simon Conflict tasks, set to 1
    subject_id: int = 1,
    disease_id: int = 1,            # assigned disease ID, e.g, 0 for healthy control, 1 for pakinson's disease
):
    """
    - from events.tsv, take S 1 cue as lock, S 1 is the cue event; S 2 is the response event
    - match with beh.tsv file's trial information, only select accuracy==1 (response correct)
    - segment [-0.5, 1.0] s of S1 cue event, get X: (N, 375(1.5s), C); Y: (N, 5)=[task_id, congruentcondition, accuracy, subject_id, disease_id]
    """
    # read events and select S 1 cue events
    ev = pd.read_csv(events_tsv_path, sep="\t")
    ev = ev[ev["value"].astype(str).str.match(r"^S\s*\d+$")].copy()  # select S 1, S 2, etc.
    ev["code"] = ev["value"].str.extract(r"(\d+)").astype(int)
    s1 = ev[ev["code"] == 1].reset_index(drop=True)

    # read behavior data and match with S1 events
    beh = pd.read_csv(beh_tsv_path, sep="\t")
    print(len(beh), "trials in behavior data")
    print(len(s1), "S1 events in events data")
    n = min(len(beh), len(s1))
    print(f"Using the first {n} trials from behavior data and S1 events to avoid mismatch.")
    beh = beh.iloc[:n].reset_index(drop=True)
    s1 = s1.iloc[:n].reset_index(drop=True)

    # only keep correct responses
    """mask = (beh["accuracy"] == 1).to_numpy()  # accuracy==1 correct, 2 incorrect, 99 no response (I guess)
    beh_ok = beh.iloc[mask]            
    s1_ok  = s1.iloc[mask]"""   # Let's keep all trials for now, and add accuracy to y later

    # from onset second to timestamp under current sampling frequency
    sfreq = raw.info["sfreq"]
    print(f"Current sampling frequency: {sfreq} Hz")
    s1_samples = np.round(s1["onset"].values * sfreq).astype(int)
    # (n_events, 3) array, first column is sample index, second column is 0 (previous event), third column is 1 (current event)
    # later we use cue=1 to epoch
    events = np.c_[s1_samples, np.zeros_like(s1_samples), np.ones_like(s1_samples).astype(int)]

    # segment epoch
    picks = mne.pick_types(raw.info, eeg=True, eog=False, exclude="bads")
    epochs = mne.Epochs(
        raw, events, event_id=dict(cue=1),
        tmin=tmin, tmax=tmax,
        baseline=baseline, picks=picks,
        proj=False, preload=True, reject=None, verbose=False
    )

    # guarantee int((tmax - tmin) * sfreq) timestamps 
    target_len = int((tmax - tmin) * sfreq)
    data = epochs.get_data()  # (N, C, T)
    if data.shape[-1] > target_len:
        data = data[..., :target_len]
    elif data.shape[-1] < target_len:
        pad = target_len - data.shape[-1]
        data = np.pad(data, ((0, 0), (0, 0), (0, pad)), mode="edge")

    # convert to (N, int((tmax - tmin) * sfreq), C)
    X = np.transpose(data, (0, 2, 1))

    # generate to Y: [task_id, congruentcondition, accuracy, subject_id, disease_id]
    congruentcondition = beh["congruentcondition"].values.astype(int)
    accuracy = beh["accuracy"].values.astype(int)  # 1 for correct, 2 for incorrect, 99 for no response
    y = np.column_stack([
        np.full_like(congruentcondition, task_id),
        congruentcondition,
        accuracy,
        np.full_like(congruentcondition, subject_id),
        np.full_like(congruentcondition, disease_id)
    ])
    
    sel = epochs.selection  # it may drop some trials during epoch segmenting due to trial out of bound
    y = y[sel]  # y align with X to avoid mismatch

    return X, y

In [12]:
for sub in os.listdir(root):
    if 'sub-' in sub:
        print("Current subject: ", sub)
        eeg_path = os.path.join(root, sub, 'eeg/')
        beh_path = os.path.join(root, sub, 'beh/')
        # print(sub_path)
        # look for set, events, and beh file, set file is the EEG data, events and beh file is the erp information
        set_file_path, events_file_path, beh_file_path = None, None, None
        for file in os.listdir(eeg_path):
            if '.set' in file:
                set_file_path = os.path.join(eeg_path, file)
            if 'events.tsv' in file:
                events_file_path = os.path.join(eeg_path, file)
        for file in os.listdir(beh_path):
            if 'beh.tsv' in file:
                beh_file_path = os.path.join(beh_path, file)
        if set_file_path is None:
            raise FileNotFoundError(f".set file not found in: {eeg_path}")
        if events_file_path is None:
            raise FileNotFoundError(f"events .tsv not found in: {eeg_path}")
        if beh_file_path is None:
            raise FileNotFoundError(f"beh .tsv not found in: {beh_path}")
        
        # load eeg data and preprocess
        print("Start preprocessing EEG data...")
        raw = mne.io.read_raw_eeglab(set_file_path, preload=True)
        raw = data_preprocessing(raw, common_channels, SAMPLE_RATE, notch_freq=60, l_freq=0.5, h_freq=40, verbose=True)
        
        print()
        
        # segment and make X, y
        subject_id = int((sub.split('-')[-1]).split('.')[0])
        disease_id = participants[participants['participant_id'] == sub]['GROUP'].values[0]
        if disease_id == 'Control':
            disease_id = 0
        elif disease_id == 'PD':
            disease_id = 1
            
        print(f"Subject ID: {subject_id}, Disease ID: {disease_id}")
        print("Start epoching and making X, y...")
        X, y = epoch_and_make_xy(
            raw, events_file_path, beh_file_path,
            tmin=-0.5, tmax=1.0, baseline=(-0.3, -0.2),
            task_id=1, subject_id=subject_id, disease_id=disease_id
        )
        print(f"X shape: {X.shape}, y shape: {y.shape}")  # X: (N, SAMPLE_RATE * (tmax - tmin), C), y: (N, 5), [task_id, congruentcondition, accuracy, subject_id, disease_id]
        # save X, y to npy files
        np.save(feature_path + '/feature_{:03d}.npy'.format(subject_id), X)
        np.save(label_path + '/label_{:03d}.npy'.format(subject_id), y)
        
        print("------------------------------------------------\n")

Current subject:  sub-001
Start preprocessing EEG data...
Reading D:\PycharmPorjects\DataProcessingLocalDisk\ERP-Benchmark\PD-SIM\PD-SIM\sub-001\eeg\sub-001_task-Simon_eeg.fdt
Reading 0 ... 476079  =      0.000 ...   952.158 secs...
✔ Step 2: Picked common channels (60): ['FT7', 'Fz', 'P7', 'P8', 'POz', 'F7', 'FT8', 'AF8', 'TP7', 'Cz', 'CP1', 'FC4', 'FC1', 'T7', 'Oz', 'AF7', 'CP6', 'C4', 'C2', 'C3', 'P4', 'F4', 'F6', 'P1', 'TP10', 'P3', 'O2', 'T8', 'CP5', 'Fp1', 'AF3', 'P5', 'P2', 'F2', 'Fp2', 'CPz', 'O1', 'CP3', 'F5', 'FCz', 'PO7', 'TP9', 'PO8', 'P6', 'C1', 'TP8', 'CP4', 'FT10', 'F3', 'F8', 'FC2', 'C6', 'C5', 'CP2', 'AFz', 'FC5', 'AF4', 'FC6', 'F1', 'FC3']
✔ Step 1, Montage set: 'standard_1020'.
✔ Step 3: Notch @ 60 Hz
✔ Step 4: Band-pass 0.5–40 Hz
ℹ Step 5: No bads to interpolate (set raw.info['bads'] first if needed)
✔ Step 6: Average reference

Fitting ICA to data using 60 channels (please be patient, this may take a while)
Selecting by explained variance: 49 components
Fitting ICA

## Load and check the processed data

In [13]:
# Test the saved npy file
# example

total_samples = 0
sub_id = 1
for feature_file, label_file in zip(os.listdir(feature_path), os.listdir(label_path)):
    feature_file_path = os.path.join(feature_path, feature_file)
    label_file_path = os.path.join(label_path, label_file)
    X = np.load(feature_file_path)
    y = np.load(label_file_path)
    print(f"Subject {sub_id}: X shape: {X.shape}, y shape: {y.shape}")
    if X.shape[0] != y.shape[0]:
        raise(f"Subject {sub_id} data and label length mismatch: " 
                f"{X.shape[0]} vs {y.shape[0]}")
    total_samples += np.load(feature_file_path).shape[0]
    sub_id += 1
print("\nTotal number of samples:", total_samples)

Subject 1: X shape: (384, 300, 60), y shape: (384, 5)
Subject 2: X shape: (370, 300, 60), y shape: (370, 5)
Subject 3: X shape: (384, 300, 60), y shape: (384, 5)
Subject 4: X shape: (384, 300, 60), y shape: (384, 5)
Subject 5: X shape: (384, 300, 60), y shape: (384, 5)
Subject 6: X shape: (381, 300, 60), y shape: (381, 5)
Subject 7: X shape: (384, 300, 60), y shape: (384, 5)
Subject 8: X shape: (384, 300, 60), y shape: (384, 5)
Subject 9: X shape: (384, 300, 60), y shape: (384, 5)
Subject 10: X shape: (383, 300, 60), y shape: (383, 5)
Subject 11: X shape: (381, 300, 60), y shape: (381, 5)
Subject 12: X shape: (382, 300, 60), y shape: (382, 5)
Subject 13: X shape: (383, 300, 60), y shape: (383, 5)
Subject 14: X shape: (384, 300, 60), y shape: (384, 5)
Subject 15: X shape: (384, 300, 60), y shape: (384, 5)
Subject 16: X shape: (192, 300, 60), y shape: (192, 5)
Subject 17: X shape: (382, 300, 60), y shape: (382, 5)
Subject 18: X shape: (384, 300, 60), y shape: (384, 5)
Subject 19: X shape