# Label unlabeled `Preprocessed_setfiles` with `EC_EO_Classifier` runs

This notebook replaces the role of `Ensembling_for_labeling.ipynb`, but uses **your** classifier pipeline outputs from:

- `New_EEG/EC_EO_Classifier.ipynb` (trained runs)
- `New_EEG/outputs/<run_folder>/` (model artifacts)

It loops through **all** `.set` files in:

- `G:\\ChristianMusaeus\\Preprocessed_setfiles`

…and writes one `label_predictions.csv` **per run** under `LABELING_DIR` (default `G:\ChristianMusaeus\labeling`):

- `.../preprocessed_setfiles/<run_folder>/label_predictions.csv`

For FOOOF runs it tries to reuse the precomputed cache from:

- `New_EEG/outputs/saved_fooof/`


In [None]:
from __future__ import annotations

import csv
import json
import os
import re
import time
import platform
from datetime import datetime
from pathlib import Path
from typing import Dict, List, Optional, Tuple

import joblib
import numpy as np
import pandas as pd
import mne

# -----------------
# User config
# -----------------

# Dataset root (unlabeled setfiles)
DATASET_DIR = os.getenv("PREPROCESSED_SETFILES_DIR", r"G:\\ChristianMusaeus\\Preprocessed_setfiles")
RECURSIVE = True

# If channels are generic ("Ch1".."Ch19"), optionally map them to standard 10–20 names.
# NOTE: This assumes the dataset uses the common 19-channel ordering.
AUTO_RENAME_CH1_TO_1020 = True
CH1_TO_1020_ORDER_19 = [
    "Fp1",
"Fp2",
"F3",
"F4",
"C3",
"C4",
"P3",
"P4",
"O1",
"O2",
"F7",
"F8",
"T7",
"T8",
"P7",
"P8",
"Fz",
"Cz",
"Pz",
]

# Use cached ONE_MAIN_FOOOF features if available (strongly recommended)
USE_SAVED_FOOOF = True

# Write per-run label_predictions.csv
OVERWRITE = False

# Classification threshold for turning prob_ec into a label (0/1)
THRESHOLD_EC = 0.5

# Limit files for quick testing (None = all)
MAX_FILES: Optional[int] = None

# -----------------
# Paths
# -----------------

def _detect_notebook_path() -> Optional[Path]:
    try:
        vsc = globals().get("__vsc_ipynb_file__", None)
        if vsc:
            p = Path(str(vsc)).expanduser()
            if p.suffix.lower() == ".ipynb" and p.exists():
                return p.resolve()
    except Exception:
        pass
    for key in ("NOTEBOOK_PATH", "IPYNB_PATH"):
        v = os.getenv(key)
        if v:
            try:
                p = Path(v).expanduser()
                if p.suffix.lower() == ".ipynb" and p.exists():
                    return p.resolve()
            except Exception:
                pass
    try:
        here = Path.cwd().resolve()
        for _ in range(6):
            cand = here / "New_EEG" / "Project-main" / "Label_with_EC_EO_Classifier.ipynb"
            if cand.exists():
                return cand.resolve()
            here = here.parent
    except Exception:
        pass
    return None

NOTEBOOK_PATH = _detect_notebook_path()
NOTEBOOK_DIR = NOTEBOOK_PATH.parent if NOTEBOOK_PATH is not None else Path.cwd().resolve()
NEW_EEG_DIR = NOTEBOOK_DIR.parent
OUTPUTS_ROOT = NEW_EEG_DIR / "outputs"

# Output roots (can be outside repo; defaults point to your shared drive)
SAVED_FOOOF_DIR = os.getenv("SAVED_FOOOF_DIR", r"G:\ChristianMusaeus\saved_fooof")
LABELING_DIR = os.getenv("LABELING_DIR", r"G:\ChristianMusaeus\labeling")

def resolve_output_dir(p: str, fallback: Path) -> Path:
    candidates = []
    p_str = str(p)

    if platform.system() == "Windows":
        candidates.append(p_str)
    else:
        # WSL drive-letter conversion: G:\... -> /mnt/g/...
        try:
            if "microsoft" in platform.uname().release.lower():
                if len(p_str) >= 3 and p_str[1] == ":" and p_str[0].isalpha():
                    drive = p_str[0].lower()
                    rest = p_str[2:].lstrip("\\/").replace("\\", "/")
                    candidates.append(f"/mnt/{drive}/{rest}")
        except Exception:
            pass

    candidates.append(str(fallback))

    for c in candidates:
        if not c:
            continue
        try:
            pp = Path(str(c)).expanduser()
            pp.mkdir(parents=True, exist_ok=True)
            return pp.resolve()
        except Exception:
            pass
    raise RuntimeError(f"Could not create output dir: {p}")

