In [1]:
from phaseedge.storage import store

coll = store.db_rw()["ce_models"]   # use RW handle to be safe
print("count:", coll.count_documents({}))
for d in coll.find({}, {"_id": 0, "ce_key": 1}).limit(10):
    print(d)


count: 2
{'ce_key': '781041790a68cc664c64278fed25132d984b3bfbc802d207e41d8626818ff737'}
{'ce_key': '76195c946ffef01b43a45fbccfb192c96d55011b69a54b0ccc2d6ff9f6f59bba'}


In [7]:
from typing import Any, Mapping, Sequence, cast
import math
import numpy as np
from numpy.typing import NDArray

from ase.atoms import Atoms
from pymatgen.core import Structure
from pymatgen.io.ase import AseAtomsAdaptor

from smol.cofe import ClusterExpansion

# ---- phaseedge imports (modern-only) ----
from phaseedge.jobs.store_ce_model import lookup_ce_by_key
from phaseedge.science.prototypes import PrototypeName
from phaseedge.jobs.train_ce import (
    featurize_structures,
    predict_from_features,
    compute_stats,
    _n_replace_sites_from_prototype,
)
from phaseedge.schemas.mixture import Mixture, sublattices_from_mixtures
from phaseedge.utils.keys import occ_key_for_atoms
from phaseedge.utils.testing import lookup_total_energy_eV, fmt_mev


# ----------------- helpers: structures, sources, sublattices -----------------

def _gather_composition_mixtures(sources: Sequence[Mapping[str, Any]]) -> list[Mixture]:
    mixes: list[Mixture] = []
    for s in sources:
        if str(s.get("type", "")).lower() != "composition":
            continue
        for m in s.get("mixtures", []):
            mixes.append(
                Mixture.from_dict(
                    {
                        "composition_map": dict(m["composition_map"]),
                        "K": int(m["K"]),
                        "seed": int(m["seed"]),
                    }
                )
            )
    if not mixes:
        raise RuntimeError("Expected 'composition' mixtures in sources (modern schema).")
    return mixes


def _sublattices_for_doc(doc: Mapping[str, Any]) -> dict[str, tuple[str, ...]]:
    """
    Determine replaceable sublattices for per-site scaling:
      - composition CE: from this doc's mixtures
      - wl_refined CE : from the base (composition) CE's mixtures
    """
    sources = cast(Sequence[Mapping[str, Any]], doc["sources"])
    if not sources:
        raise RuntimeError("CE document missing 'sources'.")
    s0 = sources[0]
    s_type = str(s0.get("type", "")).lower()

    if s_type == "composition":
        return sublattices_from_mixtures(_gather_composition_mixtures(sources))

    if s_type.startswith("wl_refined"):
        base_ce_key = cast(str, s0["base_ce_key"])
        base_doc = lookup_ce_by_key(base_ce_key)
        if not base_doc:
            raise RuntimeError(f"Base CE not found for wl_refined source: {base_ce_key}")
        return sublattices_from_mixtures(_gather_composition_mixtures(cast(Sequence[Mapping[str, Any]], base_doc["sources"])))

    raise RuntimeError(f"Unsupported source type in CE doc: {s_type!r}")


# ----------------- helpers: energy + summary -----------------

def _print_design_metrics(dm: Mapping[str, Any] | None) -> None:
    if not dm:
        print("Design matrix diagnostics: <not available>")
        return

    def _get(name: str, default: float | int | str = float("nan")) -> Any:
        return dm.get(name, default)

    try:
        n_samples = int(_get("n_samples", -1))
        n_features = int(_get("n_features", -1))
        rank = int(_get("rank", -1))

        sigma_max = float(_get("sigma_max", float("nan")))
        sigma_min = float(_get("sigma_min", float("nan")))
        kappa = float(_get("condition_number", float("inf")))
        logdet = float(_get("logdet_xtx", float("nan")))

        lev_mean = float(_get("leverage_mean", float("nan")))
        lev_max = float(_get("leverage_max", float("nan")))
        lev_p95 = float(_get("leverage_p95", float("nan")))

        weighting_applied = bool(_get("weighting_applied", False))
        standardization = str(_get("standardization", "none"))
        zero_var = int(_get("zero_variance_feature_count", 0))
    except Exception:
        print("Design matrix diagnostics: <malformed in CE document>")
        return

    print("Design matrix diagnostics (stored)")
    print("  Samples / Features     : %d / %d" % (n_samples, n_features))
    print("  Rank                   : %d" % rank)
    print("  Sigma max / min        : %.6g / %.6g" % (sigma_max, sigma_min))
    if math.isfinite(kappa):
        print("  Condition number (κ)   : %.6g" % kappa)
    else:
        print("  Condition number (κ)   : inf (rank-deficient or σ_min≈0)")
    print("  log det(XᵀX)           : %.6g" % logdet)
    print("  Leverage mean / p95/max: %.6g / %.6g / %.6g" % (lev_mean, lev_p95, lev_max))
    print("  Weighting applied      : %s" % weighting_applied)
    print("  Standardization        : %s" % standardization)
    print("  Zero-variance columns  : %d" % zero_var)


