# Precompute ONE_MAIN_FOOOF cache (A/B)

This notebook scans a directory of EEGLAB `.set` **epochs** files (e.g. `G:\\ChristianMusaeus\\Preprocessed_setfiles`) and precomputes `ONE_MAIN_FOOOF` artifacts so they don't need to be recomputed every time.

At the top you can choose between:

- **(A) cache alpha profile only**: saves per-subject `(alpha_cf, alpha_bw)`.
- **(B) cache full per-epoch features**: saves `(X, feature_names, metadata)` so inference becomes fast.

Outputs are written under `SAVED_FOOOF_DIR` (default `G:\ChristianMusaeus\saved_fooof`) in a config-tagged subfolder, never into the notebook directory itself.


In [None]:
from __future__ import annotations

import json
import os
import re
import time
import platform
from dataclasses import dataclass
from pathlib import Path
from typing import Dict, List, Optional, Tuple

import numpy as np
import pandas as pd
import mne

# -----------------
# User config
# -----------------

# Choose caching mode:
# - "A" = alpha profile only (subject-level MAIN_FOOOF params)
# - "B" = full per-epoch ONE_MAIN_FOOOF features
CACHE_MODE = "B"  # "A" or "B"

# Data selection (choose which files to scan)
# 1 = Preprocessed_setfiles (default, same as before)
# 2 = Old marked EO/EC folders (Open_marked + Closed_marked)
# 3 = New_EEG processed folder
DATASET_OPTION = 2  # 1, 2, or 3

PREPROCESSED_SETFILES_DIR = os.getenv("PREPROCESSED_SETFILES_DIR", r"G:\\ChristianMusaeus\\Preprocessed_setfiles")
OLD_OPEN_MARKED_DIR = os.getenv("OLD_OPEN_MARKED_DIR", r"E:\\Saxe_sandkasse\\30EOEC_filer\\Open_marked")
OLD_CLOSED_MARKED_DIR = os.getenv("OLD_CLOSED_MARKED_DIR", r"E:\\Saxe_sandkasse\\30EOEC_filer\\Closed_marked")
NEW_EEG_PROCESSED_DIR = os.getenv("NEW_EEG_PROCESSED_DIR", r"G:\\ChristianMusaeus\\New_EEG\\Processed")

if int(DATASET_OPTION) == 1:
    INPUT_DIRS = [PREPROCESSED_SETFILES_DIR]
    DATASET_TAG = "old_dataset_preprocessed_setfiles"
elif int(DATASET_OPTION) == 2:
    INPUT_DIRS = [OLD_OPEN_MARKED_DIR, OLD_CLOSED_MARKED_DIR]
    DATASET_TAG = "old_dataset_open_closed_marked"
elif int(DATASET_OPTION) == 3:
    INPUT_DIRS = [NEW_EEG_PROCESSED_DIR]
    DATASET_TAG = "new_dataset_processed"
else:
    raise ValueError("DATASET_OPTION must be 1, 2, or 3")

RECURSIVE = True

# Output behavior
OVERWRITE = False
SAVE_PER_FILE = True  # if False, write one big combined file (B only)

# Channels
CHANNEL_SELECTION = ["all"]  # or e.g. ["O1", "O2", "P3", "P4", "P7", "P8", "Pz"]
ALL_CHANNELS = any(str(x).lower() == "all" for x in CHANNEL_SELECTION)

# If channels are generic ("Ch1".."Ch19"), optionally map them to standard 10â€“20 names.
# NOTE: This assumes the dataset uses the common 19-channel ordering.
AUTO_RENAME_CH1_TO_1020 = True
CH1_TO_1020_ORDER_19 = [
    "Fp1",
"Fp2",
"F3",
"F4",
"C3",
"C4",
"P3",
"P4",
"O1",
"O2",
"F7",
"F8",
"T7",
"T8",
"P7",
"P8",
"Fz",
"Cz",
"Pz",
]



# Alpha profile (subject MAIN_FOOOF)
ALPHA_PROFILE_RANGE = (4.0, 16.0)
ALPHA_PROFILE_ROI = ["O1", "O2", "P3", "P4", "P7", "P8", "Pz"]

# PSD computation
PSD_KWARGS = dict(fmin=1.0, fmax=45.0)
TARGET_SECS = 2.0
COMBINE_ADJACENT_EPOCHS = False  # If True, pair consecutive epochs by concatenating time (matches EC_EO_Classifier option)
ALPHA_FREQ_RANGE = (3.0, 40.0)   # Fit range used for per-epoch aperiodic-only fits (ONE_MAIN_FOOOF)
ETA_EVERY = 1  # Print an ETA every N files

