# EEG Classification Analysis

This notebook performs comprehensive analysis of labeled EEG dataset with FOOOF parameters:
- Mean center frequency (FOOOF CF) vs age
- Top 60 EC/EO epochs selection
- Mean absolute alpha power calculations
- Visualizations and statistics

## 1. Imports and Configuration

This notebook analyzes classifier labeling outputs and relates them to age/sex metadata.

Key ideas:
- Epoch selection is based on classifier confidence (from `prob_ec`).
- Relative alpha is computed from the EEG time series using Welch PSD (alpha 8–13 Hz divided by total 1–40 Hz).
- Expensive steps write caches under the run-specific `OUTPUT_DIR` so reruns can be faster.


In [None]:
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
import mne
from mne.time_frequency import psd_array_welch
import os
import json
from pathlib import Path
from typing import Dict, List, Optional, Tuple
import glob
import re

# Set style
sns.set_style("whitegrid")
plt.rcParams['figure.figsize'] = (12, 7)
from scipy.interpolate import make_interp_spline
from scipy import stats


In [None]:
# --- Configuration (repo-connected + Windows/WSL friendly) ---
import platform

def _detect_notebook_path() -> Optional[Path]:
    """Detect the current notebook path (best-effort)."""
    try:
        vsc = globals().get("__vsc_ipynb_file__", None)
        if vsc:
            p = Path(str(vsc)).expanduser()
            if p.suffix.lower() == ".ipynb" and p.exists():
                return p.resolve()
    except Exception:
        pass
    for key in ("NOTEBOOK_PATH", "IPYNB_PATH"):
        v = os.getenv(key)
        if v:
            try:
                p = Path(v).expanduser()
                if p.suffix.lower() == ".ipynb" and p.exists():
                    return p.resolve()
            except Exception:
                pass
    cwd = Path.cwd().resolve()
    for _ in range(6):
        cand = cwd / "New_EEG" / "analysis_script.ipynb"
        if cand.exists():
            return cand.resolve()
        cwd = cwd.parent
    return None

def _is_wsl() -> bool:
    try:
        return "microsoft" in platform.uname().release.lower()
    except Exception:
        return False

def resolve_windows_path(p: str) -> Path:
    """Convert Windows drive paths when running on WSL; otherwise return as-is."""
    s = str(p)
    if _is_wsl():
        m = re.match(r"^([A-Za-z]):[\\/](.*)$", s)
        if m:
            drive = m.group(1).lower()
            rest = m.group(2).replace("\\", "/")
            return Path(f"/mnt/{drive}/{rest}")
    return Path(s)

NOTEBOOK_PATH = _detect_notebook_path()
NOTEBOOK_DIR = NOTEBOOK_PATH.parent if NOTEBOOK_PATH is not None else Path.cwd().resolve()
REPO_ROOT = NOTEBOOK_DIR.parent if NOTEBOOK_DIR.name == "New_EEG" else NOTEBOOK_DIR

# Which labeling run to analyze
RUN_FOLDER = os.getenv(
    "RUN_FOLDER",
    "old_dataset__fooof__allch__cv2__time_align_conditions__one_main_fooof__mainfooof_all_epochs__pen_l2",
)

# --- Analysis toggles ---
# Drop center-frequency datapoints where cf==0 (often indicates a failed/empty fit)
DROP_ZERO_CF = os.getenv("DROP_ZERO_CF", "1").strip() not in {"0", "false", "False"}
ZERO_CF_EPS = float(os.getenv("ZERO_CF_EPS", "0"))  # set e.g. 1e-6 to drop near-zero too

# Label mapping used throughout this notebook (from Label_with_EC_EO_Classifier):
# prob_ec = P(EC), so label 1 == EC and label 0 == EO
LABEL_EO = 0
LABEL_EC = 1

# Epoch-selection settings (section 6)
CONF_THRESH = float(os.getenv("CONF_THRESH", "0.90"))  # symmetric rule uses prob_ec >= t for EC and <= (1-t) for EO
USE_SELECTION_CACHE = os.getenv("USE_SELECTION_CACHE", "1").strip() not in {"0", "false", "False"}

# Epoch selection mode:
# - "threshold": keep ALL epochs above CONF_THRESH (symmetric rule on prob_ec)
# - "top_k": keep TOP_K most confident epochs per subject per class (EC and EO)
SELECTION_MODE = os.getenv("SELECTION_MODE", "top_k").strip().lower()  # threshold | top_k
TOP_K = int(os.getenv("TOP_K", "60"))

# Alpha power computation settings (section 7)
USE_ALPHA_CACHE = os.getenv("USE_ALPHA_CACHE", "1").strip() not in {"0", "false", "False"}
PSD_ALPHA_FMIN = 8.0
PSD_ALPHA_FMAX = 13.0
PSD_TOTAL_FMIN = 1.0
PSD_TOTAL_FMAX = 40.0
# Backwards-compatible aliases
PSD_FMIN = PSD_ALPHA_FMIN
PSD_FMAX = PSD_ALPHA_FMAX
PSD_N_FFT = 200  # matches Project-main/mean_alpha_power.ipynb defaults
PSD_CHUNK_EPOCHS = 200  # process epochs in chunks to limit memory
OCCIPITAL_ROI = ["O1", "O2"]  # matches Project-main/mean_alpha_power.ipynb Step 5

# If channels are generic ("Ch1".."Ch19"), optionally map them to standard 10–20 names.
# NOTE: This assumes the dataset uses the common 19-channel ordering (Ch1..Ch19).
AUTO_RENAME_CH1_TO_1020 = os.getenv("AUTO_RENAME_CH1_TO_1020", "1").strip() not in {"0", "false", "False"}
CH1_TO_1020_ORDER_19 = [
    "Fp1", "Fp2", "F3", "F4", "C3", "C4", "P3", "P4", "O1", "O2",
    "F7", "F8", "T7", "T8", "P7", "P8", "Fz", "Cz", "Pz",
]




# Labeling outputs root (external default on Windows; repo fallback)
LABELING_DIR = os.getenv("LABELING_DIR", r"G:\\ChristianMusaeus\\labeling")
labeling_root = resolve_windows_path(LABELING_DIR)
if labeling_root.name.lower() == "labeling":
    labeling_root = labeling_root / "preprocessed_setfiles"
if not labeling_root.exists():
    labeling_root = NOTEBOOK_DIR / "outputs" / "labeling" / "preprocessed_setfiles"

# Labeled predictions CSV (override with LABEL_PREDICTIONS_CSV)
LABELED_DATASET_PATH = os.getenv(
    "LABEL_PREDICTIONS_CSV",
    str(labeling_root / RUN_FOLDER / "label_predictions.csv"),
)

# Saved ONE_MAIN_FOOOF cache root (override with SAVED_FOOOF_DIR)
_saved_fooof_env = os.getenv("SAVED_FOOOF_DIR", "") or os.getenv("SAVED_FOOOF_ROOT", "")
if _saved_fooof_env:
    SAVED_FOOOF_PATH = str(resolve_windows_path(_saved_fooof_env))
else:
    cand_win = Path(r"G:\\ChristianMusaeus\\saved_fooof")
    if platform.system() == "Windows" and cand_win.exists():
        SAVED_FOOOF_PATH = str(cand_win)
    else:
        SAVED_FOOOF_PATH = str(NOTEBOOK_DIR / "outputs" / "saved_fooof")

