# Mean Alpha Power: EC vs EO (Top-K per class per subject)

This notebook reproduces the **mean absolute** and **mean relative** alpha power analysis, but instead of selecting a single class/run (e.g. a longest EC run), it selects:

- **Top-K EC epochs per subject**: highest `prob_ec`
- **Top-K EO epochs per subject**: lowest `prob_ec`

Key rules:
- Default `K = 60` (configurable)
- A subject is **excluded** unless it has **K distinct epochs for both EC and EO**
- Results are cached so heavy computations are not repeated on reruns


In [None]:
import os
import re
import platform
from pathlib import Path

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns

import mne
from mne.time_frequency import psd_array_welch
from scipy.interpolate import make_interp_spline
from scipy import stats

sns.set_style('whitegrid')
plt.rcParams['figure.figsize'] = (12, 6)


In [None]:
# -------------------- Path configuration (New_EEG adaptation) --------------------
# Pick which classifier run you want to analyze.
# This should match a run folder under: <LABELING_ROOT>/preprocessed_setfiles/<RUN_FOLDER>/label_predictions.csv
RUN_FOLDER = os.getenv(
    'RUN_FOLDER',
    'old_dataset__fooof__allch__cv2__time_align_conditions__one_main_fooof__mainfooof_all_epochs__pen_l2'
)

# Where Label_with_EC_EO_Classifier.ipynb wrote its predictions
LABELING_ROOT = os.getenv('LABELING_ROOT', r'G:\ChristianMusaeus\labeling')

# Metadata (age/sex/etc). Provide a path to metadata_time_filtered.csv.
# Set to "" to disable metadata-dependent plots.
METADATA_CSV = os.getenv('METADATA_CSV', r'G:\ChristianMusaeus\metadata_time_filtered.csv')

# -------------------- Selection + computation configuration --------------------
TOP_K_PER_CLASS = int(os.getenv('TOP_K_PER_CLASS', '60'))
EXCLUDE_IF_MISSING_K = os.getenv('EXCLUDE_IF_MISSING_K', '1').strip() not in {'0','false','False'}
FORCE_RECOMPUTE = os.getenv('FORCE_RECOMPUTE', '0').strip() in {'1','true','True'}

# Alpha bands / PSD settings (match Project-main defaults)
ABS_FMIN, ABS_FMAX = 8.0, 13.0
REL_FMIN, REL_FMAX = 1.0, 40.0
N_FFT = int(os.getenv('N_FFT', '200'))

# ROI for occipital plots
OCCIPITAL_ROI = ['O1', 'O2']

# -------------------- Helpers --------------------

def guess_project_root() -> Path:
    p = Path.cwd().resolve()
    for _ in range(8):
        if (p / '.git').exists() or (p / 'New_EEG').exists():
            return p
        p = p.parent
    return Path.cwd().resolve()

def is_wsl() -> bool:
    try:
        return 'microsoft' in platform.uname().release.lower()
    except Exception:
        return False

def resolve_windows_path(p: str) -> Path:
    # Best-effort conversion of Windows drive paths when running on WSL.
    s = str(p)
    if is_wsl():
        m = re.match(r'^([A-Za-z]):[\/](.*)$', s)
        if m:
            drive = m.group(1).lower()
            rest = m.group(2).replace('\\', '/')
            return Path(f'/mnt/{drive}/{rest}')
    return Path(s)

project_root = guess_project_root()
new_eeg_root = project_root / 'New_EEG'
outputs_root = new_eeg_root / 'outputs'

labeling_root = resolve_windows_path(LABELING_ROOT)
if labeling_root.name.lower() == 'labeling':
    labeling_root = labeling_root / 'preprocessed_setfiles'

run_dir = labeling_root / RUN_FOLDER
label_predictions_csv = run_dir / 'label_predictions.csv'
if not label_predictions_csv.exists():
    raise FileNotFoundError(f'Missing: {label_predictions_csv}')

analysis_root = outputs_root / 'analysis' / 'mean_alpha_power_EC_EO' / f"{RUN_FOLDER}__topk{TOP_K_PER_CLASS}"
analysis_root.mkdir(parents=True, exist_ok=True)

