# Prepare dataset

## Individual process

In [9]:
import os
import numpy as np
import pandas as pd
from scipy.signal import butter, sosfiltfilt, hilbert
from tqdm import tqdm


def fixation_bandpower_hilbert(df_page):
    """
    Compute fixation-aligned EEG bandpower features using:
      1) Band-pass filter per frequency band (SciPy butter + sosfiltfilt)
      2) Hilbert transform to obtain analytic signal
      3) Instantaneous bandpower time series = |analytic|^2 (log10 transformed)
      4) Aggregate bandpower within each fixation window using sample indices

    Expected input `df_page`:
      - Rows are time samples, ordered in time)
      - First 64 columns are EEG channel values (continuous EEG samples)
      - Columns include:
          sfreq (constant per run/page is fine),
          fix_R_tStart, fix_R_tEnd,
          is_mw,
          fix_R_fixed_word, fix_R_fixed_word_key,
          sentence_id, sentence

    Returns:
      DataFrame with one row per fixation interval and columns:
        {ch}_{band} for 64*8 features + metadata cols.
    """

    bands = {
        "theta1": (4.0, 6.0),
        "theta2": (6.5, 8.0),
        "alpha1": (8.5, 10.0),
        "alpha2": (10.5, 13.0),
        "beta1":  (13.5, 18.0),
        "beta2":  (18.5, 30.0),
        "gamma1": (30.5, 40.0),
        "gamma2": (40.0, 49.5),
    }

    # ---------------------------
    # EEG: (n_ch, n_times)
    # ---------------------------
    eeg = df_page.iloc[:, :64].to_numpy(dtype=np.float64).T
    n_ch, n_times = eeg.shape

    ch_names = df_page.columns.values[:64].tolist()
    sfreq = float(df_page["sfreq"].iloc[0])
    nyq = sfreq / 2.0

    # ---------------------------
    # Build fixation index table: one row per fixation interval
    # start_idx/end_idx are *sample indices* (row indices) within df_page
    # ---------------------------
    df_fix_idx = (
        df_page.dropna(subset=[
            "fix_R_tStart", "fix_R_tEnd",
            "is_mw", "fix_R_fixed_word", "fix_R_fixed_word_key",
            "sentence_id", "sentence",
        ])
        .reset_index()  # creates column "index" = original sample row index
        .groupby(["fix_R_tStart", "fix_R_tEnd"], as_index=False)
        .agg(
            start_idx=("index", "min"),
            end_idx=("index", "max"),
            is_mw=("is_mw", "mean"),
            fix_R_fixed_word=("fix_R_fixed_word", "first"),
            fix_R_fixed_word_key=("fix_R_fixed_word_key", "first"),
            sentence_id=("sentence_id", "first"),
            sentence=("sentence", "first"),
        )
        .sort_values(["fix_R_tStart", "fix_R_tEnd"])
        .reset_index(drop=True)
    )

    if df_fix_idx.empty:
        # no fixations found on this page
        return pd.DataFrame()

    start_samp = np.clip(df_fix_idx["start_idx"].to_numpy(), 0, n_times - 1)
    end_samp   = np.clip(df_fix_idx["end_idx"].to_numpy(),   0, n_times - 1)

    band_names = list(bands.keys())
    band_ranges = [bands[b] for b in band_names]
    n_fix = len(df_fix_idx)
    n_bands = len(band_names)

    feat = np.full((n_fix, n_ch, n_bands), np.nan, dtype=np.float64)

    # ---------------------------
    # Filter + Hilbert per band, then aggregate per fixation
    # ---------------------------
    for bi, (fmin, fmax) in enumerate(band_ranges):
        # SciPy band-pass (zero-phase)
        sos = butter(
            N=4,
            Wn=[fmin / nyq, fmax / nyq],
            btype="bandpass",
            output="sos",
        )
        x_filt = sosfiltfilt(sos, eeg, axis=1)           # filter over time
        x_analytic = hilbert(x_filt, axis=1)             # hilbert over time
        power = np.abs(x_analytic) ** 2                  # (n_ch, n_times)
        # power = np.log10(power + 1e-20)                  # log power

        for i in range(n_fix):
            s = int(start_samp[i])
            e = int(end_samp[i])
            if e < s:
                continue
            feat[i, :, bi] = np.nanmean(power[:, s:e+1], axis=1)

    # Flatten features: (n_fix, n_ch*n_bands)
    feat_flat = feat.reshape(n_fix, n_ch * n_bands)
    columns = [f"{ch}_{band}" for ch in ch_names for band in band_names]
    df_psd = pd.DataFrame(feat_flat, columns=columns)

    # Attach fixation metadata
    df_psd = pd.concat([df_psd, df_fix_idx[[
        "is_mw",
        "fix_R_fixed_word",
        "fix_R_fixed_word_key",
        "sentence_id",
        "sentence",
    ]].reset_index(drop=True)], axis=1)

    return df_psd


