In [1]:
import os
import numpy as np
import pandas as pd
from collections import Counter
from sklearn.utils import shuffle
from scipy.signal import resample
import matplotlib.pyplot as plt
import warnings
import mne
from mne.preprocessing import ICA
warnings.filterwarnings("ignore")
try:
    from mne_icalabel import label_components
except Exception:
    label_components = None

In [2]:
SAMPLE_RATE = 200  # fs
# SAMPLE_LEN = 1.0   # sample seconds
# OVERLAPPING = 0.8  # overlapping seconds
sub_folder_path = str(SAMPLE_RATE) + 'Hz'
sub_folder_path

'200Hz'

In [3]:
feature_path = 'Processed/' + sub_folder_path + '/PD-ODD/Feature'
if not os.path.exists(feature_path):
    os.makedirs(feature_path)
    
label_path = 'Processed/' + sub_folder_path + '/PD-ODD/Label'
if not os.path.exists(label_path):
    os.makedirs(label_path)

In [4]:
root = "PD-ODD/"
participants_path = os.path.join(root, 'participants.tsv')
participants = pd.read_csv(participants_path, sep='\t')
participants

Unnamed: 0,participant_id,GROUP,ID,EEG,BEHAVIOR,AGE,GENDER,MOCA,UPDRS,TYPE
0,sub-001,PD,1004,PD1004,Subject_1001_10-Oct-2017_10.47h,80,M,19,28.0,1
1,sub-002,PD,1014,PD1014,Subject_1002_17-Oct-2017_13.28h,81,M,17,25.0,1
2,sub-003,PD,1024,PD1024,Subject_1003_27-Oct-2017_11.13h,68,F,26,10.0,1
3,sub-004,PD,1034,PD1034,Subject_1004_31-Oct-2017_15.29h,80,M,22,10.0,1
4,sub-005,PD,1044,PD1044,Subject_1005_07-Nov-2017_12.19h,56,M,21,13.0,1
...,...,...,...,...,...,...,...,...,...,...
141,sub-142,Control,1454,Control1454,Subject_2046_03-Feb-2021_12.5h,64,F,27,,0
142,sub-143,Control,1464,Control1464,Subject_2047_03-Mar-2021_14.51h,71,M,30,,0
143,sub-144,Control,1474,Control1474,Subject_2048_04-Mar-2021_11.0h,78,M,27,,0
144,sub-145,Control,1484,Control1484,Subject_2049_08-Mar-2021_15.18h,68,F,27,,0


In [5]:
# Test for bad channels, sampling freq and shape
bad_channel_list, sampling_freq_list, data_shape_list = [], [], []
for sub in os.listdir(root):
    if 'sub-' in sub:
        sub_path = os.path.join(root, sub, 'eeg/')
        # print(sub_path)
        for file in os.listdir(sub_path):
            if '.set' in file:
                file_path = os.path.join(sub_path, file)
                raw = mne.io.read_raw_eeglab(file_path, preload=True)
                # get bad channels
                bad_channel = raw.info['bads']
                bad_channel_list.append(bad_channel)
                # get sampling frequency
                sampling_freq = raw.info['sfreq']
                sampling_freq_list.append(sampling_freq)
                # get eeg data
                data = raw.get_data()
                data_shape = data.shape
                data_shape_list.append(data_shape)

Reading D:\PycharmPorjects\DataProcessingLocalDisk\ERP-Benchmark\PD-ODD\PD-ODD\sub-001\eeg\sub-001_task-Oddball_eeg.fdt
Reading 0 ... 408099  =      0.000 ...   816.198 secs...
Reading D:\PycharmPorjects\DataProcessingLocalDisk\ERP-Benchmark\PD-ODD\PD-ODD\sub-002\eeg\sub-002_task-Oddball_eeg.fdt
Reading 0 ... 543579  =      0.000 ...  1087.158 secs...
Reading D:\PycharmPorjects\DataProcessingLocalDisk\ERP-Benchmark\PD-ODD\PD-ODD\sub-003\eeg\sub-003_task-Oddball_eeg.fdt
Reading 0 ... 601269  =      0.000 ...  1202.538 secs...
Reading D:\PycharmPorjects\DataProcessingLocalDisk\ERP-Benchmark\PD-ODD\PD-ODD\sub-004\eeg\sub-004_task-Oddball_eeg.fdt
Reading 0 ... 403759  =      0.000 ...   807.518 secs...
Reading D:\PycharmPorjects\DataProcessingLocalDisk\ERP-Benchmark\PD-ODD\PD-ODD\sub-005\eeg\sub-005_task-Oddball_eeg.fdt
Reading 0 ... 362229  =      0.000 ...   724.458 secs...
Reading D:\PycharmPorjects\DataProcessingLocalDisk\ERP-Benchmark\PD-ODD\PD-ODD\sub-006\eeg\sub-006_task-Oddball_eeg