selected_cache = analysis_root / f"selected_epochs_topk{TOP_K_PER_CLASS}.csv.gz"
metrics_cache = analysis_root / f"alpha_metrics_topk{TOP_K_PER_CLASS}.csv.gz"

metadata_path = None
if str(METADATA_CSV).strip():
    metadata_path = resolve_windows_path(METADATA_CSV)
    if not metadata_path.exists():
        raise FileNotFoundError(f'METADATA_CSV does not exist: {metadata_path}')

print('Run dir:', run_dir)
print('Label predictions:', label_predictions_csv)
print('Analysis outputs:', analysis_root)
print('Selected cache:', selected_cache)
print('Metrics cache:', metrics_cache)
print('Metadata:', metadata_path if metadata_path else '<not set>')

In [None]:
def _to_str(x) -> str:
    return '' if pd.isna(x) else str(x).strip()

def compute_prob_ec(df: pd.DataFrame) -> pd.Series:
    # Prefer explicit prob_ec if present
    if 'prob_ec' in df.columns:
        return pd.to_numeric(df['prob_ec'], errors='coerce')

    # Otherwise infer from (Label, Probability)
    # Probability is probability of the predicted label.
    if 'Label' in df.columns and 'Probability' in df.columns:
        label = pd.to_numeric(df['Label'], errors='coerce')
        p = pd.to_numeric(df['Probability'], errors='coerce')
        return np.where(label == 1, p, 1.0 - p)

    raise KeyError('Need either prob_ec, or (Label + Probability) columns to compute prob_ec')

def load_epochs_any(path_str: str):
    p = resolve_windows_path(path_str)
    suf = p.suffix.lower()
    if suf == '.set':
        return mne.io.read_epochs_eeglab(str(p), verbose='ERROR')
    return mne.read_epochs(str(p), verbose='ERROR')

def canonical_channel_name(ch_name: str) -> str:
    name = str(ch_name).strip()
    name = re.sub(r'^EEG\s+', '', name, flags=re.IGNORECASE)
    name = re.sub(r'-REF$', '', name, flags=re.IGNORECASE)
    name = re.sub(r'\s+', '', name)
    return name

def get_occipital_picks(epochs, roi=None) -> np.ndarray:
    roi = roi or OCCIPITAL_ROI
    roi_set = {c.upper() for c in roi}
    ch_can = [canonical_channel_name(ch).upper() for ch in epochs.ch_names]
    picks = [i for i, n in enumerate(ch_can) if n in roi_set]
    return np.asarray(picks, dtype=int)

def select_topk_epochs_per_subject(df_sub: pd.DataFrame, k: int) -> pd.DataFrame:
    # Returns rows with columns: epoch_idx, prob_ec, class
    d = df_sub.copy()
    d = d.dropna(subset=['epoch_idx','prob_ec'])
    d['epoch_idx'] = pd.to_numeric(d['epoch_idx'], errors='coerce').astype('Int64')
    d = d.dropna(subset=['epoch_idx'])
    d['epoch_idx'] = d['epoch_idx'].astype(int)

    d = d.sort_values('prob_ec', ascending=True)
    if len(d) < 2*k:
        return pd.DataFrame()

    eo = d.head(k).copy()
    ec = d.tail(k).copy()

    if set(eo['epoch_idx']).intersection(set(ec['epoch_idx'])):
        return pd.DataFrame()

    eo['class'] = 'EO'
    ec['class'] = 'EC'
    out = pd.concat([eo, ec], ignore_index=True)
    return out[['epoch_idx','prob_ec','class']]

def abs_alpha_uV2_perHz(selected_data: np.ndarray, sfreq: float, fmin: float, fmax: float, n_fft: int) -> float:
    psds, _ = psd_array_welch(
        selected_data,
        sfreq=float(sfreq),
        fmin=float(fmin),
        fmax=float(fmax),
        n_fft=int(n_fft),
        verbose=False,
    )
    mean_power_per_channel = psds.mean(axis=-1)          # (epochs, channels)
    mean_power_over_epochs = mean_power_per_channel.mean(axis=0)  # (channels,)
    return float(np.nanmean(mean_power_over_epochs)) * 1e12

