In [1]:
import os
import numpy as np
import pandas as pd
from collections import Counter
from sklearn.utils import shuffle
from scipy.signal import resample
import matplotlib.pyplot as plt
import warnings
import mne
from mne.preprocessing import ICA
warnings.filterwarnings("ignore")
try:
    from mne_icalabel import label_components
except Exception:
    label_components = None

In [2]:
SAMPLE_RATE = 200  # fs
# SAMPLE_LEN = 1.0   # sample seconds
# OVERLAPPING = 0.8  # overlapping seconds
sub_folder_path = str(SAMPLE_RATE) + 'Hz'
sub_folder_path

'200Hz'

In [3]:
root = "mTBI-ODD/"
participants_path = os.path.join(root, 'participants.tsv')
participants = pd.read_csv(participants_path, sep='\t')
participants

Unnamed: 0,participant_id,Original_ID,URSI,sex,age,Group
0,sub-001,3044,M87102477,1,23,0
1,sub-002,3009,M87104171,0,30,1
2,sub-003,3042,M87104582,0,19,0
3,sub-004,3074,M87106278,0,19,0
4,sub-005,3076,M87107075,0,22,0
...,...,...,...,...,...,...
91,sub-092,86461,M87186461,1,38,2
92,sub-093,86772,M87186772,0,28,2
93,sub-094,93385,M87193385,1,45,2
94,sub-095,94391,M87194391,0,52,2


In [17]:
bad_channel_list, sampling_freq_list, data_shape_list = [], [], []

bad_files = [
    # data loading error
    ("sub-042", "ses-02"),
    ("sub-033", "ses-02"),
    ("sub-012", "ses-02"),
    ("sub-031", "ses-03"),
    
    # no standard target events (might be caused by wrong event name)
    ("sub-040", "ses-01"),
    ("sub-060", "ses-01"),
]

In [4]:
for sub in os.listdir(root):
    if 'sub-' in sub:
        sub_path = os.path.join(root, sub)
        for ses in os.listdir(sub_path):  # Iterate over all ses-xx folders
            ses_path = os.path.join(sub_path, ses, 'eeg')  
            if os.path.exists(ses_path):
                for file in os.listdir(ses_path):
                    if file.endswith('.set'):
                        file_path = os.path.join(ses_path, file)

                        # Skip known problematic files
                        if any(subj in file_path and sess in file_path for subj, sess in bad_files):
                            print(f"Skipping bad file: {file_path}")
                            continue

                        print(f"Reading EEG file: {file_path}")
                        raw = mne.io.read_raw_eeglab(file_path, preload=True)

                        # Get bad channels
                        bad_channel = raw.info['bads']
                        bad_channel_list.append(bad_channel)

                        # Get sampling rate
                        sampling_freq = raw.info['sfreq']
                        sampling_freq_list.append(sampling_freq)

                        # Get data dimensions
                        data = raw.get_data()
                        data_shape = data.shape
                        data_shape_list.append(data_shape)


Reading EEG file: mTBI-ODD/sub-001\ses-01\eeg\sub-001_ses-01_task-ThreeStimAuditoryOddball_eeg.set
Reading D:\PycharmPorjects\DataProcessingLocalDisk\ERP-Benchmark\mTBI-ODD\mTBI-ODD\sub-001\ses-01\eeg\sub-001_ses-01_task-ThreeStimAuditoryOddball_eeg.fdt
Reading 0 ... 482524  =      0.000 ...   965.048 secs...
Reading EEG file: mTBI-ODD/sub-002\ses-01\eeg\sub-002_ses-01_task-ThreeStimAuditoryOddball_eeg.set
Reading D:\PycharmPorjects\DataProcessingLocalDisk\ERP-Benchmark\mTBI-ODD\mTBI-ODD\sub-002\ses-01\eeg\sub-002_ses-01_task-ThreeStimAuditoryOddball_eeg.fdt
Reading 0 ... 761649  =      0.000 ...  1523.298 secs...
Reading EEG file: mTBI-ODD/sub-002\ses-02\eeg\sub-002_ses-02_task-ThreeStimAuditoryOddball_eeg.set
Reading D:\PycharmPorjects\DataProcessingLocalDisk\ERP-Benchmark\mTBI-ODD\mTBI-ODD\sub-002\ses-02\eeg\sub-002_ses-02_task-ThreeStimAuditoryOddball_eeg.fdt
Reading 0 ... 526624  =      0.000 ...  1053.248 secs...
Reading EEG file: mTBI-ODD/sub-003\ses-01\eeg\sub-003_ses-01_task-T

