In [1]:
"""
e_3_student_scoring_b.ipynb
───────────────────────────────────────────────────────────────────────────────
LOOCV student scoring with per-fold hyperparameters reused from teacher models
(no tuning) and percent-based augmentation sizes. Student embedding is selectable;
labels are read from a configurable teacher label source.

This script:
1) Loads seed data and resources
   - Reads 96 seeds (texts + 14 targets).
   - Loads or computes seed sentence embeddings for the selected student embedding.
   - Discovers augmentation methods and their M_max from tuned teacher outputs.
   - Preloads synthetic text embeddings for the selected student embedding and
     per-fold tuned labels from the label source (e.g., e5_base + chain_ERCcv_lr).

2) Builds student regressors from teacher hyperparameters (no tuning)
   - Students: chain_ERCcv_lr, local_lasso, local_rf, global_rf, chain_ERCcv_rf.
   - Extracts per-fold HPs from teacher pickles (step_g_2a) for the same
     student embedding; raises an error if missing (no silent defaults).

3) Evaluates the baseline (seeds-only)
   - For each LOOCV fold and student, loads that fold’s teacher estimator
     (if present) and predicts the held-out seed; otherwise rebuilds the
     same pipeline from the teacher’s stored HPs and predicts.
   - Writes per-pct baseline files for naming symmetry.

4) Evaluates the augmented “full” variant (seeds + synthetics)
   - Percent mode: P ∈ {10, 20, 50, 100, 200, 400}; K = round(P% of 96).
   - For each fold, takes the first K items from that fold’s M_max labeled set
     (labels from the label source), fits the student with the per-fold HPs,
     and predicts the held-out seed.
   - Idempotent: computes only missing folds and preserves already complete CSVs.
   - Legacy reuse for chain_ERCcv_lr at {100,200,400} applies only when the
     student embedding equals the label-source embedding (e.g., e5_base).

5) Produces unified summary tables (no PRIMARY/APPENDIX split)
   - Aggregates per-fold median RRMSE into a wide summary with columns:
     baseline, full, baseline_vs_full.
   - Writes a comparison table with ΔRRMSE and relative % change.

6) Performs statistical analyses
   - Per-config Wilcoxon signed-rank tests (one-sided, alternative="less",
     zero_method="pratt") with Holm correction grouped by (regressor, method).
   - Paired Cliff’s delta per configuration.
   - Pooled Wilcoxon and Cliff’s delta across students per (method, pct).
   - Hierarchical bootstrap (folds × domains micro-bootstrap) for Δ median with 95% CI.

7) Visualizes performance
   - RRMSE vs %K and ΔRRMSE vs %K for the configured PCTS_FOR_PLOTS.

Idempotency & disk reuse:
- Reuses seed embeddings if shape and max_seq_len match; otherwise rebuilds.
- Reuses synthetic embeddings (cache); builds them if missing.
- Skips FULL computations for (student, method, pct) whose per-fold CSV is complete.
- Optional: skip recomputing baseline if all baseline CSVs are already complete.
- Optional: summarization-only mode to build tables/plots from existing CSVs.

Inputs:
- data/activity_scores.csv
- data/activities.csv
- outputs/b_frozen/results/{student_embedding}_vectors.npy
- outputs/e_1_synth_augmentation/g_final_n{M}_{method}.csv
- outputs/e_2_teacher_labeling/g2f_labels_fold{ii}_n{M}_{method}__{label_embedding}__{label_model}.csv
- models/teacher/teacher_fold{ii}_{student_embedding}__{student}.pkl
- outputs/e_2_teacher_labeling/cache/synth_embeds/*__{student_embedding}.npy and *__index.csv
- (legacy reuse, only if student_embedding==label_embedding==e5_base and student==chain_ERCcv_lr)
  outputs/e_3_student_scoring/results/
    rrmse_perfold_e5_base__chain_ERCcv_lr__{method}__n96_sps1__full.csv
    rrmse_perfold_e5_base__chain_ERCcv_lr__{method}__n192_sps2__full.csv
    rrmse_perfold_e5_base__chain_ERCcv_lr__{method}__n384_sps4__full.csv

Outputs:
- outputs/e_3_student_scoring/results/
    rrmse_perfold_{student_embedding}__{regressor}__{method}__pct{P}_K{K}__Mmax{Mmax}__baseline.csv
    rrmse_perfold_{student_embedding}__{regressor}__{method}__pct{P}_K{K}__Mmax{Mmax}__full.csv
- outputs/e_3_student_scoring/hp_perfold/
    {regressor}/{method}/n{Mmax}/pct{P}_K{K}/fold{ii}_best.json
- outputs/e_3_student_scoring/tables/
    summary_median_rrmse.csv
    summary_median_rrmse_with_delta.csv
    wilcoxon_holm_vs_baseline.csv
    cliffs_delta_vs_baseline.csv
    wilcoxon_pooled_vs_baseline.csv
    cliffs_delta_pooled_vs_baseline.csv
    bootstrap_delta_ci.csv
- outputs/e_3_student_scoring/plots/
    combined_rrmse_vs_pct__{method}.png
    combined_delta_vs_pct__{method}.png
    combined_rrmse_vs_pct__{method}__{embedding|all}.png
    combined_delta_vs_pct__{method}__{embedding|all}.png
- outputs/e_3_student_scoring/
    cache/, run.log, run_config.json
"""

# ────────────────────────────────────────────────────────────────────────────
#  Imports
# ────────────────────────────────────────────────────────────────────────────

from __future__ import annotations

from itertools import combinations
from pathlib import Path
from typing import Dict, List, Optional, Tuple

import json
import logging
import os
import pickle
import re
import sys
import time
import warnings

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import scikit_posthocs as sp
from scipy.stats import friedmanchisquare, wilcoxon
from sentence_transformers import SentenceTransformer
from statsmodels.stats.multitest import multipletests
import torch
from sklearn.base import BaseEstimator, RegressorMixin, clone
from sklearn.decomposition import PCA
from sklearn.ensemble import RandomForestRegressor
from sklearn.linear_model import Lasso, LinearRegression
from sklearn.model_selection import KFold
from sklearn.multioutput import MultiOutputRegressor
from sklearn.pipeline import Pipeline
from sklearn.utils import check_random_state

warnings.filterwarnings("ignore")
os.environ.setdefault("TOKENIZERS_PARALLELISM", "false")

# ────────────────────────────────────────────────────────────────────────────
#  Paths 
# ────────────────────────────────────────────────────────────────────────────
def project_root(marker: str = "LICENSE") -> Path:
    here = Path.cwd().resolve()
    for d in (here, *here.parents):
        if (d / marker).is_file():
            return d
    return Path.cwd().resolve()

ROOT = project_root()
DATA_DIR = ROOT / "data"

G1_DIR  = ROOT / "outputs" / "e_1_synth_augmentation"
G2_DIR  = ROOT / "outputs" / "e_2_teacher_labeling"
SEED_VECTORS_DIR = ROOT / "outputs" / "b_frozen" / "results"

G3A_DIR = ROOT / "outputs" / "e_3_student_scoring"
RES_DIR     = G3A_DIR / "results"
TABLES_DIR  = G3A_DIR / "tables"
CACHE_DIR   = G3A_DIR / "cache"
HP_PERFOLD  = G3A_DIR / "hp_perfold"
FIGS_DIR    = G3A_DIR / "plots"

for p in (G3A_DIR, RES_DIR, TABLES_DIR, CACHE_DIR, HP_PERFOLD, FIGS_DIR):
    p.mkdir(parents=True, exist_ok=True)

OUT_TABLES = TABLES_DIR
PLOTS_DIR = FIGS_DIR
OUT_PLOTS  = FIGS_DIR

LOG_FILE = G3A_DIR / "run.log"
for h in list(logging.root.handlers):
    logging.root.removeHandler(h)

# ────────────────────────────────────────────────────────────────────────────
#  Logging 
# ────────────────────────────────────────────────────────────────────────────
logging.basicConfig(
    level=logging.INFO,
    format="%(asctime)s %(levelname)s: %(message)s",
    datefmt="%Y-%m-%d %H:%M:%S",
    handlers=[
        logging.FileHandler(str(LOG_FILE), mode="a", encoding="utf-8"),
        # use the uncaptured stdout to bypass notebook/pytest capturing
        logging.StreamHandler(getattr(sys, "__stdout__", sys.stdout)),
    ],
    force=True,
)
log = logging.getLogger(__name__)
_say_logger = logging.getLogger("sayfile")
if not _say_logger.handlers:
    _say_logger.setLevel(logging.INFO)
    _say_logger.propagate = False  # do NOT bubble to root (prevents console echo)
    _fh = logging.FileHandler(str(LOG_FILE), mode="a", encoding="utf-8")
    _fh.setFormatter(logging.Formatter("%(asctime)s %(levelname)s: %(message)s",
                                       datefmt="%Y-%m-%d %H:%M:%S"))
    _say_logger.addHandler(_fh)
    
RUN_ID: str = time.strftime("%Y%m%d-%H%M%S")
# ────────────────────────────────────────────────────────────────────────────
#  Config
# ────────────────────────────────────────────────────────────────────────────
REVIEW_MODE: bool = True

N_TARGETS = 14

METHODS: Optional[List[str]] = None

PCT_LIST: List[int]       = [10, 20, 50, 100, 200, 400]
PCTS_FOR_PLOTS: List[int] = [10, 20, 50, 100, 200, 400]
PRIMARY_PCTS = PCT_LIST

SAVE_PRED_PERFOLD: bool = True
LOG_EVERY: int          = 16
MAX_FOLDS: int          = 0  # 0 ⇒ all 96

# Which student regressors to run
STUDENTS: List[str] = ["chain_ERCcv_lr", "local_lasso", "local_rf", "global_rf", "chain_ERCcv_rf"]  

TARGET_COLS = [f"domain{i}" for i in range(1, 15)]
_DOM_PREFIX = "rrmse_domain"
BOOT_B: int = 5000

# ────────────────────────────────────────────────────────────────────────────
#  Device, cores, seeding
# ────────────────────────────────────────────────────────────────────────────
DETERMINISTIC = True
N_JOBS: int = min(os.cpu_count() or 6, 6)
os.environ.setdefault("OMP_NUM_THREADS",        str(N_JOBS))
os.environ.setdefault("OPENBLAS_NUM_THREADS",   str(N_JOBS))
os.environ.setdefault("MKL_NUM_THREADS",        str(N_JOBS))
os.environ.setdefault("VECLIB_MAXIMUM_THREADS", str(N_JOBS))
os.environ.setdefault("NUMEXPR_NUM_THREADS",    str(N_JOBS))
try:
    torch.set_num_threads(N_JOBS)
except Exception:
    pass