# FOOOF/specparam settings

# Which library to use for spectral parameterization:
# - "auto": try specparam first, then fooof
# - "specparam": require specparam
# - "fooof": require fooof
BACKEND_PREFERENCE = os.getenv("FOOOF_BACKEND", "auto").strip().lower()
FOOOF_SETTINGS = {
    "aperiodic_mode": "fixed",
    "peak_width_limits": (0.5, 12.0),
    "max_n_peaks": 6,
    "min_peak_height": 0.05,
    "peak_threshold": 2.0,
    "verbose": False,
}

# Feature selection: keep subset of ["offset","exponent","alpha_cf","alpha_amp","alpha_bw"]
FOOOF_SELECTED_FEATURES = ["offset", "exponent", "alpha_amp"]

# -----------------
# Output folder setup
# -----------------

def _detect_notebook_path() -> Optional[Path]:
    try:
        vsc = globals().get("__vsc_ipynb_file__", None)
        if vsc:
            p = Path(str(vsc)).expanduser()
            if p.suffix.lower() == ".ipynb" and p.exists():
                return p.resolve()
    except Exception:
        pass
    for key in ("NOTEBOOK_PATH", "IPYNB_PATH"):
        v = os.getenv(key)
        if v:
            try:
                p = Path(v).expanduser()
                if p.suffix.lower() == ".ipynb" and p.exists():
                    return p.resolve()
            except Exception:
                pass
    # Repo-local fallback
    try:
        here = Path.cwd().resolve()
        for _ in range(6):
            cand = here / "New_EEG" / "Precompute_ONE_MAIN_FOOOF.ipynb"
            if cand.exists():
                return cand.resolve()
            here = here.parent
    except Exception:
        pass
    return None

NOTEBOOK_PATH = _detect_notebook_path()
NOTEBOOK_DIR = NOTEBOOK_PATH.parent if NOTEBOOK_PATH is not None else Path.cwd().resolve()

# Output folder (can be outside repo; default points to your shared drive)
SAVED_FOOOF_DIR = os.getenv("SAVED_FOOOF_DIR", r"G:\ChristianMusaeus\saved_fooof")

def _maybe_wsl_path(p: str) -> str | None:
    m = re.match(r"^([A-Za-z]):[\\/](.*)$", str(p))
    if not m:
        return None
    drive = m.group(1).lower()
    rest = m.group(2).replace("\\", "/")
    return f"/mnt/{drive}/{rest}"

def is_wsl() -> bool:
    try:
        return 'microsoft' in platform.uname().release.lower()
    except Exception:
        return False

def resolve_saved_fooof_root() -> Path:
    candidates = []
    if platform.system() == 'Windows':
        candidates.append(str(SAVED_FOOOF_DIR))
    elif is_wsl():
        w = _maybe_wsl_path(str(SAVED_FOOOF_DIR))
        if w:
            candidates.append(w)
    # repo-local fallback (safe on macOS/Linux)
    candidates.append(str(NOTEBOOK_DIR / 'outputs' / 'saved_fooof'))

    for c in candidates:
        try:
            p = Path(str(c)).expanduser()
            p.mkdir(parents=True, exist_ok=True)
            return p.resolve()
        except Exception:
            continue
    raise RuntimeError(f'Could not create SAVED_FOOOF_DIR: {SAVED_FOOOF_DIR}')

OUTPUTS_ROOT = resolve_saved_fooof_root()


def _safe_tag(s: str) -> str:
    s = re.sub(r"[^a-zA-Z0-9._-]+", "_", str(s))
    s = re.sub(r"_+", "_", s).strip("_")
    return s[:120] if len(s) > 120 else s

def _channel_tag() -> str:
    sel = [str(x).strip() for x in (CHANNEL_SELECTION or []) if x is not None]
    if any(x.lower() == "all" for x in sel):
        return "allch"
    if not sel:
        return "ch_unknown"
    joined = "-".join(_safe_tag(x.upper()) for x in sel)
    return f"ch_{joined}"[:80]

