In [1]:
import os
import numpy as np
import h5py
import pandas as pd
from scipy.signal import resample
from sklearn.utils import shuffle
import matplotlib.pyplot as plt
import warnings
warnings.filterwarnings("ignore")

In [2]:
import mne
from mne.preprocessing import ICA
try:
    from mne_icalabel import label_components
except Exception:
    label_components = None

In [3]:
SAMPLE_RATE = 200  # fs
# SAMPLE_LEN = 1.0   # sample seconds
# OVERLAPPING = 0.8  # overlapping seconds
sub_folder_path = str(SAMPLE_RATE) + 'Hz'
sub_folder_path

'200Hz'

In [4]:
## Load participants.tsv file
# root dir
root = "RLPD/"
participants_path = os.path.join(root, "participants.tsv")
participants = pd.read_csv(participants_path, sep="\t")
participants


Unnamed: 0,participant_id,Original_ID,Group,sess1_Med,sess2_Med,sex,age
0,sub-001,8010,CTL,,no s2,Female,61
1,sub-002,801,PD,ON,OFF,Female,60
2,sub-003,802,PD,OFF,ON,Male,75
3,sub-004,803,PD,OFF,ON,Female,76
4,sub-005,804,PD,ON,OFF,Male,75
5,sub-006,805,PD,ON,OFF,Male,79
6,sub-007,8060,CTL,,no s2,Female,83
7,sub-008,806,PD,OFF,ON,Female,79
8,sub-009,8070,CTL,,no s2,Female,67
9,sub-010,807,PD,OFF,ON,Female,72


In [5]:
# Test for bad channels, sampling freq and shape
bad_channel_list, sampling_freq_list, data_shape_list = [], [], []
for sub in os.listdir(root):
    if sub.startswith("sub-"):   # Confirm it is a subject directory
        sub_path = os.path.join(root, sub)
        for ses in os.listdir(sub_path):
            ses_path = os.path.join(sub_path, ses, "eeg")
            if os.path.exists(ses_path):
                for file in os.listdir(ses_path):
                    if file.endswith(".set"):
                        file_path = os.path.join(ses_path, file)
                        print(f"Reading in progress: {file_path}")
                        
                        try:
                            # Try loading raw first
                            raw = mne.io.read_raw_eeglab(file_path, preload=True)
                            bad_channel_list.append(raw.info['bads'])
                            sampling_freq_list.append(raw.info['sfreq'])
                            data_shape_list.append(raw.get_data().shape)
                        except Exception as e1:
                            print(f"  -> If raw fails, try epochs: {e1}")
                            try:
                                epochs = mne.read_epochs_eeglab(file_path)
                                epochs.load_data()
                                bad_channel_list.append(epochs.info['bads'])
                                sampling_freq_list.append(epochs.info['sfreq'])
                                data_shape_list.append(epochs.get_data().shape)
                            except Exception as e2:
                                print(f"  -> Failed to load epochs as well: {e2}")

Reading in progress: RLPD/sub-001\ses-01\eeg\sub-001_ses-01_task-ReinforcementLearning_eeg.set
Reading D:\PycharmPorjects\DataProcessingLocalDisk\ERP-Benchmark\RLPD\RLPD\sub-001\ses-01\eeg\sub-001_ses-01_task-ReinforcementLearning_eeg.fdt
Reading 0 ... 704899  =      0.000 ...  1409.798 secs...
Reading in progress: RLPD/sub-002\ses-01\eeg\sub-002_ses-01_task-ReinforcementLearning_eeg.set
Reading D:\PycharmPorjects\DataProcessingLocalDisk\ERP-Benchmark\RLPD\RLPD\sub-002\ses-01\eeg\sub-002_ses-01_task-ReinforcementLearning_eeg.fdt
Reading 0 ... 960124  =      0.000 ...  1920.248 secs...
Reading in progress: RLPD/sub-002\ses-02\eeg\sub-002_ses-02_task-ReinforcementLearning_eeg.set
Reading D:\PycharmPorjects\DataProcessingLocalDisk\ERP-Benchmark\RLPD\RLPD\sub-002\ses-02\eeg\sub-002_ses-02_task-ReinforcementLearning_eeg.fdt
Reading 0 ... 312124  =      0.000 ...   624.248 secs...
Reading in progress: RLPD/sub-003\ses-01\eeg\sub-003_ses-01_task-ReinforcementLearning_eeg.set
Reading D:\Pychar

