# EC vs EO Classification Pipeline

This notebook builds an **EC (eyes closed)** vs **EO (eyes open)** classifier from *clinical* EEG data.

**High-level flow**
1. Import dependencies.
2. Set global toggles (channel selection, feature mode, CV settings, temporal smoothing).
3. Resolve input files:
   - **OLD dataset**: EC/EO marked EEGLAB `.set` files.
   - **NEW dataset (optional)**: preprocessed MNE epoch `.fif` files (doctor a/b).
4. Extract features per epoch:
   - **PSD features** (optionally frequency-binned), or
   - **FOOOF/specparam features** (optionally `ONE_MAIN_FOOOF` subject-level alpha template).
5. Train + evaluate **logistic regression** with subject-wise CV and optional temporal smoothing.
6. Generate diagnostics/plots and save artifacts under `outputs/`.

**Core globals created downstream**
- Feature/label arrays: `X_combined`, `y_combined`, `subject_ids`
- Spectral data (for plots): `psd_cube`, `psd_freqs`, `feature_channels`


In [None]:
import os
import sys
import re
import warnings
from collections import Counter
from pathlib import Path

os.environ.setdefault("NUMBA_DISABLE_CACHE", "1")
os.environ.setdefault("NUMBA_CACHE_DIR", str((Path.cwd() / ".numba_cache").resolve()))
Path(os.environ["NUMBA_CACHE_DIR"]).mkdir(parents=True, exist_ok=True)

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from scipy.io import loadmat
from scipy.signal import stft
import joblib
import mne

from sklearn.model_selection import LeaveOneOut
from sklearn.preprocessing import StandardScaler
from sklearn.decomposition import PCA, FastICA
from sklearn.linear_model import LogisticRegression
from sklearn.svm import SVC
from sklearn.ensemble import RandomForestClassifier
from sklearn.metrics import (
    accuracy_score,
    precision_score,
    recall_score,
    f1_score,
    roc_auc_score,
    classification_report,
    confusion_matrix,
    log_loss,
    roc_curve,
    auc,
)
from sklearn.impute import SimpleImputer


## Data Locations & Global Toggles

This section controls *what data is loaded* and *how the pipeline behaves*.

**Data locations (clinical)**
- OLD `.set` data can be pointed to via environment variables:
  - `EC_EO_OPEN_DIR` (EO)
  - `EC_EO_CLOSED_DIR` (EC)
- NEW preprocessed epoch `.fif` data can be pointed to via:
  - `NEW_EEG_PROCESSED_DIR`

**Most important toggles**
- `NEW_DATA`: use NEW `.fif` epochs instead of OLD `.set` files.
- `USE_BOTH_DATASETS` / `TEST_ON_OTHER_DATASET`: cross-dataset evaluation modes.
- `CHANNEL_SELECTION`: list of channels to use, or `'all'`.
- `USE_FOOOF`: feature type (`False` = PSD, `True` = FOOOF/specparam).
- `ONE_MAIN_FOOOF`: subject-level alpha template mode (FOOOF only).
- `CV_LEVEL` and related knobs: define the subject-wise evaluation scheme.
- `USE_TIME_ADJUSTMENT`: enables run-length smoothing of predicted labels (uses `TIME_AXIS_MODE`).

Set these once up front, then run the next cell(s) top-to-bottom.


In [None]:
# -*- coding: utf-8 -*-
import platform

# -----------------
# High-level toggles
# -----------------
COMBINE_ADJACENT_EPOCHS = False  # Pair consecutive clean epochs into 2-second windows
NEW_DATA = False                 # True -> use NEW preprocessed FIF dataset (doctor a/b union); False -> use old EC/EO .set files

# Dataset evaluation / mixing toggles
TEST_ON_OTHER_DATASET = False   # If True: train on selected dataset (NEW_DATA), test on the other dataset
USE_BOTH_DATASETS = False       # If True: load NEW+OLD together and run subject-wise CV over the combined pool
DATASET_SUBJECT_OFFSET = 1000000  # Offset used to avoid subject_id collisions across datasets
CROSS_DATASET_TEST = bool(TEST_ON_OTHER_DATASET and (not USE_BOTH_DATASETS))

# How to build the per-epoch "time_idx" used in timeline plots
# - "append_files": old behavior (EO file(s) then EC file(s))
# - "align_conditions": use the epoch number as the timeline (best when EO/EC files are two label-views of the same 1800 epochs)
# - "interleave_conditions": treat EO epoch k then EC epoch k as consecutive time steps (t=2k, t=2k+1)
TIME_AXIS_MODE = "align_conditions"
CHANNEL_SELECTION = [   
                     #     'O1', 'O2'
                     #   , 'P3', 'P4', 'P7', 'P8', 'Pz'
                     #   , 'F3', 'F4', 'C3', 'C4', 'F7', 'F8', 'T7', 'T8', 'Fz', 'Cz'
                     #   , 'Fp1', 'Fp2'
                        'all'
                     #   "Fp1","Fp2","F3","F4","C3","C4","P3","P4","O1","O2","F7","F8","T7","T8","P7","P8","T9", "T10", "Fz"
                     ]  # "O1", "O2", "P3", "P4", "P7", "P8"  or "all" to use every EEG channel available
USE_FOOOF = False             # True -> FOOOF (specparam) features, False -> PSD features
USE_SAVED_FOOOF = True        # If True and available: reuse precomputed ONE_MAIN_FOOOF features (only applies for ONE_MAIN_FOOOF)
ONE_MAIN_FOOOF = True  # Use subject-level alpha template FOOOF mode
MAIN_FOOOF_USE_ALL_EPOCHS = True  # If True, build SUBJECT_ALPHA_PROFILE from all finite epochs (ignores per-epoch labels / rejmanual)

# Logistic regression penalty tuning
TUNE_LOGREG_PENALTY = False   # If True, compare L2 vs L1 during inner-loop tuning
LOGREG_PENALTY_OPTIONS = ["l2", "l1"]
LOGREG_PENALTY_FIXED = "l2"   # Used when TUNE_LOGREG_PENALTY=False
LOGREG_MAX_ITER = 2000        # Used for all logistic regression fits
ALPHA_PROFILE_RANGE = (4.0, 16.0)  # Hz range to find subject alpha peak
ALPHA_PROFILE_ROI = ['O1', 'O2', 'P3', 'P4', 'P7', 'P8', 'Pz']  # channels used to build alpha profile

# -----------------
# Channel sanity checks (especially for NEW .fif dataset)
# -----------------
# The NEW dataset is intended to be 19ch 10-20. If your processed files contain extra channels
# or inconsistent naming across subjects, the pipeline will build a UNION, increasing feature dims.
STANDARD_19_CHANNELS = [
    "Fp1", "Fp2", "F3", "F4", "C3", "C4", "P3", "P4", "O1", "O2",
    "F7", "F8", "T7", "T8", "P7", "P8", "Fz", "Cz", "Pz",
]
EXPECTED_NEW_CHANNEL_COUNT = 19
FAIL_IF_NEW_CHANNEL_MISMATCH = True
# If True, force NEW dataset training to exactly STANDARD_19_CHANNELS (skip subjects missing any).
FORCE_STANDARD_19_FOR_NEW = False
USE_CLASS_WEIGHTS = True    # Toggle to enable class_weight='balanced' in classifiers
CLASS_WEIGHT = "balanced" if USE_CLASS_WEIGHTS else None
USE_FREQ_BINNING = True      # Only relevant when USE_FOOOF is False
FREQ_BIN_OPTIONS = [5, 10, 15, 20, 25, 30]
C_GRID = [0.01, 0.1, 0.2, 0.5, 1.0]

# FOOOF feature selection (applies when USE_FOOOF is True)
# Choose a subset of ["offset", "exponent", "alpha_cf", "alpha_amp", "alpha_bw"]
FOOOF_SELECTED_FEATURES = [
    "offset", "exponent",  # Aperiodic features
    #"alpha_cf", "alpha_bw", # center frequency and bandwidth
    "alpha_amp" # amplitude
    ]

# PSD feature selection (applies when USE_FOOOF is False)
# Set PSD_FEATURE_RANGE to a tuple like (8.0, 12.0) to restrict
# classifier features to that frequency band; use None for full range.
PSD_FEATURE_RANGE = None #(8.0, 12.0) # or None 

CV_LEVEL = 2                  # 1 = single hold-out, 2 = covering CV, 3 = repeated covering CV, 4 = fixed test split
CV_TEST_SUBJECTS_PER_SPLIT = 5  # Number of subjects assigned to each held-out pane
CV_REPEAT_COUNT = 3           # Only applies when CV_LEVEL==3
CV_RANDOM_SEED = 13           # Controls subject shuffling for reproducibility
FIXED_TEST_SUBJECTS_LEVEL4 = [10135, 10171, 10193, 10203, 10204]

# ---- Component analysis (optional PCA/ICA) ----
USE_COMPONENT_ANALYSIS = False    # True to enable PCA/ICA before logistic regression
COMPONENT_METHOD = "pca"          # "pca" or "ica"
COMPONENT_N_COMPONENTS = 10     # int or None; if None a default is used # this is max components in training
ELBOW_MAX_COMPONENTS = 50         # max components shown in elbow plot

# Temporal smoothing of predicted labels
USE_TIME_ADJUSTMENT = True      # Toggle run-length smoothing of predictions
LENGTH_TUNING = True            # If True, tune run-length on out-of-fold predictions
LENGTH_GRID = [1, 3, 5, 7, 9, 11, 15, 21]  # Candidate run lengths (in epochs)
LENGTH_TUNING_METRIC = "balanced_accuracy"  # "balanced_accuracy" or "accuracy"
MIN_RUN_LENGTH = 5              # Minimum interior run length to keep as-is
USE_EDGE_SMOOTHING = True       # Also smooth short runs at the edges
MIN_RUN_LENGTH_EDGE = 5         # Edge-window span (in original epochs) for majority vote


# PSD and FOOOF parameters
PSD_KWARGS = dict(method="welch", fmin=1.0, fmax=45.0)
ALPHA_BAND = (8.0, 12.0)
ALPHA_PRIMARY_RANGE = (8.0, 14.0)
ALPHA_EXPANDED_RANGE = (4.0, 17.0)
FOOOF_SETTINGS = dict(
    aperiodic_mode="fixed",
    peak_width_limits=(0.5, 12.0),
    max_n_peaks=6,
    min_peak_height=0.05,
    peak_threshold=2.0,
    verbose=False,
)


## Paths, Data Discovery, and Output Folders

This cell sets up **where data comes from** and **where results are written**.

**What it does**
- Detects a plausible project root.
- Defines helper functions for locating directories and collecting `.set` files.
- Reads optional environment overrides (`EC_EO_OPEN_DIR`, `EC_EO_CLOSED_DIR`, `NEW_EEG_PROCESSED_DIR`).
- Creates an `outputs/` folder and helper `outpath(...)` for saving artifacts.
- Resolves input files:
  - OLD `.set` files → `eyes_open_files`, `eyes_closed_files`, `set_files`, `set_labels`.
  - NEW `.fif` epoch files (if enabled) → `new_subject_pairs`.

Run it after configuring toggles in the previous cell.


In [None]:

# -----------------
# Helper utilities for locating data
# -----------------
def guess_project_root() -> Path:
    """Walk up the directory tree to find a plausible project root."""
    p = Path.cwd().resolve()
    for _ in range(6):
        if (p / ".git").exists() or (p / "data").exists():
            return p
        p = p.parent
    return Path.cwd().resolve()

def first_existing(candidates):
    """Return the first existing Path from a list of candidates."""
    for candidate in candidates:
        if candidate is None:
            continue
        path_obj = Path(candidate).expanduser()
        try:
            if path_obj.exists():
                return path_obj.resolve()
        except OSError:
            pass
    return None

def collect_set_files(directory: Path, recursive: bool = True):
    """Collect .set files (case-insensitive) from a directory."""
    if directory is None or not directory.exists():
        return []
    patterns = ("*.set", "*.SET")
    if recursive:
        files = [p.resolve() for pat in patterns for p in directory.rglob(pat)]
    else:
        files = [p.resolve() for pat in patterns for p in directory.glob(pat)]
    return sorted({str(f) for f in files})

def is_wsl() -> bool:
    """Detect Windows Subsystem for Linux."""
    try:
        return "microsoft" in platform.uname().release.lower()
    except Exception:
        return False

project_root = guess_project_root()
system = platform.system()
wsl = is_wsl()

ENV = {
    "EC_EO_CLOSED_DIR": os.getenv("EC_EO_CLOSED_DIR"),
    "EC_EO_OPEN_DIR": os.getenv("EC_EO_OPEN_DIR"),
    "NEW_EEG_PROCESSED_DIR": os.getenv("NEW_EEG_PROCESSED_DIR"),
}

WIN_CLOSED = [
    r"E:\Saxe_sandkasse\30EOEC_filer\Closed_marked",
    r"E:/Saxe_sandkasse/30EOEC_filer/Closed_marked",
]
WIN_OPEN = [
    r"E:\Saxe_sandkasse\30EOEC_filer\Open_marked",
    r"E:/Saxe_sandkasse/30EOEC_filer/Open_marked",
]
WSL_CLOSED = [r"/mnt/e/Saxe_sandkasse/30EOEC_filer/Closed_marked"]
WSL_OPEN = [r"/mnt/e/Saxe_sandkasse/30EOEC_filer/Open_marked"]

REL_CLINICAL_CLOSED = [
    r"E:\Saxe_sandkasse\30EOEC_filer\Closed_marked",
    project_root / "data/30EOEC_filer/Closed_marked",
]
REL_CLINICAL_OPEN = [
    r"E:\Saxe_sandkasse\30EOEC_filer\Open_marked",
    project_root / "data/30EOEC_filer/Open_marked",
]

# NEW data locations (preprocessed FIF epoch files)
NEW_MAC_PROCESSED = [
    r"/Users/Saxe/Desktop/GitHub/EEG-classifiers/data/NEW_processed.nosync",
    project_root / "data/NEW_processed.nosync",
]
NEW_WIN_PROCESSED = [
    r"G:\ChristianMusaeus\New_EEG\Processed",
]
NEW_WSL_PROCESSED = [r"/mnt/g/ChristianMusaeus/New_EEG/Processed"]

print("System:", system, "(WSL:", wsl, ")")
print("Project root:", project_root)
print("NEW_DATA:", NEW_DATA)
print("Channel selection:", CHANNEL_SELECTION)
print("Using FOOOF features:" if USE_FOOOF else "Using PSD features only:", USE_FOOOF)
print("Class weights:", CLASS_WEIGHT)
print("Combine adjacent epochs:", COMBINE_ADJACENT_EPOCHS)
print("Time axis mode:", TIME_AXIS_MODE)

# -------------------- Output folder setup --------------------
# Save all generated artifacts under: <notebook_dir>/outputs/<config_tag>/...
from typing import Optional
def _detect_notebook_path() -> Optional[Path]:
    # VS Code notebooks sometimes inject this.
    try:
        vsc = globals().get('__vsc_ipynb_file__', None)
        if vsc:
            p = Path(str(vsc)).expanduser()
            if p.suffix.lower() == '.ipynb' and p.exists():
                return p.resolve()
    except Exception:
        pass
    # Env override (optional)
    for key in ("NOTEBOOK_PATH", "IPYNB_PATH"):
        v = os.getenv(key)
        if v:
            try:
                p = Path(v).expanduser()
                if p.suffix.lower() == '.ipynb' and p.exists():
                    return p.resolve()
            except Exception:
                pass
    # Repo-local fallback (this notebook lives at New_EEG/EC_EO_Classifier.ipynb)
    try:
        guess = Path(project_root) / 'New_EEG' / 'EC_EO_Classifier.ipynb'
        if guess.exists():
            return guess.resolve()
    except Exception:
        pass
    # Search fallback
    try:
        for p in Path(project_root).rglob('EC_EO_Classifier.ipynb'):
            return p.resolve()
    except Exception:
        pass
    return None

NOTEBOOK_PATH = _detect_notebook_path()
NOTEBOOK_DIR = NOTEBOOK_PATH.parent if NOTEBOOK_PATH is not None else Path.cwd().resolve()
OUTPUTS_ROOT = NOTEBOOK_DIR / 'outputs'
OUTPUTS_ROOT.mkdir(parents=True, exist_ok=True)
SAVED_FOOOF_ROOT = OUTPUTS_ROOT / 'saved_fooof'

def _safe_tag(s: str) -> str:
    s = re.sub(r'[^a-zA-Z0-9._-]+', '_', str(s))
    s = re.sub(r'_+', '_', s).strip('_')
    return s[:120] if len(s) > 120 else s

def _channel_tag() -> str:
    try:
        sel = list(CHANNEL_SELECTION)
    except Exception:
        sel = []
    sel_norm = [str(x).strip() for x in sel if x is not None]
    if any(x.lower() == 'all' for x in sel_norm):
        return 'allch'
    if not sel_norm:
        return 'ch_unknown'
    joined = '-'.join(_safe_tag(x.upper()) for x in sel_norm)
    return f"ch_{joined}"[:80]

def _dataset_tag() -> str:
    if 'USE_BOTH_DATASETS' in globals() and USE_BOTH_DATASETS:
        return 'combined_datasets'
    if 'CROSS_DATASET_TEST' in globals() and CROSS_DATASET_TEST:
        return 'train_new_test_old' if NEW_DATA else 'train_old_test_new'
    return 'new_dataset' if NEW_DATA else 'old_dataset'

def _fooof_tag() -> str:
    return 'fooof' if USE_FOOOF else 'no_fooof'

def _config_subdir() -> str:
    parts = [_dataset_tag(), _fooof_tag(), _channel_tag(), f"cv{CV_LEVEL}"]
    try:
        parts.append(f"time_{_safe_tag(TIME_AXIS_MODE)}")
    except Exception:
        pass
    if USE_FOOOF and ONE_MAIN_FOOOF:
        parts.append('one_main_fooof')
        if 'MAIN_FOOOF_USE_ALL_EPOCHS' in globals() and MAIN_FOOOF_USE_ALL_EPOCHS:
            parts.append('mainfooof_all_epochs')
    if 'USE_COMPONENT_ANALYSIS' in globals() and USE_COMPONENT_ANALYSIS:
        parts.append(f"comp_{_safe_tag(COMPONENT_METHOD)}{COMPONENT_N_COMPONENTS}")
    if 'TUNE_LOGREG_PENALTY' in globals() and TUNE_LOGREG_PENALTY:
        parts.append('tune_penalty')
    else:
        try:
            parts.append(f"pen_{_safe_tag(LOGREG_PENALTY_FIXED)}")
        except Exception:
            pass
    return '__'.join([p for p in parts if p])

def get_output_dir() -> Path:
    out_dir = OUTPUTS_ROOT / _config_subdir()
    out_dir.mkdir(parents=True, exist_ok=True)
    return out_dir

def outpath(filename: str) -> Path:
    return get_output_dir() / str(filename)

print('Notebook dir:', NOTEBOOK_DIR)
print('Outputs root:', OUTPUTS_ROOT)
print('Saved-FOOOF root:', SAVED_FOOOF_ROOT)
print('Output dir:', get_output_dir())

# -------------------- Resolve input files --------------------
# Depending on TEST_ON_OTHER_DATASET / USE_BOTH_DATASETS, we may need to resolve BOTH datasets.
if USE_BOTH_DATASETS and TEST_ON_OTHER_DATASET:
    warnings.warn('USE_BOTH_DATASETS=True overrides TEST_ON_OTHER_DATASET; running combined CV instead of cross-test.')
need_new = bool(NEW_DATA or CROSS_DATASET_TEST or USE_BOTH_DATASETS)
need_old = bool((not NEW_DATA) or CROSS_DATASET_TEST or USE_BOTH_DATASETS)
mode = "combined" if USE_BOTH_DATASETS else ("cross-test" if CROSS_DATASET_TEST else "single")
print("Dataset mode:", mode)

new_processed_dir = None
new_subject_pairs = []  # [(subject_id, file_a, file_b), ...] for NEW_DATA
new_epoch_files = []

eyes_closed_files, eyes_open_files = [], []
set_files, set_labels = [], np.array([], dtype=int)

if need_new:
    tried_new = [ENV["NEW_EEG_PROCESSED_DIR"]] + NEW_WIN_PROCESSED + NEW_WSL_PROCESSED + NEW_MAC_PROCESSED

    new_processed_dir = first_existing(tried_new)
    print("Tried NEW processed paths (in order):")
    for path in tried_new:
        print("  -", path)
    print("Resolved NEW processed dir:", new_processed_dir)

    if new_processed_dir is None:
        warnings.warn("Could not resolve NEW processed directory.")
        new_epoch_files = []
    else:
        new_epoch_files = [p.resolve() for p in Path(new_processed_dir).glob("*_epo.fif")] + [p.resolve() for p in Path(new_processed_dir).glob("*_epo.FIF")]
    print(f"Total NEW epoch FIF files collected: {len(new_epoch_files)}")

    # Collect doctor a/b per subject.
    # If a subject only has a single rater file (only 'a' OR only 'b'), keep it as a valid NEW subject entry.
    pairs = {}
    for p in new_epoch_files:
        m = re.search(r"sub(\d+)([ab])", p.stem, flags=re.IGNORECASE)
        if not m:
            continue
        sid = int(m.group(1))
        rater = m.group(2).lower()
        pairs.setdefault(sid, {})[rater] = str(p)

    new_subject_pairs = []  # [(subject_id, file_primary, file_secondary_or_None), ...]
    paired_count = 0
    single_count = 0
    for sid in sorted(pairs):
        entry = pairs[sid]
        if ('a' in entry) and ('b' in entry):
            new_subject_pairs.append((sid, entry['a'], entry['b']))
            paired_count += 1
        elif 'a' in entry:
            new_subject_pairs.append((sid, entry['a'], None))
            single_count += 1
        elif 'b' in entry:
            new_subject_pairs.append((sid, entry['b'], None))
            single_count += 1

    print(f"NEW subjects found: total={len(new_subject_pairs)} (paired={paired_count}, single={single_count})")
    if new_subject_pairs:
        print("Example NEW subject entry:", new_subject_pairs[0])
else:
    print("Skipping NEW dataset resolution (need_new=False).")

if need_old:
    tried_closed = [ENV["EC_EO_CLOSED_DIR"]] + WIN_CLOSED + WSL_CLOSED + REL_CLINICAL_CLOSED
    tried_open = [ENV["EC_EO_OPEN_DIR"]] + WIN_OPEN + WSL_OPEN + REL_CLINICAL_OPEN

    eyes_closed_dir = first_existing(tried_closed)
    eyes_open_dir = first_existing(tried_open)

    print("Tried eyes-closed paths (in order):")
    for path in tried_closed:
        print("  -", path)
    print("Tried eyes-open paths (in order):")
    for path in tried_open:
        print("  -", path)

    eyes_closed_files = collect_set_files(eyes_closed_dir, recursive=True)
    eyes_open_files = collect_set_files(eyes_open_dir, recursive=True)

    if not eyes_closed_files or not eyes_open_files:
        warnings.warn(
            "No .set files found for the configured clinical EC/EO paths."
            f"Resolved eyes_closed_dir: {eyes_closed_dir}"
            f"Resolved eyes_open_dir  : {eyes_open_dir}"
            "Tips:"
            "  • Verify the folders contain .set files."
            "  • Override with env vars if the defaults differ on your machine."
            "  • If on WSL, ensure /mnt/... paths are mounted."
        )
    else:
        print(
            f"Clinical data: {len(eyes_closed_files)} closed files, "
            f"{len(eyes_open_files)} open files."
        )

    set_files = eyes_open_files + eyes_closed_files
    set_labels = np.array([0] * len(eyes_open_files) + [1] * len(eyes_closed_files))
    print(f"Total .set files collected: {len(set_files)}")
else:
    print("Skipping OLD dataset resolution (need_old=False).")


### Sanity check: confirm that files were found

This cell prints one example EO and EC file path from the discovered lists.

If either list is empty, go back and fix the input paths (env vars or the default Windows/WSL locations).


In [None]:
example_open = eyes_open_files[0] if eyes_open_files else "None"
example_closed = eyes_closed_files[0] if eyes_closed_files else "None"
print(f"Example eyes-open file: {example_open}")
print(f"Example eyes-closed file: {example_closed}")


## Feature Extraction & Label Preparation

This stage turns raw epochs into a model-ready feature matrix.