def relative_alpha(selected_data: np.ndarray, sfreq: float) -> float:
    psds_alpha, _ = psd_array_welch(
        selected_data,
        sfreq=float(sfreq),
        fmin=float(ABS_FMIN),
        fmax=float(ABS_FMAX),
        n_fft=int(N_FFT),
        verbose=False,
    )
    alpha_power = float(np.nanmean(psds_alpha.sum(axis=-1)))

    psds_total, _ = psd_array_welch(
        selected_data,
        sfreq=float(sfreq),
        fmin=float(REL_FMIN),
        fmax=float(REL_FMAX),
        n_fft=int(N_FFT),
        verbose=False,
    )
    total_power = float(np.nanmean(psds_total.sum(axis=-1)))

    if total_power == 0 or np.isnan(total_power):
        return float('nan')
    return alpha_power / total_power

def compute_ci_bounds(mean: float, std: float, n: int) -> tuple[float, float]:
    if n < 2 or not np.isfinite(std):
        return mean, mean
    se = std / np.sqrt(n)
    try:
        t = float(stats.t.ppf(0.975, df=int(n-1)))
    except Exception:
        t = 1.96
    return mean - t*se, mean + t*se

def prepare_age_stats(df: pd.DataFrame, value_col: str, age_col: str = 'age') -> pd.DataFrame:
    d = df.copy()
    d = d.dropna(subset=[age_col, 'class', value_col])
    d[age_col] = pd.to_numeric(d[age_col], errors='coerce')
    d[value_col] = pd.to_numeric(d[value_col], errors='coerce')
    d = d.replace([np.inf, -np.inf], np.nan).dropna(subset=[age_col, value_col])

    stats_df = d.groupby([age_col, 'class']).agg(
        Mean=(value_col, 'mean'),
        Std=(value_col, 'std'),
        N=(value_col, 'count'),
    ).reset_index()

    bounds = stats_df.apply(lambda r: compute_ci_bounds(float(r['Mean']), float(r['Std']) if pd.notna(r['Std']) else float('nan'), int(r['N'])), axis=1)
    bounds_df = pd.DataFrame(bounds.tolist(), columns=['Lower','Upper'])
    return pd.concat([stats_df, bounds_df], axis=1)

def plot_age_curve(stats_df: pd.DataFrame, title: str, ylabel: str, out_path: Path, age_col: str = 'age'):
    plt.figure(figsize=(12, 5))
    palette = {'EC': 'tab:red', 'EO': 'tab:blue'}

    for cls in ['EC', 'EO']:
        g = stats_df[stats_df['class'] == cls].sort_values(age_col)
        if len(g) == 0:
            continue

        x = g[age_col].to_numpy(dtype=float)
        y = g['Mean'].to_numpy(dtype=float)
        lo = g['Lower'].to_numpy(dtype=float)
        hi = g['Upper'].to_numpy(dtype=float)

        # Ensure strictly increasing x for spline
        order = np.argsort(x)
        x, y, lo, hi = x[order], y[order], lo[order], hi[order]
        x_unique, idx = np.unique(x, return_index=True)
        x, y, lo, hi = x_unique, y[idx], lo[idx], hi[idx]

        color = palette.get(cls, 'gray')
        plt.scatter(x, y, color=color, s=20, alpha=0.8, label=f'{cls} mean')

        if len(x) >= 3:
            x_smooth = np.linspace(float(x.min()), float(x.max()), 1000)
            k = int(min(5, max(1, len(x)-1)))
            try:
                spline_mean = make_interp_spline(x, y, k=k)
                spline_lo = make_interp_spline(x, lo, k=k)
                spline_hi = make_interp_spline(x, hi, k=k)
                y_s = spline_mean(x_smooth)
                lo_s = spline_lo(x_smooth)
                hi_s = spline_hi(x_smooth)
                plt.plot(x_smooth, y_s, color=color, linewidth=2, label=f'{cls} smoothed')
                plt.fill_between(x_smooth, lo_s, hi_s, color=color, alpha=0.18)
            except Exception:
                plt.plot(x, y, color=color, linewidth=2, label=f'{cls} line')
        else:
            plt.plot(x, y, color=color, linewidth=2, label=f'{cls} line')

    plt.title(title)
    plt.xlabel('Age')
    plt.ylabel(ylabel)
    plt.grid(True, alpha=0.3)

    handles, labels = plt.gca().get_legend_handles_labels()
    seen = set()
    uniq_h, uniq_l = [], []
    for h, l in zip(handles, labels):
        if l in seen:
            continue
        seen.add(l)
        uniq_h.append(h)
        uniq_l.append(l)
    plt.legend(uniq_h, uniq_l, fontsize=9)

    plt.tight_layout()
    plt.savefig(out_path, dpi=300, bbox_inches='tight')
    plt.show()
    print('Saved:', out_path)


