In [1]:
from typing import Any, Mapping, Sequence, cast
import math
from collections import defaultdict

import numpy as np
from pymatgen.core import Structure

from smol.cofe import ClusterExpansion

# ---- phaseedge imports (modern-only) ----
from phaseedge.jobs.store_ce_model import lookup_ce_by_key
from phaseedge.jobs.train_ce import (
    featurize_structures,
    predict_from_features,
    compute_stats,
    _n_replace_sites_from_prototype,
)
from phaseedge.schemas.calc_spec import CalcSpec
from phaseedge.schemas.mixture import Mixture, composition_map_sig, sublattices_from_mixtures
from phaseedge.science.prototype_spec import PrototypeSpec
from phaseedge.storage.cetrainref_dataset import Dataset
from phaseedge.storage.store import lookup_total_energy_eV
from phaseedge.utils.keys import occ_key_for_structure


# ----------------- helpers: structures, sources, sublattices -----------------

def fmt_mev(x: float | str) -> str:
    if isinstance(x, str):
        return x
    return f"{1e3 * float(x):.3f}"

def _gather_composition_mixtures(sources: Sequence[Mapping[str, Any]]) -> list[Mixture]:
    mixes: list[Mixture] = []
    for s in sources:
        if str(s.get("type", "")).lower() != "composition":
            continue
        for m in s.get("mixtures", []):
            mixes.append(
                Mixture.from_dict(
                    {
                        "composition_map": dict(m["composition_map"]),
                        "K": int(m["K"]),
                        "seed": int(m["seed"]),
                    }
                )
            )
    if not mixes:
        raise RuntimeError("Expected 'composition' mixtures in sources (modern schema).")
    return mixes


def _sublattices_for_doc(doc: Mapping[str, Any]) -> dict[str, tuple[str, ...]]:
    """
    Determine replaceable sublattices for per-site scaling:
      - composition CE: from this doc's mixtures
      - dopt sampling CE : from the base (composition) CE's mixtures
    """
    sources = cast(Sequence[Mapping[str, Any]], doc["sources"])
    if not sources:
        raise RuntimeError("CE document missing 'sources'.")
    s0 = sources[0]
    s_type = str(s0.get("type", "")).lower()

    if s_type == "composition":
        return sublattices_from_mixtures(_gather_composition_mixtures(sources))

    if s_type == "dopt_sampling_intent":
        base_ce_key = cast(str, s0["base_ce_key"])
        base_doc = lookup_ce_by_key(base_ce_key)
        if not base_doc:
            raise RuntimeError(f"Base CE not found for dopt_sampling_intent source: {base_ce_key}")
        return sublattices_from_mixtures(_gather_composition_mixtures(cast(Sequence[Mapping[str, Any]], base_doc["sources"])))

    raise RuntimeError(f"Unsupported source type in CE doc: {s_type!r}")

# ----------------- pretty printers ------------------------------------------

def _print_design_metrics(dm: Mapping[str, Any] | None) -> None:
    if not dm:
        print("Design matrix diagnostics: <not available>")
        return

    def _get(name: str, default: float | int | str = float("nan")) -> Any:
        return dm.get(name, default)

    try:
        n_samples = int(_get("n_samples", -1))
        n_features = int(_get("n_features", -1))
        rank = int(_get("rank", -1))

        sigma_max = float(_get("sigma_max", float("nan")))
        sigma_min = float(_get("sigma_min", float("nan")))
        kappa = float(_get("condition_number", float("inf")))
        logdet = float(_get("logdet_xtx", float("nan")))

        lev_mean = float(_get("leverage_mean", float("nan")))
        lev_max = float(_get("leverage_max", float("nan")))
        lev_p95 = float(_get("leverage_p95", float("nan")))

        weighting_applied = bool(_get("weighting_applied", False))
        standardization = str(_get("standardization", "none"))
        zero_var = int(_get("zero_variance_feature_count", 0))
    except Exception:
        print("Design matrix diagnostics: <malformed in CE document>")
        return

    print("Design matrix diagnostics (stored)")
    print("  Samples / Features     : %d / %d" % (n_samples, n_features))
    print("  Rank                   : %d" % rank)
    print("  Sigma max / min        : %.6g / %.6g" % (sigma_max, sigma_min))
    if math.isfinite(kappa):
        print("  Condition number (κ)   : %.6g" % kappa)
    else:
        print("  Condition number (κ)   : inf (rank-deficient or σ_min≈0)")
    print("  log det(XᵀX)           : %.6g" % logdet)
    print("  Leverage mean / p95/max: %.6g / %.6g / %.6g" % (lev_mean, lev_p95, lev_max))
    print("  Weighting applied      : %s" % weighting_applied)
    print("  Standardization        : %s" % standardization)
    print("  Zero-variance columns  : %d" % zero_var)