**What happens**
- Loads epochs for EO/EC and assigns labels (`EO=0`, `EC=1`).
- Drops rejected or invalid epochs (e.g., non-finite values or `rejmanual` where available).
- Selects channels (`CHANNEL_SELECTION`) and aligns channels across subjects.
- Computes PSDs for each epoch and channel (`psd_cube`, `psd_freqs`).
- Builds features:
  - **PSD mode** (`USE_FOOOF=False`): flattens PSD (optionally frequency-binned).
  - **FOOOF mode** (`USE_FOOOF=True`): fits aperiodic + peaks and produces per-channel features.
  - **ONE_MAIN_FOOOF** (`ONE_MAIN_FOOOF=True`): derives a subject-level alpha profile (`SUBJECT_ALPHA_PROFILE`) and uses it to compute a fixed alpha template amplitude per epoch.

**Outputs used by training**
- `X_combined`: shape `(n_epochs_total, n_features)`
- `y_combined`: shape `(n_epochs_total,)`
- `subject_ids`: shape `(n_epochs_total,)`

Run this once after file discovery succeeds.


In [None]:
# ---- Feature extraction pipeline (drop-in replacement; supports "all" or ["ALL"]) ----
from typing import Optional
from pathlib import Path
from collections import Counter, OrderedDict
import os
import json
import csv
import re
import warnings
import numpy as np
os.environ.setdefault("NUMBA_DISABLE_CACHE", "1")
os.environ.setdefault("NUMBA_CACHE_DIR", str((Path.cwd() / ".numba_cache").resolve()))
Path(os.environ["NUMBA_CACHE_DIR"]).mkdir(parents=True, exist_ok=True)
import mne
from scipy.io import loadmat

# Prefer specparam; fallback to legacy fooof if needed
try:
    from specparam import SpectralModel
    # If specparam FitError isn't available in your env, we create a fallback
    try:
        from specparam.core.errors import FitError  # type: ignore
    except Exception:
        class FitError(Exception):
            pass
    FOOOF_BACKEND = "specparam"
except Exception:
    try:
        from fooof import FOOOF as SpectralModel
        from fooof.core.errors import FitError
        FOOOF_BACKEND = "fooof"
    except Exception:
        SpectralModel = None
        class FitError(Exception):
            pass
        FOOOF_BACKEND = None
        warnings.warn("FOOOF/specparam not available. Set USE_FOOOF=False to proceed.")

# ---------- Helper: normalize CHANNEL_SELECTION so "ALL" or ["ALL"] means use all channels ----------
def _is_all_channels(sel):
    if isinstance(sel, str):
        return sel.strip().lower() == "all"
    if isinstance(sel, (list, tuple, set)) and len(sel) == 1:
        only = next(iter(sel))
        return isinstance(only, str) and only.strip().lower() == "all"
    return False

ALL_CHANNELS = _is_all_channels(CHANNEL_SELECTION)

ALPHA_FREQ_RANGE = (3.0, 40.0)

# -------------------- Saved ONE_MAIN_FOOOF cache (precomputed) --------------------
def _canonical_channel_name(ch_name: str) -> str:
    name = str(ch_name).strip()
    name = re.sub(r"^EEG\s+", "", name, flags=re.IGNORECASE)
    name = re.sub(r"-REF$", "", name, flags=re.IGNORECASE)
    name = re.sub(r"\s+", "", name)
    return name

def _normalize_channel_selection(sel) -> str:
    try:
        items = list(sel)
    except Exception:
        items = []
    items = [str(x).strip() for x in items if x is not None]
    if any(x.lower() == 'all' for x in items):
        return 'all'
    return ','.join([_canonical_channel_name(x).upper() for x in items])

def _saved_fooof_dataset_tag_from_file(path_str: str) -> str:
    s = str(path_str).lower()
    if 'open_marked' in s or 'closed_marked' in s:
        return 'old_dataset_open_closed_marked'
    if 'preprocessed_setfiles' in s:
        return 'old_dataset_preprocessed_setfiles'
    if 'new_eeg' in s and 'processed' in s:
        return 'new_dataset_processed'
    return 'old_dataset_preprocessed_setfiles'

def _saved_fooof_expected_config_tag(dataset_tag: str) -> str:
    parts = [
        str(dataset_tag),
        'fooof',
        'one_main_fooof',
        'cache_b',
        _channel_tag(),
        f"psd_{PSD_KWARGS.get('fmin', 1.0)}-{PSD_KWARGS.get('fmax', 45.0)}Hz",
        f"fit_{ALPHA_FREQ_RANGE[0]}-{ALPHA_FREQ_RANGE[1]}Hz",
    ]
    if 'FOOOF_SELECTED_FEATURES' in globals() and FOOOF_SELECTED_FEATURES:
        parts.append('feat_' + _safe_tag('-'.join(list(FOOOF_SELECTED_FEATURES))))
    return '__'.join([p for p in parts if p])

def _saved_fooof_config_ok(folder: Path) -> bool:
    cfg_path = folder / 'config.json'
    if not cfg_path.exists():
        return True
    try:
        cfg = json.loads(cfg_path.read_text(encoding='utf-8'))
    except Exception:
        return False
    # Channel selection
    if 'CHANNEL_SELECTION' in cfg:
        if _normalize_channel_selection(cfg.get('CHANNEL_SELECTION')) != _normalize_channel_selection(CHANNEL_SELECTION):
            return False
    # PSD band
    psd_cfg = cfg.get('PSD_KWARGS') or {}
    if float(psd_cfg.get('fmin', PSD_KWARGS.get('fmin', 1.0))) != float(PSD_KWARGS.get('fmin', 1.0)):
        return False
    if float(psd_cfg.get('fmax', PSD_KWARGS.get('fmax', 45.0))) != float(PSD_KWARGS.get('fmax', 45.0)):
        return False
    # Must be ONE_MAIN_FOOOF cache mode B
    if str(cfg.get('CACHE_MODE', '')).upper() != 'B':
        return False
    if bool(cfg.get('COMBINE_ADJACENT_EPOCHS', False)):
        return False
    # Fit range
    afr = cfg.get('ALPHA_FREQ_RANGE', None)
    if afr is not None:
        try:
            afr = tuple(float(x) for x in afr)
            if tuple(float(x) for x in ALPHA_FREQ_RANGE) != afr:
                return False
        except Exception:
            return False
    # Selected features
    if 'FOOOF_SELECTED_FEATURES' in cfg and 'FOOOF_SELECTED_FEATURES' in globals():
        if list(cfg.get('FOOOF_SELECTED_FEATURES') or []) != list(FOOOF_SELECTED_FEATURES or []):
            return False
    # FOOOF settings (match on important keys)
    want = dict(FOOOF_SETTINGS) if 'FOOOF_SETTINGS' in globals() else {}
    have = cfg.get('FOOOF_SETTINGS') or {}
    for k in ('aperiodic_mode', 'peak_width_limits', 'max_n_peaks', 'min_peak_height', 'peak_threshold'):
        if k in want and k in have and want.get(k) != have.get(k):
            return False
    return True

def _find_saved_fooof_npz(file_path: str, subject_raw: int) -> Optional[Path]:
    try:
        root = SAVED_FOOOF_ROOT
    except Exception:
        return None
    if root is None or not Path(root).exists():
        return None
    file_tag = _safe_tag(Path(str(file_path)).stem)
    ds_tag = _saved_fooof_dataset_tag_from_file(str(file_path))
    cfg_tag = _saved_fooof_expected_config_tag(ds_tag)
    cfg_dir = Path(root) / cfg_tag
    cand = cfg_dir / f"features_subject_{int(subject_raw)}__{file_tag}.npz"
    if cand.exists() and _saved_fooof_config_ok(cfg_dir):
        return cand
    # Fallback: search anywhere under saved_fooof
    pattern = f"features_subject_{int(subject_raw)}__{file_tag}.npz"
    matches = [p for p in Path(root).rglob(pattern) if p.is_file()]
    if len(matches) == 1:
        folder = matches[0].parent
        if _saved_fooof_config_ok(folder):
            return matches[0]
    elif len(matches) > 1:
        for p in matches:
            if _saved_fooof_config_ok(p.parent):
                return p
    return None
closed_set = set(map(str, eyes_closed_files))
open_set = set(map(str, eyes_open_files))

def infer_class_from_name(filepath: str) -> Optional[int]:
    """Return 1 for EC, 0 for EO based on filename or folder membership."""
    name = Path(filepath).name.lower()
    if "epoched_60epochsmarked" in name or "closed_marked" in name:
        return 1
    if "eyesopen_marked" in name or "open_marked" in name:
        return 0
    if filepath in closed_set:
        return 1
    if filepath in open_set:
        return 0
    return None

def parse_subject_id(filepath: str) -> int:
    """Extract a subject identifier from the filename."""
    path = Path(filepath)
    match = re.search(r"(\\|/)(\d{5})_", filepath)
    if match:
        return int(match.group(2))
    match = re.search(r"label[01]_(\d+)", path.name)
    if match:
        return int(match.group(1))
    match = re.search(r"(\d+)", path.stem)
    if match:
        return int(match.group(1))
    return abs(hash(path.stem)) % (10 ** 7)

def _load_rejmanual_vector(mat_path: str, n_epochs_expected: int) -> Optional[np.ndarray]:
    """Load EEGLAB reject.rejmanual (1=reject, 0=keep)."""
    try:
        mat = loadmat(mat_path, struct_as_record=False, squeeze_me=True)
    except Exception:
        return None
    rej = None
    block = mat.get("reject", None)
    if block is not None and hasattr(block, "rejmanual"):
        rej = np.array(block.rejmanual)
    if rej is None and "EEG" in mat:
        EEG = mat["EEG"]
        try:
            reject_section = getattr(EEG, "reject", None)
        except Exception:
            reject_section = None
        if reject_section is not None:
            if hasattr(reject_section, "rejmanual"):
                rej = np.array(reject_section.rejmanual)
            elif isinstance(reject_section, dict) and "rejmanual" in reject_section:
                rej = np.array(reject_section["rejmanual"])
    if rej is None:
        return None
    rej = np.asarray(rej).ravel().astype(int)
    if rej.size != n_epochs_expected:
        warnings.warn(f"{Path(mat_path).name}: rejmanual length {rej.size} != n_epochs {n_epochs_expected}.")
        return None
    return (rej != 0).astype(int)

def _psd_array_welch_clean(data: np.ndarray, sfreq: float, fmin=1.0, fmax=45.0, target_secs=2.0):
    """Compute PSDs with Welch while handling NaNs and short epochs."""
    n_epochs, _, n_times = data.shape
    n_per_seg = max(8, min(n_times, int(round(target_secs * sfreq))))
    n_overlap = n_per_seg // 2 if n_per_seg >= 16 else 0
    psds, freqs = mne.time_frequency.psd_array_welch(
        data,
        sfreq=sfreq,
        fmin=fmin,
        fmax=fmax,
        n_per_seg=n_per_seg,
        n_overlap=n_overlap,
        window="hann",
        average="mean",
        verbose=False,
    )
    return psds, freqs

def reduce_freq_resolution(psd_cube: np.ndarray, n_bins: int) -> np.ndarray:
    """Mean-bin the frequency axis to reduce dimensionality."""
    n_samples, n_channels, n_freqs = psd_cube.shape
    bin_size = n_freqs // n_bins
    if bin_size == 0:
        raise ValueError(f"n_bins={n_bins} is too high for n_freqs={n_freqs}")
    trimmed = psd_cube[:, :, : bin_size * n_bins]
    reshaped = trimmed.reshape(n_samples, n_channels, n_bins, bin_size)
    binned = reshaped.mean(axis=3)
    return binned


def _select_alpha_peak(peaks: np.ndarray):
    """Select alpha peak in a two-stage window.
    First search 8–14 Hz, then 4–17 Hz.
    Returns a (CF, Amp, BW) row or None.
    """
    peaks_arr = np.asarray(peaks, float)
    if peaks_arr.size == 0:
        return None
    if peaks_arr.ndim == 1:
        peaks_arr = peaks_arr.reshape(1, -1)
    try:
        primary_lo, primary_hi = (ALPHA_PRIMARY_RANGE if 'ALPHA_PRIMARY_RANGE' in globals() else (8.0, 14.0))
        exp_lo, exp_hi = (ALPHA_EXPANDED_RANGE if 'ALPHA_EXPANDED_RANGE' in globals() else (4.0, 17.0))
    except Exception:
        primary_lo, primary_hi = 8.0, 14.0
        exp_lo, exp_hi = 4.0, 17.0

    chosen = None
    mask = (peaks_arr[:, 0] >= primary_lo) & (peaks_arr[:, 0] <= primary_hi)
    if np.any(mask):
        subset = peaks_arr[mask]
        chosen = subset[np.argmax(subset[:, 1])]
    if chosen is None:
        mask = (peaks_arr[:, 0] >= exp_lo) & (peaks_arr[:, 0] <= exp_hi)
        if np.any(mask):
            subset = peaks_arr[mask]
            chosen = subset[np.argmax(subset[:, 1])]
    return chosen

def compute_fooof_features(freqs: np.ndarray, psd_cube: np.ndarray) -> np.ndarray:
    """Compute FOOOF/specparam features per epoch/channel.
    Uses a two-stage alpha peak search (8–14 Hz, then 4–17 Hz).
    Peaks outside this band are ignored; if none are found, alpha features stay at 0.
    NaN spectra -> zeros via exception path.
    """
    if SpectralModel is None:
        raise RuntimeError('FOOOF backend unavailable; set USE_FOOOF=False.')
    features = []
    for epoch_psd in psd_cube:
        epoch_feats = []
        for spectrum in epoch_psd:
            try:
                if not np.all(np.isfinite(spectrum)):
                    raise ValueError('Non-finite in spectrum')
                model = SpectralModel(**FOOOF_SETTINGS)
                model.fit(freqs, spectrum, freq_range=ALPHA_FREQ_RANGE)
                offset, exponent = 0.0, 0.0
                if hasattr(model, 'aperiodic_params_'):
                    params = np.asarray(model.aperiodic_params_)
                    if params.size > 0:
                        offset = float(params[0])
                    if params.size > 1:
                        exponent = float(params[1])
                alpha_cf = alpha_amp = alpha_bw = 0.0
                peaks = np.asarray(getattr(model, 'peak_params_', []))
                if peaks.size:
                    chosen = _select_alpha_peak(peaks)
                    if chosen is not None:
                        alpha_cf, alpha_amp, alpha_bw = map(float, chosen[:3])
                epoch_feats.extend([offset, exponent, alpha_cf, alpha_amp, alpha_bw])
            except (FitError, RuntimeError, ValueError, np.linalg.LinAlgError):
                epoch_feats.extend([0.0, 0.0, 0.0, 0.0, 0.0])
        features.append(epoch_feats)
    return np.asarray(features, dtype=float)


def compute_one_main_fooof_features(freqs: np.ndarray, psd_cube: np.ndarray, subject_id: int, alpha_profile_map, include_aperiodic: bool = True) -> np.ndarray:
    """Compute features using subject-level alpha profile and per-epoch aperiodic fits.
    For each subject, alpha center frequency and bandwidth are fixed from SUBJECT_ALPHA_PROFILE.
    For each epoch/channel, we fit only the aperiodic component (max_n_peaks=0) and then
    fit the amplitude of a Gaussian template at the subject alpha params to the residual.
    Feature layout per channel: [offset, exponent, alpha_cf, alpha_amp, alpha_bw].
    """
    if SpectralModel is None:
        raise RuntimeError("FOOOF backend unavailable; set USE_FOOOF=False.")
    subj = int(subject_id)
    profile = alpha_profile_map.get(subj) if alpha_profile_map is not None else None
    # If no profile is available, fall back to zero alpha amplitude (still allow aperiodic if requested).
    has_profile = profile is not None and len(profile) == 2
    if has_profile:
        alpha_cf, alpha_bw = map(float, profile)
    else:
        alpha_cf, alpha_bw = 0.0, 0.0
    freqs_arr = np.asarray(freqs, float)
    # Gaussian template (unit amplitude) using fooof/specparam definition: exp(-(f-cf)^2 / (2*sigma^2)), where
    # bw ~ 2*sqrt(2*ln(2))*sigma. We invert that to get sigma from bw.
    import math
    if has_profile and alpha_bw > 0:
        sigma = float(alpha_bw) / (2.0 * math.sqrt(2.0 * math.log(2.0)))
        gauss = np.exp(-0.5 * ((freqs_arr - alpha_cf) / sigma) ** 2)
    else:
        gauss = np.zeros_like(freqs_arr)
    denom = float(np.sum(gauss ** 2)) if gauss.size else 0.0
    features = []
    # Prepare FOOOF settings for aperiodic-only fits
    ap_settings = dict(FOOOF_SETTINGS)
    try:
        ap_settings["max_n_peaks"] = 0
    except Exception:
        pass
    for epoch_psd in psd_cube:
        epoch_feats = []
        for spectrum in epoch_psd:
            try:
                if not np.all(np.isfinite(spectrum)):
                    raise ValueError("Non-finite in spectrum")
                offset, exponent = 0.0, 0.0
                alpha_amp = 0.0
                if include_aperiodic or has_profile:
                    model = SpectralModel(**ap_settings)
                    model.fit(freqs_arr, spectrum, freq_range=ALPHA_FREQ_RANGE)
                    if hasattr(model, "aperiodic_params_"):
                        params = np.asarray(model.aperiodic_params_)
                        if params.size > 0:
                            offset = float(params[0])
                        if params.size > 1:
                            exponent = float(params[1])
                    # reconstruct aperiodic-only spectrum (log10 power)
                    try:
                        # specparam offers get_model_spectrum with no peaks if max_n_peaks=0
                        ap_fit = None
                        get_fun = getattr(model, "get_model_spectrum", None)
                        if callable(get_fun):
                            ap_fit = np.asarray(get_fun(freqs_arr))
                    except Exception:
                        ap_fit = None
                    if ap_fit is None:
                        # fallback: use modeled spectrum attribute if available
                        for name in ("fooofed_spectrum_", "modeled_spectrum_", "model_spectrum_", "model_spectrum__"):
                            if hasattr(model, name):
                                ap_fit = np.asarray(getattr(model, name))
                                break
                    if ap_fit is None or ap_fit.shape != spectrum.shape:
                        ap_fit = np.zeros_like(spectrum)
                    if has_profile and denom > 0.0:
                        # Assume spectrum & ap_fit are in the same (log10) scale; fit template amplitude to residual
                        residual = spectrum - ap_fit
                        num = float(np.sum(gauss * residual))
                        alpha_amp = max(num / denom, 0.0)
                epoch_feats.extend([offset, exponent, alpha_cf if has_profile else 0.0, alpha_amp, alpha_bw if has_profile else 0.0])
            except (FitError, RuntimeError, ValueError, np.linalg.LinAlgError):
                epoch_feats.extend([0.0, 0.0, 0.0, 0.0, 0.0])
        features.append(epoch_feats)
    return np.asarray(features, dtype=float)

def _canonical_eeg_channel_name(ch_name: str) -> str:
    name = str(ch_name).strip()
    name = re.sub(r"^EEG\s+", "", name, flags=re.IGNORECASE)
    name = re.sub(r"-REF$", "", name, flags=re.IGNORECASE)
    name = re.sub(r"\s+", "", name)
    return name

def _rename_epochs_channels_canonical(epochs):
    new_names = [_canonical_eeg_channel_name(ch) for ch in epochs.ch_names]
    if len(set(new_names)) != len(new_names):
        warnings.warn("Canonical channel renaming would create duplicate names; keeping original channel names.")
        return epochs
    mapping = {old: new for old, new in zip(epochs.ch_names, new_names) if old != new}
    if mapping:
        epochs.rename_channels(mapping)
    return epochs

def _labels_from_epochs_events(epochs) -> np.ndarray:
    code_to_name = {int(v): str(k).upper() for k, v in epochs.event_id.items()}
    labels = np.full(len(epochs), -1, dtype=int)
    for i, code in enumerate(epochs.events[:, 2].astype(int)):
        name = code_to_name.get(int(code), "")
        if name.startswith("EO"):
            labels[i] = 0
        elif name.startswith("EC"):
            labels[i] = 1
    return labels

# -------------------- MAIN EXTRACTION LOOP --------------------
records = []
psd_aligned = []
class_epoch_found = Counter()
class_epoch_kept = Counter()
dropped_due_to_nonfinite = 0
paired_epochs_created = 0
paired_epochs_dropped = 0
freq_reference = None
# Track per-subject original epoch positions across files
subject_epoch_cursor = {}

conflict_rows = []
conflict_log_path = None
if 'need_new' in globals() and need_new:
    try:
        conflict_dir = Path(new_processed_dir) if 'new_processed_dir' in globals() and new_processed_dir else None
    except Exception:
        conflict_dir = None
    if conflict_dir is None:
        conflict_dir = Path.cwd()
    conflict_log_path = outpath('doctor_label_conflicts.csv')

items = []
use_offset = bool(CROSS_DATASET_TEST or USE_BOTH_DATASETS)
if 'need_new' in globals() and need_new:
    items.extend([("new", sid, fa, fb) for (sid, fa, fb) in new_subject_pairs])
    if not new_subject_pairs:
        warnings.warn('need_new=True but no NEW subject files were found.')
if 'need_old' in globals() and need_old:
    items.extend([("old", None, f, None) for f in set_files])
    if not set_files:
        warnings.warn('need_old=True but no OLD .set files were found.')
if not items:
    raise RuntimeError('No input files found for the selected dataset mode. Check paths and toggles.')

