# Classifier Results Visualization

This notebook visualizes and analyzes results from the EC/EO classifier pipeline across old (30 subjects) and new (100 subjects) datasets.

It includes:
- Data distributions (age, alpha power)
- FOOOF center frequency analysis
- Model performance metrics (ROC, confusion matrices, accuracy)
- PSD vs FOOOF comparisons
- Statistical tests comparing models

## 1. Setup and Configuration

In [None]:
from __future__ import annotations

import json
import os
import platform
import re
import warnings
from pathlib import Path
from typing import Dict, List, Optional, Tuple

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import matplotlib.gridspec as gridspec
import seaborn as sns
import mne
from sklearn.metrics import (
    accuracy_score, classification_report, confusion_matrix,
    roc_curve, roc_auc_score, precision_score, recall_score, f1_score
)
from scipy import stats
from scipy.io import loadmat
from statsmodels.stats.contingency_tables import mcnemar
import joblib

# Set style
sns.set_style("whitegrid")
plt.rcParams['figure.figsize'] = (12, 6)
plt.rcParams['font.size'] = 10


In [None]:
# -----------------
# Clinical-only Windows configuration (explicit paths)
# -----------------

def _is_windows() -> bool:
    return platform.system() == "Windows"

def _detect_notebook_path() -> Optional[Path]:
    """Detect the current notebook path (best-effort)."""
    try:
        vsc = globals().get("__vsc_ipynb_file__", None)
        if vsc:
            p = Path(str(vsc)).expanduser()
            if p.suffix.lower() == ".ipynb" and p.exists():
                return p.resolve()
    except Exception:
        pass
    for key in ("NOTEBOOK_PATH", "IPYNB_PATH"):
        v = os.getenv(key)
        if v:
            try:
                p = Path(v).expanduser()
                if p.suffix.lower() == ".ipynb" and p.exists():
                    return p.resolve()
            except Exception:
                pass
    return None

NOTEBOOK_PATH = _detect_notebook_path()
NOTEBOOK_DIR = NOTEBOOK_PATH.parent if NOTEBOOK_PATH is not None else Path.cwd().resolve()

# Where `EC_EO_Classifier.ipynb` writes model outputs (within this repo by default)
OUTPUTS_ROOT = NOTEBOOK_DIR / "outputs"
OUTPUTS_ROOT.mkdir(parents=True, exist_ok=True)

# ---------- Clinical raw data roots (Windows) ----------
# Old dataset: 30 subjects (Open_marked + Closed_marked)
OLD_OPEN_DIR = r"E:\\Saxe_sandkasse\\30EOEC_filer\\Open_marked"
OLD_CLOSED_DIR = r"E:\\Saxe_sandkasse\\30EOEC_filer\\Closed_marked"

# New dataset: 100 subjects (preprocessed epoch FIFs)
NEW_PROCESSED_DIR = r"G:\\ChristianMusaeus\\New_EEG\\Processed"

# Saved ONE_MAIN_FOOOF cache
SAVED_FOOOF_DIR = r"G:\\ChristianMusaeus\\saved_fooof"

# Metadata
METADATA_OLD_CSV = r"G:\\ChristianMusaeus\\metadata_time_filtered.csv"
METADATA_NEW_CSV = r"G:\\ChristianMusaeus\\EEG_sub_data_pseudoanonym.csv"

# Notebook-specific outputs (figures + cached intermediates)
VIZ_OUT_DIR = OUTPUTS_ROOT / "visualizations" / "clinical_classifier_results"
VIZ_CACHE_DIR = VIZ_OUT_DIR / "cache"
VIZ_OUT_DIR.mkdir(parents=True, exist_ok=True)
VIZ_CACHE_DIR.mkdir(parents=True, exist_ok=True)

print("Notebook dir:", NOTEBOOK_DIR)
print("Outputs root:", OUTPUTS_ROOT)
print("Viz outputs:", VIZ_OUT_DIR)
print("Platform:", platform.system())
if not _is_windows():
    print("WARNING: This notebook is configured for Windows clinical paths.")

print("Old open dir:", OLD_OPEN_DIR)
print("Old closed dir:", OLD_CLOSED_DIR)
print("New processed dir:", NEW_PROCESSED_DIR)
print("Saved fooof dir:", SAVED_FOOOF_DIR)
print("Metadata old:", METADATA_OLD_CSV)
print("Metadata new:", METADATA_NEW_CSV)


In [None]:
# -----------------
# Helper Functions for Data Loading
# -----------------

def load_model_outputs(output_dir: Path, is_fooof: bool = None) -> Dict:
    """Load model outputs from a configuration directory."""
    if is_fooof is None:
        is_fooof = "__fooof__" in output_dir.name
    
    prefix = "fooof" if is_fooof else "psd"
    data = {}
    
    # Load prediction files
    for key, filename in [
        ("epoch_idx", f"{prefix}_epoch_idx.npy"),
        ("y_true", f"{prefix}_y_true.npy"),
        ("prob_ec", f"{prefix}_prob_ec.npy"),
        ("time_idx", f"{prefix}_time_idx.npy"),
    ]:
        filepath = output_dir / filename
        if filepath.exists():
            data[key] = np.load(filepath)
        else:
            data[key] = None
    
    # Load summary files
    csv_path = output_dir / "logreg_cv_summary.csv"
    if csv_path.exists():
        data["cv_summary"] = pd.read_csv(csv_path)
    else:
        data["cv_summary"] = None
    
    csv_path = output_dir / "logreg_per_subject_metrics.csv"
    if csv_path.exists():
        data["per_subject_metrics"] = normalize_per_subject_metrics(pd.read_csv(csv_path))
    else:
        data["per_subject_metrics"] = None
    
    # Load subject splits
    for key, filename in [
        ("test_subjects", "test_subjects.npy"),
        ("cv_test_subjects", "cv_test_subjects.npy"),
        ("val_accuracies", "val_accuracies.npy"),
        ("val_subject_ids", "val_subject_ids.npy"),
    ]:
        filepath = output_dir / filename
        if filepath.exists():
            data[key] = np.load(filepath)
        else:
            data[key] = None
    
    data["is_fooof"] = is_fooof
    data["output_dir"] = output_dir
    data["model_name"] = output_dir.name
    
    return data

def pretty_model_name(config_name: str) -> str:
    """Short, consistent display name for plots."""
    s = str(config_name)
    if "train_old_test_new" in s:
        ds = "Train old → Test new"
    elif "train_new_test_old" in s:
        ds = "Train new → Test old"
    elif s.startswith("old_dataset"):
        ds = "Old dataset"
    elif s.startswith("new_dataset"):
        ds = "New dataset"
    elif s.startswith("combined_datasets"):
        ds = "Combined datasets"
    else:
        ds = s.split("__")[0]

    feat = "FOOOF" if "__fooof__" in s else "PSD"
    if "__fooof__" in s and "__one_main_fooof__" in s:
        feat = "FOOOF (ONE_MAIN)"

    pen = None
    if "__tune_penalty" in s:
        pen = "tuned"
    elif "__pen_l1" in s:
        pen = "L1"
    elif "__pen_l2" in s:
        pen = "L2"

    parts = [ds, feat]
    if pen:
        parts.append(f"pen {pen}")
    return " | ".join(parts)



def model_short_key(config_name: str) -> str:
    """Very short model key for tight plots."""
    s = str(config_name)
    if s.startswith("old_dataset") and "__fooof__" in s:
        return "old_fooof"
    if s.startswith("old_dataset"):
        return "old_psd"
    if s.startswith("new_dataset") and "__fooof__" in s:
        return "new_fooof"
    if s.startswith("new_dataset"):
        return "new_psd"
    return s.split("__")[0]

def normalize_per_subject_metrics(df: pd.DataFrame) -> pd.DataFrame:
    """Normalize per-subject metrics to have `subject_id` column."""
    if df is None or len(df) == 0:
        return df
    df = df.copy()
    if "subject_id" not in df.columns and "subject" in df.columns:
        df = df.rename(columns={"subject": "subject_id"})
    if "subject_id" in df.columns:
        df["subject_id"] = pd.to_numeric(df["subject_id"], errors="coerce").astype("Int64")
    return df

def load_metadata_old() -> pd.DataFrame:
    """Load old dataset metadata from CSV (clinical path)."""
    path = Path(str(METADATA_OLD_CSV))
    if not path.exists():
        raise FileNotFoundError(f"Old metadata not found: {path}")
    df = pd.read_csv(path)

    # Many clinical CSV exports are semicolon-separated; if we only got 1 column,
    # try again with ';' (keeps explicit path requirement, just more robust parsing).
    if df.shape[1] == 1:
        try:
            df2 = pd.read_csv(path, sep=';', engine='python')
            if df2.shape[1] > 1:
                df = df2
        except Exception:
            pass
    # Normalize columns
    if "subject_id" not in df.columns and "Subject_ID" in df.columns:
        df = df.rename(columns={"Subject_ID": "subject_id"})
    if "age" not in df.columns and "Age" in df.columns:
        df = df.rename(columns={"Age": "age"})
    if "sex" in df.columns:
        df["sex"] = df["sex"].astype(str).str.strip()
    if "subject_id" in df.columns:
        df["subject_id"] = pd.to_numeric(df["subject_id"], errors="coerce").astype("Int64")
    if "age" in df.columns:
        df["age"] = pd.to_numeric(df["age"], errors="coerce")
    return df


def load_metadata_new() -> pd.DataFrame:
    """Load new dataset metadata from CSV (clinical path).

    Expected: subject ids may appear as strings like "Sub001" or "Ros_Sub001".
    This normalizes to numeric Int64 (e.g. 1) to match .fif parsing.
    """
    path = Path(str(METADATA_NEW_CSV))
    if not path.exists():
        raise FileNotFoundError(f"New metadata not found: {path}")

    # Try comma first, then semicolon; keep the parse that yields more columns.
    df_comma = pd.read_csv(path)
    df = df_comma
    try:
        df_semi = pd.read_csv(path, sep=';', engine='python')
        if df_semi.shape[1] > df_comma.shape[1]:
            df = df_semi
    except Exception:
        pass

    # Normalize columns
    if 'subject_id' not in df.columns:
        # common first-column ID
        first = df.columns[0]
        df = df.rename(columns={first: 'subject_id'})

    # If subject_id values look like "Sub001" / "Ros_Sub001", extract digits
    if 'subject_id' in df.columns:
        sid = df['subject_id'].astype(str).str.strip()
        sid = sid.str.extract(r'(\d+)', expand=False)
        df['subject_id'] = pd.to_numeric(sid, errors='coerce').astype('Int64')

    if 'age' not in df.columns:
        if 'Age' in df.columns:
            df = df.rename(columns={'Age': 'age'})
        elif 'Y' in df.columns:
            df = df.rename(columns={'Y': 'age'})

    if 'sex' not in df.columns:
        if 'Sex' in df.columns:
            df = df.rename(columns={'Sex': 'sex'})
        elif 'M/F' in df.columns:
            df = df.rename(columns={'M/F': 'sex'})
        elif 'F/M' in df.columns:
            df = df.rename(columns={'F/M': 'sex'})

    if 'sex' in df.columns:
        df['sex'] = df['sex'].astype(str).str.strip()
        df['sex'] = df['sex'].replace({'M': 'Male', 'F': 'Female'})

    if 'age' in df.columns:
        df['age'] = pd.to_numeric(df['age'], errors='coerce')

    return df

print("Helper functions defined.")


In [None]:
# -----------------
# Load Metadata
# -----------------

metadata_old = load_metadata_old()
metadata_new = load_metadata_new()

print(f"Old metadata: {len(metadata_old)} rows")
print(f"New metadata: {len(metadata_new)} rows")
if len(metadata_old) > 0:
    print(f"Old metadata columns: {metadata_old.columns.tolist()}")