def _print_mixture_energy_ranges(
    mixture_rows: list[tuple[str, int, float, float, float, float]],
) -> None:
    """
    Each row: (mixture_key, n, E_min_prim, E_max_prim, dE_prim, dE_meV_per_site)
    """
    print("-" * 72)
    print("Mixture energy span (training set)".center(72))
    print("-" * 72)
    hdr = (
        f"{'Mixture (Es/Fm counts)':<42}  {'n':>3}  "
        f"{'E_min (eV/prim)':>14}  {'E_max (eV/prim)':>16}  {'ΔE (eV/prim)':>13}  {'ΔE (meV/site)':>15}"
    )
    print(hdr)
    print("-" * 72)
    for key, n, e_min, e_max, de_prim, de_mev_site in mixture_rows:
        print(
            f"{key:<42}  {n:>3d}  "
            f"{e_min:>14.6f}  {e_max:>16.6f}  {de_prim:>13.6f}  {de_mev_site:>15.3f}"
        )
    print("-" * 72)


def _print_summary(
    *,
    ce_key: str,
    n: int,
    stored_in: Mapping[str, Any],
    recomputed_in: Mapping[str, Any],
    stored_cv: Mapping[str, Any] | None,
    design_metrics: Mapping[str, Any] | None,
    mixture_rows: list[tuple[str, int, float, float, float, float]],
    atol: float,
    rtol: float,
    ok: bool,
) -> None:
    s_mae, s_rmse, s_max = (
        float(stored_in["mae_per_site"]),
        float(stored_in["rmse_per_site"]),
        float(stored_in["max_abs_per_site"]),
    )
    r_mae, r_rmse, r_max = (
        float(recomputed_in["mae_per_site"]),
        float(recomputed_in["rmse_per_site"]),
        float(recomputed_in["max_abs_per_site"]),
    )

    print("═" * 72)
    print(f" CE Validation (per-site metrics in meV) ".center(72, "═"))
    print("═" * 72)
    print(f"CE key        : {ce_key}")
    print(f"samples (n)   : {n}")
    print(f"tolerances    : atol={atol:g}, rtol={rtol:g}")
    print("-" * 72)
    print("In-sample")
    print("Metric         Stored    Recomputed   Δ (recomp − stored)")
    print("-------------  --------  ----------   -------------------")
    print(f"MAE            {fmt_mev(s_mae):>8}   {fmt_mev(r_mae):>8}        {fmt_mev(r_mae - s_mae):>8}")
    print(f"RMSE           {fmt_mev(s_rmse):>8}   {fmt_mev(r_rmse):>8}        {fmt_mev(r_rmse - s_rmse):>8}")
    print(f"Max|err|       {fmt_mev(s_max):>8}   {fmt_mev(r_max):>8}        {fmt_mev(r_max - s_max):>8}")
    print("-" * 72)

    if stored_cv and all(k in stored_cv for k in ("n", "mae_per_site", "rmse_per_site", "max_abs_per_site")):
        cv_n = int(stored_cv["n"]); cv_mae = float(stored_cv["mae_per_site"])
        cv_rmse = float(stored_cv["rmse_per_site"]); cv_max = float(stored_cv["max_abs_per_site"])
        print("Stored 5-fold CV")
        print("Metric         Value")
        print("-------------  --------")
        print(f"n              {cv_n}")
        print(f"MAE            {fmt_mev(cv_mae):>8}")
        print(f"RMSE           {fmt_mev(cv_rmse):>8}")
        print(f"Max|err|       {fmt_mev(cv_max):>8}")
        print("-" * 72)
    else:
        print("Stored 5-fold CV: <not available>")
        print("-" * 72)

    _print_design_metrics(design_metrics)

    if mixture_rows:
        _print_mixture_energy_ranges(mixture_rows)

    print("Result         : " + ("✅  PASS (in-sample matches stored)" if ok else "❌  FAIL (in-sample mismatch)"))
    print("═" * 72)