device = torch.device(
    "cuda" if torch.cuda.is_available()
    else "mps" if getattr(torch.backends, "mps", None) and torch.backends.mps.is_available()
    else "cpu"
)
DEVICE_STR = device.type

SEED = 42
np.random.seed(SEED)
try:
    import random
    random.seed(SEED)
    torch.manual_seed(SEED)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(SEED)
    if getattr(torch.backends, "mps", None) and torch.backends.mps.is_available():
        torch.manual_seed(SEED)
    if DETERMINISTIC:
        torch.use_deterministic_algorithms(True)
except Exception:
    pass

# ────────────────────────────────────────────────────────────────────────────
#  Embedding registry and selection
# ────────────────────────────────────────────────────────────────────────────
EMBEDDING_SPECS: Dict[str, str] = {
    "e5_base":          "embaas/sentence-transformers-multilingual-e5-base",
    #"e5_large":         "embaas/sentence-transformers-multilingual-e5-large",
    #"simcse_xlmr_base": "sentence-transformers/paraphrase-xlm-r-multilingual-v1",
    #"sbert_bert":       "jegormeister/bert-base-dutch-cased-snli",
}

STUDENT_EMBEDDINGS: List[str] = ["e5_base"]
def _normalize_embeddings(x):
    if isinstance(x, str):
        return [s.strip() for s in x.split(",") if s.strip()]
    elif isinstance(x, (list, tuple)):
        return [str(s).strip() for s in x if str(s).strip()]
    else:
        return []

# Normalize & dedupe
_seen = set()
STUDENT_EMBEDDINGS = [e for e in STUDENT_EMBEDDINGS if not (e in _seen or _seen.add(e))]
# After deduping STUDENT_EMBEDDINGS
STUDENT_EMB = STUDENT_EMBEDDINGS[0]

# Label source (teacher that produced the labels for synthetics)
LABEL_EMB: str   = "e5_base"
LABEL_MODEL: str = "local_lasso"
if LABEL_EMB not in EMBEDDING_SPECS:
    raise ValueError(f"Unknown LABEL_SOURCE_EMBEDDING='{LABEL_EMB}'. Allowed: {list(EMBEDDING_SPECS)}")

MAX_SEQ_LEN: int = 512

# Optional extra pass: train chain_ERCcv_lr student on labels from local_lasso teacher (same embedding)
EXTRA_STUDENT_LABEL_JOBS = [
    {
        "student": "chain_ERCcv_lr",
        "label_emb": "e5_base",
        "label_model": "local_lasso",
        "labels_tag": "labels_local_lasso",
    },
        {
        "student": "local_lasso",
        "label_emb": "e5_base",
        "label_model": "local_lasso",
        "labels_tag": "labels_local_lasso",
    }
]

# ────────────────────────────────────────────────────────────────────────────
# Per-embedding tables/plots
# ────────────────────────────────────────────────────────────────────────────
def _auto_flag(val: str, default_auto: bool) -> bool:
    v = (val or "").strip().lower()
    if v in {"1","true","yes","y"}:  return True
    if v in {"0","false","no","n"}:  return False
    return default_auto

MULTI_EMBED = len(STUDENT_EMBEDDINGS) > 1
WRITE_PER_EMBED_TABLES: bool = _auto_flag("auto", default_auto=not MULTI_EMBED)
WRITE_PER_EMBED_PLOTS:  bool = _auto_flag("auto", default_auto=not MULTI_EMBED)

if not WRITE_PER_EMBED_TABLES:
    log.info("[guard] Per-embedding tables disabled (combined report will handle tables).")
if not WRITE_PER_EMBED_PLOTS:
    log.info("[guard] Per-embedding plots disabled (combined report will draw plots).")

def write_run_config():
    """Dump a lightweight snapshot of env + resolved settings for reproducibility."""
    try:
        cfg = {
            "run_id": RUN_ID,
            "timestamp": time.strftime("%Y-%m-%d %H:%M:%S"),
            "device": DEVICE_STR,
            "seed": SEED,
            "n_jobs": N_JOBS,
            "student_embeddings": STUDENT_EMBEDDINGS,
            "current_student_embedding": STUDENT_EMB,
            "students": STUDENTS,
            "methods_filter_env": os.getenv("METHODS", "").strip() or None,
            "pct_list": PCT_LIST,
            "pcts_for_plots": PCTS_FOR_PLOTS,
            "label_source_embedding": LABEL_EMB,
            "label_source_model": LABEL_MODEL,
            "max_folds": MAX_FOLDS,
            "boot_B": BOOT_B,
            "write_per_embed_tables": WRITE_PER_EMBED_TABLES,
            "write_per_embed_plots": WRITE_PER_EMBED_PLOTS,
            "skip_baseline_if_present": os.getenv("SKIP_BASELINE_IF_PRESENT", "0"),
            "do_refits": os.getenv("DO_REFITS", "1"),
            "paths": {
                "root": str(ROOT),
                "g1_dir": str(G1_DIR),
                "g2_dir": str(G2_DIR),
                "g3a_dir": str(G3A_DIR),
                "results_dir": str(RES_DIR),
                "tables_dir": str(TABLES_DIR),
                "figs_dir": str(FIGS_DIR),
            },
        }
        (G3A_DIR / "run_config.json").write_text(json.dumps(cfg, indent=2), encoding="utf-8")
        log.info("run_config.json written.")
    except Exception as e:
        log.warning(f"Could not write run_config.json: {e}")

# ────────────────────────────────────────────────────────────────────────────
#  Sentence encoders
# ────────────────────────────────────────────────────────────────────────────
_ENCODERS: Dict[str, SentenceTransformer] = {}

def get_encoder(emb_key: str) -> SentenceTransformer:
    if emb_key in _ENCODERS:
        return _ENCODERS[emb_key]
    repo = EMBEDDING_SPECS[emb_key]
    log.info(f"Loading SentenceTransformer [{emb_key}]: {repo} → device={DEVICE_STR}")
    m = SentenceTransformer(repo, device=DEVICE_STR)
    m.max_seq_length = MAX_SEQ_LEN
    _ENCODERS[emb_key] = m
    return m

def encode_texts(emb_key: str, texts: List[str], batch_size: int = 64) -> np.ndarray:
    mdl = get_encoder(emb_key)
    X = mdl.encode(texts, batch_size=batch_size, convert_to_numpy=True,
                   show_progress_bar=False, normalize_embeddings=False)
    return X.astype(np.float32, copy=False)

# ────────────────────────────────────────────────────────────────────────────
#  Student models
# ────────────────────────────────────────────────────────────────────────────
class RegressorChainCV(BaseEstimator, RegressorMixin):
    """Cross-validated out-of-fold chaining with randomized target order support."""
    def __init__(self, base_estimator, order=None, cv_splits=5, random_state=SEED):
        self.base_estimator = base_estimator
        self.order = order
        self.cv_splits = cv_splits
        self.random_state = random_state
        self.chain_models_ = []
        self.n_targets_ = None

    def fit(self, X, Y):
        rng = check_random_state(self.random_state)
        n_samples, self.n_targets_ = Y.shape
        if self.order is None:
            self.order = np.arange(self.n_targets_)
        kf = KFold(n_splits=self.cv_splits, shuffle=True, random_state=rng)
        X_chain = np.ascontiguousarray(X, dtype=np.float32)

        oof_cols = []
        for target_idx in self.order:
            y = Y[:, target_idx]
            oof = np.zeros(n_samples, dtype=np.float32)
            for tr, va in kf.split(X_chain):
                est = clone(self.base_estimator)
                est.fit(X_chain[tr], y[tr])
                oof[va] = est.predict(X_chain[va])
            oof_cols.append(oof.reshape(-1, 1))
            X_chain = np.hstack([X_chain, oof.reshape(-1,1)])

        self.chain_models_ = []
        acc = []
        for i, target_idx in enumerate(self.order):
            if i == 0:
                X_full = X
            else:
                acc.append(oof_cols[i-1])
                X_full = np.hstack([X, np.hstack(acc)])
            est = clone(self.base_estimator)
            est.fit(X_full, Y[:, target_idx])
            self.chain_models_.append(est)
        return self

    def predict(self, X):
        X_ext = np.ascontiguousarray(X, dtype=np.float32)
        n = X.shape[0]
        preds = np.zeros((n, self.n_targets_), dtype=np.float32)
        for i, target_idx in enumerate(self.order):
            yhat = self.chain_models_[i].predict(X_ext).reshape(-1, 1)
            preds[:, target_idx] = yhat[:, 0]
            X_ext = np.hstack([X_ext, yhat])
        return preds

class EnsembleRegressorChainsCV(BaseEstimator, RegressorMixin):
    """Ensemble over random target orders; average predictions."""
    def __init__(self, base_estimator, n_chains=5, cv_splits=5, random_state=SEED):
        self.base_estimator = base_estimator
        self.n_chains = n_chains
        self.cv_splits = cv_splits
        self.random_state = random_state
        self.ensemble_ = None
        self.n_targets_ = None

    def fit(self, X, Y):
        rng = check_random_state(self.random_state)
        self.n_targets_ = Y.shape[1]
        self.ensemble_ = []
        for _ in range(self.n_chains):
            order = np.arange(self.n_targets_)
            rng.shuffle(order)
            chain = RegressorChainCV(
                base_estimator=self.base_estimator,
                order=order,
                cv_splits=self.cv_splits,
                random_state=rng.randint(0, 1_000_000),
            )
            chain.fit(X, Y)
            self.ensemble_.append((order, chain))
        return self

    def predict(self, X):
        preds = [chain.predict(X) for (_, chain) in self.ensemble_]
        return np.mean(preds, axis=0)