In [6]:
from collections import Counter

print(bad_channel_list)
print(data_shape_list[0])
print("Channel number counter:", Counter(i[0] for i in data_shape_list))
print("Sampling rate counter:", Counter(sampling_freq_list))

[[], [], [], [], [], [], [], [], [], [], [], [], [], [], [], [], [], [], [], [], [], [], [], [], [], [], [], [], [], [], [], [], [], [], [], [], [], [], [], [], [], [], [], [], [], [], [], [], [], [], [], [], [], [], [], [], [], [], [], [], [], [], [], [], [], [], [], [], [], [], [], [], [], [], [], [], [], [], [], [], [], [], [], []]
(67, 704900)
Channel number counter: Counter({67: 84})
Sampling rate counter: Counter({500.0: 84})


In [7]:
common_channels = []
for sub in os.listdir(root):
    if sub.startswith("sub-"):   # Only process subject directories
        sub_path = os.path.join(root, sub)
        for ses in os.listdir(sub_path):  # Iterate over ses-xx directories
            ses_path = os.path.join(sub_path, ses, "eeg")
            if os.path.exists(ses_path):
                for file in os.listdir(ses_path):
                    if file.endswith(".set"):
                        file_path = os.path.join(ses_path, file)
                        print(f"Reading in progress: {file_path}")
                        
                        # Try loading raw data first, fallback to epochs if raw fails
                        try:
                            raw = mne.io.read_raw_eeglab(file_path, preload=True)
                            current_channels = set(raw.info['ch_names'])
                        except Exception as e1:
                            print(f"  -> Raw read failed, fallback to epochs: {e1}")
                            try:
                                epochs = mne.read_epochs_eeglab(file_path)
                                epochs.load_data()
                                current_channels = set(epochs.info['ch_names'])
                            except Exception as e2:
                                print(f"  -> Failed to read epochs as well: {e2}")
                                continue  # Skip this file
                        
                        # Find common channels
                        if not common_channels:
                            common_channels = current_channels
                        else:
                            common_channels &= current_channels

common_channels = list(common_channels)
print(f"Common channels ({len(common_channels)}):", common_channels)

Reading in progress: RLPD/sub-001\ses-01\eeg\sub-001_ses-01_task-ReinforcementLearning_eeg.set
Reading D:\PycharmPorjects\DataProcessingLocalDisk\ERP-Benchmark\RLPD\RLPD\sub-001\ses-01\eeg\sub-001_ses-01_task-ReinforcementLearning_eeg.fdt
Reading 0 ... 704899  =      0.000 ...  1409.798 secs...
Reading in progress: RLPD/sub-002\ses-01\eeg\sub-002_ses-01_task-ReinforcementLearning_eeg.set
Reading D:\PycharmPorjects\DataProcessingLocalDisk\ERP-Benchmark\RLPD\RLPD\sub-002\ses-01\eeg\sub-002_ses-01_task-ReinforcementLearning_eeg.fdt
Reading 0 ... 960124  =      0.000 ...  1920.248 secs...
Reading in progress: RLPD/sub-002\ses-02\eeg\sub-002_ses-02_task-ReinforcementLearning_eeg.set
Reading D:\PycharmPorjects\DataProcessingLocalDisk\ERP-Benchmark\RLPD\RLPD\sub-002\ses-02\eeg\sub-002_ses-02_task-ReinforcementLearning_eeg.fdt
Reading 0 ... 312124  =      0.000 ...   624.248 secs...
Reading in progress: RLPD/sub-003\ses-01\eeg\sub-003_ses-01_task-ReinforcementLearning_eeg.set
Reading D:\Pychar

In [8]:
## Data preprocessing and segmentation
feature_path = 'Processed/' + sub_folder_path + '/RLPD/Feature'
if not os.path.exists(feature_path):
    os.makedirs(feature_path)

label_path = 'Processed/' + sub_folder_path + '/RLPD/Label'
if not os.path.exists(label_path):
    os.makedirs(label_path)