for dataset_tag, subj_hint, file_a, file_b in items:
    is_new = (dataset_tag == 'new')
    path = Path(file_a)
    if not path.exists():
        warnings.warn(f'Missing file: {path}')
        continue

    file_level_label = None
    path_b = None
    subject_id_raw = None
    if is_new:
        try:
            epochs = mne.read_epochs(str(path), preload=False, verbose='ERROR')
        except Exception as exc:
            warnings.warn(f'Failed to read NEW FIF epochs for subject {subj_hint}: {exc}')
            continue
        epochs = _rename_epochs_channels_canonical(epochs)
        labels_a = _labels_from_epochs_events(epochs)
        # Optional second rater file (doctor b). If missing, proceed with labels from the single file.
        path_b = Path(file_b) if file_b else None
        labels_all = labels_a
        if path_b is not None:
            if not path_b.exists():
                warnings.warn(f'Missing paired file (b): {path_b}. Proceeding with single NEW file: {path.name}.')
                path_b = None
            else:
                try:
                    epochs_b = mne.read_epochs(str(path_b), preload=False, verbose='ERROR')
                    labels_b = _labels_from_epochs_events(epochs_b)
                except Exception as exc:
                    warnings.warn(f'Failed to read NEW paired FIF epochs for subject {subj_hint}: {path_b.name}: {exc}. Proceeding with single file.')
                    labels_b = None
                    path_b = None
                if labels_b is not None:
                    if labels_a.shape != labels_b.shape:
                        warnings.warn(
                            f'Label length mismatch for subject {subj_hint}: {path.name} ({labels_a.size}) vs {path_b.name} ({labels_b.size}). Skipping.'
                        )
                        continue
                    union_labels = labels_a.copy()
                    take_from_b = (union_labels < 0)
                    union_labels[take_from_b] = labels_b[take_from_b]
                    conflict_mask = (labels_a >= 0) & (labels_b >= 0) & (labels_a != labels_b)
                    if np.any(conflict_mask):
                        sid = int(subj_hint) if subj_hint is not None else int(parse_subject_id(str(path)))
                        for ep_idx in np.flatnonzero(conflict_mask):
                            conflict_rows.append({
                                'subject_id': int(sid),
                                'epoch_index': int(ep_idx),
                                'label_a': int(labels_a[ep_idx]),
                                'label_b': int(labels_b[ep_idx]),
                                'file_a': str(path),
                                'file_b': str(path_b),
                            })
                    union_labels[conflict_mask] = -1
                    labels_all = union_labels

        keep_mask_labels = (labels_all >= 0)
        subject_id_raw = int(subj_hint) if subj_hint is not None else int(parse_subject_id(str(path)))
    else:
        label = infer_class_from_name(str(path))
        if label is None:
            warnings.warn(f'Cannot infer EC/EO label for {path.name}. Skipping.')
            continue
        file_level_label = int(label)
        try:
            epochs = mne.io.read_epochs_eeglab(str(path), verbose='ERROR')
        except Exception as exc:
            warnings.warn(f'Failed to read {path.name}: {exc}')
            continue
        labels_all = np.full(len(epochs), int(label), dtype=int)
        rej = _load_rejmanual_vector(str(path), n_epochs_expected=len(epochs))
        if rej is None:
            warnings.warn(f'{path.name}: no valid rejmanual vector. Skipping.')
            continue
        keep_mask_labels = (rej == 0)
        if not np.any(keep_mask_labels):
            warnings.warn(f'{path.name}: no epochs marked as keep (rejmanual==0). Skipping.')
            continue
        subject_id_raw = int(parse_subject_id(str(path)))

    # Use a disambiguated subject_id when mixing datasets (prevents collisions across NEW vs OLD).
    subject_id = int(subject_id_raw)
    if use_offset and (not is_new):
        subject_id = int(subject_id_raw) + int(DATASET_SUBJECT_OFFSET)

    channels_info = epochs.info['chs']

    if ALL_CHANNELS:
        picks_all = mne.pick_types(epochs.info, eeg=True, meg=False, stim=False, eog=False, exclude='bads')
        if len(picks_all) == 0:
            picks_all = mne.pick_types(epochs.info, eeg=True, meg=False, stim=False, eog=False, exclude=[])
        available_channels = [epochs.ch_names[idx] for idx in picks_all]
        picks = picks_all
    else:
        requested = CHANNEL_SELECTION
        requested_upper = [ch.upper() for ch in requested]
        name_lookup = {ch_info['ch_name'].upper(): ch_info['ch_name'] for ch_info in channels_info}
        missing = [ch for ch in requested_upper if ch not in name_lookup]
        if missing:
            print(f'{path.name}: missing requested channels {missing}. Available channel names:')
            for ch in channels_info:
                print(f"  - {ch['ch_name']}")
            continue
        available_channels = [name_lookup[ch] for ch in requested_upper]
        picks = [epochs.ch_names.index(ch) for ch in available_channels]

    if not available_channels:
        warnings.warn(f'{path.name}: no EEG channels matched the selection. Skipping.')
        continue

    data = epochs.get_data(picks=picks)
    n_epochs, _, _ = data.shape
    sfreq = float(epochs.info['sfreq'])
    if labels_all.shape[0] != n_epochs or keep_mask_labels.shape[0] != n_epochs:
        warnings.warn(f'{path.name}: label/mask length mismatch (n_epochs={n_epochs}). Skipping.')
        continue

    for cls in (0, 1):
        class_epoch_found[cls] += int(np.sum(labels_all == cls))

    finite_mask = np.all(np.isfinite(data), axis=(1, 2))
    if not np.all(finite_mask):
        dropped = int(np.sum(~finite_mask))
        dropped_due_to_nonfinite += dropped
        warnings.warn(f'{path.name}: dropping {dropped} epochs containing NaN/Inf.')
    keep_mask = keep_mask_labels & finite_mask
    if not np.any(keep_mask):
        warnings.warn(f'{path.name}: no usable epochs after keep-mask and NaN/Inf filtering. Skipping.')
        continue

    kept_indices = np.flatnonzero(keep_mask)
    kept_labels = labels_all[keep_mask].astype(int)
    if COMBINE_ADJACENT_EPOCHS:
        grouped_pairs = []
        pair_labels = []
        cursor = 0
        while cursor < len(kept_indices):
            run_end = cursor + 1
            while (
                run_end < len(kept_indices)
                and kept_indices[run_end] == kept_indices[run_end - 1] + 1
                and kept_labels[run_end] == kept_labels[run_end - 1]
            ):
                run_end += 1
            run = kept_indices[cursor:run_end]
            run_label = int(kept_labels[cursor])
            n_pairs = len(run) // 2
            for pair_idx in range(n_pairs):
                grouped_pairs.append((run[2 * pair_idx], run[2 * pair_idx + 1]))
                pair_labels.append(run_label)
            if len(run) % 2 == 1:
                paired_epochs_dropped += 1
            cursor = run_end
        if not grouped_pairs:
            warnings.warn(f'{path.name}: no adjacent same-label kept epochs available for pairing. Skipping file.')
            continue
        combined = [np.concatenate([data[i0], data[i1]], axis=1) for i0, i1 in grouped_pairs]
        data_kept = np.stack(combined, axis=0)
        epoch_local_indices = np.array([min(i0, i1) for i0, i1 in grouped_pairs], dtype=int)
        labels_kept = np.asarray(pair_labels, dtype=int)
        paired_epochs_created += len(grouped_pairs)
    else:
        data_kept = data[keep_mask]
        epoch_local_indices = kept_indices.astype(int)
        labels_kept = kept_labels.astype(int)

    for cls in (0, 1):
        class_epoch_kept[cls] += int(np.sum(labels_kept == cls))

    target_secs = 1.0 if data_kept.shape[-1] < sfreq * 3 else 2.0
    try:
        psd_data, freqs = _psd_array_welch_clean(
            data_kept,
            sfreq=sfreq,
            fmin=PSD_KWARGS.get('fmin', 1.0),
            fmax=PSD_KWARGS.get('fmax', 45.0),
            target_secs=target_secs,
        )
    except Exception as exc:
        warnings.warn(f'PSD computation failed for {path.name}: {exc}')
        continue

    if freq_reference is None:
        freq_reference = freqs
    else:
        if freqs.shape != freq_reference.shape or not np.allclose(freqs, freq_reference):
            raise RuntimeError(f'Frequency axis mismatch in {path.name} compared to earlier files.')

    # For ONE_MAIN_FOOOF, optionally build the per-subject alpha profile from ALL epochs (to mimic unseen data).
    # This does not change which epochs are used for supervised training (labels_kept still comes from keep_mask).
    psd_profile = psd_data
    if USE_FOOOF and ONE_MAIN_FOOOF and MAIN_FOOOF_USE_ALL_EPOCHS and SpectralModel is not None:
        profile_mask = finite_mask  # ignore keep_mask_labels (labels/rejmanual) for alpha-profile PSD
        if not np.any(profile_mask):
            warnings.warn(f'{path.name}: no finite epochs available for MAIN_FOOOF_USE_ALL_EPOCHS; using labeled/kept epochs instead.')
        else:
            data_profile = None
            if COMBINE_ADJACENT_EPOCHS:
                # Pair consecutive finite epochs regardless of label (labels may be missing/unknown).
                profile_indices = np.flatnonzero(profile_mask)
                profile_pairs = []
                cursor = 0
                while cursor < len(profile_indices) - 1:
                    i0 = int(profile_indices[cursor])
                    i1 = int(profile_indices[cursor + 1])
                    if i1 == i0 + 1:
                        profile_pairs.append((i0, i1))
                        cursor += 2
                    else:
                        cursor += 1
                if profile_pairs:
                    combined_profile = [np.concatenate([data[i0], data[i1]], axis=1) for i0, i1 in profile_pairs]
                    data_profile = np.stack(combined_profile, axis=0)
                else:
                    warnings.warn(f'{path.name}: no adjacent finite epochs available for pairing; using labeled/kept epochs for main-FOOOF profile.')
            else:
                data_profile = data[profile_mask]

            if data_profile is not None and data_profile.shape[0] > 0:
                try:
                    psd_profile, freqs_profile = _psd_array_welch_clean(
                        data_profile,
                        sfreq=sfreq,
                        fmin=PSD_KWARGS.get('fmin', 1.0),
                        fmax=PSD_KWARGS.get('fmax', 45.0),
                        target_secs=target_secs,
                    )
                    if freqs_profile.shape != freqs.shape or not np.allclose(freqs_profile, freqs):
                        raise RuntimeError('Frequency axis mismatch for main-FOOOF profile PSD.')
                except Exception as exc:
                    warnings.warn(f'Main-FOOOF profile PSD computation failed for {path.name}: {exc}; using labeled/kept epochs instead.')
                    psd_profile = psd_data

    # Optional: use precomputed ONE_MAIN_FOOOF features (skip per-epoch SpectralModel.fit)
    saved_fooof_npz = None
    if (
        'USE_SAVED_FOOOF' in globals() and USE_SAVED_FOOOF
        and USE_FOOOF and ONE_MAIN_FOOOF
        and (not COMBINE_ADJACENT_EPOCHS)
    ):
        try:
            cand = _find_saved_fooof_npz(str(path), subject_raw=int(subject_id_raw))
            if cand is not None:
                saved_fooof_npz = str(cand)
        except Exception:
            saved_fooof_npz = None

    if TIME_AXIS_MODE == 'append_files':
        offset = subject_epoch_cursor.get(subject_id, 0)
        time_indices = offset + epoch_local_indices
        subject_epoch_cursor[subject_id] = offset + n_epochs
    elif TIME_AXIS_MODE == 'align_conditions':
        time_indices = epoch_local_indices
    elif TIME_AXIS_MODE == 'interleave_conditions':
        time_indices = 2 * epoch_local_indices + labels_kept
    else:
        raise ValueError(f'Unknown TIME_AXIS_MODE: {TIME_AXIS_MODE!r}')

    records.append(dict(
        file=str(path),
        file_b=str(path_b) if (is_new and path_b is not None) else None,
        label=file_level_label,
        labels=labels_kept.tolist(),
        dataset=str(dataset_tag),
        subject_raw=int(subject_id_raw),
        subject=subject_id,
        finite_mask=np.asarray(finite_mask, dtype=bool).tolist(),
        keep_mask_labels=np.asarray(keep_mask_labels, dtype=bool).tolist(),
        saved_fooof_npz=saved_fooof_npz,
        channels=available_channels,
        psd=psd_data,
        psd_profile=psd_profile,
        freqs=freqs,
        total_epochs=n_epochs,
        kept_epochs=data_kept.shape[0],
        epoch_time_indices=np.asarray(time_indices, dtype=int).tolist(),
    ))

    if is_new:
        cnt_eo = int(np.sum(labels_kept == 0))
        cnt_ec = int(np.sum(labels_kept == 1))
        print(f'[OK] {path.name}: kept {data_kept.shape[0]}/{n_epochs} epochs (EO={cnt_eo}, EC={cnt_ec})')
    else:
        status = 'EC' if file_level_label == 1 else 'EO'
        print(f'[OK] {path.name}: kept {data_kept.shape[0]}/{n_epochs} epochs (class={status})')

if ('need_new' in globals() and need_new) and conflict_log_path is not None and conflict_rows:
    fieldnames = ['subject_id', 'epoch_index', 'label_a', 'label_b', 'file_a', 'file_b']
    try:
        with conflict_log_path.open('w', newline='') as f:
            w = csv.DictWriter(f, fieldnames=fieldnames)
            w.writeheader()
            w.writerows(conflict_rows)
        print(f'Wrote conflict log: {conflict_log_path} ({len(conflict_rows)} conflicts)')
    except Exception as exc:
        warnings.warn(f'Failed to write conflict log to {conflict_log_path}: {exc}')

if not records:
    raise RuntimeError("No feature vectors extracted. Check channel selection, FOOOF availability, or data quality.")

# -------------------- CHANNEL RECONCILIATION (UNION IF ALL) --------------------
if ALL_CHANNELS:
    # Build UNION of channels across all files, preserving first-seen order
    union_upper_to_name = OrderedDict()
    for rec in records:
        for ch in rec['channels']:
            key = ch.upper()
            if key not in union_upper_to_name:
                union_upper_to_name[key] = ch
    feature_channels = list(union_upper_to_name.values())
else:
    requested_upper = [ch.upper() for ch in CHANNEL_SELECTION]
    feature_channels = []
    missing_global = []
    for req in requested_upper:
        actual = None
        for rec in records:
            mapping = {ch.upper(): ch for ch in rec['channels']}
            if req in mapping:
                actual = mapping[req]
                break
        if actual is None:
            missing_global.append(req)
        else:
            feature_channels.append(actual)
    if missing_global:
        raise RuntimeError(f"Requested channels not found in the available data: {missing_global}")

if not feature_channels:
    raise RuntimeError("Resolved feature channel list is empty. Adjust CHANNEL_SELECTION.")

# -------------------- Channel sanity summary --------------------
def _union_channels_for_dataset(ds: str) -> list:
    union = OrderedDict()
    for rec in records:
        if str(rec.get("dataset")) != str(ds):
            continue
        for ch in rec.get("channels", []):
            key = _canonical_channel_name(ch).upper()
            if key not in union:
                union[key] = ch
    return list(union.values())

new_feature_channels = _union_channels_for_dataset("new")
old_feature_channels = _union_channels_for_dataset("old")
if new_feature_channels:
    print(f"NEW dataset channel summary: {len(new_feature_channels)} channels -> {new_feature_channels}")
    std_upper = [c.upper() for c in STANDARD_19_CHANNELS] if "STANDARD_19_CHANNELS" in globals() else []
    new_upper = [str(c).upper() for c in new_feature_channels]
    missing_std = [c for c in (STANDARD_19_CHANNELS if "STANDARD_19_CHANNELS" in globals() else []) if c.upper() not in new_upper]
    extra = [c for c in new_feature_channels if c.upper() not in std_upper] if std_upper else []
    if missing_std:
        print("NEW dataset missing STANDARD_19_CHANNELS:", missing_std)
    if extra:
        print("NEW dataset extra channels vs STANDARD_19_CHANNELS:", extra)
    if "FAIL_IF_NEW_CHANNEL_MISMATCH" in globals() and FAIL_IF_NEW_CHANNEL_MISMATCH:
        if "EXPECTED_NEW_CHANNEL_COUNT" in globals() and len(new_feature_channels) != int(EXPECTED_NEW_CHANNEL_COUNT):
            warnings.warn(f"NEW dataset channel count is {len(new_feature_channels)} (expected {EXPECTED_NEW_CHANNEL_COUNT}). This will change feature dimensionality.")

    if "FORCE_STANDARD_19_FOR_NEW" in globals() and FORCE_STANDARD_19_FOR_NEW:
        # Only safe when training NEW dataset alone; forces a stable feature axis.
        if missing_std:
            raise RuntimeError(f"FORCE_STANDARD_19_FOR_NEW=True but missing channels: {missing_std}")
        # Force global feature_channels to the standard order (NEW dataset channels assumed canonical already).
        feature_channels = list(STANDARD_19_CHANNELS)
        print("FORCE_STANDARD_19_FOR_NEW applied: feature_channels forced to STANDARD_19_CHANNELS")

if old_feature_channels:
    print(f"OLD dataset channel summary: {len(old_feature_channels)} channels -> {old_feature_channels}")
# ---------------------------------------------------------------

# -------------------- ALIGN TO FEATURE CHANNELS + BUILD FEATURES --------------------
psd_freqs = freq_reference
psd_aligned = []          # features: aligned PSDs for labeled/kept epochs
psd_aligned_profile = []  # profile: aligned PSDs used to build SUBJECT_ALPHA_PROFILE

missing_report = []
for rec in records:
    name_to_idx = {ch: idx for idx, ch in enumerate(rec['channels'])}
    pick_idx = [name_to_idx.get(ch, None) for ch in feature_channels]

    psd_feat = rec['psd']
    psd_prof = rec.get('psd_profile', psd_feat)

    n_epochs_here = psd_feat.shape[0]
    n_freqs_here = psd_feat.shape[2]

    aligned = np.full((n_epochs_here, len(feature_channels), n_freqs_here), np.nan, dtype=float)

    n_epochs_prof = psd_prof.shape[0]
    aligned_prof = np.full((n_epochs_prof, len(feature_channels), n_freqs_here), np.nan, dtype=float)

    have_pairs = [(out_i, in_i) for out_i, in_i in enumerate(pick_idx) if in_i is not None]
    if have_pairs:
        out_idx, in_idx = zip(*have_pairs)
        aligned[:, list(out_idx), :] = psd_feat[:, list(in_idx), :]
        aligned_prof[:, list(out_idx), :] = psd_prof[:, list(in_idx), :]

    n_missing = sum(1 for x in pick_idx if x is None)
    if n_missing:
        missing_report.append((Path(rec['file']).name, n_missing))

    psd_aligned.append(aligned)
    psd_aligned_profile.append(aligned_prof)

# -------------------- SUBJECT ALPHA PROFILE ESTIMATION (per subject, from EC epochs and posterior ROI) --------------------
SUBJECT_ALPHA_PROFILE = {}
if USE_FOOOF and ONE_MAIN_FOOOF and SpectralModel is not None:
    # Build subject-level alpha profiles by averaging ROI spectra and running FOOOF once per subject.
    # We reuse feature_channels (already aligned with psd_aligned).
    roi_names_profile = [ch for ch in ALPHA_PROFILE_ROI if ch in feature_channels]
    roi_idx_profile = [feature_channels.index(ch) for ch in roi_names_profile]
    if roi_idx_profile:
        per_subject_psds = {}
        for rec, aligned in zip(records, psd_aligned_profile):
            subj = int(rec['subject'])
            # aligned: (n_epochs, n_channels, n_freqs); take ROI channels
            roi_cube = aligned[:, roi_idx_profile, :]
            lst = per_subject_psds.setdefault(subj, [])
            lst.append(roi_cube)
        for subj, cubes in per_subject_psds.items():
            try:
                data = np.concatenate(cubes, axis=0)  # (epochs_total, n_roi, n_freqs)
                mean_spectrum = data.mean(axis=(0, 1))
                if not np.any(np.isfinite(mean_spectrum)):
                    continue
                # FOOOF model for alpha profile
                model = SpectralModel(**FOOOF_SETTINGS)
                lo, hi = ALPHA_PROFILE_RANGE
                model.fit(psd_freqs, mean_spectrum, freq_range=(lo, hi))
                peaks = np.asarray(getattr(model, 'peak_params_', []))
                if peaks.size:
                    # pick strongest peak in ALPHA_PROFILE_RANGE
                    mask = (peaks[:, 0] >= lo) & (peaks[:, 0] <= hi)
                    if np.any(mask):
                        subset = peaks[mask]
                        best = subset[np.argmax(subset[:, 1])]
                        cf, amp, bw = map(float, best[:3])
                        SUBJECT_ALPHA_PROFILE[subj] = (cf, bw)
            except Exception:
                continue
    print('Built alpha profiles for', len(SUBJECT_ALPHA_PROFILE), 'subject(s); ROI channels:', (roi_names_profile if roi_idx_profile else 'None found'))
else:
    SUBJECT_ALPHA_PROFILE = {}

def _load_saved_one_main_fooof_features(rec: dict, feature_channels: list, psd_freqs: np.ndarray) -> Optional[np.ndarray]:
    """Load cached ONE_MAIN_FOOOF features and expand to full 5-feature layout.

    Returns an array shaped (n_epochs, len(feature_channels) * 5) or None.
    """
    if not ('USE_SAVED_FOOOF' in globals() and USE_SAVED_FOOOF):
        return None
    if not (USE_FOOOF and ONE_MAIN_FOOOF):
        return None
    if COMBINE_ADJACENT_EPOCHS:
        # Precompute notebook pairs finite epochs; classifier pairs kept same-label runs.
        return None
    npz_path = rec.get('saved_fooof_npz', None)
    if not npz_path:
        return None
    npz_path = Path(str(npz_path))
    if not npz_path.exists():
        return None
    try:
        d = np.load(npz_path, allow_pickle=True)
    except Exception:
        return None
    if 'X' not in d or 'feature_names' not in d or 'freqs' not in d:
        return None
    freqs_saved = np.asarray(d['freqs'], float).ravel()
    freqs_here = np.asarray(psd_freqs, float).ravel()
    if freqs_saved.shape != freqs_here.shape or not np.allclose(freqs_saved, freqs_here):
        return None
    X_saved = np.asarray(d['X'])
    feat_names = [str(x) for x in np.asarray(d['feature_names']).ravel().tolist()]

    # Cache is computed on finite epochs only; subset to the same kept epochs used here.
    finite_mask = np.asarray(rec.get('finite_mask', []), dtype=bool)
    keep_mask_labels = np.asarray(rec.get('keep_mask_labels', []), dtype=bool)
    if finite_mask.size and keep_mask_labels.size and finite_mask.shape == keep_mask_labels.shape:
        keep_finite = keep_mask_labels[finite_mask]
        if keep_finite.shape[0] != X_saved.shape[0]:
            return None
        X_saved = X_saved[keep_finite]

    # Ensure epoch count matches labels
    n_labels = len(rec.get('labels', []))
    if n_labels and X_saved.shape[0] != n_labels:
        return None

    base_order = ['offset', 'exponent', 'alpha_cf', 'alpha_amp', 'alpha_bw']
    try:
        needed = [f for f in list(FOOOF_SELECTED_FEATURES) if f in base_order]
    except Exception:
        needed = base_order
    if not needed:
        needed = base_order
    available_feats = set()
    mapping = {}
    for idx, name in enumerate(feat_names):
        try:
            ch_part, feat = str(name).rsplit('_', 1)
        except ValueError:
            continue
        feat = str(feat)
        if feat not in base_order:
            continue
        ch_key = _canonical_channel_name(ch_part).upper()
        mapping[(ch_key, feat)] = int(idx)
        available_feats.add(feat)
    if not set(needed).issubset(available_feats):
        return None

    # Expand into the full 5-feature-per-channel layout expected by the classifier notebook.
    out = np.zeros((int(X_saved.shape[0]), int(len(feature_channels) * len(base_order))), dtype=float)
    for ch_i, ch in enumerate(feature_channels):
        ch_key = _canonical_channel_name(ch).upper()
        base_col = ch_i * len(base_order)
        for off, feat in enumerate(base_order):
            src = mapping.get((ch_key, feat), None)
            if src is None:
                continue
            out[:, base_col + off] = X_saved[:, src]
    return np.nan_to_num(out, nan=0.0, posinf=0.0, neginf=0.0)

# Build features (FOOOF or flattened PSD)
X_list, y_list, subject_ids, epoch_time_list = [], [], [], []
for rec, aligned in zip(records, psd_aligned):
    feats = None
    if USE_FOOOF and ONE_MAIN_FOOOF:
        feats = _load_saved_one_main_fooof_features(rec, feature_channels=feature_channels, psd_freqs=psd_freqs)
        if feats is not None:
            try:
                print(f"[SAVED_FOOOF] {Path(str(rec.get('file',''))).name} -> {Path(str(rec.get('saved_fooof_npz'))).name}")
            except Exception:
                pass
    if feats is None:
        if USE_FOOOF:
            if ONE_MAIN_FOOOF:
                feats = compute_one_main_fooof_features(psd_freqs, aligned, subject_id=int(rec['subject']), alpha_profile_map=SUBJECT_ALPHA_PROFILE)
            else:
                feats = compute_fooof_features(psd_freqs, aligned)
        else:
            flat = aligned.reshape(aligned.shape[0], -1)
            feats = np.nan_to_num(flat, nan=0.0, posinf=0.0, neginf=0.0)
    X_list.append(feats)
    labels = np.asarray(rec.get('labels', []), dtype=int)
    if labels.shape[0] != feats.shape[0]:
        labels = np.full(feats.shape[0], int(rec.get('label', -1)), dtype=int)
    y_list.append(labels)
    subject_ids.extend([rec['subject']] * feats.shape[0])
    epoch_time_list.append(np.asarray(rec['epoch_time_indices'], dtype=int))

psd_cube = np.concatenate(psd_aligned, axis=0)
X_combined = np.vstack(X_list).astype(float)
y_combined = np.concatenate(y_list).astype(int)
subject_ids = np.array(subject_ids, dtype=int)
epoch_time_indices = np.concatenate(epoch_time_list).astype(int)

# -------------------- REPORT --------------------
print(f"Features collected: {X_combined.shape[0]} epochs × {X_combined.shape[1]} features.")
class_names = {0: "EO", 1: "EC"}
counts = {class_names.get(int(cls), str(cls)): int(cnt) for cls, cnt in zip(*np.unique(y_combined, return_counts=True))}
print("Class counts:", counts)
print("FOOOF backend:", FOOOF_BACKEND)
print(f"Final feature channel list ({len(feature_channels)} channels): {feature_channels}")
if dropped_due_to_nonfinite:
    print(f"Dropped epochs due to NaN/Inf: {dropped_due_to_nonfinite}")
if COMBINE_ADJACENT_EPOCHS:
    print(f"Combined {paired_epochs_created} pairs of adjacent epochs; dropped {paired_epochs_dropped} singles that could not be paired.")
print("Epoch statistics (found → used):")
for cls in sorted(class_epoch_found):
    label_name = class_names.get(cls, str(cls))
    found = class_epoch_found[cls]
    kept = class_epoch_kept[cls]
    print(f"  {label_name}: {found} found, {kept} used")

# Optional QC: channels missing per file (when using UNION)
if missing_report:
    for fname, n_miss in missing_report:
        print(f"{fname}: {n_miss} / {len(feature_channels)} channels missing in this recording")