if len(metadata_new) > 0:
    print(f"New metadata columns: {metadata_new.columns.tolist()}")

In [None]:
# -----------------
# Load the two *clinical* "small" datasets (same discovery logic as EC_EO_Classifier.ipynb)
# -----------------

def _collect_set_files(directory: Path) -> list[Path]:
    if directory is None or not Path(directory).exists():
        return []
    directory = Path(directory)
    files = list(directory.rglob('*.set')) + list(directory.rglob('*.SET'))
    return sorted({f.resolve() for f in files})

def _parse_subject_id_from_path(p: Path) -> int:
    m = re.search(r"(\d{3,})", p.stem)
    if m:
        return int(m.group(1))
    raise ValueError(f"Could not parse subject id from: {p.name}")

def load_old_open_closed_pairs(open_dir: str, closed_dir: str, limit_subjects: int = 30):
    open_files = _collect_set_files(Path(open_dir))
    closed_files = _collect_set_files(Path(closed_dir))
    open_by_sid = {}
    closed_by_sid = {}
    for f in open_files:
        try:
            open_by_sid[_parse_subject_id_from_path(f)] = f
        except Exception:
            continue
    for f in closed_files:
        try:
            closed_by_sid[_parse_subject_id_from_path(f)] = f
        except Exception:
            continue

    common = sorted(set(open_by_sid).intersection(closed_by_sid))
    if limit_subjects and len(common) > int(limit_subjects):
        common = common[: int(limit_subjects)]

    pairs = [(sid, open_by_sid[sid], closed_by_sid[sid]) for sid in common]
    return pairs

def load_new_subject_pairs(processed_dir: str):
    processed = Path(processed_dir)
    files = list(processed.glob('*_epo.fif')) + list(processed.glob('*_epo.FIF'))
    pairs: dict[int, dict[str, Path]] = {}
    for f in files:
        m = re.search(r"sub(\d+)([ab])", f.stem, flags=re.IGNORECASE)
        if not m:
            continue
        sid = int(m.group(1))
        rater = m.group(2).lower()
        pairs.setdefault(sid, {})[rater] = f.resolve()

    out: list[tuple[int, Path, Path | None]] = []
    paired_count = 0
    single_count = 0
    for sid in sorted(pairs):
        entry = pairs[sid]
        if 'a' in entry and 'b' in entry:
            out.append((sid, entry['a'], entry['b']))
            paired_count += 1
        elif 'a' in entry:
            out.append((sid, entry['a'], None))
            single_count += 1
        elif 'b' in entry:
            out.append((sid, entry['b'], None))
            single_count += 1

    print(f"NEW subjects found: total={len(out)} (paired={paired_count}, single={single_count})")
    if out:
        print("Example NEW subject entry:", out[0])
    return out

# Resolve clinical datasets
old_pairs = load_old_open_closed_pairs(OLD_OPEN_DIR, OLD_CLOSED_DIR, limit_subjects=30)
new_pairs = load_new_subject_pairs(NEW_PROCESSED_DIR)

old_subject_ids_small = [sid for (sid, _, _) in old_pairs]
new_subject_ids_small = [sid for (sid, _, _) in new_pairs]

print(f"OLD small set: subjects={len(old_subject_ids_small)} | files={len(old_pairs)*2}")
print(f"NEW small set: subjects={len(new_subject_ids_small)} | fif entries={len(new_pairs)}")

# Build old set_files ordering exactly like EC_EO_Classifier: open files first, then closed files
old_open_files_small = [p_open for (_, p_open, _) in old_pairs]
old_closed_files_small = [p_closed for (_, _, p_closed) in old_pairs]
old_set_files_ordered = old_open_files_small + old_closed_files_small

# Filter metadata to the small sets
metadata_old_small = metadata_old[metadata_old.get('subject_id').isin(old_subject_ids_small)].copy() if len(metadata_old) else metadata_old
metadata_new_small = metadata_new[metadata_new.get('subject_id').isin(new_subject_ids_small)].copy() if len(metadata_new) else metadata_new

print(f"Filtered old metadata: {len(metadata_old_small)} / {len(metadata_old)}")
print(f"Filtered new metadata: {len(metadata_new_small)} / {len(metadata_new)}")


In [None]:
# -----------------
# Build (and cache) an epoch-index → (subject, file, epoch) map for the clinical small sets
# This is required to make per-subject epoch plots, worst-subject plots, and disagreement-epoch PSD plots real.
# -----------------

USE_INDEX_MAP_CACHE = True

STANDARD_OCCIPITAL = ["O1", "O2", "P3", "P4", "P7", "P8", "Pz"]
STANDARD_OCCIPITAL_SET = {ch.upper() for ch in STANDARD_OCCIPITAL}

def canonical_channel_name(ch_name: str) -> str:
    name = str(ch_name).strip()
    name = re.sub(r"^EEG\s+", "", name, flags=re.IGNORECASE)
    name = re.sub(r"-REF$", "", name, flags=re.IGNORECASE)
    name = re.sub(r"\s+", "", name)
    return name

def rename_epochs_channels_canonical(epochs):
    new_names = [canonical_channel_name(ch) for ch in epochs.ch_names]
    if len(set(new_names)) != len(new_names):
        return epochs
    mapping = {old: new for old, new in zip(epochs.ch_names, new_names) if old != new}
    if mapping:
        epochs.rename_channels(mapping)
    return epochs

def labels_from_epochs_events(epochs) -> np.ndarray:
    code_to_name = {int(v): str(k).upper() for k, v in epochs.event_id.items()}
    labels = np.full(len(epochs), -1, dtype=int)
    for i, code in enumerate(epochs.events[:, 2].astype(int)):
        name = code_to_name.get(int(code), "")
        if name.startswith("EO"):
            labels[i] = 0
        elif name.startswith("EC"):
            labels[i] = 1
    return labels

def load_rejmanual_vector(set_path: Path, n_epochs_expected: int) -> np.ndarray | None:
    """Return reject vector (1=reject, 0=keep) or None."""
    try:
        mat = loadmat(str(set_path), struct_as_record=False, squeeze_me=True)
    except Exception:
        return None
    rej = None
    block = mat.get("reject", None)
    if block is not None and hasattr(block, "rejmanual"):
        rej = np.array(block.rejmanual)
    if rej is None and "EEG" in mat:
        EEG = mat["EEG"]
        reject_section = getattr(EEG, "reject", None)
        if reject_section is not None:
            if hasattr(reject_section, "rejmanual"):
                rej = np.array(reject_section.rejmanual)
            elif isinstance(reject_section, dict) and "rejmanual" in reject_section:
                rej = np.array(reject_section["rejmanual"])
    if rej is None:
        return None
    rej = np.asarray(rej).ravel().astype(int)
    if rej.size != int(n_epochs_expected):
        warnings.warn(f"{set_path.name}: rejmanual length {rej.size} != n_epochs {n_epochs_expected}")
        return None
    return (rej != 0).astype(int)

def build_index_map_old(set_files_ordered: list[Path]) -> pd.DataFrame:
    rows = []
    cursor = 0
    for i, path in enumerate(set_files_ordered, start=1):
        label = 0 if ("open_marked" in str(path).lower() or "eyesopen" in str(path).lower()) else 1
        try:
            epochs = mne.io.read_epochs_eeglab(str(path), verbose='ERROR')
        except Exception as exc:
            warnings.warn(f"Failed to read {path}: {exc}")
            continue
        epochs = rename_epochs_channels_canonical(epochs)
        n_epochs = len(epochs)
        subj = _parse_subject_id_from_path(path)

        rej = load_rejmanual_vector(path, n_epochs_expected=n_epochs)
        if rej is None:
            warnings.warn(f"{path.name}: missing rejmanual; skipping")
            continue
        keep_mask = (rej == 0)
        if not keep_mask.any():
            continue

        data = epochs.get_data()
        finite_mask = np.all(np.isfinite(data), axis=(1,2))
        keep = keep_mask & finite_mask
        kept_idx = np.flatnonzero(keep)
        for epoch_orig_idx in kept_idx:
            rows.append({
                'global_idx': int(cursor),
                'dataset': 'old',
                'subject_id': int(subj),
                'file': str(path),
                'epoch_orig_idx': int(epoch_orig_idx),
                'y_true': int(label),
                'condition': 'EC' if int(label)==1 else 'EO',
            })
            cursor += 1

        if (i % 10) == 0:
            print(f"  [old] {i}/{len(set_files_ordered)} files | samples so far: {cursor}")

    return pd.DataFrame(rows)

def build_index_map_new(subject_pairs: list[tuple[int, Path, Path | None]]) -> pd.DataFrame:
    rows = []
    cursor = 0
    for i, (sid, file_a, file_b) in enumerate(subject_pairs, start=1):
        path = Path(file_a)
        try:
            epochs = mne.read_epochs(str(path), preload=False, verbose='ERROR')
        except Exception as exc:
            warnings.warn(f"Failed to read NEW epochs {path}: {exc}")
            continue
        epochs = rename_epochs_channels_canonical(epochs)
        labels_a = labels_from_epochs_events(epochs)
        labels_all = labels_a

        path_b = Path(file_b) if file_b else None
        if path_b is not None:
            if not path_b.exists():
                path_b = None
            else:
                try:
                    epochs_b = mne.read_epochs(str(path_b), preload=False, verbose='ERROR')
                    labels_b = labels_from_epochs_events(epochs_b)
                    if labels_b.shape == labels_a.shape:
                        union = labels_a.copy()
                        take_from_b = (union < 0)
                        union[take_from_b] = labels_b[take_from_b]
                        conflict = (labels_a >= 0) & (labels_b >= 0) & (labels_a != labels_b)
                        union[conflict] = -1
                        labels_all = union
                except Exception:
                    path_b = None

        keep_mask_labels = (labels_all >= 0)
        if not keep_mask_labels.any():
            continue

        data = epochs.get_data()
        finite_mask = np.all(np.isfinite(data), axis=(1,2))
        keep = keep_mask_labels & finite_mask
        kept_idx = np.flatnonzero(keep)

        for epoch_orig_idx in kept_idx:
            y = int(labels_all[epoch_orig_idx])
            rows.append({
                'global_idx': int(cursor),
                'dataset': 'new',
                'subject_id': int(sid),
                'file': str(path),
                'file_b': str(path_b) if path_b is not None else None,
                'epoch_orig_idx': int(epoch_orig_idx),
                'y_true': int(y),
                'condition': 'EC' if int(y)==1 else 'EO',
            })
            cursor += 1

        if (i % 25) == 0:
            print(f"  [new] {i}/{len(subject_pairs)} subjects | samples so far: {cursor}")

    return pd.DataFrame(rows)

# Cache + build
idx_old_cache = VIZ_CACHE_DIR / 'index_map_old.pkl'
idx_new_cache = VIZ_CACHE_DIR / 'index_map_new.pkl'

if USE_INDEX_MAP_CACHE and idx_old_cache.exists():
    index_map_old = pd.read_pickle(idx_old_cache)
    print('Loaded index_map_old cache:', idx_old_cache)
else:
    print('Building index_map_old...')
    index_map_old = build_index_map_old(old_set_files_ordered)
    index_map_old.to_pickle(idx_old_cache)
    print('Wrote:', idx_old_cache, '| rows:', len(index_map_old))

if USE_INDEX_MAP_CACHE and idx_new_cache.exists():
    index_map_new = pd.read_pickle(idx_new_cache)
    print('Loaded index_map_new cache:', idx_new_cache)
else:
    print('Building index_map_new...')
    index_map_new = build_index_map_new(new_pairs)
    index_map_new.to_pickle(idx_new_cache)
    print('Wrote:', idx_new_cache, '| rows:', len(index_map_new))

print('index_map_old rows:', len(index_map_old))
print('index_map_new rows:', len(index_map_new))

