In [1]:
import os
import numpy as np
import h5py
import pandas as pd
from scipy.signal import resample
from sklearn.utils import shuffle
import matplotlib.pyplot as plt
import warnings
warnings.filterwarnings("ignore")

In [2]:
import mne
from mne.preprocessing import ICA
try:
    from mne_icalabel import label_components
except Exception:
    label_components = None

In [3]:
SAMPLE_RATE = 200  # fs
# SAMPLE_LEN = 1.0   # sample seconds
# OVERLAPPING = 0.8  # overlapping seconds
sub_folder_path = str(SAMPLE_RATE) + 'Hz'
sub_folder_path

'200Hz'

In [4]:
# root dir
import os
import re  
root = "ADHD-WMRI/"

for file in os.listdir(root):
    if file.endswith(".cnt"):
        # Extract Subject ID and Group

        match = re.match(r"Subject(\d+)_(Control|ADHD)_(\w+)_Raw\.cnt", file)
        if match:
            subject_id = int(match.group(1))
            group = match.group(2)
            task = match.group(3)   # NBackTask / GoNogoTask / CombinedTask

            if group == "Control":
                disease_id = 0
            elif group == "ADHD":
                disease_id = 1

            print(f"Subject {subject_id}, Group={group}, DiseaseID={disease_id}, Task={task}, File={file}")


Subject 10, Group=Control, DiseaseID=0, Task=CombinedTask, File=Subject10_Control_CombinedTask_Raw.cnt
Subject 10, Group=Control, DiseaseID=0, Task=GoNogoTask, File=Subject10_Control_GoNogoTask_Raw.cnt
Subject 10, Group=Control, DiseaseID=0, Task=NBackTask, File=Subject10_Control_NBackTask_Raw.cnt
Subject 11, Group=Control, DiseaseID=0, Task=CombinedTask, File=Subject11_Control_CombinedTask_Raw.cnt
Subject 11, Group=Control, DiseaseID=0, Task=GoNogoTask, File=Subject11_Control_GoNogoTask_Raw.cnt
Subject 11, Group=Control, DiseaseID=0, Task=NBackTask, File=Subject11_Control_NBackTask_Raw.cnt
Subject 12, Group=Control, DiseaseID=0, Task=CombinedTask, File=Subject12_Control_CombinedTask_Raw.cnt
Subject 12, Group=Control, DiseaseID=0, Task=GoNogoTask, File=Subject12_Control_GoNogoTask_Raw.cnt
Subject 12, Group=Control, DiseaseID=0, Task=NBackTask, File=Subject12_Control_NBackTask_Raw.cnt
Subject 13, Group=Control, DiseaseID=0, Task=CombinedTask, File=Subject13_Control_CombinedTask_Raw.cnt


In [5]:
bad_channel_list, sampling_freq_list, data_shape_list = [], [], []

for file in os.listdir(root):
    if file.endswith(".cnt"):  # Find all .cnt files
        file_path = os.path.join(root, file)
        print(f"ğŸ”¹ Loading {file_path} ...")
        
        # Load EEG data (Neuroscan CNT format)
        raw = mne.io.read_raw_cnt(file_path, preload=True, verbose=False)

        # Get bad channels (if any are marked)
        bad_channel = raw.info['bads']
        bad_channel_list.append(bad_channel)

        # Get sampling rate
        sampling_freq = raw.info['sfreq']
        sampling_freq_list.append(sampling_freq)

        #  Get data shape
        data = raw.get_data()
        data_shape = data.shape
        data_shape_list.append(data_shape)



