In [None]:
# Preprocess oddball (robust handling of events.tsv onset units)
import os
import re
import numpy as np
import pandas as pd
import mne
from tqdm import tqdm

# ========== CONFIG ==========
BASE_DIR = r"E:\UNIVERSITY\neurouScience\btl-EEG\data\oddball_files"
SUBS = [f"sub-{i:02d}" for i in range(1, 11)]   
OUT_DIR = os.path.join(BASE_DIR, "preprocessed_mne")
os.makedirs(OUT_DIR, exist_ok=True)

# Preproc params
HP = 1.0
LP = 40.0
NOTCH = [50]
SF_RESAMPLE = 250          # hoặc None để không resample
EPOCH_TMIN, EPOCH_TMAX = -0.2, 0.8
BASELINE = (EPOCH_TMIN, 0)
REJECT = dict(eeg=150e-6)  # 150 µV threshold; set None để disable
ICA_N_COMPONENTS = 0.95    # set None để skip ICA

# Event mapping (use numeric codes found in your events file)
# According to your events.json:
# S  5 -> Standard (frequent)
# S  6 -> Target
# S  7 -> Deviant / Distractor
EVENT_CODE_TO_NAME = {5: 'standard', 6: 'target', 7: 'distractor'}

# Helper: parse event code like "S  5" -> 5
def parse_event_code(s):
    if pd.isna(s):
        return None
    m = re.search(r'(\d+)', str(s))
    return int(m.group(1)) if m else None