def attach_subject_mapping(model_data: Dict, index_map: pd.DataFrame) -> pd.DataFrame:
    """Attach subject/file/epoch info to model outputs using global_idx."""
    epoch_idx = model_data.get('epoch_idx')
    y_true = model_data.get('y_true')
    prob_ec = model_data.get('prob_ec')
    time_idx = model_data.get('time_idx')

    if epoch_idx is None or y_true is None or prob_ec is None:
        return pd.DataFrame()

    df = pd.DataFrame({
        'global_idx': np.asarray(epoch_idx, dtype=int),
        'y_true_model': np.asarray(y_true, dtype=int),
        'prob_ec': np.asarray(prob_ec, dtype=float),
    })
    if time_idx is not None:
        df['time_idx'] = np.asarray(time_idx, dtype=int)

    df['y_pred'] = (df['prob_ec'] >= 0.5).astype(int)

    meta_cols = [c for c in ['global_idx','subject_id','file','epoch_orig_idx','y_true','condition'] if c in index_map.columns]
    merged = df.merge(index_map[meta_cols], on='global_idx', how='left')

    missing = merged['subject_id'].isna().sum() if 'subject_id' in merged.columns else len(merged)
    if missing:
        print(f"WARNING: {missing} predictions could not be mapped to subjects (check dataset ordering vs model output).")

    return merged


In [None]:
# -----------------
# Load Model Outputs from 4 Configurations
# -----------------

MODEL_CONFIGS = [
    "old_dataset__fooof__allch__cv2__time_align_conditions__one_main_fooof__mainfooof_all_epochs__pen_l2",
    "old_dataset__no_fooof__allch__cv2__time_align_conditions__pen_l2",
    "new_dataset__fooof__allch__cv2__time_align_conditions__one_main_fooof__mainfooof_all_epochs__pen_l2",
    "new_dataset__no_fooof__allch__cv2__time_align_conditions__pen_l2",
]

model_outputs = {}

for config_name in MODEL_CONFIGS:
    config_dir = OUTPUTS_ROOT / config_name
    if config_dir.exists():
        print(f"Loading {config_name}...")
        model_outputs[config_name] = load_model_outputs(config_dir)
        print(f"  ✓ Loaded")
    else:
        print(f"  ✗ Not found: {config_dir}")

print(f"\nLoaded {len(model_outputs)} model configurations")

In [None]:
# -----------------
# Load FOOOF Center Frequencies (subject-level) from saved_fooof
# Prefer alpha_profiles.csv / feature_manifest.csv produced by Precompute_ONE_MAIN_FOOOF.
# -----------------

saved_root = Path(str(SAVED_FOOOF_DIR))

fooof_profiles = []
if saved_root.exists():
    # Fast path: load alpha_profiles.csv files
    for csv_path in saved_root.rglob('alpha_profiles.csv'):
        try:
            df = pd.read_csv(csv_path)
            df['source'] = str(csv_path)
            df['config_dir'] = str(csv_path.parent.name)
            fooof_profiles.append(df)
        except Exception:
            continue

    # Fallback: feature_manifest.csv (has per-file alpha_cf/alpha_bw)
    if not fooof_profiles:
        for csv_path in saved_root.rglob('feature_manifest.csv'):
            try:
                df = pd.read_csv(csv_path)
                df['source'] = str(csv_path)
                df['config_dir'] = str(csv_path.parent.name)
                fooof_profiles.append(df)
            except Exception:
                continue

if not fooof_profiles:
    print(f"No saved_fooof profile CSVs found under: {saved_root}")
    profiles_df = pd.DataFrame()
else:
    profiles_df = pd.concat(fooof_profiles, ignore_index=True)

# Normalize columns
if len(profiles_df):
    if 'subject' in profiles_df.columns and 'subject_id' not in profiles_df.columns:
        profiles_df = profiles_df.rename(columns={'subject': 'subject_id'})
    if 'subject_id' in profiles_df.columns:
        profiles_df['subject_id'] = pd.to_numeric(profiles_df['subject_id'], errors='coerce').astype('Int64')

    # alpha_cf may be stored as alpha_cf or cf depending on file
    if 'alpha_cf' not in profiles_df.columns:
        for cand in ['cf', 'center_frequency', 'alpha_center_frequency']:
            if cand in profiles_df.columns:
                profiles_df = profiles_df.rename(columns={cand: 'alpha_cf'})
                break

    if 'alpha_cf' in profiles_df.columns:
        profiles_df['alpha_cf'] = pd.to_numeric(profiles_df['alpha_cf'], errors='coerce')

# Split to old/new using subject_id membership (robust even if filenames don't encode dataset)
fooof_cf_by_subject_old = pd.Series(dtype=float)
fooof_cf_by_subject_new = pd.Series(dtype=float)

if len(profiles_df) and 'alpha_cf' in profiles_df.columns and 'subject_id' in profiles_df.columns:
    old_set = set(int(x) for x in old_subject_ids_small)
    new_set = set(int(x) for x in new_subject_ids_small)
    small_set = old_set.union(new_set)

    df_small = profiles_df[profiles_df['subject_id'].isin(list(small_set))].copy()
    cf_by_subject = df_small.groupby('subject_id')['alpha_cf'].mean()
    fooof_cf_by_subject_old = cf_by_subject[cf_by_subject.index.isin(list(old_set))].dropna()
    fooof_cf_by_subject_new = cf_by_subject[cf_by_subject.index.isin(list(new_set))].dropna()

print(f"Saved-FOOOF CF subjects (old): {len(fooof_cf_by_subject_old)}")
print(f"Saved-FOOOF CF subjects (new): {len(fooof_cf_by_subject_new)}")


## 3. Data Visualization

### 3.1 Age Distribution

In [None]:
# -----------------
# Age Distribution Visualization (small clinical sets only)
# -----------------

fig, axes = plt.subplots(1, 2, figsize=(14, 5), sharey=True)

# Old dataset (30 subjects)
ax = axes[0]
if len(metadata_old_small) > 0 and 'age' in metadata_old_small.columns:
    ages = metadata_old_small['age'].dropna().astype(float)
    ax.hist(ages, bins=20, color='steelblue', edgecolor='black', alpha=0.75)
    ax.set_title(f"Old dataset (small set) age distribution (N={len(ages)})")
else:
    ax.text(0.5, 0.5, 'No old metadata for small set', ha='center', va='center', transform=ax.transAxes)
ax.set_xlabel('Age')
ax.set_ylabel('Number of subjects')
ax.grid(True, alpha=0.3)

# New dataset (100 subjects)
ax = axes[1]
if len(metadata_new_small) > 0 and 'age' in metadata_new_small.columns:
    ages = metadata_new_small['age'].dropna().astype(float)
    ax.hist(ages, bins=30, color='coral', edgecolor='black', alpha=0.75)
    ax.set_title(f"New dataset (small set) age distribution (N={len(ages)})")
else:
    ax.text(0.5, 0.5, 'No new metadata for small set', ha='center', va='center', transform=ax.transAxes)
ax.set_xlabel('Age')
ax.grid(True, alpha=0.3)

plt.tight_layout()
plt.show()


### 3.2 Alpha Power Analysis

Note: Alpha power computation requires loading raw epoch data. This section will compute PSD from epochs if data files are accessible, otherwise it will use available computed values.

In [None]:
# -----------------
# Alpha Power Computation and Visualization (clinical raw files)
# -----------------

ALPHA_BAND = (8.0, 13.0)
N_EPOCHS_PER_CLASS = 60
MIN_EPOCHS_PER_CLASS = 6
PSD_TARGET_SECS = 2.0

# Consistent EC/EO colors across all panels
COND_ORDER = ['EO', 'EC']
COND_PALETTE = {'EO': 'tab:blue', 'EC': 'tab:red'}

alpha_cache = VIZ_CACHE_DIR / 'alpha_power_small_sets.csv'

def _psd_array_welch_clean(data: np.ndarray, sfreq: float, fmin: float, fmax: float, target_secs: float = 2.0):
    n_epochs, _, n_times = data.shape
    n_per_seg = max(8, min(n_times, int(round(target_secs * sfreq))))
    n_overlap = n_per_seg // 2 if n_per_seg >= 16 else 0
    psds, freqs = mne.time_frequency.psd_array_welch(
        data,
        sfreq=sfreq,
        fmin=float(fmin),
        fmax=float(fmax),
        n_per_seg=n_per_seg,
        n_overlap=n_overlap,
        window='hann',
        average='mean',
        verbose=False,
    )
    return psds, freqs

def _alpha_power_epochwise(epochs, epoch_indices: np.ndarray, picks: np.ndarray) -> float:
    if epoch_indices.size == 0 or picks.size == 0:
        return float('nan')
    data = epochs.get_data()[epoch_indices][:, picks, :]
    sfreq = float(epochs.info['sfreq'])
    psds, freqs = _psd_array_welch_clean(data, sfreq=sfreq, fmin=ALPHA_BAND[0], fmax=ALPHA_BAND[1], target_secs=PSD_TARGET_SECS)
    # psds: (n_epochs, n_channels, n_freqs)
    mean_power = psds.mean(axis=-1)  # (n_epochs, n_channels)
    return float(np.nanmean(mean_power)) * 1e12


# Relative alpha power: (alpha band power) / (total 1-40 Hz power)
REL_BAND = (1.0, 40.0)

def _band_power_epochwise(epochs, epoch_indices: np.ndarray, picks: np.ndarray, band: tuple[float,float]) -> float:
    if epoch_indices.size == 0 or picks.size == 0:
        return float('nan')
    data = epochs.get_data()[epoch_indices][:, picks, :]
    sfreq = float(epochs.info['sfreq'])
    psds, freqs = _psd_array_welch_clean(data, sfreq=sfreq, fmin=float(band[0]), fmax=float(band[1]), target_secs=PSD_TARGET_SECS)
    # Integrate across freqs (approx band power), then mean over epochs and channels
    pow_ep_ch = psds.sum(axis=-1)
    return float(np.nanmean(pow_ep_ch))

def _relative_alpha_epochwise(epochs, epoch_indices: np.ndarray, picks: np.ndarray) -> float:
    alpha_pow = _band_power_epochwise(epochs, epoch_indices, picks, ALPHA_BAND)
    total_pow = _band_power_epochwise(epochs, epoch_indices, picks, REL_BAND)
    if not np.isfinite(alpha_pow) or not np.isfinite(total_pow) or total_pow == 0:
        return float('nan')
    return float(alpha_pow / total_pow)

def _picks_all_eeg(epochs) -> np.ndarray:
    picks = mne.pick_types(epochs.info, eeg=True, meg=False, stim=False, eog=False, exclude='bads')
    if len(picks) == 0:
        picks = mne.pick_types(epochs.info, eeg=True, meg=False, stim=False, eog=False, exclude=[])
    return np.asarray(picks, dtype=int)

def _picks_occipital(epochs) -> np.ndarray:
    names = [canonical_channel_name(ch).upper() for ch in epochs.ch_names]
    occ = [i for i,n in enumerate(names) if str(n).upper() in STANDARD_OCCIPITAL_SET]
    return np.asarray(occ, dtype=int)

def _load_epochs_for_file(path_str: str):
    path = Path(str(path_str))
    if path.suffix.lower() == '.set':
        epochs = mne.io.read_epochs_eeglab(str(path), verbose='ERROR')
    else:
        epochs = mne.read_epochs(str(path), preload=False, verbose='ERROR')
    epochs = rename_epochs_channels_canonical(epochs)
    return epochs


# ---- Attach metadata (age/sex) without overwriting new-set ages ----