def _config_tag() -> str:
    parts = [
        str(DATASET_TAG),
        "fooof",
        "one_main_fooof",
        f"cache_{_safe_tag(CACHE_MODE).lower()}",
        _channel_tag(),
        f"psd_{PSD_KWARGS.get('fmin', 1.0)}-{PSD_KWARGS.get('fmax', 45.0)}Hz",
    ]
    if COMBINE_ADJACENT_EPOCHS:
        parts.append("pair_epochs")
    parts.append(f"fit_{ALPHA_FREQ_RANGE[0]}-{ALPHA_FREQ_RANGE[1]}Hz")
    if FOOOF_SELECTED_FEATURES:
        parts.append("feat_" + _safe_tag("-".join(FOOOF_SELECTED_FEATURES)))
    return "__".join([p for p in parts if p])

def get_output_dir() -> Path:
    out_dir = OUTPUTS_ROOT / _config_tag()
    out_dir.mkdir(parents=True, exist_ok=True)
    return out_dir

def outpath(name: str) -> Path:
    return get_output_dir() / str(name)

print("Notebook dir:", NOTEBOOK_DIR)
print("Outputs root:", OUTPUTS_ROOT)
print("Output dir:", get_output_dir())

# Persist config for reproducibility
config_dump = {
    "CACHE_MODE": CACHE_MODE,
    "DATASET_OPTION": int(DATASET_OPTION),
    "DATASET_TAG": str(DATASET_TAG),
    "INPUT_DIRS": list(INPUT_DIRS),
    "PREPROCESSED_SETFILES_DIR": PREPROCESSED_SETFILES_DIR,
    "OLD_OPEN_MARKED_DIR": OLD_OPEN_MARKED_DIR,
    "OLD_CLOSED_MARKED_DIR": OLD_CLOSED_MARKED_DIR,
    "NEW_EEG_PROCESSED_DIR": NEW_EEG_PROCESSED_DIR,
    "RECURSIVE": RECURSIVE,
    "OVERWRITE": OVERWRITE,
    "SAVE_PER_FILE": SAVE_PER_FILE,
    "CHANNEL_SELECTION": CHANNEL_SELECTION,
    "ALPHA_PROFILE_RANGE": list(ALPHA_PROFILE_RANGE),
    "ALPHA_PROFILE_ROI": list(ALPHA_PROFILE_ROI),
    "PSD_KWARGS": dict(PSD_KWARGS),
    "TARGET_SECS": TARGET_SECS,
    "COMBINE_ADJACENT_EPOCHS": COMBINE_ADJACENT_EPOCHS,
    "ALPHA_FREQ_RANGE": list(ALPHA_FREQ_RANGE),
    "ETA_EVERY": ETA_EVERY,
    "FOOOF_SETTINGS": dict(FOOOF_SETTINGS),
    "FOOOF_BACKEND_PREFERENCE": str(BACKEND_PREFERENCE),
    "FOOOF_SELECTED_FEATURES": list(FOOOF_SELECTED_FEATURES),
}
outpath("config.json").write_text(json.dumps(config_dump, indent=2), encoding="utf-8")

In [None]:
# -----------------
# Backend imports
# -----------------

SpectralModel = None
FitError = Exception
FOOOF_BACKEND = "unavailable"
_backend_error = None

pref = str(globals().get('BACKEND_PREFERENCE', 'auto')).strip().lower()
if pref not in {'auto', 'specparam', 'fooof'}:
    print(f"Unknown BACKEND_PREFERENCE={pref!r}; falling back to 'auto'.")
    pref = 'auto'

# Try specparam
if pref in {'auto', 'specparam'}:
    try:
        from specparam import SpectralModel as _SpecModel
        from specparam.core.errors import FitError as _FitError
        SpectralModel = _SpecModel
        FitError = _FitError
        FOOOF_BACKEND = 'specparam'
    except Exception as exc:
        _backend_error = exc
        if pref == 'specparam':
            raise RuntimeError(
                "BACKEND_PREFERENCE=specparam but specparam could not be imported.\n"
                "Install specparam or set BACKEND_PREFERENCE=auto/fooof."
            ) from exc

# Try fooof
if SpectralModel is None and pref in {'auto', 'fooof'}:
    try:
        from fooof import FOOOF as _FooofModel
        from fooof.core.errors import FitError as _FitError
        SpectralModel = _FooofModel
        FitError = _FitError
        FOOOF_BACKEND = 'fooof'
    except Exception as exc:
        _backend_error = exc
        if pref == 'fooof':
            raise RuntimeError(
                "BACKEND_PREFERENCE=fooof but fooof could not be imported.\n"
                "Install fooof or set BACKEND_PREFERENCE=auto/specparam."
            ) from exc