# ----------------- main: in-sample validator + mixture spans -----------------

def verify_ce_in_sample(
    ce_key: str,
    *,
    atol: float = 1e-6,
    rtol: float = 1e-8,
    verbose: bool = True,
) -> bool:
    """
    Validate CE *in-sample* stats using the serialized Structures stored in train_refs.
    Also print per-mixture energy ranges (E_min, E_max, ΔE) by inferring Es/Fm
    sublattice counts from O-coordination (4 vs 6).
    """
    # ---- Load CE document (modern schema assumed) ----
    doc = lookup_ce_by_key(ce_key)
    if not doc:
        raise RuntimeError(f"No CE found for ce_key={ce_key}")

    sx, sy, sz = (int(x) for x in doc["supercell_diag"])
    supercell_diag = (sx, sy, sz)

    stats = cast(Mapping[str, Any], doc["stats"])
    in_sample_stored = cast(Mapping[str, Any], stats["in_sample"])
    stored_cv = cast(Mapping[str, Any] | None, stats.get("five_fold_cv"))

    # Design metrics (stored at training time)
    design_metrics = cast(Mapping[str, Any] | None, doc.get("design_metrics"))

    # sublattices for per-site scaling
    sublattices = _sublattices_for_doc(doc)
    # Cation symbols we care about (e.g., {"Al","Mg"})
    cation_symbols: set[str] = set()
    for _, allowed in sublattices.items():
        for el in allowed:
            cation_symbols.add(str(el))
    if not cation_symbols:
        raise RuntimeError("Empty cation symbol set inferred from sublattices.")

    # ---- Gather training Structures and energies in the recorded order ----
    dataset = Dataset.from_key(doc["dataset_key"])

    # Targets are per primitive/conventional cell
    n_prims = int(np.prod(np.asarray(supercell_diag, dtype=int)))
    calc_spec_top = CalcSpec.from_dict(doc["calc_spec"])

    structures: list[Structure] = []
    y_true_per_prim: list[float] = []
    by_composition_map: dict[str, list[float]] = defaultdict(list)
    for i, r in enumerate(dataset.train_refs):
        composition_map = r.composition_map
        occ_key = r.occ_key
        calc_spec = r.calc_spec

        if calc_spec_top != calc_spec:
            raise RuntimeError(
                f"train_refs[{i}] calc_spec identity differs from CE doc: "
                f"{calc_spec} != {calc_spec_top}"
            )

        # Strong sanity: recompute occ_key from structure and compare
        ok2 = occ_key_for_structure(r.structure)
        if ok2 != occ_key:
            raise RuntimeError(
                f"train_refs[{i}] occ_key mismatch: expected {occ_key}, rebuilt {ok2}. "
                "This indicates a change in structure hashing."
            )

        structures.append(r.structure)
        energy = lookup_total_energy_eV(
            occ_key=occ_key,
            calc_spec=calc_spec,
        )
        if energy is None:
            raise RuntimeError(f"Energy not found in store for train_refs[{i}] composition_map={composition_map}, occ_key={occ_key[:12]}...")
        y_true_per_prim.append(energy  / float(n_prims))
        by_composition_map[composition_map_sig(composition_map)].append(energy  / float(n_prims))

    # ---- Rehydrate CE and build features ----
    ce = ClusterExpansion.from_dict(cast(Mapping[str, Any], doc["payload"]))
    subspace = ce.cluster_subspace

    _, X = featurize_structures(
        subspace=subspace,
        structures=structures,
        supercell_diag=supercell_diag,
    )

    # Per-site scaling (replaceable sites only)
    prototype_spec = PrototypeSpec.from_dict(doc["prototype_spec"])
    sites_per_supercell = _n_replace_sites_from_prototype(
        prototype_spec=prototype_spec, supercell_diag=supercell_diag,
    )
    if sites_per_supercell % n_prims != 0:
        raise RuntimeError("sites_per_supercell is not divisible by n_prims (prototype/sublattice mismatch).")
    sites_per_prim = sites_per_supercell // n_prims
    scale_site = 1.0 / float(sites_per_prim)

    # ---- In-sample prediction using stored ECIs ----
    coefs = np.asarray(getattr(ce, "coefs"), dtype=np.float64)
    y_pred_per_prim_ins = predict_from_features(X, coefs)

    in_sample_recomputed = compute_stats(
        (np.array(y_true_per_prim) * scale_site).tolist(),
        (y_pred_per_prim_ins * scale_site).tolist(),
    )

    mix_rows: list[tuple[str, int, float, float, float, float]] = []
    for key, vals in sorted(by_composition_map.items(), key=lambda kv: min(kv[1])):
        n_k = len(vals)
        e_min = min(vals)
        e_max = max(vals)
        dE_prim = e_max - e_min
        # Convert to meV per site for quick comparison with CE error bars
        dE_meV_per_site = 1e3 * (dE_prim / float(sites_per_prim))
        mix_rows.append((key, n_k, e_min, e_max, dE_prim, dE_meV_per_site))

    # ---- Compare + summarize ----
    ok = (
        int(in_sample_recomputed["n"]) == int(in_sample_stored["n"])
        and math.isclose(float(in_sample_recomputed["mae_per_site"]),  float(in_sample_stored["mae_per_site"]),  rel_tol=rtol, abs_tol=atol)
        and math.isclose(float(in_sample_recomputed["rmse_per_site"]), float(in_sample_stored["rmse_per_site"]), rel_tol=rtol, abs_tol=atol)
        and math.isclose(float(in_sample_recomputed["max_abs_per_site"]), float(in_sample_stored["max_abs_per_site"]), rel_tol=rtol, abs_tol=atol)
    )

    if verbose:
        _print_summary(
            ce_key=ce_key,
            n=int(in_sample_recomputed["n"]),
            stored_in=in_sample_stored,
            recomputed_in=in_sample_recomputed,
            stored_cv=stored_cv,
            design_metrics=design_metrics,
            mixture_rows=mix_rows,
            atol=atol,
            rtol=rtol,
            ok=ok,
        )

    return ok