# Metadata (old/new external defaults; override with METADATA_CSV)
METADATA_OLD_DEFAULT = os.getenv("METADATA_OLD_CSV", r"G:\\ChristianMusaeus\\metadata_time_filtered.csv")
METADATA_NEW_DEFAULT = os.getenv("METADATA_NEW_CSV", r"G:\\ChristianMusaeus\\EEG_sub_data_pseudoanonym.csv")
default_meta = METADATA_NEW_DEFAULT if str(RUN_FOLDER).startswith("new_dataset") else METADATA_OLD_DEFAULT
METADATA_PATH = os.getenv("METADATA_CSV", default_meta)
meta_path_obj = resolve_windows_path(METADATA_PATH)
if not meta_path_obj.exists():
    # Repo fallback (old metadata is available in-repo; new metadata may not be)
    cand_repo = REPO_ROOT / "data" / "metadata" / "metadata_time_filtered.csv"
    if cand_repo.exists():
        meta_path_obj = cand_repo
METADATA_PATH = str(meta_path_obj)

# Output folder (always under this notebook's outputs)
OUTPUTS_ROOT = NOTEBOOK_DIR / "outputs"
OUTPUT_DIR = str(OUTPUTS_ROOT / "analysis_script" / RUN_FOLDER)
os.makedirs(OUTPUT_DIR, exist_ok=True)

print(f"Notebook dir: {NOTEBOOK_DIR}")
print(f"Outputs root: {OUTPUTS_ROOT}")
print(f"Run folder: {RUN_FOLDER}")
print(f"Label predictions: {LABELED_DATASET_PATH}")
print(f"Saved FOOOF root: {SAVED_FOOOF_PATH}")
print(f"Metadata: {METADATA_PATH}")
print(f"Output directory: {OUTPUT_DIR}")

## 2. Helper Functions

In [None]:
def load_labeled_data(path: str) -> pd.DataFrame:
    """Load labeled predictions CSV files and concatenate them."""
    if os.path.isfile(path):
        # Single CSV file
        df = pd.read_csv(path)
        return df
    elif os.path.isdir(path):
        # Directory with multiple CSV files
        csv_files = glob.glob(os.path.join(path, "*.csv"))
        if not csv_files:
            raise FileNotFoundError(f"No CSV files found in {path}")
        
        all_data = []
        for csv_file in csv_files:
            df = pd.read_csv(csv_file)
            # Extract subject ID from filename if not in dataframe
            if 'subject_id' not in df.columns and 'Test subject ID' not in df.columns:
                filename = os.path.basename(csv_file)
                # Try to extract subject ID from filename
                parts = filename.replace('.csv', '').split('_')
                for part in parts:
                    if part.startswith('sub-'):
                        df['subject_id'] = part
                        break
            all_data.append(df)
        return pd.concat(all_data, ignore_index=True)
    else:
        raise FileNotFoundError(f"Path not found: {path}")

def normalize_subject_id(subject_id) -> str:
    """Normalize subject ID to consistent format (sub-XXX)."""
    if pd.isna(subject_id):
        return None
    
    subject_str = str(subject_id).strip()
    
    # If already in sub-XXX format
    if subject_str.startswith('sub-'):
        return subject_str
    
    # If it's a number, convert to sub-XXX
    try:
        subject_num = int(float(subject_str))
        return f"sub-{subject_num:03d}"
    except (ValueError, TypeError):
        # If it's not a number, try to extract number from string
        match = re.search(r'\d+', subject_str)
        if match:
            subject_num = int(match.group())
            return f"sub-{subject_num:03d}"
    
    return subject_str

In [None]:
def _extract_subject_num(subject_id) -> Optional[int]:
    # Extract numeric subject id from strings like 'sub-001', 'Sub001', 1.
    if subject_id is None or (isinstance(subject_id, float) and np.isnan(subject_id)):
        return None
    s = str(subject_id).strip()
    m = re.search(r"(\d+)", s)
    if not m:
        return None
    try:
        return int(m.group(1))
    except Exception:
        return None

_SAVED_FOOOF_INDEX = None
_SAVED_FOOOF_INDEX_ROOT = None

def _build_saved_fooof_index(saved_fooof_path: str) -> Dict[str, str]:
    # Build an index {subject_key -> npz_path} by scanning saved_fooof_path.
    root = Path(saved_fooof_path)
    index: Dict[str, str] = {}
    if not root.exists():
        return index

    npz_files = list(root.rglob('*.npz'))
    for p in npz_files:
        p_str = str(p).lower()
        name = p.name.lower()

        m = re.search(r"(?:subject|sub)[-_]?(\d{1,6})", name)
        if not m:
            m = re.search(r"(\d{3})", name)
        if not m:
            continue

        try:
            n = int(m.group(1))
        except Exception:
            continue

        z3 = f"{n:03d}"
        keys = [
            str(n),
            z3,
            f"sub-{z3}",
            f"sub_{z3}",
            f"sub{z3}",
            f"subject_{n}",
            f"subject_{z3}",
        ]

        for k in keys:
            if k not in index:
                index[k] = str(p)
            else:
                prev = index[k].lower()
                if 'cache_b' in p_str and 'cache_b' not in prev:
                    index[k] = str(p)

    print(f"Indexed {len(npz_files)} .npz files from: {saved_fooof_path}")
    print(f"Index contains {len(index)} subject keys")
    return index

def find_saved_fooof_npz(subject_id, saved_fooof_path: str) -> Optional[str]:
    # Find saved FOOOF .npz file for a given subject (fast path + robust fallback).
    if not os.path.exists(saved_fooof_path):
        print(f"Saved FOOOF path not found: {saved_fooof_path}")
        return None

    subject_num = _extract_subject_num(subject_id)
    if subject_num is None:
        return None

    z3 = f"{subject_num:03d}"

    # Fast glob patterns
    patterns = [
        f"*subject_{subject_num}*.npz",
        f"*subject_{z3}*.npz",
        f"*sub-{z3}*.npz",
        f"*sub_{z3}*.npz",
        f"*sub{z3}*.npz",
    ]

    for pattern in patterns:
        matches = glob.glob(os.path.join(saved_fooof_path, "**", pattern), recursive=True)
        if matches:
            cache_b = [m for m in matches if 'cache_b' in m.lower()]
            return cache_b[0] if cache_b else matches[0]

    # Robust fallback: indexed scan (built once per root)
    global _SAVED_FOOOF_INDEX, _SAVED_FOOOF_INDEX_ROOT
    if _SAVED_FOOOF_INDEX is None or _SAVED_FOOOF_INDEX_ROOT != saved_fooof_path:
        _SAVED_FOOOF_INDEX = _build_saved_fooof_index(saved_fooof_path)
        _SAVED_FOOOF_INDEX_ROOT = saved_fooof_path

    keys = [
        f"sub-{z3}",
        f"sub_{z3}",
        f"sub{z3}",
        z3,
        str(subject_num),
        f"subject_{subject_num}",
        f"subject_{z3}",
    ]
    for k in keys:
        if k in _SAVED_FOOOF_INDEX:
            return _SAVED_FOOOF_INDEX[k]

    return None