# Preprocess one subject
def preprocess_subject(vhdr_path, events_tsv_path, out_dir=OUT_DIR):
    sub = os.path.basename(vhdr_path).split('_')[0]  # e.g. sub-01
    print(f"--- Processing {sub} ---")
    raw = mne.io.read_raw_brainvision(vhdr_path, preload=True, verbose=False)
    orig_sfreq = raw.info['sfreq']
    duration_sec = raw.times[-1]
    print(f"sfreq={orig_sfreq} Hz, duration={duration_sec:.1f}s, nchan={raw.info['nchan']}")

    # Basic filters & ref
    raw.notch_filter(NOTCH, picks='eeg', verbose=False)
    raw.filter(HP, LP, picks='eeg', fir_design='firwin', verbose=False)
    raw.set_eeg_reference('average', projection=False)
    if SF_RESAMPLE is not None and SF_RESAMPLE < raw.info['sfreq']:
        raw.resample(SF_RESAMPLE, npad='auto')

    # Read events.tsv
    ev = pd.read_csv(events_tsv_path, sep='\t', dtype=str)
    # Keep only stimulus trials (per your events.json)
    if 'trial_type' in ev.columns:
        ev = ev[ev['trial_type'].str.lower() == 'stimulus']
    # parse event codes from event_type column
    if 'event_type' in ev.columns:
        ev['event_code'] = ev['event_type'].apply(parse_event_code)
    else:
        raise RuntimeError("events.tsv missing 'event_type' column")

    # parse onset column (may be 'seconds' per JSON, but values look like samples)
    # Try convert to numeric
    ev['onset_raw'] = pd.to_numeric(ev['onset'], errors='coerce')

    # Decide whether onset is in samples or seconds:
    # If maximum onset is much larger than recording duration (in seconds), it's likely samples.
    max_onset = ev['onset_raw'].max()
    sfreq = raw.info['sfreq']
    if np.isnan(max_onset):
        raise RuntimeError("No numeric onset values found in events.tsv")

    # Heuristic: if max_onset > duration_sec * 1.2  -> treat as samples
    if max_onset > duration_sec * 1.2:
        onset_in_samples = True
    else:
        onset_in_samples = False

    if onset_in_samples:
        print("Detected: onset values in EVENTS.TSV are in SAMPLES (converting to sample indices).")
        onset_samples = ev['onset_raw'].astype(int).values
    else:
        print("Detected: onset values in EVENTS.TSV are in SECONDS (converting to samples).")
        onset_samples = np.round(ev['onset_raw'].astype(float).values * sfreq).astype(int)

    # build events array for MNE: (sample, 0, event_id)
    event_codes = ev['event_code'].astype(int).values
    events = np.column_stack((onset_samples, np.zeros_like(onset_samples, dtype=int), event_codes))

    # Remove events falling outside raw sample range (safety)
    good_mask = (events[:,0] >= 0) & (events[:,0] < raw.n_times)
    if not np.all(good_mask):
        nbad = np.sum(~good_mask)
        print(f"[WARN] {nbad} events outside recording bounds -> removed")
        events = events[good_mask]

    # Build event_id mapping name -> int (for Epochs)
    # Only include codes we care about (5,6,7)
    event_id = {}
    for code, name in EVENT_CODE_TO_NAME.items():
        # find if code present in events
        if code in events[:,2]:
            event_id[name] = code

    if len(event_id) == 0:
        raise RuntimeError("No valid stimulus event codes (5/6/7) found in events.tsv / events array.")

    # Epoching: allow duplicate event times (event_repeated='merge')
    epochs = mne.Epochs(raw, events, event_id=event_id,
                        tmin=EPOCH_TMIN, tmax=EPOCH_TMAX,
                        baseline=BASELINE, preload=True,
                        event_repeated='merge', verbose=False)

    # Reject bad epochs (optional)
    if REJECT is not None:
        try:
            epochs.drop_bad(reject=REJECT)
        except Exception as e:
            print(f"[WARN] drop_bad failed: {e}")

    # ICA (optional) - try to remove EOG if available
    if ICA_N_COMPONENTS is not None:
        try:
            ica = mne.preprocessing.ICA(n_components=ICA_N_COMPONENTS, random_state=97, max_iter='auto')
            # fit on raw (filtered)
            ica.fit(raw.copy().filter(1., 40., picks='eeg'), picks='eeg')
            # try to find eog channels or frontal channels
            eog_chs = mne.pick_types(raw.info, eeg=False, eog=True)
            if len(eog_chs) > 0:
                ch_names = [raw.ch_names[i] for i in eog_chs]
                eog_inds, scores = ica.find_bads_eog(raw, ch_name=ch_names)
                ica.exclude = list(set(ica.exclude).union(set(eog_inds)))
            else:
                frontal = [ch for ch in ['Fp1','Fp2','Fpz'] if ch in raw.ch_names]
                if frontal:
                    eog_inds, scores = ica.find_bads_eog(raw, ch_name=frontal)
                    ica.exclude = list(set(ica.exclude).union(set(eog_inds)))
            # apply to epochs
            ica.apply(epochs)
        except Exception as e:
            print(f"[WARN] ICA failed or skipped: {e}")

    # compute evokeds
    evokeds = {cond: epochs[cond].average() for cond in epochs.event_id}

    # Save outputs
    subj_out = os.path.join(out_dir, os.path.basename(vhdr_path).split('_')[0])
    os.makedirs(subj_out, exist_ok=True)
    raw_out = os.path.join(subj_out, f"{os.path.basename(vhdr_path).split('.')[0]}_preproc_raw.fif")
    epochs_out = os.path.join(subj_out, f"{os.path.basename(vhdr_path).split('.')[0]}_preproc-epo.fif")
    epochs.save(epochs_out, overwrite=True)
    raw.save(raw_out, overwrite=True)

    # Save a small summary CSV
    summary = {
        'subject': os.path.basename(vhdr_path).split('_')[0],
        'n_channels': raw.info['nchan'],
        'orig_sfreq': orig_sfreq,
        'processed_sfreq': raw.info['sfreq'],
        'n_epochs_total': len(epochs),
        'event_id_keys': list(event_id.keys())
    }
    pd.DataFrame([summary]).to_csv(os.path.join(subj_out, 'preproc_summary.csv'), index=False)

    print(f"[OK] {sub} done. Saved epochs -> {epochs_out}")
    return summary

# ========== Run for all subs found ==========
results = []
for sub in SUBS:
    vhdr = os.path.join(BASE_DIR, f"{sub}_task-oddball_eeg.vhdr")
    tsv = os.path.join(BASE_DIR, f"{sub}_task-oddball_events.tsv")
    if not os.path.exists(vhdr):
        print(f"[SKIP] {sub}: vhdr not found at {vhdr}")
        continue
    if not os.path.exists(tsv):
        print(f"[SKIP] {sub}: events.tsv not found at {tsv}")
        continue
    try:
        res = preprocess_subject(vhdr, tsv)
        results.append(res)
    except Exception as e:
        print(f"[ERROR] {sub}: {e}")

# Save overall log
pd.DataFrame(results).to_csv(os.path.join(OUT_DIR, "all_subjects_preproc_log.csv"), index=False)
print("All done.")
