In [None]:
from typing import Any, Mapping, Sequence, cast
import math
from collections import defaultdict

import numpy as np
from pymatgen.core import Structure

from smol.cofe import ClusterExpansion

# ---- phaseedge imports (modern-only) ----
from phaseedge.jobs.store_ce_model import lookup_ce_by_key
from phaseedge.jobs.train_ce import (
    featurize_structures,
    predict_from_features,
    compute_stats,
    _n_replace_sites_from_prototype,
)
from phaseedge.schemas.calc_spec import CalcSpec
from phaseedge.schemas.mixture import Mixture, composition_map_sig, sublattices_from_mixtures
from phaseedge.science.prototype_spec import PrototypeSpec
from phaseedge.storage.cetrainref_dataset import Dataset
from phaseedge.storage.store import lookup_total_energy_eV
from phaseedge.utils.keys import occ_key_for_structure


# ----------------- helpers: structures, sources, sublattices -----------------

def fmt_mev(x: float | str) -> str:
    if isinstance(x, str):
        return x
    return f"{1e3 * float(x):.3f}"

def _gather_composition_mixtures(sources: Sequence[Mapping[str, Any]]) -> list[Mixture]:
    mixes: list[Mixture] = []
    for s in sources:
        if str(s.get("type", "")).lower() != "composition":
            continue
        for m in s.get("mixtures", []):
            mixes.append(
                Mixture.from_dict(
                    {
                        "composition_map": dict(m["composition_map"]),
                        "K": int(m["K"]),
                        "seed": int(m["seed"]),
                    }
                )
            )
    if not mixes:
        raise RuntimeError("Expected 'composition' mixtures in sources (modern schema).")
    return mixes


def _sublattices_for_doc(doc: Mapping[str, Any]) -> dict[str, tuple[str, ...]]:
    """
    Determine replaceable sublattices for per-site scaling:
      - composition CE: from this doc's mixtures
      - dopt sampling CE : from the base (composition) CE's mixtures
    """
    sources = cast(Sequence[Mapping[str, Any]], doc["sources"])
    if not sources:
        raise RuntimeError("CE document missing 'sources'.")
    s0 = sources[0]
    s_type = str(s0.get("type", "")).lower()

    if s_type == "composition":
        return sublattices_from_mixtures(_gather_composition_mixtures(sources))

    if s_type == "dopt_sampling_intent":
        base_ce_key = cast(str, s0["base_ce_key"])
        base_doc = lookup_ce_by_key(base_ce_key)
        if not base_doc:
            raise RuntimeError(f"Base CE not found for dopt_sampling_intent source: {base_ce_key}")
        return sublattices_from_mixtures(_gather_composition_mixtures(cast(Sequence[Mapping[str, Any]], base_doc["sources"])))

    raise RuntimeError(f"Unsupported source type in CE doc: {s_type!r}")

# ----------------- pretty printers ------------------------------------------

def _print_design_metrics(dm: Mapping[str, Any] | None) -> None:
    if not dm:
        print("Design matrix diagnostics: <not available>")
        return

    def _get(name: str, default: float | int | str = float("nan")) -> Any:
        return dm.get(name, default)

    try:
        n_samples = int(_get("n_samples", -1))
        n_features = int(_get("n_features", -1))
        rank = int(_get("rank", -1))

        sigma_max = float(_get("sigma_max", float("nan")))
        sigma_min = float(_get("sigma_min", float("nan")))
        kappa = float(_get("condition_number", float("inf")))
        logdet = float(_get("logdet_xtx", float("nan")))

        lev_mean = float(_get("leverage_mean", float("nan")))
        lev_max = float(_get("leverage_max", float("nan")))
        lev_p95 = float(_get("leverage_p95", float("nan")))

        weighting_applied = bool(_get("weighting_applied", False))
        standardization = str(_get("standardization", "none"))
        zero_var = int(_get("zero_variance_feature_count", 0))
    except Exception:
        print("Design matrix diagnostics: <malformed in CE document>")
        return

    print("Design matrix diagnostics (stored)")
    print("  Samples / Features     : %d / %d" % (n_samples, n_features))
    print("  Rank                   : %d" % rank)
    print("  Sigma max / min        : %.6g / %.6g" % (sigma_max, sigma_min))
    if math.isfinite(kappa):
        print("  Condition number (κ)   : %.6g" % kappa)
    else:
        print("  Condition number (κ)   : inf (rank-deficient or σ_min≈0)")
    print("  log det(XᵀX)           : %.6g" % logdet)
    print("  Leverage mean / p95/max: %.6g / %.6g / %.6g" % (lev_mean, lev_p95, lev_max))
    print("  Weighting applied      : %s" % weighting_applied)
    print("  Standardization        : %s" % standardization)
    print("  Zero-variance columns  : %d" % zero_var)


