In [None]:
from typing import Any, Mapping, Sequence, cast
import numpy as np

from ase.atoms import Atoms
from pymatgen.core import Structure
from pymatgen.io.ase import AseAtomsAdaptor
from smol.cofe import ClusterExpansion

# ---- phaseedge imports (modern-only) ----
from phaseedge.jobs.store_ce_model import lookup_ce_by_key
from phaseedge.science.prototypes import PrototypeName
from phaseedge.jobs.train_ce import (
    featurize_structures,
    lookup_train_refs_by_key,
    predict_from_features,
    compute_stats,
    _n_replace_sites_from_prototype,
)
from phaseedge.schemas.mixture import Mixture, sublattices_from_mixtures
from phaseedge.storage.store import lookup_total_energy_eV
from phaseedge.utils.keys import occ_key_for_structure


# ----------------- small helpers -----------------

def fmt_mev(x: float | str) -> str:
    if isinstance(x, str):
        return x
    return f"{1e3 * float(x):.3f}"

def _gather_composition_mixtures(sources: Sequence[Mapping[str, Any]]) -> list[Mixture]:
    mixes: list[Mixture] = []
    for s in sources:
        if str(s.get("type", "")).lower() != "composition":
            continue
        for m in s.get("mixtures", []):
            mixes.append(
                Mixture.from_dict(
                    {
                        "composition_map": dict(m["composition_map"]),
                        "K": int(m["K"]),
                        "seed": int(m["seed"]),
                    }
                )
            )
    if not mixes:
        raise RuntimeError("Expected 'composition' mixtures in sources (modern schema).")
    return mixes


def _sublattices_for_doc(doc: Mapping[str, Any]) -> dict[str, tuple[str, ...]]:
    """
    Replaceable sublattices for per-site scaling:
      - composition CE: from this doc's mixtures
      - wl_refined CE : from the base (composition) CE's mixtures
    """
    sources = cast(Sequence[Mapping[str, Any]], doc["sources"])
    if not sources:
        raise RuntimeError("CE document missing 'sources'.")
    s0 = sources[0]
    s_type = str(s0.get("type", "")).lower()

    if s_type == "composition":
        return sublattices_from_mixtures(_gather_composition_mixtures(sources))

    if s_type.startswith("wl_refined"):
        base_ce_key = cast(str, s0["base_ce_key"])
        base_doc = lookup_ce_by_key(base_ce_key)
        if not base_doc:
            raise RuntimeError(f"Base CE not found for wl_refined source: {base_ce_key}")
        return sublattices_from_mixtures(
            _gather_composition_mixtures(cast(Sequence[Mapping[str, Any]], base_doc["sources"]))
        )

    raise RuntimeError(f"Unsupported source type in CE doc: {s_type!r}")


def _fmt_f(x: float) -> str:
    return f"{float(x):.6g}"


def _grab(dm: Mapping[str, Any] | None, key: str, default: Any) -> Any:
    return default if dm is None else dm.get(key, default)


def _print_two_col_table(title: str, col_a_name: str, col_b_name: str, rows: list[tuple[str, str, str]]) -> None:
    """Pretty side-by-side ASCII table."""
    w_name = max(len(r[0]) for r in rows + [("Metric", "", "")])
    w_a = max(len(col_a_name), max(len(r[1]) for r in rows))
    w_b = max(len(col_b_name), max(len(r[2]) for r in rows))
    total = 3 + w_name + 3 + w_a + 3 + w_b + 3

    print(title.center(total, "─"))
    print(f" {'Metric'.ljust(w_name)} │ {col_a_name.ljust(w_a)} │ {col_b_name.ljust(w_b)} ")
    print("─" * total)
    for name, a, b in rows:
        print(f" {name.ljust(w_name)} │ {a.ljust(w_a)} │ {b.ljust(w_b)} ")
    print("")


# ----------------- main comparison -----------------