SAVED_FOOOF_ROOT = resolve_output_dir(SAVED_FOOOF_DIR, OUTPUTS_ROOT / "saved_fooof")
LABELING_ROOT = resolve_output_dir(LABELING_DIR, OUTPUTS_ROOT / "labeling") / "preprocessed_setfiles"
LABELING_ROOT.mkdir(parents=True, exist_ok=True)

print("Saved-FOOOF root:", SAVED_FOOOF_ROOT)
print("Labeling root:", LABELING_ROOT)

print("Notebook dir:", NOTEBOOK_DIR)
print("New_EEG dir:", NEW_EEG_DIR)
print("Outputs root:", OUTPUTS_ROOT)

def _safe_tag(s: str) -> str:
    s = re.sub(r"[^a-zA-Z0-9._-]+", "_", str(s))
    s = re.sub(r"_+", "_", s).strip("_")
    return s[:160] if len(s) > 160 else s

def is_wsl() -> bool:
    try:
        return "microsoft" in platform.uname().release.lower()
    except Exception:
        return False


def _maybe_wsl_path(p: str) -> Optional[str]:
    m = re.match(r"^([A-Za-z]):[\\\\/](.*)$", str(p))
    if not m:
        return None
    drive = m.group(1).lower()
    rest = m.group(2).replace("\\\\", "/").replace("\\", "/")
    return f"/mnt/{drive}/{rest}"

def resolve_dataset_dir() -> Path:
    candidates = [DATASET_DIR]
    w = _maybe_wsl_path(str(DATASET_DIR))
    if w:
        candidates.append(w)
    candidates.append(r"/mnt/g/ChristianMusaeus/Preprocessed_setfiles")
    for c in candidates:
        if not c:
            continue
        try:
            p = Path(str(c)).expanduser()
            if p.exists() and p.is_dir():
                return p.resolve()
        except Exception:
            pass
    raise RuntimeError(f"Could not resolve DATASET_DIR: {DATASET_DIR}")

DATASET_PATH = resolve_dataset_dir()
print("Resolved dataset dir:", DATASET_PATH)

In [None]:

# -----------------
# Dataset sanity check: show file + channel info
# -----------------

# Collect .set files (lightweight preview)
pats = ("*.set", "*.SET")
set_files = sorted({p.resolve() for pat in pats for p in Path(DATASET_PATH).rglob(pat)}) if RECURSIVE else sorted({p.resolve() for pat in pats for p in Path(DATASET_PATH).glob(pat)})
print("Resolved dataset dir:", DATASET_PATH)
print("Recursive:", RECURSIVE)
print("Total .set files:", len(set_files))
if set_files:
    print("First file:", set_files[0])
    if len(set_files) > 1:
        print("Last file:", set_files[-1])

# Load the first file and print channel metadata
if not set_files:
    raise FileNotFoundError(f"No .set files found under {DATASET_PATH}")

first_path = set_files[0]
epochs = mne.io.read_epochs_eeglab(str(first_path), verbose='ERROR')

print("\n--- MNE epochs info (first file) ---")
print("n_epochs:", len(epochs))
print("n_channels:", len(epochs.ch_names))
print("sfreq:", float(epochs.info['sfreq']))
print("ch_names:", epochs.ch_names)

print("\n--- Channel dict entries (epochs.info['chs']) ---")
for ch in epochs.info['chs']:
    # Print a compact subset of fields that are usually informative
    print({
        'ch_name': ch.get('ch_name'),
        'kind': ch.get('kind'),
        'unit': ch.get('unit'),
        'coil_type': ch.get('coil_type'),
    })


In [None]:
# -----------------
# Detect the 4 run folders (old/new × fooof/no_fooof)
# -----------------

def list_run_folders(outputs_root: Path) -> List[Path]:
    if not outputs_root.exists():
        return []
    out = []
    for p in outputs_root.iterdir():
        if not p.is_dir():
            continue
        if p.name in {"saved_fooof", "comparisons", "labeling"}:
            continue
        if p.name.startswith("compare__"):
            continue
        out.append(p)
    return sorted(out, key=lambda x: x.name)