# Derive TARGET_CHANNELS for plotting or later use
if ALL_CHANNELS:
    preferred = [ch for ch in ("O1", "O2"
                               , "P3", "P4", "P7", "P8", "Pz"
                            #   , "F3", "F4", "C3", "C4", "F7", "F8", "T7", "T8", "Fz", "Cz"
                            #   , "Fp1", "Fp2"
                               ) if ch in feature_channels]
    TARGET_CHANNELS = preferred if preferred else feature_channels[:2]
else:
    TARGET_CHANNELS = feature_channels

PSD_META = dict(n_channels=len(feature_channels), n_freqs=psd_cube.shape[-1], channels=feature_channels, freqs=psd_freqs)

# Persist feature-axis metadata for downstream labeling notebooks
try:
    np.save(outpath('feature_channels.npy'), np.asarray(feature_channels, dtype=object))
    np.save(outpath('psd_freqs.npy'), np.asarray(psd_freqs, dtype=float))
    if 'USE_FOOOF' in globals() and USE_FOOOF:
        np.save(outpath('fooof_selected_features.npy'), np.asarray(list(FOOOF_SELECTED_FEATURES), dtype=object))
    print('Saved feature metadata:', outpath('feature_channels.npy').name, outpath('psd_freqs.npy').name)
except Exception as _exc:
    warnings.warn(f'Could not save feature metadata: {_exc}')


### Diagnostic: subject-level ROI PSD and main alpha peak (ONE_MAIN_FOOOF)

This cell is a *visual check* of the **subject-level alpha profile** used in `ONE_MAIN_FOOOF`.

It:
- Picks a subject that has an entry in `SUBJECT_ALPHA_PROFILE`.
- Averages PSDs across the ROI channels (`ALPHA_PROFILE_ROI`).
- Fits FOOOF/specparam to the averaged spectrum.
- Plots the empirical PSD, the reconstructed aperiodic fit, and the strongest alpha-range peak.

Run it only after feature extraction when `USE_FOOOF=True` and `ONE_MAIN_FOOOF=True`.


In [None]:
# The plot shows: original PSD, aperiodic fit, and the Gaussian alpha peak component.

import numpy as np
import matplotlib.pyplot as plt
import math

if not ("USE_FOOOF" in globals() and USE_FOOOF and "ONE_MAIN_FOOOF" in globals() and ONE_MAIN_FOOOF):
    print("This plot is intended for ONE_MAIN_FOOOF mode after feature extraction.")
else:
    if not SUBJECT_ALPHA_PROFILE:
        print("No SUBJECT_ALPHA_PROFILE available – run the feature extraction cell first.")
    else:
        # Pick the first subject with an alpha profile
        subj_example = sorted(SUBJECT_ALPHA_PROFILE.keys())[0]
        alpha_cf, alpha_bw = SUBJECT_ALPHA_PROFILE[subj_example]

        # Determine ROI channels used for the profile
        roi_names = [ch for ch in ALPHA_PROFILE_ROI if ch in feature_channels]
        if not roi_names:
            roi_names = feature_channels
        roi_idx = [feature_channels.index(ch) for ch in roi_names]

        # Gather all PSDs for this subject and ROI channels
        subj_mask = (subject_ids == subj_example)
        if not np.any(subj_mask):
            print("No epochs found in psd_cube for subject", subj_example)
        else:
            data = psd_cube[subj_mask][:, roi_idx, :]  # (n_epochs, n_roi, n_freqs)
            mean_spectrum = data.mean(axis=(0, 1))
            if not np.any(np.isfinite(mean_spectrum)):
                print("Mean spectrum for subject", subj_example, "is non-finite.")
            else:
                freqs_arr = np.asarray(psd_freqs, float)

                # ---- choose FOOOF fit range (broad, like the report) ----
                if isinstance(FOOOF_SETTINGS, dict) and "freq_range" in FOOOF_SETTINGS and FOOOF_SETTINGS["freq_range"] is not None:
                    fit_lo, fit_hi = FOOOF_SETTINGS["freq_range"]
                else:
                    fit_lo, fit_hi = freqs_arr[0], freqs_arr[-1]

                fit_mask = (freqs_arr >= fit_lo) & (freqs_arr <= fit_hi)
                freqs_fit = freqs_arr[fit_mask]
                spec_fit = mean_spectrum[fit_mask]

                # ---- Fit FOOOF/specparam on the averaged spectrum ----
                model = SpectralModel(**FOOOF_SETTINGS)
                model.fit(freqs_fit, spec_fit)

                # Frequency axis and original spectrum as the model sees them
                freqs_plot = np.asarray(getattr(model, "freqs", freqs_fit))
                psd_plot   = np.asarray(getattr(model, "power_spectrum", spec_fit))

                # ---- Reconstruct aperiodic fit from aperiodic_params_ ----
                ap_params = np.asarray(getattr(model, "aperiodic_params_", []), float)
                if ap_params.size == 0:
                    ap_fit = np.zeros_like(freqs_plot, dtype=float)
                else:
                    if ap_params.size == 2:
                        # fixed mode: [offset, exponent]
                        offset, exponent = ap_params
                        ap_fit = offset - exponent * np.log10(freqs_plot)
                    elif ap_params.size == 3:
                        # knee mode: [offset, knee, exponent] – approximate formula
                        offset, knee, exponent = ap_params
                        ap_fit = offset - np.log10(knee + freqs_plot**exponent)
                    else:
                        # fallback: flat
                        ap_fit = np.zeros_like(freqs_plot, dtype=float)

                # ---- Build Gaussian for the main alpha peak ----
                gauss_main = None
                peaks = np.asarray(getattr(model, "peak_params_", []), float)
                if peaks.size:
                    lo_alpha, hi_alpha = ALPHA_PROFILE_RANGE
                    # pick strongest peak inside alpha range
                    mask_peaks = (peaks[:, 0] >= lo_alpha) & (peaks[:, 0] <= hi_alpha)
                    if np.any(mask_peaks):
                        subset = peaks[mask_peaks]
                        best = subset[np.argmax(subset[:, 1])]
                        cf, amp, bw = map(float, best[:3])
                        if bw > 0:
                            sigma = bw / (2.0 * math.sqrt(2.0 * math.log(2.0)))
                            gauss_main = amp * np.exp(-0.5 * ((freqs_plot - cf) / sigma) ** 2)

                # ---- Prepare plot ----
                plt.figure(figsize=(7, 4))

                # Original averaged PSD (as FOOOF sees it)
                plt.plot(freqs_plot, psd_plot, label="Averaged PSD", color="#1f77b4")

                # Aperiodic fit
                plt.plot(freqs_plot, ap_fit, label="Aperiodic fit", color="#ff7f0e", linestyle="--")

                # Aperiodic + main alpha peak
                if gauss_main is not None:
                    plt.plot(freqs_plot, ap_fit + gauss_main,
                             label="Aperiodic + main alpha", color="#2ca02c")

                plt.xlabel("Frequency (Hz)")
                plt.ylabel("Power (log10 or model units)")
                plt.title(f"Subject {subj_example}: averaged ROI PSD and main FOOOF fit")
                plt.grid(alpha=0.3)
                plt.legend()
                plt.tight_layout()
                plt.show()


### Verification: FOOOF/specparam fitting behavior

These plots help verify that the FOOOF/specparam model is behaving sensibly.

Typical checks:
- The modeled spectrum tracks the empirical PSD over the fit range.
- Detected peaks fall in plausible frequency ranges.
- The selected alpha-related peak/feature is consistent across epochs.

Use this when debugging feature behavior or comparing settings.


In [None]:
# -------------------- FOOOF VERIFICATION PLOTS --------------------
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

# Optional deps that might be used here
try:
    import seaborn as sns
except Exception:
    sns = None

# ---- safety & defaults ----
# Gracefully handle FitError not being imported
try:
    _FitErrorBase = FitError  # noqa: F821
except Exception:
    class _FitErrorBase(Exception):
        pass

# Ensure an alpha range exists for model fitting
if 'ALPHA_FREQ_RANGE' not in globals() or ALPHA_FREQ_RANGE is None:
    if 'ALPHA_BAND' in globals() and ALPHA_BAND is not None:
        ALPHA_FREQ_RANGE = ALPHA_BAND
    else:
        ALPHA_FREQ_RANGE = (7.0, 14.0)  # sensible default

# Helper to extract the modeled spectrum across fooof/specparam versions
def _get_modeled_spectrum(model):
    # Known attr variants across versions
    for name in ('fooofed_spectrum_', 'modeled_spectrum_', 'model_spectrum_', 'model_spectrum__'):
        if hasattr(model, name):
            return getattr(model, name)
    # API method on some versions
    get_fun = getattr(model, 'get_model_spectrum', None)
    if callable(get_fun):
        try:
            return get_fun()
        except Exception:
            pass
    return None

# Helper to extract the frequency vector used by the model
def _get_model_freqs(model, psd_freqs, modeled, alpha_range):
    for attr in ("freqs", "freqs_", "model_freqs_", "model_freqs"):
        if hasattr(model, attr):
            freqs = np.asarray(getattr(model, attr))
            if freqs is not None and freqs.size == np.asarray(modeled).size:
                return freqs
    # Fallback: try to map to the input freqs via the fit range
    lo, hi = alpha_range
    mask = (psd_freqs >= lo) & (psd_freqs <= hi)
    if modeled is not None and mask.sum() == np.asarray(modeled).size:
        return psd_freqs[mask]
    # Last resort: evenly space within the range with the correct length
    return np.linspace(lo, hi, np.asarray(modeled).size)

# ---- main guard ----
if not USE_FOOOF or SpectralModel is None:
    print("FOOOF verification skipped because USE_FOOOF=False or no compatible backend is available.")
else:
    rng = np.random.default_rng(42)
    class_lookup = {0: "Eyes Open", 1: "Eyes Closed"}

    # Pick channels present in PSD_META; fall back to the first channel if none match
    channels_for_plot = [ch for ch in TARGET_CHANNELS if ch in PSD_META['channels']]
    if not channels_for_plot:
        channels_for_plot = [PSD_META['channels'][0]]
    channel_indices = [PSD_META['channels'].index(ch) for ch in channels_for_plot]

    # Sample a few epochs per class
    samples_per_class = 2
    sample_records = []
    for label in (0, 1):
        idx_pool = np.where(y_combined == label)[0]
        if idx_pool.size == 0:
            continue
        take = min(samples_per_class, idx_pool.size)
        chosen = rng.choice(idx_pool, size=take, replace=False)
        for epoch_idx in chosen:
            sample_records.append({
                "epoch_index": int(epoch_idx),
                "label": label,
                "label_name": class_lookup.get(label, str(label)),
                "subject_id": int(subject_ids[epoch_idx]),
            })
    if not sample_records:
        raise RuntimeError("Could not locate any epochs for FOOOF verification plots.")

    # Fit per sample/channel and collect outputs
    plot_payloads = []
    for rec in sample_records:
        channel_payload = {}
        for ch_name, ch_idx in zip(channels_for_plot, channel_indices):
            spectrum = psd_cube[rec["epoch_index"], ch_idx, :]
            if spectrum.ndim != 1 or not np.any(np.isfinite(spectrum)):
                continue
            try:
                model = SpectralModel(**FOOOF_SETTINGS)
                # Most implementations expect linear power; keep as-is
                model.fit(psd_freqs, spectrum, freq_range=ALPHA_FREQ_RANGE)

                # Peaks: (CF, Amp, BW) rows; empty if none
                peaks = np.asarray(getattr(model, "peak_params_", []))

                # Modeled spectrum (log or linear depending on backend)
                modeled = _get_modeled_spectrum(model)

                # Frequencies the model actually used (length must match modeled)
                fit_freqs = _get_model_freqs(model, psd_freqs, modeled, ALPHA_FREQ_RANGE)
            except (_FitErrorBase, RuntimeError, ValueError, np.linalg.LinAlgError, FloatingPointError):
                peaks, modeled, fit_freqs = np.empty((0, 3)), None, None

            channel_payload[ch_name] = {
                "spectrum": spectrum,
                "modeled": modeled,
                "fit_freqs": fit_freqs,
                "peaks": peaks,
            }
        plot_payloads.append({**rec, "channels": channel_payload})

    # Plot PSDs and fits with alpha markers
    n_cols = len(plot_payloads)
    n_rows = len(channels_for_plot)
    fig, axes = plt.subplots(n_rows, n_cols, figsize=(4 * n_cols, 3.2 * n_rows), sharex=True, sharey=False)
    if n_rows == 1:
        axes = np.expand_dims(axes, axis=0)
    if n_cols == 1:
        axes = np.expand_dims(axes, axis=1)

    # Use ALPHA_BAND for shading if available; else ALPHA_FREQ_RANGE
    _alpha_lo, _alpha_hi = (ALPHA_BAND if 'ALPHA_BAND' in globals() and ALPHA_BAND is not None
                            else ALPHA_FREQ_RANGE)

    for col, payload in enumerate(plot_payloads):
        for row, ch_name in enumerate(channels_for_plot):
            ax = axes[row, col]
            channel_info = payload["channels"].get(ch_name)
            if not channel_info:
                ax.text(0.5, 0.5, f"{ch_name} missing", ha="center", va="center", transform=ax.transAxes)
                ax.set_axis_off()
                continue

            # Empirical PSD (log10 for readability)
            y_emp = np.log10(np.maximum(channel_info["spectrum"], 1e-30))
            ax.plot(psd_freqs, y_emp, label="PSD", linewidth=1.5, color="#1f77b4")

            # Modeled fit (use model's own frequency vector)
            if channel_info["modeled"] is not None:
                modeled = np.asarray(channel_info["modeled"]).squeeze()
                if modeled.ndim != 1:
                    modeled = modeled.reshape(-1)
                x_mod = channel_info.get("fit_freqs")
                if x_mod is None or np.asarray(x_mod).shape[0] != modeled.shape[0]:
                    lo, hi = (_alpha_lo, _alpha_hi)
                    x_mod = np.linspace(lo, hi, modeled.shape[0])

                # Keep y-scale consistent with y_emp (log10)
                # Heuristic: fooofed_spectrum_ is already log10-scale
                m_is_log = (np.nanmax(modeled) < 5.0)  # typical log10 PSD range
                y_mod = modeled if m_is_log else np.log10(np.maximum(modeled, 1e-30))

                ax.plot(np.asarray(x_mod), y_mod, linestyle="--", linewidth=1.2, color="#ff7f0e", label="FOOOF fit")

            # Mark peaks within the shaded alpha band (for context)
            peaks = channel_info["peaks"]
            mask = (peaks[:, 0] >= _alpha_lo) & (peaks[:, 0] <= _alpha_hi) if peaks.size else np.zeros(0, dtype=bool)
            if mask.any():
                ax.vlines(peaks[mask][:, 0], ymin=np.min(y_emp), ymax=np.max(y_emp),
                          colors="#d62728", linestyles=":", label="Alpha peaks (in band)")
            # Highlight the exact peak used as the feature (may be outside 8–12 Hz)
            sel_cf = None
            if peaks.size:
                # Selection logic mirrors compute_fooof_features via _select_alpha_peak
                try:
                    if "_select_alpha_peak" in globals():
                        chosen = _select_alpha_peak(peaks)
                    else:
                        chosen = None
                    if chosen is not None:
                        sel_cf = float(chosen[0])
                except Exception:
                    sel_cf = None
            if sel_cf is not None:
                ax.vlines([sel_cf], ymin=np.min(y_emp), ymax=np.max(y_emp), colors="#e41a1c", linestyles="-", linewidth=2.0, label="Selected feature peak")

            # Shade alpha band
            ax.axvspan(_alpha_lo, _alpha_hi, color="#ffbf00", alpha=0.15)

            if row == 0:
                ax.set_title(f"{payload['label_name']} — subj {payload['subject_id']} — epoch #{payload['epoch_index']}")
            ax.set_ylabel(f"{ch_name}: log10 power")
            ax.set_xlabel("Frequency (Hz)")
            ax.grid(alpha=0.2)

    handles, labels = axes[0, 0].get_legend_handles_labels()
    if handles:
        # Avoid duplicate legend entries
        uniq = dict(zip(labels, handles))
        fig.legend(uniq.values(), uniq.keys(), loc="upper right")
    fig.suptitle("FOOOF verification: PSDs with detected peaks", fontsize=14)
    plt.tight_layout(rect=(0, 0, 1, 0.97))
    plt.show()

    # ----- Distribution of detected alpha-center features from your feature matrix -----
    feature_stride = 5  # [offset, exponent, center, amp, bw] → center is index +2
    channel_to_feat = {ch: idx for idx, ch in enumerate(feature_channels)}
    peak_records = []

    for ch_name, feat_idx in channel_to_feat.items():
        center_col = feat_idx * feature_stride + 2
        if center_col >= X_combined.shape[1]:
            continue
        center_values = X_combined[:, center_col]
        mask = np.isfinite(center_values) & (center_values > 0)
        if not np.any(mask):
            continue
        for freq, label in zip(center_values[mask], y_combined[mask]):
            peak_records.append({
                "channel": ch_name,
                "center_freq": float(freq),
                "label": class_lookup.get(int(label), str(label)),
            })

    if peak_records and sns is not None:
        alpha_df = pd.DataFrame(peak_records)
        plt.figure(figsize=(8, 4))
        lo_hist, hi_hist = ALPHA_FREQ_RANGE
        sns.histplot(alpha_df, x="center_freq", hue="label", multiple="stack",
                     bins=np.linspace(lo_hist, hi_hist, 25))
        plt.xlabel("Detected alpha peak center frequency (Hz)")
        plt.ylabel("Count")
        plt.title("Distribution of detected alpha peaks (all epochs)")
        plt.tight_layout()
        plt.show()
    elif not peak_records:
        print("Alpha peak distribution skipped – the feature matrix did not contain valid center frequencies.")
    else:
        print("Seaborn not available; skipping alpha peak distribution plot.")
# ---- end verification ----


### FOOOF Plot

This plot is a **sanity check** for FOOOF/specparam-based features.

It loads one subject/epoch (EO and EC), then:
- Overlays the empirical PSD with the fitted aperiodic/background component.
- Shows detected peaks and the selected alpha-related peak/features.

Use it when `USE_FOOOF=True` (especially with `ONE_MAIN_FOOOF=True`) to confirm that the fitting range, channel selection, and peak behavior look reasonable.


In [None]:
# ---- Drop-in: PSD + FOOOF/specparam overlays for selected channels (EC vs EO) ----
# Implements:
# • Plot PSD for selected channels.
# • If FOOOF/specparam is enabled, also overlay the FULL model fit and the APERIODIC fit.
# • Toggle y-axis between log10 power and linear power via FOOOF_VIS_LOG_POWER.
# • Two side-by-side columns: Eyes Closed (EC) and Eyes Open (EO).
#
# Assumes the following exist in your notebook:
#   USE_FOOOF, SpectralModel, FOOOF_SETTINGS, TARGET_CHANNELS,
#   eyes_closed_files, eyes_open_files, parse_subject_id, PSD_KWARGS
# and imports: numpy as np, matplotlib.pyplot as plt, warnings, mne, loadmat, Path

# -------------------- CONFIG --------------------
FOOOF_VIS_SUBJECT = int(subject_ids[0])
FOOOF_VIS_EPOCH_RANK = {"EC": 0, "EO": 0}
FOOOF_VIS_FREQ_RANGE = (3.0, 40.0)
FOOOF_VIS_LOG_POWER = True  # Toggle: True = log10(y), False = linear(y)

# -------------------- UTILS ---------------------
# Some environments don't have FitError imported; use a safe fallback
try:
    _FitErrorBase = FitError  # noqa: F821
except Exception:
    class _FitErrorBase(Exception):
        pass

def _locate_subject_file(file_list, subj_id):
    for candidate in file_list:
        try:
            if parse_subject_id(candidate) == subj_id:
                return candidate
        except ValueError:
            continue
    return None

def _get_modeled_spectrum(model):
    """Return the model's full fit spectrum (may be log10 or linear depending on backend)."""
    for name in ('fooofed_spectrum_', 'modeled_spectrum_', 'model_spectrum_', 'model_spectrum__'):
        if hasattr(model, name):
            return getattr(model, name)
    fn = getattr(model, 'get_model_spectrum', None)
    if callable(fn):
        try:
            return fn()
        except Exception:
            pass
    return None

def _get_model_freqs(model, freqs_input, modeled, fit_range):
    """Return the frequency vector that corresponds to the modeled array length."""
    modeled = np.asarray(modeled) if modeled is not None else None
    for attr in ("freqs", "freqs_", "model_freqs_", "model_freqs"):
        if hasattr(model, attr):
            f = np.asarray(getattr(model, attr))
            if modeled is None or f.size == modeled.size:
                return f
    # Fallbacks:
    lo, hi = fit_range
    mask = (freqs_input >= lo) & (freqs_input <= hi)
    if modeled is not None and mask.sum() == modeled.size:
        return freqs_input[mask]
    if modeled is not None:
        return np.linspace(lo, hi, modeled.size)
    return freqs_input[(freqs_input >= lo) & (freqs_input <= hi)]

def _compute_aperiodic_curve(model, freqs_for_curve):
    """
    Compute aperiodic-only curve on freqs_for_curve.
    Works for both fixed and knee modes if params are exposed.
    Returns array on LOG10 scale (to match fooof's internal representation).
    """
    params = getattr(model, "aperiodic_params_", None)
    mode = getattr(model, "aperiodic_mode_", getattr(model, "aperiodic_mode", "fixed"))
    if params is None:
        # Some versions have a direct method:
        fn = getattr(model, "get_aperiodic", None)
        if callable(fn):
            try:
                ap = fn()
                return np.asarray(ap)
            except Exception:
                return None
        return None

    p = np.asarray(params).ravel()
    f = np.asarray(freqs_for_curve, dtype=float)
    f = np.clip(f, 1e-6, None)  # avoid log10(0)

    try:
        if (isinstance(mode, str) and "knee" in mode.lower()) or p.size >= 3:
            # Knee model: log10 P = offset - log10(knee + f**exponent)
            offset = p[0]; knee = p[1]; exponent = p[2]
            log10_ap = offset - np.log10(knee + np.power(f, exponent))
        else:
            # Fixed model: log10 P = offset - exponent*log10(f)
            offset = p[0]; exponent = p[1]
            log10_ap = offset - exponent * np.log10(f)
        return log10_ap
    except Exception:
        return None

def _safe_log10(arr):
    return np.log10(np.maximum(np.asarray(arr, dtype=float), 1e-30))