# ==========================
# Main processing
# ==========================
data_root = "/gpfs1/pi/djangraw/mindless_reading/data"
coords_root = "/gpfs1/pi/djangraw/hsun11/roamm_ml/res"

all_subjects = sorted(
    d for d in os.listdir(data_root)
    if d.startswith("s") and os.path.isdir(os.path.join(data_root, d))
)

for subject_id in all_subjects:
    print(f"Processing subject {subject_id}...")
    ml_data_dir = os.path.join(data_root, subject_id, "ml_data")
    save_dir = os.path.join(ml_data_dir, "eeg2text_data")
    os.makedirs(save_dir, exist_ok=True)

    pkl_files = sorted([f for f in os.listdir(ml_data_dir) if f.endswith(".pkl")])

    # make sure each subject has 5 runs of data
    if len(pkl_files) != 5:
        raise ValueError(f"Subject {subject_id} has {len(pkl_files)} runs instead of 5")

    subject_rows = []  # collect per-page fixation-level dfs, concat once

    for pkl_file in tqdm(pkl_files, desc=f"{subject_id} runs", unit="run", leave=False):
        df = pd.read_pickle(os.path.join(ml_data_dir, pkl_file))

        # Filter: first pass reading only
        if "first_pass_reading" in df.columns:
            df = df[df["first_pass_reading"] == 1].copy()

        # convert bool col explicitly to avoid pandas warning
        for col in ['is_blink', 'is_sacc', 'is_fix', 'is_mw', 'first_pass_reading']:
            df[col] = df[col] == True

        # Filter out samples 2 seconds before page end
        if "time" in df.columns and "page_end" in df.columns:
            df = df[df["time"] < (df["page_end"] - 2)].copy()

        df["subject_id"] = subject_id

        # Sentence info merge (word_key)
        story_name = df["story_name"].iloc[0]
        coord_path = os.path.join(coords_root, f"{story_name}_coordinates.csv")
        df_coords = pd.read_csv(coord_path)

        df = df.merge(
            df_coords[["sentence_id", "sentence", "word_key"]],
            left_on="fix_R_fixed_word_key",
            right_on="word_key",
            how="left",
        )

        # Process each page
        pages = sorted(df["page_num"].unique().tolist())
        for page in pages:
            df_page = df[df["page_num"] == page].copy()
            df_fix_eeg = fixation_bandpower_hilbert(df_page)
            if df_fix_eeg.empty:
                continue
            # Add metadata columns
            df_fix_eeg["page_num"] = page
            df_fix_eeg["story_name"] = story_name
            df_fix_eeg["subject_id"] = subject_id
            subject_rows.append(df_fix_eeg)

    if len(subject_rows) == 0:
        Warning(f"No fixation data found for subject {subject_id}, saving empty file.")

    df_group = pd.concat(subject_rows, ignore_index=True)
    out_path = os.path.join(save_dir, f"{subject_id}_eeg2text_data.csv")
    df_group.to_csv(out_path, index=False)


Processing subject s10014...


s10014 runs:   0%|          | 0/5 [00:00<?, ?run/s]

                                                           

Processing subject s10052...


                                                           

Processing subject s10059...


                                                           

Processing subject s10073...


                                                           

Processing subject s10081...


                                                           

Processing subject s10084...


                                                           

Processing subject s10085...


                                                           

Processing subject s10089...


                                                           

Processing subject s10094...


                                                           

Processing subject s10100...


                                                           

Processing subject s10103...


                                                           

Processing subject s10110...


                                                           

Processing subject s10111...


                                                           

Processing subject s10115...


                                                           

Processing subject s10117...


                                                           

Processing subject s10121...


                                                           

Processing subject s10125...


                                                           

Processing subject s10138...


                                                           

Processing subject s10139...


                                                           

Processing subject s10141...


                                                           

Processing subject s10144...


                                                           

Processing subject s10145...


                                                           

Processing subject s10148...


                                                           

Processing subject s10153...


                                                           

Processing subject s10156...


                                                           

Processing subject s10158...


                                                           