def auto_pick_four_runs(outputs_root: Path) -> Dict[str, Path]:
    runs = list_run_folders(outputs_root)
    by_name = {p.name: p for p in runs}
    picked: Dict[str, Path] = {}
    for p in runs:
        name = p.name
        if name.startswith("old_dataset__") and "__fooof__" in name:
            picked["old_fooof"] = p
        if name.startswith("old_dataset__") and "__no_fooof__" in name:
            picked["old_no_fooof"] = p
        if name.startswith("new_dataset__") and "__fooof__" in name:
            picked["new_fooof"] = p
        if name.startswith("new_dataset__") and "__no_fooof__" in name:
            picked["new_no_fooof"] = p
    # sanity
    missing = [k for k in ("old_fooof", "old_no_fooof", "new_fooof", "new_no_fooof") if k not in picked]
    if missing:
        raise RuntimeError(f"Could not auto-detect the 4 runs in {outputs_root}. Missing: {missing}. Found: {list(by_name)[:8]}...")
    return picked

RUNS = auto_pick_four_runs(OUTPUTS_ROOT)
for k, p in RUNS.items():
    print(f"{k}: {p.name}")

# Which classifier runs to apply
# - "all": run all 4 back-to-back
# - "single": run just RUN_KEY
# - "custom": run RUN_KEYS_CUSTOM
RUN_MODE = "custom"
RUN_KEY = "old_fooof"
RUN_KEYS_CUSTOM = ["old_fooof", "new_no_fooof" , "new_fooof"]

if RUN_MODE == "all":
    RUN_KEYS_TO_PROCESS = ["old_fooof", "old_no_fooof", "new_fooof", "new_no_fooof"]
elif RUN_MODE == "single":
    RUN_KEYS_TO_PROCESS = [str(RUN_KEY)]
elif RUN_MODE == "custom":
    RUN_KEYS_TO_PROCESS = [str(x) for x in RUN_KEYS_CUSTOM]
else:
    raise ValueError(f"Unknown RUN_MODE: {RUN_MODE}")

missing = [k for k in RUN_KEYS_TO_PROCESS if k not in RUNS]
if missing:
    raise ValueError(f"Unknown run keys: {missing}. Available: {list(RUNS)}")

# Persist selection for reproducibility
cfg = {
    "DATASET_DIR": str(DATASET_PATH),
    "USE_SAVED_FOOOF": bool(USE_SAVED_FOOOF),
    "THRESHOLD_EC": float(THRESHOLD_EC),
    "MAX_FILES": MAX_FILES,
    "RUNS": {k: str(v) for k, v in RUNS.items()},
    "RUN_MODE": str(RUN_MODE),
    "RUN_KEY": str(RUN_KEY),
    "RUN_KEYS_CUSTOM": list(RUN_KEYS_CUSTOM),
    "RUN_KEYS_TO_PROCESS": list(RUN_KEYS_TO_PROCESS),
}
(LABELING_ROOT / f"labeling_config__{datetime.now().strftime('%Y%m%d_%H%M%S')}.json").write_text(json.dumps(cfg, indent=2), encoding="utf-8")


In [None]:
# -----------------
# Helpers: channels, PSD, saved-FOOOF loader
# -----------------

def canonical_channel_name(ch_name: str) -> str:
    name = str(ch_name).strip()
    name = re.sub(r"^EEG\s+", "", name, flags=re.IGNORECASE)
    name = re.sub(r"-REF$", "", name, flags=re.IGNORECASE)
    name = re.sub(r"\s+", "", name)
    key = name.upper()
    if globals().get('AUTO_RENAME_CH1_TO_1020', False):
        order = globals().get('CH1_TO_1020_ORDER_19', None)
        if isinstance(order, (list, tuple)) and len(order) == 19:
            ch_map = {f'CH{i+1}': str(order[i]) for i in range(19)}
            if key in ch_map:
                return ch_map[key]
    return name

def rename_epochs_channels_canonical(epochs):
    new_names = [canonical_channel_name(ch) for ch in epochs.ch_names]
    if len(set(new_names)) != len(new_names):
        return epochs
    mapping = {old: new for old, new in zip(epochs.ch_names, new_names) if old != new}
    if mapping:
        epochs.rename_channels(mapping)
    return epochs

def parse_subject_id(path: Path) -> int:
    m = re.search(r"(\d{3,})", path.stem)
    if m:
        return int(m.group(1))
    return int(abs(hash(path.stem)) % 1_000_000_000)