def _load_epoch_payload(file_path, rank_within_marked, label_hint):
    if file_path is None:
        raise RuntimeError(f"No {label_hint} file matched subject {FOOOF_VIS_SUBJECT}.")
    path_obj = Path(file_path).resolve()
    if path_obj.suffix.lower() == '.fif':
        epochs = mne.read_epochs(str(path_obj), preload=False, verbose='ERROR')
        if '_rename_epochs_channels_canonical' in globals():
            epochs = _rename_epochs_channels_canonical(epochs)
    else:
        epochs = mne.io.read_epochs_eeglab(str(path_obj), verbose='ERROR')

    available_channels = [ch for ch in TARGET_CHANNELS if ch in epochs.ch_names]
    if not available_channels:
        raise RuntimeError(f"{path_obj.name} does not contain the requested channels {TARGET_CHANNELS}.")

    if path_obj.suffix.lower() == '.fif':
        want_label = 1 if ('Closed' in str(label_hint)) else 0
        if '_labels_from_epochs_events' in globals():
            labels_a = _labels_from_epochs_events(epochs)
        else:
            labels_a = np.full(len(epochs), -1, dtype=int)
        union_labels = labels_a
        try:
            b_stem = re.sub(r'(sub\\d+)a', r'\\1b', path_obj.stem, flags=re.IGNORECASE)
            b_path = path_obj.with_name(b_stem + path_obj.suffix)
        except Exception:
            b_path = None
        if b_path is not None and b_path.exists() and '_labels_from_epochs_events' in globals():
            try:
                epochs_b = mne.read_epochs(str(b_path), preload=False, verbose='ERROR')
                labels_b = _labels_from_epochs_events(epochs_b)
                if labels_b.shape == labels_a.shape:
                    union_labels = labels_a.copy()
                    take_from_b = (union_labels < 0)
                    union_labels[take_from_b] = labels_b[take_from_b]
                    conflict_mask = (labels_a >= 0) & (labels_b >= 0) & (labels_a != labels_b)
                    union_labels[conflict_mask] = -1
            except Exception:
                pass
        keep_indices = np.where(union_labels == want_label)[0]
    else:
        # Read EEGLAB 'reject' marks if present
        try:
            mat = loadmat(str(path_obj), struct_as_record=False, squeeze_me=True)
            reject_block = mat.get("reject", None)
            if reject_block is not None and hasattr(reject_block, "rejmanual"):
                labels = np.array(reject_block.rejmanual, dtype=int).ravel()
            elif reject_block is not None:
                labels = np.array(reject_block, dtype=int).ravel()
            else:
                labels = np.zeros(len(epochs), dtype=int)
        except Exception as exc:
            warnings.warn(f"Falling back to unlabelled epochs for {path_obj.name}: {exc}")
            labels = np.zeros(len(epochs), dtype=int)

        keep_indices = np.where(labels == 0)[0]
    if keep_indices.size == 0:
        keep_indices = np.arange(len(epochs))
    if rank_within_marked >= keep_indices.size:
        raise IndexError(
            f"Requested epoch rank {rank_within_marked} exceeds the available marked epochs (n={keep_indices.size})."
        )

    target_epoch = int(keep_indices[rank_within_marked])
    single_epoch = epochs[target_epoch:target_epoch + 1]
    psd = single_epoch.compute_psd(**PSD_KWARGS)
    freqs = psd.freqs
    spectra = psd.get_data()[0]

    fooof_enabled = bool(USE_FOOOF and SpectralModel is not None)

    fooof_full, fooof_aper, fooof_freqs = {}, {}, {}
    for ch in available_channels:
        idx = epochs.ch_names.index(ch)
        fooof_full[ch] = None
        fooof_aper[ch] = None
        fooof_freqs[ch] = None
        if not fooof_enabled:
            continue

        spectrum = spectra[idx]
        try:
            model = SpectralModel(**FOOOF_SETTINGS)
            model.fit(freqs, spectrum, freq_range=FOOOF_VIS_FREQ_RANGE)

            modeled = _get_modeled_spectrum(model)          # may be log10 or linear
            fit_freqs = _get_model_freqs(model, freqs, modeled, FOOOF_VIS_FREQ_RANGE)

            # Ensure 1D & same length for frequency & modeled
            if modeled is not None:
                modeled = np.asarray(modeled).reshape(-1)
                if fit_freqs.shape[0] != modeled.shape[0]:
                    # Last-resort alignment: interpolate modeled to fit_freqs length if needed
                    # (but usually _get_model_freqs ensures equality)
                    x_tmp = np.linspace(fit_freqs.min(), fit_freqs.max(), modeled.shape[0])
                    modeled = np.interp(fit_freqs, x_tmp, modeled)

            # Aperiodic curve (log10 scale)
            log10_ap = _compute_aperiodic_curve(model, fit_freqs)

            fooof_full[ch] = modeled
            fooof_aper[ch] = log10_ap  # store LOG10 curve; we'll convert later if needed
            fooof_freqs[ch] = fit_freqs

        except (_FitErrorBase, RuntimeError, ValueError, np.linalg.LinAlgError, FloatingPointError) as exc:
            warnings.warn(f"FOOOF failed for {path_obj.name} channel {ch}: {exc}")

    payload = {
        "file": str(path_obj),
        "epoch_index": target_epoch,
        "freqs": freqs,
        "spectra": {ch: spectra[epochs.ch_names.index(ch)] for ch in available_channels},
        "fooof_full": fooof_full,
        "fooof_aper": fooof_aper,     # LOG10 curve if present
        "fooof_freqs": fooof_freqs,
        "available_channels": available_channels,
        "sfreq": float(epochs.info['sfreq']),
        "time_series": single_epoch.get_data()[0][[epochs.ch_names.index(ch) for ch in available_channels], :],
        "label": label_hint,
    }
    return payload

# -------------------- LOAD EC/EO PAYLOADS ---------------------
# Prefer locating files via `records` (works even when NEW+OLD are both loaded).
closed_file = None
open_file = None
try:
    _sid = int(FOOOF_VIS_SUBJECT)
except Exception:
    _sid = None
if _sid is not None and 'records' in globals():
    recs_subj = [r for r in records if int(r.get('subject', -1)) == _sid]
    if recs_subj:
        if any(str(r.get('dataset', '')) == 'old' for r in recs_subj):
            rec_ec = next((r for r in recs_subj if int(r.get('label', -1)) == 1), None)
            rec_eo = next((r for r in recs_subj if int(r.get('label', -1)) == 0), None)
            closed_file = rec_ec['file'] if rec_ec is not None else recs_subj[0]['file']
            open_file = rec_eo['file'] if rec_eo is not None else recs_subj[0]['file']
        else:
            # NEW dataset: one file contains both EC/EO labeled epochs
            closed_file = recs_subj[0]['file']
            open_file = recs_subj[0]['file']

# Fallback to legacy file-location logic if needed
if closed_file is None or open_file is None:
    if NEW_DATA:
        subj_file = None
        for sid, fa, fb in new_subject_pairs:
            if int(sid) == int(FOOOF_VIS_SUBJECT):
                subj_file = fa
                break
        closed_file = subj_file
        open_file = subj_file
    else:
        closed_file = _locate_subject_file(eyes_closed_files, FOOOF_VIS_SUBJECT)
        open_file   = _locate_subject_file(eyes_open_files,  FOOOF_VIS_SUBJECT)

closed_payload = _load_epoch_payload(closed_file, FOOOF_VIS_EPOCH_RANK['EC'], 'Eyes Closed')
open_payload   = _load_epoch_payload(open_file,  FOOOF_VIS_EPOCH_RANK['EO'], 'Eyes Open')

shared_channels = [ch for ch in TARGET_CHANNELS if (ch in closed_payload['spectra'] or ch in open_payload['spectra'])]
if not shared_channels:
    shared_channels = list(closed_payload['spectra'].keys()) or list(open_payload['spectra'].keys())

conditions = [("Eyes Closed", closed_payload), ("Eyes Open", open_payload)]

# -------------------- PLOTTING ---------------------
fig, axes = plt.subplots(len(shared_channels), len(conditions), figsize=(12, 4 * len(shared_channels)), sharex=False)
if len(shared_channels) == 1:
    axes = np.expand_dims(axes, axis=0)

for row, ch in enumerate(shared_channels):
    for col, (label, payload) in enumerate(conditions):
        ax = axes[row, col]
        spectrum = payload['spectra'].get(ch)
        if spectrum is None:
            ax.text(0.5, 0.5, f"{ch} missing", ha="center", va="center", transform=ax.transAxes)
            ax.set_axis_off()
            continue

        freqs = payload['freqs']
        # PSD (empirical)
        if FOOOF_VIS_LOG_POWER:
            y_emp = _safe_log10(spectrum)
            y_label = "log10 Power"
        else:
            y_emp = np.asarray(spectrum)
            y_label = "Power (AU)"
        mask = (freqs >= FOOOF_VIS_FREQ_RANGE[0]) & (freqs <= FOOOF_VIS_FREQ_RANGE[1])
        ax.plot(freqs[mask], y_emp[mask], label=f"PSD ({ch})", linewidth=1.5)

        # Full model overlay (if available)
        model_freqs = payload['fooof_freqs'].get(ch)
        full_curve  = payload['fooof_full'].get(ch)

        if model_freqs is not None and full_curve is not None:
            model_freqs = np.asarray(model_freqs).reshape(-1)
            full_curve  = np.asarray(full_curve).reshape(-1)
            # If lengths still mismatch, resample y to x via interpolation
            if model_freqs.shape[0] != full_curve.shape[0]:
                x_tmp = np.linspace(model_freqs.min(), model_freqs.max(), full_curve.shape[0])
                full_curve = np.interp(model_freqs, x_tmp, full_curve)

            # Determine whether model curve is already log10 (FOOOFed) or linear
            m_is_log = (np.nanmax(full_curve) < 5.0)  # typical log10 PSD range
            y_full = full_curve if m_is_log else _safe_log10(full_curve)
            if not FOOOF_VIS_LOG_POWER and m_is_log:
                # convert LOG10 back to linear if requested
                y_full = np.power(10.0, y_full)

            ax.plot(model_freqs, y_full, linestyle='--', linewidth=1.2, label='Model fit')

        # Aperiodic overlay (if available)
        ap_log10 = payload['fooof_aper'].get(ch)  # stored as LOG10 if computed
        if model_freqs is not None and ap_log10 is not None:
            ap_log10 = np.asarray(ap_log10).reshape(-1)

            # If lengths mismatch against model_freqs, interpolate
            if ap_log10.shape[0] != np.asarray(model_freqs).shape[0]:
                x_tmp = np.linspace(model_freqs.min(), model_freqs.max(), ap_log10.shape[0])
                ap_log10 = np.interp(model_freqs, x_tmp, ap_log10)

            y_ap = ap_log10
            if not FOOOF_VIS_LOG_POWER:
                y_ap = np.power(10.0, y_ap)

            ax.plot(model_freqs, y_ap, linestyle=':', linewidth=1.2, label='Aperiodic fit')

        ax.set_title(f"{label} — {Path(payload['file']).name} — epoch {payload['epoch_index']} — {ch}")
        ax.set_xlabel("Frequency (Hz)")
        ax.set_ylabel(y_label)
        ax.grid(alpha=0.2)
        ax.legend(loc='upper right')

fig.suptitle(f"Subject {FOOOF_VIS_SUBJECT}: PSD vs. Model & Aperiodic fits", fontsize=14, y=0.98)
plt.tight_layout(rect=(0, 0, 1, 0.96))
plt.show()

# Optionally keep a small summary dict of what you just plotted
selected_epoch_data = {
    "subject_id": FOOOF_VIS_SUBJECT,
    "sfreq": float(closed_payload['sfreq']),
    "channels": shared_channels,
    "eyes_closed": {
        "avg_signal": closed_payload['time_series'].mean(axis=0),
        "epoch_index": closed_payload['epoch_index'],
        "file": closed_payload['file'],
    },
    "eyes_open": {
        "avg_signal": open_payload['time_series'].mean(axis=0),
        "epoch_index": open_payload['epoch_index'],
        "file": open_payload['file'],
    },
}
# ---- End drop-in ----


### Spectrogram plot

This visualization shows **time–frequency structure** for a selected epoch.

It is useful to:
- Spot obvious artifacts (broadband bursts, line noise harmonics).
- Compare EO vs EC structure qualitatively.
- Confirm that epoch duration and sampling rate are interpreted correctly.


In [None]:
# ---- Drop-in: Spectrogram visualization for selected epoch data ----

if 'selected_epoch_data' not in globals():
    raise RuntimeError("Run the PSD/FOOOF visualization cell first to populate selected_epoch_data.")

sfreq = selected_epoch_data['sfreq']
window = 0.5
step = 0.05
fmax = 45.0


def _spectrogram(signal, sfreq, window_sec, step_sec, fmax_hz):
    signal = np.asarray(signal, float)
    if signal.size == 0:
        return None
    nperseg = max(32, int(round(window_sec * sfreq)))
    hop = max(1, int(round(step_sec * sfreq)))
    nperseg = min(nperseg, signal.size)
    noverlap = max(0, nperseg - hop)
    freqs, times, Zxx = stft(signal - np.mean(signal), fs=sfreq, nperseg=nperseg, noverlap=noverlap, boundary='zeros', padded=True)
    if freqs.size == 0 or times.size == 0:
        return None
    mask = freqs <= fmax_hz
    if not np.any(mask):
        mask = slice(None)
    freqs = freqs[mask]
    Zxx = Zxx[mask]
    power = 10.0 * np.log10(np.maximum(np.abs(Zxx) ** 2, 1e-30))
    return times, freqs, power

spectra_payloads = {
    'Eyes Closed': selected_epoch_data['eyes_closed'],
    'Eyes Open': selected_epoch_data['eyes_open'],
}
results = {}
all_power = []
for label, payload in spectra_payloads.items():
    spec = _spectrogram(payload['avg_signal'], sfreq, window, step, fmax)
    results[label] = spec
    if spec is not None:
        all_power.append(spec[2].ravel())
if not any(results.values()):
    raise RuntimeError("STFT computation failed for both conditions.")
if all_power:
    concatenated = np.concatenate(all_power)
    vmin, vmax = np.percentile(concatenated, [5, 95])
else:
    vmin = vmax = None

fig, axes = plt.subplots(1, 2, figsize=(12, 4), sharey=True)
axes = np.atleast_1d(axes)
for ax, (label, payload) in zip(axes, spectra_payloads.items()):
    spec = results[label]
    if spec is None:
        ax.text(0.5, 0.5, 'No spectrogram', ha='center', va='center', transform=ax.transAxes)
        ax.set_axis_off()
        continue
    times, freqs, power = spec
    mesh = ax.pcolormesh(times, freqs, power, shading='auto', cmap='magma', vmin=vmin, vmax=vmax)
    ax.set_title(f"{label} — epoch {payload['epoch_index']}")
    ax.set_xlabel('Time (s)')
    ax.set_ylabel('Frequency (Hz)')
    fig.colorbar(mesh, ax=ax, pad=0.02, label='Power (dB)')
fig.suptitle(
    f"Subject {selected_epoch_data['subject_id']} — spectrogram of averaged {', '.join(selected_epoch_data['channels'])}",
    fontsize=14,
    y=0.98,
)
plt.tight_layout(rect=(0, 0, 1, 0.95))
plt.show()


## Training the logistic regression model

This section trains and evaluates a **subject-wise** EC/EO classifier.

Key ideas:
- Splits are done by **subject**, not by epoch, to prevent leakage.
- An inner loop chooses hyperparameters (e.g., `C`, and PSD binning when enabled).
- Optional temporal smoothing (`USE_TIME_ADJUSTMENT`) can be evaluated on out-of-fold predictions.

Run this after feature extraction has produced `X_combined`, `y_combined`, and `subject_ids`.


The inner loop uses **leave-one-subject-out (LOSO)** validation on the training subjects.

**What the next code cell does**
- Builds the outer evaluation split(s) based on `CV_LEVEL`.
- For each outer split:
  - Runs an inner LOSO sweep over `C_GRID` (and `FREQ_BIN_OPTIONS` when PSD+binning is enabled).
  - Fits the final logistic regression model with the chosen hyperparameters.
  - Stores per-epoch predictions and fold metadata for later plots/exports.
- Computes aggregated metrics and saves artifacts (CSV/NPZ/joblib) under `outputs/<config_tag>/`.

**Important inputs/toggles**
- Feature mode: `USE_FOOOF`, `FOOOF_SELECTED_FEATURES` vs PSD settings.
- CV settings: `CV_LEVEL`, `CV_TEST_SUBJECTS_PER_SPLIT`, `CV_REPEAT_COUNT`, `CV_RANDOM_SEED`.
- Temporal smoothing: `USE_TIME_ADJUSTMENT`, `MIN_RUN_LENGTH`, `TIME_AXIS_MODE`.


In [None]:
# ---- Cross-validated logistic regression with inner LOSO hyper-parameter tuning ----

C_grid = C_GRID
n_bins_grid = FREQ_BIN_OPTIONS if (not USE_FOOOF and USE_FREQ_BINNING) else [None]

# Logistic regression penalty grid (L1 vs L2)
try:
    penalty_grid = list(LOGREG_PENALTY_OPTIONS) if TUNE_LOGREG_PENALTY else [str(LOGREG_PENALTY_FIXED)]
except Exception:
    penalty_grid = ["l2"]
penalty_grid = [str(p).lower() for p in penalty_grid if p is not None]
if not penalty_grid:
    penalty_grid = ["l2"]

def _make_logreg(C_value: float, penalty: str):
    pen = str(penalty).lower()
    # L1 requires a solver that supports it (saga/liblinear). Use saga so L1 and L2 are comparable.
    use_saga = (pen == "l1") or (pen == "elasticnet") or (pen == "l2" and ("TUNE_LOGREG_PENALTY" in globals() and TUNE_LOGREG_PENALTY))
    solver = "saga" if use_saga else "lbfgs"
    kwargs = dict(C=float(C_value), penalty=pen, solver=solver, max_iter=int(LOGREG_MAX_ITER), class_weight=CLASS_WEIGHT)
    if solver == "saga":
        kwargs["random_state"] = int(CV_RANDOM_SEED)
    return LogisticRegression(**kwargs)
unique_subjects = np.unique(subject_ids)
if unique_subjects.size < 2:
    raise ValueError("Need at least two unique subjects for cross-validation.")

def _prepare_psd_matrix(psd_block, n_bins):
    """Reduce PSD to (n_samples, n_features) for a given n_bins.
    If PSD_FEATURE_RANGE is set, only keep bins whose frequency span
    overlaps the requested range.
    """
    reduced = reduce_freq_resolution(psd_block, n_bins)  # (n_samples, n_channels, n_bins)
    n_samples, n_channels, n_bins_eff = reduced.shape
    if PSD_FEATURE_RANGE is not None:
        fmin_sel, fmax_sel = PSD_FEATURE_RANGE
        freqs = np.asarray(PSD_META['freqs'], float)
        n_freqs = freqs.size
        if n_freqs == 0:
            return reduced.reshape(n_samples, n_channels * n_bins_eff)
        bin_size = n_freqs // n_bins
        if bin_size <= 0:
            raise ValueError(f"n_bins={n_bins} is too high for n_freqs={n_freqs}")
        keep_mask = []
        for b in range(n_bins):
            start = b * bin_size
            end = (b + 1) * bin_size - 1
            if start < 0 or end >= n_freqs:
                keep_mask.append(False)
                continue
            f_lo = freqs[start]
            f_hi = freqs[end]
            # Keep this bin if it overlaps the requested range
            keep_mask.append(not (f_hi < fmin_sel or f_lo > fmax_sel))
        keep_mask = np.asarray(keep_mask, dtype=bool)
        if not np.any(keep_mask):
            raise ValueError(
                f"PSD_FEATURE_RANGE={PSD_FEATURE_RANGE} excluded all bins for n_bins={n_bins}"
            )
        reduced = reduced[:, :, keep_mask]
        n_bins_eff = reduced.shape[2]
    return reduced.reshape(n_samples, n_channels * n_bins_eff)

def _get_fooof_feature_layout(channels):
    """Return (indices, names) for the selected FOOOF features.

    The base per-channel order is [offset, exponent, alpha_cf, alpha_amp, alpha_bw].
    FOOOF_SELECTED_FEATURES controls which of these are kept.
    """
    base_order = ["offset", "exponent", "alpha_cf", "alpha_amp", "alpha_bw"]
    try:
        selected = list(FOOOF_SELECTED_FEATURES)
    except Exception:
        selected = base_order
    selected_set = {s for s in selected if s in base_order}
    if not selected_set:
        selected_set = set(base_order)
    indices = []
    names = []
    stride = len(base_order)
    for ch_idx, ch in enumerate(channels):
        base_col = ch_idx * stride
        for offset_idx, feat_name in enumerate(base_order):
            if feat_name not in selected_set:
                continue
            indices.append(base_col + offset_idx)
            names.append(f"{ch}_{feat_name}")
    return indices, names

def _smooth_labels_by_run(
    labels: np.ndarray,
    positions: np.ndarray,
    min_run_interior: int,
    min_run_edge: int,
    use_edge_smoothing: bool = True,
) -> np.ndarray:
    """Run-length smoothing using a time axis.

    Interior runs (surrounded on both sides by the same opposite label)
    whose span in *positions* is < min_run_interior are flipped.

    When use_edge_smoothing is True, short edge runs (at the start or end)
    whose span in *positions* is < min_run_edge are replaced by the
    majority label within a time window of size min_run_edge.
    """
    labels = np.asarray(labels, dtype=int)
    positions = np.asarray(positions, dtype=int)
    if labels.size == 0 or positions.size != labels.size:
        return labels.copy()
    # Identify runs as (start, end, value) with end exclusive
    runs = []
    start = 0
    current = labels[0]
    for i in range(1, labels.size):
        if labels[i] != current:
            runs.append((start, i, current))
            start = i
            current = labels[i]
    runs.append((start, labels.size, current))
    out = labels.copy()
    if not runs:
        return out

    # Interior runs
    if min_run_interior > 1 and len(runs) >= 3:
        for idx in range(1, len(runs) - 1):
            s, e, val = runs[idx]
            prev_val = runs[idx - 1][2]
            next_val = runs[idx + 1][2]
            # Span in original time index (inclusive)
            span = int(positions[e - 1] - positions[s] + 1)
            if prev_val == next_val and prev_val != val and span < min_run_interior:
                out[s:e] = prev_val

    # Edge runs (first and last) by majority vote in a time window
    if use_edge_smoothing and min_run_edge > 1:
        # First run
        if len(runs) >= 1:
            s0, e0, v0 = runs[0]
            span0 = int(positions[e0 - 1] - positions[s0] + 1)
            if span0 < min_run_edge:
                t0 = int(positions[s0])
                t_edge = t0 + min_run_edge - 1
                win_idx = (positions >= t0) & (positions <= t_edge)
                if np.any(win_idx):
                    labels_win = out[win_idx]
                    vals, counts = np.unique(labels_win, return_counts=True)
                    maj = int(vals[np.argmax(counts)])
                    if maj != v0:
                        out[s0:e0] = maj

        # Last run
        if len(runs) >= 2:
            sn, en, vn = runs[-1]
            spann = int(positions[en - 1] - positions[sn] + 1)
            if spann < min_run_edge:
                t_end = int(positions[en - 1])
                t_start = t_end - min_run_edge + 1
                win_idx = (positions >= t_start) & (positions <= t_end)
                if np.any(win_idx):
                    labels_win = out[win_idx]
                    vals, counts = np.unique(labels_win, return_counts=True)
                    maj = int(vals[np.argmax(counts)])
                    if maj != vn:
                        out[sn:en] = maj

    return out

def _get_base_feature_names_for_mode(n_bins=None):
    """Return feature names for the current feature mode.
    When USE_FOOOF is True, names reflect per-channel FOOOF parameters.
    When USE_FOOOF is False, names reflect binned PSD features.
    """
    if USE_FOOOF:
        _, names = _get_fooof_feature_layout(feature_channels)
        return names
    else:
        if n_bins is None:
            n_bins = FREQ_BIN_OPTIONS[0]
        freqs = np.asarray(PSD_META['freqs'], float)
        n_freqs = freqs.size
        if n_freqs == 0:
            return []
        bin_size = n_freqs // n_bins
        if bin_size <= 0:
            raise ValueError(f"n_bins={n_bins} is too high for n_freqs={n_freqs}")
        names = []
        for ch in PSD_META['channels']:
            for b in range(n_bins):
                start = b * bin_size
                end = (b + 1) * bin_size - 1
                if start < 0 or end >= n_freqs:
                    continue
                f_lo = freqs[start]
                f_hi = freqs[end]
                if PSD_FEATURE_RANGE is not None:
                    fmin_sel, fmax_sel = PSD_FEATURE_RANGE
                    if f_hi < fmin_sel or f_lo > fmax_sel:
                        continue
                names.append(f"{ch}_PSD_{f_lo:.2f}-{f_hi:.2f}Hz")
        return names

def _build_outer_plan(subjects):
    group_size = min(max(1, CV_TEST_SUBJECTS_PER_SPLIT), subjects.size)
    repeats = 1 if CV_LEVEL in (1, 2) else max(1, CV_REPEAT_COUNT)
    plan = []
    rng = np.random.default_rng(CV_RANDOM_SEED)
    if CV_LEVEL == 1:
        test_subjects = rng.choice(subjects, size=group_size, replace=False)
        plan.append({"repeat": 0, "fold": 1, "test_subjects": np.sort(test_subjects)})
        return plan
    if CV_LEVEL == 4:
        # Fixed test split based on configured subject IDs
        fixed = np.array(FIXED_TEST_SUBJECTS_LEVEL4, dtype=int)
        available = np.intersect1d(subjects, fixed)
        if available.size == 0:
            raise ValueError("CV_LEVEL=4: none of FIXED_TEST_SUBJECTS_LEVEL4 are present in this dataset.")
        plan.append({"repeat": 0, "fold": 1, "test_subjects": np.sort(available)})
        return plan
    for rep in range(repeats):
        perm = subjects.copy()
        rng_rep = np.random.default_rng(CV_RANDOM_SEED + rep)
        rng_rep.shuffle(perm)
        start = 0
        fold_idx = 0
        while start < perm.size:
            fold_idx += 1
            group = np.sort(perm[start:start + group_size])
            plan.append({"repeat": rep, "fold": fold_idx, "test_subjects": group})
            start += group_size
    return plan