print('FOOOF/specparam backend used:', FOOOF_BACKEND)
if SpectralModel is None:
    raise RuntimeError('FOOOF/specparam backend unavailable in this environment.')

# Update config.json with the chosen backend (best-effort)
try:
    cfg_path = outpath('config.json')
    if cfg_path.exists():
        cfg = json.loads(cfg_path.read_text(encoding='utf-8'))
    else:
        cfg = {}
    cfg['FOOOF_BACKEND_USED'] = str(FOOOF_BACKEND)
    cfg_path.write_text(json.dumps(cfg, indent=2), encoding='utf-8')
except Exception as exc:
    print('Warning: could not update config.json with backend:', exc)


In [None]:
# -----------------
# Helpers
# -----------------

def parse_subject_id(path: Path) -> int:
    stem = path.stem
    m = re.search(r"(\d{3,})", stem)
    if m:
        return int(m.group(1))
    return int(abs(hash(stem)) % 1_000_000_000)

def canonical_channel_name(ch_name: str) -> str:
    name = str(ch_name).strip()
    name = re.sub(r"^EEG\s+", "", name, flags=re.IGNORECASE)
    name = re.sub(r"-REF$", "", name, flags=re.IGNORECASE)
    name = re.sub(r"\s+", "", name)
    key = name.upper()
    if globals().get('AUTO_RENAME_CH1_TO_1020', False):
        order = globals().get('CH1_TO_1020_ORDER_19', None)
        if isinstance(order, (list, tuple)) and len(order) == 19:
            ch_map = {f'CH{i+1}': str(order[i]) for i in range(19)}
            if key in ch_map:
                return ch_map[key]
    return name

def rename_epochs_channels_canonical(epochs):
    new_names = [canonical_channel_name(ch) for ch in epochs.ch_names]
    if len(set(new_names)) != len(new_names):
        return epochs
    mapping = {old: new for old, new in zip(epochs.ch_names, new_names) if old != new}
    if mapping:
        epochs.rename_channels(mapping)
    return epochs

def psd_array_welch_clean(data: np.ndarray, sfreq: float, fmin=1.0, fmax=45.0, target_secs=2.0):
    n_epochs, _, n_times = data.shape
    n_per_seg = max(8, min(n_times, int(round(target_secs * sfreq))))
    n_overlap = n_per_seg // 2 if n_per_seg >= 16 else 0
    psds, freqs = mne.time_frequency.psd_array_welch(
        data,
        sfreq=sfreq,
        fmin=float(fmin),
        fmax=float(fmax),
        n_per_seg=n_per_seg,
        n_overlap=n_overlap,
        window="hann",
        average="mean",
        verbose=False,
    )
    return psds, freqs

def select_alpha_peak(peaks: np.ndarray, lo: float, hi: float):
    peaks_arr = np.asarray(peaks, float)
    if peaks_arr.size == 0:
        return None
    if peaks_arr.ndim == 1:
        peaks_arr = peaks_arr.reshape(1, -1)
    mask = (peaks_arr[:, 0] >= lo) & (peaks_arr[:, 0] <= hi)
    if not np.any(mask):
        return None
    subset = peaks_arr[mask]
    return subset[np.argmax(subset[:, 1])]