In [5]:
from collections import Counter

print(bad_channel_list)
print(data_shape_list[0])
print("Channel number counter:", Counter(i[0] for i in data_shape_list))
print("Sampling rate counter:", Counter(sampling_freq_list))

[[], [], [], [], [], [], [], [], [], [], [], [], [], [], [], [], [], [], [], [], [], [], [], [], [], [], [], [], [], [], [], [], [], [], [], [], [], [], [], [], [], [], [], [], [], [], [], [], [], [], [], [], [], [], [], [], [], [], [], [], [], [], [], [], [], [], [], [], [], [], [], [], [], [], [], [], [], [], [], [], [], [], [], [], [], [], [], [], [], [], [], [], [], [], [], [], [], [], [], [], [], [], [], [], [], [], [], [], [], [], [], [], [], [], [], [], [], [], [], [], [], [], [], [], [], [], [], [], [], [], [], [], [], [], [], [], [], [], [], [], [], [], [], [], [], [], [], [], [], [], [], [], [], [], [], [], [], [], [], [], [], [], [], [], [], [], [], [], [], [], [], [], [], [], [], [], [], [], [], [], [], [], [], [], [], [], [], [], [], [], [], [], [], [], [], []]
(65, 482525)
Channel number counter: Counter({65: 188, 64: 8})
Sampling rate counter: Counter({500.0: 196})


In [6]:
common_channels = []

for sub in os.listdir(root):
    if 'sub-' in sub:
        sub_path = os.path.join(root, sub)
        for ses in os.listdir(sub_path):  
            ses_path = os.path.join(sub_path, ses, 'eeg')
            if os.path.exists(ses_path):
                for file in os.listdir(ses_path):
                    if '.set' in file:
                        file_path = os.path.join(ses_path, file)

                        # Skip error files
                        if any(subj in file_path and sess in file_path for subj, sess in bad_files):
                            print(f"Skipping bad file: {file_path}")
                            continue

                        raw = mne.io.read_raw_eeglab(file_path, preload=True)
                        current_channels = set(raw.info['ch_names'])
                        if not common_channels:
                            common_channels = current_channels
                        else:
                            common_channels &= current_channels
common_channels = list(common_channels)
print(common_channels)
print("Common channels number: ", len(common_channels))

Reading D:\PycharmPorjects\DataProcessingLocalDisk\ERP-Benchmark\mTBI-ODD\mTBI-ODD\sub-001\ses-01\eeg\sub-001_ses-01_task-ThreeStimAuditoryOddball_eeg.fdt
Reading 0 ... 482524  =      0.000 ...   965.048 secs...
Reading D:\PycharmPorjects\DataProcessingLocalDisk\ERP-Benchmark\mTBI-ODD\mTBI-ODD\sub-002\ses-01\eeg\sub-002_ses-01_task-ThreeStimAuditoryOddball_eeg.fdt
Reading 0 ... 761649  =      0.000 ...  1523.298 secs...
Reading D:\PycharmPorjects\DataProcessingLocalDisk\ERP-Benchmark\mTBI-ODD\mTBI-ODD\sub-002\ses-02\eeg\sub-002_ses-02_task-ThreeStimAuditoryOddball_eeg.fdt
Reading 0 ... 526624  =      0.000 ...  1053.248 secs...
Reading D:\PycharmPorjects\DataProcessingLocalDisk\ERP-Benchmark\mTBI-ODD\mTBI-ODD\sub-003\ses-01\eeg\sub-003_ses-01_task-ThreeStimAuditoryOddball_eeg.fdt
Reading 0 ... 481349  =      0.000 ...   962.698 secs...
Reading D:\PycharmPorjects\DataProcessingLocalDisk\ERP-Benchmark\mTBI-ODD\mTBI-ODD\sub-004\ses-01\eeg\sub-004_ses-01_task-ThreeStimAuditoryOddball_eeg.f

In [7]:
## Data preprocessing and segmentation
feature_path = 'Processed/' + sub_folder_path + '/mTBI-ODD/Feature'
if not os.path.exists(feature_path):
    os.makedirs(feature_path)

label_path = 'Processed/' + sub_folder_path + '/mTBI-ODD/Label'
if not os.path.exists(label_path):
    os.makedirs(label_path)