def _run_logreg_fold(plan_entry):
    test_subjects_fold = plan_entry['test_subjects']
    train_subjects_fold = np.setdiff1d(unique_subjects, test_subjects_fold)
    train_idx_fold = np.where(np.isin(subject_ids, train_subjects_fold))[0]
    test_idx_fold = np.where(np.isin(subject_ids, test_subjects_fold))[0]
    if train_idx_fold.size == 0 or test_idx_fold.size == 0:
        raise ValueError("Empty train/test split encountered. Adjust CV settings.")
    y_train_fold = y_combined[train_idx_fold]
    y_test_fold = y_combined[test_idx_fold]
    train_subj_ids = subject_ids[train_idx_fold]
    if train_subj_ids.size < 2:
        raise ValueError("Need at least two training subjects for the inner LOSO loop.")

    if USE_FOOOF:
        # Select only the requested FOOOF feature types
        fooof_indices, _ = _get_fooof_feature_layout(feature_channels)
        if not fooof_indices:
            raise RuntimeError(
                "FOOOF_SELECTED_FEATURES produced no valid features; "
                "check FOOOF_SELECTED_FEATURES at the top of the notebook."
            )
        base_train = X_combined[train_idx_fold]
        base_test = X_combined[test_idx_fold]
        train_matrix = base_train[:, fooof_indices]
        test_matrix = base_test[:, fooof_indices]
        train_cache = {None: train_matrix}
        test_cache = {None: test_matrix}
    else:
        train_psd = psd_cube[train_idx_fold]
        test_psd = psd_cube[test_idx_fold]
        train_cache = {}
        test_cache = {}

        def _get_cache(cache, data_block, n_bins):
            if n_bins not in cache:
                cache[n_bins] = _prepare_psd_matrix(data_block, n_bins)
            return cache[n_bins]
    outer_loo = LeaveOneOut()
    inner_subjects = np.unique(train_subj_ids)
    best_C_list, best_bins_list, best_penalty_list = [], [], []
    val_records = []
    for _, (inner_train_idx, val_sub_idx) in enumerate(outer_loo.split(inner_subjects), start=1):
        inner_train_subjects = inner_subjects[inner_train_idx]
        val_subject = inner_subjects[val_sub_idx[0]]
        mask_train = np.where(np.isin(train_subj_ids, inner_train_subjects))[0]
        mask_val = np.where(train_subj_ids == val_subject)[0]
        best_score = -np.inf
        best_params = None
        best_val_acc = 0.0
        for C in C_grid:
            for n_bins in n_bins_grid:
                if USE_FOOOF:
                    X_inner = train_matrix[mask_train]
                    X_val = train_matrix[mask_val]
                else:
                    n_bins_eval = n_bins if n_bins is not None else FREQ_BIN_OPTIONS[0]
                    X_inner = _get_cache(train_cache, train_psd, n_bins_eval)[mask_train]
                    X_val = _get_cache(train_cache, train_psd, n_bins_eval)[mask_val]

                imputer = SimpleImputer(strategy="constant", fill_value=0.0)
                scaler = StandardScaler()
                X_inner_imp = imputer.fit_transform(X_inner)
                X_val_imp = imputer.transform(X_val)
                X_inner_scaled = scaler.fit_transform(X_inner_imp)
                X_val_scaled = scaler.transform(X_val_imp)

                # Optional component analysis (PCA/ICA) in the inner loop
                X_inner_proc = X_inner_scaled
                X_val_proc = X_val_scaled
                if USE_COMPONENT_ANALYSIS:
                    method = str(COMPONENT_METHOD).lower()
                    n_features_here = X_inner_scaled.shape[1]
                    max_default = min(50, n_features_here)
                    n_comp = COMPONENT_N_COMPONENTS if COMPONENT_N_COMPONENTS is not None else max_default
                    n_comp = max(1, min(int(n_comp), n_features_here))

                    if method == "pca":
                        comp_model_inner = PCA(n_components=n_comp, random_state=CV_RANDOM_SEED)
                    elif method == "ica":
                        comp_model_inner = FastICA(n_components=n_comp, random_state=CV_RANDOM_SEED, max_iter=500)
                    else:
                        comp_model_inner = None

                    if comp_model_inner is not None:
                        X_inner_proc = comp_model_inner.fit_transform(X_inner_scaled)
                        X_val_proc = comp_model_inner.transform(X_val_scaled)

                for penalty in penalty_grid:
                    try:
                        clf = _make_logreg(C, penalty)
                        clf.fit(X_inner_proc, y_train_fold[mask_train])
                        probs = clf.predict_proba(X_val_proc)
                        score = -log_loss(y_train_fold[mask_val], probs)
                        val_acc = accuracy_score(y_train_fold[mask_val], clf.predict(X_val_proc))
                    except Exception:
                        continue
                    if score > best_score:
                        best_score = score
                        best_params = {"C": C, "n_bins": n_bins, "penalty": penalty}
                        best_val_acc = val_acc

        if best_params is None:
            raise RuntimeError("Inner loop failed to find hyper-parameters.")
        best_C_list.append(best_params['C'])
        if best_params.get('n_bins', None) is not None:
            best_bins_list.append(best_params['n_bins'])
        best_penalty_list.append(str(best_params.get('penalty', 'l2')).lower())
        val_records.append({"subject": int(val_subject), "accuracy": float(best_val_acc)})

    selected_C_fold = Counter(best_C_list).most_common(1)[0][0]
    selected_penalty_fold = Counter(best_penalty_list).most_common(1)[0][0] if best_penalty_list else penalty_grid[0]
    if best_bins_list:
        selected_bins_fold = Counter(best_bins_list).most_common(1)[0][0]
    else:
        selected_bins_fold = None
    if USE_FOOOF:
        X_train_features = train_matrix
        X_test_features = test_matrix
    else:
        n_bins_final = selected_bins_fold if selected_bins_fold is not None else FREQ_BIN_OPTIONS[0]
        X_train_features = _get_cache(train_cache, train_psd, n_bins_final)
        X_test_features = _get_cache(test_cache, test_psd, n_bins_final)

    imputer_final = SimpleImputer(strategy="constant", fill_value=0.0)
    scaler_final = StandardScaler()
    X_train_imp = imputer_final.fit_transform(X_train_features)
    X_test_imp = imputer_final.transform(X_test_features)
    X_train_scaled = scaler_final.fit_transform(X_train_imp)
    X_test_scaled = scaler_final.transform(X_test_imp)

    # Diagnostics: feature names before component analysis / PCA
    base_feature_names = _get_base_feature_names_for_mode(
        selected_bins_fold if (not USE_FOOOF) else None
    )
    print(
        f"Fold r{plan_entry['repeat']} f{plan_entry['fold']} - input features for component analysis/logistic (n={len(base_feature_names)}):"
    )
    for fname in base_feature_names:
        print(f"  {fname}")

    # Optional component analysis for the final model in this fold
    X_train_final = X_train_scaled
    X_test_final = X_test_scaled
    component_model_final = None
    component_method_final = None
    if USE_COMPONENT_ANALYSIS:
        method = str(COMPONENT_METHOD).lower()
        component_method_final = method
        n_features_here = X_train_scaled.shape[1]
        max_default = min(50, n_features_here)
        n_comp = COMPONENT_N_COMPONENTS if COMPONENT_N_COMPONENTS is not None else max_default
        n_comp = max(1, min(int(n_comp), n_features_here))

        if method == "pca":
            component_model_final = PCA(n_components=n_comp, random_state=CV_RANDOM_SEED)
        elif method == "ica":
            component_model_final = FastICA(n_components=n_comp, random_state=CV_RANDOM_SEED, max_iter=500)

        if component_model_final is not None:
            X_train_final = component_model_final.fit_transform(X_train_scaled)
            X_test_final = component_model_final.transform(X_test_scaled)

    # Diagnostics: feature names actually used for logistic regression training
    if component_model_final is not None and component_method_final in ("pca", "ica"):
        prefix = "PC" if component_method_final == "pca" else "IC"
        training_feature_names = [f"{prefix}{i+1}" for i in range(X_train_final.shape[1])]
    else:
        training_feature_names = base_feature_names
    print(
        f"Fold r{plan_entry['repeat']} f{plan_entry['fold']} - features used for logistic regression training (n={len(training_feature_names)}):"
    )
    for fname in training_feature_names:
        print(f"  {fname}")

    clf_final = _make_logreg(selected_C_fold, selected_penalty_fold)
    clf_final.fit(X_train_final, y_train_fold)
    y_pred_fold = clf_final.predict(X_test_final)
    y_proba_fold = clf_final.predict_proba(X_test_final)[:, 1]

    fold_result = {
        "repeat": plan_entry['repeat'],
        "fold": plan_entry['fold'],
        "test_subjects": test_subjects_fold,
        "train_subjects": train_subjects_fold,
        "train_idx": train_idx_fold,
        "test_idx": test_idx_fold,
        "y_test": y_test_fold,
        "y_pred": y_pred_fold,
        "y_proba": y_proba_fold,
        "subject_ids": subject_ids[test_idx_fold],
        "time_idx": epoch_time_indices[test_idx_fold],
        "acc": accuracy_score(y_test_fold, y_pred_fold),
        "conf_matrix": confusion_matrix(y_test_fold, y_pred_fold),
        "selected_C": selected_C_fold,
        "selected_penalty": selected_penalty_fold,
        "selected_n_bins": selected_bins_fold,
        "imputer": imputer_final,
        "scaler": scaler_final,
        "model": clf_final,
        "component_model": component_model_final,
        "X_train_features": X_train_features,
        "X_test_features": X_test_features,
        "val_records": val_records,
    }
    return fold_result

# Decide how to form the outer test set.
# - Default: subject-wise CV on the currently loaded dataset(s)
# - TEST_ON_OTHER_DATASET: train on primary (NEW_DATA) and test on the other dataset as a single held-out fold
subjects_new = np.array(sorted({int(r['subject']) for r in records if str(r.get('dataset', '')) == 'new'}), dtype=int)
subjects_old = np.array(sorted({int(r['subject']) for r in records if str(r.get('dataset', '')) == 'old'}), dtype=int)

if (CROSS_DATASET_TEST or USE_BOTH_DATASETS) and CV_LEVEL == 4:
    raise ValueError(
        "CV_LEVEL=4 uses a fixed list of subject IDs; when mixing datasets (offset IDs), this is ambiguous. "
        "Use CV_LEVEL=1/2/3, or update FIXED_TEST_SUBJECTS_LEVEL4 to match the combined subject_ids."
    )

if CROSS_DATASET_TEST:
    if NEW_DATA:
        test_subjects_external = subjects_old
        if test_subjects_external.size == 0:
            raise RuntimeError("CROSS_DATASET_TEST=True but OLD dataset subjects are empty. Resolve .set paths or disable the toggle.")
    else:
        test_subjects_external = subjects_new
        if test_subjects_external.size == 0:
            raise RuntimeError("CROSS_DATASET_TEST=True but NEW dataset subjects are empty. Resolve processed .fif paths or disable the toggle.")
    outer_plan = [{"repeat": 0, "fold": 1, "test_subjects": np.sort(test_subjects_external)}]
else:
    outer_plan = _build_outer_plan(unique_subjects)
print(f"Running logistic regression with {len(outer_plan)} outer folds (level={CV_LEVEL}).")
logreg_cv_folds = []
val_subject_buffer, val_accuracy_buffer = [], []
for plan_entry in outer_plan:
    fold_result = _run_logreg_fold(plan_entry)
    logreg_cv_folds.append(fold_result)
    val_subject_buffer.extend([rec['subject'] for rec in fold_result['val_records']])
    val_accuracy_buffer.extend([rec['accuracy'] for rec in fold_result['val_records']])
    print(f"  Fold r{plan_entry['repeat']} f{plan_entry['fold']}: test subjects {fold_result['test_subjects']} — acc={fold_result['acc']:.3f}")
if not logreg_cv_folds:
    raise RuntimeError("Logistic regression did not run any folds. Check CV configuration.")

# ---- Summary of feature and component counts ----
base_feature_names = _get_base_feature_names_for_mode(
    None if USE_FOOOF else (logreg_cv_folds[0]['selected_n_bins'] or FREQ_BIN_OPTIONS[0])
)
n_base_features = len(base_feature_names)
primary_fold = max(logreg_cv_folds, key=lambda fold: fold['acc'])
component_model = primary_fold.get('component_model')
if component_model is not None and USE_COMPONENT_ANALYSIS:
    if hasattr(component_model, 'components_'):
        n_components_fitted = component_model.components_.shape[0]
    elif hasattr(component_model, 'n_components_'):
        n_components_fitted = int(component_model.n_components_)
    else:
        n_components_fitted = primary_fold['X_train_features'].shape[1]
    # The logistic model sees all columns of X_train_final
    n_components_used = primary_fold['model'].coef_.shape[1]
else:
    n_components_fitted = 0
    n_components_used = 0
print("\n[Logistic regression summary]")
print(f"  Total base input features (before PCA/ICA): {n_base_features}")
if USE_COMPONENT_ANALYSIS and component_model is not None:
    print(f"  Components fitted in best fold: {n_components_fitted}")
    print(f"  Components used by logistic model: {n_components_used}")
else:
    print("  Component analysis disabled for the final model (using base features directly).")

np.save(outpath("val_subject_ids.npy"), np.array(val_subject_buffer, dtype=int))
np.save(outpath("val_accuracies.npy"), np.array(val_accuracy_buffer, dtype=float))
max_group = max(len(entry['test_subjects']) for entry in outer_plan)
cv_matrix = -np.ones((len(outer_plan), max_group), dtype=int)
for row, entry in enumerate(outer_plan):
    arr = entry['test_subjects']
    cv_matrix[row, :arr.size] = arr
np.save(outpath("cv_test_subjects.npy"), cv_matrix)
primary_fold = max(logreg_cv_folds, key=lambda fold: fold['acc'])
logreg_primary_fold = primary_fold
logreg_primary_test_subjects = primary_fold['test_subjects']
logreg_covered_subjects = np.unique(np.concatenate([fold['subject_ids'] for fold in logreg_cv_folds]))

# Persist model artefacts from the best-performing outer fold
final_model = primary_fold['model']
final_scaler = primary_fold['scaler']
final_imputer = primary_fold['imputer']
final_component = primary_fold.get('component_model')
final_C = primary_fold['selected_C']
final_n_bins = primary_fold['selected_n_bins']
np.save(outpath("test_subjects.npy"), primary_fold['test_subjects'])
np.save(outpath("selected_C_lr.npy"), np.array([final_C], dtype=float))
if final_n_bins is not None:
    np.save(outpath("final_n_bins_lr.npy"), np.array([final_n_bins], dtype=int))
joblib.dump(final_model, outpath("final_model_lr.pkl"))
joblib.dump(final_scaler, outpath("final_scaler_lr.pkl"))
joblib.dump(final_imputer, outpath("final_imputer_lr.pkl"))
if final_component is not None:
    joblib.dump(final_component, outpath("final_component_lr.pkl"))

# Assemble aggregated predictions across all folds for downstream diagnostics
test_idx = np.concatenate([fold['test_idx'] for fold in logreg_cv_folds])
y_test = np.concatenate([fold['y_test'] for fold in logreg_cv_folds])
y_pred = np.concatenate([fold['y_pred'] for fold in logreg_cv_folds])
y_proba = np.concatenate([fold['y_proba'] for fold in logreg_cv_folds])
# X_test features may have different widths across folds (different n_bins).
# Try to stack; if shapes differ, keep as a list for QC stats.
try:
    X_test = np.vstack([fold['X_test_features'] for fold in logreg_cv_folds])
except Exception:
    X_test = [fold['X_test_features'] for fold in logreg_cv_folds]
time_idx_all = np.concatenate([fold['time_idx'] for fold in logreg_cv_folds])
logreg_predictions_df = pd.DataFrame({
    'fold': [f"r{fold['repeat']}_f{fold['fold']}" for fold in logreg_cv_folds for _ in range(fold['y_test'].size)],
    'subject_id': np.concatenate([fold['subject_ids'] for fold in logreg_cv_folds]),
    'epoch_idx': test_idx,
    'time_idx': time_idx_all,
    'y_true': y_test,
    'y_pred': y_pred,
    'prob_ec': y_proba,
})
train_idx = primary_fold['train_idx']
train_subjects = primary_fold['train_subjects']
test_subjects = logreg_covered_subjects
y_train = y_combined[train_idx]
X_train = primary_fold['X_train_features']

# Optional: tune run-length smoothing length on out-of-fold predictions
RUN_LENGTH_TUNED = False
if USE_TIME_ADJUSTMENT and LENGTH_TUNING:
    from sklearn.metrics import accuracy_score as _acc_score
    from sklearn.metrics import balanced_accuracy_score as _bal_acc_score

    def _score_smoothing(y_true_arr, y_pred_arr):
        metric = str(LENGTH_TUNING_METRIC).lower().strip()
        if metric == "accuracy":
            return float(_acc_score(y_true_arr, y_pred_arr))
        if metric == "balanced_accuracy":
            return float(_bal_acc_score(y_true_arr, y_pred_arr))
        raise ValueError(f"Unknown LENGTH_TUNING_METRIC: {LENGTH_TUNING_METRIC!r}")

    def _smooth_all_with(L: int) -> np.ndarray:
        out = np.empty_like(logreg_predictions_df['y_pred'].to_numpy())
        for subj in np.unique(logreg_predictions_df['subject_id']):
            mask = logreg_predictions_df['subject_id'] == subj
            df_subj = logreg_predictions_df.loc[mask].sort_values('time_idx')
            smoothed = _smooth_labels_by_run(
                df_subj['y_pred'].to_numpy(),
                df_subj['time_idx'].to_numpy(),
                int(L),
                int(L),
                use_edge_smoothing=USE_EDGE_SMOOTHING,
            )
            out[df_subj.index.to_numpy()] = smoothed
        return out

    grid = [int(x) for x in (LENGTH_GRID if LENGTH_GRID is not None else [])]
    grid = sorted({x for x in grid if x >= 1})
    if not grid:
        grid = [1]
    best_L = None
    best_score = -1.0
    y_true_all = logreg_predictions_df['y_true'].to_numpy()
    for L in grid:
        y_sm = _smooth_all_with(L)
        s = _score_smoothing(y_true_all, y_sm)
        if s > best_score:
            best_score = s
            best_L = L
    MIN_RUN_LENGTH = int(best_L)
    MIN_RUN_LENGTH_EDGE = int(best_L)
    RUN_LENGTH_TUNED = True
    print(f"[Temporal smoothing tuning] metric={LENGTH_TUNING_METRIC}, grid={grid} -> selected L={best_L} (score={best_score:.3f})")
else:
    if USE_TIME_ADJUSTMENT:
        print(f"[Temporal smoothing] tuning disabled; using MIN_RUN_LENGTH={MIN_RUN_LENGTH}, MIN_RUN_LENGTH_EDGE={MIN_RUN_LENGTH_EDGE}")

# Expose the selected smoothing lengths for later display
SELECTED_RUN_LENGTH = int(MIN_RUN_LENGTH) if USE_TIME_ADJUSTMENT else None
SELECTED_RUN_LENGTH_EDGE = int(MIN_RUN_LENGTH_EDGE) if USE_TIME_ADJUSTMENT else None

# Optional temporal smoothing of predicted labels
if USE_TIME_ADJUSTMENT:
    smoothed_all = np.empty_like(y_pred)
    for subj in np.unique(logreg_predictions_df['subject_id']):
        mask = logreg_predictions_df['subject_id'] == subj
        df_subj = logreg_predictions_df.loc[mask].sort_values('time_idx')
        smoothed = _smooth_labels_by_run(
            df_subj['y_pred'].to_numpy(),
            df_subj['time_idx'].to_numpy(),
            MIN_RUN_LENGTH,
            MIN_RUN_LENGTH_EDGE,
            use_edge_smoothing=USE_EDGE_SMOOTHING,
        )
        # map back into the global array using the sorted index positions
        smoothed_all[df_subj.index.to_numpy()] = smoothed
    logreg_predictions_df['y_pred_smooth'] = smoothed_all
else:
    logreg_predictions_df['y_pred_smooth'] = logreg_predictions_df['y_pred'].to_numpy()

# Choose which predictions to use for evaluation
if USE_TIME_ADJUSTMENT:
    eval_pred = logreg_predictions_df['y_pred_smooth'].to_numpy()
else:
    eval_pred = y_pred
acc = accuracy_score(y_test, eval_pred)
report = classification_report(y_test, eval_pred)
conf_matrix = confusion_matrix(y_test, eval_pred)

# Save per-epoch probabilities for this configuration
config_tag = "fooof" if USE_FOOOF else "psd"
np.save(outpath(f"{config_tag}_epoch_idx.npy"), test_idx)
np.save(outpath(f"{config_tag}_time_idx.npy"), epoch_time_indices[test_idx])
np.save(outpath(f"{config_tag}_y_true.npy"), y_test)
np.save(outpath(f"{config_tag}_prob_ec.npy"), y_proba)


### Temporal smoothing: raw vs smoothed predictions

If `USE_TIME_ADJUSTMENT=True`, the training cell can create a *smoothed* prediction per epoch based on run-length rules.

This plotting cell:
- Visualizes per-subject prediction timelines.
- Lets you compare raw predictions vs smoothed predictions.

It requires outputs from the logistic regression CV cell (predictions + time indices).


In [None]:
# ---- Temporal label plots: raw vs smoothed ----
import numpy as np
import matplotlib.pyplot as plt

if 'logreg_predictions_df' not in globals():
    raise RuntimeError("Run the logistic regression CV cell first to populate logreg_predictions_df.")

# Ensure we have the smoothed column available
if 'y_pred_smooth' not in logreg_predictions_df.columns:
    logreg_predictions_df['y_pred_smooth'] = logreg_predictions_df['y_pred'].to_numpy()

# Plot configuration: choose subjects and time window
# plot_subjects can be "all" or a list like [23, 24]
plot_subjects = "all" # [10213, 10175, 10139, 10136, 10135, 10002]
# time_xlim can be None or a tuple like (0, 1000)
time_xlim = None

subjects = {}
for sid in np.sort(logreg_predictions_df['subject_id'].unique()):
    subjects[int(sid)] = None

if isinstance(plot_subjects, str) and plot_subjects.strip().lower() == "all":
    subject_ids_sorted = sorted(subjects.keys())
else:
    try:
        subj_list = list(plot_subjects)
    except TypeError:
        subj_list = [plot_subjects]
    subject_ids_sorted = [int(s) for s in subj_list]

colors = {0: "blue", 1: "red"}
time_xlabel = {
    "append_files": "Epoch timeline index (files appended)",
    "align_conditions": "Epoch number (EO/EC aligned)",
    "interleave_conditions": "Epoch timeline index (EO→EC interleaved)",
}.get(TIME_AXIS_MODE, "Epoch timeline index")

# Helper to build a figure for a given column name

def _plot_labels_per_subject(column_name: str, title: str):
    fig, ax = plt.subplots(figsize=(16, 6))
    for i, subj_id in enumerate(subject_ids_sorted):
        df_subj = logreg_predictions_df[logreg_predictions_df['subject_id'] == subj_id].copy()
        df_subj = df_subj.sort_values('time_idx')
        x_base = df_subj['time_idx'].to_numpy()
        y = np.full_like(x_base, i)
        labels = df_subj[column_name].to_numpy()
        for label_val in [0, 1]:
            idx = labels == label_val
            if not np.any(idx):
                continue
            x = x_base
            if TIME_AXIS_MODE == "align_conditions":
                x = x_base + (-0.15 if label_val == 0 else 0.15)
            ax.plot(x[idx], y[idx], 'o', color=colors[label_val], markersize=4)
    ax.set_yticks(range(len(subject_ids_sorted)))
    ax.set_yticklabels(subject_ids_sorted)
    if time_xlim is not None:
        ax.set_xlim(time_xlim)
    ax.set_xlabel(time_xlabel)
    ax.set_ylabel("Subject ID")
    ax.set_title(title)
    legend_elements = [
        plt.Line2D([0], [0], marker='o', color='w', label='Eyes Open', markerfacecolor='blue', markersize=6),
        plt.Line2D([0], [0], marker='o', color='w', label='Eyes Closed', markerfacecolor='red', markersize=6),
    ]
    ax.legend(handles=legend_elements, title="Epoch Labels", bbox_to_anchor=(1.01, 1), loc='upper left')
    ax.grid(axis='x', linestyle='--', alpha=0.3)
    plt.tight_layout()
    plt.show()

# Plot raw predictions
_plot_labels_per_subject('y_pred', "Raw classifier labels per subject")

# Plot smoothed predictions
_plot_labels_per_subject('y_pred_smooth', f"Smoothed classifier labels per subject (min_run_length={MIN_RUN_LENGTH}, USE_TIME_ADJUSTMENT={USE_TIME_ADJUSTMENT})")