def _print_mixture_energy_ranges(
    mixture_rows: list[tuple[str, int, float, float, float, float]],
) -> None:
    """
    Each row: (mixture_key, n, E_min_prim, E_max_prim, dE_prim, dE_meV_per_site)
    """
    print("-" * 72)
    print("Mixture energy span (training set)".center(72))
    print("-" * 72)
    hdr = (
        f"{'Mixture (Es/Fm counts)':<42}  {'n':>3}  "
        f"{'E_min (eV/prim)':>14}  {'E_max (eV/prim)':>16}  {'ΔE (eV/prim)':>13}  {'ΔE (meV/site)':>15}"
    )
    print(hdr)
    print("-" * 72)
    for key, n, e_min, e_max, de_prim, de_mev_site in mixture_rows:
        print(
            f"{key:<42}  {n:>3d}  "
            f"{e_min:>14.6f}  {e_max:>16.6f}  {de_prim:>13.6f}  {de_mev_site:>15.3f}"
        )
    print("-" * 72)


def _print_summary(
    *,
    ce_key: str,
    n: int,
    stored_in: Mapping[str, Any],
    recomputed_in: Mapping[str, Any],
    stored_cv: Mapping[str, Any] | None,
    design_metrics: Mapping[str, Any] | None,
    mixture_rows: list[tuple[str, int, float, float, float, float]],
    atol: float,
    rtol: float,
    ok: bool,
) -> None:
    s_mae, s_rmse, s_max = (
        float(stored_in["mae_per_site"]),
        float(stored_in["rmse_per_site"]),
        float(stored_in["max_abs_per_site"]),
    )
    r_mae, r_rmse, r_max = (
        float(recomputed_in["mae_per_site"]),
        float(recomputed_in["rmse_per_site"]),
        float(recomputed_in["max_abs_per_site"]),
    )

    print("═" * 72)
    print(f" CE Validation (per-site metrics in meV) ".center(72, "═"))
    print("═" * 72)
    print(f"CE key        : {ce_key}")
    print(f"samples (n)   : {n}")
    print(f"tolerances    : atol={atol:g}, rtol={rtol:g}")
    print("-" * 72)
    print("In-sample")
    print("Metric         Stored    Recomputed   Δ (recomp − stored)")
    print("-------------  --------  ----------   -------------------")
    print(f"MAE            {fmt_mev(s_mae):>8}   {fmt_mev(r_mae):>8}        {fmt_mev(r_mae - s_mae):>8}")
    print(f"RMSE           {fmt_mev(s_rmse):>8}   {fmt_mev(r_rmse):>8}        {fmt_mev(r_rmse - s_rmse):>8}")
    print(f"Max|err|       {fmt_mev(s_max):>8}   {fmt_mev(r_max):>8}        {fmt_mev(r_max - s_max):>8}")
    print("-" * 72)

    if stored_cv and all(k in stored_cv for k in ("n", "mae_per_site", "rmse_per_site", "max_abs_per_site")):
        cv_n = int(stored_cv["n"]); cv_mae = float(stored_cv["mae_per_site"])
        cv_rmse = float(stored_cv["rmse_per_site"]); cv_max = float(stored_cv["max_abs_per_site"])
        print("Stored 5-fold CV")
        print("Metric         Value")
        print("-------------  --------")
        print(f"n              {cv_n}")
        print(f"MAE            {fmt_mev(cv_mae):>8}")
        print(f"RMSE           {fmt_mev(cv_rmse):>8}")
        print(f"Max|err|       {fmt_mev(cv_max):>8}")
        print("-" * 72)
    else:
        print("Stored 5-fold CV: <not available>")
        print("-" * 72)

    _print_design_metrics(design_metrics)

    if mixture_rows:
        _print_mixture_energy_ranges(mixture_rows)

    print("Result         : " + ("✅  PASS (in-sample matches stored)" if ok else "❌  FAIL (in-sample mismatch)"))
    print("═" * 72)