ğŸ”¹ Loading ADHD-WMRI/Subject10_Control_CombinedTask_Raw.cnt ...
ğŸ”¹ Loading ADHD-WMRI/Subject10_Control_GoNogoTask_Raw.cnt ...
ğŸ”¹ Loading ADHD-WMRI/Subject10_Control_NBackTask_Raw.cnt ...
ğŸ”¹ Loading ADHD-WMRI/Subject11_Control_CombinedTask_Raw.cnt ...
ğŸ”¹ Loading ADHD-WMRI/Subject11_Control_GoNogoTask_Raw.cnt ...
ğŸ”¹ Loading ADHD-WMRI/Subject11_Control_NBackTask_Raw.cnt ...
ğŸ”¹ Loading ADHD-WMRI/Subject12_Control_CombinedTask_Raw.cnt ...
ğŸ”¹ Loading ADHD-WMRI/Subject12_Control_GoNogoTask_Raw.cnt ...
ğŸ”¹ Loading ADHD-WMRI/Subject12_Control_NBackTask_Raw.cnt ...
ğŸ”¹ Loading ADHD-WMRI/Subject13_Control_CombinedTask_Raw.cnt ...
ğŸ”¹ Loading ADHD-WMRI/Subject13_Control_GoNogoTask_Raw.cnt ...
ğŸ”¹ Loading ADHD-WMRI/Subject13_Control_NBackTask_Raw.cnt ...
ğŸ”¹ Loading ADHD-WMRI/Subject14_Control_CombinedTask_Raw.cnt ...
ğŸ”¹ Loading ADHD-WMRI/Subject14_Control_GoNogoTask_Raw.cnt ...
ğŸ”¹ Loading ADHD-WMRI/Subject14_Control_NBackTask_Raw.cnt ...
ğŸ”¹ Loading ADHD-WMRI/Subject15_Co

In [6]:
from collections import Counter

print(bad_channel_list)
print(data_shape_list[0])
print("Channel number counter:", Counter(i[0] for i in data_shape_list))
print("Sampling rate counter:", Counter(sampling_freq_list))

[[], [], [], [], [], [], [], [], [], [], [], [], [], [], [], [], [], [], [], [], [], [], [], [], [], [], [], [], [], [], [], [], [], [], [], [], [], [], [], [], [], [], [], [], [], [], [], [], [], [], [], [], [], [], [], [], [], [], [], [], [], [], [], [], [], [], [], [], [], [], [], [], [], [], [], [], [], [], [], [], [], [], [], [], [], [], [], [], [], [], [], [], [], [], [], [], [], [], [], [], [], [], [], [], [], [], [], [], [], []]
(23, 461400)
Channel number counter: Counter({23: 110})
Sampling rate counter: Counter({500.0: 110})


In [7]:
# channel number not consistent, take the common channels
common_channels = None

for file in os.listdir(root):
    if file.endswith(".cnt"):  # Iterate over all .cnt files
        file_path = os.path.join(root, file)
        print(f"ğŸ”¹ Checking channels in {file_path} ...")

        raw = mne.io.read_raw_cnt(file_path, preload=False, verbose=False)  # Read data without fully loading it

        current_channels = set(raw.info['ch_names'])  # Channels of the current file
        if common_channels is None:
            common_channels = current_channels
        else:
            common_channels &= current_channels  # Take the intersection

# Convert to list and sort
if common_channels is None:
    common_channels = []
else:
    common_channels = sorted(list(common_channels))
print(common_channels)
print("Common channels number: ", len(common_channels))

ğŸ”¹ Checking channels in ADHD-WMRI/Subject10_Control_CombinedTask_Raw.cnt ...
ğŸ”¹ Checking channels in ADHD-WMRI/Subject10_Control_GoNogoTask_Raw.cnt ...
ğŸ”¹ Checking channels in ADHD-WMRI/Subject10_Control_NBackTask_Raw.cnt ...
ğŸ”¹ Checking channels in ADHD-WMRI/Subject11_Control_CombinedTask_Raw.cnt ...
ğŸ”¹ Checking channels in ADHD-WMRI/Subject11_Control_GoNogoTask_Raw.cnt ...
ğŸ”¹ Checking channels in ADHD-WMRI/Subject11_Control_NBackTask_Raw.cnt ...
ğŸ”¹ Checking channels in ADHD-WMRI/Subject12_Control_CombinedTask_Raw.cnt ...
ğŸ”¹ Checking channels in ADHD-WMRI/Subject12_Control_GoNogoTask_Raw.cnt ...
ğŸ”¹ Checking channels in ADHD-WMRI/Subject12_Control_NBackTask_Raw.cnt ...
ğŸ”¹ Checking channels in ADHD-WMRI/Subject13_Control_CombinedTask_Raw.cnt ...
ğŸ”¹ Checking channels in ADHD-WMRI/Subject13_Control_GoNogoTask_Raw.cnt ...
ğŸ”¹ Checking channels in ADHD-WMRI/Subject13_Control_NBackTask_Raw.cnt ...
ğŸ”¹ Checking channels in ADHD-WMRI/Subject14_Control_CombinedTask_Raw.c