def psd_array_welch_clean(data: np.ndarray, sfreq: float, fmin=1.0, fmax=45.0, target_secs=2.0):
    n_epochs, _, n_times = data.shape
    n_per_seg = max(8, min(n_times, int(round(float(target_secs) * float(sfreq)))))
    n_overlap = n_per_seg // 2 if n_per_seg >= 16 else 0
    psds, freqs = mne.time_frequency.psd_array_welch(
        data,
        sfreq=float(sfreq),
        fmin=float(fmin),
        fmax=float(fmax),
        n_per_seg=int(n_per_seg),
        n_overlap=int(n_overlap),
        window="hann",
        average="mean",
        verbose=False,
    )
    return psds, freqs

def reduce_freq_resolution(psd_cube: np.ndarray, n_bins: int) -> np.ndarray:
    n_samples, n_channels, n_freqs = psd_cube.shape
    bin_size = n_freqs // int(n_bins)
    if bin_size == 0:
        raise ValueError(f"n_bins={n_bins} is too high for n_freqs={n_freqs}")
    trimmed = psd_cube[:, :, : bin_size * int(n_bins)]
    trimmed = trimmed.reshape(n_samples, n_channels, int(n_bins), bin_size)
    reduced = trimmed.mean(axis=3)
    return reduced

def _find_saved_fooof_npz(file_path: Path, subject_id: int) -> Optional[Path]:
    if not SAVED_FOOOF_ROOT.exists():
        return None
    file_tag = _safe_tag(file_path.stem)
    pattern = f"features_subject_{int(subject_id)}__{file_tag}.npz"
    matches = [p for p in SAVED_FOOOF_ROOT.rglob(pattern) if p.is_file()]
    if not matches:
        return None
    # prefer matches that sit under cache_b
    matches = sorted(matches, key=lambda p: ("cache_b" not in str(p.parent), len(str(p))))
    return matches[0]

def load_saved_fooof_features(npz_path: Path) -> Tuple[np.ndarray, List[str]]:
    d = np.load(npz_path, allow_pickle=True)
    X = np.asarray(d["X"], dtype=float)
    names = [str(x) for x in np.asarray(d["feature_names"]).ravel().tolist()]
    return X, names

def align_by_feature_names(X: np.ndarray, names: List[str], desired_names: List[str]) -> np.ndarray:
    # Map (channel, feat) -> column index using canonical channel key
    mapping: Dict[Tuple[str, str], int] = {}
    for idx, n in enumerate(names):
        try:
            ch_part, feat = str(n).rsplit("_", 1)
        except ValueError:
            continue
        mapping[(canonical_channel_name(ch_part).upper(), str(feat))] = int(idx)
    out = np.zeros((int(X.shape[0]), int(len(desired_names))), dtype=float)
    for j, dn in enumerate(desired_names):
        ch_part, feat = str(dn).rsplit("_", 1)
        src = mapping.get((canonical_channel_name(ch_part).upper(), str(feat)), None)
        if src is None:
            continue
        out[:, j] = X[:, src]
    return np.nan_to_num(out, nan=0.0, posinf=0.0, neginf=0.0)

def make_desired_fooof_feature_names(feature_channels: List[str], selected_feats: List[str]) -> List[str]:
    base_order = ["offset", "exponent", "alpha_cf", "alpha_amp", "alpha_bw"]
    selected = [f for f in (selected_feats or base_order) if f in base_order]
    if not selected:
        selected = base_order
    names = []
    for ch in feature_channels:
        for feat in selected:
            names.append(f"{ch}_{feat}")
    return names

def load_feature_channels(run_dir: Path) -> Optional[List[str]]:
    p = run_dir / "feature_channels.npy"
    if not p.exists():
        return None
    arr = np.load(p, allow_pickle=True)
    return [str(x) for x in np.asarray(arr).ravel().tolist()]

def load_psd_freqs(run_dir: Path) -> Optional[np.ndarray]:
    p = run_dir / "psd_freqs.npy"
    if not p.exists():
        return None
    return np.asarray(np.load(p), dtype=float).ravel()

# -----------------
# FOOOF backend + ONE_MAIN_FOOOF on-the-fly fallback
# -----------------

# Defaults mirror `New_EEG/EC_EO_Classifier.ipynb`.
ALPHA_FREQ_RANGE = (3.0, 40.0)
ALPHA_PROFILE_RANGE = (4.0, 16.0)
ALPHA_PROFILE_ROI = ["O1", "O2", "P3", "P4", "P7", "P8", "Pz"]
FOOOF_SETTINGS = {
    "aperiodic_mode": "fixed",
    "peak_width_limits": (0.5, 12.0),
    "max_n_peaks": 6,
    "min_peak_height": 0.05,
    "peak_threshold": 2.0,
    "verbose": False,
}