In [8]:
def data_preprocessing(
    raw: mne.io.Raw,
    common_channels: list,
    sample_rate: int = 250,
    notch_freq: float = 60.0,
    l_freq: float = 0.5,
    h_freq: float = 40.0,
    do_bad_interp: bool = True,
    verbose: bool = True,
):
    """
    Preprocessing stepsï¼š
      1) Choose common channels and reorder
      2) Set Montage
      3) Notch filter (before bandpass)
      4) Bandpass filter (default 0.1â€“30 Hz)
      5) Interpolate bad channels (if do_bad_interp is True)
      6) Re-reference to average
      7) ICA (on 1 Hz high-pass copy, auto exclude with ICLabel if available)
      8) Downsample to sample_rate
    """

    raw.set_channel_types({'VEOG': 'eog', 'EKG': 'ecg'})
    raw.pick_types(eeg=True)

    # 1. Select common channels
    keep = [ch for ch in common_channels if ch in raw.ch_names]
    raw.pick_channels(keep)
    raw.reorder_channels(keep)
    if verbose:
        print(f"âœ” Step 1: Picked {len(keep)} common EEG channels: {keep}")

    # 2. Set standard montage
    raw.set_montage(mne.channels.make_standard_montage('standard_1020'))
    if verbose:
        print("âœ” Step 2: Montage set to 'standard_1020'.")

    # 3. Notch filter
    if notch_freq is not None:
        raw.notch_filter(freqs=[notch_freq], picks="eeg", verbose=False)
        if verbose:
            print(f"âœ” Step 3: Notch filter applied @ {notch_freq} Hz")

    # 4. Bandpass filter
    raw.filter(l_freq=l_freq, h_freq=h_freq, picks="eeg", verbose=False)
    if verbose:
        print(f"âœ” Step 4: Band-pass filtered ({l_freq}â€“{h_freq} Hz)")

    # 5. Interpolate bad channels
    if do_bad_interp and raw.info.get("bads"):
        raw.interpolate_bads(reset_bads=True, verbose=False)
        if verbose:
            print(f"âœ” Step 5: Interpolated bads: {raw.info.get('bads', [])}")
    else:
        if verbose:
            print("â„¹ Step 5: No bads to interpolate")

    # 6. Re-reference to average
    raw.set_eeg_reference("average", verbose=False)
    if verbose:
        print("âœ” Step 6: Average reference applied")

    # 7) ICA (fit ICLabel on 1 Hz high-pass filtered copy, then apply to original)
    raw_for_ica = raw.copy().filter(l_freq=1.0, h_freq=None, picks="eeg", verbose=False)
    ica = ICA(n_components=None, method="fastica", random_state=97, max_iter="auto")
    ica.fit(raw_for_ica)

    excluded = []

    if 'label_components' in globals() and label_components is not None:
        try:
            ic_labels = label_components(raw_for_ica, ica, method="iclabel")
            labels = ic_labels["labels"]
            probs = ic_labels["y_pred_proba"]  # (n_components, n_classes)
            thresholds = {
                "eye blink": 0.7,
                "muscle artifact": 0.6,
                "heart beat": 0.5,
                "line noise": 0.8,
                "channel noise": 0.9,
            }
            for i, lab in enumerate(labels):
                if lab in thresholds:
                    p = probs[i].max() if probs is not None else 1.0
                    if p >= thresholds[lab]:
                        excluded.append(i)
        except Exception as e:
            if verbose:
                print(f"âš  ICLabel failed: {e}")
    else:
        if verbose:
            print("â„¹ ICLabel not available; fitted ICA but no auto exclusion")

    if excluded:
        ica.exclude = sorted(set(excluded))
        raw = ica.apply(raw.copy())
        if verbose:
            print(f"âœ” Step 7: ICA applied. Excluded components: {ica.exclude}")
    else:
        if verbose:
            print("â„¹ Step 7: No ICA components excluded")

    # 8. Resample
    if raw.info["sfreq"] != sample_rate:
        raw.resample(sample_rate, npad="auto", verbose=False)
        if verbose:
            print(f"âœ” Step 8: Resampled to {sample_rate} Hz")

    return raw