## Load predictions and select top-K epochs per class

In [None]:
# Load predictions
pred_df = pd.read_csv(label_predictions_csv)

# Standardize key columns
if 'Test subject ID' in pred_df.columns:
    pred_df['subject_id'] = pred_df['Test subject ID'].apply(_to_str)
elif 'subject_id' in pred_df.columns:
    pred_df['subject_id'] = pred_df['subject_id'].apply(_to_str)
else:
    raise KeyError('Could not find subject id column (Test subject ID / subject_id)')

if 'Epoch number' in pred_df.columns:
    pred_df['epoch_idx'] = pred_df['Epoch number']
elif 'epoch_idx' not in pred_df.columns:
    raise KeyError('Could not find epoch index column (Epoch number / epoch_idx)')

pred_df['prob_ec'] = compute_prob_ec(pred_df)

# File path mapping (optional but preferred)
if 'file' in pred_df.columns:
    subject_file_map = pred_df.groupby('subject_id')['file'].first().to_dict()
else:
    subject_file_map = {}

# Select epochs (cached)
if selected_cache.exists() and not FORCE_RECOMPUTE:
    selected_df = pd.read_csv(selected_cache, compression='gzip')
    print(f'Loaded selected epochs cache: {selected_cache} | rows={len(selected_df)}')
else:
    rows = []
    excluded = 0
    for subject_id, df_sub in pred_df.groupby('subject_id'):
        sel = select_topk_epochs_per_subject(df_sub[['epoch_idx','prob_ec']].assign(epoch_idx=df_sub['epoch_idx'], prob_ec=df_sub['prob_ec']), TOP_K_PER_CLASS)
        if len(sel) == 0:
            excluded += 1
            continue
        file_path = subject_file_map.get(subject_id, '')
        sel = sel.copy()
        sel['subject_id'] = subject_id
        sel['file'] = file_path
        rows.append(sel)

    selected_df = pd.concat(rows, ignore_index=True) if rows else pd.DataFrame(columns=['subject_id','file','epoch_idx','prob_ec','class'])

    # Enforce exclusion rule: must have K for both classes per subject
    if EXCLUDE_IF_MISSING_K and len(selected_df) > 0:
        counts = selected_df.groupby(['subject_id','class'])['epoch_idx'].nunique().unstack('class').fillna(0)
        ok = counts[(counts.get('EC',0) >= TOP_K_PER_CLASS) & (counts.get('EO',0) >= TOP_K_PER_CLASS)].index
        before = selected_df['subject_id'].nunique()
        selected_df = selected_df[selected_df['subject_id'].isin(ok)].copy()
        after = selected_df['subject_id'].nunique()
        print(f'Excluded subjects missing K per class: {before-after} (kept {after})')

    selected_df.to_csv(selected_cache, index=False, compression='gzip')
    print(f'Saved selected epochs cache: {selected_cache} | rows={len(selected_df)} | excluded (initial)={excluded}')

print('Selected subjects:', selected_df['subject_id'].nunique())
print('Rows per class:')
print(selected_df['class'].value_counts(dropna=False))


## Compute alpha metrics (absolute + relative; all channels + occipital ROI)

This step is cached to avoid recomputing PSDs on reruns.

In [None]:
# Compute alpha metrics per subject per class (cached)
if metrics_cache.exists() and not FORCE_RECOMPUTE:
    metrics_df = pd.read_csv(metrics_cache, compression='gzip')
    print(f'Loaded metrics cache: {metrics_cache} | rows={len(metrics_df)}')