def load_fooof_data(npz_path: str) -> Tuple[Optional[np.ndarray], Optional[List[str]], Optional[float]]:
    # Load saved FOOOF data from an .npz file and return (X, feature_names, alpha_cf).
    try:
        data = np.load(npz_path, allow_pickle=True)

        X = None
        feature_names = None
        alpha_cf = None

        if 'X' in data:
            X = np.asarray(data['X'], dtype=float)
        if 'feature_names' in data:
            feature_names = [str(x) for x in np.asarray(data['feature_names']).ravel().tolist()]

        # Support a few possible key names
        for k in ('alpha_cf', 'cf', 'center_frequency', 'alpha_center_frequency'):
            if k in data:
                arr = np.asarray(data[k])
                if arr.size:
                    alpha_cf = float(np.nanmean(arr))
                break

        return X, feature_names, alpha_cf
    except Exception as e:
        print(f"Error loading {npz_path}: {e}")
        return None, None, None


In [None]:
def extract_alpha_power_from_features(X: np.ndarray, feature_names: List[str], channel_filter: Optional[List[str]] = None) -> np.ndarray:
    """Extract alpha power (alpha_amp) from feature matrix.
    
    Args:
        X: Feature matrix (n_epochs x n_features)
        feature_names: List of feature names
        channel_filter: Optional list of channels to include (e.g., occipital channels)
    
    Returns:
        alpha_power: Array of alpha power per epoch (mean across selected channels)
    """
    if X is None or feature_names is None:
        return np.array([])
    
    # Find indices of alpha_amp features
    alpha_amp_indices = []
    alpha_amp_channels = []
    
    for idx, name in enumerate(feature_names):
        if name.endswith('_alpha_amp'):
            # Extract channel name (e.g., "O1_alpha_amp" -> "O1")
            channel = name.replace('_alpha_amp', '').upper()
            
            # Filter by channel if specified
            if channel_filter is None or channel in [c.upper() for c in channel_filter]:
                alpha_amp_indices.append(idx)
                alpha_amp_channels.append(channel)
    
    if not alpha_amp_indices:
        # No alpha_amp features found
        return np.zeros(X.shape[0])
    
    # Extract alpha power values (absolute value)
    alpha_power = np.abs(X[:, alpha_amp_indices])
    
    # Mean across selected channels
    mean_alpha_power = np.nanmean(alpha_power, axis=1)
    
    return mean_alpha_power

def get_occipital_channels() -> List[str]:
    """Return list of occipital channel names."""
    return ['O1', 'O2', 'OZ', 'PO3', 'POZ', 'PO4']

def canonical_channel_name(ch_name: str) -> str:
    name = str(ch_name).strip()
    name = re.sub(r"^EEG\s+", "", name, flags=re.IGNORECASE)
    name = re.sub(r"-REF$", "", name, flags=re.IGNORECASE)
    name = re.sub(r"\s+", "", name)
    return name


def maybe_rename_ch1_to_1020(epochs):
    # If channels are named Ch1..Ch19 (generic), rename them to a standard 10–20 montage.
    # Safe only when you are confident about the ordering.
    if not globals().get('AUTO_RENAME_CH1_TO_1020', False):
        return epochs

    try:
        names = [str(n).strip() for n in epochs.ch_names]
    except Exception:
        return epochs

    if len(names) != 19:
        return epochs

    # Accept CH1/ch1/Ch1
    nums = []
    for n in names:
        m = re.fullmatch(r"(?i)ch(\d+)", n)
        if not m:
            return epochs
        nums.append(int(m.group(1)))

    if set(nums) != set(range(1, 20)):
        return epochs

    order = globals().get('CH1_TO_1020_ORDER_19', None)
    if not order or len(order) != 19:
        return epochs

    mapping = {orig: str(order[num-1]) for orig, num in zip(names, nums)}

    # Avoid accidental duplicate names
    if len(set(mapping.values())) != len(mapping.values()):
        return epochs

    try:
        epochs.rename_channels(mapping)
    except Exception:
        return epochs

    return epochs

def load_epochs_any(path_str: str):
    # Load epoched EEG from either EEGLAB .set or MNE .fif epochs.
    p = resolve_windows_path(path_str)
    suf = p.suffix.lower()
    if suf == '.set':
        epochs = mne.io.read_epochs_eeglab(str(p), verbose='ERROR')
    else:
        epochs = mne.read_epochs(str(p), verbose='ERROR')

    # Optional: map generic Ch1..Ch19 to 10–20 names
    epochs = maybe_rename_ch1_to_1020(epochs)

    # Canonicalize names (strip EEG prefix / -REF suffix) to match ROI filters
    try:
        epochs.rename_channels({ch: canonical_channel_name(ch) for ch in epochs.ch_names})
    except Exception:
        pass

    return epochs


def mean_relative_alpha_psd(epochs, epoch_indices: np.ndarray, picks_occ: np.ndarray) -> tuple[float, float, int]:
    # Compute mean relative alpha = bandpower(8-13 Hz) / bandpower(1-40 Hz)
    # Returns (relative_all_channels, relative_occipital, n_epochs_used)
    epoch_indices = np.asarray(epoch_indices, dtype=int)
    if epoch_indices.size == 0:
        return float('nan'), float('nan'), 0

    picks_all = mne.pick_types(epochs.info, eeg=True, meg=False, stim=False, eog=False, exclude='bads')
    if len(picks_all) == 0:
        picks_all = np.arange(len(epochs.ch_names), dtype=int)

    # Map occipital picks (original channel indices) to indices within picks_all
    occ_in_picks_all = []
    if picks_occ is not None and len(picks_occ) > 0:
        idx_map = {int(full): i for i, full in enumerate(picks_all)}
        occ_in_picks_all = [idx_map[int(x)] for x in np.asarray(picks_occ, dtype=int) if int(x) in idx_map]

    sum_alpha_ch = None
    sum_total_ch = None
    n_used = 0

    for start in range(0, int(epoch_indices.size), int(PSD_CHUNK_EPOCHS)):
        chunk = epoch_indices[start:start+int(PSD_CHUNK_EPOCHS)]
        try:
            data = epochs[chunk].get_data()
        except Exception:
            data = epochs.get_data()[chunk]

        data = data[:, picks_all, :]

        psds_a, _ = psd_array_welch(
            data,
            sfreq=float(epochs.info['sfreq']),
            fmin=float(PSD_ALPHA_FMIN),
            fmax=float(PSD_ALPHA_FMAX),
            n_fft=int(PSD_N_FFT),
            verbose=False,
        )
        psds_t, _ = psd_array_welch(
            data,
            sfreq=float(epochs.info['sfreq']),
            fmin=float(PSD_TOTAL_FMIN),
            fmax=float(PSD_TOTAL_FMAX),
            n_fft=int(PSD_N_FFT),
            verbose=False,
        )

        # integrate across frequency bins to approximate bandpower
        pow_a = psds_a.sum(axis=-1)  # (epochs, channels)
        pow_t = psds_t.sum(axis=-1)  # (epochs, channels)

        if sum_alpha_ch is None:
            sum_alpha_ch = np.zeros(pow_a.shape[1], dtype=np.float64)
            sum_total_ch = np.zeros(pow_t.shape[1], dtype=np.float64)

        sum_alpha_ch += np.nansum(pow_a, axis=0)
        sum_total_ch += np.nansum(pow_t, axis=0)
        n_used += int(pow_a.shape[0])

    if sum_alpha_ch is None or sum_total_ch is None or n_used == 0:
        return float('nan'), float('nan'), 0

    mean_alpha_ch = sum_alpha_ch / float(n_used)
    mean_total_ch = sum_total_ch / float(n_used)

    rel_all = float(np.nanmean(mean_alpha_ch) / np.nanmean(mean_total_ch)) if np.nanmean(mean_total_ch) else float('nan')

    rel_occ = float('nan')
    if occ_in_picks_all:
        denom = float(np.nanmean(mean_total_ch[occ_in_picks_all]))
        if denom:
            rel_occ = float(np.nanmean(mean_alpha_ch[occ_in_picks_all]) / denom)

    return rel_all, rel_occ, n_used