In [9]:
def epoch_and_make_xy(
    raw: mne.io.Raw,
    events_tsv_path: str,
    tmin: float = -0.2,
    tmax: float = 0.8,
    baseline=(-0.2, 0),
    task_id: int = 1,
    subject_id: int = 1,
    session_id: int = 1,
    disease_id: int = 1,
):
    """
    Epoching & Constructing X, y for the Oddball Dataset
    - Extract stimulus onsets from events.tsv and align with trials
    - Epoch window: [-0.2, 0.8] seconds
    - baseline: [-0.2, 0] seconds
    - y = [task_id, stimulus_type, subject_id, session_id, disease_id]
    """
    # Read stimulus events from events.tsv
    ev = pd.read_csv(events_tsv_path, sep="\t")

    # Keep only events with event codes starting with 'S' followed by a number
    ev = ev[ev["value"].astype(str).str.match(r"^S\s*\d+$")].copy()
    ev["code"] = ev["value"].str.extract(r"(\d+)").astype(int)

    # Stimulus type mapping
    stim_map = {200: 0, 201: 1, 202: 2}   # S200=target, S201=standard, S202=novel
    ev["stimulus"] = ev["code"].map(stim_map).fillna(3).astype(int)  # Others are categorized as irrelevant
    ev = ev[ev['stimulus'].isin([0, 1, 2])].reset_index(drop=True)   # only use oddball tasks trials

    # onset â†’ sample 
    sfreq = raw.info["sfreq"]
    stim_samples = np.round(ev["onset"].values * sfreq).astype(int)
    events = np.c_[stim_samples,
                   np.zeros_like(stim_samples),
                   np.ones_like(stim_samples).astype(int)]

    # Epoching
    picks = mne.pick_types(raw.info, eeg=True, eog=False, exclude="bads")
    epochs = mne.Epochs(
        raw, events, event_id=dict(cue=1),
        tmin=tmin, tmax=tmax,
        baseline=baseline, picks=picks,
        proj=False, preload=True, reject=None, verbose=False
    )

    # Ensure uniform (epoch) duration
    target_len = int((tmax - tmin) * sfreq)
    data = epochs.get_data()  # (N, C, T)
    if data.shape[-1] > target_len:
        data = data[..., :target_len]
    elif data.shape[-1] < target_len:
        pad = target_len - data.shape[-1]
        data = np.pad(data, ((0, 0), (0, 0), (0, pad)), mode="edge")

    # Reshape to (N, T, C) 
    X = np.transpose(data, (0, 2, 1))
    
    if len(data) == 0:
        print("âš  Warning: No valid epochs found â€” returning None.")
        return None, None
    if len(ev) != len(data):
        print(f"âš  Warning: Event count mismatch â€” ev: {len(ev)}, epochs: {len(data)}")
    
    ev = ev.iloc[:len(data)]  # truncate to match epoch count
    # Create label array y with structure: [task_id, stimulus_type, subject_id, session_id, disease_id]
    stimulus = ev["stimulus"].values.astype(int)
    y = np.column_stack([
        np.full_like(stimulus, task_id),
        stimulus,
        np.full_like(stimulus, subject_id),
        np.full_like(stimulus, session_id),
        np.full_like(stimulus, disease_id),
    ])

    return X, y

In [18]:
# Iterate over subjects/sessions 
for sub in os.listdir(root):
    if 'sub-' in sub:
        subject_id = int((sub.split('-')[-1]).split('.')[0])
        disease_id = participants.loc[participants['participant_id'] == sub, 'Group'].values[0]
        sub_path = os.path.join(root, sub)
        for ses in os.listdir(sub_path):
            if 'ses-' in ses:
                session_id = int((ses.split('-')[-1]).split('.')[0])
                eeg_path = os.path.join(sub_path, ses, 'eeg/')
                if not os.path.exists(eeg_path):
                    continue  # Skip if 'eeg' folder does not exist

                print(f"Current subject: {sub}, session: {ses}")

                # look for set and events file
                set_file_path, events_file_path = None, None
                for file in os.listdir(eeg_path):
                    if file.endswith('.set'):
                        set_file_path = os.path.join(eeg_path, file)
                    if file.endswith('events.tsv'):
                        events_file_path = os.path.join(eeg_path, file)

                if set_file_path is None:
                    raise FileNotFoundError(f".set file not found in: {eeg_path}")
                if events_file_path is None:
                    raise FileNotFoundError(f"events.tsv not found in: {eeg_path}")

                print("Find the EEG file:", set_file_path)
                print("Find the event file:", events_file_path)
                
                # load eeg data and preprocess
                print("Start preprocessing EEG data...")
                if any(subj in set_file_path and sess in set_file_path for subj, sess in bad_files):
                    print(f"Skipping bad file: {set_file_path}")
                    continue

                raw = mne.io.read_raw_eeglab(set_file_path, preload=True)
                raw = data_preprocessing(raw, common_channels, SAMPLE_RATE, notch_freq=60, l_freq=0.5, h_freq=40, verbose=True)

                #  Skip if None is returned
                if raw is None:
                    print(f"âš  Skipping {sub}, {ses}")
                    continue

                print()

                print(f"Subject ID: {subject_id}, Disease ID: {disease_id}")
                print("Start epoching and making X, y...")

                # Epoch and construct X, y (as required by the paper: -0.2 to 0.8 s, baseline -0.2 to 0 s)
                print(raw.get_data().shape)
                X, y = epoch_and_make_xy(
                    raw, events_file_path,
                    tmin=-0.2, tmax=0.8, baseline=(-0.2, 0),
                    task_id=1, subject_id=subject_id, session_id=session_id, disease_id=disease_id
                )

                print(f"X shape: {X.shape}, y shape: {y.shape}")  
                # X: (N, SAMPLE_RATE * (tmax - tmin), C)
                # y: (N, 5) = [task_id, stimulus_type, subject_id, session_id, disease_id]

                # save X, y to npy files
                np.save(feature_path + f'/feature_{subject_id:03d}.npy', X)
                np.save(label_path + f'/label_{subject_id:03d}.npy', y)

                print("------------------------------------------------\n")
                 