else:
    records = []
    for subject_id, df_sub in selected_df.groupby('subject_id'):
        file_path = df_sub['file'].iloc[0] if 'file' in df_sub.columns else ''
        if not str(file_path).strip():
            # fallback: try any file column in pred_df
            cand = pred_df[pred_df['subject_id'] == subject_id].get('file')
            file_path = cand.dropna().iloc[0] if cand is not None and len(cand.dropna()) else ''

        if not str(file_path).strip():
            continue

        try:
            epochs = load_epochs_any(str(file_path))
        except Exception as e:
            print(f'Could not load epochs for {subject_id}: {e}')
            continue

        picks_occ = get_occipital_picks(epochs, OCCIPITAL_ROI)
        sfreq = float(epochs.info['sfreq'])

        for cls in ['EC','EO']:
            ep_idx = pd.to_numeric(df_sub[df_sub['class'] == cls]['epoch_idx'], errors='coerce').dropna().astype(int).unique()
            if ep_idx.size != TOP_K_PER_CLASS:
                continue

            data = epochs.get_data()[ep_idx]

            abs_all = abs_alpha_uV2_perHz(data, sfreq, ABS_FMIN, ABS_FMAX, N_FFT)
            rel_all = relative_alpha(data, sfreq)

            abs_occ = float('nan')
            rel_occ = float('nan')
            if picks_occ is not None and len(picks_occ) > 0:
                data_occ = data[:, picks_occ, :]
                abs_occ = abs_alpha_uV2_perHz(data_occ, sfreq, ABS_FMIN, ABS_FMAX, N_FFT)
                rel_occ = relative_alpha(data_occ, sfreq)

            records.append({
                'subject_id': subject_id,
                'class': cls,
                'file': file_path,
                'n_epochs': int(ep_idx.size),
                'abs_all_uV2_perHz': abs_all,
                'rel_all': rel_all,
                'abs_occ_uV2_perHz': abs_occ,
                'rel_occ': rel_occ,
            })

    metrics_df = pd.DataFrame.from_records(records)
    metrics_df.to_csv(metrics_cache, index=False, compression='gzip')
    print(f'Saved metrics cache: {metrics_cache} | rows={len(metrics_df)}')

# Merge metadata when available
if metadata_path is not None and len(metrics_df) > 0:
    meta = pd.read_csv(metadata_path)
    # Normalize expected columns
    if 'subject_id' not in meta.columns:
        # Try common alternatives
        for c in ['Subject_ID','SubjectID','Test subject ID']:
            if c in meta.columns:
                meta['subject_id'] = meta[c]
                break
    if 'age' not in meta.columns:
        if 'Age' in meta.columns:
            meta['age'] = meta['Age']
        elif 'Y' in meta.columns:
            meta['age'] = meta['Y']
    if 'sex' not in meta.columns:
        if 'Sex' in meta.columns:
            meta['sex'] = meta['Sex']

    meta['subject_id'] = meta['subject_id'].apply(_to_str)
    metrics_df['subject_id'] = metrics_df['subject_id'].apply(_to_str)

    metrics_df = metrics_df.merge(meta[['subject_id'] + [c for c in ['age','sex'] if c in meta.columns]], on='subject_id', how='left')

print(metrics_df.head())
print('Subjects in metrics:', metrics_df['subject_id'].nunique() if len(metrics_df) else 0)


## Plots

All plots below are produced **for both classes** (EC and EO) using the cached `metrics_df`.

Note on units:
- Absolute alpha power is reported in **µV²/Hz**.
- Relative alpha power is unitless (alpha-band power / total 1–40 Hz power).


In [None]:
# Distribution plots (included subjects only)
if metadata_path is None or 'age' not in metrics_df.columns:
    print('Metadata not available; skipping age/sex distribution plots.')
else:
    included = metrics_df[['subject_id','age','sex']].drop_duplicates('subject_id').copy()
    included['age'] = pd.to_numeric(included['age'], errors='coerce')

    plt.figure(figsize=(10, 4))
    plt.hist(included['age'].dropna(), bins=20, color='gray', alpha=0.8)
    plt.title('Age distribution (included subjects)')
    plt.xlabel('Age')
    plt.ylabel('Count')
    plt.tight_layout()
    plt.show()

    if 'sex' in included.columns:
        plt.figure(figsize=(6, 4))
        included['sex'] = included['sex'].astype(str)
        counts = included['sex'].value_counts(dropna=False)
        plt.bar(counts.index.astype(str), counts.values, color='steelblue', alpha=0.85)
        plt.title('Sex distribution (included subjects)')
        plt.xlabel('Sex')
        plt.ylabel('Count')
        plt.tight_layout()
        plt.show()