# verify_ce_in_sample("ab038ffc02fd679d3675a23163a6b73aa715af816bde6f77c6bf274b956b0ff2") # MgAl2O4 r2SCAN
# verify_ce_in_sample("0bc859d8f855a49818267b926867ec9d4fe2f2e388fd9ad244d4d1d9039a8178") # MgAl2O4 GGA
# verify_ce_in_sample("23f1bc055d408250264901c651e1e0357b8262b33fd7b1e03c7cbb0293275c17") # Pyrochlore 8-6-4
# verify_ce_in_sample("e31258f876183ba3fcb127d5283d07cfa0b1c10c663b40a1c4a656ffd95c2b97") # Pyrochlore 10-7 FAILED
# verify_ce_in_sample("d76b1010a6883839644477cabe4c0a3688e202cf7b97b018bf974b57ef542621") # Pyrochlore 9-7 FAILED
# verify_ce_in_sample("9afceef2d84df4aaccf295ea3eafac7419667a123fd8056cba938c2cf93c42c6") # Pyrochlore 9-7-4 FAILED
# verify_ce_in_sample("67931814e380a63d0aee77118b7e52d31c256e41476acc9b01036c18f4a03a04") # Pyrochlore 10
# verify_ce_in_sample("cc4deb3c2ff88cfef40a9a8bb7954c1239dfb5310d4a00c0cdc9b5c556c8527d") # Pyrochlore 6
# verify_ce_in_sample("7838a93c61d6910d48287b9914f2796e149b9a5b0fe41936cf4d1f6ddee0b41c") # Pyrochlore 7
# verify_ce_in_sample("b2637b96d6e549e6c42c57af478a41b19e8ff6f9d351e48e72bf78b078c0399e") # Pyrochlore 8
# verify_ce_in_sample("0aa13727eb7507a7a614e1c6ef8e6aeefa3da2b035df0f6e1bccac24a5618c04") # Pyrochlore 12
# verify_ce_in_sample("eb3c79e08545ee9b036ba9598df0fe7fa3c76ea273a10c7b3c5461317b35de91") # Pyrochlore 10-6
# verify_ce_in_sample("4d942e9d4f16ad4aac8ae4d501f5bd3c556f8ab30ce20e89296a36cc5802c0dd") # Pyrochlore 10-6-4