def compute_one_main_fooof_features(freqs: np.ndarray, psd_cube: np.ndarray, subject_id: int, alpha_profile_map, include_aperiodic: bool = True) -> np.ndarray:
    """Compute ONE_MAIN_FOOOF features (matches `New_EEG/EC_EO_Classifier.ipynb`).

    For each subject, alpha center frequency and bandwidth are fixed from alpha_profile_map.
    For each epoch/channel, fit only the aperiodic component (max_n_peaks=0) and then fit
    the amplitude of a Gaussian alpha template to the residual.

    Feature layout per channel: [offset, exponent, alpha_cf, alpha_amp, alpha_bw].
    """
    if SpectralModel is None:
        raise RuntimeError("FOOOF backend unavailable.")
    subj = int(subject_id)
    profile = alpha_profile_map.get(subj) if alpha_profile_map is not None else None
    has_profile = profile is not None and len(profile) == 2
    if has_profile:
        alpha_cf, alpha_bw = map(float, profile)
    else:
        alpha_cf, alpha_bw = 0.0, 0.0
    freqs_arr = np.asarray(freqs, float)
    import math
    if has_profile and alpha_bw > 0:
        sigma = float(alpha_bw) / (2.0 * math.sqrt(2.0 * math.log(2.0)))
        gauss = np.exp(-0.5 * ((freqs_arr - alpha_cf) / sigma) ** 2)
    else:
        gauss = np.zeros_like(freqs_arr)
    denom = float(np.sum(gauss ** 2)) if gauss.size else 0.0
    features = []
    ap_settings = dict(FOOOF_SETTINGS)
    try:
        ap_settings["max_n_peaks"] = 0
    except Exception:
        pass
    for epoch_psd in psd_cube:
        epoch_feats = []
        for spectrum in epoch_psd:
            try:
                if not np.all(np.isfinite(spectrum)):
                    raise ValueError("Non-finite in spectrum")
                offset, exponent = 0.0, 0.0
                alpha_amp = 0.0
                if include_aperiodic or has_profile:
                    model = SpectralModel(**ap_settings)
                    model.fit(freqs_arr, spectrum, freq_range=ALPHA_FREQ_RANGE)
                    if hasattr(model, "aperiodic_params_"):
                        params = np.asarray(model.aperiodic_params_)
                        if params.size > 0:
                            offset = float(params[0])
                        if params.size > 1:
                            exponent = float(params[1])
                    try:
                        ap_fit = None
                        get_fun = getattr(model, "get_model_spectrum", None)
                        if callable(get_fun):
                            ap_fit = np.asarray(get_fun(freqs_arr))
                    except Exception:
                        ap_fit = None
                    if ap_fit is None:
                        for name in ("fooofed_spectrum_", "modeled_spectrum_", "model_spectrum_", "model_spectrum__"):
                            if hasattr(model, name):
                                ap_fit = np.asarray(getattr(model, name))
                                break
                    if ap_fit is None or ap_fit.shape != spectrum.shape:
                        ap_fit = np.zeros_like(spectrum)
                    if has_profile and denom > 0.0:
                        residual = spectrum - ap_fit
                        num = float(np.sum(gauss * residual))
                        alpha_amp = max(num / denom, 0.0)
                epoch_feats.extend([offset, exponent, alpha_cf if has_profile else 0.0, alpha_amp, alpha_bw if has_profile else 0.0])
            except (FitError, RuntimeError, ValueError, np.linalg.LinAlgError):
                epoch_feats.extend([0.0, 0.0, 0.0, 0.0, 0.0])
        features.append(epoch_feats)
    return np.asarray(features, dtype=float)

def fooof_feature_names(channels: List[str]) -> List[str]:
    base_order = ["offset", "exponent", "alpha_cf", "alpha_amp", "alpha_bw"]
    selected = [f for f in (FOOOF_SELECTED_FEATURES or base_order) if f in base_order]
    if not selected:
        selected = base_order
    names = []
    for ch in channels:
        for feat in selected:
            names.append(f"{ch}_{feat}")
    return names

def select_feature_columns_full(full_X: np.ndarray, channels: List[str]) -> Tuple[np.ndarray, List[str]]:
    base_order = ["offset", "exponent", "alpha_cf", "alpha_amp", "alpha_bw"]
    selected = [f for f in (FOOOF_SELECTED_FEATURES or base_order) if f in base_order]
    if not selected:
        selected = base_order
    stride = len(base_order)
    keep_offsets = [base_order.index(s) for s in selected]
    idx = []
    for ch_i in range(len(channels)):
        base = ch_i * stride
        for off in keep_offsets:
            idx.append(base + off)
    X_sel = full_X[:, idx]
    return X_sel, fooof_feature_names(channels)


In [None]:
# -----------------
# Scan input files
# -----------------

def _maybe_wsl_path(p: str) -> Optional[str]:
    m = re.match(r"^([A-Za-z]):[\\\\/](.*)$", str(p))
    if not m:
        return None
    drive = m.group(1).lower()
    rest = m.group(2).replace("\\\\", "/").replace("\\", "/")
    return f"/mnt/{drive}/{rest}"