In [8]:
feature_path = 'Processed/' + sub_folder_path + '/ADHD-WMRI/Feature'
if not os.path.exists(feature_path):
    os.makedirs(feature_path)
    
label_path = 'Processed/' + sub_folder_path + '/ADHD-WMRI/Label'
if not os.path.exists(label_path):
    os.makedirs(label_path)

In [9]:
def data_preprocessing(
    raw: mne.io.Raw,
    common_channels: list,
    sample_rate: int = 250,
    notch_freq: float = 50.0,
    l_freq: float = 0.5,
    h_freq: float = 40.0,
    do_bad_interp: bool = True,
    verbose: bool = True,
):
    """
    Preprocessing steps ï¼š
      1) choose common channels and reorder
      2) Set Montage 
      3) 50 Hz Notchï¼ˆbefore band passï¼‰
      4) bandpass filterï¼ˆdefault 0.5â€“40 Hzï¼‰
      5) interpolate bad channelsï¼ˆif do_bad_interp is Trueï¼‰
      6) re-reference to average
      7) ICA (fit on a 1 Hz high-pass filtered copy; automatically exclude eye/muscle artifacts using mne-icalabel)
      8) downsample to 250 Hz
    """
    
    # 1. select common channels and reorder to given order
    keep = [ch for ch in common_channels if ch in raw.ch_names]
    raw.pick_channels(keep)
    raw.reorder_channels(keep)
    if verbose:
        print(f"âœ” Step 2: Picked common channels ({len(keep)}): {keep}")
    # Remove non-EEG channels
    raw.drop_channels([ch for ch in ['HEOG', 'VEOG'] if ch in raw.ch_names])
    channel_types = raw.get_channel_types(picks=None)
    non_eeg_channels = [ch_name for ch_name, ch_type in zip(raw.ch_names, channel_types) if ch_type != 'eeg']
    if non_eeg_channels:
        raw.drop_channels(non_eeg_channels)
        if verbose:
            print(f"âœ” Step 2.5: Dropped non-EEG channels: {non_eeg_channels}")
            
    # Channel renaming map (to ensure compatibility with the standard_1020 format)
    rename_dict = {
        'FP1': 'Fp1',
        'FP2': 'Fp2',
        'FZ': 'Fz',
        'CZ': 'Cz',
        'PZ': 'Pz',
    }

    # Rename only the channels that exist in raw
    rename_dict = {k: v for k, v in rename_dict.items() if k in raw.ch_names}
    raw.rename_channels(rename_dict)

    
    # 2. Set Montage
    raw.set_montage(mne.channels.make_standard_montage('standard_1020'))
    if verbose:
        print("âœ” Step 1, Montage set: 'standard_1020'.")
        
    # 3. Notch
    if notch_freq is not None:
        raw.notch_filter(freqs=[notch_freq], picks="eeg", verbose=False)
        if verbose:
            print(f"âœ” Step 3: Notch @ {notch_freq} Hz")
        
    # 4. Bandpass Filter (0.5â€“40 Hz)
    raw.filter(l_freq=l_freq, h_freq=h_freq, picks="eeg", verbose=False)
    if verbose:
        print(f"âœ” Step 4: Band-pass {l_freq}â€“{h_freq} Hz")
        
    # 5. Interpolate bad channels
    if do_bad_interp and raw.info.get("bads"):
        raw.interpolate_bads(reset_bads=True, verbose=False)
        if verbose:
            print(f"âœ” Step 5: Interpolated bads: {raw.info.get('bads', [])}")
    else:
        if verbose:
            print("â„¹ Step 5: No bads to interpolate (set raw.info['bads'] first if needed)")
            
    # 6) Re-reference to average
    raw.set_eeg_reference("average", verbose=False)
    if verbose:
        print("âœ” Step 6: Average reference")
    
    print()
    # 7) ICA (fit ICLabel on 1 Hz high-pass filtered copy, then apply to original)
    raw_for_ica = raw.copy().filter(l_freq=1.0, h_freq=None, picks="eeg", verbose=False)
    ica = ICA(n_components=0.999, method="fastica", random_state=97, max_iter="auto")
    ica.fit(raw_for_ica)

    excluded = []
    if label_components is not None:
        try:
            ic_labels = label_components(raw_for_ica, ica, method="iclabel")
            labels = ic_labels["labels"]
            probs = ic_labels["y_pred_proba"]  # (n_comp, n_classes)
            thresholds = {
                "eye blink": 0.7,
                "muscle artifact": 0.6,
                "heart beat": 0.5,
                "line noise": 0.8,
                "channel noise": 0.9,
            }
            for i, lab in enumerate(labels):
                if lab in thresholds:
                    if probs is not None:
                        p = probs[i].max()
                    else:
                        p = 1.0
                    if p >= thresholds[lab]:
                        excluded.append(i)
        except Exception as e:
            if verbose:
                print(f"âš  ICLabel failed ({e}). Skipping auto exclusion.")
    else:
        if verbose:
            print("â„¹ ICLabel not available; fitted ICA but no auto component exclusion.")

    if excluded:
        ica.exclude = sorted(set(excluded))
        raw = ica.apply(raw.copy())
        if verbose:
            print(f"âœ” Step 7: ICA applied. Excluded comps: {ica.exclude}")
    else:
        if verbose:
            print("â„¹ Step 7: No ICA components excluded.")

    # 8) downsample to 250 Hz
    if raw.info["sfreq"] != sample_rate:
        raw.resample(sample_rate, npad="auto", verbose=False)
    if verbose:
        print(f"âœ” Step 8: Resampled to {sample_rate} Hz")
        
    return raw