In [9]:
def data_preprocessing(
        input_data,
        common_channels: list,
        sample_rate: int = 250,
        notch_freq: float = 60.0,
        l_freq: float = 0.1,
        h_freq: float = 20.0,
        do_bad_interp: bool = True,
        verbose: bool = True,
):
    """
    Simon Conflict Task EEG Preprocessing:
      1) Select common channels, remove unreliable ones (FT9, FT10, TP9, TP10), reorder
      2) Set Montage
      3) 60 Hz Notch filter
      4) Bandpass filter (0.1–20 Hz)
      5) Interpolate bad channels (if enabled)
      6) Re-reference to average
      7) ICA removal (fit on 1 Hz copy + ICLabel)
      8) Downsample to 250 Hz (if needed)
    """

    raw = input_data
    # 1. Remove unreliable electrodes
    bad_electrodes = {"FT9", "FT10", "TP9", "TP10"}
    keep = [ch for ch in common_channels if ch in raw.ch_names and ch not in bad_electrodes]
    raw.pick_channels(keep)
    raw.reorder_channels(keep)
    if verbose:
        print(f"✔ Step 1: Picked & reordered channels (excluding {bad_electrodes}): {keep}")

    # 2. Set Montage
    raw.drop_channels(['X', 'Y', 'Z', 'VEOG'])

    raw.set_montage(mne.channels.make_standard_montage('standard_1020'))
    if verbose:
        print("✔ Step 2: Montage set to 'standard_1020'.")

    # 3. Notch filter (60 Hz)
    if notch_freq is not None:
        raw.notch_filter(freqs=[notch_freq], picks="eeg", verbose=False)
        if verbose:
            print(f"✔ Step 3: Notch filtered at {notch_freq} Hz")

    # 4. Bandpass filter (0.1–20 Hz)
    raw.filter(l_freq=l_freq, h_freq=h_freq, picks="eeg", verbose=False)
    if verbose:
        print(f"✔ Step 4: Band-pass filtered from {l_freq}–{h_freq} Hz")

    # 5. Interpolate bad channels
    if do_bad_interp and raw.info.get("bads"):
        raw.interpolate_bads(reset_bads=True, verbose=False)
        if verbose:
            print(f"✔ Step 5: Interpolated bad channels: {raw.info.get('bads', [])}")
    else:
        if verbose:
            print("ℹ Step 5: No bad channels to interpolate.")

    # 6. Re-reference to average
    raw.set_eeg_reference("average", verbose=False)
    if verbose:
        print("✔ Step 6: EEG re-referenced to average")

    # 7. ICA
    raw_for_ica = raw.copy().filter(l_freq=1.0, h_freq=None, picks="eeg", verbose=False)
    ica = ICA(n_components=None, method="fastica", random_state=97, max_iter="auto")

    ica.fit(raw_for_ica)

    excluded = []
    if label_components is not None:
        try:
            ic_labels = label_components(raw_for_ica, ica, method="iclabel")
            labels = ic_labels["labels"]
            probs = ic_labels["y_pred_proba"]
            thresholds = {
                "eye blink": 0.7,
                "muscle artifact": 0.6,
                "heart beat": 0.5,
                "line noise": 0.8,
                "channel noise": 0.9,
            }
            for i, lab in enumerate(labels):
                if lab in thresholds:
                    p = probs[i].max() if probs is not None else 1.0
                    if p >= thresholds[lab]:
                        excluded.append(i)
        except Exception as e:
            if verbose:
                print(f"⚠ ICLabel failed: {e}")
    else:
        if verbose:
            print("ℹ ICLabel not available; ICA fitted but no auto exclusion")

    if excluded:
        ica.exclude = sorted(set(excluded))
        raw = ica.apply(raw.copy())
        if verbose:
            print(f"✔ Step 7: ICA applied. Excluded components: {ica.exclude}")
    else:
        if verbose:
            print("ℹ Step 7: No ICA components excluded.")

    # 8. Downsample (if not already 250 Hz)
    if raw.info["sfreq"] != sample_rate:
        raw.resample(sample_rate, npad="auto", verbose=False)
        if verbose:
            print(f"✔ Step 8: Resampled to {sample_rate} Hz")

    return raw