Current subject: sub-060, session: ses-01
Find the EEG file: mTBI-ODD/sub-060\ses-01\eeg/sub-060_ses-01_task-ThreeStimAuditoryOddball_eeg.set
Find the event file: mTBI-ODD/sub-060\ses-01\eeg/sub-060_ses-01_task-ThreeStimAuditoryOddball_events.tsv
Start preprocessing EEG data...
Skipping bad file: mTBI-ODD/sub-060\ses-01\eeg/sub-060_ses-01_task-ThreeStimAuditoryOddball_eeg.set
Current subject: sub-060, session: ses-02
Find the EEG file: mTBI-ODD/sub-060\ses-02\eeg/sub-060_ses-02_task-ThreeStimAuditoryOddball_eeg.set
Find the event file: mTBI-ODD/sub-060\ses-02\eeg/sub-060_ses-02_task-ThreeStimAuditoryOddball_events.tsv
Start preprocessing EEG data...
Reading D:\PycharmPorjects\DataProcessingLocalDisk\ERP-Benchmark\mTBI-ODD\mTBI-ODD\sub-060\ses-02\eeg\sub-060_ses-02_task-ThreeStimAuditoryOddball_eeg.fdt
Reading 0 ... 703199  =      0.000 ...  1406.398 secs...
âœ” Step 1: Picked 61 common EEG channels: ['CP2', 'FC1', 'F1', 'C2', 'C3', 'FT9', 'TP8', 'C1', 'FC3', 'Fz', 'AF8', 'C4', 'FCz', '

## Load and check the processed data

In [19]:
# Test the saved npy file
# example
import re

total_samples = 0
for feature_file, label_file in zip(os.listdir(feature_path), os.listdir(label_path)):
    sub_id = int(re.search(r'\d+', feature_file).group())
    feature_file_path = os.path.join(feature_path, feature_file)
    label_file_path = os.path.join(label_path, label_file)
    X = np.load(feature_file_path)
    y = np.load(label_file_path)
    print(f"Subject {sub_id}: X shape: {X.shape}, y shape: {y.shape}")
    if X.shape[0] != y.shape[0]:
        raise(f"Subject {sub_id} data and label length mismatch: " 
                f"{X.shape[0]} vs {y.shape[0]}")
    total_samples += np.load(feature_file_path).shape[0]
    sub_id += 1
print("\nTotal number of trials:", total_samples)

Subject 1: X shape: (260, 200, 61), y shape: (260, 5)
Subject 2: X shape: (260, 200, 61), y shape: (260, 5)
Subject 3: X shape: (260, 200, 61), y shape: (260, 5)
Subject 4: X shape: (260, 200, 61), y shape: (260, 5)
Subject 5: X shape: (260, 200, 61), y shape: (260, 5)
Subject 6: X shape: (260, 200, 61), y shape: (260, 5)
Subject 7: X shape: (260, 200, 61), y shape: (260, 5)
Subject 8: X shape: (260, 200, 61), y shape: (260, 5)
Subject 9: X shape: (260, 200, 61), y shape: (260, 5)
Subject 10: X shape: (260, 200, 61), y shape: (260, 5)
Subject 11: X shape: (260, 200, 61), y shape: (260, 5)
Subject 12: X shape: (260, 200, 61), y shape: (260, 5)
Subject 13: X shape: (260, 200, 61), y shape: (260, 5)
Subject 14: X shape: (260, 200, 61), y shape: (260, 5)
Subject 15: X shape: (260, 200, 61), y shape: (260, 5)
Subject 16: X shape: (260, 200, 61), y shape: (260, 5)
Subject 17: X shape: (260, 200, 61), y shape: (260, 5)
Subject 18: X shape: (260, 200, 61), y shape: (260, 5)
Subject 19: X shape