In [10]:
def epoch_and_make_xy(
    raw: mne.io.Raw,
    tmin: float = -0.2,
    tmax: float = 0.65,
    baseline=(-0.2, 0),
    task_id: int = 1,
    subject_id: int = 1,
    disease_id: int = 1,
    task_type: str = "CombinedTask",  # "NBackTask", "GoNogoTask", or "CombinedTask"
):
    """
    # Extract stimulus events (S 2) from the .cnt file to construct epochs and labels. No external behavioral file required.
    # Applicable when event information is already included in raw.annotations.
    """

    # Extract annotations and convert to events
    if not raw.annotations or len(raw.annotations) == 0:
        raise ValueError("raw.annotations is empty; unable to extract events from the .cnt file.")
    events, _ = mne.events_from_annotations(raw)
    
    # Filter stimulus events with code == 2
    stimulus_events = events[events[:, 2] == 2]
    if len(stimulus_events) == 0:
        raise ValueError("No stimulus events with code == 2 (S 2) were found.")

    # epoching
    picks = mne.pick_types(raw.info, eeg=True, eog=False, exclude="bads")
    epochs = mne.Epochs(
        raw,
        events=stimulus_events,
        event_id={"stimulus": 2},
        tmin=tmin,
        tmax=tmax,
        baseline=baseline,
        picks=picks,
        proj=False,
        preload=True,
        reject=None,
        verbose=False,
    )

    # Ensure consistent epoch length
    sfreq = raw.info["sfreq"]
    target_len = int((tmax - tmin) * sfreq)
    data = epochs.get_data()  # shape: (N, C, T)
    if data.shape[-1] > target_len:
        data = data[..., :target_len]
    elif data.shape[-1] < target_len:
        pad = target_len - data.shape[-1]
        data = np.pad(data, ((0, 0), (0, 0), (0, pad)), mode="edge")

    # Convert to (N, T, C)
    X = np.transpose(data, (0, 2, 1))

    # Cannot access behavioral labels for now; use empty or default labels
    # If you know the stimulus trial type (target / nogo / go) for the given task_type, you can assign a uniform label
    N = len(X)
    if task_type == "NBackTask":
        y0 = np.zeros(N, dtype=int)  
    elif task_type == "GoNogoTask":
        y0 = np.ones(N, dtype=int)  
    elif task_type == "CombinedTask":
        y0 = np.full(N, 2, dtype=int)  
    else:
        raise ValueError(f"Unsupported task_type: {task_type}")

    y = np.column_stack([
        y0,
        np.ones(N, dtype=int),  # All marked as correct (placeholder)
        np.full(N, task_id),
        np.full(N, subject_id),
        np.full(N, disease_id),
    ])

    return X, y