In [10]:
def epoch_and_make_xy(
    raw: mne.io.Raw,
    events_tsv_path: str,
    tmin: float = -0.5,
    tmax: float = 1.0,
    baseline=(-0.3, -0.2),  # Change 1: Update baseline setting
    task_id: int = 1,
    subject_id: int = 1,
    disease_id: int = 1,
):
    """
    Epoch Simon Conflict data using cue-lock strategy.
    - Use 'Test Stim: ...' as cue-lock events.
    - Extract accuracy from events.tsv directly. (congruent info is unclear in Test Stim)
    """

    # Extract stimulus events (cues) from events.tsv
    ev = pd.read_csv(events_tsv_path, sep="\t")
    ev = ev.reset_index(drop=True)

    stim_mask = ev["trial_type"].astype(str).str.contains("Button", na=False)
    test_response_ev = ev[stim_mask].reset_index()

    # Construct response-locked events using onset times from events.tsv
    sfreq = raw.info["sfreq"]
    print(f"Current sampling frequency: {sfreq} Hz")
    s1_samples = np.round(test_response_ev["onset"].values * sfreq).astype(int)
    events = np.c_[s1_samples, np.zeros_like(s1_samples), np.ones_like(s1_samples)]

    # Construct epochs
    picks = mne.pick_types(raw.info, eeg=True, eog=False, exclude="bads")
    epochs = mne.Epochs(
        raw, events, event_id=dict(cue=1),
        tmin=tmin, tmax=tmax,
        baseline=baseline, picks=picks,
        proj=False, preload=True, reject=None, verbose=False
    )

    # Ensure shape is correct: (N, T, C)
    target_len = int((tmax - tmin) * sfreq)
    data = epochs.get_data()  # (N, C, T)
    if data.shape[-1] > target_len:
        data = data[..., :target_len]
    elif data.shape[-1] < target_len:
        pad = target_len - data.shape[-1]
        data = np.pad(data, ((0, 0), (0, 0), (0, pad)), mode="edge")
    X = np.transpose(data, (0, 2, 1))  # → (N, T, C)

    # Modification 2: derive y from events.tsv
    # [task_id, subject_id, disease_id]
    y = np.column_stack([
        np.full_like(s1_samples, task_id),
        np.full_like(s1_samples, subject_id),
        np.full_like(s1_samples, disease_id)
    ])

    return X, y

In [11]:
for sub in os.listdir(root):
    if 'sub-' in sub:
        sub_path = os.path.join(root, sub)
        for ses in os.listdir(sub_path):
            if 'ses-' in ses:
                print(f"Current subject/session: {sub} / {ses}")
                eeg_path = os.path.join(sub_path, ses, 'eeg/')

                # Initialize file paths
                set_file_path, events_file_path = None, None

                # Traverse the eeg folder to find files
                for file in os.listdir(eeg_path):
                    if file.endswith('.set'):
                        set_file_path = os.path.join(eeg_path, file)
                    if 'events.tsv' in file:
                        events_file_path = os.path.join(eeg_path, file)

                # Error checking
                if set_file_path is None:
                    raise FileNotFoundError(f".set file not found in: {eeg_path}")
                if events_file_path is None:
                    raise FileNotFoundError(f"events.tsv not found in: {eeg_path}")

                # Load EEG data and preprocess
                print("Start preprocessing EEG data...")
                raw = mne.io.read_raw_eeglab(set_file_path, preload=True)
                raw = data_preprocessing(raw, common_channels, SAMPLE_RATE, notch_freq=60, l_freq=0.5, h_freq=40, verbose=True)
                print()


                # # Extract subject_id and disease_id

                subject_id = int(sub.split('-')[-1])
                disease_id = participants[participants['participant_id'] == sub]['Group'].values[0]
                if disease_id == 'CTL':
                    disease_id = 0
                elif disease_id == 'PD':
                    disease_id = 1

                print(f"Subject ID: {subject_id}, Disease ID: {disease_id}")
                print("Start epoching and making X, y...")

                # Remove beh_file_path, keep other parameters unchanged
                X, y = epoch_and_make_xy(
                    raw, events_file_path,
                    tmin=-2.0, tmax=1.0, baseline=(-0.2, 0.0),   # baseline correction according to paper, tmim tmas was observed from the data
                    task_id=1, subject_id=subject_id, disease_id=disease_id
                )

                print(f"X shape: {X.shape}, y shape: {y.shape}")
                np.save(os.path.join(feature_path, 'feature_{:03d}.npy'.format(subject_id)), X)
                np.save(os.path.join(label_path, 'label_{:03d}.npy'.format(subject_id)), y)

                print("------------------------------------------------\n")