SpectralModel = None
FitError = Exception
try:
    from specparam import SpectralModel as _SpecModel
    try:
        from specparam.core.errors import FitError as _FitError
    except Exception:
        class _FitError(Exception):
            pass
    SpectralModel = _SpecModel
    FitError = _FitError
    FOOOF_BACKEND = "specparam"
except Exception:
    try:
        from fooof import FOOOF as _FooofModel
        from fooof.core.errors import FitError as _FitError
        SpectralModel = _FooofModel
        FitError = _FitError
        FOOOF_BACKEND = "fooof"
    except Exception:
        SpectralModel = None
        FitError = Exception
        FOOOF_BACKEND = "unavailable"

def select_alpha_peak(peaks: np.ndarray, lo: float, hi: float):
    peaks_arr = np.asarray(peaks, float)
    if peaks_arr.size == 0:
        return None
    if peaks_arr.ndim == 1:
        peaks_arr = peaks_arr.reshape(1, -1)
    mask = (peaks_arr[:, 0] >= float(lo)) & (peaks_arr[:, 0] <= float(hi))
    if not np.any(mask):
        return None
    subset = peaks_arr[mask]
    return subset[np.argmax(subset[:, 1])]

def build_alpha_profile(freqs: np.ndarray, psd_cube: np.ndarray, channels: List[str]) -> Tuple[float, float]:
    """Fit a single alpha peak on the mean ROI spectrum (returns alpha_cf, alpha_bw)."""
    if SpectralModel is None:
        return 0.0, 0.0
    roi_names = [ch for ch in ALPHA_PROFILE_ROI if ch in channels]
    if not roi_names:
        roi_idx = list(range(len(channels)))
    else:
        roi_idx = [channels.index(ch) for ch in roi_names]
    roi_cube = psd_cube[:, roi_idx, :]
    mean_spectrum = np.nanmean(roi_cube, axis=(0, 1))
    if not np.any(np.isfinite(mean_spectrum)):
        return 0.0, 0.0
    try:
        model = SpectralModel(**FOOOF_SETTINGS)
        lo, hi = ALPHA_PROFILE_RANGE
        model.fit(np.asarray(freqs, float), np.asarray(mean_spectrum, float), freq_range=(float(lo), float(hi)))
        peaks = np.asarray(getattr(model, "peak_params_", []))
        chosen = select_alpha_peak(peaks, float(lo), float(hi))
        if chosen is None:
            return 0.0, 0.0
        cf, _, bw = map(float, chosen[:3])
        return float(cf), float(bw)
    except Exception:
        return 0.0, 0.0

def compute_one_main_fooof_features(freqs: np.ndarray, psd_cube: np.ndarray, subject_id: int, alpha_profile_map, include_aperiodic: bool = True) -> np.ndarray:
    """Compute ONE_MAIN_FOOOF features (matches `New_EEG/EC_EO_Classifier.ipynb`)."""
    if SpectralModel is None:
        raise RuntimeError("FOOOF backend unavailable")
    subj = int(subject_id)
    profile = alpha_profile_map.get(subj) if alpha_profile_map is not None else None
    has_profile = profile is not None and len(profile) == 2
    if has_profile:
        alpha_cf, alpha_bw = map(float, profile)
    else:
        alpha_cf, alpha_bw = 0.0, 0.0
    freqs_arr = np.asarray(freqs, float)
    import math
    if has_profile and alpha_bw > 0:
        sigma = float(alpha_bw) / (2.0 * math.sqrt(2.0 * math.log(2.0)))
        gauss = np.exp(-0.5 * ((freqs_arr - alpha_cf) / sigma) ** 2)
    else:
        gauss = np.zeros_like(freqs_arr)
    denom = float(np.sum(gauss ** 2)) if gauss.size else 0.0
    features = []
    ap_settings = dict(FOOOF_SETTINGS)
    try:
        ap_settings["max_n_peaks"] = 0
    except Exception:
        pass
    for epoch_psd in psd_cube:
        epoch_feats = []
        for spectrum in epoch_psd:
            try:
                if not np.all(np.isfinite(spectrum)):
                    raise ValueError("Non-finite in spectrum")
                offset, exponent = 0.0, 0.0
                alpha_amp = 0.0
                if include_aperiodic or has_profile:
                    model = SpectralModel(**ap_settings)
                    model.fit(freqs_arr, spectrum, freq_range=ALPHA_FREQ_RANGE)
                    if hasattr(model, "aperiodic_params_"):
                        params = np.asarray(model.aperiodic_params_)
                        if params.size > 0:
                            offset = float(params[0])
                        if params.size > 1:
                            exponent = float(params[1])
                    try:
                        ap_fit = None
                        get_fun = getattr(model, "get_model_spectrum", None)
                        if callable(get_fun):
                            ap_fit = np.asarray(get_fun(freqs_arr))
                    except Exception:
                        ap_fit = None
                    if ap_fit is None:
                        for name in ("fooofed_spectrum_", "modeled_spectrum_", "model_spectrum_", "model_spectrum__"):
                            if hasattr(model, name):
                                ap_fit = np.asarray(getattr(model, name))
                                break
                    if ap_fit is None or ap_fit.shape != spectrum.shape:
                        ap_fit = np.zeros_like(spectrum)
                    if has_profile and denom > 0.0:
                        residual = spectrum - ap_fit
                        num = float(np.sum(gauss * residual))
                        alpha_amp = max(num / denom, 0.0)
                epoch_feats.extend([offset, exponent, alpha_cf if has_profile else 0.0, alpha_amp, alpha_bw if has_profile else 0.0])
            except (FitError, RuntimeError, ValueError, np.linalg.LinAlgError):
                epoch_feats.extend([0.0, 0.0, 0.0, 0.0, 0.0])
        features.append(epoch_feats)
    return np.asarray(features, dtype=float)