In [11]:
from collections import defaultdict

files_by_sub = defaultdict(list)
for fname in os.listdir(root):
    if not fname.endswith(".cnt"):
        continue
    m = re.search(r"Subject(\d+)", fname)
    if not m:
        continue
    sid = int(m.group(1))
    files_by_sub[sid].append(fname)
files_by_sub

defaultdict(list,
            {10: ['Subject10_Control_CombinedTask_Raw.cnt',
              'Subject10_Control_GoNogoTask_Raw.cnt',
              'Subject10_Control_NBackTask_Raw.cnt'],
             11: ['Subject11_Control_CombinedTask_Raw.cnt',
              'Subject11_Control_GoNogoTask_Raw.cnt',
              'Subject11_Control_NBackTask_Raw.cnt'],
             12: ['Subject12_Control_CombinedTask_Raw.cnt',
              'Subject12_Control_GoNogoTask_Raw.cnt',
              'Subject12_Control_NBackTask_Raw.cnt'],
             13: ['Subject13_Control_CombinedTask_Raw.cnt',
              'Subject13_Control_GoNogoTask_Raw.cnt',
              'Subject13_Control_NBackTask_Raw.cnt'],
             14: ['Subject14_Control_CombinedTask_Raw.cnt',
              'Subject14_Control_GoNogoTask_Raw.cnt',
              'Subject14_Control_NBackTask_Raw.cnt'],
             15: ['Subject15_Control_CombinedTask_Raw.cnt',
              'Subject15_Control_GoNogoTask_Raw.cnt',
              'Subject15_Con

In [12]:
for subject_id, file_list in sorted(files_by_sub.items()):
    X_parts, y_parts = [], []
    print(f"\n=== Subject {subject_id:03d} ===")
    print(file_list)
    
    for filename in file_list:
        print("Current file:", filename)

        # Extract subject_id, group, and task_type from the filename
        name_parts = filename.replace(".cnt", "").split("_")
        group = name_parts[1]  # Control or PD
        task_type = name_parts[2]  # NBackTask / GoNogoTask / CombinedTask

        disease_id = 0 if group == "Control" else 1
        task_id = {"NBackTask": 0, "GoNogoTask": 1, "CombinedTask": 2}[task_type]

        raw_path = os.path.join(root, filename)
        print(f"Subject ID: {subject_id}, Group: {group}, Task: {task_type}")

        raw = mne.io.read_raw_cnt(raw_path, preload=True)
        raw = data_preprocessing(raw, common_channels, SAMPLE_RATE, verbose=True)

        X, y = epoch_and_make_xy(
            raw,
            tmin=-0.2,
            tmax=0.65,
            baseline=(-0.2, 0),
            task_id=task_id,
            subject_id=subject_id,
            disease_id=disease_id,
            task_type=task_type
        )

        print(f"Task {task_id} trial with X shape: {X.shape}, y shape: {y.shape}")
        X_parts.append(X)   # X: (N, T, C)
        y_parts.append(y)   # y: (N, 5) = [task_id, stimulus_type, subject_id, session_id?, disease_id]
        print("----------------\n")

    if X_parts:
        X_all = np.concatenate(X_parts, axis=0)
        y_all = np.concatenate(y_parts, axis=0)
        np.save(os.path.join(feature_path, f"feature_{subject_id:03d}.npy"), X_all)
        np.save(os.path.join(label_path,   f"label_{subject_id:03d}.npy"),   y_all)
        print(f"Saved Subject {subject_id:03d}: X {X_all.shape}, y {y_all.shape}")
    else:
        print(f"(Nothing to save for Subject {subject_id:03d})")

    print("------------------------------------------------\n")


=== Subject 001 ===
['Subject1_Control_CombinedTask_Raw.cnt', 'Subject1_Control_GoNogoTask_Raw.cnt', 'Subject1_Control_NBackTask_Raw.cnt']
Current file: Subject1_Control_CombinedTask_Raw.cnt
Subject ID: 1, Group: Control, Task: CombinedTask
Reading 0 ... 467199  =      0.000 ...   934.398 secs...
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
âœ” Step 2: Picked common channels (23): ['C3', 'C4', 'Cz', 'F3', 'F4', 'F7', 'F8', 'FC5', 'FC6', 'FP1', 'FP2', 'FT10', 'FT9', 'Fz', 'HEOG', 'O1', 'O2', 'P3', 'P4', 'P7', 'P8', 'Pz', 'VEOG']
âœ” Step 1, Montage set: 'standard_1020'.
âœ” Step 3: Notch @ 50.0 Hz
âœ” Step 4: Band-pass 0.5â€“40.0 Hz
â„¹ Step 5: No bads to interpolate (set raw.info['bads'] first if needed)
âœ” Step 6: Average reference

Fitting ICA to data using 21 channels (please be patient, this may take a while)
Selecting by explained variance: 20 components
Fitting ICA took 8.7s.
Applying ICA to Raw instance
    Transforming to ICA space (20 compo

In [13]:
# Test the saved npy file
# example
import re

total_samples = 0
for feature_file, label_file in zip(os.listdir(feature_path), os.listdir(label_path)):
    sub_id = int(re.search(r'\d+', feature_file).group())
    feature_file_path = os.path.join(feature_path, feature_file)
    label_file_path = os.path.join(label_path, label_file)
    X = np.load(feature_file_path)
    y = np.load(label_file_path)
    print(f"Subject {sub_id}: X shape: {X.shape}, y shape: {y.shape}")
    if X.shape[0] != y.shape[0]:
        raise(f"Subject {sub_id} data and label length mismatch: " 
                f"{X.shape[0]} vs {y.shape[0]}")
    total_samples += np.load(feature_file_path).shape[0]
    sub_id += 1
print("\nTotal number of trials:", total_samples)

Subject 1: X shape: (536, 170, 21), y shape: (536, 5)
Subject 2: X shape: (536, 170, 21), y shape: (536, 5)
Subject 3: X shape: (538, 170, 21), y shape: (538, 5)
Subject 4: X shape: (532, 170, 21), y shape: (532, 5)
Subject 5: X shape: (533, 170, 21), y shape: (533, 5)
Subject 6: X shape: (529, 170, 21), y shape: (529, 5)
Subject 7: X shape: (537, 170, 21), y shape: (537, 5)
Subject 8: X shape: (528, 170, 21), y shape: (528, 5)
Subject 9: X shape: (537, 170, 21), y shape: (537, 5)
Subject 10: X shape: (539, 170, 21), y shape: (539, 5)
Subject 11: X shape: (530, 170, 21), y shape: (530, 5)
Subject 12: X shape: (535, 170, 21), y shape: (535, 5)
Subject 13: X shape: (537, 170, 21), y shape: (537, 5)
Subject 14: X shape: (538, 170, 21), y shape: (538, 5)
Subject 15: X shape: (536, 170, 21), y shape: (536, 5)
Subject 16: X shape: (540, 170, 21), y shape: (540, 5)
Subject 17: X shape: (536, 170, 21), y shape: (536, 5)
Subject 18: X shape: (539, 170, 21), y shape: (539, 5)
Subject 19: X shape