Current subject/session: sub-001 / ses-01
Start preprocessing EEG data...
Reading D:\PycharmPorjects\DataProcessingLocalDisk\ERP-Benchmark\RLPD\RLPD\sub-001\ses-01\eeg\sub-001_ses-01_task-ReinforcementLearning_eeg.fdt
Reading 0 ... 704899  =      0.000 ...  1409.798 secs...
✔ Step 1: Picked & reordered channels (excluding {'TP9', 'TP10', 'FT10', 'FT9'}): ['AF7', 'P4', 'F8', 'X', 'CP5', 'F1', 'AFz', 'Fz', 'CP3', 'O2', 'PO4', 'FC6', 'F5', 'PO8', 'AF3', 'FC1', 'FC2', 'FC3', 'O1', 'CP6', 'Fp1', 'F4', 'P5', 'PO3', 'F2', 'FC4', 'Y', 'P7', 'Pz', 'P8', 'FT7', 'P1', 'T7', 'P6', 'CP4', 'F6', 'VEOG', 'C3', 'C4', 'PO7', 'C6', 'CP2', 'P2', 'C5', 'POz', 'FT8', 'FCz', 'Oz', 'Fp2', 'F3', 'AF4', 'FC5', 'C2', 'Cz', 'AF8', 'TP8', 'Z', 'T8', 'TP7', 'F7', 'CP1', 'P3', 'C1']
✔ Step 2: Montage set to 'standard_1020'.
✔ Step 3: Notch filtered at 60 Hz
✔ Step 4: Band-pass filtered from 0.5–40 Hz
ℹ Step 5: No bad channels to interpolate.
✔ Step 6: EEG re-referenced to average
Fitting ICA to data using 59 channe

## Load and check the processed data

In [12]:
# Test the saved npy file
# example

total_samples = 0
sub_id = 1
for feature_file, label_file in zip(os.listdir(feature_path), os.listdir(label_path)):
    feature_file_path = os.path.join(feature_path, feature_file)
    label_file_path = os.path.join(label_path, label_file)
    X = np.load(feature_file_path)
    y = np.load(label_file_path)
    print(f"Subject {sub_id}: X shape: {X.shape}, y shape: {y.shape}")
    if X.shape[0] != y.shape[0]:
        print(f"Subject {sub_id} data and label length mismatch: " 
                f"{X.shape[0]} vs {y.shape[0]}")
    total_samples += np.load(feature_file_path).shape[0]
    sub_id += 1
print("\nTotal number of samples:", total_samples)

Subject 1: X shape: (252, 600, 59), y shape: (252, 3)
Subject 2: X shape: (86, 600, 59), y shape: (86, 3)
Subject 3: X shape: (262, 600, 59), y shape: (262, 3)
Subject 4: X shape: (254, 600, 59), y shape: (254, 3)
Subject 5: X shape: (254, 600, 59), y shape: (254, 3)
Subject 6: X shape: (256, 600, 59), y shape: (256, 3)
Subject 7: X shape: (257, 600, 59), y shape: (257, 3)
Subject 8: X shape: (256, 600, 59), y shape: (256, 3)
Subject 9: X shape: (255, 600, 59), y shape: (255, 3)
Subject 10: X shape: (255, 600, 59), y shape: (255, 3)
Subject 11: X shape: (334, 600, 59), y shape: (334, 3)
Subject 12: X shape: (254, 600, 59), y shape: (254, 3)
Subject 13: X shape: (255, 600, 59), y shape: (255, 3)
Subject 14: X shape: (254, 600, 59), y shape: (254, 3)
Subject 15: X shape: (254, 600, 59), y shape: (254, 3)
Subject 16: X shape: (255, 600, 59), y shape: (255, 3)
Subject 17: X shape: (258, 600, 59), y shape: (258, 3)
Subject 18: X shape: (254, 600, 59), y shape: (254, 3)
Subject 19: X shape: 