In [6]:
from collections import Counter

print(bad_channel_list)
print(data_shape_list[0])
print("Channel number counter:", Counter(i[0] for i in data_shape_list))
print("Sampling rate counter:", Counter(sampling_freq_list))

[[], [], [], [], [], [], [], [], [], [], [], [], [], [], [], [], [], [], [], [], [], [], [], [], [], [], [], [], [], [], [], [], [], [], [], [], [], [], [], [], [], [], [], [], [], [], [], [], [], [], [], [], [], [], [], [], [], [], [], [], [], [], [], [], [], [], [], [], [], [], [], [], [], [], [], [], [], [], [], [], [], [], [], [], [], [], [], [], [], [], [], [], [], [], [], [], [], [], [], [], [], [], [], [], [], [], [], [], [], [], [], [], [], [], [], [], [], [], [], [], [], [], [], [], [], [], [], [], [], [], [], [], [], [], [], [], [], [], [], [], [], [], [], [], [], []]
(63, 408100)
Channel number counter: Counter({63: 116, 64: 29, 66: 1})
Sampling rate counter: Counter({500.0: 146})


In [7]:
common_channels = []
for sub in os.listdir(root):
    if 'sub-' in sub:
        sub_path = os.path.join(root, sub, 'eeg/')
        for file in os.listdir(sub_path):
            if '.set' in file:
                file_path = os.path.join(sub_path, file)
                raw = mne.io.read_raw_eeglab(file_path, preload=True)
                current_channels = set(raw.info['ch_names'])
                if not common_channels:
                    common_channels = current_channels
                else:
                    common_channels &= current_channels
common_channels = list(common_channels)
print(f"Common channels ({len(common_channels)}):", common_channels)

Reading D:\PycharmPorjects\DataProcessingLocalDisk\ERP-Benchmark\PD-ODD\PD-ODD\sub-001\eeg\sub-001_task-Oddball_eeg.fdt
Reading 0 ... 408099  =      0.000 ...   816.198 secs...
Reading D:\PycharmPorjects\DataProcessingLocalDisk\ERP-Benchmark\PD-ODD\PD-ODD\sub-002\eeg\sub-002_task-Oddball_eeg.fdt
Reading 0 ... 543579  =      0.000 ...  1087.158 secs...
Reading D:\PycharmPorjects\DataProcessingLocalDisk\ERP-Benchmark\PD-ODD\PD-ODD\sub-003\eeg\sub-003_task-Oddball_eeg.fdt
Reading 0 ... 601269  =      0.000 ...  1202.538 secs...
Reading D:\PycharmPorjects\DataProcessingLocalDisk\ERP-Benchmark\PD-ODD\PD-ODD\sub-004\eeg\sub-004_task-Oddball_eeg.fdt
Reading 0 ... 403759  =      0.000 ...   807.518 secs...
Reading D:\PycharmPorjects\DataProcessingLocalDisk\ERP-Benchmark\PD-ODD\PD-ODD\sub-005\eeg\sub-005_task-Oddball_eeg.fdt
Reading 0 ... 362229  =      0.000 ...   724.458 secs...
Reading D:\PycharmPorjects\DataProcessingLocalDisk\ERP-Benchmark\PD-ODD\PD-ODD\sub-006\eeg\sub-006_task-Oddball_eeg