def build_with_hp(student_key: str, hp: Dict[str, float | int]) -> Pipeline:
    if student_key == "chain_ERCcv_lr":
        return Pipeline([
            ("pca", PCA(random_state=SEED, n_components=float(hp["pca__n_components"]))),
            ("chain", EnsembleRegressorChainsCV(
                base_estimator=LinearRegression(),
                n_chains=int(hp["chain__n_chains"]),
                cv_splits=int(hp["chain__cv_splits"]),
                random_state=SEED
            )),
        ])
    elif student_key == "local_lasso":
        return Pipeline([
            ("pca", PCA(random_state=SEED, n_components=float(hp["pca__n_components"]))),
            ("reg", MultiOutputRegressor(Lasso(alpha=float(hp["reg__estimator__alpha"]),
                                               random_state=SEED, max_iter=10_000))),
        ])
    elif student_key == "local_rf":
        return Pipeline([
            ("pca", PCA(random_state=SEED, n_components=float(hp["pca__n_components"]))),
            ("reg", MultiOutputRegressor(RandomForestRegressor(
                n_estimators=int(hp["reg__estimator__n_estimators"]),
                max_depth=None if (hp.get("reg__estimator__max_depth", None) in [None, "None"]) else int(hp["reg__estimator__max_depth"]),
                min_samples_leaf=int(hp.get("reg__estimator__min_samples_leaf", 1)),
                random_state=SEED, n_jobs=1
            ))),
        ])
    elif student_key == "global_rf":
        return Pipeline([
            ("pca", PCA(random_state=SEED, n_components=float(hp["pca__n_components"]))),
            ("reg", RandomForestRegressor(
                n_estimators=int(hp["reg__n_estimators"]),
                max_depth=None if (hp.get("reg__max_depth", None) in [None, "None"]) else int(hp["reg__max_depth"]),
                min_samples_leaf=int(hp.get("reg__min_samples_leaf", 1)),
                random_state=SEED, n_jobs=1
            )),
        ])
    elif student_key == "chain_ERCcv_rf":
        base_rf = RandomForestRegressor(
            n_estimators=int(hp.get("chain__base_rf__n_estimators", 100)),
            max_depth=None if (hp.get("chain__base_rf__max_depth", None) in [None, "None"]) else int(hp["chain__base_rf__max_depth"]),
            min_samples_leaf=int(hp.get("chain__base_rf__min_samples_leaf", 1)),
            random_state=SEED, n_jobs=1
        )
        return Pipeline([
            ("pca", PCA(random_state=SEED, n_components=float(hp["pca__n_components"]))),
            ("chain", EnsembleRegressorChainsCV(
                base_estimator=base_rf,
                n_chains=int(hp["chain__n_chains"]),
                cv_splits=int(hp["chain__cv_splits"]),
                random_state=SEED
            )),
        ])
    else:
        raise ValueError(f"Unknown student_key: {student_key}")

# ────────────────────────────────────────────────────────────────────────────
#  Load per-fold teacher pickle & extract HPs
# ────────────────────────────────────────────────────────────────────────────
def teacher_pickle_path(fold_idx: int, student_key: str, emb_key: str) -> Path:
    # output/g_synth_augmented/g_2a_teacher_labeling_loocv_tuned/teacher/teacher_fold{ii}_{emb}__{student}.pkl
    return ROOT / "models" / "teacher" / f"teacher_fold{fold_idx:02d}_{emb_key}__{student_key}.pkl"

def _hp_json_path(fold_idx: int, student_key: str, emb_key: str) -> Path:
    return G2_DIR / "teacher" / f"hp_fold{fold_idx:02d}_{emb_key}__{student_key}.json"

def hp_from_teacher(est: Pipeline, student_key: str) -> Dict[str, float | int]:
    hp: Dict[str, float | int] = {}

    if "pca" in est.named_steps:
        hp["pca__n_components"] = float(getattr(est.named_steps["pca"], "n_components", 0.8))

    if student_key == "chain_ERCcv_lr":
        ch = est.named_steps["chain"]
        hp["chain__n_chains"] = int(getattr(ch, "n_chains", 5))
        hp["chain__cv_splits"] = int(getattr(ch, "cv_splits", 5))

    elif student_key == "local_lasso":
        reg = est.named_steps["reg"].estimator
        hp["reg__estimator__alpha"] = float(getattr(reg, "alpha", 0.01))

    elif student_key == "local_rf":
        reg = est.named_steps["reg"].estimator
        hp["reg__estimator__n_estimators"] = int(getattr(reg, "n_estimators", 100))
        hp["reg__estimator__max_depth"] = getattr(reg, "max_depth", None)
        hp["reg__estimator__min_samples_leaf"] = int(getattr(reg, "min_samples_leaf", 1))

    elif student_key == "global_rf":
        reg = est.named_steps["reg"]
        hp["reg__n_estimators"] = int(getattr(reg, "n_estimators", 100))
        hp["reg__max_depth"] = getattr(reg, "max_depth", None)
        hp["reg__min_samples_leaf"] = int(getattr(reg, "min_samples_leaf", 1))

    elif student_key == "chain_ERCcv_rf":
        ch = est.named_steps["chain"]
        hp["chain__n_chains"] = int(getattr(ch, "n_chains", 5))
        hp["chain__cv_splits"] = int(getattr(ch, "cv_splits", 5))

        base_rf = getattr(ch, "base_estimator", None)
        if base_rf is None and hasattr(ch, "base_estimator_"):
            base_rf = ch.base_estimator_

        if isinstance(base_rf, RandomForestRegressor):
            hp["chain__base_rf__n_estimators"] = int(getattr(base_rf, "n_estimators", 100))
            hp["chain__base_rf__max_depth"] = getattr(base_rf, "max_depth", None)
            hp["chain__base_rf__min_samples_leaf"] = int(getattr(base_rf, "min_samples_leaf", 1))
    else:
        raise ValueError(student_key)

    return hp

def load_teacher_fold_estimator(fold_idx: int, student_key: str, emb_key: str, retries: int = 3, sleep_s: float = 1.5) -> Optional[Pipeline]:
    """Load fitted teacher pipeline with a few retries for flaky network/cloud files."""
    p = teacher_pickle_path(fold_idx, student_key, emb_key)
    for attempt in range(1, retries + 1):
        if not p.exists():
            break
        try:
            with open(p, "rb") as f:
                return pickle.load(f)
        except Exception as e:
            if "timed out" in str(e).lower() or isinstance(e, OSError):
                log.warning(f"Could not load teacher pickle ({p.name}) [attempt {attempt}/{retries}]: {e}")
                if attempt < retries:
                    time.sleep(sleep_s * attempt)
                    continue
            else:
                log.warning(f"Could not load teacher pickle ({p.name}): {e}")
            break
    return None

def hp_from_teacher_or_json(fold_idx: int, student_key: str, emb_key: str, est: Optional[Pipeline]) -> Dict[str, float | int]:
    """
    Extract HPs from teacher estimator if provided; else from the per-fold hp_*.json.
    If neither is available, RAISE an error (no hardcoded defaults).
    """
    # 1) Teacher pickle provided
    if est is not None:
        return hp_from_teacher(est, student_key)

    # 2) JSON fallback
    j = _hp_json_path(fold_idx, student_key, emb_key)
    if j.exists():
        try:
            payload = json.loads(j.read_text(encoding="utf-8"))
            # Some older JSONs may store the whole payload; accept either {"best_params": {...}} or the flat dict
            if isinstance(payload, dict) and "best_params" in payload and isinstance(payload["best_params"], dict):
                return payload["best_params"]
            if isinstance(payload, dict):
                return payload  # assume it already is the HP dict
        except Exception as e:
            raise RuntimeError(f"HP JSON unreadable: {j.name} | {e}") from e

    # 3) Hard stop
    raise RuntimeError(
        f"Missing per-fold HPs for fold={fold_idx}, student={student_key}, emb={emb_key}. "
        f"Expected either teacher pickle ({teacher_pickle_path(fold_idx, student_key, emb_key).name}) "
        f"or HP JSON ({_hp_json_path(fold_idx, student_key, emb_key).name})."
    )

# ────────────────────────────────────────────────────────────────────────────
#  Data loading (seeds + synthetics)
# ────────────────────────────────────────────────────────────────────────────
def load_seeds() -> pd.DataFrame:
    scores = pd.read_csv(DATA_DIR / "activity_scores.csv")
    acts   = pd.read_csv(DATA_DIR / "activities.csv")
    dm = scores.pivot(index="activity_id", columns="domain_id", values="score").reset_index()
    dm = dm.rename(columns=lambda x: f"domain{x}" if isinstance(x, (int, np.integer)) else x)
    dm = dm.merge(acts[["activity_id", "question"]], on="activity_id", how="left")
    dm = dm.rename(columns={"activity_id":"seed_id", "question":"text"})
    dm = dm.sort_values("seed_id").reset_index(drop=True)
    assert len(dm) == 96, f"Expected 96 seeds, got {len(dm)}"
    return dm[["seed_id","text", *TARGET_COLS]]

def ensure_seed_vectors(seeds_df: pd.DataFrame, emb_key: str) -> np.ndarray:
    vec_path  = SEED_VECTORS_DIR / f"{emb_key}_vectors.npy"
    meta_path = vec_path.with_suffix(".meta.json")

    # Prefer reuse if rows match; meta is optional
    if vec_path.exists():
        try:
            X = np.load(vec_path)
            ok_rows = (X.shape[0] == len(seeds_df))
            ok_len = True  # default to OK if no meta
            if meta_path.exists():
                meta = json.loads(meta_path.read_text(encoding="utf-8"))
                ok_len = int(meta.get("max_seq_len", MAX_SEQ_LEN)) == int(MAX_SEQ_LEN)
            if ok_rows and ok_len:
                np.save(CACHE_DIR / f"X_seed_{emb_key}.npy", X)
                log.info("✔ Reusing seed vectors → %s", vec_path.relative_to(ROOT))
                return X.astype(np.float32, copy=False)
            else:
                log.warning("[%s] Rebuilding seed vectors (rows_ok=%s, max_seq_len_ok=%s).",
                            emb_key, ok_rows, ok_len)
        except Exception as e:
            log.warning("[%s] Existing vectors unusable (%s); rebuilding.", emb_key, e)

    # Compute and save to the frozen location if nothing reusable was found
    SEED_VECTORS_DIR.mkdir(parents=True, exist_ok=True)
    log.info("Computing %s embeddings for seeds …", emb_key)
    X = encode_texts(emb_key, seeds_df["text"].astype(str).tolist(), batch_size=64)
    np.save(vec_path, X.astype(np.float32, copy=False))
    np.save(CACHE_DIR / f"X_seed_{emb_key}.npy", X.astype(np.float32, copy=False))

    # (optional) write meta
    try:
        meta = {
            "embedding": emb_key,
            "repo": EMBEDDING_SPECS[emb_key],
            "max_seq_len": int(MAX_SEQ_LEN),
            "n_rows": int(X.shape[0]),
            "n_dim": int(X.shape[1]),
        }
        meta_path.write_text(json.dumps(meta, indent=2), encoding="utf-8")
    except Exception:
        pass

    log.info("✔ Seed vectors saved → %s", vec_path.relative_to(ROOT))
    return X.astype(np.float32, copy=False)


def base_from_method_and_M(method: str, M: int) -> str:
    return f"n{M}_{method}"

def g2_label_file(
    fold_idx: int,
    method: str,
    M: int,
    label_emb: str,
    label_model: str,
    strict_model: bool = False
) -> Path:
    base = base_from_method_and_M(method, M)
    preferred = G2_DIR / f"g2f_labels_fold{fold_idx:02d}_{base}__{label_emb}__{label_model}.csv"
    if preferred.exists():
        return preferred
    if strict_model:
        return preferred  # let caller raise later if missing
    # fallback: any model with the label embedding (keeps label source embedding fixed)
    cand = sorted(G2_DIR.glob(f"g2f_labels_fold{fold_idx:02d}_{base}__{label_emb}__*.csv"))
    if cand:
        log.warning(
            "Falling back to %s for fold=%02d, method=%s, M=%d (label_model=%s missing).",
            cand[0].name, fold_idx, method, M, label_model
        )
        return cand[0]
    return preferred  # will raise later if missing