## 3. Load Data

In [None]:
print("Loading labeled data...")
try:
    labeled_df = load_labeled_data(LABELED_DATASET_PATH)
    print(f"Loaded {len(labeled_df)} labeled epochs")
    print(f"Columns: {labeled_df.columns.tolist()}")
    print(f"\nFirst few rows:")
    print(labeled_df.head())
except Exception as e:
    print(f"Error loading labeled data: {e}")
    raise

In [None]:
# Standardize column names
if 'Test subject ID' in labeled_df.columns:
    labeled_df['subject_id'] = labeled_df['Test subject ID'].apply(normalize_subject_id)
elif 'subject_id' in labeled_df.columns:
    labeled_df['subject_id'] = labeled_df['subject_id'].apply(normalize_subject_id)

if 'Epoch number' in labeled_df.columns:
    labeled_df['epoch_idx'] = labeled_df['Epoch number']
elif 'epoch' in labeled_df.columns:
    labeled_df['epoch_idx'] = labeled_df['epoch']
elif 'epoch_idx' not in labeled_df.columns:
    # Assume epochs are sequential from 0
    labeled_df['epoch_idx'] = labeled_df.index

if 'Label' in labeled_df.columns:
    labeled_df['prediction'] = labeled_df['Label']
elif 'prediction' not in labeled_df.columns:
    raise ValueError("Could not find 'Label' or 'prediction' column")

# Keep both probability conventions when available:
# - prob_ec: probability of EC (label=1)
# - probability: probability of the predicted label (for confidence sorting)
if 'prob_ec' in labeled_df.columns:
    labeled_df['prob_ec'] = pd.to_numeric(labeled_df['prob_ec'], errors='coerce')

if 'Probability' in labeled_df.columns:
    labeled_df['probability'] = pd.to_numeric(labeled_df['Probability'], errors='coerce')
elif 'probability' in labeled_df.columns:
    labeled_df['probability'] = pd.to_numeric(labeled_df['probability'], errors='coerce')
elif 'prob_ec' in labeled_df.columns:
    # prob_ec = P(EC). Label mapping: 1=EC, 0=EO.
    # Probability of predicted label is prob_ec for EC predictions, and 1-prob_ec for EO predictions.
    labeled_df['probability'] = np.where(
        labeled_df['prediction'].astype(int) == LABEL_EC,
        labeled_df['prob_ec'],
        1.0 - labeled_df['prob_ec'],
    )
else:
    raise ValueError("Could not find 'Probability', 'probability', or 'prob_ec' column")

print(f"Standardized columns. Unique subjects: {labeled_df['subject_id'].nunique()}")
print(f"Label distribution: {labeled_df['prediction'].value_counts().to_dict()}")

# Sanity check: if prob_ec exists, verify probability is the probability of the predicted label
if 'prob_ec' in labeled_df.columns:
    df_tmp = labeled_df.dropna(subset=["prob_ec","prediction","probability"]).copy()
    df_tmp["prediction"] = df_tmp["prediction"].astype(int)
    ec = df_tmp[df_tmp["prediction"] == LABEL_EC]
    eo = df_tmp[df_tmp["prediction"] == LABEL_EO]
    if len(ec):
        print("Sanity check (EC preds): mean(probability - prob_ec)=", float((ec["probability"]-ec["prob_ec"]).mean()))
    if len(eo):
        print("Sanity check (EO preds): mean(probability - (1-prob_ec))=", float((eo["probability"]-(1.0-eo["prob_ec"])).mean()))



In [None]:
print("\nLoading metadata...")
try:
    metadata_df = pd.read_csv(METADATA_PATH)
    print(f"Loaded {len(metadata_df)} metadata entries")
    print(f"Columns: {metadata_df.columns.tolist()}")
    
    # Normalize subject ID column
    if 'Subject_ID' in metadata_df.columns:
        metadata_df['subject_id'] = metadata_df['Subject_ID'].apply(normalize_subject_id)
    elif 'subject_id' in metadata_df.columns:
        metadata_df['subject_id'] = metadata_df['subject_id'].apply(normalize_subject_id)
    else:
        # Fallback: use first column
        first = metadata_df.columns[0]
        metadata_df['subject_id'] = metadata_df[first].apply(normalize_subject_id)
    
    # Ensure Age column exists (supports old CSV: age/Age, and new: Y)
    if 'Age' not in metadata_df.columns:
        if 'age' in metadata_df.columns:
            metadata_df['Age'] = metadata_df['age']
        elif 'Y' in metadata_df.columns:
            metadata_df['Age'] = metadata_df['Y']
        else:
            raise ValueError("Could not find an age column (Age/age/Y) in metadata")
    
    print(f"\nFirst few rows:")
    print(metadata_df[['subject_id', 'Age']].head())
except Exception as e:
    print(f"Error loading metadata: {e}")
    raise

In [None]:
# Merge labeled data with metadata
cols = ['subject_id', 'epoch_idx', 'prediction', 'probability']
if 'prob_ec' in labeled_df.columns:
    cols.append('prob_ec')
if 'file' in labeled_df.columns:
    cols.append('file')

analysis_df = pd.merge(
    labeled_df[cols],
    metadata_df[['subject_id', 'Age']],
    on='subject_id',
    how='left'
)

print(f"Merged data: {len(analysis_df)} entries")
print(f"Subjects with metadata: {analysis_df['Age'].notna().sum() / len(analysis_df) * 100:.1f}%")
print(f"Unique subjects: {analysis_df['subject_id'].nunique()}")


## 4. Extract FOOOF Parameters (CF only)

We use the `saved_fooof` cache to extract **alpha center frequency (CF)** per subject.

Important:
- This is a **single value per subject** (not per epoch).
- We intentionally avoid loading the full per-epoch feature matrix, because it is large and not needed for the PSD-based alpha power computation later.


In [None]:
print("Extracting subject-level alpha center frequency (CF) from saved FOOOF cache...")

# This step only extracts the subject-level alpha center frequency (alpha_cf).
# It does NOT load the full per-epoch feature matrix (X), which would be slow and memory-heavy.

USE_CF_CACHE = os.getenv('USE_CF_CACHE', '1').strip() not in {'0','false','False'}
cf_cache = os.path.join(OUTPUT_DIR, 'cf_by_subject.csv.gz')

subject_cf_map = {}

if USE_CF_CACHE and os.path.exists(cf_cache):
    cf_df = pd.read_csv(cf_cache, compression='gzip')
    if {'subject_id','cf'}.issubset(cf_df.columns):
        subject_cf_map = dict(zip(cf_df['subject_id'].astype(str), pd.to_numeric(cf_df['cf'], errors='coerce')))
        subject_cf_map = {k: float(v) for k,v in subject_cf_map.items() if pd.notna(v)}
        print(f"Loaded cached CF for {len(subject_cf_map)} subjects: {cf_cache}")
    else:
        print(f"CF cache missing required columns; ignoring: {cf_cache}")