In [8]:
def data_preprocessing(
    raw: mne.io.Raw,
    common_channels: list,
    sample_rate: int = 250,
    notch_freq: float = 60.0,
    l_freq: float = 0.5,
    h_freq: float = 40.0,
    do_bad_interp: bool = True,
    verbose: bool = True,
):
    """
    Preprocessing steps ：
      1) choose common channels and reorder
      2) Set Montage 
      3) 60 Hz Notch（before band pass）
      4) bandpass filter（default 0.5–40 Hz）
      5) interpolate bad channels（if do_bad_interp is True）
      6) re-reference to average
      7) ICA（在 1 Hz 高通的副本上拟合，自动剔除眼动/肌电等分量，需 mne-icalabel）
      8) downsample to 250 Hz
    """
    
    # 1. select common channels and reorder to given order
    keep = [ch for ch in common_channels if ch in raw.ch_names]
    raw.pick_channels(keep)
    raw.reorder_channels(keep)
    if verbose:
        print(f"✔ Step 2: Picked common channels ({len(keep)}): {keep}")
        
    # 2. Set Montage
    raw.set_montage(mne.channels.make_standard_montage('standard_1020'))
    if verbose:
        print("✔ Step 1, Montage set: 'standard_1020'.")
        
    # 3. Notch（工频）
    if notch_freq is not None:
        raw.notch_filter(freqs=[notch_freq], picks="eeg", verbose=False)
        if verbose:
            print(f"✔ Step 3: Notch @ {notch_freq} Hz")
        
    # 4. Bandpass Filter (0.5–40 Hz)
    raw.filter(l_freq=l_freq, h_freq=h_freq, picks="eeg", verbose=False)
    if verbose:
        print(f"✔ Step 4: Band-pass {l_freq}–{h_freq} Hz")
        
    # 5. Interpolate bad channels
    if do_bad_interp and raw.info.get("bads"):
        raw.interpolate_bads(reset_bads=True, verbose=False)
        if verbose:
            print(f"✔ Step 5: Interpolated bads: {raw.info.get('bads', [])}")
    else:
        if verbose:
            print("ℹ Step 5: No bads to interpolate (set raw.info['bads'] first if needed)")
            
    # 6) Re-reference to average
    raw.set_eeg_reference("average", verbose=False)
    if verbose:
        print("✔ Step 6: Average reference")
    
    print()
    # 7) ICA (fit ICLabel on 1 Hz high-pass filtered copy, then apply to original)
    raw_for_ica = raw.copy().filter(l_freq=1.0, h_freq=None, picks="eeg", verbose=False)
    ica = ICA(n_components=0.999, method="fastica", random_state=97, max_iter="auto")
    ica.fit(raw_for_ica)

    excluded = []
    if label_components is not None:
        try:
            ic_labels = label_components(raw_for_ica, ica, method="iclabel")
            labels = ic_labels["labels"]
            probs = ic_labels["y_pred_proba"]  # (n_comp, n_classes)
            thresholds = {
                "eye blink": 0.7,
                "muscle artifact": 0.6,
                "heart beat": 0.5,
                "line noise": 0.8,
                "channel noise": 0.9,
            }
            for i, lab in enumerate(labels):
                if lab in thresholds:
                    if probs is not None:
                        p = probs[i].max()
                    else:
                        p = 1.0
                    if p >= thresholds[lab]:
                        excluded.append(i)
        except Exception as e:
            if verbose:
                print(f"⚠ ICLabel failed ({e}). Skipping auto exclusion.")
    else:
        if verbose:
            print("ℹ ICLabel not available; fitted ICA but no auto component exclusion.")

    if excluded:
        ica.exclude = sorted(set(excluded))
        raw = ica.apply(raw.copy())
        if verbose:
            print(f"✔ Step 7: ICA applied. Excluded comps: {ica.exclude}")
    else:
        if verbose:
            print("ℹ Step 7: No ICA components excluded.")

    # 8) downsample to 250 Hz
    if raw.info["sfreq"] != sample_rate:
        raw.resample(sample_rate, npad="auto", verbose=False)
    if verbose:
        print(f"✔ Step 8: Resampled to {sample_rate} Hz")
        
    return raw