def compare_comp_vs_refined_on_refined_dataset(
    refined_ce_key: str,
    *,
    atol: float = 1e-6,
    rtol: float = 1e-8,
    verbose: bool = True,
) -> dict[str, Any]:
    """
    Compare base composition CE vs. refined WL CE *on the refined dataset*.

    - Input: a **refined WL CE key**.
    - We evaluate the **base (composition) CE** out-of-sample on the refined CE's training set.
    - For the refined CE, we **do NOT recompute**; we display the stored 5-fold CV stats.
    - Side-by-side printout includes per-site performance and stored design metrics.
    - Returns a dict with both metric bundles.
    """
    # --- Load refined CE doc ---
    ref_doc = lookup_ce_by_key(refined_ce_key)
    if not ref_doc:
        raise RuntimeError(f"No CE found for refined_ce_key={refined_ce_key}")

    sources = cast(Sequence[Mapping[str, Any]], ref_doc["sources"])
    if not sources:
        raise RuntimeError("Refined CE doc has no sources.")
    s0 = sources[0]
    s_type = str(s0.get("type", "")).lower()
    if not s_type.startswith("wl_refined"):
        raise RuntimeError(f"Provided CE is not a refined WL CE (source.type={s_type!r}).")

    base_ce_key = cast(str, s0["base_ce_key"])
    base_doc = lookup_ce_by_key(base_ce_key)
    if not base_doc:
        raise RuntimeError(f"Base (composition) CE not found: {base_ce_key}")

    # --- Basic identity checks (should match in the modern pipeline) ---
    r_proto = PrototypeName(ref_doc["prototype"])
    b_proto = PrototypeName(base_doc["prototype"])
    if r_proto != b_proto:
        raise RuntimeError("Prototype mismatch between refined CE and base composition CE.")
    r_pp = cast(Mapping[str, Any], ref_doc["prototype_params"])
    b_pp = cast(Mapping[str, Any], base_doc["prototype_params"])
    if dict(r_pp) != dict(b_pp):
        raise RuntimeError("prototype_params mismatch between refined CE and base composition CE.")
    
    xs, ys, zs = ref_doc["supercell_diag"]
    r_sc = (int(xs), int(ys), int(zs))
    b_xs, b_ys, b_zs = base_doc["supercell_diag"]
    b_sc = (int(b_xs), int(b_ys), int(b_zs))
    if r_sc != b_sc:
        raise RuntimeError("supercell_diag mismatch between refined CE and base composition CE.")

    # --- Per-site scaling (replaceable sites / primitive) ---
    sublattices = _sublattices_for_doc(base_doc)  # composition-ground truth
    n_prims = int(np.prod(np.asarray(r_sc, dtype=int)))
    n_sites_const = _n_replace_sites_from_prototype(
        prototype=r_proto,
        prototype_params=r_pp,
        supercell_diag=r_sc,
        sublattices=sublattices,
    )
    if n_sites_const % n_prims != 0:
        raise RuntimeError("sites_per_supercell not divisible by n_prims (prototype/sublattice mismatch).")
    sites_per_prim = n_sites_const // n_prims
    scale_site = 1.0 / float(sites_per_prim)

    # --- Build refined dataset from refined CE train_refs (structures + energies) ---
    train_refs = lookup_train_refs_by_key(ref_doc["dataset_key"])
    if not train_refs or not all("structure" in r for r in train_refs):
        raise RuntimeError("Refined CE doc missing 'train_refs' with embedded 'structure'.")

    model_r = cast(str, ref_doc["model"])
    relax_cell_r = bool(ref_doc["relax_cell"])
    dtype_r = cast(str, ref_doc["dtype"])

    structures: list[Structure] = []
    energies_eV_super: list[float] = []

    for i, r in enumerate(train_refs):
        pmg_struct = Structure.from_dict(r["structure"])
        # sanity: occ_key should round-trip from structure
        ok2 = occ_key_for_structure(pmg_struct)
        ok_expected = cast(str, r["occ_key"])
        if ok2 != ok_expected:
            raise RuntimeError(
                f"train_refs[{i}] occ_key mismatch: expected {ok_expected}, rebuilt {ok2}."
            )

        E = lookup_total_energy_eV(
            set_id=cast(str, r["set_id"]),
            occ_key=ok_expected,
            model=model_r,
            relax_cell=relax_cell_r,
            dtype=dtype_r,
        )
        if E is None:
            raise RuntimeError(f"Energy not found in store for train_refs[{i}] set_id={r['set_id']}, occ_key={ok_expected[:12]}...")

        structures.append(pmg_struct)
        energies_eV_super.append(E)

    y_true_per_prim = (np.asarray(energies_eV_super, dtype=np.float64) / float(n_prims)).astype(np.float64, copy=False)

    # --- Evaluate BASE CE on the refined dataset ---
    base_ce = ClusterExpansion.from_dict(cast(Mapping[str, Any], base_doc["payload"]))
    _, X_base = featurize_structures(
        subspace=base_ce.cluster_subspace,
        structures=structures,
        supercell_diag=r_sc,
    )
    coefs_base = np.asarray(getattr(base_ce, "coefs"), dtype=np.float64)
    y_pred_per_prim_base = predict_from_features(X_base, coefs_base)

    base_on_refined_stats = compute_stats(
        (y_true_per_prim * scale_site).tolist(),
        (y_pred_per_prim_base * scale_site).tolist(),
    )  # per-site

    # --- Pull refined CE stored 5-fold CV stats (per-site) ---
    ref_stats = cast(Mapping[str, Any], ref_doc["stats"])
    ref_cv = cast(Mapping[str, Any], ref_stats["five_fold_cv"])
    # also harvest design metrics (stored at training time)
    dm_base = cast(Mapping[str, Any] | None, base_doc.get("design_metrics"))
    dm_ref = cast(Mapping[str, Any] | None, ref_doc.get("design_metrics"))

    # --- Pretty side-by-side printout ---
    if verbose:
        print("═" * 86)
        print(" CE Comparison on Refined Dataset ".center(86, "═"))
        print("═" * 86)
        print(f"Refined CE key     : {refined_ce_key}")
        print(f"Base (composition) : {base_ce_key}")
        print(f"Prototype          : {r_proto}")
        print(f"Supercell diag     : {r_sc}  (n_prims={n_prims}, sites/prim={sites_per_prim})")
        print("")

        # Performance (per-site, meV)
        perf_rows: list[tuple[str, str, str]] = [
            ("n",                str(int(base_on_refined_stats["n"])), str(int(ref_cv["n"]))),
            ("MAE (meV/site)",   fmt_mev(base_on_refined_stats["mae_per_site"]), fmt_mev(float(ref_cv["mae_per_site"]))),
            ("RMSE (meV/site)",  fmt_mev(base_on_refined_stats["rmse_per_site"]), fmt_mev(float(ref_cv["rmse_per_site"]))),
            ("Max|err| (meV/site)", fmt_mev(base_on_refined_stats["max_abs_per_site"]), fmt_mev(float(ref_cv["max_abs_per_site"]))),
        ]
        _print_two_col_table(
            title="Performance (per-site, meV) — Base CE evaluated on refined dataset  vs  Refined CE stored 5-fold CV",
            col_a_name="Base CE (OOS on refined set)",
            col_b_name="Refined CE (5-fold CV, stored)",
            rows=perf_rows,
        )

        # Design metrics side-by-side (stored)
        def _dm_rows(dm_a: Mapping[str, Any] | None, dm_b: Mapping[str, Any] | None) -> list[tuple[str, str, str]]:
            return [
                ("n_samples",                 str(int(_grab(dm_a, "n_samples", -1))),              str(int(_grab(dm_b, "n_samples", -1)))),
                ("n_features",                str(int(_grab(dm_a, "n_features", -1))),             str(int(_grab(dm_b, "n_features", -1)))),
                ("rank",                      str(int(_grab(dm_a, "rank", -1))),                   str(int(_grab(dm_b, "rank", -1)))),
                ("sigma_max",                 _fmt_f(float(_grab(dm_a, "sigma_max", float("nan")))), _fmt_f(float(_grab(dm_b, "sigma_max", float("nan"))))),
                ("sigma_min",                 _fmt_f(float(_grab(dm_a, "sigma_min", float("nan")))), _fmt_f(float(_grab(dm_b, "sigma_min", float("nan"))))),
                ("condition_number (κ)",      _fmt_f(float(_grab(dm_a, "condition_number", float("inf")))), _fmt_f(float(_grab(dm_b, "condition_number", float("inf"))))),
                ("logdet(XᵀX)",               _fmt_f(float(_grab(dm_a, "logdet_xtx", float("nan")))), _fmt_f(float(_grab(dm_b, "logdet_xtx", float("nan"))))),
                ("leverage_mean",             _fmt_f(float(_grab(dm_a, "leverage_mean", float("nan")))), _fmt_f(float(_grab(dm_b, "leverage_mean", float("nan"))))),
                ("leverage_p95",              _fmt_f(float(_grab(dm_a, "leverage_p95", float("nan")))), _fmt_f(float(_grab(dm_b, "leverage_p95", float("nan"))))),
                ("leverage_max",              _fmt_f(float(_grab(dm_a, "leverage_max", float("nan")))), _fmt_f(float(_grab(dm_b, "leverage_max", float("nan"))))),
                ("weighting_applied",         str(bool(_grab(dm_a, "weighting_applied", False))),    str(bool(_grab(dm_b, "weighting_applied", False)))),
                ("standardization",           str(_grab(dm_a, "standardization", "none")),           str(_grab(dm_b, "standardization", "none"))),
                ("zero_variance_feature_count", str(int(_grab(dm_a, "zero_variance_feature_count", 0))), str(int(_grab(dm_b, "zero_variance_feature_count", 0)))),
            ]

        _print_two_col_table(
            title="Design metrics (stored at training time)",
            col_a_name="Base CE (composition)",
            col_b_name="Refined CE (WL refined)",
            rows=_dm_rows(dm_base, dm_ref),
        )

        # Quick verdict relative to refined CV (optional)
        mae_base = float(base_on_refined_stats["mae_per_site"])
        mae_ref = float(ref_cv["mae_per_site"])
        rmse_base = float(base_on_refined_stats["rmse_per_site"])
        rmse_ref = float(ref_cv["rmse_per_site"])

        verdict = "≈ comparable"
        if (mae_base - mae_ref) > max(atol, rtol * abs(mae_ref)) and (rmse_base - rmse_ref) > max(atol, rtol * abs(rmse_ref)):
            verdict = "worse than refined (higher MAE & RMSE)"
        elif (mae_ref - mae_base) > max(atol, rtol * abs(mae_base)) and (rmse_ref - rmse_base) > max(atol, rtol * abs(rmse_base)):
            verdict = "better than refined (lower MAE & RMSE)"
        print(f"Verdict        : Base CE on refined set is {verdict}.")
        print("═" * 86)

    return {
        "refined_ce_key": refined_ce_key,
        "base_ce_key": base_ce_key,
        "n_prims": n_prims,
        "sites_per_prim": sites_per_prim,
        "base_on_refined_stats_per_site": base_on_refined_stats,
        "refined_cv_stats_per_site": ref_cv,
        "design_metrics": {
            "base": dm_base,
            "refined": dm_ref,
        },
    }