if not subject_cf_map:
    unique_subjects = sorted([s for s in analysis_df['subject_id'].dropna().unique()])

    for i, subject_id in enumerate(unique_subjects, start=1):
        npz_path = find_saved_fooof_npz(subject_id, SAVED_FOOOF_PATH)
        if not npz_path:
            continue
        _, _, alpha_cf = load_fooof_data(npz_path)
        if alpha_cf is None or (isinstance(alpha_cf, float) and np.isnan(alpha_cf)):
            continue
        subject_cf_map[str(subject_id)] = float(alpha_cf)

        if i % 50 == 0:
            print(f"  processed {i}/{len(unique_subjects)} subjects... (CF found for {len(subject_cf_map)})")

    cf_df = pd.DataFrame({
        'subject_id': list(subject_cf_map.keys()),
        'cf': list(subject_cf_map.values()),
    })
    cf_df.to_csv(cf_cache, index=False, compression='gzip')
    print(f"Saved CF cache: {cf_cache}")

print()
print(f"Extracted CF for {len(subject_cf_map)} subjects")


## 5. Plot: Mean Center Frequency vs Age

In [None]:
# Create dataframe for CF vs Age
cf_data = []
for subject_id, cf in subject_cf_map.items():
    age_row = metadata_df[metadata_df['subject_id'] == subject_id]
    if not age_row.empty and pd.notna(age_row.iloc[0]['Age']):
        cf_data.append({
            'subject_id': subject_id,
            'cf': cf,
            'Age': age_row.iloc[0]['Age']
        })

cf_df = pd.DataFrame(cf_data)

if len(cf_df) > 0:
    # Optional: drop cf == 0 datapoints (can heavily affect mean/CI)
    cf_df['cf'] = pd.to_numeric(cf_df['cf'], errors='coerce')
    cf_df['Age'] = pd.to_numeric(cf_df['Age'], errors='coerce')

    before_n = len(cf_df)
    before_mean = cf_df['cf'].replace([np.inf, -np.inf], np.nan).dropna().mean()

    if DROP_ZERO_CF:
        if ZERO_CF_EPS > 0:
            cf_df = cf_df[cf_df['cf'].abs() > float(ZERO_CF_EPS)]
        else:
            cf_df = cf_df[cf_df['cf'] != 0]

    after_n = len(cf_df)
    after_mean = cf_df['cf'].replace([np.inf, -np.inf], np.nan).dropna().mean()

    print(f"CF rows before filter: {before_n} | after: {after_n} | DROP_ZERO_CF={DROP_ZERO_CF} | ZERO_CF_EPS={ZERO_CF_EPS}")
    try:
        print(f"Mean CF before: {before_mean:.2f} | after: {after_mean:.2f}")
    except Exception:
        pass

    plt.figure(figsize=(10, 6))
    sns.regplot(x='Age', y='cf', data=cf_df, scatter_kws={'alpha': 0.3}, label='Data points')
    sns.lineplot(x='Age', y='cf', data=cf_df, estimator='mean', errorbar=('ci', 95),
                  color='red', linewidth=2, label='Mean with 95% CI')
    plt.title('Mean Center Frequency (FOOOF CF) vs Age (95% CI)', fontsize=14, fontweight='bold')
    plt.xlabel('Age', fontsize=12)
    plt.ylabel('Mean Center Frequency (Hz)', fontsize=12)
    plt.grid(True, alpha=0.3)
    plt.legend()
    plt.tight_layout()

    fname = 'mean_cf_vs_age.png' if not DROP_ZERO_CF else 'mean_cf_vs_age_drop0.png'
    output_path = os.path.join(OUTPUT_DIR, fname)
    plt.savefig(output_path, dpi=300, bbox_inches='tight')
    plt.show()
    print(f"Plot saved to {output_path}")
    print(f"Mean CF: {cf_df['cf'].mean():.2f} Hz, SD: {cf_df['cf'].std():.2f} Hz")
else:
    print("No CF data available for plotting")


## 6. Select EC and EO Epochs

This step decides *which epochs* to use for downstream analyses. You can choose between two modes:

- `SELECTION_MODE = 'threshold'`: keep **all** epochs above the confidence threshold using the symmetric rule on `prob_ec = P(EC)`:
  - EC kept if `prob_ec >= CONF_THRESH`
  - EO kept if `prob_ec <= 1 - CONF_THRESH`
- `SELECTION_MODE = 'top_k'`: keep the **TOP_K most confident** EC epochs and the **TOP_K most confident** EO epochs per subject (sorted by probability of the predicted label).

The selected epoch list is cached under `OUTPUT_DIR` so reruns don’t need to re-filter millions of rows.


In [None]:
print("Selecting EC/EO epochs...")

# We support two selection modes:
# - SELECTION_MODE='threshold': keep ALL epochs above the confidence threshold (symmetric rule on prob_ec)
# - SELECTION_MODE='top_k': keep TOP_K most confident epochs per subject per class
#
# prob_ec is always P(EC). With label mapping: 1=EC, 0=EO.
# The confidence of the predicted label is:
#   probability = prob_ec           if prediction==1 (EC)
#   probability = 1 - prob_ec       if prediction==0 (EO)

mode = str(SELECTION_MODE).strip().lower()
if mode not in {'threshold','top_k','topk'}:
    raise ValueError(f"Unknown SELECTION_MODE: {SELECTION_MODE!r} (use 'threshold' or 'top_k')")

if mode in {'top_k','topk'}:
    cache_path = os.path.join(OUTPUT_DIR, f"selected_epochs_topk{int(TOP_K)}.csv.gz")
else:
    cache_path = os.path.join(OUTPUT_DIR, f"selected_epochs_threshold_conf{int(round(CONF_THRESH*100))}.csv.gz")

if USE_SELECTION_CACHE and os.path.exists(cache_path):
    top_epochs_df = pd.read_csv(cache_path, compression='gzip')
    print(f"Loaded cached selection: {cache_path}")
else:
    if 'prob_ec' not in analysis_df.columns:
        raise ValueError("analysis_df is missing 'prob_ec'. Ensure the labeled predictions include prob_ec.")

    cols = ['subject_id','epoch_idx','prediction','prob_ec']
    if 'file' in analysis_df.columns:
        cols.append('file')
    df = analysis_df[cols].copy()

    df['prob_ec'] = pd.to_numeric(df['prob_ec'], errors='coerce')
    df['prediction'] = pd.to_numeric(df['prediction'], errors='coerce')
    df = df.replace([np.inf, -np.inf], np.nan).dropna(subset=['subject_id','epoch_idx','prediction','prob_ec'])
    df['prediction'] = df['prediction'].astype(int)

    # probability of predicted label
    df['probability'] = np.where(df['prediction'] == LABEL_EC, df['prob_ec'], 1.0 - df['prob_ec'])

    if mode in {'threshold'}:
        ec = df[(df['prediction'] == LABEL_EC) & (df['prob_ec'] >= float(CONF_THRESH))].copy()
        eo = df[(df['prediction'] == LABEL_EO) & (df['prob_ec'] <= float(1.0 - CONF_THRESH))].copy()
    else:
        # Top-K per subject per class based on probability
        ec = df[df['prediction'] == LABEL_EC].sort_values(['subject_id','probability'], ascending=[True, False]).groupby('subject_id', as_index=False).head(int(TOP_K)).copy()
        eo = df[df['prediction'] == LABEL_EO].sort_values(['subject_id','probability'], ascending=[True, False]).groupby('subject_id', as_index=False).head(int(TOP_K)).copy()

    ec['class'] = 'EC'
    eo['class'] = 'EO'

    top_epochs_df = pd.concat([ec, eo], ignore_index=True)

    keep_cols = ['subject_id','epoch_idx','class','prediction','prob_ec','probability']
    if 'file' in top_epochs_df.columns:
        keep_cols.append('file')
    top_epochs_df = top_epochs_df[keep_cols]

    if USE_SELECTION_CACHE:
        top_epochs_df.to_csv(cache_path, index=False, compression='gzip')
        print(f"Saved cached selection: {cache_path}")