## Mean alpha power vs age (EC and EO together)

In [None]:
# Absolute alpha vs age (all channels)
if metadata_path is None or 'age' not in metrics_df.columns:
    print('Metadata not available; skipping age plots.')
else:
    abs_all_stats = prepare_age_stats(metrics_df, 'abs_all_uV2_perHz', age_col='age')
    plot_age_curve(
        abs_all_stats,
        title='Mean Absolute Alpha Power vs. Age (All Channels) with 95% CI',
        ylabel='Mean Absolute Alpha Power (µV²/Hz)',
        out_path=analysis_root / 'abs_alpha_all_channels_vs_age.png',
        age_col='age',
    )

# Relative alpha vs age (all channels)
if metadata_path is not None and 'age' in metrics_df.columns:
    rel_all_stats = prepare_age_stats(metrics_df, 'rel_all', age_col='age')
    plot_age_curve(
        rel_all_stats,
        title='Mean Relative Alpha Power vs. Age (All Channels) with 95% CI',
        ylabel='Mean Relative Alpha Power (alpha / 1–40 Hz)',
        out_path=analysis_root / 'rel_alpha_all_channels_vs_age.png',
        age_col='age',
    )


## Mean alpha power vs age by sex (EC and EO together)

In [None]:
# Plots by sex (two panels: Female/Male), each showing EC + EO

def plot_by_sex(value_col: str, title: str, ylabel: str, out_name: str):
    if metadata_path is None or 'age' not in metrics_df.columns or 'sex' not in metrics_df.columns:
        print('Metadata not available; skipping sex plots.')
        return

    d = metrics_df.dropna(subset=['age','sex', value_col, 'class']).copy()
    d['age'] = pd.to_numeric(d['age'], errors='coerce')
    d[value_col] = pd.to_numeric(d[value_col], errors='coerce')
    d = d.replace([np.inf, -np.inf], np.nan).dropna(subset=['age', value_col])

    sexes = [s for s in ['Female','Male'] if s in set(d['sex'].astype(str))]
    if not sexes:
        sexes = sorted(d['sex'].astype(str).unique())[:2]

    fig, axes = plt.subplots(1, len(sexes), figsize=(16, 5), sharey=True)
    if len(sexes) == 1:
        axes = [axes]

    palette = {'EC': 'tab:red', 'EO': 'tab:blue'}

    for ax, sex in zip(axes, sexes):
        ds = d[d['sex'].astype(str) == str(sex)]
        stats_df = prepare_age_stats(ds, value_col, age_col='age')

        for cls in ['EC','EO']:
            g = stats_df[stats_df['class'] == cls].sort_values('age')
            if len(g) == 0:
                continue

            x = g['age'].to_numpy(dtype=float)
            y = g['Mean'].to_numpy(dtype=float)
            lo = g['Lower'].to_numpy(dtype=float)
            hi = g['Upper'].to_numpy(dtype=float)

            order = np.argsort(x)
            x, y, lo, hi = x[order], y[order], lo[order], hi[order]
            x_u, idx = np.unique(x, return_index=True)
            x, y, lo, hi = x_u, y[idx], lo[idx], hi[idx]

            color = palette.get(cls, 'gray')
            ax.scatter(x, y, color=color, s=18, alpha=0.8)

            if len(x) >= 3:
                xs = np.linspace(float(x.min()), float(x.max()), 800)
                k = int(min(5, max(1, len(x)-1)))
                try:
                    sm = make_interp_spline(x, y, k=k)
                    slo = make_interp_spline(x, lo, k=k)
                    shi = make_interp_spline(x, hi, k=k)
                    ax.plot(xs, sm(xs), color=color, linewidth=2, label=cls)
                    ax.fill_between(xs, slo(xs), shi(xs), color=color, alpha=0.18)
                except Exception:
                    ax.plot(x, y, color=color, linewidth=2, label=cls)
            else:
                ax.plot(x, y, color=color, linewidth=2, label=cls)

        ax.set_title(str(sex))
        ax.set_xlabel('Age')
        ax.grid(True, alpha=0.3)

    axes[0].set_ylabel(ylabel)
    handles, labels = axes[-1].get_legend_handles_labels()
    fig.legend(handles, labels, loc='upper right')
    fig.suptitle(title)
    fig.tight_layout()

    out_path = analysis_root / out_name
    fig.savefig(out_path, dpi=300, bbox_inches='tight')
    plt.show()
    print('Saved:', out_path)