def discover_methods_and_M_from_g2(label_emb: str) -> Dict[str, List[int]]:
    pats = sorted(G2_DIR.glob(f"g2f_labels_fold00_n*_*__{label_emb}__*.csv"))
    rows: Dict[str, set] = {}
    for p in pats:
        m = re.match(rf"g2f_labels_fold00_n(\d+)_([A-Za-z0-9_]+)__{label_emb}__.*\.csv$", p.name)
        if not m: 
            continue
        M = int(m.group(1)); method = m.group(2)
        rows.setdefault(method, set()).add(M)
    return {k: sorted(v) for k, v in rows.items()}

def synth_cache_paths_from_g2c(g1_path: Path, emb_key: str) -> Tuple[Path, Path]:
    base = g1_path.stem
    m = re.match(r"g_final_(n\d+_[A-Za-z0-9_]+)$", base)
    if m:
        base = m.group(1)
    cache_dir = G2_DIR / "cache" / "synth_embeds"
    candidates = [
        (cache_dir / f"g_final_{base}__{emb_key}.npy",  cache_dir / f"g_final_{base}__index.csv"),
        (cache_dir / f"{base}__{emb_key}.npy",          cache_dir / f"{base}__index.csv"),
    ]
    for npy, idx in candidates:
        if npy.exists() and idx.exists():
            return npy, idx
    nearby = sorted(p.name for p in cache_dir.glob(f"*{base}*"))
    tried  = " ; ".join([npy.name for npy, _ in candidates] + [idx.name for _, idx in candidates])
    raise FileNotFoundError(
        f"Missing synth cache for base='{base}' and embedding='{emb_key}'. Tried: {tried}. "
        f"Found near-matches in cache: {nearby}"
    )

def g1_source_csv(method: str, M: int) -> Path:
    p = G1_DIR / f"g_final_n{M}_{method}.csv"
    if not p.exists():
        raise FileNotFoundError(f"Missing Script-A source: {p}")
    return p

def ensure_synth_cache(method: str, M: int, emb_key: str) -> Tuple[Path, Path]:
    """Ensure (npy, index.csv) exist for Script-A synthetics under the given embedding."""
    src_csv = g1_source_csv(method, M)
    base = src_csv.stem
    m = re.match(r"g_final_(n\d+_[A-Za-z0-9_]+)$", base)
    base = m.group(1) if m else base

    cache_dir = G2_DIR / "cache" / "synth_embeds"
    npy  = cache_dir / f"g_final_{base}__{emb_key}.npy"
    idx  = cache_dir / f"g_final_{base}__index.csv"
    if npy.exists() and idx.exists():
        return npy, idx

    df = pd.read_csv(src_csv)
    if "text" not in df.columns:
        raise ValueError(f"{src_csv.name}: missing 'text' column")
    texts = df["text"].astype(str).tolist()

    log.info(f"[cache] building synth embeds for method={method}, M={M}, emb={emb_key} …")
    Xs = encode_texts(emb_key, texts, batch_size=64).astype(np.float32, copy=False)
    cache_dir.mkdir(parents=True, exist_ok=True)
    np.save(npy, Xs)
    pd.DataFrame({"text": texts}).to_csv(idx, index=False)
    log.info(f"[cache] ✔ saved → {npy.relative_to(ROOT)} ; {idx.relative_to(ROOT)}")
    return npy, idx


# ────────────────────────────────────────────────────────────────────────────
#  Metrics 
# ────────────────────────────────────────────────────────────────────────────
def rmse(a: np.ndarray, b: np.ndarray) -> np.ndarray:
    return np.sqrt(np.mean((a - b) ** 2, axis=0))

def rrmse_vs_dummy(y_true: np.ndarray, y_pred: np.ndarray, y_dummy: np.ndarray) -> np.ndarray:
    rm = rmse(y_true, y_pred)
    rd = rmse(y_true, y_dummy)
    return rm / np.maximum(rd, 1e-12)

def _fmt_dt(sec: float) -> str:
    m, s = divmod(sec, 60.0)
    h, m = divmod(m, 60.0)
    if h >= 1: return f"{int(h)}h{int(m):02d}m{s:04.1f}s"
    if m >= 1: return f"{int(m)}m{s:04.1f}s"
    return f"{s:0.2f}s"

def paired_cliffs_delta(full: np.ndarray, base: np.ndarray):
    diff = base - full  # positive when full improves (lower is better)
    n_pos = int(np.sum(diff > 0))
    n_neg = int(np.sum(diff < 0))
    n_zero = int(np.sum(diff == 0))
    denom = max(1, (n_pos + n_neg))
    delta = (n_pos - n_neg) / denom
    mag = ("negligible" if abs(delta) < 0.147 else
           "small"       if abs(delta) < 0.33  else
           "medium"      if abs(delta) < 0.474 else
           "large")
    return float(delta), n_pos, n_neg, n_zero, mag

def pct_to_K(pct: int, n_seeds: int = 96) -> int:
    return max(1, int(round((pct / 100.0) * n_seeds)))

def _dom_cols(df: pd.DataFrame) -> List[str]:
    return [c for c in df.columns if c.startswith(_DOM_PREFIX)]

def _global_median_from_file(p: Path) -> float:
    df = pd.read_csv(p)
    cols = _dom_cols(df)
    if not cols:
        raise RuntimeError(f"{p.name}: no '{_DOM_PREFIX}*' columns present.")
    return float(np.median(df[cols].to_numpy(dtype=np.float32).ravel()))

def _flat_arrays(p_base: Path, p_full: Path) -> Tuple[np.ndarray, np.ndarray]:
    db = pd.read_csv(p_base)
    df = pd.read_csv(p_full)
    cols = _dom_cols(db)
    if not cols or not set(cols).issubset(df.columns):
        missing = sorted(set(cols) - set(df.columns))
        raise RuntimeError(f"{p_full.name}: missing '{_DOM_PREFIX}*' columns: {missing}")
    xb = db[cols].to_numpy(dtype=np.float32).ravel()
    xf = df[cols].to_numpy(dtype=np.float32).ravel()
    return xb, xf

def section(title):
    """Print section header"""
    bar = "═" * len(title)
    print(f"\n{bar}\n{title}\n{bar}")

def _save_and_show(fig, path: str):
    """Save and display figure"""
    fig.savefig(path, bbox_inches="tight", dpi=300)
    plt.show()
    print(f"Plot saved → {path}")

def aligned_ranks(mat):
    """Hodges–Lehmann alignment + ranking along rows (lower is better)"""
    aligned = mat - np.median(mat, axis=1, keepdims=True)
    return np.apply_along_axis(lambda r: np.argsort(np.argsort(r)) + 1, 1, aligned)

def friedman_aligned(mat):
    """Aligned-Friedman χ² and Iman–Davenport F-statistic (expects ranks or aligned data)"""
    k = mat.shape[1]
    chi2, _ = friedmanchisquare(*[mat[:, i] for i in range(k)])
    Ff = ((mat.shape[0] - 1) * chi2) / (mat.shape[0] * (k - 1) - chi2)
    return chi2, Ff

def wilcoxon_matrix(mat, labels):
    """Pairwise two-sided Wilcoxon (zero-method = zsplit)"""
    df = pd.DataFrame(np.ones((len(labels), len(labels))), index=labels, columns=labels)
    for i, j in combinations(range(len(labels)), 2):
        diff = mat[:, i] - mat[:, j]
        p = 1.0 if np.allclose(diff, 0) else wilcoxon(diff, zero_method="zsplit")[1]
        df.iat[i, j] = df.iat[j, i] = p
    return df.round(4)

def holm_correct_and_effects(raw_p, data, labels):
    """Holm–Bonferroni correction and Cliff's Δ effect sizes"""
    idx = list(combinations(range(len(labels)), 2))
    pvals = [raw_p.iat[i, j] for i, j in idx]
    _, p_adj, _, _ = multipletests(pvals, method="holm")

    adj_df = raw_p.copy()
    for (i, j), p in zip(idx, p_adj):
        adj_df.iat[i, j] = adj_df.iat[j, i] = p
    adj_df[np.eye(len(labels), dtype=bool)] = 1.0

    def cliffs_delta(x, y):
        diffs = np.subtract.outer(x, y)
        n = len(x) * len(y)
        return (np.sum(diffs > 0) - np.sum(diffs < 0)) / n

    delta_df = pd.DataFrame(np.ones((len(labels), len(labels))), index=labels, columns=labels)
    for (i, j) in idx:
        d_ij = cliffs_delta(data[:, i], data[:, j])
        delta_df.iat[i, j] = d_ij
        delta_df.iat[j, i] = -d_ij

    return adj_df.round(4), delta_df.round(3)

def conover_posthoc(ranks, labels, fname_tag):
    """Conover–Iman test with Holm correction"""
    p_df = sp.posthoc_conover_friedman(ranks, p_adjust="holm")
    p_df.index = p_df.columns = labels
    out = TABLES_DIR / f"{fname_tag}_conover_p.csv"
    p_df.to_csv(out)
    print("\nConover–Iman post-hoc p-values (Holm-adjusted):")
    print(p_df.round(4).to_string())
    print("  ↳ saved →", out)
    return p_df