In [9]:
def epoch_and_make_xy(
    raw: mne.io.Raw,
    events_tsv_path: str,
    beh_tsv_path: str,
    tmin: float = -0.5,
    tmax: float = 1.5,
    baseline=(-0.5, 0),               # e.g, (-0.5, 0)   baseline correction from -0.5 to 0 s
    task_id: int = 1,            # assigned task ID, here are all Oddball tasks, set to 1
    subject_id: int = 1,
    disease_id: int = 1,            # assigned disease ID, e.g, 0 for healthy control, 1 for pakinson's disease
):
    """
    Generate EEG segments X and labels y:
    - X: (n_trials, n_times, n_channels)
    - y: [task_id, trial_type, acc, subject_id, disease_id]
          trial_type: 0=standard, 1=auditory oddball, 2=visual oddball
    """
    # Read and filter events
    ev = pd.read_csv(events_tsv_path, sep="\t")
    ev = ev[ev["value"].astype(str).str.match(r"^S\s*\d+$")].copy()
    ev["code"] = ev["value"].str.extract(r"(\d+)").astype(int)
    s2 = ev[ev["code"] == 2].reset_index(drop=True)

    # Read behavioral data and match the number of trials
    beh = pd.read_csv(beh_tsv_path, sep="\t")
    n = min(len(beh), len(s2))
    beh = beh.iloc[:n].reset_index(drop=True)
    s2 = s2.iloc[:n].reset_index(drop=True)

    # Construct mne.Epochs
    sfreq = raw.info["sfreq"]
    s2_samples = np.round(s2["onset"].values * sfreq).astype(int)
    # (n_events, 3) array, first column is sample index, second column is 0 (previous event), third column is 1 (current event)
    # later we use cue=1 to epoch
    events = np.c_[s2_samples, np.zeros_like(s2_samples), np.ones_like(s2_samples)]  

    picks = mne.pick_types(raw.info, eeg=True, eog=False, exclude="bads")
    epochs = mne.Epochs(
        raw, events, event_id=dict(cue=1),
        tmin=tmin, tmax=tmax, baseline=baseline,
        picks=picks, proj=False, preload=True, reject=None, verbose=False
    )

    # Standardize the time length of all epochs and transpose to (n_trials, n_times, n_channels)
    target_len = int((tmax - tmin) * sfreq)
    data = epochs.get_data()
    if data.shape[-1] > target_len:
        data = data[..., :target_len]
    elif data.shape[-1] < target_len:
        pad = target_len - data.shape[-1]
        data = np.pad(data, ((0, 0), (0, 0), (0, pad)), mode="edge")
    X = np.transpose(data, (0, 2, 1))

    
    # Robustly read 'Odd: Audio' / 'Odd: Visual' / 'Odd: Haptic' columns; fill with 0 if missing
    def _safe_get_int(colname):
        return beh[colname].fillna(0).astype(int) if colname in beh.columns else np.zeros(len(beh), dtype=int)

    odd_audio = _safe_get_int("Odd: Audio")
    odd_visual = _safe_get_int("Odd: Visual")
    #odd_haptic = _safe_get_int("Odd: Haptic")  # Not included in labels for now; loaded only for compatibility

    # # Generate odd_type: default 0 = standard
    odd_type = np.zeros(len(beh), dtype=int)
    odd_type[odd_visual == 1] = 2
    odd_type[odd_audio == 1] = 1

    # Accuracy
    acc = beh["Acc"].values.astype(int)

    # Updated y structure: [task_id, odd_type, accuracy, subject_id, disease_id]
    y = np.column_stack([
        np.full_like(odd_type, task_id),   # 0: task_id
        odd_type,                          # 1: odd_type (0/1/2)
        acc,                                 # 2: acc
        np.full_like(odd_type, subject_id),# 3: subject_id
        np.full_like(odd_type, disease_id) # 4: disease_id
    ])
    
    sel = epochs.selection  # it may drop some trials during epoch segmenting due to trial out of bound
    y = y[sel]  # y align with X to avoid mismatch

    return X, y

In [10]:
for sub in os.listdir(root):
    if 'sub-' in sub:
        print("Current subject:", sub)
        eeg_path = os.path.join(root, sub, 'eeg/')
        beh_path = os.path.join(root, sub, 'beh/')
        set_file_path, events_file_path, beh_file_path = None, None, None
        for file in os.listdir(eeg_path):
            if '.set' in file:
                set_file_path = os.path.join(eeg_path, file)
            if 'events.tsv' in file:
                events_file_path = os.path.join(eeg_path, file)
        for file in os.listdir(beh_path):
            if 'beh.tsv' in file:
                beh_file_path = os.path.join(beh_path, file)
        if not set_file_path or not events_file_path or not beh_file_path:
            print(f"❌ Missing required files for {sub}. Skipping.")
            continue


        print("Start preprocessing EEG data...")
        try:
            raw = mne.io.read_raw_eeglab(set_file_path, preload=True)
            raw = data_preprocessing(raw, common_channels, SAMPLE_RATE, notch_freq=60, l_freq=0.5, h_freq=40, verbose=True)
        except Exception as e:
            print(f"⚠ Skipping subject {sub} due to preprocessing error: {e}")
            continue
            
        print()

        subject_id = int(sub.split('-')[-1])
        disease_id = participants[participants['participant_id'] == sub]['GROUP'].values[0]
        if disease_id == 'Control':
            disease_id = 0
        elif disease_id == 'PD':
            disease_id = 1

        print(f"Subject ID: {subject_id}, Disease ID: {disease_id}")
        print("Start epoching and making X, y...")
        X, y = epoch_and_make_xy(
            raw, events_file_path, beh_file_path,
            tmin=-0.5, tmax=1.0, baseline=(-0.3, -0.2),
            task_id=1, subject_id=subject_id, disease_id=disease_id
        )
        # X: (N, SAMPLE_RATE * (tmax - tmin), C), y: (N, 5), [task_id, odd_type, accuracy, subject_id, disease_id]
        print(f"X shape: {X.shape}, y shape: {y.shape}")
        np.save(os.path.join(feature_path, f'feature_{subject_id:03d}.npy'), X)
        np.save(os.path.join(label_path, f'label_{subject_id:03d}.npy'), y)
        
        print("------------------------------------------------\n")