def resolve_input_dirs() -> List[Path]:
    candidates = list(INPUT_DIRS or [])
    # WSL convenience for Windows paths
    for c in list(candidates):
        w = _maybe_wsl_path(str(c))
        if w:
            candidates.append(w)
    # Common WSL shortcuts (harmless if they don't exist)
    candidates += [
        r"/mnt/g/ChristianMusaeus/Preprocessed_setfiles",
        r"/mnt/e/Saxe_sandkasse/30EOEC_filer/Open_marked",
        r"/mnt/e/Saxe_sandkasse/30EOEC_filer/Closed_marked",
        r"/mnt/g/ChristianMusaeus/New_EEG/Processed",
    ]
    # Local fallback
    candidates += [str((NOTEBOOK_DIR.parent / "data").resolve())]
    found: List[Path] = []
    for c in candidates:
        if not c:
            continue
        try:
            p = Path(str(c)).expanduser()
            if p.exists() and p.is_dir():
                found.append(p.resolve())
        except Exception:
            pass
    # de-dup preserve order
    seen = set()
    out: List[Path] = []
    for p in found:
        key = str(p)
        if key in seen:
            continue
        seen.add(key)
        out.append(p)
    return out

roots = resolve_input_dirs()
print("Resolved input dirs:", roots)
if not roots:
    raise RuntimeError("Could not resolve any INPUT_DIRS for this DATASET_OPTION. Update paths or set env vars.")

paths: List[Path] = []

# Decide which file types to scan
if int(DATASET_OPTION) == 3:
    patterns = ("*_epo.fif", "*_epo.FIF", "*.fif", "*.FIF")
else:
    patterns = ("*.set", "*.SET")

for root in roots:
    for pat in patterns:
        paths += sorted(root.rglob(pat) if RECURSIVE else root.glob(pat))
paths = [p.resolve() for p in dict.fromkeys(paths)]
print("Input files found:", len(paths))
if not paths:
    raise RuntimeError(f"No input files found under {roots} for patterns={patterns}. Check DATASET_OPTION and directory paths.")
if paths:
    print("Example:", paths[0])

# Persist the input file list for reproducibility
try:
    outpath("input_files.txt").write_text("\n".join(str(p) for p in paths), encoding="utf-8")
except Exception as exc:
    print("Warning: could not write input_files.txt:", exc)

In [None]:
# -----------------
# Epoch loader + validations
# -----------------

def load_epochs(path: Path):
    suf = path.suffix.lower()
    if suf == '.set':
        return mne.io.read_epochs_eeglab(str(path), verbose='ERROR')
    if suf == '.fif':
        # MNE epochs saved as FIF (used by the NEW processed dataset)
        return mne.read_epochs(str(path), preload=False, verbose='ERROR')
    raise ValueError(f"Unsupported file type: {path}")

def validate_epochs_basic(epochs, path: Path) -> Tuple[bool, str]:
    try:
        n_epochs = len(epochs)
        n_ch = len(getattr(epochs, 'ch_names', []))
        sfreq = float(epochs.info['sfreq'])
        if n_epochs <= 0:
            return False, 'no epochs'
        if n_ch <= 0:
            return False, 'no channels'
        if not np.isfinite(sfreq) or sfreq <= 0:
            return False, 'invalid sfreq'
        return True, ''
    except Exception as exc:
        return False, f'bad epochs object: {exc}'



def validate_saved_npz(npz_path: Path, expected_n_epochs: int, expected_feature_names: List[str]) -> Tuple[bool, str]:
    try:
        d = np.load(npz_path, allow_pickle=True)
        if 'X' not in d or 'feature_names' not in d:
            return False, 'missing keys'
        X = np.asarray(d['X'])
        names = [str(x) for x in np.asarray(d['feature_names']).ravel().tolist()]
        if X.ndim != 2:
            return False, f'X has wrong ndim: {X.ndim}'
        if int(X.shape[0]) != int(expected_n_epochs):
            return False, f'epoch count mismatch: {X.shape[0]} vs {expected_n_epochs}'
        if len(names) != int(X.shape[1]):
            return False, 'feature_names length mismatch'
        if expected_feature_names and names != list(expected_feature_names):
            return False, 'feature_names order mismatch'
        if not np.all(np.isfinite(X)):
            return False, 'non-finite values in X'
        return True, ''
    except Exception as exc:
        return False, str(exc)
# -----------------
# Main loop
# -----------------

@dataclass
class ProfileRow:
    subject_id: int
    file: str
    alpha_cf: float
    alpha_bw: float
    n_epochs_used: int
    roi_channels: str
    n_channels: int
    n_freqs: int

profile_rows: List[ProfileRow] = []
file_info_rows: List[dict] = []
feature_manifest_rows: List[dict] = []

def _fmt_secs(seconds: float) -> str:
    seconds = max(0.0, float(seconds))
    m, s = divmod(int(round(seconds)), 60)
    h, m = divmod(m, 60)
    return f"{h:d}h{m:02d}m" if h else f"{m:d}m{s:02d}s"