def run_friedman(mat, block_name, col_labels, fname_tag):
    """Generic routine for Friedman analysis with post-hoc tests"""
    k = len(col_labels)
    nblocks = mat.shape[0]

    # Save & print medians (PRINT SORTED low→high; CSV keeps original order)
    col_meds = pd.Series(np.median(mat, axis=0), index=col_labels)
    med_path = TABLES_DIR / f"{fname_tag}_median.csv"
    col_meds.to_csv(med_path, header=["median_rrmse"])
    print(f"\nMedian RRMSE per {block_name[:-1] if block_name.endswith('s') else block_name} (sorted low→high):")
    print(col_meds.sort_values().round(3).to_string())
    print("  ↳ saved →", med_path)

    if nblocks == 2:
        print(f"\nOnly two {block_name} → skipping Friedman/post-hoc.")
        wilc = wilcoxon_matrix(mat, col_labels)
        print("\nWilcoxon pairwise p-values:")
        print(wilc.round(4).to_string())
        wilc_path = TABLES_DIR / f"{fname_tag}_wilcoxon_raw_p.csv"
        wilc.to_csv(wilc_path)
        print("  ↳ saved →", wilc_path)

        adj, delta = holm_correct_and_effects(wilc, mat, col_labels)
        print("\nHolm–Bonferroni adjusted p-values:")
        print(adj.round(4).to_string())
        adj_path = TABLES_DIR / f"{fname_tag}_wilcoxon_holm_p.csv"
        adj.to_csv(adj_path)
        print("  ↳ saved →", adj_path)

        print("\nCliff's Δ effect sizes:")
        print(delta.round(3).to_string())
        delta_path = TABLES_DIR / f"{fname_tag}_cliffs_delta.csv"
        delta.to_csv(delta_path)
        print("  ↳ saved →", delta_path)
        return

    if k == 2:
        p = wilcoxon(mat[:, 0], mat[:, 1], zero_method="zsplit")[1]
        print(f"\nPaired Wilcoxon ({col_labels[0]} vs {col_labels[1]}): p = {p:.5g}")
        return

    # Friedman statistics
    ranks = aligned_ranks(mat)
    chi2_a, Ff_a = friedman_aligned(ranks)
    chi2_o, p_o = friedmanchisquare(*[mat[:, i] for i in range(k)])
    Ff_o = ((nblocks - 1) * chi2_o) / (nblocks * (k - 1) - chi2_o)

    print(f"\n*Aligned-Friedman* (blocks = {block_name})")
    print(f"  χ²_F = {chi2_a:.3f}    F_F = {Ff_a:.3f}")
    print(f"\n*Original-Friedman* (blocks = {block_name})")
    print(f"  χ²_F = {chi2_o:.3f}    p = {p_o:.3g}    F_F = {Ff_o:.3f}")

    # Post-hoc tests
    if nblocks < 10:
        conover_posthoc(ranks, col_labels, fname_tag)
    else:
        pvals_nem = sp.posthoc_nemenyi_friedman(ranks)
        pvals_nem.index = pvals_nem.columns = col_labels
        nem_path = TABLES_DIR / f"{fname_tag}_nemenyi_p.csv"
        pvals_nem.to_csv(nem_path)
        print("\nNemenyi p-values (aligned post-hoc):")
        print(pvals_nem.round(4).to_string())
        print("  ↳ saved →", nem_path)

    # 1) Pair-wise Wilcoxon
    wilc = wilcoxon_matrix(mat, col_labels)
    print("\nWilcoxon pairwise p-values:")
    print(wilc.round(4).to_string())
    wilc_path = TABLES_DIR / f"{fname_tag}_wilcoxon_raw_p.csv"
    wilc.to_csv(wilc_path)
    print("  ↳ saved →", wilc_path)

    # 2) Holm–Bonferroni adjustment + Cliff’s Δ
    adj, delta = holm_correct_and_effects(wilc, mat, col_labels)

    print("\nHolm–Bonferroni adjusted p-values:")
    print(adj.round(4).to_string())
    adj_path = TABLES_DIR / f"{fname_tag}_wilcoxon_holm_p.csv"
    adj.to_csv(adj_path)
    print("  ↳ saved →", adj_path)

    print("\nCliff's Δ effect sizes:")
    print(delta.round(3).to_string())
    delta_path = TABLES_DIR / f"{fname_tag}_cliffs_delta.csv"
    delta.to_csv(delta_path)
    print("  ↳ saved →", delta_path)

def vector_per_target(rrmse_array):
    """Collapse (folds × targets) → (targets,) by median across folds"""
    return np.median(rrmse_array, axis=0) if getattr(rrmse_array, "ndim", 1) == 2 else rrmse_array

def matrix_per_target_compare_models(data_dict, model_list, embedding_list):
    """Build matrix: rows = targets, cols = models, aggregated across embeddings"""
    return np.column_stack([
        np.median(
            np.concatenate([
                vector_per_target(data_dict[emb][model])
                for emb in embedding_list if emb in data_dict
            ], axis=0).reshape(-1, N_TARGETS),
            axis=0
        )
        for model in model_list
    ])

# --- PATCH: show full statistics right under each CD diagram -----------------

def cd_plot(matrix, labels, title, fname):
    """Critical-distance diagram with robust p-value alignment to labels."""
    if matrix.shape[1] < 2:
        print(f"⚠  Skipping CD-plot '{title}' (need ≥2 methods, got {matrix.shape[1]})")
        return

    ranks = aligned_ranks(matrix)

    # Compute post-hoc p-values and FORCE index/columns to match `labels`
    pvals_raw = sp.posthoc_nemenyi_friedman(ranks)
    if not isinstance(pvals_raw, pd.DataFrame):
        pvals = pd.DataFrame(pvals_raw, index=range(len(labels)), columns=range(len(labels)))
    else:
        pvals = pvals_raw.copy()

    # Defensive shape fix (trim/pad unlikely; trim covers rare inconsistencies)
    if pvals.shape != (len(labels), len(labels)):
        pvals = pvals.iloc[:len(labels), :len(labels)]
        if pvals.shape != (len(labels), len(labels)):
            # Last resort: identity p-values (no significant lines)
            pvals = pd.DataFrame(np.ones((len(labels), len(labels))), index=range(len(labels)), columns=range(len(labels)))

    # Align names to your model labels, sanitize & symmetrize
    pvals.index = labels
    pvals.columns = labels
    pvals = pvals.astype(float).fillna(1.0)
    pvals = pd.DataFrame(np.minimum(pvals.values, pvals.values.T), index=labels, columns=labels)
    np.fill_diagonal(pvals.values, 1.0)

    fig, ax = plt.subplots(figsize=(8, 3), dpi=120)
    sp.critical_difference_diagram(pd.Series(ranks.mean(0), index=labels), pvals, ax=ax)
    ax.set_title(title, pad=10)
    _save_and_show(fig, PLOTS_DIR / fname)

    # full report printed under the plot
    tag = Path(fname).stem
    run_friedman(matrix, block_name="folds", col_labels=labels, fname_tag=tag)


def cd_plot_dual(matrix1, labels1, matrix2, labels2, title1, title2, fname):
    """Two CD-diagrams side-by-side with robust p-value alignment."""
    if matrix1.shape[1] < 2 or matrix2.shape[1] < 2:
        print("⚠  Skipping dual CD-plot (need ≥2 methods for both)")
        return

    def _aligned_pvals(M, lbls):
        r = aligned_ranks(M)
        raw = sp.posthoc_nemenyi_friedman(r)
        if not isinstance(raw, pd.DataFrame):
            P = pd.DataFrame(raw, index=range(len(lbls)), columns=range(len(lbls)))
        else:
            P = raw.copy()
        if P.shape != (len(lbls), len(lbls)):
            P = P.iloc[:len(lbls), :len(lbls)]
            if P.shape != (len(lbls), len(lbls)):
                P = pd.DataFrame(np.ones((len(lbls), len(lbls))), index=range(len(lbls)), columns=range(len(lbls)))
        P.index = lbls
        P.columns = lbls
        P = P.astype(float).fillna(1.0)
        P = pd.DataFrame(np.minimum(P.values, P.values.T), index=lbls, columns=lbls)
        np.fill_diagonal(P.values, 1.0)
        return r, P

    ranks1, pvals1 = _aligned_pvals(matrix1, labels1)
    ranks2, pvals2 = _aligned_pvals(matrix2, labels2)

    fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(14, 3), dpi=120)
    sp.critical_difference_diagram(pd.Series(ranks1.mean(0), index=labels1), pvals1, ax=ax1)
    ax1.set_title(title1, pad=10)
    sp.critical_difference_diagram(pd.Series(ranks2.mean(0), index=labels2), pvals2, ax=ax2)
    ax2.set_title(title2, pad=10)
    plt.tight_layout()
    _save_and_show(fig, PLOTS_DIR / fname)

    base_tag = Path(fname).stem
    section(f"Full statistics — LEFT panel: {title1}")
    run_friedman(matrix1, block_name="folds", col_labels=labels1, fname_tag=f"{base_tag}__left")

    section(f"Full statistics — RIGHT panel: {title2}")
    run_friedman(matrix2, block_name="folds", col_labels=labels2, fname_tag=f"{base_tag}__right")


# ────────────────────────────────────────────────────────────────────────────
#  Reuse for chain_ERCcv_lr — only for e5_base student == label source
# ───────────────────────────────────────────────────────────────────────────

def result_path(student_emb: str, regressor: str, method: str,
                pct: int, K: int, Mmax: int, variant: str,
                labels_tag: Optional[str] = None) -> Path:
    """
    Build results path for baseline/full CSVs.
    If labels_tag is provided (e.g., 'labels_local_lasso'), it is inserted
    before the final '__{variant}' segment to keep files disambiguated.
    """
    tag = f"__{labels_tag}" if labels_tag else ""
    fname = f"rrmse_perfold_{student_emb}__{regressor}__{method}__pct{pct}_K{K}__Mmax{Mmax}{tag}__{variant}.csv"
    return RES_DIR / fname

def ensure_chainlr_from_legacy_pct(student_emb: str, method: str, pct: int, K: int,
                                   N_SEEDS: int, Mmax: int, labels_tag: Optional[str] = None) -> bool:
    # Only reuse legacy artifacts for the DEFAULT label source (no tag).
    if labels_tag:
        return False
    if not (student_emb == "e5_base" and LABEL_EMB == "e5_base"):
        return False
    mapping = {100: ("n96_sps1", 96), 200: ("n192_sps2", 192), 400: ("n384_sps4", 384)}
    if pct not in mapping:
        return False
    tag, _ = mapping[pct]
    legacy_root = RES_DIR
    legacy_full = legacy_root / f"rrmse_perfold_e5_base__chain_ERCcv_lr__{method}__{tag}__full.csv"
    if not legacy_full.exists():
        log.info(f"[reuse/chain_lr] Missing legacy {legacy_full.name} → recompute instead.")
        return False

    new_full = result_path(student_emb, "chain_ERCcv_lr", method, pct, K, Mmax, "full", labels_tag=None)
    try:
        if new_full.exists():
            n_rows = sum(1 for _ in open(new_full, "r", encoding="utf-8")) - 1
            if n_rows >= N_SEEDS:
                return True
    except Exception:
        pass

    df_old = pd.read_csv(legacy_full).iloc[:N_SEEDS].copy()
    keep_cols = [c for c in df_old.columns if c.startswith("rrmse_domain")]
    if "median_rrmse_fold" in df_old.columns:
        keep_cols = ["median_rrmse_fold"] + keep_cols

    fold_col = df_old["fold"].astype(int).values if "fold" in df_old.columns else np.arange(N_SEEDS, dtype=int)
    df_new = pd.DataFrame({
        "fold": fold_col,
        "method": method,
        "pct": pct,
        "K": K,
        "M_max": Mmax,
        "student": "S1",
        "embedding": student_emb,
        "regressor": "chain_ERCcv_lr",
        "variant": "full",
    })
    for c in keep_cols:
        df_new[c] = df_old[c].values

    new_full.parent.mkdir(parents=True, exist_ok=True)
    df_new.to_csv(new_full, index=False)
    log.info(f"[reuse/chain_lr] Wrote {new_full.name} from legacy {legacy_full.name}")
    return True