def _print_summary(
    *,
    ce_key: str,
    n: int,
    stored_in: Mapping[str, Any],
    recomputed_in: Mapping[str, Any],
    stored_cv: Mapping[str, Any] | None,
    design_metrics: Mapping[str, Any] | None,
    atol: float,
    rtol: float,
    ok: bool,
) -> None:
    s_mae, s_rmse, s_max = (
        float(stored_in["mae_per_site"]),
        float(stored_in["rmse_per_site"]),
        float(stored_in["max_abs_per_site"]),
    )
    r_mae, r_rmse, r_max = (
        float(recomputed_in["mae_per_site"]),
        float(recomputed_in["rmse_per_site"]),
        float(recomputed_in["max_abs_per_site"]),
    )

    print("═" * 72)
    print(f" CE Validation (per-site metrics in meV) ".center(72, "═"))
    print("═" * 72)
    print(f"CE key        : {ce_key}")
    print(f"samples (n)   : {n}")
    print(f"tolerances    : atol={atol:g}, rtol={rtol:g}")
    print("-" * 72)
    print("In-sample")
    print("Metric         Stored    Recomputed   Δ (recomp − stored)")
    print("-------------  --------  ----------   -------------------")
    print(f"MAE            {fmt_mev(s_mae):>8}   {fmt_mev(r_mae):>8}        {fmt_mev(r_mae - s_mae):>8}")
    print(f"RMSE           {fmt_mev(s_rmse):>8}   {fmt_mev(r_rmse):>8}        {fmt_mev(r_rmse - s_rmse):>8}")
    print(f"Max|err|       {fmt_mev(s_max):>8}   {fmt_mev(r_max):>8}        {fmt_mev(r_max - s_max):>8}")
    print("-" * 72)

    if stored_cv and all(k in stored_cv for k in ("n", "mae_per_site", "rmse_per_site", "max_abs_per_site")):
        cv_n = int(stored_cv["n"]); cv_mae = float(stored_cv["mae_per_site"])
        cv_rmse = float(stored_cv["rmse_per_site"]); cv_max = float(stored_cv["max_abs_per_site"])
        print("Stored 5-fold CV")
        print("Metric         Value")
        print("-------------  --------")
        print(f"n              {cv_n}")
        print(f"MAE            {fmt_mev(cv_mae):>8}")
        print(f"RMSE           {fmt_mev(cv_rmse):>8}")
        print(f"Max|err|       {fmt_mev(cv_max):>8}")
        print("-" * 72)
    else:
        print("Stored 5-fold CV: <not available>")
        print("-" * 72)

    _print_design_metrics(design_metrics)
    print("-" * 72)

    print("Result         : " + ("✅  PASS (in-sample matches stored)" if ok else "❌  FAIL (in-sample mismatch)"))
    print("═" * 72)


# ----------------- main: in-sample validator (uses stored Structures only) -----------------