def _build_metadata_small() -> pd.DataFrame:
    parts = []
    for df in [metadata_old_small, metadata_new_small]:
        try:
            if df is not None and len(df):
                cols = [c for c in ['subject_id','age','sex'] if c in df.columns]
                if 'subject_id' in cols:
                    parts.append(df[cols].drop_duplicates())
        except Exception:
            continue
    if not parts:
        return pd.DataFrame(columns=['subject_id','age','sex'])
    meta = pd.concat(parts, ignore_index=True).drop_duplicates(subset=['subject_id'])
    meta['subject_id'] = pd.to_numeric(meta['subject_id'], errors='coerce').astype('Int64')
    if 'age' in meta.columns:
        meta['age'] = pd.to_numeric(meta['age'], errors='coerce')
    return meta

def _attach_metadata_small(alpha_df: pd.DataFrame) -> pd.DataFrame:
    if alpha_df is None or len(alpha_df) == 0:
        return alpha_df
    meta = _build_metadata_small()
    if meta is None or len(meta) == 0:
        return alpha_df

    out = alpha_df.copy()
    # Drop legacy merge artifacts if present
    for c in ['age_meta','sex_meta','age_meta2','sex_meta2','age_meta_x','age_meta_y','sex_meta_x','sex_meta_y']:
        if c in out.columns:
            out = out.drop(columns=[c])

    out['subject_id'] = pd.to_numeric(out['subject_id'], errors='coerce').astype('Int64')
    out = out.merge(meta, on='subject_id', how='left', suffixes=('', '_meta'))

    # Coalesce if we already had age/sex in the cache
    for col in ['age','sex']:
        alt = f'{col}_meta'
        if col in out.columns and alt in out.columns:
            out[col] = out[col].fillna(out[alt])
            out = out.drop(columns=[alt])

    return out

if alpha_cache.exists():
    alpha_df = pd.read_csv(alpha_cache)
    alpha_df = _attach_metadata_small(alpha_df)
    # If cache was created by an older version (new-set ages ended up in age_meta2), this fixes it.
    if 'dataset' in alpha_df.columns and 'age' in alpha_df.columns:
        if alpha_df.loc[alpha_df['dataset']=='new','age'].isna().all() and len(_build_metadata_small()):
            print('WARNING: new dataset ages missing after attach; check METADATA_NEW_CSV subject_id normalization.')
    # Rewrite cache so future runs have consistent columns
    alpha_df.to_csv(alpha_cache, index=False)
    print('Loaded alpha cache:', alpha_cache, '| rows:', len(alpha_df))
else:
    print('Computing alpha power for small sets... (this can take a while the first time)')
    records = []

    # Old dataset: use index_map_old (already filtered to keep/finite)
    for sid in sorted(set(index_map_old['subject_id'].dropna().astype(int).tolist())):
        for cond_label, cond_name in [(0,'EO'), (1,'EC')]:
            rows = index_map_old[(index_map_old['subject_id']==sid) & (index_map_old['y_true']==cond_label)]
            if rows.empty:
                continue
            # Each condition is a separate file in old dataset
            file_path = rows['file'].iloc[0]
            epoch_idxs = rows['epoch_orig_idx'].to_numpy(dtype=int)
            epoch_idxs = np.sort(epoch_idxs)[:N_EPOCHS_PER_CLASS]
            if epoch_idxs.size < MIN_EPOCHS_PER_CLASS:
                continue
            epochs = _load_epochs_for_file(file_path)
            picks_all = _picks_all_eeg(epochs)
            picks_occ = _picks_occipital(epochs)
            val_all = _alpha_power_epochwise(epochs, epoch_idxs, picks_all)
            val_occ = _alpha_power_epochwise(epochs, epoch_idxs, picks_occ) if picks_occ.size else float('nan')
            rel_all = _relative_alpha_epochwise(epochs, epoch_idxs, picks_all)
            rel_occ = _relative_alpha_epochwise(epochs, epoch_idxs, picks_occ) if picks_occ.size else float('nan')
            records.append(dict(dataset='old', subject_id=int(sid), condition=str(cond_name), mean_alpha_all=val_all, mean_alpha_occ=val_occ, mean_rel_all=rel_all, mean_rel_occ=rel_occ))

    # New dataset
    for sid in sorted(set(index_map_new['subject_id'].dropna().astype(int).tolist())):
        rows_sub = index_map_new[index_map_new['subject_id']==sid]
        if rows_sub.empty:
            continue
        file_path = rows_sub['file'].iloc[0]
        epochs = _load_epochs_for_file(file_path)
        picks_all = _picks_all_eeg(epochs)
        picks_occ = _picks_occipital(epochs)
        for cond_label, cond_name in [(0,'EO'), (1,'EC')]:
            rows = rows_sub[rows_sub['y_true']==cond_label]
            epoch_idxs = np.sort(rows['epoch_orig_idx'].to_numpy(dtype=int))[:N_EPOCHS_PER_CLASS]
            if epoch_idxs.size < MIN_EPOCHS_PER_CLASS:
                continue
            val_all = _alpha_power_epochwise(epochs, epoch_idxs, picks_all)
            val_occ = _alpha_power_epochwise(epochs, epoch_idxs, picks_occ) if picks_occ.size else float('nan')
            rel_all = _relative_alpha_epochwise(epochs, epoch_idxs, picks_all)
            rel_occ = _relative_alpha_epochwise(epochs, epoch_idxs, picks_occ) if picks_occ.size else float('nan')
            records.append(dict(dataset='new', subject_id=int(sid), condition=str(cond_name), mean_alpha_all=val_all, mean_alpha_occ=val_occ, mean_rel_all=rel_all, mean_rel_occ=rel_occ))

    alpha_df = pd.DataFrame.from_records(records)

    alpha_df = _attach_metadata_small(alpha_df)

    alpha_df.to_csv(alpha_cache, index=False)
    print('Wrote alpha cache:', alpha_cache)

# Plot
if len(alpha_df) == 0:
    print('No alpha power rows computed.')
else:
    for metric, title in [("mean_alpha_all", "All channels"), ("mean_alpha_occ", "Occipital ROI")]:
        fig, axes = plt.subplots(1, 2, figsize=(14, 5), sharey=True)
        for ax, ds, meta, color in [
            (axes[0], 'old', metadata_old_small, 'steelblue'),
            (axes[1], 'new', metadata_new_small, 'coral'),
        ]:
            df_ds = alpha_df[(alpha_df['dataset']==ds) & alpha_df['age'].notna()].copy()
            if df_ds.empty:
                ax.text(0.5,0.5,f'No {ds} alpha+age data',ha='center',va='center',transform=ax.transAxes)
                continue
            sns.lineplot(x='age', y=metric, hue='condition', hue_order=COND_ORDER, palette=COND_PALETTE, data=df_ds, estimator='mean', errorbar=('ci',95), marker='o', linewidth=2, ax=ax)
            ax.set_title(f"{ds.upper()} dataset: {title}")
            ax.set_xlabel('Age')
            ax.set_ylabel('Mean alpha power (µV²/Hz)')
            ax.grid(True, alpha=0.3)
            ax.legend(title='Condition')
        plt.tight_layout()
        plt.show()
# Relative alpha plots (same layout as absolute)
for metric, title in [("mean_rel_all", "All channels"), ("mean_rel_occ", "Occipital ROI")]:
    fig, axes = plt.subplots(1, 2, figsize=(14, 5), sharey=True)
    for ax, ds in [(axes[0], 'old'), (axes[1], 'new')]:
        df_ds = alpha_df[(alpha_df['dataset']==ds) & alpha_df.get('age').notna()].copy() if 'age' in alpha_df.columns else pd.DataFrame()
        if df_ds.empty or metric not in df_ds.columns:
            ax.text(0.5,0.5,f'No {ds} relative alpha+age data',ha='center',va='center',transform=ax.transAxes)
            continue
        sns.lineplot(x='age', y=metric, hue='condition', hue_order=COND_ORDER, palette=COND_PALETTE,
                     data=df_ds, estimator='mean', errorbar=('ci',95), marker='o', linewidth=2, ax=ax)
        ax.set_title(f"{ds.upper()} dataset: {title} (relative)")
        ax.set_xlabel('Age')
        ax.set_ylabel('Relative alpha power (8–13 / 1–40 Hz)')
        ax.grid(True, alpha=0.3)
        ax.legend(title='Condition')
    plt.tight_layout()
    plt.show()


### 3.3 Labeled Epochs Visualization

In [None]:
# -----------------
# Labeled Epochs Visualization
# -----------------

def count_epochs_by_class(model_data: Dict) -> Dict:
    """Count epochs by class (EC=1, EO=0) for true labels and predictions."""
    counts = {"true": {"EC": 0, "EO": 0}, "pred": {"EC": 0, "EO": 0}}

    if model_data.get("y_true") is not None:
        y_true = model_data["y_true"]
        counts["true"]["EC"] = int(np.sum(y_true == 1))
        counts["true"]["EO"] = int(np.sum(y_true == 0))

    if model_data.get("prob_ec") is not None:
        prob_ec = model_data["prob_ec"]
        y_pred = (prob_ec >= 0.5).astype(int)
        counts["pred"]["EC"] = int(np.sum(y_pred == 1))
        counts["pred"]["EO"] = int(np.sum(y_pred == 0))

    return counts

# Collect counts for each model
old_fooof_model = None
old_psd_model = None
new_fooof_model = None
new_psd_model = None

for name, data in model_outputs.items():
    if "old_dataset__fooof" in name:
        old_fooof_model = data
    elif "old_dataset__no_fooof" in name:
        old_psd_model = data
    elif "new_dataset__fooof" in name:
        new_fooof_model = data
    elif "new_dataset__no_fooof" in name:
        new_psd_model = data

fig, axes = plt.subplots(1, 2, figsize=(14, 6), sharey=True)

# True labels
ax = axes[0]
datasets = ["Old", "New"]

ec_counts = []
eo_counts = []
if old_fooof_model and old_fooof_model.get("y_true") is not None:
    y_true_old = old_fooof_model["y_true"]
    ec_counts.append(int(np.sum(y_true_old == 1)))
    eo_counts.append(int(np.sum(y_true_old == 0)))
else:
    ec_counts.append(0); eo_counts.append(0)

if new_fooof_model and new_fooof_model.get("y_true") is not None:
    y_true_new = new_fooof_model["y_true"]
    ec_counts.append(int(np.sum(y_true_new == 1)))
    eo_counts.append(int(np.sum(y_true_new == 0)))
else:
    ec_counts.append(0); eo_counts.append(0)

x = np.arange(len(datasets))
width = 0.35
ax.bar(x - width/2, ec_counts, width, label="EC", color="red", alpha=0.7)
ax.bar(x + width/2, eo_counts, width, label="EO", color="blue", alpha=0.7)
ax.set_xlabel("Dataset")
ax.set_ylabel("Number of Epochs")
ax.set_title("True Labels: EC vs EO")
ax.set_xticks(x)
ax.set_xticklabels(datasets)
ax.legend()
ax.grid(True, alpha=0.3, axis="y")

# Predictions
ax = axes[1]
ec_counts_pred = []
eo_counts_pred = []

if old_fooof_model and old_fooof_model.get("prob_ec") is not None:
    prob_ec_old = old_fooof_model["prob_ec"]
    y_pred_old = (prob_ec_old >= 0.5).astype(int)
    ec_counts_pred.append(int(np.sum(y_pred_old == 1)))
    eo_counts_pred.append(int(np.sum(y_pred_old == 0)))
else:
    ec_counts_pred.append(0); eo_counts_pred.append(0)

if new_fooof_model and new_fooof_model.get("prob_ec") is not None:
    prob_ec_new = new_fooof_model["prob_ec"]
    y_pred_new = (prob_ec_new >= 0.5).astype(int)
    ec_counts_pred.append(int(np.sum(y_pred_new == 1)))
    eo_counts_pred.append(int(np.sum(y_pred_new == 0)))
else:
    ec_counts_pred.append(0); eo_counts_pred.append(0)