# ----------------- main: in-sample validator + mixture spans -----------------

def verify_ce_in_sample(
    ce_key: str,
    *,
    atol: float = 1e-6,
    rtol: float = 1e-8,
    verbose: bool = True,
) -> bool:
    """
    Validate CE *in-sample* stats using the serialized Structures stored in train_refs.
    Also print per-mixture energy ranges (E_min, E_max, ΔE) by inferring Es/Fm
    sublattice counts from O-coordination (4 vs 6).
    """
    # ---- Load CE document (modern schema assumed) ----
    doc = lookup_ce_by_key(ce_key)
    if not doc:
        raise RuntimeError(f"No CE found for ce_key={ce_key}")

    sx, sy, sz = (int(x) for x in doc["supercell_diag"])
    supercell_diag = (sx, sy, sz)

    stats = cast(Mapping[str, Any], doc["stats"])
    in_sample_stored = cast(Mapping[str, Any], stats["in_sample"])
    stored_cv = cast(Mapping[str, Any] | None, stats.get("five_fold_cv"))

    # Design metrics (stored at training time)
    design_metrics = cast(Mapping[str, Any] | None, doc.get("design_metrics"))

    # sublattices for per-site scaling
    sublattices = _sublattices_for_doc(doc)
    # Cation symbols we care about (e.g., {"Al","Mg"})
    cation_symbols: set[str] = set()
    for _, allowed in sublattices.items():
        for el in allowed:
            cation_symbols.add(str(el))
    if not cation_symbols:
        raise RuntimeError("Empty cation symbol set inferred from sublattices.")

    # ---- Gather training Structures and energies in the recorded order ----
    dataset = Dataset.from_key(doc["dataset_key"])

    # Targets are per primitive/conventional cell
    n_prims = int(np.prod(np.asarray(supercell_diag, dtype=int)))
    calc_spec_top = CalcSpec.from_dict(doc["calc_spec"])

    structures: list[Structure] = []
    y_true_per_prim: list[float] = []
    by_composition_map: dict[str, list[float]] = defaultdict(list)
    for i, r in enumerate(dataset.train_refs):
        composition_map = r.composition_map
        occ_key = r.occ_key
        calc_spec = r.calc_spec

        if calc_spec_top != calc_spec:
            raise RuntimeError(
                f"train_refs[{i}] calc_spec identity differs from CE doc: "
                f"{calc_spec} != {calc_spec_top}"
            )

        # Strong sanity: recompute occ_key from structure and compare
        ok2 = occ_key_for_structure(r.structure)
        if ok2 != occ_key:
            raise RuntimeError(
                f"train_refs[{i}] occ_key mismatch: expected {occ_key}, rebuilt {ok2}. "
                "This indicates a change in structure hashing."
            )

        structures.append(r.structure)
        energy_result = lookup_total_energy_eV(
            occ_key=occ_key,
            calc_spec=calc_spec,
        )
        if energy_result is None or energy_result.max_force_eV_per_A > calc_spec.max_force_eV_per_A:
            raise RuntimeError(f"Energy not found in store for train_refs[{i}] composition_map={composition_map}, occ_key={occ_key[:12]}...")
        y_true_per_prim.append(energy_result.total_energy_eV / float(n_prims))
        by_composition_map[composition_map_sig(composition_map)].append(energy_result.total_energy_eV / float(n_prims))

    # ---- Rehydrate CE and build features ----
    ce = ClusterExpansion.from_dict(cast(Mapping[str, Any], doc["payload"]))
    subspace = ce.cluster_subspace

    _, X = featurize_structures(
        subspace=subspace,
        structures=structures,
        supercell_diag=supercell_diag,
    )

    # Per-site scaling (replaceable sites only)
    prototype_spec = PrototypeSpec.from_dict(doc["prototype_spec"])
    sites_per_supercell = _n_replace_sites_from_prototype(
        prototype_spec=prototype_spec, supercell_diag=supercell_diag,
    )
    if sites_per_supercell % n_prims != 0:
        raise RuntimeError("sites_per_supercell is not divisible by n_prims (prototype/sublattice mismatch).")
    sites_per_prim = sites_per_supercell // n_prims
    scale_site = 1.0 / float(sites_per_prim)

    # ---- In-sample prediction using stored ECIs ----
    coefs = np.asarray(getattr(ce, "coefs"), dtype=np.float64)
    y_pred_per_prim_ins = predict_from_features(X, coefs)

    in_sample_recomputed = compute_stats(
        (np.array(y_true_per_prim) * scale_site).tolist(),
        (y_pred_per_prim_ins * scale_site).tolist(),
    )

    mix_rows: list[tuple[str, int, float, float, float, float]] = []
    for key, vals in sorted(by_composition_map.items(), key=lambda kv: min(kv[1])):
        n_k = len(vals)
        e_min = min(vals)
        e_max = max(vals)
        dE_prim = e_max - e_min
        # Convert to meV per site for quick comparison with CE error bars
        dE_meV_per_site = 1e3 * (dE_prim / float(sites_per_prim))
        mix_rows.append((key, n_k, e_min, e_max, dE_prim, dE_meV_per_site))

    # ---- Compare + summarize ----
    ok = (
        int(in_sample_recomputed["n"]) == int(in_sample_stored["n"])
        and math.isclose(float(in_sample_recomputed["mae_per_site"]),  float(in_sample_stored["mae_per_site"]),  rel_tol=rtol, abs_tol=atol)
        and math.isclose(float(in_sample_recomputed["rmse_per_site"]), float(in_sample_stored["rmse_per_site"]), rel_tol=rtol, abs_tol=atol)
        and math.isclose(float(in_sample_recomputed["max_abs_per_site"]), float(in_sample_stored["max_abs_per_site"]), rel_tol=rtol, abs_tol=atol)
    )

    if verbose:
        _print_summary(
            ce_key=ce_key,
            n=int(in_sample_recomputed["n"]),
            stored_in=in_sample_stored,
            recomputed_in=in_sample_recomputed,
            stored_cv=stored_cv,
            design_metrics=design_metrics,
            mixture_rows=mix_rows,
            atol=atol,
            rtol=rtol,
            ok=ok,
        )

    return ok