def verify_ce_in_sample(
    ce_key: str,
    *,
    atol: float = 1e-6,
    rtol: float = 1e-8,
    verbose: bool = True,
) -> bool:
    """
    Validate CE *in-sample* stats using the serialized Structures stored in train_refs.
    No occupancies are used. Energies are read from FF outputs only.

    Modern-only invariants:
      - CE doc has 'train_refs', every item includes a 'structure' field.
      - For wl_refined CEs, 'sources[0].base_ce_key' exists and its mixtures
        define the replaceable sublattice set for per-site scaling.
    """
    # ---- Load CE document (modern schema assumed) ----
    doc = lookup_ce_by_key(ce_key)
    if not doc:
        raise RuntimeError(f"No CE found for ce_key={ce_key}")

    prototype = PrototypeName(doc["prototype"])
    proto_params = cast(Mapping[str, Any], doc["prototype_params"])
    sx, sy, sz = (int(x) for x in doc["supercell_diag"])
    supercell_diag = (sx, sy, sz)

    stats = cast(Mapping[str, Any], doc["stats"])
    in_sample_stored = cast(Mapping[str, Any], stats["in_sample"])
    stored_cv = cast(Mapping[str, Any] | None, stats.get("five_fold_cv"))

    # Design metrics (stored at training time)
    design_metrics = cast(Mapping[str, Any] | None, doc.get("design_metrics"))

    # Engine identity (must match train_refs entries)
    model_top = cast(str, doc["model"])
    relax_cell_top = bool(doc["relax_cell"])
    dtype_top = cast(str, doc["dtype"])

    # sublattices for per-site scaling
    sublattices = _sublattices_for_doc(doc)

    # ---- Gather training Structures and energies in the recorded order ----
    train_refs = cast(Sequence[Mapping[str, Any]], doc["train_refs"])
    if not train_refs:
        raise RuntimeError("CE document has no 'train_refs'.")
    if not all("structure" in r for r in train_refs):
        raise RuntimeError("All train_refs must include 'structure' (modern-only).")

    structures: list[Structure] = []
    energies_supercell: list[float] = []

    for i, r in enumerate(train_refs):
        set_id = cast(str, r["set_id"])
        occ_key = cast(str, r["occ_key"])
        model = cast(str, r.get("model", model_top))
        relax_cell = bool(r.get("relax_cell", relax_cell_top))
        dtype = cast(str, r.get("dtype", dtype_top))

        if (model, relax_cell, dtype) != (model_top, relax_cell_top, dtype_top):
            raise RuntimeError(
                f"train_refs[{i}] engine identity differs from CE doc: "
                f"({model},{relax_cell},{dtype}) != ({model_top},{relax_cell_top},{dtype_top})"
            )

        pmg_struct = Structure.from_dict(r["structure"])

        # Strong sanity: recompute occ_key from structure and compare
        atoms = AseAtomsAdaptor.get_atoms(pmg_struct)
        ok2 = occ_key_for_atoms(cast(Atoms, atoms))
        if ok2 != occ_key:
            raise RuntimeError(
                f"train_refs[{i}] occ_key mismatch: expected {occ_key}, rebuilt {ok2}. "
                "This indicates a change in structure hashing."
            )

        structures.append(pmg_struct)
        energies_supercell.append(
            lookup_total_energy_eV(
                set_id=set_id,
                occ_key=occ_key,
                model=model_top,
                relax_cell=relax_cell_top,
                dtype=dtype_top,
            )
        )

    # ---- Rehydrate CE and build features ----
    ce = ClusterExpansion.from_dict(cast(Mapping[str, Any], doc["payload"]))
    subspace = ce.cluster_subspace

    _, X = featurize_structures(
        subspace=subspace,
        structures=structures,
        supercell_diag=supercell_diag,
    )

    # Targets are per primitive/conventional cell
    n_prims = int(np.prod(np.asarray(supercell_diag, dtype=int)))
    y_true_per_prim: NDArray[np.float64] = (
        np.asarray(energies_supercell, dtype=np.float64) / float(n_prims)
    ).astype(np.float64, copy=False)

    # Per-site scaling (replaceable sites only)
    sites_per_supercell = _n_replace_sites_from_prototype(
        prototype=prototype,
        prototype_params=proto_params,
        supercell_diag=supercell_diag,
        sublattices=sublattices,
    )
    if sites_per_supercell % n_prims != 0:
        raise RuntimeError("sites_per_supercell is not divisible by n_prims (prototype/sublattice mismatch).")
    sites_per_prim = sites_per_supercell // n_prims
    scale_site = 1.0 / float(sites_per_prim)

    # ---- In-sample prediction using stored ECIs ----
    coefs = np.asarray(getattr(ce, "coefs"), dtype=np.float64)
    y_pred_per_prim_ins = predict_from_features(X, coefs)

    in_sample_recomputed = compute_stats(
        (y_true_per_prim * scale_site).tolist(),
        (y_pred_per_prim_ins * scale_site).tolist(),
    )

    # ---- Compare + summarize ----
    ok = (
        int(in_sample_recomputed["n"]) == int(in_sample_stored["n"])
        and math.isclose(float(in_sample_recomputed["mae_per_site"]),  float(in_sample_stored["mae_per_site"]),  rel_tol=rtol, abs_tol=atol)
        and math.isclose(float(in_sample_recomputed["rmse_per_site"]), float(in_sample_stored["rmse_per_site"]), rel_tol=rtol, abs_tol=atol)
        and math.isclose(float(in_sample_recomputed["max_abs_per_site"]), float(in_sample_stored["max_abs_per_site"]), rel_tol=rtol, abs_tol=atol)
    )

    if verbose:
        _print_summary(
            ce_key=ce_key,
            n=int(in_sample_recomputed["n"]),
            stored_in=in_sample_stored,
            recomputed_in=in_sample_recomputed,
            stored_cv=stored_cv,
            design_metrics=design_metrics,
            atol=atol,
            rtol=rtol,
            ok=ok,
        )

    return ok

# ---- Example usage in a notebook:
verify_ce_in_sample("781041790a68cc664c64278fed25132d984b3bfbc802d207e41d8626818ff737")
verify_ce_in_sample("76195c946ffef01b43a45fbccfb192c96d55011b69a54b0ccc2d6ff9f6f59bba")


════════════════════════════════════════════════════════════════════════
═══════════════ CE Validation (per-site metrics in meV) ════════════════
════════════════════════════════════════════════════════════════════════
CE key        : 781041790a68cc664c64278fed25132d984b3bfbc802d207e41d8626818ff737
samples (n)   : 102
tolerances    : atol=1e-06, rtol=1e-08
------------------------------------------------------------------------
In-sample
Metric         Stored    Recomputed   Δ (recomp − stored)
-------------  --------  ----------   -------------------
MAE               0.545      0.545           0.000
RMSE              0.691      0.691           0.000
Max|err|          1.818      1.818           0.000
------------------------------------------------------------------------
Stored 5-fold CV
Metric         Value
-------------  --------
n              102
MAE               1.312
RMSE              1.724
Max|err|          5.777
---------------------------------------------------------------

True