n_skipped = 0
n_processed = 0

t0 = time.time()
for i, path in enumerate(paths, start=1):
    subj = parse_subject_id(path)
    file_tag = _safe_tag(path.stem)
    t_file = time.time()
    print(f"[{i}/{len(paths)}] {path.name} (subject_id={subj})")

    # Skip if already cached
    if CACHE_MODE.upper() == "B" and SAVE_PER_FILE:
        target_npz = outpath(f"features_subject_{subj}__{file_tag}.npz")
        if target_npz.exists() and not OVERWRITE:
            n_skipped += 1
            msg = f"  -> exists, skipping: {target_npz.name}"
            if ETA_EVERY and (i % int(ETA_EVERY) == 0):
                elapsed = time.time() - t0
                avg = elapsed / max(i, 1)
                eta = avg * max(len(paths) - i, 0)
                msg += f" | elapsed {_fmt_secs(elapsed)} | ETA {_fmt_secs(eta)}"
            print(msg)
            continue

    try:
        epochs = load_epochs(path)
        ok, why = validate_epochs_basic(epochs, path)
        if not ok:
            print("  -> invalid epochs file (skipping):", why)
            continue
        file_info_rows.append({
            "subject_id": int(subj),
            "file": str(path),
            "suffix": str(path.suffix),
            "n_epochs": int(len(epochs)),
            "n_channels": int(len(epochs.ch_names)),
            "sfreq": float(epochs.info["sfreq"]),
            "original_ch_names": ",".join(map(str, epochs.ch_names)),
        })
    except Exception as exc:
        print("  -> failed to read as epochs .set (skipping):", exc)
        continue

    epochs = rename_epochs_channels_canonical(epochs)
    try:
        if file_info_rows and int(file_info_rows[-1].get("subject_id", -1)) == int(subj):
            file_info_rows[-1]["mapped_ch_names"] = ",".join(map(str, epochs.ch_names))
    except Exception:
        pass
    sfreq = float(epochs.info['sfreq'])

    if ALL_CHANNELS:
        picks_all = mne.pick_types(epochs.info, eeg=True, meg=False, stim=False, eog=False, exclude='bads')
        if len(picks_all) == 0:
            picks_all = mne.pick_types(epochs.info, eeg=True, meg=False, stim=False, eog=False, exclude=[])
        picks = picks_all
        channels = [epochs.ch_names[idx] for idx in picks]
    else:
        requested = [str(ch).upper() for ch in CHANNEL_SELECTION]
        name_lookup = {canonical_channel_name(ch).upper(): ch for ch in epochs.ch_names}
        missing = [ch for ch in requested if ch not in name_lookup]
        if missing:
            print("  -> missing requested channels; skipping:", missing)
            continue
        channels = [name_lookup[ch] for ch in requested]
        picks = [epochs.ch_names.index(ch) for ch in channels]

    data_all = epochs.get_data(picks=picks)
    finite_mask = np.all(np.isfinite(data_all), axis=(1, 2))
    if not np.any(finite_mask):
        print("  -> no finite epochs; skipping")
        continue
    if COMBINE_ADJACENT_EPOCHS:
        finite_idx = np.flatnonzero(finite_mask)
        pairs = []
        cursor = 0
        while cursor < len(finite_idx) - 1:
            i0 = int(finite_idx[cursor])
            i1 = int(finite_idx[cursor + 1])
            if i1 == i0 + 1:
                pairs.append((i0, i1))
                cursor += 2
            else:
                cursor += 1
        if not pairs:
            print("  -> no adjacent finite epochs found for pairing; using single epochs")
            data = data_all[finite_mask]
        else:
            combined = [np.concatenate([data_all[i0], data_all[i1]], axis=1) for i0, i1 in pairs]
            data = np.stack(combined, axis=0)
    else:
        data = data_all[finite_mask]

    # PSD cube: (epochs, channels, freqs)
    try:
        psd, freqs = psd_array_welch_clean(
            data,
            sfreq=sfreq,
            fmin=PSD_KWARGS.get('fmin', 1.0),
            fmax=PSD_KWARGS.get('fmax', 45.0),
            target_secs=TARGET_SECS,
        )
    except Exception as exc:
        print("  -> PSD failed; skipping:", exc)
        continue

    # Alpha profile from ROI mean spectrum
    roi_names = [ch for ch in ALPHA_PROFILE_ROI if ch in channels]
    if not roi_names:
        print("  -> no ROI channels present; using all channels for profile")
        roi_idx = list(range(len(channels)))
        roi_names = channels
    else:
        roi_idx = [channels.index(ch) for ch in roi_names]

    roi_cube = psd[:, roi_idx, :]
    mean_spectrum = np.nanmean(roi_cube, axis=(0, 1))
    alpha_cf = 0.0
    alpha_bw = 0.0
    try:
        model = SpectralModel(**FOOOF_SETTINGS)
        lo, hi = ALPHA_PROFILE_RANGE
        model.fit(freqs, mean_spectrum, freq_range=(float(lo), float(hi)))
        peaks = np.asarray(getattr(model, 'peak_params_', []))
        chosen = select_alpha_peak(peaks, float(lo), float(hi))
        if chosen is not None:
            alpha_cf, amp, alpha_bw = map(float, chosen[:3])
    except Exception as exc:
        print("  -> alpha profile fit failed; continuing with zeros:", exc)

    profile_rows.append(
        ProfileRow(
            subject_id=int(subj),
            file=str(path),
            alpha_cf=float(alpha_cf),
            alpha_bw=float(alpha_bw),
            n_epochs_used=int(psd.shape[0]),
            roi_channels=','.join(map(str, roi_names)),
            n_channels=int(psd.shape[1]),
            n_freqs=int(psd.shape[2]),
        )
    )

    if CACHE_MODE.upper() == "A":
        n_processed += 1
        if ETA_EVERY and (i % int(ETA_EVERY) == 0):
            elapsed = time.time() - t0
            avg = elapsed / max(i, 1)
            eta = avg * max(len(paths) - i, 0)
            print(f"  -> done in {_fmt_secs(time.time() - t_file)} | elapsed {_fmt_secs(elapsed)} | ETA {_fmt_secs(eta)}")
        continue

    # Full per-epoch features
    alpha_profile_map = {int(subj): (float(alpha_cf), float(alpha_bw))}
    full_X = compute_one_main_fooof_features(freqs=freqs, psd_cube=psd, subject_id=int(subj), alpha_profile_map=alpha_profile_map, include_aperiodic=True)
    X, feature_names = select_feature_columns_full(full_X, channels)

    if SAVE_PER_FILE:
        target_npz = outpath(f"features_subject_{subj}__{file_tag}.npz")
        if target_npz.exists() and not OVERWRITE:
            print("  -> exists, skipping save:", target_npz.name)
        else:
            np.savez_compressed(
                target_npz,
                X=X.astype(np.float32),
                feature_names=np.asarray(feature_names, dtype=object),
                freqs=freqs.astype(np.float32),
                channels=np.asarray(channels, dtype=object),
                subject_id=np.int64(subj),
                source_file=str(path),
                alpha_cf=np.float32(alpha_cf),
                alpha_bw=np.float32(alpha_bw),
            )
            feature_manifest_rows.append(
                {
                    "subject_id": int(subj),
                    "file": str(path),
                    "file_tag": str(file_tag),
                    "npz": str(target_npz),
                    "n_epochs": int(X.shape[0]),
                    "n_features": int(X.shape[1]),
                    "alpha_cf": float(alpha_cf),
                    "alpha_bw": float(alpha_bw),
                }
            )

    n_processed += 1
    if ETA_EVERY and (i % int(ETA_EVERY) == 0):
        elapsed = time.time() - t0
        avg = elapsed / max(i, 1)
        eta = avg * max(len(paths) - i, 0)
        print(f"  -> done in {_fmt_secs(time.time() - t_file)} | elapsed {_fmt_secs(elapsed)} | ETA {_fmt_secs(eta)}")

dt = time.time() - t0
print(f"Done. Processed: {n_processed}, skipped: {n_skipped}, total: {len(paths)} | elapsed: {dt/60:.1f} min")

# Write outputs
profiles_df = pd.DataFrame([r.__dict__ for r in profile_rows])
profiles_path = outpath("alpha_profiles.csv")
profiles_df.to_csv(profiles_path, index=False)
print("Wrote:", profiles_path)

if feature_manifest_rows:
    manifest_df = pd.DataFrame(feature_manifest_rows)
    manifest_path = outpath("feature_manifest.csv")
    manifest_df.to_csv(manifest_path, index=False)
    print("Wrote:", manifest_path)

# Write file info summary
if file_info_rows:
    info_df = pd.DataFrame(file_info_rows)
    info_path = outpath('file_info.csv')
    info_df.to_csv(info_path, index=False)
    print('Wrote:', info_path)