# verify_ce_in_sample("885ec0cf0822b8159e415a5c93b10ffc3b83366b78af019e9fd61b75bb37548f") # Ba2AlTaO6 r2SCAN200 0.05 DONE
# verify_ce_in_sample("f4acc18bc4cee1ee2996366370b238823f8b11c2d888c0d3929c831c2a21d92a") # Ba2GaNbO6 r2SCAN200 0.05 DONE
# verify_ce_in_sample("16581e322e3cf577d72525cf3ff59ddd10a430b04df480b2b577b07433381d8f") # Ba2GaSbO6 r2SCAN200 0.05 DONE
# verify_ce_in_sample("38c2a00d15799ca550488100025117815f27339d588d40ea5b7a51e565f6d8b9") # Ba2GaTaO6 r2SCAN200 0.05 DONE
# verify_ce_in_sample("70131368b2ee45d6739a0cacce9a00e4a3b866f593fcea346779d3d627a9c723") # Ba2InNbO6 r2SCAN200 0.05 DONE
# verify_ce_in_sample("2b378090d2086650386d881397103711ac0bf11f21a5fbf6e33b8a8749396c4e") # Ba2NbScO6 r2SCAN200 0.05 DONE
# verify_ce_in_sample("d0e322d8342821f7f3d3118a806e3d3df6bfd4601dab81c36d34962693bed316") # Ba2ScTaO6 r2SCAN200 0.05 DONE
# verify_ce_in_sample("1efbc65a88d4519a616b768f9af26e783e805160ae1a26dafc22a88cd509383d") # Bi2MgSnO6 r2SCAN200 0.1/0.15 DONE - TERRIBLE
# verify_ce_in_sample("a2ca21a50c31bbd31aee3ba76cc3f99ef6f3662cd6f25c5b9b99a8df8e176a9a") # Bi2MgZrO6 r2SCAN200 0.1/0.15 DONE - TERRIBLE
# verify_ce_in_sample("c8aef15c53f9d66e1e846a6c95fa6284c30d8e1b3b817b52b671bf103669fb7d") # Ca2AlSbO6 r2SCAN200 MPA/0.15 DONE
# verify_ce_in_sample("960bdbfc5b58fc1a867c1b188f5ec9a1dc8660622477a3151dd0a9506aac95b8") # Ca2AlTaO6 r2SCAN200 MPA/0.1 DONE
# verify_ce_in_sample("f7be7538d32e9046704d4d9c105f32243b2ab80eb702634a6131b1f941fc0b32") # Ca2GaNbO6 r2SCAN200 0.05 DONE
# verify_ce_in_sample("2b868a477ecb60c05f6c4a739c58b6f1dbd9458aed7811ec0a667609fac3c371") # Ca2GaSbO6 r2SCAN200 0.05 DONE
# verify_ce_in_sample("8483ad3e8b00a7b5bc3b75926351fceb6f12ca51d0689705d9b845ce56bd331a") # Ca2GaTaO6 r2SCAN200 0.05 DONE
# verify_ce_in_sample("b2bf9270a4ebbfbc9077829c95f71b1bf81f171f0ce1819c5086aaa3bd0463f3") # Ca2InNbO6 r2SCAN200 0.06 DONE
# verify_ce_in_sample("649dde67894239d2e4479c92b62421d01bc2f955f430e3e5e7731398ed0f920f") # Ca2InTaO6 r2SCAN200 0.05 DONE
# verify_ce_in_sample("fc00458822c875bd298f061164f8d638df2c05bd19bd17d275559111f377384c") # Ca2MgWO6 r2SCAN200 0.05/0.15 DONE
# verify_ce_in_sample("a6076f656353ac0a2d129f850c11b13dee484dce453124e9d557a19749f9ec89") # Ca2WZnO6 r2SCAN200 0.05 DONE
# verify_ce_in_sample("40d9f342907cdf59c3e847d7cf14e4c68659c5340560a35a0f2eacba240d5363") # Ca2NbScO6 r2SCAN200 0.05/0.15 DONE
# verify_ce_in_sample("5a0be8138a1ae6b7abed4f69ebe754edf1d427c6d119f96e013fccbc24bfa9b2") # Ca2SbScO6 r2SCAN200 0.05 DONE
# verify_ce_in_sample("d0def1e30d2b3fcfff1b41f9e96cd0671c4398fd856919ad7cfc7d6ea547eee1") # Ca2ScTaO6 r2SCAN200 0.1/0.12 DONE
# verify_ce_in_sample("1a54ab271d3c7554e5105156fb2a57affa7731eb3fd4eb375827786b057c9c89") # La2MgSnO6 r2SCAN200 0.1/0.15
# verify_ce_in_sample("b49f1138023b0b5f1725ae64d08745760480d9fed99f9e3720bc1644dcb78bf9") # La2MgZrO6 r2SCAN200 0.05 DONE
# verify_ce_in_sample("f33ef5ca759b32288cb67077c787cea4cee07fdfe34ab858dd77c212c6d108bf") # La2SnZnO6 r2SCAN200 0.05 DONE
# verify_ce_in_sample("919d992c53291d615fe91c0ade0b8986b3b7bc72e4a273563015630f6d5e53ca") # La2ZnZrO6 r2SCAN200 0.05 DONE
# verify_ce_in_sample("eb3d2d3b4e1ad498694a85e2d91e0293e7d5caf3f5c11c6aff6e8ea07832c2b8") # Sr2AlSbO6 r2SCAN200 0.05 DONE
# verify_ce_in_sample("4ecdd726f474369a0a054669795a9f0c3a3ab3bfcd02b94f0190250fc4c9bab4") # Sr2AlTaO6 r2SCAN200 0.05 DONE
# verify_ce_in_sample("2d226e9d616df91d9a6b4aee742c2dceb0a7d4f0ccf69270a1e62cf3e5b82fe5") # Sr2GaNbO6 r2SCAN200 0.1 DONE
# verify_ce_in_sample("a226eeee4b938b836c1b3ae58d6b7bbf83e094203c7960eb6b2bcb583e934390") # Sr2GaSbO6 r2SCAN200 0.05 DONE
# verify_ce_in_sample("f376b6bbe33c6bc0e5cfe1cda414dbd138004cbf403142f4d2f659371573e67e") # Sr2GaTaO6 r2SCAN200 0.05 DONE
# verify_ce_in_sample("98af0a253f377ec090884dcd33a72cf1b0ca6d3d293d8d787e045330a83d3cdc") # Sr2InNbO6 r2SCAN200 0.05 DONE
# verify_ce_in_sample("6d4a63f806171365aab94b969d3d0838b719ac57b6e70d0f91dbe34ff9a3a791") # Sr2NbScO6 r2SCAN200 0.1 DONE
# verify_ce_in_sample("b741f6fe45546535fa1b815b232acdb922be1a3cd4807320f5ac2802ac7b511a") # Sr2SbScO6 r2SCAN200 0.05 DONE
# verify_ce_in_sample("046e3e63bc37dc15eb97eeae909526c0150aefda35d2626abedbc57173507cae") # Sr2ScTaO6 r2SCAN200 0.1 DONE