ax.bar(x - width/2, ec_counts_pred, width, label="EC", color="red", alpha=0.7)
ax.bar(x + width/2, eo_counts_pred, width, label="EO", color="blue", alpha=0.7)
ax.set_xlabel("Dataset")
ax.set_ylabel("Number of Epochs")
ax.set_title("Predictions: EC vs EO")
ax.set_xticks(x)
ax.set_xticklabels(datasets)
ax.legend()
ax.grid(True, alpha=0.3, axis="y")

# Force identical y-axis limits
max_y = max(ec_counts + eo_counts + ec_counts_pred + eo_counts_pred)
for ax in axes:
    ax.set_ylim(0, max_y * 1.05 if max_y else 1)

plt.tight_layout()
plt.show()


In [None]:
# -----------------
# Per-Subject Labeled Epochs Visualization (real counts via index_map)
# -----------------

def _pick_model(model_outputs: Dict, prefix: str) -> Dict | None:
    for name, data in model_outputs.items():
        if prefix in name:
            return data
    return None

old_model = _pick_model(model_outputs, 'old_dataset__fooof')
new_model = _pick_model(model_outputs, 'new_dataset__fooof')

fig, axes = plt.subplots(1, 2, figsize=(16, 6), sharey=True)

for ax, model_data, idx_map, title in [
    (axes[0], old_model, index_map_old, 'OLD (30 subjects)'),
    (axes[1], new_model, index_map_new, 'NEW (100 subjects)'),
]:
    if model_data is None:
        ax.text(0.5,0.5,'Missing model outputs',ha='center',va='center',transform=ax.transAxes)
        continue
    df = attach_subject_mapping(model_data, idx_map)
    if df.empty or 'subject_id' not in df.columns:
        ax.text(0.5,0.5,'Could not map predictions to subjects',ha='center',va='center',transform=ax.transAxes)
        continue

    # Count predicted labels per subject
    counts = (
        df.dropna(subset=['subject_id'])
          .assign(subject_id=lambda d: d['subject_id'].astype(int))
          .groupby(['subject_id','y_pred']).size().unstack(fill_value=0)
          .rename(columns={0:'EO_pred',1:'EC_pred'})
          .reset_index()
          .sort_values('subject_id')
    )

    # Plot all subjects (can be wide for NEW)
    max_subjects = 30 if title.startswith('OLD') else 50
    counts_plot = counts.head(max_subjects)

    x = np.arange(len(counts_plot))
    width = 0.42
    ax.bar(x - width/2, counts_plot.get('EC_pred',0), width, label='EC', color='red', alpha=0.7)
    ax.bar(x + width/2, counts_plot.get('EO_pred',0), width, label='EO', color='blue', alpha=0.7)
    ax.set_title(f"{title}: predicted EC/EO epochs per subject (first {len(counts_plot)})")
    ax.set_xlabel('Subject')
    ax.set_ylabel('Epoch count')
    ax.set_xticks(x)
    ax.set_xticklabels([str(int(s)) for s in counts_plot['subject_id'].tolist()], rotation=45, ha='right')
    ax.grid(True, alpha=0.3, axis='y')
    ax.legend()

# Align y-limits
ylim = max(ax.get_ylim()[1] for ax in axes)
for ax in axes:
    ax.set_ylim(0, ylim)

plt.tight_layout()
plt.show()


## 4. FOOOF Visualization

### 4.1 Center Frequency Distribution

In [None]:
# -----------------
# FOOOF Center Frequency Distribution (subject-level)
# -----------------

fig, axes = plt.subplots(1, 2, figsize=(14, 5), sharey=True)

# Old
ax = axes[0]
cf_old = fooof_cf_by_subject_old.values if hasattr(fooof_cf_by_subject_old, 'values') else np.array([])
cf_old = cf_old[np.isfinite(cf_old)]
if cf_old.size:
    ax.hist(cf_old, bins=25, color='steelblue', edgecolor='black', alpha=0.75)
    ax.set_title(f"Old dataset CF (N={cf_old.size})")
    ax.set_xlabel('Center frequency (Hz)')
    ax.set_ylabel('Count')
    ax.grid(True, alpha=0.3)
else:
    ax.text(0.5,0.5,'No CF data for old dataset',ha='center',va='center',transform=ax.transAxes)

# New
ax = axes[1]
cf_new = fooof_cf_by_subject_new.values if hasattr(fooof_cf_by_subject_new, 'values') else np.array([])
cf_new = cf_new[np.isfinite(cf_new)]
if cf_new.size:
    ax.hist(cf_new, bins=25, color='coral', edgecolor='black', alpha=0.75)
    ax.set_title(f"New dataset CF (N={cf_new.size})")
    ax.set_xlabel('Center frequency (Hz)')
    ax.grid(True, alpha=0.3)
else:
    ax.text(0.5,0.5,'No CF data for new dataset',ha='center',va='center',transform=ax.transAxes)

plt.tight_layout()
plt.show()


### 4.2 Center Frequency vs Age

In [None]:
# -----------------
# FOOOF Center Frequency vs Age (subject-level)
# -----------------

fig, axes = plt.subplots(1, 2, figsize=(14, 5), sharey=True)

# Old
ax = axes[0]
if len(fooof_cf_by_subject_old) and len(metadata_old_small):
    df = pd.DataFrame({'subject_id': fooof_cf_by_subject_old.index.astype(int), 'alpha_cf': fooof_cf_by_subject_old.values})
    df = df.merge(metadata_old_small[['subject_id','age']], on='subject_id', how='left')
    df = df[df['age'].notna() & df['alpha_cf'].notna()]
    if len(df):
        sns.regplot(x='age', y='alpha_cf', data=df, ax=ax, scatter_kws={'alpha':0.5, 's':25}, line_kws={'color':'black'})
        ax.set_title('Old dataset: CF vs Age')
        ax.set_xlabel('Age')
        ax.set_ylabel('Center frequency (Hz)')
        ax.grid(True, alpha=0.3)
    else:
        ax.text(0.5,0.5,'No age+CF rows (old)',ha='center',va='center',transform=ax.transAxes)
else:
    ax.text(0.5,0.5,'Missing old CF/metadata',ha='center',va='center',transform=ax.transAxes)

# New
ax = axes[1]
if len(fooof_cf_by_subject_new) and len(metadata_new_small):
    df = pd.DataFrame({'subject_id': fooof_cf_by_subject_new.index.astype(int), 'alpha_cf': fooof_cf_by_subject_new.values})
    df = df.merge(metadata_new_small[['subject_id','age']], on='subject_id', how='left')
    df = df[df['age'].notna() & df['alpha_cf'].notna()]
    if len(df):
        sns.regplot(x='age', y='alpha_cf', data=df, ax=ax, scatter_kws={'alpha':0.5, 's':25}, line_kws={'color':'black'})
        ax.set_title('New dataset: CF vs Age')
        ax.set_xlabel('Age')
        ax.set_ylabel('Center frequency (Hz)')
        ax.grid(True, alpha=0.3)
    else:
        ax.text(0.5,0.5,'No age+CF rows (new)',ha='center',va='center',transform=ax.transAxes)
else:
    ax.text(0.5,0.5,'Missing new CF/metadata',ha='center',va='center',transform=ax.transAxes)

plt.tight_layout()
plt.show()


## 5. Model Performance Analysis

For each model configuration, we'll analyze:
- Aggregated test results (accuracy, classification report)
- ROC curve (with AUC)
- Confusion matrix
- Accuracy histograms (per training and test subject)
- Accuracy before/after time adjustment smoothing

In [None]:
# -----------------
# Model Performance Analysis Helper Functions
# -----------------

def analyze_model_performance(model_data: Dict, model_name: str):
    """Analyze and print model performance metrics."""
    print(f"\n{'='*60}")
    print(f"Model: {pretty_model_name(model_name)}")
    print(f"  folder: {model_name}")
    print(f"{'='*60}")
    
    if model_data.get("y_true") is None or model_data.get("prob_ec") is None:
        print("Missing prediction data (y_true or prob_ec)")
        return None
    
    y_true = model_data["y_true"]
    prob_ec = model_data["prob_ec"]
    y_pred = (prob_ec >= 0.5).astype(int)
    
    # Aggregated test results
    accuracy = accuracy_score(y_true, y_pred)
    print(f"\nAggregated Test Results:")
    print(f"Accuracy: {accuracy:.4f}")
    print(f"\nClassification Report:")
    print(classification_report(y_true, y_pred, target_names=["EO", "EC"], digits=4))
    
    # ROC AUC
    try:
        auc = roc_auc_score(y_true, prob_ec)
        print(f"\nROC AUC: {auc:.4f}")
    except Exception as e:
        print(f"\nROC AUC: Error computing - {e}")
        auc = None
    
    # Confusion Matrix
    cm = confusion_matrix(y_true, y_pred)
    print(f"\nConfusion Matrix:")
    print(cm)
    
    # Time adjustment smoothing (if available in cv_summary)
    if model_data.get("cv_summary") is not None:
        df_summary = model_data["cv_summary"]
        if "accuracy_raw" in df_summary.columns and "accuracy_smoothed" in df_summary.columns:
            acc_raw = df_summary["accuracy_raw"].mean()
            acc_smooth = df_summary["accuracy_smoothed"].mean()
            delta = acc_smooth - acc_raw
            print(f"\nTime Adjustment Smoothing:")
            print(f"  Accuracy (raw): {acc_raw:.4f}")
            print(f"  Accuracy (smoothed): {acc_smooth:.4f}")
            print(f"  Delta: {delta:.4f}")
    
    return {
        "accuracy": accuracy,
        "auc": auc,
        "cm": cm,
        "y_true": y_true,
        "y_pred": y_pred,
        "prob_ec": prob_ec,
    }

In [None]:
# -----------------
# Analyze Each Model (and print chosen hyperparameters)
# -----------------

model_performances = {}

def _print_hyperparams(model_name: str, model_data: Dict):
    cv = model_data.get('cv_summary')
    if cv is None or len(cv) == 0:
        print(f"{model_name}: (no cv_summary)")
        return

    cols = [c for c in ['selected_C','selected_n_bins'] if c in cv.columns]
    if not cols:
        print(f"{model_name}: (cv_summary missing selected hyperparam columns)")
        return

    print("\n" + "-"*80)
    print(f"Hyperparameters for: {model_short_key(model_name)}")
    for c in cols:
        vals = cv[c].dropna().astype(str)
        if len(vals) == 0:
            continue
        mode = vals.value_counts().index[0]
        uniq = sorted(vals.unique().tolist())
        print(f"  {c}: mode={mode} | unique={uniq}")

for model_name, model_data in model_outputs.items():
    _print_hyperparams(model_name, model_data)
    perf = analyze_model_performance(model_data, model_name)
    if perf is not None:
        model_performances[model_name] = perf


### 5.1 ROC Curves

In [None]:
# -----------------
# ROC Curves for All Models
# -----------------

fig, axes = plt.subplots(2, 2, figsize=(14, 12))
axes = axes.flatten()

colors = {"old": "steelblue", "new": "coral"}
styles = {"fooof": "-", "no_fooof": "--"}

for idx, (model_name, perf) in enumerate(model_performances.items()):
    ax = axes[idx]
    
    y_true = perf["y_true"]
    prob_ec = perf["prob_ec"]
    
    fpr, tpr, _ = roc_curve(y_true, prob_ec)
    auc_score = perf.get("auc")
    
    ax.plot(fpr, tpr, linewidth=2, label=f"ROC (AUC = {auc_score:.3f})" if auc_score else "ROC")
    ax.plot([0, 1], [0, 1], 'k--', linewidth=1, label="Random")
    ax.set_xlabel("False Positive Rate")
    ax.set_ylabel("True Positive Rate")
    ax.set_title(pretty_model_name(model_name))
    ax.legend()
    ax.grid(True, alpha=0.3)

plt.tight_layout()
plt.show()

### 5.2 Confusion Matrices