def run_for_embedding(emb_key: str):
    """Run the whole pipeline for a single student embedding key."""
    global STUDENT_EMB
    if emb_key not in EMBEDDING_SPECS:
        raise ValueError(f"Unknown embedding '{emb_key}'. Allowed: {list(EMBEDDING_SPECS)}")
    prev = STUDENT_EMB
    try:
        STUDENT_EMB = emb_key
        log.info(f"=== BEGIN run() for STUDENT_EMBEDDING={emb_key} ===")
        _run_core(labels_tag=None)
        log.info(f"=== END   run() for STUDENT_EMBEDDING={emb_key} ===")
    finally:
        STUDENT_EMB = prev

# ────────────────────────────────────────────────────────────────────────────
#  Combined reports 
# ────────────────────────────────────────────────────────────────────────────

def _matrix_for_method_pct(emb: str, method: str, pct: int, Mmax: int, regressors: List[str]) -> Tuple[np.ndarray, List[str]]:
    """
    Build mat of shape (n_folds, n_regressors) using FULL variant,
    each cell = median over domain-wise RRMSE for that fold.
    Only keeps folds present for ALL regressors (inner join on 'fold').
    """
    K = pct_to_K(pct, 96)
    fold_sets = []
    per_reg = {}

    for reg in regressors:
        p_full = RES_DIR / f"rrmse_perfold_{emb}__{reg}__{method}__pct{pct}_K{K}__Mmax{Mmax}__full.csv"
        if not p_full.exists():
            continue
        df = pd.read_csv(p_full)
        if "median_rrmse_fold" in df.columns:
            z = df[["fold", "median_rrmse_fold"]].rename(columns={"median_rrmse_fold":"med"})
        else:
            cols = [c for c in df.columns if c.startswith("rrmse_domain")]
            if not cols:
                continue
            z = df[["fold", *cols]].copy()
            z["med"] = z[cols].median(axis=1)
            z = z[["fold", "med"]]
        per_reg[reg] = z
        fold_sets.append(set(z["fold"].astype(int).tolist()))

    regs = sorted(per_reg.keys())
    if len(regs) < 2:
        return np.empty((0, 0)), []

    common_folds = sorted(set.intersection(*fold_sets)) if fold_sets else []
    if len(common_folds) < 2:
        return np.empty((0, 0)), []

    mats = []
    for reg in regs:
        z = per_reg[reg]
        z = z[z["fold"].isin(common_folds)].sort_values("fold")
        mats.append(z["med"].to_numpy(dtype=np.float32))

    mat = np.vstack(mats).T  # rows = folds, cols = regs
    return mat, regs

def say(msg: str):
    """Print once to screen and write once to run.log (no console duplicates)."""
    stream = getattr(sys, "__stdout__", sys.stdout)
    print(msg, file=stream, flush=True)  
    _say_logger.info(msg)          


def _bootstrap_delta_ci_from_files(p_base: Path, p_full: Path, B: int = BOOT_B, seed: int = SEED):
    """Hierarchical micro-bootstrap for Δ median (baseline - full) with 95% CI."""
    db = pd.read_csv(p_base)
    df = pd.read_csv(p_full)
    cols = [c for c in db.columns if c.startswith(_DOM_PREFIX)]
    if not cols:
        return np.nan, np.nan, np.nan
    folds = sorted(set(db['fold']).intersection(set(df['fold'])))
    if not folds:
        return np.nan, np.nan, np.nan

    xb = db.set_index('fold')[cols].loc[folds].to_numpy(dtype=np.float32)
    xf = df.set_index('fold')[cols].loc[folds].to_numpy(dtype=np.float32)
    n_f, n_d = xb.shape
    rng = np.random.default_rng(seed)

    deltas = np.empty(B, dtype=np.float32)
    for b in range(B):
        f_idx = rng.integers(0, n_f, size=n_f, endpoint=False)
        d_idx = rng.integers(0, n_d, size=n_d, endpoint=False)
        xb_s = xb[f_idx][:, d_idx]
        xf_s = xf[f_idx][:, d_idx]
        deltas[b] = np.median(xb_s) - np.median(xf_s)
    lo, hi = np.percentile(deltas, [2.5, 97.5])
    return float(np.median(deltas)), float(lo), float(hi)