# verify_ce_in_sample("32a8d2615677477b4d5ce9cb1ae3f5723ec72d74540714a4c2ddcf79feb81852") # MgAl2O4 r2SCAN500 0.02
# verify_ce_in_sample("b9bd310b04a9cd7289d81dd6c4394e1224a0c80015f590fcf5037ea607de29b7") # MgGa2O4 r2SCAN500 0.02
# verify_ce_in_sample("09d408a248cf2a4dddef6d01443ffbb6fca7c0e259c82d296a2e58c52729e1e6") # MgIn2O4 r2SCAN500 0.02
# verify_ce_in_sample("e4226393f112a920fff6009e26f7ddba0515434ce73ec58320b85db663de1bc2") # MgSc2O4 r2SCAN500 0.02
# verify_ce_in_sample("09d62fbdb713ff78299fa32f4b455152da1f381cec828ca884d72dce100ad3d7") # ZnAl2O4 r2SCAN500 0.02
# verify_ce_in_sample("671932910d06459c7901ee75e5f015b44ac43dd7c7a47aace189ccec7b7470e7") # ZnGa2O4 r2SCAN500 0.02
# verify_ce_in_sample("becde6c06e4d340fa60b506b68e4cd4d2254560d28cbec299f12d1a77f9115e4") # MgIn2O4 r2SCAN500 0.02
# verify_ce_in_sample("aac22c66a0fd01be40fc2ae186c671a8df165b04105a15069e1e36f3f68835ff") # MgSc2O4 r2SCAN500 0.02