Processing subject s10159...


                                                           

Processing subject s10160...


                                                           

Processing subject s10165...


                                                           

Processing subject s10173...


                                                           

Processing subject s10177...


                                                           

Processing subject s10178...


                                                           

Processing subject s10180...


                                                           

Processing subject s10181...


                                                           

Processing subject s10183...


                                                           

Processing subject s10185...


                                                           

Processing subject s10186...


                                                           

Processing subject s10188...


                                                           

Processing subject s10192...


                                                           

Processing subject s10195...


                                                           

Processing subject s10196...


                                                           

Processing subject s10197...


                                                           

Processing subject s10200...


                                                           

Processing subject s10202...


                                                           

## Merge subject files

In [10]:
import os
import pandas as pd

data_root = "/gpfs1/pi/djangraw/mindless_reading/data"
all_subjects = sorted(
    d for d in os.listdir(data_root)
    if d.startswith("s") and os.path.isdir(os.path.join(data_root, d))
)
df_list = []

for subject_id in all_subjects:
    ml_data_dir = os.path.join(data_root, subject_id, "ml_data")
    save_dir = os.path.join(ml_data_dir, "eeg2text_data")
    infile = os.path.join(save_dir, f"{subject_id}_eeg2text_data.csv")

    df = pd.read_csv(infile)

    df["subject_id"] = subject_id  # enforce subject_id is correct
    df_list.append(df)

if len(df_list) == 0:
    raise RuntimeError("No subject EEG2Text files were loaded. Nothing to concatenate.")

df_all = pd.concat(df_list, ignore_index=True)
out_file = os.path.join(data_root, "all_subjects_eeg2text_data.csv")
df_all.to_csv(out_file, index=False)

print(f"\nSaved merged dataset to:\n  {out_file}")
print(f"Total rows: {len(df_all):,}")
print(f"Total subjects loaded: {df_all['subject_id'].nunique()}")


Saved merged dataset to:
  /gpfs1/pi/djangraw/mindless_reading/data/all_subjects_eeg2text_data.csv
Total rows: 394,319
Total subjects loaded: 44


# Examine dataset

In [8]:
import numpy as np
import pandas as pd
import torch
from torch.utils.data import Dataset

seed = 42

df = pd.read_csv("/gpfs1/pi/djangraw/mindless_reading/data/all_subjects_eeg2text_data.csv")
data_type = "all"  # options: "all", "nr", "mw"
sentence_id_col = "sentence_id"
sentence_col = "sentence"
fix_key_col = "fix_R_fixed_word_key"
assert data_type in {"all", "nr", "mw"}, f"Invalid data_type={data_type}"

if data_type != "all":
    if "is_mw" not in df.columns:
        raise ValueError("Column 'is_mw' not found in CSV but data_type != 'all' was requested.")
    if data_type == "nr":
        df = df[df["is_mw"] == 0].reset_index(drop=True)
    elif data_type == "mw":
        df = df[df["is_mw"] != 0].reset_index(drop=True)

# Identify EEG feature columns: numeric columns excluding obvious metadata/labels.
meta_cols = {
    sentence_id_col, sentence_col, fix_key_col,
    "fix_R_fixed_word", "is_mw", "page_num", "story_name", "subject_id",
}
eeg_cols = [c for c in df.columns if c not in meta_cols]
eeg_cols = [c for c in eeg_cols if pd.api.types.is_numeric_dtype(df[c])]
if len(eeg_cols) == 0:
    raise ValueError("No numeric EEG feature columns found after excluding metadata columns.")
eeg_cols = eeg_cols

# Drop whole sentences containing any NaN in EEG features or missing sentence text
# if drop_nan_sentences:
#     good_ids = []
#     for sid, g in df.groupby(sentence_id_col, sort=False):
#         if g[self.sentence_col].isna().any():
#             continue
#         if g[self.eeg_cols].isna().any().any():
#             continue
#         good_ids.append(sid)
#     df = df[df[sentence_id_col].isin(good_ids)].reset_index(drop=True)

# Split by unique sentence text (unseen sentences in test)
sent_ids = df[sentence_col].dropna().unique().tolist()
rng = np.random.RandomState(seed)
rng.shuffle(sent_ids)

n = len(sent_ids)
n_train = int(round(0.8 * n))
n_dev = int(round(0.1 * n))

train_ids = set(sent_ids[:n_train])
dev_ids   = set(sent_ids[n_train:n_train + n_dev])
test_ids  = set(sent_ids[n_train + n_dev:])