def build_combined_reports_across_embeddings(include_alt_labels: bool = True):
    avail = discover_methods_and_M_from_g2(LABEL_EMB)
    methods = sorted(avail.keys()) if (METHODS is None) else [m for m in sorted(avail.keys()) if m in METHODS]

    pat = re.compile(
        r"^rrmse_perfold_(?P<emb>.+?)__(?P<reg>.+?)__(?P<meth>.+?)__pct(?P<pct>\d+)_K(?P<K>\d+)__Mmax(?P<M>\d+)"
        r"(?:__(?P<labels>labels_[A-Za-z0-9_]+))?__(?P<var>baseline|full)\.csv$"
    )

    summary_rows, wil_rows, cliffs_rows = [], [], []
    pooled_pairs = {}             # key=(meth,pct,labels_source) -> {"base":[...], "full":[...]}
    pooled_pairs_by_embed = {}    # key=(emb,meth,pct,labels_source) -> {"base":[...], "full":[...]}
    bootstrap_rows = []

    for emb in STUDENT_EMBEDDINGS:
        files = sorted([p for p in RES_DIR.glob(f"rrmse_perfold_{emb}__*__full.csv") if pat.match(p.name)])
        for f_full in files:
            m = pat.match(f_full.name)
            if not m: 
                continue
            d = m.groupdict()
            reg, meth = d["reg"], d["meth"]
            pct, K, Mmax = int(d["pct"]), int(d["K"]), int(d["M"])
            labels_tag = (d.get("labels") or "").strip()
            labels_source = labels_tag if labels_tag else "default"
            if (not include_alt_labels) and labels_tag:
                continue
            if pct not in PCT_LIST or (methods and meth not in methods): 
                continue

            p_full = f_full
            p_base = result_path(emb, reg, meth, pct, K, Mmax, "baseline",
                                 labels_tag=None if labels_source == "default" else labels_source)
            if not p_base.exists(): 
                continue

            # global medians
            g_base = _global_median_from_file(p_base)
            g_full = _global_median_from_file(p_full)
            summary_rows.append({
                "student":"S1","embedding":emb,"regressor":reg,"method":meth,
                "pct":pct,"K":K,"M_max":Mmax,"baseline":g_base,"full":g_full,
                "labels_source":labels_source,
            })

            # paired arrays (fold×domain flattened)
            xb, xf = _flat_arrays(p_base, p_full)
            pooled_pairs.setdefault((meth, pct, labels_source), {"base":[], "full":[]})
            pooled_pairs[(meth, pct, labels_source)]["base"].extend(xb.tolist())
            pooled_pairs[(meth, pct, labels_source)]["full"].extend(xf.tolist())

            pooled_pairs_by_embed.setdefault((emb, meth, pct, labels_source), {"base":[], "full":[]})
            pooled_pairs_by_embed[(emb, meth, pct, labels_source)]["base"].extend(xb.tolist())
            pooled_pairs_by_embed[(emb, meth, pct, labels_source)]["full"].extend(xf.tolist())

            # wilcoxon per-config (one-sided less)
            try:
                _, p_raw = wilcoxon(xf, xb, zero_method="pratt", alternative="less")
            except Exception:
                p_raw = 1.0
            dlt, npos, nneg, nzero, mag = paired_cliffs_delta(xf, xb)
            wil_rows.append({"student":"S1","embedding":emb,"regressor":reg,"method":meth,
                             "pct":pct,"K":K,"M_max":Mmax,"labels_source":labels_source,"p_raw":float(p_raw)})
            cliffs_rows.append({"student":"S1","embedding":emb,"regressor":reg,"method":meth,
                                "pct":pct,"K":K,"M_max":Mmax,"labels_source":labels_source,
                                "cliffs_delta_paired":float(dlt),"npos":int(npos),"nneg":int(nneg),
                                "nzero":int(nzero),"magnitude":mag})

            # bootstrap Δ median CI
            d_med, ci_lo, ci_hi = _bootstrap_delta_ci_from_files(p_base, p_full, B=BOOT_B, seed=SEED)
            bootstrap_rows.append({
                "embedding":emb,"regressor":reg,"method":meth,"pct":pct,"K":K,"M_max":Mmax,
                "labels_source":labels_source,"delta_median":d_med,"ci_lo":ci_lo,"ci_hi":ci_hi
            })

    if not summary_rows:
        say("[review] No result pairs found to summarise.")
        return

    # ---------- summaries ----------
    wide = pd.DataFrame(summary_rows).sort_values(
        ["labels_source","embedding","regressor","method","pct"]
    ).reset_index(drop=True)

    tol = 1e-9
    def _cmp(row):
        b, f = row["baseline"], row["full"]
        if pd.isna(b) or pd.isna(f): return "n/a"
        if (b - f) > tol:  return "better"
        if (f - b) > tol:  return "worse"
        return "same"
    wide["baseline_vs_full"] = wide.apply(_cmp, axis=1)

    OUT_TABLES.mkdir(parents=True, exist_ok=True)
    wide.to_csv(OUT_TABLES / "summary_median_rrmse__ALL.csv", index=False)

    comp = wide.copy()
    comp["delta"] = comp["baseline"] - comp["full"]
    comp["rel_change_pct"] = 100.0 * comp["delta"] / comp["baseline"].replace(0, np.nan)
    comp.to_csv(OUT_TABLES / "summary_median_rrmse_with_delta__ALL.csv", index=False)

    # legacy default filenames prefer 'default' source
    default_sub = wide[wide["labels_source"] == "default"]
    if not default_sub.empty:
        default_sub.to_csv(OUT_TABLES / "summary_median_rrmse.csv", index=False)
        comp[comp["labels_source"] == "default"].to_csv(
            OUT_TABLES / "summary_median_rrmse_with_delta.csv", index=False
        )
    else:
        wide.to_csv(OUT_TABLES / "summary_median_rrmse.csv", index=False)
        comp.to_csv(OUT_TABLES / "summary_median_rrmse_with_delta.csv", index=False)

    # per label-source CSVs (default / labels_local_lasso / …)
    for lbl in sorted(wide["labels_source"].unique()):
        sub = wide[wide["labels_source"] == lbl].copy()
        sub.to_csv(OUT_TABLES / f"summary_median_rrmse__{lbl}.csv", index=False)
        sc = comp[comp["labels_source"] == lbl].copy()
        sc.to_csv(OUT_TABLES / f"summary_median_rrmse_with_delta__{lbl}.csv", index=False)

    # per-embedding CSVs (compat: __e5_base copies)
    for emb in sorted(wide["embedding"].unique()):
        wide[wide["embedding"] == emb].to_csv(OUT_TABLES / f"summary_median_rrmse__{emb}.csv", index=False)
        comp[comp["embedding"] == emb].to_csv(OUT_TABLES / f"summary_median_rrmse_with_delta__{emb}.csv", index=False)

    # ---------- pooled tests (ALL + per-embed) ----------
    pooled_rows = []
    for (meth, pct, lbl), pair in pooled_pairs.items():
        xb = np.array(pair["base"], dtype=np.float32)
        xf = np.array(pair["full"], dtype=np.float32)
        if xb.size == 0 or xf.size == 0: 
            continue
        try:
            _, p_raw = wilcoxon(xf, xb, zero_method="pratt", alternative="less")
        except Exception:
            p_raw = 1.0
        dlt, npos, nneg, nzero, mag = paired_cliffs_delta(xf, xb)
        pooled_rows.append({"method":meth,"pct":pct,"labels_source":lbl,
                            "p_raw":float(p_raw),"cliffs_delta":float(dlt),
                            "npos":int(npos),"nneg":int(nneg),"nzero":int(nzero),"magnitude":mag})
    pooled_df = pd.DataFrame(pooled_rows).sort_values(["labels_source","method","pct"])
    if not pooled_df.empty:
        pooled_df.to_csv(OUT_TABLES / "wilcoxon_cliffs_pooled__ALL.csv", index=False)
        # compat split files
        pooled_df.rename(columns={"p_raw":"p_value"}).to_csv(OUT_TABLES / "wilcoxon_pooled_vs_baseline.csv", index=False)
        pooled_df[["method","pct","labels_source","cliffs_delta","magnitude","npos","nneg","nzero"]].to_csv(
            OUT_TABLES / "cliffs_delta_pooled_vs_baseline.csv", index=False
        )

    # per-embed pooled
    pooled_embed_rows = []
    for (emb, meth, pct, lbl), pair in pooled_pairs_by_embed.items():
        xb = np.array(pair["base"], dtype=np.float32)
        xf = np.array(pair["full"], dtype=np.float32)
        if xb.size == 0 or xf.size == 0:
            continue
        try:
            _, p_raw = wilcoxon(xf, xb, zero_method="pratt", alternative="less")
        except Exception:
            p_raw = 1.0
        dlt, npos, nneg, nzero, mag = paired_cliffs_delta(xf, xb)
        pooled_embed_rows.append({"embedding":emb,"method":meth,"pct":pct,"labels_source":lbl,
                                  "p_raw":float(p_raw),"cliffs_delta":float(dlt),
                                  "npos":int(npos),"nneg":int(nneg),"nzero":int(nzero),"magnitude":mag})
    pooled_e_df = pd.DataFrame(pooled_embed_rows).sort_values(["embedding","labels_source","method","pct"])
    if not pooled_e_df.empty:
        # compat: write __e5_base copies if present
        for emb in pooled_e_df["embedding"].unique():
            sub = pooled_e_df[pooled_e_df["embedding"]==emb].drop(columns=["embedding"])
            sub.rename(columns={"p_raw":"p_value"}).to_csv(OUT_TABLES / f"wilcoxon_pooled_vs_baseline__{emb}.csv", index=False)
            sub[["method","pct","labels_source","cliffs_delta","magnitude","npos","nneg","nzero"]].to_csv(
                OUT_TABLES / f"cliffs_delta_pooled_vs_baseline__{emb}.csv", index=False
            )

    # ---------- per-config wilcoxon (raw + Holm) & cliffs (compat names) ----------
    if wil_rows:
        wil_df = pd.DataFrame(wil_rows).sort_values(
            ["labels_source","embedding","regressor","method","pct"]
        )
        wil_df.to_csv(OUT_TABLES / "wilcoxon_vs_baseline__ALL.csv", index=False)

        # Holm within each (labels_source, embedding, regressor, method)
        holm_chunks = []
        for key, grp in wil_df.groupby(["labels_source","embedding","regressor","method"], dropna=False):
            _, p_adj, _, _ = multipletests(grp["p_raw"].values, method="holm")
            g2 = grp.copy()
            g2["p_holm"] = p_adj
            holm_chunks.append(g2)
        holm_df = pd.concat(holm_chunks, axis=0).sort_values(
            ["labels_source","embedding","regressor","method","pct"]
        )
        holm_df.to_csv(OUT_TABLES / "wilcoxon_holm_vs_baseline.csv", index=False)
        for emb in holm_df["embedding"].unique():
            holm_df[holm_df["embedding"]==emb].to_csv(OUT_TABLES / f"wilcoxon_holm_vs_baseline__{emb}.csv", index=False)

    if cliffs_rows:
        cliffs_df = pd.DataFrame(cliffs_rows).sort_values(
            ["labels_source","embedding","regressor","method","pct"]
        )
        cliffs_df.to_csv(OUT_TABLES / "cliffs_delta_paired__ALL.csv", index=False)
        # compat names
        cliffs_df.to_csv(OUT_TABLES / "cliffs_delta_vs_baseline.csv", index=False)
        for emb in cliffs_df["embedding"].unique():
            cliffs_df[cliffs_df["embedding"]==emb].to_csv(OUT_TABLES / f"cliffs_delta_vs_baseline__{emb}.csv", index=False)

    # ---------- bootstrap Δ CI ----------
    if bootstrap_rows:
        pd.DataFrame(bootstrap_rows).sort_values(
            ["labels_source","embedding","regressor","method","pct"]
        ).to_csv(OUT_TABLES / "bootstrap_delta_ci.csv", index=False)

    # ---------- plots ----------
    # A) per label-source (you already write __{lbl}); keep that.
    for meth in sorted(wide['method'].unique()):
        for lbl in sorted(wide['labels_source'].unique()):
            sub = wide[(wide['method']==meth) & (wide['labels_source']==lbl)]
            if sub.empty: 
                continue
            grp = sub.groupby('pct', as_index=False)[['baseline', 'full']].median()
            fig, ax = plt.subplots(figsize=(6,4), dpi=120)
            ax.plot(grp['pct'], grp['baseline'], marker='o', label='baseline')
            ax.plot(grp['pct'], grp['full'],     marker='o', label='full')
            ax.set_xlabel('%K'); ax.set_ylabel('Global median RRMSE')
            ax.set_title(f'{meth} [{lbl}]'); ax.legend()
            fig.savefig(PLOTS_DIR / f"combined_rrmse_vs_pct__{meth}__{lbl}.png", bbox_inches="tight", dpi=300)
            plt.close(fig)

            # Δ plots per label-source
            grp['delta'] = grp['baseline'] - grp['full']
            fig, ax = plt.subplots(figsize=(6,4), dpi=120)
            ax.plot(grp['pct'], grp['delta'], marker='o')
            ax.set_xlabel('%K'); ax.set_ylabel('Δ median RRMSE (baseline - full)')
            ax.set_title(f'{meth} [{lbl}] Δ vs %K')
            fig.savefig(PLOTS_DIR / f"combined_delta_vs_pct__{meth}__{lbl}.png", bbox_inches="tight", dpi=300)
            plt.close(fig)

        # B) overlay across label sources (FULL)
        piv = (wide[wide['method']==meth]
               .groupby(['labels_source','pct'], as_index=False)['full']
               .median())
        if not piv.empty and piv['labels_source'].nunique() >= 2:
            fig, ax = plt.subplots(figsize=(6,4), dpi=120)
            for lbl, chunk in piv.groupby('labels_source'):
                ax.plot(chunk['pct'], chunk['full'], marker='o', label=f'{lbl} (full)')
            ax.set_xlabel('%K'); ax.set_ylabel('Global median RRMSE')
            ax.set_title(f'{meth} — label sources (FULL)'); ax.legend()
            fig.savefig(PLOTS_DIR / f"combined_rrmse_vs_pct__{meth}__compare_labels.png", bbox_inches="tight", dpi=300)
            plt.close(fig)

        # C) compat “no-suffix” and per-embedding figures (use default label-source)
        dsub = wide[(wide['method']==meth) & (wide['labels_source']=='default')]
        if not dsub.empty:
            g = dsub.groupby('pct', as_index=False)[['baseline','full']].median()
            # no-suffix rrMSE
            fig, ax = plt.subplots(figsize=(6,4), dpi=120)
            ax.plot(g['pct'], g['baseline'], marker='o', label='baseline')
            ax.plot(g['pct'], g['full'],     marker='o', label='full')
            ax.set_xlabel('%K'); ax.set_ylabel('Global median RRMSE')
            ax.set_title(f'{meth}'); ax.legend()
            fig.savefig(PLOTS_DIR / f"combined_rrmse_vs_pct__{meth}.png", bbox_inches="tight", dpi=300)
            plt.close(fig)
            # no-suffix Δ
            g['delta'] = g['baseline'] - g['full']
            fig, ax = plt.subplots(figsize=(6,4), dpi=120)
            ax.plot(g['pct'], g['delta'], marker='o')
            ax.set_xlabel('%K'); ax.set_ylabel('Δ median RRMSE (baseline - full)')
            ax.set_title(f'{meth} Δ vs %K')
            fig.savefig(PLOTS_DIR / f"combined_delta_vs_pct__{meth}.png", bbox_inches="tight", dpi=300)
            plt.close(fig)

            # per-embedding copies (e.g., __e5_base)
            for emb in sorted(dsub['embedding'].unique()):
                g_emb = dsub[dsub['embedding']==emb].groupby('pct', as_index=False)[['baseline','full']].median()
                fig, ax = plt.subplots(figsize=(6,4), dpi=120)
                ax.plot(g_emb['pct'], g_emb['baseline'], marker='o', label='baseline')
                ax.plot(g_emb['pct'], g_emb['full'],     marker='o', label='full')
                ax.set_xlabel('%K'); ax.set_ylabel('Global median RRMSE')
                ax.set_title(f'{meth} [{emb}]'); ax.legend()
                fig.savefig(PLOTS_DIR / f"combined_rrmse_vs_pct__{meth}__{emb}.png", bbox_inches="tight", dpi=300)
                plt.close(fig)

                g_emb['delta'] = g_emb['baseline'] - g_emb['full']
                fig, ax = plt.subplots(figsize=(6,4), dpi=120)
                ax.plot(g_emb['pct'], g_emb['delta'], marker='o')
                ax.set_xlabel('%K'); ax.set_ylabel('Δ median RRMSE (baseline - full)')
                ax.set_title(f'{meth} [{emb}] Δ vs %K')
                fig.savefig(PLOTS_DIR / f"combined_delta_vs_pct__{meth}__{emb}.png", bbox_inches="tight", dpi=300)
                plt.close(fig)

    say("[review] Summaries written. (combined, per label-source, legacy-compat files)")


# ────────────────────────────────────────────────────────────────────────────
#  Core run (compute) with labels_tag support
# ────────────────────────────────────────────────────────────────────────────