# ---- Example:
# compare_comp_vs_refined_on_refined_dataset("76195c946ffef01b43a45fbccfb192c96d55011b69a54b0ccc2d6ff9f6f59bba")
compare_comp_vs_refined_on_refined_dataset("6673ff4bac88871cc8b21fc78e2035b39eac75a31198d839a4daaa75c697286f")


══════════════════════════════════════════════════════════════════════════════════════
══════════════════════════ CE Comparison on Refined Dataset ══════════════════════════
══════════════════════════════════════════════════════════════════════════════════════
Refined CE key     : 76195c946ffef01b43a45fbccfb192c96d55011b69a54b0ccc2d6ff9f6f59bba
Base (composition) : 781041790a68cc664c64278fed25132d984b3bfbc802d207e41d8626818ff737
Prototype          : rocksalt
Supercell diag     : (3, 3, 3)  (n_prims=27, sites/prim=4)

Performance (per-site, meV) — Base CE evaluated on refined dataset  vs  Refined CE stored 5-fold CV
 Metric              │ Base CE (OOS on refined set) │ Refined CE (5-fold CV, stored) 
─────────────────────────────────────────────────────────────────────────────────────────
 n                   │ 102                          │ 102                            
 MAE (meV/site)      │ 5.420                        │ 1.065                          
 RMSE (meV/site)     │ 9.472 