# verify_ce_in_sample("acc60972da1381eeedb2e2d5687f92207e1288e81eb414c0eea66e9738400ff1") # Li2TiO3-MgO 8-6-4 12.3
# verify_ce_in_sample("9b3bf3f120b7ad1b7958902151b1a8754393c77ff171baf99cfeff0505408501") # Li2TiO3-MgO 8-5 7.1
# verify_ce_in_sample("76ae820785ee7bbe06ee7283ef2381b9176eeb446416e0a241db516f1abb90ae") # Li2TiO3-MgO 7-6 8.8
verify_ce_in_sample("1b7d43754b33c7ad91fd06bc4641bacd11d3ac69824d4ec0ce939db12714f629") # Li2TiO3-MgO 7-5 5.8
# verify_ce_in_sample("255ea975bb9380e94a602fe10731db5149cf13415445601191b70721d60050f3") # Li2TiO3-MgO 7-5-3 5.9
# verify_ce_in_sample("3e59d57a00f8fd23dd7998ce134c5139428f1057743f5c92114ea3827cdfe856") # Li2TiO3-MgO 7-4 7.8
# verify_ce_in_sample("1a4672ca1ccae95c76fe7432315e5f3f5bc3068e8107aa087eebea91aae05c00") # Li2TiO3-MgO 7-5-4 5.9
# verify_ce_in_sample("5ad5d40700ddba1e8d7b4d6ed0008f916a00eb20b338d1ebedb3681abf27b4dd") # Li2TiO3-MgO 7 11.3
# verify_ce_in_sample("524d5968ee54ae1228eecc896d1ab4d5ca691f1897e0d94eedac13ed33eea052") # Li2TiO3-MgO 6-5 7.4
# verify_ce_in_sample("a94e495f0a9374355ff9c54fc4c00753433f9de574c3cd05b28b9af19282df4e") # Li2TiO3-MgO 5-5 8.1

  from .autonotebook import tqdm as notebook_tqdm


════════════════════════════════════════════════════════════════════════
═══════════════ CE Validation (per-site metrics in meV) ════════════════
════════════════════════════════════════════════════════════════════════
CE key        : 1b7d43754b33c7ad91fd06bc4641bacd11d3ac69824d4ec0ce939db12714f629
samples (n)   : 181
tolerances    : atol=1e-06, rtol=1e-08
------------------------------------------------------------------------
In-sample
Metric         Stored    Recomputed   Δ (recomp − stored)
-------------  --------  ----------   -------------------
MAE               4.647      4.647           0.000
RMSE              6.172      6.172           0.000
Max|err|         25.486     25.486           0.000
------------------------------------------------------------------------
Stored 5-fold CV
Metric         Value
-------------  --------
n              181
MAE               5.814
RMSE              7.724
Max|err|         28.639
---------------------------------------------------------------

True