# verify_ce_in_sample("8b1f974bef9f8eda6580984fbccf2b348d66c04275825eed0f4eb70e8464e89f") # MgTiO3 r2SCAN500 0.05
# verify_ce_in_sample("60b97aa56369fcf2d82f4c031d810ec2ca0d3a3e9f0bdb45b375c2cd136b7e46") # ZnTiO3 r2SCAN500 0.1
# verify_ce_in_sample("52f89331a57177878669dc32066053aa50c1c460c530c53d65e97bda55d3d63e") # CaTiO3 r2SCAN500 0.1

# verify_ce_in_sample("b508d5221d6b224fb56f37884b872615e86b593aea4b84c01ef12c49cdcb2b7a") # 876 .012/.015
verify_ce_in_sample("e679ea3128f01828d7f326fb603df1169f66defe645f4fd4538c3174eb437fb9") # 876 


════════════════════════════════════════════════════════════════════════
═══════════════ CE Validation (per-site metrics in meV) ════════════════
════════════════════════════════════════════════════════════════════════
CE key        : e679ea3128f01828d7f326fb603df1169f66defe645f4fd4538c3174eb437fb9
samples (n)   : 241
tolerances    : atol=1e-06, rtol=1e-08
------------------------------------------------------------------------
In-sample
Metric         Stored    Recomputed   Δ (recomp − stored)
-------------  --------  ----------   -------------------
MAE               0.016      0.016           0.000
RMSE              0.020      0.020           0.000
Max|err|          0.060      0.060           0.000
------------------------------------------------------------------------
Stored 5-fold CV
Metric         Value
-------------  --------
n              241
MAE               0.030
RMSE              0.042
Max|err|          0.212
---------------------------------------------------------------

True