# Report
if top_epochs_df is None or len(top_epochs_df) == 0:
    print("No epochs selected")
else:
    n_total = len(top_epochs_df)
    n_ec = int((top_epochs_df['class'] == 'EC').sum())
    n_eo = int((top_epochs_df['class'] == 'EO').sum())
    n_subj = top_epochs_df['subject_id'].nunique()
    print(f"Mode={mode} | Selected epochs: total={n_total} | EC={n_ec} | EO={n_eo} | subjects={n_subj}")

    # Class/prediction mapping check
    try:
        print("Class→unique prediction labels:")
        print(top_epochs_df.groupby('class')['prediction'].unique().to_dict())
    except Exception:
        pass

    print("First few rows:")
    print(top_epochs_df.head(10))


## 7. Calculate Mean Relative Alpha (PSD-based)

This step computes **relative alpha** from the EEG time series using Welch PSD (`psd_array_welch`).

- Alpha band: 8–13 Hz (`PSD_ALPHA_FMIN`–`PSD_ALPHA_FMAX`)
- Relative alpha is unitless: bandpower(8–13) / bandpower(1–40)
- Occipital ROI: `O1`, `O2` (computed separately)

Results are cached under `OUTPUT_DIR` with a filename that depends on `SELECTION_MODE` (and `CONF_THRESH` or `TOP_K`).


In [None]:
print("Calculating mean relative alpha (8–13 / 1–40 Hz) from EEG PSD for selected epochs...")

# Computes relative alpha from EEG using Welch PSD.
# Definition: bandpower(8–13 Hz) / bandpower(1–40 Hz) (unitless).

mode = str(SELECTION_MODE).strip().lower()
if mode in {'top_k','topk'}:
    alpha_cache = os.path.join(OUTPUT_DIR, f"relative_alpha_psd_topk{int(TOP_K)}.csv.gz")
else:
    alpha_cache = os.path.join(OUTPUT_DIR, f"relative_alpha_psd_threshold_conf{int(round(CONF_THRESH*100))}.csv.gz")

if USE_ALPHA_CACHE and os.path.exists(alpha_cache):
    alpha_power_df = pd.read_csv(alpha_cache, compression='gzip')
    print(f"Loaded cached alpha results: {alpha_cache}")
else:
    if 'top_epochs_df' not in globals() or top_epochs_df is None or len(top_epochs_df) == 0:
        print('No epochs selected (top_epochs_df is empty). Skipping alpha power computation.')
        alpha_power_df = pd.DataFrame()
    else:
        required = {'subject_id','epoch_idx','class'}
        if not required.issubset(set(top_epochs_df.columns)):
            print('top_epochs_df columns:', top_epochs_df.columns.tolist())
            raise KeyError(f"top_epochs_df must contain {sorted(required)}")

        has_file = 'file' in top_epochs_df.columns
        records = []
        n_groups = 0

        group_cols = ['subject_id'] + (['file'] if has_file else [])
        for keys, df_sub in top_epochs_df.groupby(group_cols):
            n_groups += 1
            if has_file:
                subject_id, file_path = keys
            else:
                subject_id, file_path = keys, None

            if not file_path:
                if 'file' in analysis_df.columns:
                    cand = analysis_df[analysis_df['subject_id'] == subject_id]['file'].dropna()
                    file_path = cand.iloc[0] if len(cand) else None

            if not file_path:
                continue

            try:
                epochs = load_epochs_any(str(file_path))
            except Exception as e:
                print(f"Could not load epochs for subject {subject_id}: {e}")
                continue

            ch_can = [canonical_channel_name(ch).upper() for ch in epochs.ch_names]
            roi_set = {c.upper() for c in OCCIPITAL_ROI}
            picks_occ = np.array([i for i,name in enumerate(ch_can) if name in roi_set], dtype=int)

            age_row = metadata_df[metadata_df['subject_id'] == subject_id]
            age = age_row.iloc[0]['Age'] if (not age_row.empty and pd.notna(age_row.iloc[0]['Age'])) else None

            for cls in ['EC','EO']:
                ep_idx = pd.to_numeric(df_sub[df_sub['class'] == cls]['epoch_idx'], errors='coerce').dropna().astype(int).unique()
                if ep_idx.size == 0:
                    continue

                mean_all, mean_occ, n_used = mean_relative_alpha_psd(epochs, ep_idx, picks_occ)
                if n_used == 0:
                    continue

                records.append({
                    'subject_id': subject_id,
                    'class': cls,
                    'Age': age,
                    'mean_rel_alpha_all_channels': mean_all,
                    'mean_rel_alpha_occipital_channels': mean_occ,
                    'n_epochs': n_used,
                })

            if n_groups % 50 == 0:
                print(f"Processed {n_groups} subject/file groups...")

        alpha_power_df = pd.DataFrame.from_records(records)

        if USE_ALPHA_CACHE:
            alpha_power_df.to_csv(alpha_cache, index=False, compression='gzip')
            print(f"Saved cached alpha results: {alpha_cache}")

if alpha_power_df is not None and len(alpha_power_df) > 0:
    output_path_all = os.path.join(OUTPUT_DIR, 'mean_relative_alpha_all_channels.csv')
    alpha_power_df[['subject_id', 'class', 'Age', 'mean_rel_alpha_all_channels', 'n_epochs']].to_csv(output_path_all, index=False)
    print(f"Saved mean alpha power (all channels) to {output_path_all}")

    output_path_occ = os.path.join(OUTPUT_DIR, 'mean_relative_alpha_occipital_channels.csv')
    alpha_power_df[['subject_id', 'class', 'Age', 'mean_rel_alpha_occipital_channels', 'n_epochs']].to_csv(output_path_occ, index=False)
    print(f"Saved mean alpha power (occipital channels) to {output_path_occ}")

    print("\nSummary:")
    print(alpha_power_df.groupby('class').agg({
        'mean_rel_alpha_all_channels': ['mean', 'std'],
        'mean_rel_alpha_occipital_channels': ['mean', 'std'],
        'n_epochs': ['mean', 'min', 'max'],
    }))
else:
    print('No alpha power results to save.')


## 8. Visualizations: Alpha Power vs Age

We aggregate per-subject alpha power by age and class (EC vs EO), compute a 95% confidence interval, and plot:

- scatter points for the age-binned means
- a spline-smoothed mean curve
- shaded 95% CI

This is done for both **all channels** and the **occipital ROI**.


In [None]:
# Plot: Mean relative alpha per class vs age (All Channels)
from math import ceil, floor