In [None]:
# -----------------
# Confusion Matrices for All Models
# -----------------

fig, axes = plt.subplots(2, 2, figsize=(14, 12))
axes = axes.flatten()

for idx, (model_name, perf) in enumerate(model_performances.items()):
    ax = axes[idx]
    
    cm = perf["cm"]
    
    sns.heatmap(cm, annot=True, fmt="d", cmap="Blues", ax=ax,
                xticklabels=["EO", "EC"], yticklabels=["EO", "EC"])
    ax.set_xlabel("Predicted")
    ax.set_ylabel("True")
    ax.set_title(pretty_model_name(model_name))

plt.tight_layout()
plt.show()

### 5.3 Accuracy Histograms

In [None]:
# -----------------
# Accuracy per subject (bar charts: one bar per subject)
# -----------------

def _bar_by_subject(ax, df: pd.DataFrame, title: str, color: str = 'steelblue'):
    if df is None or len(df) == 0:
        ax.text(0.5,0.5,'No data',ha='center',va='center',transform=ax.transAxes)
        ax.set_title(title)
        return

    df = df.dropna(subset=['subject_id','accuracy']).copy()
    df['subject_id'] = df['subject_id'].astype(int)
    df = df.groupby('subject_id', as_index=False)['accuracy'].mean().sort_values('subject_id')

    x = np.arange(len(df))
    ax.bar(x, df['accuracy'].to_numpy(dtype=float), color=color, alpha=0.85)
    ax.set_title(title)
    ax.set_xlabel('subject_id')
    ax.set_ylabel('accuracy')
    ax.set_ylim(0, 1.02)
    ax.grid(True, alpha=0.3, axis='y')

    # Avoid unreadable x labels for many subjects
    step = max(1, int(np.ceil(len(df) / 25)))
    ticks = x[::step]
    ax.set_xticks(ticks)
    ax.set_xticklabels([str(int(s)) for s in df['subject_id'].to_numpy()[::step]], rotation=45, ha='right')

# Training/validation subjects
fig, axes = plt.subplots(2, 2, figsize=(16, 10))
axes = axes.flatten()

for ax, (model_name, model_data) in zip(axes, model_outputs.items()):
    val_ids = model_data.get('val_subject_ids')
    val_acc = model_data.get('val_accuracies')
    if val_ids is None or val_acc is None:
        ax.text(0.5,0.5,'Missing val_subject_ids/val_accuracies',ha='center',va='center',transform=ax.transAxes)
        ax.set_title(model_short_key(model_name))
        continue

    df = pd.DataFrame({'subject_id': np.asarray(val_ids, dtype=int), 'accuracy': np.asarray(val_acc, dtype=float)})
    _bar_by_subject(ax, df, f"{model_short_key(model_name)} (train/val)")

plt.tight_layout()
plt.show()

# Test subjects
fig, axes = plt.subplots(2, 2, figsize=(16, 10))
axes = axes.flatten()

for ax, (model_name, model_data) in zip(axes, model_outputs.items()):
    dfm = model_data.get('per_subject_metrics')
    if dfm is None or len(dfm) == 0:
        ax.text(0.5,0.5,'Missing per-subject metrics',ha='center',va='center',transform=ax.transAxes)
        ax.set_title(model_short_key(model_name))
        continue

    dfm = normalize_per_subject_metrics(dfm)
    if 'subject_id' not in dfm.columns or 'accuracy' not in dfm.columns:
        ax.text(0.5,0.5,'Bad per-subject metrics schema',ha='center',va='center',transform=ax.transAxes)
        ax.set_title(model_short_key(model_name))
        continue

    _bar_by_subject(ax, dfm[['subject_id','accuracy']], f"{model_short_key(model_name)} (test)")

plt.tight_layout()
plt.show()


In [None]:
# -----------------
# Accuracy Before/After Time Adjustment Smoothing
# -----------------

fig, axes = plt.subplots(2, 2, figsize=(16, 12))
axes = axes.flatten()

ax_idx = 0
for model_name, model_data in model_outputs.items():
    if model_data.get("cv_summary") is not None:
        df = model_data["cv_summary"]
        if "accuracy_raw" in df.columns and "accuracy_smoothed" in df.columns:
            ax = axes[ax_idx]
            
            acc_raw = df["accuracy_raw"].values
            acc_smooth = df["accuracy_smoothed"].values
            delta = acc_smooth - acc_raw
            
            x = np.arange(len(acc_raw))
            width = 0.35
            
            ax.bar(x - width/2, acc_raw, width, label="Raw", color="steelblue", alpha=0.7)
            ax.bar(x + width/2, acc_smooth, width, label="Smoothed", color="coral", alpha=0.7)
            ax.set_xlabel("Fold/Subject")
            ax.set_ylabel("Accuracy")
            ax.set_title(f"{pretty_model_name(model_name)}\nRaw vs Smoothed Accuracy")
            ax.legend()
            ax.grid(True, alpha=0.3, axis="y")
            
            # Delta plot
            ax2 = ax.twinx()
            ax2.plot(x, delta, "g-o", markersize=4, label="Delta", alpha=0.6)
            ax2.axhline(0, color="black", linestyle="--", linewidth=0.5)
            ax2.set_ylabel("Delta Accuracy", color="green")
            ax2.tick_params(axis="y", labelcolor="green")
            
            ax_idx += 1
            if ax_idx >= 4:
                break

plt.tight_layout()
plt.show()

### 5.4 Cross-Model Comparison (Box Plots)

In [None]:
# -----------------
# Box Plot: Accuracies per Subject for Each Model (short labels + outliers)
# -----------------

per_model = []
labels = []
per_model_df = {}

for model_name, model_data in model_outputs.items():
    dfm = model_data.get('per_subject_metrics')
    if dfm is None or len(dfm) == 0:
        continue
    dfm = normalize_per_subject_metrics(dfm)
    if 'subject_id' not in dfm.columns or 'accuracy' not in dfm.columns:
        continue
    dfm = dfm.dropna(subset=['subject_id']).copy()
    dfm['subject_id'] = dfm['subject_id'].astype(int)

    key = model_short_key(model_name)
    per_model.append(dfm['accuracy'].to_numpy(dtype=float))
    labels.append(key)
    per_model_df[key] = dfm[['subject_id','accuracy']].copy()

if not per_model:
    print('No per-subject metrics available for boxplot.')
else:
    fig, ax = plt.subplots(1, 1, figsize=(10, 5))
    bp = ax.boxplot(per_model, labels=labels, patch_artist=True)
    for patch in bp['boxes']:
        patch.set_facecolor('lightblue')
        patch.set_alpha(0.75)
    ax.set_ylabel('Accuracy')
    ax.set_title('Accuracy per subject (test)')
    ax.set_ylim(0, 1.02)
    ax.grid(True, alpha=0.3, axis='y')
    plt.tight_layout()
    plt.show()

    # Print outliers (1.5*IQR rule)
    print('\nOutliers per model (1.5*IQR):')
    for key, dfm in per_model_df.items():
        acc = dfm['accuracy'].to_numpy(dtype=float)
        q1 = np.nanpercentile(acc, 25)
        q3 = np.nanpercentile(acc, 75)
        iqr = q3 - q1
        lo = q1 - 1.5 * iqr
        hi = q3 + 1.5 * iqr
        out = dfm[(dfm['accuracy'] < lo) | (dfm['accuracy'] > hi)].sort_values('accuracy')
        if out.empty:
            print(f"  {key}: (none)")
        else:
            print(f"  {key}: {len(out)} outlier subject(s)")
            print(out.to_string(index=False))


## 6. Challenging Subjects Analysis

In [None]:
# -----------------
# Top 5 Worst-Performing Subjects per Model
# -----------------

def get_worst_subjects(model_data: Dict, metadata_df: pd.DataFrame, n=5):
    df = model_data.get('per_subject_metrics')
    if df is None or len(df) == 0:
        return None
    df = normalize_per_subject_metrics(df)
    if 'subject_id' not in df.columns or 'accuracy' not in df.columns:
        return None

    out = df.dropna(subset=['subject_id']).copy()
    out['subject_id'] = out['subject_id'].astype(int)
    out = out.nsmallest(int(n), 'accuracy')[['subject_id','accuracy']]

    if len(metadata_df) and 'subject_id' in metadata_df.columns and 'age' in metadata_df.columns:
        out = out.merge(metadata_df[['subject_id','age']], on='subject_id', how='left')

    return out

for model_name, model_data in model_outputs.items():
    key = model_short_key(model_name)
    meta = metadata_old_small if key.startswith('old_') else metadata_new_small

    worst = get_worst_subjects(model_data, meta, n=5)
    print('\n' + '='*60)
    print(f"Top 5 worst subjects: {key}")
    print('='*60)
    # Add CF for FOOOF models (subject-level alpha center frequency)
    if model_data.get('is_fooof'):
        try:
            cf_series = fooof_cf_by_subject_old if key.startswith('old_') else fooof_cf_by_subject_new
            if cf_series is not None and len(cf_series):
                cf_map = {int(k): float(v) for k,v in zip(cf_series.index.astype(int), cf_series.values) if pd.notna(v)}
                worst['alpha_cf'] = worst['subject_id'].astype(int).map(cf_map)
        except Exception:
            pass

    if worst is None or len(worst) == 0:
        print('No per-subject metrics available')
    else:
        print(worst.to_string(index=False))


## 7. PSD vs FOOOF Comparison

### 7.1 Confidence Comparison Plots

In [None]:
# -----------------
# PSD vs FOOOF Confidence Comparison (Old Dataset)
# -----------------

def plot_psd_vs_fooof(fooof_data: Dict, psd_data: Dict, dataset_name: str):
    """Create scatterplot with marginal histograms comparing FOOOF vs PSD predictions."""
    # Load and align predictions
    fooof_idx = fooof_data.get("epoch_idx")
    fooof_y = fooof_data.get("y_true")
    fooof_p = fooof_data.get("prob_ec")
    
    psd_idx = psd_data.get("epoch_idx")
    psd_y = psd_data.get("y_true")
    psd_p = psd_data.get("prob_ec")
    
    if fooof_idx is None or fooof_y is None or fooof_p is None:
        print(f"Skipping {dataset_name} - missing FOOOF data")
        return
    if psd_idx is None or psd_y is None or psd_p is None:
        print(f"Skipping {dataset_name} - missing PSD data")
        return
    
    # Sanity checks: same epochs and labels
    if not np.array_equal(fooof_idx, psd_idx):
        print(f"Warning: {dataset_name} - epoch indices don't match, aligning...")
        # Align by finding common indices
        common_idx = np.intersect1d(fooof_idx, psd_idx)
        fooof_mask = np.isin(fooof_idx, common_idx)
        psd_mask = np.isin(psd_idx, common_idx)
        
        fooof_y = fooof_y[fooof_mask]
        fooof_p = fooof_p[fooof_mask]
        psd_y = psd_y[psd_mask]
        psd_p = psd_p[psd_mask]
        
        if not np.array_equal(fooof_y, psd_y):
            print(f"Warning: {dataset_name} - labels don't match after alignment")
    
    y_true = fooof_y
    x = fooof_p  # FOOOF model P(EC)
    y = psd_p    # PSD model P(EC)
    
    # Masks and colors: EO=0 (blue), EC=1 (red)
    eo_mask = (y_true == 0)
    ec_mask = (y_true == 1)
    
    fig = plt.figure(figsize=(8, 8))
    gs = gridspec.GridSpec(2, 2, width_ratios=[4, 1], height_ratios=[1, 4],
                           wspace=0.05, hspace=0.05)
    
    ax_scatter = fig.add_subplot(gs[1, 0])
    ax_histx = fig.add_subplot(gs[0, 0], sharex=ax_scatter)
    ax_histy = fig.add_subplot(gs[1, 1], sharey=ax_scatter)
    
    # Scatter
    ax_scatter.scatter(x[eo_mask], y[eo_mask], c="blue", alpha=0.5, label="EO", s=20)
    ax_scatter.scatter(x[ec_mask], y[ec_mask], c="red", alpha=0.5, label="EC", s=20)
    ax_scatter.set_xlabel("P(EC) – FOOOF features")
    ax_scatter.set_ylabel("P(EC) – PSD features")
    ax_scatter.legend(loc="lower right")
    ax_scatter.grid(alpha=0.2)
    ax_scatter.set_title(f"{dataset_name}: FOOOF vs PSD Predictions")
    
    # Marginal histograms (x-axis)
    ax_histx.hist(x[eo_mask], bins=30, color="blue", alpha=0.4, density=True)
    ax_histx.hist(x[ec_mask], bins=30, color="red", alpha=0.4, density=True)
    ax_histx.tick_params(axis="x", labelbottom=False)
    ax_histx.set_ylabel("Density")
    
    # Marginal histograms (y-axis)
    ax_histy.hist(y[eo_mask], bins=30, orientation="horizontal", color="blue", alpha=0.4, density=True)
    ax_histy.hist(y[ec_mask], bins=30, orientation="horizontal", color="red", alpha=0.4, density=True)
    ax_histy.tick_params(axis="y", labelleft=False)
    ax_histy.set_xlabel("Density")
    
    plt.show()