def select_from_full_one_main(full_X: np.ndarray, n_channels: int, selected_feats: List[str]) -> np.ndarray:
    base_order = ["offset", "exponent", "alpha_cf", "alpha_amp", "alpha_bw"]
    selected = [f for f in (selected_feats or base_order) if f in base_order]
    if not selected:
        selected = base_order
    stride = len(base_order)
    keep_offsets = [base_order.index(s) for s in selected]
    idx = []
    for ch_i in range(int(n_channels)):
        base = ch_i * stride
        for off in keep_offsets:
            idx.append(base + off)
    return np.asarray(full_X, dtype=float)[:, idx]


In [None]:
# -----------------
# Main labeling loop
# -----------------

def collect_set_files(directory: Path, recursive: bool = True) -> List[Path]:
    if not directory.exists():
        return []
    pats = ("*.set", "*.SET")
    if recursive:
        files = [p.resolve() for pat in pats for p in directory.rglob(pat)]
    else:
        files = [p.resolve() for pat in pats for p in directory.glob(pat)]
    # de-dup preserve order
    out = []
    seen = set()
    for p in sorted(files, key=lambda x: str(x)):
        s = str(p)
        if s in seen:
            continue
        seen.add(s)
        out.append(p)
    return out

all_files = collect_set_files(DATASET_PATH, recursive=RECURSIVE)
if MAX_FILES is not None:
    all_files = all_files[: int(MAX_FILES)]
print(".set files to process:", len(all_files))
if all_files:
    print("Example:", all_files[0])

def _fmt_secs(seconds: float) -> str:
    seconds = max(0.0, float(seconds))
    m, s = divmod(int(round(seconds)), 60)
    h, m = divmod(m, 60)
    return f"{h:d}h{m:02d}m" if h else f"{m:d}m{s:02d}s"