plot_by_sex('abs_all_uV2_perHz', 'Mean Absolute Alpha Power vs. Age by Sex (All Channels)', 'Mean Absolute Alpha Power (µV²/Hz)', 'abs_alpha_all_by_sex.png')
plot_by_sex('rel_all', 'Mean Relative Alpha Power vs. Age by Sex (All Channels)', 'Mean Relative Alpha Power (alpha / 1–40 Hz)', 'rel_alpha_all_by_sex.png')


## Occipital ROI plots (O1/O2)

In [None]:
# Occipital ROI (O1/O2): absolute + relative vs age
if metadata_path is None or 'age' not in metrics_df.columns:
    print('Metadata not available; skipping occipital age plots.')
else:
    abs_occ_stats = prepare_age_stats(metrics_df, 'abs_occ_uV2_perHz', age_col='age')
    plot_age_curve(
        abs_occ_stats,
        title='Mean Absolute Alpha Power vs. Age (Occipital ROI) with 95% CI',
        ylabel='Mean Absolute Alpha Power (µV²/Hz)',
        out_path=analysis_root / 'abs_alpha_occipital_vs_age.png',
        age_col='age',
    )

    rel_occ_stats = prepare_age_stats(metrics_df, 'rel_occ', age_col='age')
    plot_age_curve(
        rel_occ_stats,
        title='Mean Relative Alpha Power vs. Age (Occipital ROI) with 95% CI',
        ylabel='Mean Relative Alpha Power (alpha / 1–40 Hz)',
        out_path=analysis_root / 'rel_alpha_occipital_vs_age.png',
        age_col='age',
    )


## EC vs EO direct comparison (paired per subject)

In [None]:
# EC vs EO direct comparison (paired per subject)

def plot_ec_vs_eo(value_col: str, title: str, xlabel: str, ylabel: str, out_name: str):
    wide = metrics_df.pivot_table(index='subject_id', columns='class', values=value_col, aggfunc='mean')
    wide = wide.dropna(subset=['EC','EO'])
    if len(wide) == 0:
        print('No paired EC/EO data available for', value_col)
        return

    x = wide['EO'].to_numpy(dtype=float)
    y = wide['EC'].to_numpy(dtype=float)

    plt.figure(figsize=(6, 6))
    plt.scatter(x, y, s=18, alpha=0.75)
    lo = float(np.nanmin([x.min(), y.min()]))
    hi = float(np.nanmax([x.max(), y.max()]))
    plt.plot([lo, hi], [lo, hi], color='gray', linestyle='--', linewidth=1)
    plt.title(title)
    plt.xlabel(xlabel)
    plt.ylabel(ylabel)
    plt.grid(True, alpha=0.3)
    plt.tight_layout()

    out_path = analysis_root / out_name
    plt.savefig(out_path, dpi=300, bbox_inches='tight')
    plt.show()
    print('Saved:', out_path)

    diff = y - x
    print(f'N subjects: {len(diff)} | % EC>EO: {100*np.mean(diff>0):.1f}% | mean(EC-EO)={np.mean(diff):.4g}')

plot_ec_vs_eo(
    'abs_all_uV2_perHz',
    title='EC vs EO: Absolute Alpha Power (All Channels)',
    xlabel='EO mean absolute alpha (µV²/Hz)',
    ylabel='EC mean absolute alpha (µV²/Hz)',
    out_name='compare_abs_all_ec_vs_eo.png',
)

plot_ec_vs_eo(
    'rel_all',
    title='EC vs EO: Relative Alpha Power (All Channels)',
    xlabel='EO mean relative alpha',
    ylabel='EC mean relative alpha',
    out_name='compare_rel_all_ec_vs_eo.png',
)