# Old dataset comparison
if old_fooof_model and old_psd_model:
    plot_psd_vs_fooof(old_fooof_model, old_psd_model, "Old Dataset")

# New dataset comparison
if new_fooof_model and new_psd_model:
    plot_psd_vs_fooof(new_fooof_model, new_psd_model, "New Dataset")

### 7.2 Disagreement Analysis: Example Epoch where Models Disagree

In [None]:
# -----------------
# Find and Plot Epoch where FOOOF is confident but PSD predicts opposite (with PSD visualization)
# -----------------

def _get_model_by_prefix(prefix: str) -> Dict | None:
    for name, data in model_outputs.items():
        if prefix in name:
            return data
    return None

def _find_confident_disagreement(fooof_data: Dict, psd_data: Dict, idx_map: pd.DataFrame):
    df_f = attach_subject_mapping(fooof_data, idx_map)
    if df_f.empty:
        print('No mapped FOOOF predictions (attach_subject_mapping returned empty).')
        return None

    df_p = pd.DataFrame({
        'global_idx': np.asarray(psd_data.get('epoch_idx'), dtype=int),
        'prob_ec_psd': np.asarray(psd_data.get('prob_ec'), dtype=float),
    })

    merged = df_f.merge(df_p, on='global_idx', how='inner')
    if merged.empty:
        print('No overlap between FOOOF and PSD epoch_idx arrays after mapping.')
        return None

    merged['y_pred_psd'] = (merged['prob_ec_psd'] >= 0.5).astype(int)
    merged['confident_fooof'] = (merged['prob_ec'] > 0.90) | (merged['prob_ec'] < 0.10)
    merged['disagree'] = (merged['y_pred'] != merged['y_pred_psd']) & merged['confident_fooof']

    cand = merged[merged['disagree']]
    print(f"Disagreement candidates: {len(cand)} / merged={len(merged)}")
    if cand.empty:
        return None
    return cand.iloc[0].to_dict()

def _plot_epoch_psd(path: str, epoch_orig_idx: int, title: str):
    epochs = mne.io.read_epochs_eeglab(path, verbose='ERROR') if str(path).lower().endswith('.set') else mne.read_epochs(path, verbose='ERROR')
    epochs = rename_epochs_channels_canonical(epochs)
    epoch_orig_idx = int(epoch_orig_idx)
    data = epochs.get_data()[epoch_orig_idx:epoch_orig_idx+1]
    sfreq = float(epochs.info['sfreq'])

    picks_all = mne.pick_types(epochs.info, eeg=True, meg=False, stim=False, eog=False, exclude='bads')
    if len(picks_all) == 0:
        picks_all = mne.pick_types(epochs.info, eeg=True, meg=False, stim=False, eog=False, exclude=[])

    n_per_seg = int(min(data.shape[-1], max(8, int(round(2.0*sfreq)))))
    n_overlap = int(round(1.0*sfreq))
    if n_per_seg <= 1:
        n_overlap = 0
    else:
        n_overlap = min(n_overlap, n_per_seg - 1)

    psd, freqs = mne.time_frequency.psd_array_welch(
        data[:, picks_all, :], sfreq=sfreq, fmin=1.0, fmax=45.0,
        n_per_seg=n_per_seg, n_overlap=n_overlap,
        average='mean', window='hann', verbose=False,
    )
    mean_psd = psd.mean(axis=(0,1))

    names = [canonical_channel_name(ch).upper() for ch in epochs.ch_names]
    occ = [i for i,n in enumerate(names) if str(n).upper() in STANDARD_OCCIPITAL_SET]
    occ_psd = None
    if occ:
        psd_occ, _ = mne.time_frequency.psd_array_welch(
            data[:, occ, :], sfreq=sfreq, fmin=1.0, fmax=45.0,
            n_per_seg=n_per_seg, n_overlap=n_overlap,
            average='mean', window='hann', verbose=False,
        )
        occ_psd = psd_occ.mean(axis=(0,1))

    plt.figure(figsize=(10,4))
    plt.plot(freqs, mean_psd, label='All channels (avg)')
    if occ_psd is not None:
        plt.plot(freqs, occ_psd, label='Occipital ROI (avg)')
    plt.xlabel('Frequency (Hz)')
    plt.ylabel('PSD')
    plt.title(title)
    plt.grid(True, alpha=0.3)
    plt.legend()
    plt.tight_layout()
    plt.show()

# Old dataset
old_fooof = _get_model_by_prefix('old_dataset__fooof')
old_psd = _get_model_by_prefix('old_dataset__no_fooof')
if old_fooof and old_psd:
    row = _find_confident_disagreement(old_fooof, old_psd, index_map_old)
    if row:
        print('OLD disagreement:', {k: row.get(k) for k in ['subject_id','epoch_orig_idx','file','y_true_model','prob_ec','prob_ec_psd']})
        _plot_epoch_psd(row['file'], row['epoch_orig_idx'], f"OLD disagreement | subj {int(row['subject_id'])} | epoch {int(row['epoch_orig_idx'])}")
    else:
        print('No confident disagreement found for OLD dataset')

# New dataset
new_fooof = _get_model_by_prefix('new_dataset__fooof')
new_psd = _get_model_by_prefix('new_dataset__no_fooof')
if new_fooof and new_psd:
    row = _find_confident_disagreement(new_fooof, new_psd, index_map_new)
    if row:
        print('NEW disagreement:', {k: row.get(k) for k in ['subject_id','epoch_orig_idx','file','y_true_model','prob_ec','prob_ec_psd']})
        _plot_epoch_psd(row['file'], row['epoch_orig_idx'], f"NEW disagreement | subj {int(row['subject_id'])} | epoch {int(row['epoch_orig_idx'])}")
    else:
        print('No confident disagreement found for NEW dataset')


## 8. Confident Model Performance

Analyzing performance on predictions where models are confident (>0.90 for EC, <0.10 for EO).

In [None]:
# -----------------
# Confident Model Performance Analysis
# -----------------

def analyze_confident_predictions(model_data: Dict, model_name: str):
    """Analyze performance on confident predictions only."""
    if model_data.get("y_true") is None or model_data.get("prob_ec") is None:
        return None
    
    y_true = model_data["y_true"]
    prob_ec = model_data["prob_ec"]
    
    # Filter confident predictions: EC > 0.90 or EO < 0.10
    confident_mask = (prob_ec > 0.90) | (prob_ec < 0.10)
    
    if np.sum(confident_mask) == 0:
        print(f"\n{model_name}: No confident predictions found")
        return None
    
    y_true_conf = y_true[confident_mask]
    prob_ec_conf = prob_ec[confident_mask]
    y_pred_conf = (prob_ec_conf >= 0.5).astype(int)
    
    print(f"\n{'='*60}")
    print(f"Confident Predictions: {model_name}")
    print(f"{'='*60}")
    print(f"Total confident predictions: {np.sum(confident_mask)} / {len(y_true)} ({100*np.sum(confident_mask)/len(y_true):.1f}%)")
    
    # Aggregated results
    accuracy = accuracy_score(y_true_conf, y_pred_conf)
    print(f"\nAccuracy on confident predictions: {accuracy:.4f}")
    
    # Classification report
    print(f"\nClassification Report:")
    print(classification_report(y_true_conf, y_pred_conf, target_names=["EO", "EC"], digits=4))
    
    # ROC AUC
    try:
        auc = roc_auc_score(y_true_conf, prob_ec_conf)
        print(f"\nROC AUC: {auc:.4f}")
    except Exception as e:
        print(f"\nROC AUC: Error - {e}")
        auc = None
    
    # Per-subject accuracy (if per_subject_metrics available)
    # This would need to filter per-subject metrics by confident epochs
    
    return {
        "accuracy": accuracy,
        "auc": auc,
        "n_confident": np.sum(confident_mask),
        "y_true": y_true_conf,
        "y_pred": y_pred_conf,
        "prob_ec": prob_ec_conf,
    }

# Analyze confident predictions for all models
confident_performances = {}

for model_name, model_data in model_outputs.items():
    perf = analyze_confident_predictions(model_data, model_name)
    if perf is not None:
        confident_performances[model_name] = perf

# Confusion matrices for confident predictions (one per model)
if confident_performances:
    keys = list(confident_performances.keys())
    n = len(keys)
    ncols = 2
    nrows = int(np.ceil(n / ncols))
    fig, axes = plt.subplots(nrows, ncols, figsize=(10, 4*nrows))
    axes = np.atleast_1d(axes).ravel()

    for ax, model_name in zip(axes, keys):
        perf = confident_performances[model_name]
        y_true = perf['y_true']
        y_pred = perf['y_pred']
        cm = confusion_matrix(y_true, y_pred, labels=[0,1])
        sns.heatmap(cm, annot=True, fmt='d', cmap='Blues', cbar=False,
                    xticklabels=['EO','EC'], yticklabels=['EO','EC'], ax=ax)
        ax.set_xlabel('Predicted')
        ax.set_ylabel('True')
        ax.set_title(f"{model_short_key(model_name)} | n={perf.get('n_confident', len(y_true))}")

    # Hide unused axes
    for ax in axes[len(keys):]:
        ax.axis('off')

    plt.tight_layout()
    plt.show()

# Plot ROC curves for confident predictions
if confident_performances:
    fig, axes = plt.subplots(2, 2, figsize=(14, 12))
    axes = axes.flatten()
    
    for idx, (model_name, perf) in enumerate(confident_performances.items()):
        if idx >= 4:
            break
        ax = axes[idx]
        
        y_true = perf["y_true"]
        prob_ec = perf["prob_ec"]
        
        fpr, tpr, _ = roc_curve(y_true, prob_ec)
        auc_score = perf.get("auc")
        
        ax.plot(fpr, tpr, linewidth=2, label=f"ROC (AUC = {auc_score:.3f})" if auc_score else "ROC")
        ax.plot([0, 1], [0, 1], 'k--', linewidth=1, label="Random")
        ax.set_xlabel("False Positive Rate")
        ax.set_ylabel("True Positive Rate")
        ax.set_title(f"{pretty_model_name(model_name)}\n(Confident Predictions Only)")
        ax.legend()
        ax.grid(True, alpha=0.3)
    
    plt.tight_layout()
    plt.show()

## 9. Statistical Tests

Comparing models to each other and to a baseline (majority class predictor).

In [None]:
# -----------------
# Statistical Tests: Model Comparisons
# -----------------

from collections import Counter

def get_baseline_predictions(y_true):
    """Get baseline predictions (majority class)."""
    majority_class = Counter(y_true).most_common(1)[0][0]
    return np.array([majority_class] * len(y_true))