def _run_core(labels_tag: Optional[str] = None):
    """
    Compute pipeline body. Trains/evaluates students using the current global LABEL_EMB/LABEL_MODEL.
    Results are written with an optional labels_tag to disambiguate alt label sources.
    """
    t0_all = time.time()
    write_run_config()

    # 0) Load seeds and vectors
    seeds_df = load_seeds()
    N_SEEDS = len(seeds_df)
    X_seed = ensure_seed_vectors(seeds_df, STUDENT_EMB)
    Y_seed = seeds_df[TARGET_COLS].to_numpy(dtype=np.float32)

    # 1) Discover available methods & pick M_max per method from e_2 labels
    avail = discover_methods_and_M_from_g2(LABEL_EMB)
    methods = sorted(avail.keys()) if (METHODS is None) else [m for m in sorted(avail.keys()) if m in METHODS]
    if not methods:
        raise FileNotFoundError(f"No labeled synthetics found in {G2_DIR} for embedding={LABEL_EMB}")

    folds = list(range(N_SEEDS)) if MAX_FOLDS <= 0 else list(range(min(MAX_FOLDS, N_SEEDS)))
    say(f"[run] STUDENT_EMB={STUDENT_EMB} | LABEL_SRC={LABEL_EMB}+{LABEL_MODEL} | labels_tag={labels_tag or 'default'}")
    say(f"[run] methods={methods} | folds={len(folds)} | PCT_LIST={PCT_LIST}")

    for method in methods:
        Mmax = max(avail[method])
        # Ensure synth cache exists and load once per method for student embedding
        npy_path, idx_path = ensure_synth_cache(method, Mmax, STUDENT_EMB)
        X_synth = np.load(npy_path)  # shape (Mmax, d)
        n_synth_rows = X_synth.shape[0]

        for pct in PCT_LIST:
            K = pct_to_K(pct, N_SEEDS)
            if K > n_synth_rows:
                log.warning(f"[{method}] pct={pct} → K={K} exceeds Mmax={n_synth_rows}; clipping to Mmax")
                K = n_synth_rows

            for reg in STUDENTS:
                # Output files (idempotent)
                p_base = result_path(STUDENT_EMB, reg, method, pct, K, Mmax, "baseline", labels_tag=labels_tag)
                p_full = result_path(STUDENT_EMB, reg, method, pct, K, Mmax, "full",      labels_tag=labels_tag)

                # Skip if already complete
                def _is_complete(p: Path) -> bool:
                    try:
                        if not p.exists(): return False
                        n_rows = sum(1 for _ in open(p, "r", encoding="utf-8")) - 1
                        return max(0, n_rows) >= len(folds)
                    except Exception:
                        return False

                base_done = _is_complete(p_base)
                full_done = _is_complete(p_full)

                # Optional legacy reuse for chain_lr (full only) — only for default tag
                if (reg == "chain_ERCcv_lr") and (pct in {100,200,400}) and not full_done:
                    reused = ensure_chainlr_from_legacy_pct(STUDENT_EMB, method, pct, K, N_SEEDS, Mmax, labels_tag=labels_tag)
                    if reused:
                        full_done = True

                if base_done and full_done:
                    log.info(f"[skip] {p_base.name} & {p_full.name} complete.")
                    continue

                rows_base, rows_full = [], []
                t_start = time.time()

                for i, fold_idx in enumerate(folds):
                    if (i % LOG_EVERY) == 0:
                        log.info(f"[{method} | {reg} | pct={pct}] fold {i+1}/{len(folds)} …")

                    tr_idx = np.array([j for j in range(N_SEEDS) if j != fold_idx], dtype=int)
                    y_dummy = np.mean(Y_seed[tr_idx], axis=0, dtype=np.float32)

                    # Build model from HPs (teacher pickle or json)
                    teacher_est = load_teacher_fold_estimator(fold_idx, reg, STUDENT_EMB)
                    hp = hp_from_teacher_or_json(fold_idx, reg, STUDENT_EMB, teacher_est)
                    hp_out = HP_PERFOLD / reg / method / f"n{Mmax}" / f"pct{pct}_K{K}"
                    hp_out.mkdir(parents=True, exist_ok=True)
                    (hp_out / f"fold{fold_idx:02d}_best.json").write_text(json.dumps(hp, indent=2))

                    model_base = build_with_hp(reg, hp)

                    # ---------- Baseline (seeds-only) ----------
                    if not base_done:
                        model_base.fit(X_seed[tr_idx], Y_seed[tr_idx])
                        y_true = Y_seed[fold_idx:fold_idx+1]
                        y_pred = model_base.predict(X_seed[fold_idx:fold_idx+1])
                        rr = rrmse_vs_dummy(y_true, y_pred, y_dummy.reshape(1, -1)).flatten()
                        row = {
                            "fold": int(fold_idx),
                            "method": method,
                            "pct": int(pct),
                            "K": int(K),
                            "M_max": int(Mmax),
                            "student": "S1",
                            "embedding": STUDENT_EMB,
                            "regressor": reg,
                            "variant": "baseline",
                            "median_rrmse_fold": float(np.median(rr)),
                        }
                        for t_idx in range(1, 15):
                            row[f"{_DOM_PREFIX}{t_idx}"] = float(rr[t_idx-1])
                        rows_base.append(row)

                    # ---------- Full (seeds + synthetics) ----------
                    if not full_done:
                        # Labels for this fold, method, Mmax, label source
                        lbl_csv = g2_label_file(
                            fold_idx, method, Mmax, LABEL_EMB, LABEL_MODEL,
                            strict_model=True
                        )
                        if not lbl_csv.exists():
                            raise FileNotFoundError(f"Missing labels: {lbl_csv}")
                        df_lbl = pd.read_csv(lbl_csv)
                        if any(c not in df_lbl.columns for c in TARGET_COLS):
                            raise RuntimeError(f"{lbl_csv.name}: missing target cols {TARGET_COLS}")
                        # first K rows
                        Y_syn = df_lbl[TARGET_COLS].to_numpy(dtype=np.float32)[:K, :]
                        X_syn = X_synth[:K, :]

                        model_full = build_with_hp(reg, hp)
                        X_tr_full = np.vstack([X_seed[tr_idx], X_syn])
                        Y_tr_full = np.vstack([Y_seed[tr_idx], Y_syn])
                        model_full.fit(X_tr_full, Y_tr_full)

                        y_true = Y_seed[fold_idx:fold_idx+1]
                        y_pred = model_full.predict(X_seed[fold_idx:fold_idx+1])
                        rr = rrmse_vs_dummy(y_true, y_pred, y_dummy.reshape(1, -1)).flatten()
                        row = {
                            "fold": int(fold_idx),
                            "method": method,
                            "pct": int(pct),
                            "K": int(K),
                            "M_max": int(Mmax),
                            "student": "S1",
                            "embedding": STUDENT_EMB,
                            "regressor": reg,
                            "variant": "full",
                            "median_rrmse_fold": float(np.median(rr)),
                        }
                        for t_idx in range(1, 15):
                            row[f"{_DOM_PREFIX}{t_idx}"] = float(rr[t_idx-1])
                        rows_full.append(row)

                # Write outputs
                if rows_base and (not base_done):
                    pd.DataFrame(rows_base).sort_values("fold").to_csv(p_base, index=False)
                    log.info(f"[write] {p_base.name} ({len(rows_base)} rows, { _fmt_dt(time.time()-t_start) })")
                if rows_full and (not full_done):
                    pd.DataFrame(rows_full).sort_values("fold").to_csv(p_full, index=False)
                    log.info(f"[write] {p_full.name} ({len(rows_full)} rows, { _fmt_dt(time.time()-t_start) })")

    say(f"[run] completed in { _fmt_dt(time.time()-t0_all) }")

# ────────────────────────────────────────────────────────────────────────────
#  Pipelines: compute & review
# ────────────────────────────────────────────────────────────────────────────

def _run_students_for_label_source(label_emb: str,
                                   label_model: str,
                                   labels_tag: Optional[str],
                                   restrict_students: Optional[List[str]] = None):
    """
    Execute the student-scoring pipeline using the given label source.
    If restrict_students is provided, only run those student keys.
    Results are written with an optional labels_tag appended in filenames.
    """
    global LABEL_EMB, LABEL_MODEL, STUDENTS

    _orig_label_emb, _orig_label_model = LABEL_EMB, LABEL_MODEL
    students_backup = list(STUDENTS)
    try:
        LABEL_EMB, LABEL_MODEL = label_emb, label_model

        if restrict_students is not None:
            STUDENTS = [s for s in students_backup if s in restrict_students]
            if not STUDENTS:
                log.warning("No matching students in restrict_students=%s; skipping.", restrict_students)
                return

        log.info("[labels] Running students %s on labels %s+%s (tag=%s)", STUDENTS, LABEL_EMB, LABEL_MODEL, labels_tag)
        _run_core(labels_tag=labels_tag)
    finally:
        LABEL_EMB, LABEL_MODEL = _orig_label_emb, _orig_label_model
        STUDENTS = students_backup

def run_pipeline_compute():
    """
    Compute pipeline: 
      1) Default pass — all students on labels E5_base + chain_ERCcv_lr (no tag).
      2) Extra pass  — chain_ERCcv_lr student on labels E5_base + local_lasso (tagged).
    """
    # Pass 1 (default)
    for emb_key in STUDENT_EMBEDDINGS:
        run_for_embedding(emb_key)           # calls _run_core(labels_tag=None)

    # Pass 2 (extras)
    for job in EXTRA_STUDENT_LABEL_JOBS:
        _run_students_for_label_source(
            label_emb=job["label_emb"],
            label_model=job["label_model"],
            labels_tag=job.get("labels_tag"),
            restrict_students=[job["student"]],
        )

def run_pipeline_review():
    """
    Review pipeline (no compute): detect any precomputed student results and build tables/plots.
    - Always loads default label-source artifacts if present.
    - Also loads alt label-source artifacts (e.g., __labels_local_lasso) if present.
    - Writes per-labels-source summaries AND a combined summary for convenience.
    """
    build_combined_reports_across_embeddings(include_alt_labels=True)

# ────────────────────────────────────────────────────────────────────────────
#  Entry point
# ────────────────────────────────────────────────────────────────────────────

if __name__ == "__main__":
    try:
        write_run_config()
        if REVIEW_MODE:
            run_pipeline_review()
        else:
            run_pipeline_compute()
        print("Run completed.")
    except Exception as e:
        log.exception("Fatal error in step e_3: %s", e)
        print(f"Error: {e}")


  from tqdm.autonotebook import tqdm, trange
  warn("The installed version of bitsandbytes was compiled without GPU support. "


'NoneType' object has no attribute 'cadam32bit_grad_fp32'
2025-10-17 01:54:47 INFO: run_config.json written.
[review] Summaries written. (combined, per label-source, legacy-compat files)
Run completed.