# Plot true labels
_plot_labels_per_subject('y_true', "True labels per subject")

# ---- Per-subject zoom plots (raw vs smoothed) ----
# Set a subject ID here to inspect in detail
zoom_subject_id = subject_ids_sorted[0] if subject_ids_sorted else None
if zoom_subject_id is not None:
    df_zoom = logreg_predictions_df[logreg_predictions_df['subject_id'] == zoom_subject_id].copy()
    df_zoom = df_zoom.sort_values('time_idx')
    x_base = df_zoom['time_idx'].to_numpy()
    raw = df_zoom['y_pred'].to_numpy()
    smooth = df_zoom['y_pred_smooth'].to_numpy()

    fig, ax = plt.subplots(figsize=(16, 4))
    for label_val, marker, label_name in [(0, 'o', 'EO raw'), (1, 'o', 'EC raw')]:
        idx = raw == label_val
        if np.any(idx):
            x = x_base
            if TIME_AXIS_MODE == "align_conditions":
                x = x_base + (-0.15 if label_val == 0 else 0.15)
            ax.plot(x[idx], raw[idx] + 0.0, marker, color=colors[label_val], markersize=4, linestyle='None', label=label_name)
    for label_val, marker, label_name in [(0, 'x', 'EO smooth'), (1, 'x', 'EC smooth')]:
        idx = smooth == label_val
        if np.any(idx):
            x = x_base
            if TIME_AXIS_MODE == "align_conditions":
                x = x_base + (-0.15 if label_val == 0 else 0.15)
            ax.plot(x[idx], smooth[idx] + 0.1, marker, color=colors[label_val], markersize=5, linestyle='None', label=label_name)
    ax.set_xlabel(time_xlabel)
    ax.set_ylabel("Label")
    ax.set_yticks([0, 1])
    ax.set_yticklabels(['EO', 'EC'])
    ax.set_title(f"Raw vs smoothed labels for subject {zoom_subject_id}")
    ax.legend(loc='upper right')
    ax.grid(axis='x', linestyle='--', alpha=0.3)
    plt.tight_layout()
    plt.show()
else:
    print("No subjects available for zoom plot.")


### Component analysis helper: PCA elbow plot

When `USE_COMPONENT_ANALYSIS=True`, features are transformed with PCA/ICA before classification.

This cell:
- Fits PCA on the (imputed + scaled) **training** features from the primary fold.
- Plots cumulative explained variance vs number of components.

Use it to choose `COMPONENT_N_COMPONENTS` (e.g., the smallest k that explains ~90% variance).


In [None]:
# ---- Elbow plot for component analysis (PCA) ----
if not USE_COMPONENT_ANALYSIS:
    print("Component analysis is disabled (USE_COMPONENT_ANALYSIS=False). Enable it in the config cell to use PCA/ICA.")
else:
    if 'primary_fold' not in globals():
        raise RuntimeError("Run the logistic regression CV cell first.")

    X_train_raw = primary_fold['X_train_features']
    imputer = primary_fold['imputer']
    scaler = primary_fold['scaler']

    X_train_imp = imputer.transform(X_train_raw)
    X_train_scaled = scaler.transform(X_train_imp)

    n_features = X_train_scaled.shape[1]
    max_components = min(ELBOW_MAX_COMPONENTS, n_features)

    pca_elbow = PCA(n_components=max_components, random_state=CV_RANDOM_SEED)
    pca_elbow.fit(X_train_scaled)
    cumvar = np.cumsum(pca_elbow.explained_variance_ratio_)

    plt.figure(figsize=(6, 4))
    plt.plot(range(1, max_components + 1), cumvar, marker='o')
    plt.axhline(0.9, color='gray', linestyle='--', label='90% variance')
    plt.xlabel('Number of PCA components')
    plt.ylabel('Cumulative explained variance')
    plt.title('Elbow plot for PCA on training features')
    plt.grid(alpha=0.2)
    plt.legend()
    plt.show()

    print('Example: choose COMPONENT_N_COMPONENTS to the smallest k where the curve bends or cumulative variance reaches around 0.9.')


### Export: per-fold CV summary

This cell collects fold-level results from `logreg_cv_folds` into a single table.

It typically includes:
- Fold identifiers (repeat/fold).
- Test subject sets.
- Accuracy and selected hyperparameters.

It saves the summary as a CSV under the current `outputs/<config_tag>/` directory.


In [None]:
# Summarize the per-fold outer-CV metrics in a table and persist them for later inspection.
# Also report how much TIME_ADJUSTMENT (run-length smoothing) changes accuracy.
if 'logreg_cv_folds' not in globals():
    raise RuntimeError("Run the logistic regression cell before requesting a CV summary.")

summary_rows = []
for fold in logreg_cv_folds:
    summary_rows.append({
        'repeat': fold['repeat'],
        'fold': fold['fold'],
        'test_subjects': ','.join(map(str, np.sort(fold['test_subjects']))),
        # Keep the historical column name for backwards compatibility
        'accuracy': fold.get('acc', None),
        'selected_C': fold.get('selected_C', None),
        'selected_n_bins': fold.get('selected_n_bins', None) if fold.get('selected_n_bins', None) is not None else 'FOOOF',
        'n_test_epochs': int(fold['y_test'].size),
    })

summary_df = pd.DataFrame(summary_rows).sort_values(['repeat', 'fold']).reset_index(drop=True)

# If per-epoch predictions are available, compute fold accuracies with/without smoothing
if 'logreg_predictions_df' in globals() and isinstance(logreg_predictions_df, pd.DataFrame):
    dfp = logreg_predictions_df.copy()
    needed = {'fold', 'y_true', 'y_pred', 'y_pred_smooth'}
    if needed.issubset(set(dfp.columns)):
        fold_acc = (
            dfp.groupby('fold')
            .apply(lambda g: pd.Series({
                'accuracy_raw': float((g['y_true'].to_numpy() == g['y_pred'].to_numpy()).mean()),
                'accuracy_smoothed': float((g['y_true'].to_numpy() == g['y_pred_smooth'].to_numpy()).mean()),
                'delta_accuracy': float((g['y_true'].to_numpy() == g['y_pred_smooth'].to_numpy()).mean() - (g['y_true'].to_numpy() == g['y_pred'].to_numpy()).mean()),
            }))
            .reset_index()
        )
        # fold id in dfp is like r{repeat}_f{fold}
        tmp = summary_df.copy()
        tmp['fold_id'] = tmp.apply(lambda r: f"r{int(r['repeat'])}_f{int(r['fold'])}", axis=1)
        merged = tmp.merge(fold_acc, left_on='fold_id', right_on='fold', how='left', suffixes=('', '_y'))
        merged = merged.drop(columns=['fold_id', 'fold_y'])
        summary_df = merged

        # Overall summary
        try:
            from sklearn.metrics import accuracy_score, balanced_accuracy_score
            y_true_all = dfp['y_true'].to_numpy(dtype=int)
            y_pred_raw = dfp['y_pred'].to_numpy(dtype=int)
            y_pred_sm = dfp['y_pred_smooth'].to_numpy(dtype=int)

            overall = {
                'use_time_adjustment': bool(globals().get('USE_TIME_ADJUSTMENT', False)),
                'run_length_tuned': bool(globals().get('RUN_LENGTH_TUNED', False)),
                'selected_run_length': globals().get('SELECTED_RUN_LENGTH', None),
                'selected_run_length_edge': globals().get('SELECTED_RUN_LENGTH_EDGE', None),
                'accuracy_raw': float(accuracy_score(y_true_all, y_pred_raw)),
                'accuracy_smoothed': float(accuracy_score(y_true_all, y_pred_sm)),
                'delta_accuracy': float(accuracy_score(y_true_all, y_pred_sm) - accuracy_score(y_true_all, y_pred_raw)),
                'balanced_accuracy_raw': float(balanced_accuracy_score(y_true_all, y_pred_raw)),
                'balanced_accuracy_smoothed': float(balanced_accuracy_score(y_true_all, y_pred_sm)),
                'delta_balanced_accuracy': float(balanced_accuracy_score(y_true_all, y_pred_sm) - balanced_accuracy_score(y_true_all, y_pred_raw)),
                'n_epochs_total': int(len(y_true_all)),
                'n_subjects_total': int(dfp['subject_id'].nunique()) if 'subject_id' in dfp.columns else None,
            }
            overall_df = pd.DataFrame([overall])
            overall_path = outpath('time_adjustment_summary.csv')
            overall_df.to_csv(overall_path, index=False)
            print('Saved TIME_ADJUSTMENT summary to', overall_path.resolve())
            display(overall_df)
        except Exception as exc:
            print('Could not compute TIME_ADJUSTMENT summary:', exc)

# Persist
display(summary_df)
summary_path = outpath('logreg_cv_summary.csv')
summary_df.to_csv(summary_path, index=False)
print(f"Saved logistic CV summary to {summary_path.resolve()}")


### Visualization: scalp map of model weights

This cell converts logistic regression coefficients into **per-channel weights** and plots them on a scalp topography.

- In PSD mode, coefficients are aggregated across frequency bins per channel.
- In FOOOF mode, coefficients are aggregated across the selected per-channel FOOOF features.

Use it to interpret which channels drive the decision boundary.


In [None]:
# ---- Scalp map (universal: works for any CV level and USE_FOOOF True/False) ----
import numpy as np
import matplotlib.pyplot as plt
import mne

# -------- helpers --------
def _extract_binary_coef(model):
    coef = np.asarray(getattr(model, "coef_", None))
    if coef is None:
        raise RuntimeError("final_model has no coef_.")
    if coef.ndim == 1:
        return coef
    if coef.ndim == 2 and coef.shape[0] == 1:
        return coef[0]
    if coef.ndim == 2 and coef.shape[0] > 1:
        # pick class 1 vs others if available; else average
        return coef[1] - np.mean(np.vstack([coef[:1], coef[2:]]), axis=0) if coef.shape[0] >= 2 else np.mean(coef, axis=0)
    raise ValueError(f"Unexpected coef_ shape: {coef.shape}")

def _per_channel_weights_fooof(coef, feature_channels, default_stride=5):
    if not feature_channels:
        raise RuntimeError("feature_channels missing while USE_FOOOF=True.")
    n_ch = len(feature_channels)
    stride = coef.size // n_ch if coef.size % n_ch == 0 else default_stride
    if stride * n_ch > coef.size:
        stride = default_stride
    w = []
    for i in range(n_ch):
        s, e = i*stride, min((i+1)*stride, coef.size)
        w.append(float(np.sum(coef[s:e])))
    return np.asarray(w, float), list(feature_channels)

def _per_channel_weights_psd(coef, final_n_bins, PSD_META, TARGET_CHANNELS=None, FREQ_BIN_OPTIONS=None):
    # bins
    if final_n_bins is not None:
        n_bins = int(final_n_bins)
    elif FREQ_BIN_OPTIONS:
        n_bins = int(FREQ_BIN_OPTIONS[0])
    else:
        n_bins = None
    # channels
    psd_ch = list(PSD_META['channels']) if isinstance(PSD_META, dict) and 'channels' in PSD_META else None
    if not psd_ch:
        raise RuntimeError("PSD_META['channels'] missing or empty.")
    cand = list(TARGET_CHANNELS) if TARGET_CHANNELS else psd_ch
    if n_bins is not None:
        if coef.size % n_bins != 0:
            # try infer from candidate channels
            n_bins = coef.size // len(cand) if len(cand) and coef.size % len(cand) == 0 else max(1, coef.size // len(psd_ch))
        n_ch = coef.size // n_bins
    else:
        if len(cand) and coef.size % len(cand) == 0:
            n_ch = len(cand); n_bins = coef.size // n_ch
        elif coef.size % len(psd_ch) == 0:
            n_ch = len(psd_ch); n_bins = coef.size // n_ch
        else:
            raise ValueError("Cannot infer n_bins/channels from coef size.")
    ch_names = cand if len(cand) == n_ch else psd_ch[:n_ch]
    w = []
    for i in range(n_ch):
        s, e = i*n_bins, (i+1)*n_bins
        w.append(float(np.sum(coef[s:e])))
    return np.asarray(w, float), ch_names

# -------- build weights (works regardless of CV level) --------
coef_vec = _extract_binary_coef(final_model)

if 'USE_FOOOF' in globals() and USE_FOOOF:
    # FOOOF layout → per-channel feature blocks
    if 'feature_channels' not in globals():
        raise RuntimeError("feature_channels not defined (required for USE_FOOOF=True).")
    per_channel_weights, channel_names = _per_channel_weights_fooof(coef_vec, feature_channels)
else:
    # PSD-binned layout → [ch0_bins, ch1_bins, ...]
    per_channel_weights, channel_names = _per_channel_weights_psd(
        coef_vec,
        final_n_bins if 'final_n_bins' in globals() else None,
        PSD_META,
        TARGET_CHANNELS=TARGET_CHANNELS if 'TARGET_CHANNELS' in globals() else None,
        FREQ_BIN_OPTIONS=FREQ_BIN_OPTIONS if 'FREQ_BIN_OPTIONS' in globals() else None
    )

if len(channel_names) != per_channel_weights.size:
    raise ValueError(f"Channel/weight mismatch: {len(channel_names)} vs {per_channel_weights.size}")

# -------- plot topomap (API-compatible across MNE versions) --------
sfreq = float(closed_payload.get('sfreq', 256.0)) if 'closed_payload' in globals() else 256.0
info = mne.create_info(ch_names=channel_names, sfreq=sfreq, ch_types='eeg')

# montage
montage = None
try:
    m_name = PSD_META.get('montage', 'standard_1020') if isinstance(PSD_META, dict) else 'standard_1020'
    montage = mne.channels.make_standard_montage(m_name)
    info.set_montage(montage, on_missing='ignore')
except Exception:
    pass

vmin_val = np.percentile(per_channel_weights, 5)
vmax_val = np.percentile(per_channel_weights, 95)

fig, ax = plt.subplots(figsize=(6, 5))
plot_kwargs = dict(axes=ax, names=channel_names, show=False, outlines='head', cmap='RdBu_r')

try:
    im, _ = mne.viz.plot_topomap(per_channel_weights, info, vlim=(vmin_val, vmax_val), **plot_kwargs)
except TypeError:
    norm = plt.Normalize(vmin=vmin_val, vmax=vmax_val)
    im, _ = mne.viz.plot_topomap(per_channel_weights, info, norm=norm, **plot_kwargs)

cb = plt.colorbar(im, ax=ax, fraction=0.046, pad=0.04)
cb.set_label("LogReg weight (signed)")
ax.set_title("Logistic Regression — Electrode weights (signed)")
plt.tight_layout()
plt.show()
# ---- End universal scalp map ----


### Quick readout of aggregate metrics

This cell prints the aggregate metrics computed in the training/CV step:
- Accuracy
- Classification report
- Confusion matrix

It also echoes key run settings (class weights and, in PSD mode, the final binning choice).


In [None]:
print("Accuracy:", acc)
print("Classification Report:\n", report)
print("Confusion Matrix:\n", conf_matrix)
if not USE_FOOOF:
    print(f"Final frequency bins used: {final_n_bins}")
print(f"Class weight setting: {CLASS_WEIGHT}")


### Diagnostics: split composition and feature health

This cell prints:
- How many subjects were used in train vs test (and how many folds covered which subjects).
- Chosen hyperparameters.
- Class balance in train/test.
- Basic feature health checks (percent NaN rows, all-zero rows).

Use it to catch silent failures early (e.g., missing labels, degenerate feature extraction).


In [None]:
# --- Diagnostics / sanity checks ---
print()
print(f"Subjects: total={unique_subjects.size}, primary-train={np.unique(train_subjects).size}, primary-test={logreg_primary_test_subjects.size}")
if CV_LEVEL > 1:
    print(f"Cross-validation covered {test_subjects.size} unique held-out subjects across {len(logreg_cv_folds)} folds.")
    print("Primary fold test subjects (model artefacts saved from this split):", np.sort(logreg_primary_test_subjects))
else:
    print("Held-out test subjects:", np.sort(logreg_primary_test_subjects))
print("Selected final C (mode across inner folds):", final_C)
if not USE_FOOOF:
    print("Selected final n_bins (mode across inner folds):", final_n_bins)
print("Train class counts:", Counter(y_train))
print("Test  class counts:", Counter(y_test))

def pct_nan(x):
    """Percent of rows that contain any NaN.
    Accepts ndarray or list of ndarrays (ragged across folds).
    """
    if isinstance(x, list):
        total = sum(arr.shape[0] for arr in x if isinstance(arr, np.ndarray))
        count = sum(np.isnan(arr).any(axis=1).sum() for arr in x if isinstance(arr, np.ndarray))
        return 100.0 * (count / total) if total else 0.0
    return 100.0 * np.isnan(x).any(axis=1).mean()

def pct_zero_row(x):
    """Percent of rows that are all-zero.
    Accepts ndarray or list of ndarrays (ragged across folds).
    """
    if isinstance(x, list):
        total = sum(arr.shape[0] for arr in x if isinstance(arr, np.ndarray))
        count = 0
        for arr in x:
            if not isinstance(arr, np.ndarray):
                continue
            count += np.all(arr == 0.0, axis=1).sum()
        return 100.0 * (count / total) if total else 0.0
    return 100.0 * (np.all(x == 0.0, axis=1)).mean()

print(f"% rows with any NaN in X_train: {pct_nan(X_train):.2f}%")
print(f"% rows that are all-zero features in X_train: {pct_zero_row(X_train):.2f}%")
print(f"% rows with any NaN in X_test:  {pct_nan(X_test):.2f}%")
print(f"% rows that are all-zero features in X_test:  {pct_zero_row(X_test):.2f}%")
print("=== Aggregated test results (all folds) ===")
print(f"Accuracy: {acc:.3f}")
print("Confusion matrix:")
print(conf_matrix)
print("Classification report:")
print(report)


### ROC curve

This cell plots the ROC curve using predicted probabilities for class **EC (label 1)**.

It is mainly useful when:
- You want a threshold-independent view of separability.
- You are comparing feature modes (PSD vs FOOOF) or smoothing settings.

It assumes the previous training cell has produced probability outputs.


In [None]:
# Get predicted probabilities for class 1 (Eyes Closed)
if 'y_proba' not in globals():
    raise RuntimeError("Logistic regression probabilities are unavailable. Run the training cell first.")

# Compute ROC curve and AUC
fpr, tpr, _ = roc_curve(y_test, y_proba)
roc_auc = auc(fpr, tpr)

# Plot ROC curve
plt.figure(figsize=(6, 5))
plt.plot(fpr, tpr, label=f"Logistic Regression (AUC = {roc_auc:.2f})", linewidth=2)
plt.plot([0, 1], [0, 1], linestyle="--", color="gray")
plt.xlim([0.0, 1.0])
plt.ylim([0.0, 1.05])
plt.xlabel("False Positive Rate")
plt.ylabel("True Positive Rate")
plt.title("ROC Curve - Logistic Regression")
plt.legend(loc="lower right")
plt.grid(alpha=0.2)
plt.show()


### Plotting test accuracy per subject (LOSO)

This plot summarizes **how performance varies across individuals**.

It loads per-subject validation results (from the training/CV step) and:
- Sorts subjects by accuracy.
- Visualizes which subjects are hardest/easiest.

Use it to identify outliers (e.g., bad recordings or labeling mismatches).


In [None]:
# Load saved data
val_accuracies = np.load(outpath("val_accuracies.npy"))
val_subject_ids = np.load(outpath("val_subject_ids.npy"))

# Sort by subject ID for better visualization
sorted_indices = np.argsort(val_subject_ids)
sorted_subjects = val_subject_ids[sorted_indices]
sorted_accuracies = val_accuracies[sorted_indices]

plt.figure(figsize=(12, 6))
plt.bar(sorted_subjects.astype(str), sorted_accuracies, color="#2d4987")
plt.title("Accuracy per Subject (LOSO) - Logistic Regression")
plt.xlabel("Subject ID")
plt.ylabel("Accuracy")
plt.ylim(0, 1)
plt.xticks(rotation=45)
plt.grid(axis='y', linestyle='--', alpha=0.7)
plt.tight_layout()
plt.show()


### Printing the chosen hyperparameters

This cell prints the hyperparameters selected by the inner loop.

Typically:
- `C` is always selected.
- If PSD binning is enabled, `n_bins` (the PSD frequency bin count) may also be selected.

This is useful for reporting and for keeping runs reproducible.


In [None]:
print("Best hyperparameters from inner loop:")
print(f" C: {final_C}")
if not USE_FOOOF:
    print(f" n_bins: {final_n_bins}")
if USE_TIME_ADJUSTMENT:
    print(f" smoothing_L: {SELECTED_RUN_LENGTH} (tuned={RUN_LENGTH_TUNED})")
    if USE_EDGE_SMOOTHING:
        print(f" smoothing_L_edge: {SELECTED_RUN_LENGTH_EDGE}")


### Hold-out performance: per-subject accuracy + confusion matrix

This section focuses on **interpretability of the evaluation**:
- Aggregates predictions by held-out subject.
- Plots per-subject accuracy in the hold-out split.
- Displays a confusion matrix to show EO/EC error balance.

Run it after the training/CV step has produced `logreg_predictions_df` (or equivalent outputs).


In [None]:
# 1. Group predictions by test subject
if 'logreg_predictions_df' not in globals():
    raise RuntimeError("Run the logistic regression cell to populate prediction records before computing per-subject metrics.")
subject_metrics = []
for subj_id, group in logreg_predictions_df.groupby('subject_id'):
    y_true_subj = group['y_true'].values
    y_pred_subj = group['y_pred'].values
    acc_subj = accuracy_score(y_true_subj, y_pred_subj)
    prec_subj = precision_score(y_true_subj, y_pred_subj, zero_division=0)
    rec_subj = recall_score(y_true_subj, y_pred_subj, zero_division=0)
    subject_metrics.append({
        'subject': int(subj_id),
        'accuracy': acc_subj,
        'precision': prec_subj,
        'recall': rec_subj,
    })
subject_metrics = sorted(subject_metrics, key=lambda item: item['subject'])

# 2. Print metrics per subject
print("Per-subject performance:")
for metrics in subject_metrics:
    print(f"Subject {metrics['subject']}: Accuracy = {metrics['accuracy']:.2f}, Precision = {metrics['precision']:.2f}, Recall = {metrics['recall']:.2f}")

# 3. Plot accuracy per subject
plt.figure(figsize=(14, 5))
subject_ids_sorted = [m['subject'] for m in subject_metrics]
accuracies = [m['accuracy'] for m in subject_metrics]
sns.barplot(x=subject_ids_sorted, y=accuracies, color="#2d4987")
plt.ylim(0, 1)
plt.title("Accuracy per Test Subject")
plt.xlabel("Subject ID")
plt.ylabel("Accuracy")
plt.tight_layout()
plt.show()

# 4. Combined confusion matrix
plt.figure(figsize=(5, 4))
sns.heatmap(conf_matrix, annot=True, fmt='d', cmap='Blues', cbar=False)
plt.title("Confusion Matrix (All Test Subjects)")
plt.xlabel("Predicted")
plt.ylabel("True")
plt.tight_layout()
plt.show()


### Export: per-subject metrics

This cell saves a per-subject performance table (accuracy/precision/recall, etc.) to CSV.

It is useful for:
- Reporting which participants drive errors.
- Selecting “hard” subjects for the PSD/FOOOF diagnostic plots.


In [None]:
# Persist the per-subject metrics and highlight the most challenging participants.
if 'subject_metrics' not in globals():
    raise RuntimeError("Compute subject_metrics in the previous cell before exporting them.")
subject_df = pd.DataFrame(subject_metrics).sort_values('accuracy', ascending=False).reset_index(drop=True)
display(subject_df)
worst = subject_df.nsmallest(3, 'accuracy')
if not worst.empty:
    print("Most challenging subjects (lowest accuracy):")
    for _, row in worst.iterrows():
        print(f"  Subject {int(row['subject'])}: accuracy={row['accuracy']:.2f}, precision={row['precision']:.2f}, recall={row['recall']:.2f}")
export_path = outpath('logreg_per_subject_metrics.csv')
subject_df.to_csv(export_path, index=False)
print(f"Saved per-subject metrics to {export_path.resolve()}")


### Diagnostic: average PSD (hard subject) for EC vs EO

This cell selects a challenging subject (based on prediction accuracy) and compares:
- Mean PSD for EO epochs vs EC epochs.
- Optionally a subset of channels (`TARGET_CHANNELS`) if available.