if len(alpha_power_df) > 0:
    df = alpha_power_df[alpha_power_df['Age'].notna()].copy()
    if len(df) == 0:
        print('No data with age information for plotting')
    else:
        df['Age'] = pd.to_numeric(df['Age'], errors='coerce')
        df = df.dropna(subset=['Age','mean_rel_alpha_all_channels','class']).copy()
        df['mean_rel_alpha_all_channels'] = pd.to_numeric(df['mean_rel_alpha_all_channels'], errors='coerce')
        df = df.replace([np.inf, -np.inf], np.nan).dropna(subset=['Age','mean_rel_alpha_all_channels','class'])

        grouped = df.groupby(['Age','class']).agg(
            MeanAlpha=('mean_rel_alpha_all_channels','mean'),
            StdAlpha=('mean_rel_alpha_all_channels','std'),
            N=('mean_rel_alpha_all_channels','count'),
        ).reset_index()

        def _ci(row):
            n = float(row['N'])
            mean = float(row['MeanAlpha'])
            std = float(row['StdAlpha']) if pd.notna(row['StdAlpha']) else 0.0
            if n < 2:
                return (mean, mean)
            se = std / np.sqrt(n)
            try:
                t = float(stats.t.ppf(0.975, df=int(n-1)))
            except Exception:
                t = 1.96
            return (mean - t*se, mean + t*se)

        ci_pairs = grouped.apply(_ci, axis=1)
        ci_df = pd.DataFrame(ci_pairs.tolist(), columns=['Lower','Upper'], index=grouped.index)
        grouped = pd.concat([grouped, ci_df], axis=1)

        plt.figure(figsize=(12, 5))
        palette = {'EC':'tab:red', 'EO':'tab:blue'}

        for cls in ['EC','EO']:
            g = grouped[grouped['class'] == cls].sort_values('Age')
            if len(g) == 0:
                continue

            grouped_clean = g.replace([np.inf, -np.inf], np.nan).dropna(subset=['Age','MeanAlpha','Lower','Upper'])
            x = grouped_clean['Age'].values
            y = grouped_clean['MeanAlpha'].values
            y_lower = grouped_clean['Lower'].values
            y_upper = grouped_clean['Upper'].values

            # ensure strictly increasing x for spline
            order = np.argsort(x)
            x, y, y_lower, y_upper = x[order], y[order], y_lower[order], y_upper[order]
            x_unique, unique_idx = np.unique(x, return_index=True)
            x, y, y_lower, y_upper = x_unique, y[unique_idx], y_lower[unique_idx], y_upper[unique_idx]

            color = palette.get(cls, 'gray')
            plt.scatter(x, y, color=color, s=20, alpha=0.75, label=f"{cls} mean")

            if len(x) < 3:
                plt.plot(x, y, color=color, linewidth=2, label=f"{cls} line")
                continue

            x_smooth = np.linspace(float(x.min()), float(x.max()), 1000)
            k = int(min(5, max(1, len(x)-1)))

            try:
                spline_mean = make_interp_spline(x, y, k=k)
                spline_lower = make_interp_spline(x, y_lower, k=k)
                spline_upper = make_interp_spline(x, y_upper, k=k)

                y_smooth = spline_mean(x_smooth)
                y_lower_smooth = spline_lower(x_smooth)
                y_upper_smooth = spline_upper(x_smooth)

                plt.plot(x_smooth, y_smooth, color=color, linewidth=2, label=f"{cls} smoothed")
                plt.fill_between(x_smooth, y_lower_smooth, y_upper_smooth, color=color, alpha=0.18)
            except Exception:
                plt.plot(x, y, color=color, linewidth=2, label=f"{cls} line")

        plt.xlabel('Age')
        plt.ylabel('Mean Relative Alpha (8–13 / 1–40 Hz)')
        plt.title('Mean Relative Alpha vs. Age (All Channels) with 95% CI')
        plt.grid(True, alpha=0.3)

        try:
            lo = int(floor(float(grouped['Age'].min())/10)*10)
            hi = int(ceil(float(grouped['Age'].max())/10)*10)
            plt.xticks(np.arange(lo, hi+1, 10))
        except Exception:
            pass

        handles, labels = plt.gca().get_legend_handles_labels()
        seen = set()
        uniq_h, uniq_l = [], []
        for h, l in zip(handles, labels):
            if l in seen:
                continue
            seen.add(l)
            uniq_h.append(h)
            uniq_l.append(l)
        plt.legend(uniq_h, uniq_l, fontsize=9)

        plt.tight_layout()

        output_path = os.path.join(OUTPUT_DIR, 'mean_relative_alpha_all_channels_vs_age_spline.png')
        plt.savefig(output_path, dpi=300, bbox_inches='tight')
        plt.show()
        print(f"Plot saved to {output_path}")
else:
    print('No alpha power data available')


In [None]:
# Plot: Mean relative alpha per class vs age (Occipital ROI)
from math import ceil, floor

if len(alpha_power_df) > 0:
    df = alpha_power_df[alpha_power_df['Age'].notna()].copy()
    if len(df) == 0:
        print('No data with age information for plotting')
    else:
        df['Age'] = pd.to_numeric(df['Age'], errors='coerce')
        df = df.dropna(subset=['Age','mean_rel_alpha_occipital_channels','class']).copy()
        df['mean_rel_alpha_occipital_channels'] = pd.to_numeric(df['mean_rel_alpha_occipital_channels'], errors='coerce')
        df = df.replace([np.inf, -np.inf], np.nan).dropna(subset=['Age','mean_rel_alpha_occipital_channels','class'])

        grouped = df.groupby(['Age','class']).agg(
            MeanAlpha=('mean_rel_alpha_occipital_channels','mean'),
            StdAlpha=('mean_rel_alpha_occipital_channels','std'),
            N=('mean_rel_alpha_occipital_channels','count'),
        ).reset_index()

        def _ci(row):
            n = float(row['N'])
            mean = float(row['MeanAlpha'])
            std = float(row['StdAlpha']) if pd.notna(row['StdAlpha']) else 0.0
            if n < 2:
                return (mean, mean)
            se = std / np.sqrt(n)
            try:
                t = float(stats.t.ppf(0.975, df=int(n-1)))
            except Exception:
                t = 1.96
            return (mean - t*se, mean + t*se)

        ci_pairs = grouped.apply(_ci, axis=1)
        ci_df = pd.DataFrame(ci_pairs.tolist(), columns=['Lower','Upper'], index=grouped.index)
        grouped = pd.concat([grouped, ci_df], axis=1)

        plt.figure(figsize=(12, 5))
        palette = {'EC':'tab:red', 'EO':'tab:blue'}

        for cls in ['EC','EO']:
            g = grouped[grouped['class'] == cls].sort_values('Age')
            if len(g) == 0:
                continue

            grouped_clean = g.replace([np.inf, -np.inf], np.nan).dropna(subset=['Age','MeanAlpha','Lower','Upper'])
            x = grouped_clean['Age'].values
            y = grouped_clean['MeanAlpha'].values
            y_lower = grouped_clean['Lower'].values
            y_upper = grouped_clean['Upper'].values

            # ensure strictly increasing x for spline
            order = np.argsort(x)
            x, y, y_lower, y_upper = x[order], y[order], y_lower[order], y_upper[order]
            x_unique, unique_idx = np.unique(x, return_index=True)
            x, y, y_lower, y_upper = x_unique, y[unique_idx], y_lower[unique_idx], y_upper[unique_idx]

            color = palette.get(cls, 'gray')
            plt.scatter(x, y, color=color, s=20, alpha=0.75, label=f"{cls} mean")

            if len(x) < 3:
                plt.plot(x, y, color=color, linewidth=2, label=f"{cls} line")
                continue

            x_smooth = np.linspace(float(x.min()), float(x.max()), 1000)
            k = int(min(5, max(1, len(x)-1)))

            try:
                spline_mean = make_interp_spline(x, y, k=k)
                spline_lower = make_interp_spline(x, y_lower, k=k)
                spline_upper = make_interp_spline(x, y_upper, k=k)

                y_smooth = spline_mean(x_smooth)
                y_lower_smooth = spline_lower(x_smooth)
                y_upper_smooth = spline_upper(x_smooth)

                plt.plot(x_smooth, y_smooth, color=color, linewidth=2, label=f"{cls} smoothed")
                plt.fill_between(x_smooth, y_lower_smooth, y_upper_smooth, color=color, alpha=0.18)
            except Exception:
                plt.plot(x, y, color=color, linewidth=2, label=f"{cls} line")

        plt.xlabel('Age')
        plt.ylabel('Mean Relative Alpha (8–13 / 1–40 Hz)')
        plt.title('Mean Relative Alpha vs. Age (Occipital ROI) with 95% CI')
        plt.grid(True, alpha=0.3)

        try:
            lo = int(floor(float(grouped['Age'].min())/10)*10)
            hi = int(ceil(float(grouped['Age'].max())/10)*10)
            plt.xticks(np.arange(lo, hi+1, 10))
        except Exception:
            pass

        handles, labels = plt.gca().get_legend_handles_labels()
        seen = set()
        uniq_h, uniq_l = [], []
        for h, l in zip(handles, labels):
            if l in seen:
                continue
            seen.add(l)
            uniq_h.append(h)
            uniq_l.append(l)
        plt.legend(uniq_h, uniq_l, fontsize=9)

        plt.tight_layout()

        output_path = os.path.join(OUTPUT_DIR, 'mean_relative_alpha_occipital_channels_vs_age_spline.png')
        plt.savefig(output_path, dpi=300, bbox_inches='tight')
        plt.show()
        print(f"Plot saved to {output_path}")