{'refined_ce_key': '76195c946ffef01b43a45fbccfb192c96d55011b69a54b0ccc2d6ff9f6f59bba',
 'base_ce_key': '781041790a68cc664c64278fed25132d984b3bfbc802d207e41d8626818ff737',
 'n_prims': 27,
 'sites_per_prim': 4,
 'base_on_refined_stats_per_site': {'n': 102,
  'mae_per_site': 0.005420326681759167,
  'rmse_per_site': 0.009472223583991959,
  'max_abs_per_site': 0.04203708925292382},
 'refined_cv_stats_per_site': {'n': 102,
  'mae_per_site': 0.0010650202138921178,
  'rmse_per_site': 0.0013909209928694552,
  'max_abs_per_site': 0.003998633950208941},
 'design_metrics': {'base': {'n_samples': 102,
   'n_features': 52,
   'rank': 49,
   'sigma_max': 52.868765530427,
   'sigma_min': 0.0827531270857613,
   'condition_number': 638.8733259062996,
   'logdet_xtx': -60.31021897498089,
   'leverage_mean': 0.480392156862745,
   'leverage_max': 0.9314774345458762,
   'leverage_p95': 0.7124695355950623,
   'weighting_applied': True,
   'standardization': 'column_zscore',
   'zero_variance_feature_count': 