def run_mcnemar_test(y_true, pred_a, pred_b, name_a, name_b):
    """Run McNemar's test comparing two models."""
    # Contingency table: [[both correct, a only], [b only, both wrong]]
    a_correct = (pred_a == y_true)
    b_correct = (pred_b == y_true)
    
    both_correct = np.sum(a_correct & b_correct)
    a_only = np.sum(a_correct & ~b_correct)
    b_only = np.sum(~a_correct & b_correct)
    both_wrong = np.sum(~a_correct & ~b_correct)
    
    table = [[both_correct, a_only],
             [b_only, both_wrong]]
    
    try:
        result = mcnemar(table, exact=True)
        return {
            "table": table,
            "statistic": result.statistic,
            "pvalue": result.pvalue,
        }
    except Exception as e:
        print(f"Error in McNemar test: {e}")
        return None

def run_wilcoxon_test(acc_a, acc_b, name_a, name_b):
    """Run Wilcoxon signed-rank test on per-subject accuracies."""
    # Remove NaN values
    mask = ~(np.isnan(acc_a) | np.isnan(acc_b))
    acc_a_clean = acc_a[mask]
    acc_b_clean = acc_b[mask]
    
    if len(acc_a_clean) < 3:
        return None
    
    try:
        statistic, pvalue = stats.wilcoxon(acc_a_clean, acc_b_clean)
        return {
            "statistic": statistic,
            "pvalue": pvalue,
        }
    except Exception as e:
        print(f"Error in Wilcoxon test: {e}")
        return None

print("\n" + "="*80)
print("STATISTICAL TESTS: Model Comparisons")
print("="*80)

# Compare each model to baseline
print("\n" + "-"*80)
print("Model vs Baseline (Majority Class)")
print("-"*80)

baseline_results = {}
for model_name, perf in model_performances.items():
    y_true = perf["y_true"]
    y_pred = perf["y_pred"]
    baseline_pred = get_baseline_predictions(y_true)
    
    baseline_acc = accuracy_score(y_true, baseline_pred)
    model_acc = perf["accuracy"]
    
    print(f"\n{model_name}:")
    print(f"  Model accuracy: {model_acc:.4f}")
    print(f"  Baseline accuracy: {baseline_acc:.4f}")
    print(f"  Improvement: {model_acc - baseline_acc:.4f}")
    
    # McNemar test vs baseline
    result = run_mcnemar_test(y_true, y_pred, baseline_pred, model_name, "Baseline")
    if result:
        print(f"  McNemar test vs baseline:")
        print(f"    Statistic: {result['statistic']:.4f}")
        print(f"    p-value: {result['pvalue']:.6f}")
        print(f"    Significant: {'Yes' if result['pvalue'] < 0.05 else 'No'} (α=0.05)")
    
    baseline_results[model_name] = {
        "baseline_acc": baseline_acc,
        "model_acc": model_acc,
        "mcnemar": result,
    }

# Pairwise model comparisons (epoch-level using McNemar)
print("\n" + "-"*80)
print("Pairwise Model Comparisons (McNemar's Test)")
print("-"*80)

model_list = list(model_performances.items())
for i in range(len(model_list)):
    for j in range(i+1, len(model_list)):
        name_a, perf_a = model_list[i]
        name_b, perf_b = model_list[j]
        
        y_true_a = perf_a["y_true"]
        y_true_b = perf_b["y_true"]
        y_pred_a = perf_a["y_pred"]
        y_pred_b = perf_b["y_pred"]
        
        # Check if same test set (same length and labels)
        if len(y_true_a) == len(y_true_b) and np.array_equal(y_true_a, y_true_b):
            print(f"\n{name_a} vs {name_b}:")
            result = run_mcnemar_test(y_true_a, y_pred_a, y_pred_b, name_a, name_b)
            if result:
                print(f"  McNemar statistic: {result['statistic']:.4f}")
                print(f"  p-value: {result['pvalue']:.6f}")
                print(f"  Significant: {'Yes' if result['pvalue'] < 0.05 else 'No'} (α=0.05)")

# Per-subject comparisons (Wilcoxon signed-rank test)
print("\n" + "-"*80)
print("Per-Subject Model Comparisons (Wilcoxon Signed-Rank Test)")
print("-"*80)

# Collect per-subject accuracies
subject_accuracies = {}

for model_name, model_data in model_outputs.items():
    if model_data.get("per_subject_metrics") is not None:
        df = model_data["per_subject_metrics"]
        if "subject_id" in df.columns and "accuracy" in df.columns:
            subject_accuracies[model_name] = df.set_index("subject_id")["accuracy"].to_dict()

if len(subject_accuracies) >= 2:
    model_names = list(subject_accuracies.keys())
    for i in range(len(model_names)):
        for j in range(i+1, len(model_names)):
            name_a, name_b = model_names[i], model_names[j]
            
            # Get common subjects
            subjects_a = set(subject_accuracies[name_a].keys())
            subjects_b = set(subject_accuracies[name_b].keys())
            common_subjects = sorted(subjects_a & subjects_b)
            
            if len(common_subjects) >= 3:
                acc_a = np.array([subject_accuracies[name_a][s] for s in common_subjects])
                acc_b = np.array([subject_accuracies[name_b][s] for s in common_subjects])
                
                print(f"\n{name_a} vs {name_b}:")
                print(f"  Common subjects: {len(common_subjects)}")
                print(f"  Mean accuracy A: {np.mean(acc_a):.4f}")
                print(f"  Mean accuracy B: {np.mean(acc_b):.4f}")
                
                result = run_wilcoxon_test(acc_a, acc_b, name_a, name_b)
                if result:
                    print(f"  Wilcoxon statistic: {result['statistic']:.4f}")
                    print(f"  p-value: {result['pvalue']:.6f}")
                    print(f"  Significant: {'Yes' if result['pvalue'] < 0.05 else 'No'} (α=0.05)")

print("\n" + "="*80)
print("Statistical Tests Complete")
print("="*80)


In [None]:
# -----------------
# Statisticla test of the different models compared to each other
# Does is make sense to do it for all of them or just for the onse trained on the same dataset? (pairwise model comparisons) Bonferroni correction?
#

In [None]:
# Specific subjects analysis
# Analyze test accuracies for subjects across OLD models.

SPEC_SUBJECTS = [10135, 10171, 10193, 10203, 10204]
SPECIFIC_USE = 'fooof_only'  # 'fooof_only' or 'both'
USE_PSD = SPECIFIC_USE not in {'fooof_only','fooof'}

old_fooof = _get_model_by_prefix('old_dataset__fooof')
old_psd = _get_model_by_prefix('old_dataset__no_fooof') if USE_PSD else None

if old_fooof is None:
    print('Missing old FOOOF model outputs.')
elif USE_PSD and old_psd is None:
    print('Missing old PSD model outputs (set SPECIFIC_USE=fooof_only to skip).')
else:
    df_fooof = attach_subject_mapping(old_fooof, index_map_old)
    df_fooof = df_fooof[df_fooof['subject_id'].isin(SPEC_SUBJECTS)].copy()

    df_psd = pd.DataFrame()
    if old_psd is not None:
        df_psd = attach_subject_mapping(old_psd, index_map_old)
        df_psd = df_psd[df_psd['subject_id'].isin(SPEC_SUBJECTS)].copy()

    print('Rows (fooof):', len(df_fooof), '| Rows (psd):', len(df_psd))

    def _summ(df: pd.DataFrame, name: str):
        if df is None or df.empty:
            print(name, ': no rows')
            return None
        y_true = df['y_true_model'].to_numpy(dtype=int)
        prob = df['prob_ec'].to_numpy(dtype=float)
        y_pred = (prob >= 0.5).astype(int)
        print("\n" + "="*60)
        print(name)
        print("="*60)
        print('Accuracy:', accuracy_score(y_true, y_pred))
        print('Confusion:\n', confusion_matrix(y_true, y_pred))
        try:
            print('AUC:', roc_auc_score(y_true, prob))
        except Exception:
            pass

        per_subj = df.groupby('subject_id').apply(lambda g: accuracy_score(g['y_true_model'], (g['prob_ec']>=0.5).astype(int))).reset_index(name='accuracy')
        print('\nPer-subject accuracy:')
        print(per_subj.sort_values('subject_id').to_string(index=False))
        return per_subj

    per_fooof = _summ(df_fooof, 'old_fooof (specific subjects)')
    per_psd = _summ(df_psd, 'old_psd (specific subjects)') if USE_PSD else None

    # ROC curves
    fig, ax = plt.subplots(1,1,figsize=(8,6))
    if df_fooof is not None and not df_fooof.empty:
        fpr, tpr, _ = roc_curve(df_fooof['y_true_model'].to_numpy(dtype=int), df_fooof['prob_ec'].to_numpy(dtype=float))
        ax.plot(fpr, tpr, label='old_fooof')
    if USE_PSD and df_psd is not None and not df_psd.empty:
        fpr, tpr, _ = roc_curve(df_psd['y_true_model'].to_numpy(dtype=int), df_psd['prob_ec'].to_numpy(dtype=float))
        ax.plot(fpr, tpr, label='old_psd')
    ax.plot([0,1],[0,1],'k--',alpha=0.4)
    ax.set_xlabel('False positive rate')
    ax.set_ylabel('True positive rate')
    ax.set_title('ROC (specific subjects, OLD dataset)')
    ax.grid(True, alpha=0.3)
    ax.legend()
    plt.tight_layout(); plt.show()

    # Confusion matrices
    if USE_PSD:
        fig, axes = plt.subplots(1,2,figsize=(10,4))
        for ax, df, label in [(axes[0], df_fooof, 'old_fooof'), (axes[1], df_psd, 'old_psd')]:
            if df is None or df.empty:
                ax.axis('off'); continue
            y_true = df['y_true_model'].to_numpy(dtype=int)
            y_pred = (df['prob_ec'].to_numpy(dtype=float) >= 0.5).astype(int)
            cm = confusion_matrix(y_true, y_pred)
            sns.heatmap(cm, annot=True, fmt='d', cmap='Blues', ax=ax, xticklabels=['EO','EC'], yticklabels=['EO','EC'])
            ax.set_title(label)
            ax.set_xlabel('Predicted')
            ax.set_ylabel('True')
        plt.tight_layout(); plt.show()

    # Accuracy per subject plot
    if per_fooof is not None and (not USE_PSD or per_psd is not None):
        if USE_PSD:
            merged = per_fooof.merge(per_psd, on='subject_id', how='outer', suffixes=('_fooof','_psd')).sort_values('subject_id')
            x = np.arange(len(merged))
            width = 0.35
            plt.figure(figsize=(10,4))
            plt.bar(x - width/2, merged['accuracy_fooof'], width, label='old_fooof')
            plt.bar(x + width/2, merged['accuracy_psd'], width, label='old_psd')
            plt.xticks(x, [str(int(s)) for s in merged['subject_id']], rotation=45, ha='right')
            plt.ylim(0, 1.02)
            plt.ylabel('accuracy')
            plt.title('Per-subject accuracy (specific subjects)')
            plt.grid(True, alpha=0.3, axis='y')
            plt.legend(); plt.tight_layout(); plt.show()
        else:
            merged = per_fooof.sort_values('subject_id')
            x = np.arange(len(merged))
            plt.figure(figsize=(8,4))
            plt.bar(x, merged['accuracy'], label='old_fooof')
            plt.xticks(x, [str(int(s)) for s in merged['subject_id']], rotation=45, ha='right')
            plt.ylim(0, 1.02)
            plt.ylabel('accuracy')
            plt.title('Per-subject accuracy (specific subjects, fooof only)')
            plt.grid(True, alpha=0.3, axis='y')
            plt.legend(); plt.tight_layout(); plt.show()