t_global = time.time()
for run_key in RUN_KEYS_TO_PROCESS:
    run_dir = RUNS[run_key]
    run_name = run_dir.name
    is_fooof = "__fooof__" in run_name
    out_dir = LABELING_ROOT / run_name
    out_dir.mkdir(parents=True, exist_ok=True)
    out_csv = out_dir / "label_predictions.csv"
    log_path = out_dir / "labeling_log.csv"

    if out_csv.exists() and (not OVERWRITE):
        print(f"[SKIP] {run_key}: output exists: {out_csv}")
        continue

    # Load run artifacts
    model_path = run_dir / "final_model_lr.pkl"
    scaler_path = run_dir / "final_scaler_lr.pkl"
    imputer_path = run_dir / "final_imputer_lr.pkl"
    component_path = run_dir / "final_component_lr.pkl"
    if not (model_path.exists() and scaler_path.exists() and imputer_path.exists()):
        raise RuntimeError(f"Missing model artifacts in {run_dir}")
    model = joblib.load(model_path)
    scaler = joblib.load(scaler_path)
    imputer = joblib.load(imputer_path)
    component = joblib.load(component_path) if component_path.exists() else None

    feature_channels = load_feature_channels(run_dir)
    psd_freqs_ref = load_psd_freqs(run_dir)

    selected_feats = None
    if is_fooof:
        sf = run_dir / "fooof_selected_features.npy"
        if sf.exists():
            selected_feats = [str(x) for x in np.asarray(np.load(sf, allow_pickle=True)).ravel().tolist()]
        else:
            selected_feats = ["offset", "exponent", "alpha_amp"]

    n_bins = None
    if not is_fooof:
        nb = run_dir / "final_n_bins_lr.npy"
        if not nb.exists():
            raise RuntimeError(f"Missing final_n_bins_lr.npy for PSD run: {run_dir}")
        n_bins = int(np.asarray(np.load(nb)).ravel()[0])

    # Open output CSVs
    with out_csv.open("w", newline="", encoding="utf-8") as f_out, log_path.open("w", newline="", encoding="utf-8") as f_log:
        w = csv.DictWriter(
            f_out,
            fieldnames=["Test subject ID", "Epoch number", "Label", "Probability", "prob_ec", "run", "file"],
        )
        w.writeheader()
        wlog = csv.DictWriter(
            f_log,
            fieldnames=["file", "subject_id", "status", "n_epochs", "note"],
        )
        wlog.writeheader()

        print(f"\n[RUN] {run_key}: {run_name}")
        print("  features:", "FOOOF" if is_fooof else "PSD", "| n_bins:", n_bins)
        print("  output:", out_csv)

        t_run = time.time()
        for i, path in enumerate(all_files, start=1):
            subj = parse_subject_id(path)
            try:
                epochs = mne.io.read_epochs_eeglab(str(path), verbose="ERROR")
                epochs = rename_epochs_channels_canonical(epochs)
                sfreq = float(epochs.info["sfreq"])

                # Skip files with NaNs (same policy as the original project)
                data_full = epochs.get_data()
                if not np.all(np.isfinite(data_full)):
                    wlog.writerow({"file": str(path), "subject_id": int(subj), "status": "skipped", "n_epochs": int(len(epochs)), "note": "NaN/Inf in raw data"})
                    continue

                # Determine channel order to match the trained run
                if feature_channels is None:
                    # fallback: use all EEG channels in file order
                    picks = mne.pick_types(epochs.info, eeg=True, meg=False, stim=False, eog=False, exclude="bads")
                    if len(picks) == 0:
                        picks = mne.pick_types(epochs.info, eeg=True, meg=False, stim=False, eog=False, exclude=[])
                    channels = [epochs.ch_names[idx] for idx in picks]
                else:
                    channels = list(feature_channels)

                # Load features
                if is_fooof:
                    desired_names = make_desired_fooof_feature_names(channels, selected_feats)
                    X_raw = None
                    note = ""
                    if USE_SAVED_FOOOF:
                        npz = _find_saved_fooof_npz(path, subject_id=int(subj))
                        if npz is not None:
                            X_saved, names_saved = load_saved_fooof_features(npz)
                            X_raw = align_by_feature_names(X_saved, names_saved, desired_names)
                            note = f"cache:{npz.name}"

                    # Fallback: compute ONE_MAIN_FOOOF features on-the-fly when cache is missing.
                    if X_raw is None:
                        if SpectralModel is None:
                            wlog.writerow({"file": str(path), "subject_id": int(subj), "status": "skipped", "n_epochs": int(len(epochs)), "note": "Missing saved_fooof cache and no specparam/fooof installed"})
                            continue

                        # Compute PSD for the desired channel list, filling missing channels with zeros.
                        lookup = {canonical_channel_name(ch).upper(): idx for idx, ch in enumerate(epochs.ch_names)}
                        present_pairs = [(out_i, lookup.get(canonical_channel_name(ch).upper(), None)) for out_i, ch in enumerate(channels)]
                        have = [(o, i_in) for o, i_in in present_pairs if i_in is not None]
                        if not have:
                            wlog.writerow({"file": str(path), "subject_id": int(subj), "status": "skipped", "n_epochs": int(len(epochs)), "note": "No matching channels for on-the-fly FOOOF"})
                            continue
                        out_idx, in_idx = zip(*have)
                        data_present = epochs.get_data(picks=list(in_idx))
                        target_secs = 1.0 if data_present.shape[-1] < sfreq * 3 else 2.0
                        psd_present, freqs = psd_array_welch_clean(data_present, sfreq=sfreq, fmin=1.0, fmax=45.0, target_secs=target_secs)
                        psd_full = np.full((psd_present.shape[0], len(channels), psd_present.shape[2]), np.nan, dtype=float)
                        psd_full[:, list(out_idx), :] = psd_present
                        psd_full = np.nan_to_num(psd_full, nan=0.0, posinf=0.0, neginf=0.0)

                        alpha_cf, alpha_bw = build_alpha_profile(freqs, psd_full, channels)
                        alpha_profile_map = {int(subj): (float(alpha_cf), float(alpha_bw))}
                        full_X = compute_one_main_fooof_features(freqs=freqs, psd_cube=psd_full, subject_id=int(subj), alpha_profile_map=alpha_profile_map, include_aperiodic=True)
                        X_raw = select_from_full_one_main(full_X, n_channels=len(channels), selected_feats=selected_feats)
                        note = "computed_on_fly"

                else:
                    # PSD features: compute PSD for all epochs
                    if feature_channels is None:
                        picks = mne.pick_types(epochs.info, eeg=True, meg=False, stim=False, eog=False, exclude="bads")
                        if len(picks) == 0:
                            picks = mne.pick_types(epochs.info, eeg=True, meg=False, stim=False, eog=False, exclude=[])
                        data = epochs.get_data(picks=picks)
                        psd, freqs = psd_array_welch_clean(data, sfreq=sfreq, fmin=1.0, fmax=45.0, target_secs=2.0)
                        psd_full = psd
                    else:
                        # compute PSD only for channels present, then place into full channel array
                        lookup = {canonical_channel_name(ch).upper(): idx for idx, ch in enumerate(epochs.ch_names)}
                        present_pairs = [(out_i, lookup.get(canonical_channel_name(ch).upper(), None)) for out_i, ch in enumerate(feature_channels)]
                        have = [(o, i_in) for o, i_in in present_pairs if i_in is not None]
                        if not have:
                            wlog.writerow({"file": str(path), "subject_id": int(subj), "status": "skipped", "n_epochs": int(len(epochs)), "note": "No matching channels"})
                            continue
                        out_idx, in_idx = zip(*have)
                        data_present = epochs.get_data(picks=list(in_idx))
                        psd_present, freqs = psd_array_welch_clean(data_present, sfreq=sfreq, fmin=1.0, fmax=45.0, target_secs=2.0)
                        psd_full = np.full((psd_present.shape[0], len(feature_channels), psd_present.shape[2]), np.nan, dtype=float)
                        psd_full[:, list(out_idx), :] = psd_present

                    if psd_freqs_ref is not None:
                        if freqs.shape != psd_freqs_ref.shape or not np.allclose(freqs, psd_freqs_ref):
                            wlog.writerow({"file": str(path), "subject_id": int(subj), "status": "skipped", "n_epochs": int(len(epochs)), "note": "PSD freqs mismatch vs trained run"})
                            continue

                    psd_full = np.nan_to_num(psd_full, nan=0.0, posinf=0.0, neginf=0.0)
                    if n_bins is not None:
                        reduced = reduce_freq_resolution(psd_full, int(n_bins))
                        X_raw = reduced.reshape(reduced.shape[0], -1)
                    else:
                        X_raw = psd_full.reshape(psd_full.shape[0], -1)

                # Apply preprocessing and predict
                X_imp = imputer.transform(X_raw)
                X_scaled = scaler.transform(X_imp)
                X_proc = component.transform(X_scaled) if component is not None else X_scaled

                prob_ec = model.predict_proba(X_proc)[:, 1].astype(float)
                y_pred = (prob_ec >= float(THRESHOLD_EC)).astype(int)
                prob_pred = np.where(y_pred == 1, prob_ec, 1.0 - prob_ec)

                for ep in range(int(prob_ec.shape[0])):
                    w.writerow(
                        {
                            "Test subject ID": int(subj),
                            "Epoch number": int(ep),
                            "Label": int(y_pred[ep]),
                            "Probability": float(prob_pred[ep]),
                            "prob_ec": float(prob_ec[ep]),
                            "run": str(run_name),
                            "file": str(path),
                        }
                    )
                wlog.writerow({"file": str(path), "subject_id": int(subj), "status": "ok", "n_epochs": int(prob_ec.shape[0]), "note": note if is_fooof else ""})

            except Exception as exc:
                wlog.writerow({"file": str(path), "subject_id": int(subj), "status": "error", "n_epochs": 0, "note": str(exc)[:240]})

            if i % 25 == 0:
                elapsed = time.time() - t_run
                avg = elapsed / max(i, 1)
                eta = avg * max(len(all_files) - i, 0)
                print(f"  [{i}/{len(all_files)}] elapsed {_fmt_secs(elapsed)} | ETA {_fmt_secs(eta)}")

        print(f"[DONE] {run_key} in {_fmt_secs(time.time() - t_run)}")

print("All runs done. Total elapsed:", _fmt_secs(time.time() - t_global))