else:
    print('No alpha power data available')


## 9. EC vs EO Alpha Power Comparison

This section compares, per subject, whether mean alpha power is higher in the model-labeled EC epochs than in the model-labeled EO epochs.

Note: these are **predicted** labels (high-confidence epochs), not ground-truth EC/EO annotations.


In [None]:
# Compare EC and EO relative alpha per subject
# NOTE: This compares the model's *predicted* EC vs predicted EO epochs (high-confidence per subject),
# not ground-truth EC/EO labels.

if len(alpha_power_df) > 0:
    ec_eo_comparison = alpha_power_df.pivot_table(
        index=['subject_id', 'Age'],
        columns='class',
        values=['mean_rel_alpha_all_channels', 'mean_rel_alpha_occipital_channels'],
        aggfunc='first'
    ).reset_index()

    ec_eo_comparison.columns = ['_'.join(col).strip('_') if col[1] else col[0] for col in ec_eo_comparison.columns.values]

    required_cols = [
        'mean_rel_alpha_all_channels_EC',
        'mean_rel_alpha_all_channels_EO',
        'mean_rel_alpha_occipital_channels_EC',
        'mean_rel_alpha_occipital_channels_EO',
    ]
    missing = [c for c in required_cols if c not in ec_eo_comparison.columns]
    if missing:
        print('Missing required EC/EO columns:', missing)
    else:
        ec_eo_comparison['EC_higher_than_EO_all'] = (
            ec_eo_comparison['mean_rel_alpha_all_channels_EC'] >
            ec_eo_comparison['mean_rel_alpha_all_channels_EO']
        )

        ec_eo_comparison['EC_higher_than_EO_occ'] = (
            ec_eo_comparison['mean_rel_alpha_occipital_channels_EC'] >
            ec_eo_comparison['mean_rel_alpha_occipital_channels_EO']
        )

        total_subjects = len(ec_eo_comparison)
        ec_higher_count_all = int(ec_eo_comparison['EC_higher_than_EO_all'].sum())
        ec_higher_count_occ = int(ec_eo_comparison['EC_higher_than_EO_occ'].sum())

        percentage_ec_higher_all = (ec_higher_count_all / total_subjects) * 100 if total_subjects > 0 else 0
        percentage_ec_higher_occ = (ec_higher_count_occ / total_subjects) * 100 if total_subjects > 0 else 0

        print(f"Total subjects with both predicted EC and predicted EO data: {total_subjects}")
        print("All channels (predicted):")
        print(f"  Subjects with EC alpha > EO alpha: {ec_higher_count_all}/{total_subjects} ({percentage_ec_higher_all:.1f}%)")
        print("Occipital channels (predicted):")
        print(f"  Subjects with EC alpha > EO alpha: {ec_higher_count_occ}/{total_subjects} ({percentage_ec_higher_occ:.1f}%)")
        print(f"  (So EO > EC for occipital in ~{100.0 - percentage_ec_higher_occ:.1f}% of subjects, excluding ties)")

        diff = ec_eo_comparison['mean_rel_alpha_occipital_channels_EO'] - ec_eo_comparison['mean_rel_alpha_occipital_channels_EC']
        ec_eo_comparison['EO_minus_EC_occ'] = diff
        print("Top 10 subjects where EO exceeds EC most (occipital):")
        print(ec_eo_comparison.sort_values('EO_minus_EC_occ', ascending=False).head(10)[['subject_id','Age','EO_minus_EC_occ']].to_string(index=False))

        plt.figure(figsize=(10, 6))
        categories = ['All Channels', 'Occipital Channels']
        percentages = [percentage_ec_higher_all, percentage_ec_higher_occ]
        colors = sns.color_palette("viridis", len(categories))

        bars = plt.bar(categories, percentages, color=colors, alpha=0.8, edgecolor='black', linewidth=1.5)

        for bar, pct in zip(bars, percentages):
            height = bar.get_height()
            plt.text(bar.get_x() + bar.get_width()/2., height + 1,
                    f'{pct:.1f}%', ha='center', va='bottom', fontsize=12, fontweight='bold')

        plt.title('Percentage of Subjects with Higher Mean Alpha Power in EC vs EO (predicted)',
                  fontsize=14, fontweight='bold')
        plt.ylabel('Percentage (%)', fontsize=12)
        plt.ylim(0, 100)
        plt.grid(True, alpha=0.3, axis='y')
        plt.tight_layout()

        output_path = os.path.join(OUTPUT_DIR, 'percentage_ec_higher_alpha.png')
        plt.savefig(output_path, dpi=300, bbox_inches='tight')
        plt.show()
        print(f"Plot saved to {output_path}")

        comparison_path = os.path.join(OUTPUT_DIR, 'ec_eo_comparison.csv')
        ec_eo_comparison.to_csv(comparison_path, index=False)
        print(f"Comparison results saved to {comparison_path}")
else:
    print("No relative alpha data available for comparison")


## 10. Summary

Analysis complete! All outputs have been saved to the run-specific output directory printed at the top (default: `New_EEG/outputs/analysis_script/<RUN_FOLDER>/`).