It helps interpret whether the model is struggling due to subtle spectral differences.


In [None]:
# ---- Plot average PSD for a challenging subject: EC vs EO (drop-in) ----
import numpy as np
import matplotlib.pyplot as plt

# ---------- Config ----------
PSD_PLOT_FREQ_RANGE = (3.0, 40.0)     # Hz range to display
PSD_PLOT_LOG_POWER  = True            # True = plot log10 power
DIFFICULTY_MIN_TEST_EPOCHS = 6        # need at least this many test epochs to trust "difficulty"
USE_TARGET_CHANNELS_ONLY = True       # average only across TARGET_CHANNELS if available

# ---------- Pick a "challenging" subject ----------
def _pick_hard_subject():
    # Prefer the predictions DF assembled after CV
    if 'logreg_predictions_df' in globals():
        df = logreg_predictions_df.copy()
        # Per-subject accuracy (or pick worst balanced accuracy if you prefer)
        grp = df.groupby('subject_id').agg(
            n=('y_true','size'),
            acc=('y_true', lambda y: (y.values == df.loc[y.index, 'y_pred'].values).mean())
        )
        # Filter by minimum test size
        grp = grp[grp['n'] >= DIFFICULTY_MIN_TEST_EPOCHS]
        if grp.empty:
            # Fallback: use all subjects if none pass the threshold
            grp = df.groupby('subject_id').agg(n=('y_true','size'),
                                               acc=('y_true', lambda y: (y.values == df.loc[y.index, 'y_pred'].values).mean()))
        hard_subj = int(grp['acc'].idxmin())
        hard_acc  = float(grp.loc[hard_subj, 'acc'])
        return hard_subj, hard_acc, int(grp.loc[hard_subj, 'n'])
    # Secondary fallback: pick a subject with the smallest margin between class counts
    ids = np.asarray(subject_ids)
    subj_list = np.unique(ids)
    best = None
    for s in subj_list:
        idx = np.where(ids == s)[0]
        if idx.size == 0:
            continue
        y_sub = y_combined[idx]
        n = idx.size
        # "difficulty" proxy: how close EC and EO counts are (more balanced → harder sometimes)
        margin = abs((y_sub == 1).sum() - (y_sub == 0).sum()) / n
        score = margin  # smaller is "harder"
        if best is None or score < best[0]:
            best = (score, int(s), n)
    if best is None:
        raise RuntimeError("Could not determine a challenging subject.")
    return best[1], np.nan, best[2]

hard_subject, hard_acc, hard_n = _pick_hard_subject()
print(f"Selected challenging subject: {hard_subject} (test n={hard_n}, acc={hard_acc if not np.isnan(hard_acc) else 'n/a'})")

# ---------- Prepare data for that subject ----------
# Channel set for averaging
if USE_TARGET_CHANNELS_ONLY and 'TARGET_CHANNELS' in globals() and isinstance(TARGET_CHANNELS, (list, tuple)) and len(TARGET_CHANNELS) > 0:
    chs = [ch for ch in TARGET_CHANNELS if ch in PSD_META['channels']]
    if not chs:
        chs = list(PSD_META['channels'])
else:
    chs = list(PSD_META['channels'])

ch_idx = [PSD_META['channels'].index(c) for c in chs]

# Indices for the subject & each condition
subj_mask = (np.asarray(subject_ids) == hard_subject)
eo_mask = subj_mask & (np.asarray(y_combined) == 0)  # Eyes Open label assumed 0
ec_mask = subj_mask & (np.asarray(y_combined) == 1)  # Eyes Closed label assumed 1

if not np.any(eo_mask) or not np.any(ec_mask):
    raise RuntimeError(f"Subject {hard_subject} lacks EO or EC epochs.")

# Slice PSD cube: (epochs, channels, freqs)
psd_eo = psd_cube[eo_mask][:, ch_idx, :]
psd_ec = psd_cube[ec_mask][:, ch_idx, :]

# Average across epochs, then across channels
mean_psd_eo = np.nanmean(psd_eo, axis=(0,1))
mean_psd_ec = np.nanmean(psd_ec, axis=(0,1))

# Optional CI (SEM across epochs) to visualize spread
def _sem_over_epochs(block):  # block shape: (n_epochs, n_channels, n_freqs)
    if block.shape[0] <= 1:
        return np.zeros(block.shape[-1], dtype=float)
    # First average across channels per epoch, then compute SEM across epochs
    per_epoch = np.nanmean(block, axis=1)  # (n_epochs, n_freqs)
    return np.nanstd(per_epoch, axis=0, ddof=1) / np.sqrt(per_epoch.shape[0])

sem_eo = _sem_over_epochs(psd_eo)
sem_ec = _sem_over_epochs(psd_ec)

# Frequency mask for plotting range
freqs = psd_freqs
fmask = (freqs >= PSD_PLOT_FREQ_RANGE[0]) & (freqs <= PSD_PLOT_FREQ_RANGE[1])

# Transform to y-scale
def _to_plot_scale(arr):
    return np.log10(np.maximum(arr, 1e-30)) if PSD_PLOT_LOG_POWER else arr

y_eo = _to_plot_scale(mean_psd_eo)
y_ec = _to_plot_scale(mean_psd_ec)
y_eo_sem = _to_plot_scale(mean_psd_eo + sem_eo) - _to_plot_scale(mean_psd_eo)
y_ec_sem = _to_plot_scale(mean_psd_ec + sem_ec) - _to_plot_scale(mean_psd_ec)

# ---------- Plot ----------
fig, axes = plt.subplots(1, 2, figsize=(12, 4), sharey=True)
titles = ["Eyes Closed (EC)", "Eyes Open (EO)"]
data   = [(y_ec, y_ec_sem), (y_eo, y_eo_sem)]

for ax, title, (y, y_sem) in zip(axes, titles, data):
    ax.plot(freqs[fmask], y[fmask], linewidth=2, label="Mean PSD")
    # shaded SEM
    ax.fill_between(freqs[fmask], (y - y_sem)[fmask], (y + y_sem)[fmask], alpha=0.2, label="±1 SEM")
    # mark alpha band if available
    if 'ALPHA_BAND' in globals() and ALPHA_BAND is not None:
        ax.axvspan(ALPHA_BAND[0], ALPHA_BAND[1], color="#ffbf00", alpha=0.15, label="Alpha band")
    ax.set_title(f"{title} — Subject {hard_subject}")
    ax.set_xlabel("Frequency (Hz)")
    ax.grid(alpha=0.25)

axes[0].set_ylabel("log10 Power" if PSD_PLOT_LOG_POWER else "Power (AU)")
# de-duplicate legend entries
handles, labels = axes[0].get_legend_handles_labels()
uniq = dict(zip(labels, handles))
axes[1].legend(uniq.values(), uniq.keys(), loc="upper right")

fig.suptitle("Challenging subject: average PSD (EC vs EO)", fontsize=14)
plt.tight_layout(rect=(0, 0, 1, 0.95))
plt.show()
# ---- End drop-in ----


### Diagnostic: PSDs for misclassified epochs

This cell drills into **two hard examples**:
- An EC epoch misclassified as EO.
- An EO epoch misclassified as EC.

For each, it plots PSD curves across selected channels so you can visually compare spectral structure near the alpha band and beyond.


In [None]:

# ---- Misclassified epochs PSDs for a challenging subject (EC→EO and EO→EC) ----
import numpy as np
import matplotlib.pyplot as plt

# -------- Config --------
MISCLASS_FREQ_RANGE = (3.0, 40.0)   # Hz range to show
MISCLASS_LOG_POWER  = True          # plot log10 power if True
USE_TARGET_CHANNELS_ONLY = True     # use TARGET_CHANNELS subset if available

# -------- Collect per-epoch predictions (with global epoch indices) --------
mis_records = []  # each: dict(epoch_idx, subject_id, y_true, y_pred, prob_ec, fold)
for fold in logreg_cv_folds:
    idxs = fold['test_idx']
    subs = fold['subject_ids']
    y_t  = fold['y_test']
    y_p  = fold['y_pred']
    p_ec = fold['y_proba']  # probability for class "EC" (label 1)
    for i in range(len(idxs)):
        if y_t[i] != y_p[i]:  # misclassified
            mis_records.append({
                "epoch_idx": int(idxs[i]),
                "subject_id": int(subs[i]),
                "y_true": int(y_t[i]),
                "y_pred": int(y_p[i]),
                "prob_ec": float(p_ec[i]),
                "fold": f"r{fold['repeat']}_f{fold['fold']}",
            })

if not mis_records:
    raise RuntimeError("No misclassifications found across folds; cannot make the requested figure.")

# -------- Pick a subject that has BOTH types of mistakes --------
from collections import defaultdict
by_subj = defaultdict(list)
for rec in mis_records:
    by_subj[rec["subject_id"]].append(rec)

candidate_subject = None
for sid, recs in by_subj.items():
    has_ec_to_eo = any(r["y_true"] == 1 and r["y_pred"] == 0 for r in recs)
    has_eo_to_ec = any(r["y_true"] == 0 and r["y_pred"] == 1 for r in recs)
    if has_ec_to_eo and has_eo_to_ec:
        candidate_subject = sid
        break
if candidate_subject is None:
    # fallback: choose the subject with most total mistakes and then try to pick opposite pairs
    candidate_subject = max(by_subj.items(), key=lambda kv: len(kv[1]))[0]

# Extract one EC→EO and one EO→EC example for that subject (choose the one closest to decision boundary)
def _closest_to_boundary(records, true_label, pred_label):
    # distance to 0.5 for prob_ec (smaller = more ambiguous)
    cand = [r for r in records if r["y_true"] == true_label and r["y_pred"] == pred_label]
    if not cand:
        return None
    return min(cand, key=lambda r: abs(r["prob_ec"] - 0.5))

subject_recs = by_subj[candidate_subject]
ec_to_eo = _closest_to_boundary(subject_recs, true_label=1, pred_label=0)
eo_to_ec = _closest_to_boundary(subject_recs, true_label=0, pred_label=1)

# If one side is missing, just take any misclassified epoch of that kind globally
if ec_to_eo is None:
    ec_to_eo = _closest_to_boundary(mis_records, true_label=1, pred_label=0)
if eo_to_ec is None:
    eo_to_ec = _closest_to_boundary(mis_records, true_label=0, pred_label=1)
if ec_to_eo is None or eo_to_ec is None:
    raise RuntimeError("Could not find both a misclassified EC epoch and a misclassified EO epoch.")

# -------- Pull PSDs for those epochs --------
# Channel selection
if USE_TARGET_CHANNELS_ONLY and 'TARGET_CHANNELS' in globals() and TARGET_CHANNELS:
    channels = [ch for ch in TARGET_CHANNELS if ch in PSD_META['channels']]
    if not channels:
        channels = list(PSD_META['channels'])
else:
    channels = list(PSD_META['channels'])
ch_idx = [PSD_META['channels'].index(c) for c in channels]

# Helper to get PSDs for an epoch (returns (freqs, 2D array: n_channels x n_freqs))
def _epoch_psds(epoch_idx):
    freqs = psd_freqs
    spectra = psd_cube[epoch_idx, ch_idx, :]  # shape (n_channels, n_freqs)
    return freqs, spectra

f_ec2eo, psd_ec2eo = _epoch_psds(ec_to_eo["epoch_idx"])
f_eo2ec, psd_eo2ec = _epoch_psds(eo_to_ec["epoch_idx"])

# -------- Plotting --------
def _to_plot_scale(arr):
    return np.log10(np.maximum(arr, 1e-30)) if MISCLASS_LOG_POWER else arr

fig, axes = plt.subplots(1, 2, figsize=(14, 5), sharey=True)

# Frequency window
mask_ec2eo = (f_ec2eo >= MISCLASS_FREQ_RANGE[0]) & (f_ec2eo <= MISCLASS_FREQ_RANGE[1])
mask_eo2ec = (f_eo2ec >= MISCLASS_FREQ_RANGE[0]) & (f_eo2ec <= MISCLASS_FREQ_RANGE[1])

# Left: EC epoch classified as EO
ax = axes[0]
for i, ch in enumerate(channels):
    ax.plot(f_ec2eo[mask_ec2eo], _to_plot_scale(psd_ec2eo[i, :])[mask_ec2eo], linewidth=1.0, alpha=0.9, label=ch)
if 'ALPHA_BAND' in globals() and ALPHA_BAND is not None:
    ax.axvspan(ALPHA_BAND[0], ALPHA_BAND[1], color="#ffbf00", alpha=0.12)
ax.set_title(f"EC epoch classified as EO — subj {ec_to_eo['subject_id']} — idx {ec_to_eo['epoch_idx']}")
ax.set_xlabel("Frequency (Hz)")
ax.set_ylabel("log10 Power" if MISCLASS_LOG_POWER else "Power (AU)")
ax.grid(alpha=0.25)
# keep legend tidy (many channels → outside frame)
ax.legend(loc='upper right', fontsize=8, ncol=1, frameon=False)

# Right: EO epoch classified as EC
ax = axes[1]
for i, ch in enumerate(channels):
    ax.plot(f_eo2ec[mask_eo2ec], _to_plot_scale(psd_eo2ec[i, :])[mask_eo2ec], linewidth=1.0, alpha=0.9, label=ch)
if 'ALPHA_BAND' in globals() and ALPHA_BAND is not None:
    ax.axvspan(ALPHA_BAND[0], ALPHA_BAND[1], color="#ffbf00", alpha=0.12)
ax.set_title(f"EO epoch classified as EC — subj {eo_to_ec['subject_id']} — idx {eo_to_ec['epoch_idx']}")
ax.set_xlabel("Frequency (Hz)")
ax.grid(alpha=0.25)
# single legend for both (avoid clutter)
handles, labels = axes[0].get_legend_handles_labels()
uniq = dict(zip(labels, handles))
axes[1].legend(uniq.values(), uniq.keys(), loc='upper right', fontsize=8, ncol=1, frameon=False)

fig.suptitle("Misclassified epochs: PSDs for all selected channels", fontsize=14)
plt.tight_layout(rect=(0, 0, 1, 0.95))
plt.show()
# End misclassified-epoch PSD plot


### ONE_MAIN_FOOOF deep-dive: subject ROI spectrum + misclassified examples

This cell combines a subject-level FOOOF view with misclassification examples.

It:
- Chooses a subject (by default the lowest-accuracy subject).
- Fits FOOOF/specparam to the subject’s ROI-averaged spectrum.
- Visualizes example misclassified epochs for that subject.

Run it after feature extraction and after the logistic regression CV cell has produced per-epoch predictions.


In [None]:
# ---- Main FOOOF + misclassified-epoch analysis (ONE_MAIN_FOOOF) ----
# Creates 3 plots (if available):
#  1) Subject-level averaged ROI PSD + aperiodic + main alpha peak
#  2) Example EC epoch misclassified as EO
#  3) Example EO epoch misclassified as EC

import numpy as np
import matplotlib.pyplot as plt
import math

# -------------------- CONFIG --------------------
MAIN_FOOOF_SUBJECT = None  # None -> auto-select lowest-accuracy subject; or set an int subject id
MIS_EXAMPLE_STRATEGY = "most_confident"  # "most_confident" or "first"

if not ("USE_FOOOF" in globals() and USE_FOOOF and "ONE_MAIN_FOOOF" in globals() and ONE_MAIN_FOOOF):
    print("This analysis is intended for ONE_MAIN_FOOOF mode (USE_FOOOF=True, ONE_MAIN_FOOOF=True) after running CV.")
elif "logreg_predictions_df" not in globals():
    print("No logreg_predictions_df found – run the logistic regression cell first.")
elif "psd_cube" not in globals() or "psd_freqs" not in globals() or "feature_channels" not in globals():
    print("No PSD data found – run the feature extraction cell first.")
elif "SpectralModel" not in globals() or SpectralModel is None:
    print("SpectralModel (specparam/FOOOF) is not available in this environment.")
else:
    # Choose evaluation predictions (smoothed vs raw)
    pred_col = 'y_pred'
    if 'USE_TIME_ADJUSTMENT' in globals() and USE_TIME_ADJUSTMENT and 'y_pred_smooth' in logreg_predictions_df.columns:
        pred_col = 'y_pred_smooth'

    df = logreg_predictions_df.copy()
    df['y_eval'] = df[pred_col].astype(int)
    df['y_true'] = df['y_true'].astype(int)

    # Auto-select subject with the lowest accuracy (among subjects present in predictions)
    subj_acc = df.groupby('subject_id').apply(lambda g: float(np.mean(g['y_true'].to_numpy() == g['y_eval'].to_numpy())))
    subj_acc = subj_acc.sort_values(ascending=True)
    if subj_acc.empty:
        raise RuntimeError("No subject predictions available to compute accuracy.")
    auto_subject = int(subj_acc.index[0])
    subj_id = int(MAIN_FOOOF_SUBJECT) if MAIN_FOOOF_SUBJECT is not None else auto_subject
    if subj_id not in subj_acc.index:
        raise ValueError(f"Subject {subj_id} not found in predictions. Available subjects: {list(map(int, subj_acc.index[:10]))}...")
    print(f"Selected subject: {subj_id} (accuracy={subj_acc.loc[subj_id]:.3f}, auto_lowest={auto_subject})")

    # ROI selection (same logic used for alpha profile)
    roi_names = [ch for ch in (ALPHA_PROFILE_ROI if 'ALPHA_PROFILE_ROI' in globals() else []) if ch in feature_channels]
    if not roi_names:
        roi_names = list(feature_channels)
    roi_idx = [feature_channels.index(ch) for ch in roi_names]

    freqs_arr = np.asarray(psd_freqs, float)

    def _fit_main_fooof(freqs_fit, spectrum_fit):
        model = SpectralModel(**FOOOF_SETTINGS) if 'FOOOF_SETTINGS' in globals() else SpectralModel()
        # Fit on provided spectrum (already restricted if caller masked)
        model.fit(freqs_fit, spectrum_fit)
        freqs_plot = np.asarray(getattr(model, 'freqs', freqs_fit))
        psd_plot = np.asarray(getattr(model, 'power_spectrum', spectrum_fit))

        # Reconstruct aperiodic fit from aperiodic_params_
        ap_params = np.asarray(getattr(model, 'aperiodic_params_', []), float)
        if ap_params.size == 0:
            ap_fit = np.zeros_like(freqs_plot, dtype=float)
        else:
            if ap_params.size == 2:
                offset, exponent = ap_params
                ap_fit = offset - exponent * np.log10(freqs_plot)
            elif ap_params.size == 3:
                offset, knee, exponent = ap_params
                ap_fit = offset - np.log10(knee + freqs_plot ** exponent)
            else:
                ap_fit = np.zeros_like(freqs_plot, dtype=float)

        # Main alpha peak (strongest peak inside ALPHA_PROFILE_RANGE)
        gauss_main = None
        peaks = np.asarray(getattr(model, 'peak_params_', []), float)
        if peaks.size:
            lo_alpha, hi_alpha = (ALPHA_PROFILE_RANGE if 'ALPHA_PROFILE_RANGE' in globals() else (8.0, 12.0))
            mask_peaks = (peaks[:, 0] >= lo_alpha) & (peaks[:, 0] <= hi_alpha)
            if np.any(mask_peaks):
                subset = peaks[mask_peaks]
                best = subset[np.argmax(subset[:, 1])]
                cf, amp, bw = map(float, best[:3])
                if bw > 0:
                    sigma = bw / (2.0 * math.sqrt(2.0 * math.log(2.0)))
                    gauss_main = amp * np.exp(-0.5 * ((freqs_plot - cf) / sigma) ** 2)
        return freqs_plot, psd_plot, ap_fit, gauss_main, model

    def _plot_fooof_overlay(freqs_plot, psd_plot, ap_fit, gauss_main, title):
        plt.figure(figsize=(8, 4))
        plt.plot(freqs_plot, psd_plot, label='PSD', color='#1f77b4')
        plt.plot(freqs_plot, ap_fit, label='Aperiodic fit', color='#ff7f0e', linestyle='--')
        if gauss_main is not None:
            plt.plot(freqs_plot, ap_fit + gauss_main, label='Aperiodic + main alpha', color='#2ca02c')
        plt.xlabel('Frequency (Hz)')
        plt.ylabel('Power (model units)')
        plt.title(title)
        plt.grid(alpha=0.25)
        plt.legend()
        plt.tight_layout()
        plt.show()

    # -------------------- Plot 1: subject mean ROI PSD --------------------
    subj_mask = (subject_ids == subj_id)
    if not np.any(subj_mask):
        raise RuntimeError(f"No epochs found in psd_cube for subject {subj_id}.")
    subj_cube = psd_cube[subj_mask][:, roi_idx, :]  # (n_epochs, n_roi, n_freqs)
    mean_spectrum = np.nanmean(subj_cube, axis=(0, 1))
    if not np.any(np.isfinite(mean_spectrum)):
        raise RuntimeError(f"Mean spectrum for subject {subj_id} is non-finite.")

    fit_lo, fit_hi = (ALPHA_FREQ_RANGE if 'ALPHA_FREQ_RANGE' in globals() else (freqs_arr[0], freqs_arr[-1]))
    fit_lo = max(float(fit_lo), float(freqs_arr[0]))
    fit_hi = min(float(fit_hi), float(freqs_arr[-1]))
    fit_mask = (freqs_arr >= fit_lo) & (freqs_arr <= fit_hi)
    freqs_fit = freqs_arr[fit_mask]
    spec_fit = mean_spectrum[fit_mask]

    freqs_plot, psd_plot, ap_fit, gauss_main, _ = _fit_main_fooof(freqs_fit, spec_fit)
    _plot_fooof_overlay(
        freqs_plot,
        psd_plot,
        ap_fit,
        gauss_main,
        title=f"Subject {subj_id}: averaged ROI PSD + main FOOOF (ROI={roi_names})",
    )

    # -------------------- Plot 2/3: example misclassifications for this subject --------------------
    df_subj = df[df['subject_id'] == subj_id].copy()
    if df_subj.empty:
        print(f"No predictions found for subject {subj_id}.")
    else:
        # Helper to choose an example row
        def _pick_example(mask):
            cand = df_subj.loc[mask].copy()
            if cand.empty:
                return None
            if MIS_EXAMPLE_STRATEGY == 'first':
                return cand.iloc[0]
            # most_confident: pick the wrong prediction with highest model confidence
            # prob_ec is P(class==1). If predicted 1 -> confidence=prob_ec; predicted 0 -> confidence=1-prob_ec
            prob = cand['prob_ec'].astype(float).to_numpy()
            pred = cand['y_eval'].astype(int).to_numpy()
            conf = np.where(pred == 1, prob, 1.0 - prob)
            return cand.iloc[int(np.argmax(conf))]

        # EC (true=1) misclassified as EO (pred=0)
        row_ec2eo = _pick_example((df_subj['y_true'] == 1) & (df_subj['y_eval'] == 0))
        if row_ec2eo is None:
            print(f"No EC→EO misclassifications found for subject {subj_id}.")
        else:
            idx = int(row_ec2eo['epoch_idx'])
            spec_epoch = np.nanmean(psd_cube[idx][roi_idx, :], axis=0)
            freqs_plot2, psd_plot2, ap_fit2, gauss2, _ = _fit_main_fooof(freqs_fit, spec_epoch[fit_mask])
            _plot_fooof_overlay(
                freqs_plot2,
                psd_plot2,
                ap_fit2,
                gauss2,
                title=(
                    f"Misclassified EC→EO (subject {subj_id}) — epoch_idx {idx} — "
                    f"prob_ec={float(row_ec2eo['prob_ec']):.3f} ({pred_col})"
                ),
            )

        # EO (true=0) misclassified as EC (pred=1)
        row_eo2ec = _pick_example((df_subj['y_true'] == 0) & (df_subj['y_eval'] == 1))
        if row_eo2ec is None:
            print(f"No EO→EC misclassifications found for subject {subj_id}.")
        else:
            idx = int(row_eo2ec['epoch_idx'])
            spec_epoch = np.nanmean(psd_cube[idx][roi_idx, :], axis=0)
            freqs_plot3, psd_plot3, ap_fit3, gauss3, _ = _fit_main_fooof(freqs_fit, spec_epoch[fit_mask])
            _plot_fooof_overlay(
                freqs_plot3,
                psd_plot3,
                ap_fit3,
                gauss3,
                title=(
                    f"Misclassified EO→EC (subject {subj_id}) — epoch_idx {idx} — "
                    f"prob_ec={float(row_eo2ec['prob_ec']):.3f} ({pred_col})"
                ),
            )
# ---- End main-FOOOF misclassification analysis ----