Current subject: sub-001
Start preprocessing EEG data...
Reading D:\PycharmPorjects\DataProcessingLocalDisk\ERP-Benchmark\PD-ODD\PD-ODD\sub-001\eeg\sub-001_task-Oddball_eeg.fdt
Reading 0 ... 408099  =      0.000 ...   816.198 secs...
✔ Step 2: Picked common channels (60): ['Fp2', 'P2', 'PO8', 'FC1', 'C1', 'AFz', 'CP1', 'FT7', 'CP4', 'F3', 'AF8', 'Fz', 'C4', 'AF7', 'CPz', 'CP2', 'TP8', 'F8', 'CP3', 'AF3', 'PO7', 'F5', 'O1', 'TP10', 'T8', 'FT8', 'P8', 'F4', 'C6', 'FT10', 'P3', 'P5', 'AF4', 'F2', 'F7', 'P1', 'TP7', 'FC5', 'C5', 'C3', 'FC2', 'C2', 'Oz', 'F6', 'F1', 'O2', 'CP6', 'TP9', 'FC6', 'P4', 'CP5', 'FC3', 'Cz', 'P6', 'FC4', 'FCz', 'P7', 'Fp1', 'POz', 'T7']
✔ Step 1, Montage set: 'standard_1020'.
✔ Step 3: Notch @ 60 Hz
✔ Step 4: Band-pass 0.5–40 Hz
ℹ Step 5: No bads to interpolate (set raw.info['bads'] first if needed)
✔ Step 6: Average reference

Fitting ICA to data using 60 channels (please be patient, this may take a while)
Selecting by explained variance: 48 components
Fitting IC

## Load and check the processed data

In [11]:
# Test the saved npy file
# example

total_samples = 0
sub_id = 1
for feature_file, label_file in zip(os.listdir(feature_path), os.listdir(label_path)):
    feature_file_path = os.path.join(feature_path, feature_file)
    label_file_path = os.path.join(label_path, label_file)
    X = np.load(feature_file_path)
    y = np.load(label_file_path)
    print(f"Subject {sub_id}: X shape: {X.shape}, y shape: {y.shape}")
    if X.shape[0] != y.shape[0]:
        raise(f"Subject {sub_id} data and label length mismatch: " 
                f"{X.shape[0]} vs {y.shape[0]}")
    total_samples += np.load(feature_file_path).shape[0]
    sub_id += 1
print("\nTotal number of samples:", total_samples)

Subject 1: X shape: (236, 300, 60), y shape: (236, 5)
Subject 2: X shape: (238, 300, 60), y shape: (238, 5)
Subject 3: X shape: (240, 300, 60), y shape: (240, 5)
Subject 4: X shape: (240, 300, 60), y shape: (240, 5)
Subject 5: X shape: (240, 300, 60), y shape: (240, 5)
Subject 6: X shape: (240, 300, 60), y shape: (240, 5)
Subject 7: X shape: (240, 300, 60), y shape: (240, 5)
Subject 8: X shape: (240, 300, 60), y shape: (240, 5)
Subject 9: X shape: (240, 300, 60), y shape: (240, 5)
Subject 10: X shape: (229, 300, 60), y shape: (229, 5)
Subject 11: X shape: (237, 300, 60), y shape: (237, 5)
Subject 12: X shape: (234, 300, 60), y shape: (234, 5)
Subject 13: X shape: (231, 300, 60), y shape: (231, 5)
Subject 14: X shape: (238, 300, 60), y shape: (238, 5)
Subject 15: X shape: (238, 300, 60), y shape: (238, 5)
Subject 16: X shape: (240, 300, 60), y shape: (240, 5)
Subject 17: X shape: (240, 300, 60), y shape: (240, 5)
Subject 18: X shape: (239, 300, 60), y shape: (239, 5)
Subject 19: X shape