In [1]:
from typing import Any, Mapping, Sequence, cast
from collections import defaultdict
import math
import numpy as np
from numpy.typing import NDArray

from ase.atoms import Atoms
from pymatgen.core import Structure
from pymatgen.io.ase import AseAtomsAdaptor

# ---- phaseedge imports ----
from phaseedge.science.prototypes import PrototypeName
from phaseedge.jobs.store_ce_model import lookup_ce_by_key
from phaseedge.jobs.decide_relax import lookup_ff_task
from phaseedge.science.prototypes import make_prototype
from phaseedge.science.random_configs import make_one_snapshot
from phaseedge.schemas.mixture import sublattices_from_mixtures, Mixture
from phaseedge.jobs.train_ce import (
    featurize_structures,
    predict_from_features,
    compute_stats,
    _n_replace_sites_from_prototype,
)
from phaseedge.utils.keys import (
    compute_set_id,
    rng_for_index,
    occ_key_for_atoms,
)
from smol.cofe import ClusterExpansion


# ----------------- DB helpers -----------------

def _lookup_total_energy_eV(
    *, set_id: str, occ_key: str, model: str, relax_cell: bool, dtype: str
) -> float:
    """
    Modern schema only: lookup_ff_task returns the inner 'output' document.
    Energy is doc['output']['energy'].
    """
    doc = lookup_ff_task(
        set_id=set_id, occ_key=occ_key, model=model, relax_cell=relax_cell, dtype=dtype
    )
    if doc is None:
        raise RuntimeError(f"No FF task output found for set_id={set_id} occ_key={occ_key}")
    try:
        return float(doc["output"]["energy"])
    except Exception as exc:
        raise RuntimeError(
            f"FF task output missing energy for set_id={set_id} occ_key={occ_key}"
        ) from exc


# ----------------- Composition source parsing -----------------

def _gather_composition_mixtures(sources: Sequence[Mapping[str, Any]]) -> list[Mixture]:
    """Modern, normalized sources → list[Mixture]."""
    mixes: list[Mixture] = []
    for s in sources:
        if str(s.get("type", "")).lower() != "composition":
            continue
        for m in s.get("mixtures", []):
            mixes.append(Mixture.from_dict({
                "composition_map": dict(m["composition_map"]),
                "K": int(m["K"]),
                "seed": int(m["seed"]),
            }))
    return mixes


# ----------------- Snapshot reconstruction -----------------

def _build_occ_to_structure_for_set(
    *,
    set_id: str,
    mixture: Mixture,
    conv_cell: Atoms,
    supercell_diag: tuple[int, int, int],
) -> dict[str, Structure]:
    """
    Generate exactly K snapshots for this set (deterministic RNG),
    and return a map occ_key -> unrelaxed Structure.
    """
    occ2struct: dict[str, Structure] = {}
    for idx in range(mixture.K):
        rng = rng_for_index(set_id, idx)
        snap: Atoms = make_one_snapshot(
            conv_cell=conv_cell,
            supercell_diag=supercell_diag,
            composition_map=mixture.composition_map,
            rng=rng,
        )
        key = occ_key_for_atoms(snap)
        occ2struct[key] = cast(Structure, AseAtomsAdaptor.get_structure(snap))
    return occ2struct


# ----------------- Pretty summary -----------------

def _fmt_mev(x: float | str) -> str:
    if isinstance(x, str):
        return x
    return f"{1e3 * float(x):.3f}"

def _print_summary(
    *,
    ce_key: str,
    n: int,
    stored_in: Mapping[str, Any],
    recomputed_in: Mapping[str, Any],
    stored_cv: Mapping[str, Any] | None,
    atol: float,
    rtol: float,
    ok: bool,
) -> None:
    s_mae, s_rmse, s_max = (
        float(stored_in["mae_per_site"]),
        float(stored_in["rmse_per_site"]),
        float(stored_in["max_abs_per_site"]),
    )
    r_mae, r_rmse, r_max = (
        float(recomputed_in["mae_per_site"]),
        float(recomputed_in["rmse_per_site"]),
        float(recomputed_in["max_abs_per_site"]),
    )

    print("═" * 72)
    print(f" CE Validation (per-site metrics in meV) ".center(72, "═"))
    print("═" * 72)
    print(f"CE key        : {ce_key}")
    print(f"samples (n)   : {n}")
    print(f"tolerances    : atol={atol:g}, rtol={rtol:g}")
    print("-" * 72)
    print("In-sample")
    print("Metric         Stored    Recomputed   Δ (recomp − stored)")
    print("-------------  --------  ----------   -------------------")
    print(f"MAE            {_fmt_mev(s_mae):>8}   {_fmt_mev(r_mae):>8}        {_fmt_mev(r_mae - s_mae):>8}")
    print(f"RMSE           {_fmt_mev(s_rmse):>8}   {_fmt_mev(r_rmse):>8}        {_fmt_mev(r_rmse - s_rmse):>8}")
    print(f"Max|err|       {_fmt_mev(s_max):>8}   {_fmt_mev(r_max):>8}        {_fmt_mev(r_max - s_max):>8}")
    print("-" * 72)

    if stored_cv and all(k in stored_cv for k in ("n", "mae_per_site", "rmse_per_site", "max_abs_per_site")):
        cv_n = int(stored_cv["n"])
        cv_mae = float(stored_cv["mae_per_site"])
        cv_rmse = float(stored_cv["rmse_per_site"])
        cv_max = float(stored_cv["max_abs_per_site"])
        print("Stored 5-fold CV")
        print("Metric         Value")
        print("-------------  --------")
        print(f"n              {cv_n}")
        print(f"MAE            {_fmt_mev(cv_mae):>8}")
        print(f"RMSE           {_fmt_mev(cv_rmse):>8}")
        print(f"Max|err|       {_fmt_mev(cv_max):>8}")
        print("-" * 72)
    else:
        print("Stored 5-fold CV: <not available>")
        print("-" * 72)

    print("Result         : " + ("✅  PASS (in-sample matches stored)"
                              if ok else "❌  FAIL (in-sample mismatch)"))
    print("═" * 72)


# ----------------- Main verifier (in-sample only) -----------------

def verify_ce_in_sample(
    ce_key: str,
    *,
    atol: float = 1e-6,
    rtol: float = 1e-8,
    verbose: bool = True,
) -> bool:
    """
    Recompute and validate **in-sample** per-site stats (composition CE only).

    PASS/FAIL is based solely on the top-level 'in_sample' stats.
    """
    # ---- Load document and assume modern schema ----
    doc = lookup_ce_by_key(ce_key)
    if not doc:
        raise RuntimeError(f"No CE found for ce_key={ce_key}")

    stats_stored = cast(Mapping[str, Any], doc["stats"])
    in_sample_stored = cast(Mapping[str, Any], stats_stored["in_sample"])
    stored_cv = cast(Mapping[str, Any] | None, stats_stored.get("five_fold_cv"))

    sources = cast(Sequence[Mapping[str, Any]], doc["sources"])
    mixtures = _gather_composition_mixtures(sources)
    if not mixtures:
        raise RuntimeError("Expected 'composition' sources in modern schema.")

    prototype = PrototypeName(doc["prototype"])
    proto_params = cast(Mapping[str, Any], doc["prototype_params"])
    sx, sy, sz = (int(x) for x in doc["supercell_diag"])
    supercell_diag = (sx, sy, sz)
    model = cast(str, doc["model"])
    relax_cell = bool(doc["relax_cell"])
    dtype = cast(str, doc["dtype"])
    train_refs = cast(Sequence[Mapping[str, Any]], doc["train_refs"])

    # ---- Build index of occ_keys per set_id from train_refs (strict) ----
    refs_by_sid: dict[str, list[str]] = defaultdict(list)
    for r in train_refs:
        refs_by_sid[str(r["set_id"])].append(str(r["occ_key"]))

    # Strict coverage check: computed set_ids from mixtures must equal refs_by_sid keys
    computed_sids = {
        compute_set_id(
            prototype=prototype,
            prototype_params=proto_params,
            supercell_diag=supercell_diag,
            composition_map=mix.composition_map,
            seed=mix.seed,
        )
        for mix in mixtures
    }
    if set(refs_by_sid.keys()) != computed_sids:
        missing_in_refs = sorted(computed_sids - set(refs_by_sid.keys()))
        missing_in_mixes = sorted(set(refs_by_sid.keys()) - computed_sids)
        raise RuntimeError(
            "Mismatch between mixture-derived set_ids and train_refs set_ids.\n"
            f"  Missing in train_refs: {missing_in_refs}\n"
            f"  Missing in mixtures  : {missing_in_mixes}"
        )

    # ---- Prepare prototype conventional cell (ASE Atoms) ----
    conv_cell: Atoms = make_prototype(prototype, **dict(proto_params))

    # ---- Reconstruct training structures & energies in deterministic order ----
    structures: list[Structure] = []
    energies_supercell: list[float] = []
    # Iterate mixtures → stable SID order; then occ_keys in recorded order
    for mix in mixtures:
        sid = compute_set_id(
            prototype=prototype,
            prototype_params=proto_params,
            supercell_diag=supercell_diag,
            composition_map=mix.composition_map,
            seed=mix.seed,
        )
        occ_to_struct = _build_occ_to_structure_for_set(
            set_id=sid,
            mixture=mix,
            conv_cell=conv_cell,
            supercell_diag=supercell_diag,
        )
        for occ_key in refs_by_sid[sid]:
            s = occ_to_struct.get(occ_key)
            if s is None:
                raise RuntimeError(f"Could not reconstruct snapshot for set_id={sid} occ_key={occ_key}")
            structures.append(s)
            energies_supercell.append(
                _lookup_total_energy_eV(set_id=sid, occ_key=occ_key, model=model, relax_cell=relax_cell, dtype=dtype)
            )

    if len(structures) != len(train_refs):
        raise RuntimeError("Mismatch in number of reconstructed structures vs. train_refs.")

    # ---- Rehydrate CE and build features ----
    ce = ClusterExpansion.from_dict(cast(Mapping[str, Any], doc["payload"]))
    subspace = ce.cluster_subspace

    _, X = featurize_structures(
        subspace=subspace,
        structures=cast(Sequence[Structure], structures),
        supercell_diag=supercell_diag,
    )

    # Targets are per primitive/conventional cell
    n_prims = int(np.prod(np.asarray(supercell_diag, dtype=int)))
    y_true_per_prim: NDArray[np.float64] = (
        np.asarray(energies_supercell, dtype=np.float64) / float(n_prims)
    ).astype(np.float64, copy=False)

    # Per-site scaling
    sites_per_supercell = _n_replace_sites_from_prototype(
        prototype=prototype,
        prototype_params=proto_params,
        supercell_diag=supercell_diag,
        sublattices=sublattices_from_mixtures(mixtures),
    )
    if sites_per_supercell % n_prims != 0:
        raise RuntimeError("sites_per_supercell is not divisible by n_prims.")
    sites_per_prim = sites_per_supercell // n_prims
    scale_site = 1.0 / float(sites_per_prim)

    # ---- In-sample prediction using stored ECIs ----
    coefs = np.asarray(getattr(ce, "coefs"), dtype=np.float64)
    y_pred_per_prim_ins = predict_from_features(X, coefs)

    in_sample_recomputed = compute_stats(
        (y_true_per_prim * scale_site).tolist(),
        (y_pred_per_prim_ins * scale_site).tolist(),
    )

    # ---- Compare + summarize ----
    ok = (
        int(in_sample_recomputed["n"]) == int(in_sample_stored["n"])
        and math.isclose(float(in_sample_recomputed["mae_per_site"]),  float(in_sample_stored["mae_per_site"]),  rel_tol=rtol, abs_tol=atol)
        and math.isclose(float(in_sample_recomputed["rmse_per_site"]), float(in_sample_stored["rmse_per_site"]), rel_tol=rtol, abs_tol=atol)
        and math.isclose(float(in_sample_recomputed["max_abs_per_site"]), float(in_sample_stored["max_abs_per_site"]), rel_tol=rtol, abs_tol=atol)
    )

    if verbose:
        _print_summary(
            ce_key=ce_key,
            n=int(in_sample_recomputed["n"]),
            stored_in=in_sample_stored,
            recomputed_in=in_sample_recomputed,
            stored_cv=stored_cv,
            atol=atol, rtol=rtol, ok=ok,
        )

    return ok


# ---- Example usage:
verify_ce_in_sample("781041790a68cc664c64278fed25132d984b3bfbc802d207e41d8626818ff737")


════════════════════════════════════════════════════════════════════════
═══════════════ CE Validation (per-site metrics in meV) ════════════════
════════════════════════════════════════════════════════════════════════
CE key        : 781041790a68cc664c64278fed25132d984b3bfbc802d207e41d8626818ff737
samples (n)   : 102
tolerances    : atol=1e-06, rtol=1e-08
------------------------------------------------------------------------
In-sample
Metric         Stored    Recomputed   Δ (recomp − stored)
-------------  --------  ----------   -------------------
MAE               0.545      0.545           0.000
RMSE              0.691      0.691           0.000
Max|err|          1.818      1.818           0.000
------------------------------------------------------------------------
Stored 5-fold CV
Metric         Value
-------------  --------
n              102
MAE               1.312
RMSE              1.724
Max|err|          5.777
---------------------------------------------------------------

True

In [20]:
from typing import Any, Mapping, Sequence, cast
from collections import defaultdict
import math
import numpy as np
from numpy.typing import NDArray

from ase.atoms import Atoms
from pymatgen.core import Structure
from pymatgen.io.ase import AseAtomsAdaptor

# ---- phaseedge imports ----
from phaseedge.science.prototypes import PrototypeName
from phaseedge.jobs.store_ce_model import lookup_ce_by_key
from phaseedge.jobs.decide_relax import lookup_ff_task
from phaseedge.science.prototypes import make_prototype
from phaseedge.science.random_configs import make_one_snapshot
from phaseedge.schemas.mixture import sublattices_from_mixtures, Mixture
from phaseedge.jobs.train_ce import (
    featurize_structures,
    predict_from_features,
    compute_stats,
    _n_replace_sites_from_prototype,
)
from phaseedge.utils.keys import (
    compute_set_id,
    rng_for_index,
    occ_key_for_atoms,
)
from phaseedge.utils.rehydrators import rehydrate_ensemble_by_ce_key
from smol.cofe import ClusterExpansion


# ----------------- DB helpers -----------------

def _lookup_total_energy_eV(
    *, set_id: str, occ_key: str, model: str, relax_cell: bool, dtype: str
) -> float:
    """
    Modern schema only: lookup_ff_task returns the inner 'output' document.
    Energy is doc['output']['energy'].
    """
    doc = lookup_ff_task(
        set_id=set_id, occ_key=occ_key, model=model, relax_cell=relax_cell, dtype=dtype
    )
    if doc is None:
        raise RuntimeError(f"No FF task output found for set_id={set_id} occ_key={occ_key}")
    try:
        return float(doc["output"]["energy"])
    except Exception as exc:
        raise RuntimeError(
            f"FF task output missing energy for set_id={set_id} occ_key={occ_key}"
        ) from exc


# ----------------- Source parsing -----------------

def _gather_composition_mixtures(sources: Sequence[Mapping[str, Any]]) -> list[Mixture]:
    """Modern, normalized sources → list[Mixture]."""
    mixes: list[Mixture] = []
    for s in sources:
        if str(s.get("type", "")).lower() != "composition":
            continue
        for m in s.get("mixtures", []):
            mixes.append(Mixture.from_dict({
                "composition_map": dict(m["composition_map"]),
                "K": int(m["K"]),
                "seed": int(m["seed"]),
            }))
    return mixes


def _source_kind(sources: Sequence[Mapping[str, Any]]) -> str:
    """
    Return 'composition' or 'wl_refined_intent' based on modern source types.
    Raises if unsupported or ambiguous.
    """
    types = [str(s.get("type", "")).lower() for s in sources]
    if any(t == "wl_refined_intent" for t in types):
        return "wl_refined_intent"
    if any(t == "composition" for t in types):
        return "composition"
    raise RuntimeError(f"Unsupported sources: {types}")


def _get_base_ce_key_for_refined(sources: Sequence[Mapping[str, Any]]) -> str:
    for s in sources:
        if str(s.get("type", "")).lower() == "wl_refined_intent":
            b = s.get("base_ce_key")
            if not isinstance(b, str) or not b:
                raise RuntimeError("wl_refined_intent source missing valid 'base_ce_key'.")
            return b
    raise RuntimeError("wl_refined_intent source not found.")


# ----------------- Snapshot reconstruction (composition CEs) -----------------

def _build_occ_to_structure_for_set(
    *,
    set_id: str,
    mixture: Mixture,
    conv_cell: Atoms,
    supercell_diag: tuple[int, int, int],
) -> dict[str, Structure]:
    """
    Generate exactly K snapshots for this set (deterministic RNG),
    and return a map occ_key -> unrelaxed Structure.
    """
    occ2struct: dict[str, Structure] = {}
    for idx in range(mixture.K):
        rng = rng_for_index(set_id, idx)
        snap: Atoms = make_one_snapshot(
            conv_cell=conv_cell,
            supercell_diag=supercell_diag,
            composition_map=mixture.composition_map,
            rng=rng,
        )
        key = occ_key_for_atoms(snap)
        occ2struct[key] = cast(Structure, AseAtomsAdaptor.get_structure(snap))
    return occ2struct


# ----------------- Structure extraction (refined WL CEs) -----------------

def _structure_from_ff_doc(
    doc: Mapping[str, Any],
    *,
    base_ce_key: str | None = None,
) -> Structure:
    """
    Modern-only structure extraction priority:
    1) doc['structure'] as pymatgen dict
    2) doc['output']['structure'] as pymatgen dict
    3) doc['occ'] with base CE ensemble → reconstruct
    """
    s_top = doc.get("structure")
    if isinstance(s_top, Mapping) and s_top.get("@class") == "Structure":
        return Structure.from_dict(cast(Mapping[str, Any], s_top))

    out = doc.get("output", {})
    if isinstance(out, Mapping):
        s_out = out.get("structure")
        if isinstance(s_out, Mapping) and s_out.get("@class") == "Structure":
            return Structure.from_dict(cast(Mapping[str, Any], s_out))

    occ = doc.get("occ")
    if isinstance(occ, list) and all(isinstance(x, (int, np.integer)) for x in occ):
        if not base_ce_key:
            raise RuntimeError("FF doc contains 'occ' but no base_ce_key was provided for reconstruction.")
        ens = rehydrate_ensemble_by_ce_key(base_ce_key)
        occ_arr = np.asarray([int(x) for x in occ], dtype=np.int32)
        st = ens.processor.structure_from_occupancy(occ_arr)  # returns pymatgen Structure (modern smol)
        if isinstance(st, Structure):
            return st
        if isinstance(st, Atoms):
            return cast(Structure, AseAtomsAdaptor.get_structure(st))
        raise RuntimeError("Unknown structure type returned by ensemble processor.")

    raise RuntimeError("Could not extract structure from FF doc (no structure/occ present).")


# ----------------- Pretty summary -----------------

def _fmt_mev(x: float | str) -> str:
    if isinstance(x, str):
        return x
    return f"{1e3 * float(x):.3f}"

def _print_summary(
    *,
    ce_key: str,
    n: int,
    stored_in: Mapping[str, Any],
    recomputed_in: Mapping[str, Any],
    stored_cv: Mapping[str, Any] | None,
    atol: float,
    rtol: float,
    ok: bool,
) -> None:
    s_mae, s_rmse, s_max = (
        float(stored_in["mae_per_site"]),
        float(stored_in["rmse_per_site"]),
        float(stored_in["max_abs_per_site"]),
    )
    r_mae, r_rmse, r_max = (
        float(recomputed_in["mae_per_site"]),
        float(recomputed_in["rmse_per_site"]),
        float(recomputed_in["max_abs_per_site"]),
    )

    print("═" * 72)
    print(f" CE Validation (per-site metrics in meV) ".center(72, "═"))
    print("═" * 72)
    print(f"CE key        : {ce_key}")
    print(f"samples (n)   : {n}")
    print(f"tolerances    : atol={atol:g}, rtol={rtol:g}")
    print("-" * 72)
    print("In-sample")
    print("Metric         Stored    Recomputed   Δ (recomp − stored)")
    print("-------------  --------  ----------   -------------------")
    print(f"MAE            {_fmt_mev(s_mae):>8}   {_fmt_mev(r_mae):>8}        {_fmt_mev(r_mae - s_mae):>8}")
    print(f"RMSE           {_fmt_mev(s_rmse):>8}   {_fmt_mev(r_rmse):>8}        {_fmt_mev(r_rmse - s_rmse):>8}")
    print(f"Max|err|       {_fmt_mev(s_max):>8}   {_fmt_mev(r_max):>8}        {_fmt_mev(r_max - s_max):>8}")
    print("-" * 72)

    if stored_cv and all(k in stored_cv for k in ("n", "mae_per_site", "rmse_per_site", "max_abs_per_site")):
        cv_n = int(stored_cv["n"])
        cv_mae = float(stored_cv["mae_per_site"])
        cv_rmse = float(stored_cv["rmse_per_site"])
        cv_max = float(stored_cv["max_abs_per_site"])
        print("Stored 5-fold CV")
        print("Metric         Value")
        print("-------------  --------")
        print(f"n              {cv_n}")
        print(f"MAE            {_fmt_mev(cv_mae):>8}")
        print(f"RMSE           {_fmt_mev(cv_rmse):>8}")
        print(f"Max|err|       {_fmt_mev(cv_max):>8}")
        print("-" * 72)
    else:
        print("Stored 5-fold CV: <not available>")
        print("-" * 72)

    print("Result         : " + ("✅  PASS (in-sample matches stored)"
                              if ok else "❌  FAIL (in-sample mismatch)"))
    print("═" * 72)


# ----------------- Main verifier (in-sample only) -----------------

def verify_ce_in_sample(
    ce_key: str,
    *,
    atol: float = 1e-6,
    rtol: float = 1e-8,
    verbose: bool = True,
) -> bool:
    """
    Recompute and validate **in-sample** per-site stats for:
      - composition CEs (ensure_ce_from_mixtures)
      - refined WL CEs (ensure_ce_from_refined_wl)
    Modern schema only. PASS/FAIL is based solely on top-level 'in_sample' stats.
    """
    # ---- Load document and assume modern schema ----
    doc = lookup_ce_by_key(ce_key)
    if not doc:
        raise RuntimeError(f"No CE found for ce_key={ce_key}")

    stats_stored = cast(Mapping[str, Any], doc["stats"])
    in_sample_stored = cast(Mapping[str, Any], stats_stored["in_sample"])
    stored_cv = cast(Mapping[str, Any] | None, stats_stored.get("five_fold_cv"))

    sources = cast(Sequence[Mapping[str, Any]], doc["sources"])
    kind = _source_kind(sources)

    prototype = PrototypeName(doc["prototype"])
    proto_params = cast(Mapping[str, Any], doc["prototype_params"])
    sx, sy, sz = (int(x) for x in doc["supercell_diag"])
    supercell_diag = (sx, sy, sz)
    model = cast(str, doc["model"])
    relax_cell = bool(doc["relax_cell"])
    dtype = cast(str, doc["dtype"])
    train_refs = cast(Sequence[Mapping[str, Any]], doc["train_refs"])

    # ---- Determine sublattices (needed for per-site scaling) ----
    if kind == "composition":
        mixtures = _gather_composition_mixtures(sources)
        if not mixtures:
            raise RuntimeError("Expected 'composition' mixtures in sources.")
        sublattices = sublattices_from_mixtures(mixtures)
    elif kind == "wl_refined_intent":
        base_ce_key = _get_base_ce_key_for_refined(sources)
        base_doc = lookup_ce_by_key(base_ce_key)
        if not base_doc:
            raise RuntimeError(f"Base CE not found for base_ce_key={base_ce_key}")
        base_mixtures = _gather_composition_mixtures(cast(Sequence[Mapping[str, Any]], base_doc["sources"]))
        if not base_mixtures:
            raise RuntimeError("Base CE sources did not contain 'composition' mixtures.")
        sublattices = sublattices_from_mixtures(base_mixtures)
    else:
        raise RuntimeError(f"Unhandled source kind: {kind}")

    # ---- Prepare prototype conventional cell (ASE Atoms) for composition path ----
    conv_cell: Atoms | None = None
    if kind == "composition":
        conv_cell = make_prototype(prototype, **dict(proto_params))

    # ---- Reconstruct training structures & energies in deterministic order ----
    structures: list[Structure] = []
    energies_supercell: list[float] = []

    if kind == "composition":
        # Build occ->structure cache per set_id
        mixtures = _gather_composition_mixtures(sources)
        sid_for_mix: dict[Mixture, str] = {}
        for mix in mixtures:
            sid = compute_set_id(
                prototype=prototype,
                prototype_params=proto_params,
                supercell_diag=supercell_diag,
                composition_map=mix.composition_map,
                seed=mix.seed,
            )
            sid_for_mix[mix] = sid

        # Strict coverage check
        refs_by_sid: dict[str, list[str]] = defaultdict(list)
        for r in train_refs:
            refs_by_sid[str(r["set_id"])].append(str(r["occ_key"]))
        computed_sids = set(sid_for_mix.values())
        if set(refs_by_sid.keys()) != computed_sids:
            missing_in_refs = sorted(computed_sids - set(refs_by_sid.keys()))
            missing_in_mixes = sorted(set(refs_by_sid.keys()) - computed_sids)
            raise RuntimeError(
                "Mismatch between mixture-derived set_ids and train_refs set_ids.\n"
                f"  Missing in train_refs: {missing_in_refs}\n"
                f"  Missing in mixtures  : {missing_in_mixes}"
            )

        # Cache structures per set
        occ_struct_cache: dict[str, dict[str, Structure]] = {}
        assert conv_cell is not None
        for mix, sid in sid_for_mix.items():
            occ_struct_cache[sid] = _build_occ_to_structure_for_set(
                set_id=sid,
                mixture=mix,
                conv_cell=conv_cell,
                supercell_diag=supercell_diag,
            )

        # Iterate train_refs in their stored order
        for r in train_refs:
            sid = str(r["set_id"])
            ok = str(r["occ_key"])
            s = occ_struct_cache[sid].get(ok)
            if s is None:
                raise RuntimeError(f"Could not reconstruct snapshot for set_id={sid} occ_key={ok}")
            structures.append(s)
            energies_supercell.append(
                _lookup_total_energy_eV(set_id=sid, occ_key=ok, model=model, relax_cell=relax_cell, dtype=dtype)
            )

    else:  # wl_refined_intent
        base_ce_key = _get_base_ce_key_for_refined(sources)
        # Iterate train_refs in stored order; structures sourced from FF doc (structure or occ via base CE)
        for r in train_refs:
            sid = str(r["set_id"])
            ok = str(r["occ_key"])
            ff_doc = lookup_ff_task(set_id=sid, occ_key=ok, model=model, relax_cell=relax_cell, dtype=dtype)
            if ff_doc is None:
                raise RuntimeError(f"Missing FF task for set_id={sid} occ_key={ok}")
            structures.append(_structure_from_ff_doc(ff_doc, base_ce_key=base_ce_key))
            energies_supercell.append(float(ff_doc["output"]["energy"]))

    if len(structures) != len(train_refs):
        raise RuntimeError("Mismatch in number of reconstructed structures vs. train_refs.")

    # ---- Rehydrate CE and build features ----
    ce = ClusterExpansion.from_dict(cast(Mapping[str, Any], doc["payload"]))
    subspace = ce.cluster_subspace

    _, X = featurize_structures(
        subspace=subspace,
        structures=cast(Sequence[Structure], structures),
        supercell_diag=supercell_diag,
    )

    # Targets are per primitive/conventional cell
    n_prims = int(np.prod(np.asarray(supercell_diag, dtype=int)))
    y_true_per_prim: NDArray[np.float64] = (
        np.asarray(energies_supercell, dtype=np.float64) / float(n_prims)
    ).astype(np.float64, copy=False)

    # Per-site scaling
    sites_per_supercell = _n_replace_sites_from_prototype(
        prototype=prototype,
        prototype_params=proto_params,
        supercell_diag=supercell_diag,
        sublattices=sublattices,
    )
    if sites_per_supercell % n_prims != 0:
        raise RuntimeError("sites_per_supercell is not divisible by n_prims.")
    sites_per_prim = sites_per_supercell // n_prims
    scale_site = 1.0 / float(sites_per_prim)

    # ---- In-sample prediction using stored ECIs ----
    coefs = np.asarray(getattr(ce, "coefs"), dtype=np.float64)
    y_pred_per_prim_ins = predict_from_features(X, coefs)

    in_sample_recomputed = compute_stats(
        (y_true_per_prim * scale_site).tolist(),
        (y_pred_per_prim_ins * scale_site).tolist(),
    )

    # ---- Compare + summarize ----
    ok = (
        int(in_sample_recomputed["n"]) == int(in_sample_stored["n"])
        and math.isclose(float(in_sample_recomputed["mae_per_site"]),  float(in_sample_stored["mae_per_site"]),  rel_tol=rtol, abs_tol=atol)
        and math.isclose(float(in_sample_recomputed["rmse_per_site"]), float(in_sample_stored["rmse_per_site"]), rel_tol=rtol, abs_tol=atol)
        and math.isclose(float(in_sample_recomputed["max_abs_per_site"]), float(in_sample_stored["max_abs_per_site"]), rel_tol=rtol, abs_tol=atol)
    )

    if verbose:
        _print_summary(
            ce_key=ce_key,
            n=int(in_sample_recomputed["n"]),
            stored_in=in_sample_stored,
            recomputed_in=in_sample_recomputed,
            stored_cv=stored_cv,
            atol=atol, rtol=rtol, ok=ok,
        )

    return ok


# ---- Example usage:
verify_ce_in_sample("76195c946ffef01b43a45fbccfb192c96d55011b69a54b0ccc2d6ff9f6f59bba")


════════════════════════════════════════════════════════════════════════
═══════════════ CE Validation (per-site metrics in meV) ════════════════
════════════════════════════════════════════════════════════════════════
CE key        : 76195c946ffef01b43a45fbccfb192c96d55011b69a54b0ccc2d6ff9f6f59bba
samples (n)   : 102
tolerances    : atol=1e-06, rtol=1e-08
------------------------------------------------------------------------
In-sample
Metric         Stored    Recomputed   Δ (recomp − stored)
-------------  --------  ----------   -------------------
MAE               0.478      0.478           0.000
RMSE              0.601      0.601           0.000
Max|err|          1.645      1.645           0.000
------------------------------------------------------------------------
Stored 5-fold CV
Metric         Value
-------------  --------
n              102
MAE               1.206
RMSE              1.577
Max|err|          4.511
---------------------------------------------------------------

True

In [None]:
from typing import Any, Mapping, Sequence, cast
from collections import defaultdict
import math
import numpy as np
from numpy.typing import NDArray

from ase.atoms import Atoms
from pymatgen.core import Structure
from pymatgen.io.ase import AseAtomsAdaptor

# ---- phaseedge imports ----
from phaseedge.science.prototypes import PrototypeName
from phaseedge.jobs.store_ce_model import lookup_ce_by_key
from phaseedge.jobs.decide_relax import lookup_ff_task
from phaseedge.science.prototypes import make_prototype
from phaseedge.schemas.mixture import sublattices_from_mixtures, Mixture
from phaseedge.jobs.train_ce import (
    featurize_structures,
    predict_from_features,
    compute_stats,
    _n_replace_sites_from_prototype,
)
from phaseedge.utils.keys import (
    compute_set_id,
    rng_for_index,
    occ_key_for_atoms,
)
from phaseedge.utils.rehydrators import rehydrate_ensemble_by_ce_key
from smol.cofe import ClusterExpansion


# ----------------- helpers: sources & structures -----------------

def _source_kind(sources: Sequence[Mapping[str, Any]]) -> str:
    kinds = [str(s.get("type", "")).lower() for s in sources]
    if any(k == "wl_refined_intent" for k in kinds):
        return "wl_refined_intent"
    if any(k == "composition" for k in kinds):
        return "composition"
    raise RuntimeError(f"Unsupported sources: {kinds!r}")

def _get_base_ce_key_for_refined(sources: Sequence[Mapping[str, Any]]) -> str:
    for s in sources:
        if str(s.get("type", "")).lower() == "wl_refined_intent":
            b = s.get("base_ce_key")
            if not isinstance(b, str) or not b:
                raise RuntimeError("wl_refined_intent source missing valid 'base_ce_key'.")
            return b
    raise RuntimeError("wl_refined_intent source not found.")

def _gather_composition_mixtures(sources: Sequence[Mapping[str, Any]]) -> list[Mixture]:
    mixes: list[Mixture] = []
    for s in sources:
        if str(s.get("type", "")).lower() != "composition":
            continue
        for m in s.get("mixtures", []):
            mixes.append(Mixture.from_dict({
                "composition_map": dict(m["composition_map"]),
                "K": int(m["K"]),
                "seed": int(m["seed"]),
            }))
    return mixes

def _structure_from_ff_doc(
    doc: Mapping[str, Any],
    *,
    base_ce_key: str,
) -> Structure:
    """
    Modern-only structure extraction priority for refined WL path:
      1) doc['structure'] as pymatgen dict
      2) doc['output']['structure'] as pymatgen dict
      3) doc['occ'] -> reconstruct via base CE ensemble
    """
    s_top = doc.get("structure")
    if isinstance(s_top, Mapping) and s_top.get("@class") == "Structure":
        return Structure.from_dict(cast(Mapping[str, Any], s_top))

    out = doc.get("output", {})
    if isinstance(out, Mapping):
        s_out = out.get("structure")
        if isinstance(s_out, Mapping) and s_out.get("@class") == "Structure":
            return Structure.from_dict(cast(Mapping[str, Any], s_out))

    occ = doc.get("occ")
    if isinstance(occ, list) and all(isinstance(x, (int, np.integer)) for x in occ):
        ens = rehydrate_ensemble_by_ce_key(base_ce_key)
        occ_arr = np.asarray([int(x) for x in occ], dtype=np.int32)
        st = ens.processor.structure_from_occupancy(occ_arr)
        if isinstance(st, Structure):
            return st
        if isinstance(st, Atoms):
            return cast(Structure, AseAtomsAdaptor.get_structure(st))

    raise RuntimeError("Could not extract structure from FF doc (no structure/occ present).")

def _lookup_total_energy_eV(
    *, set_id: str, occ_key: str, model: str, relax_cell: bool, dtype: str
) -> float:
    doc = lookup_ff_task(set_id=set_id, occ_key=occ_key, model=model, relax_cell=relax_cell, dtype=dtype)
    if doc is None:
        raise RuntimeError(f"No FF task output found for set_id={set_id} occ_key={occ_key}")
    try:
        return float(doc["output"]["energy"])
    except Exception as exc:
        raise RuntimeError(
            f"FF task output missing energy for set_id={set_id} occ_key={occ_key}"
        ) from exc

def _fmt_mev(x: float | str) -> str:
    return x if isinstance(x, str) else f"{1e3 * float(x):.3f}"

def _print_comparison_summary(
    *,
    refined_ce_key: str,
    base_ce_key: str,
    n: int,
    random_on_refined: Mapping[str, Any],
    refined_cv: Mapping[str, Any] | None,
) -> None:
    r_mae = float(random_on_refined["mae_per_site"])
    r_rmse = float(random_on_refined["rmse_per_site"])
    r_max = float(random_on_refined["max_abs_per_site"])

    print("═" * 76)
    print(" Random CE → Refined-set Evaluation (per-site metrics in meV) ".center(76, "═"))
    print("═" * 76)
    print(f"Refined CE key : {refined_ce_key}")
    print(f"Base (random) CE key : {base_ce_key}")
    print(f"samples (n)    : {n}")
    print("-" * 76)
    print("Random CE on refined training set")
    print("Metric         Value")
    print("-------------  --------")
    print(f"MAE            {_fmt_mev(r_mae):>8}")
    print(f"RMSE           {_fmt_mev(r_rmse):>8}")
    print(f"Max|err|       {_fmt_mev(r_max):>8}")
    print("-" * 76)

    if refined_cv and all(k in refined_cv for k in ("n", "mae_per_site", "rmse_per_site", "max_abs_per_site")):
        cv_n = int(refined_cv["n"])
        cv_mae = float(refined_cv["mae_per_site"])
        cv_rmse = float(refined_cv["rmse_per_site"])
        cv_max = float(refined_cv["max_abs_per_site"])
        print("Stored 5-fold CV (from refined CE)")
        print("Metric         Value       Δ (random − CV)")
        print("-------------  --------    --------------")
        print(f"n              {cv_n}")
        print(f"MAE            {_fmt_mev(cv_mae):>8}    {_fmt_mev(r_mae - cv_mae):>8}")
        print(f"RMSE           {_fmt_mev(cv_rmse):>8}    {_fmt_mev(r_rmse - cv_rmse):>8}")
        print(f"Max|err|       {_fmt_mev(cv_max):>8}    {_fmt_mev(r_max - cv_max):>8}")
        print("-" * 76)
    else:
        print("Stored 5-fold CV (from refined CE): <not available>")
        print("-" * 76)

    print("═" * 76)


# ----------------- main entrypoint -----------------

def score_random_ce_on_refined_training(
    refined_ce_key: str,
    *,
    atol: float = 1e-6,
    rtol: float = 1e-8,
    verbose: bool = True,
) -> dict[str, Any]:
    """
    Given a **refined** CE key (ensure_ce_from_refined_wl), load the **base/random** CE
    used to make it, and compute error metrics of the random CE on the refined CE's
    training samples. Print a comparison against the refined CE's stored 5-fold CV.

    Returns a dict with keys:
      - 'random_on_refined': CEStats for random CE evaluated on refined training set
      - 'refined_cv'       : stored five_fold_cv stats from refined CE (or None)
      - 'n'                : number of refined training samples evaluated
      - 'base_ce_key'      : the random CE key
      - 'refined_ce_key'   : the input refined CE key
    """
    # Load refined CE doc (modern schema)
    ref_doc = lookup_ce_by_key(refined_ce_key)
    if not ref_doc:
        raise RuntimeError(f"No CE found for refined_ce_key={refined_ce_key}")

    sources_ref = cast(Sequence[Mapping[str, Any]], ref_doc["sources"])
    if _source_kind(sources_ref) != "wl_refined_intent":
        raise RuntimeError("The provided CE is not a refined WL CE (expected 'wl_refined_intent' source).")

    base_ce_key = _get_base_ce_key_for_refined(sources_ref)
    base_doc = lookup_ce_by_key(base_ce_key)
    if not base_doc:
        raise RuntimeError(f"Base CE not found for base_ce_key={base_ce_key}")

    # Pull geometry + metadata (assumed identical across base/refined)
    prototype = PrototypeName(ref_doc["prototype"])
    proto_params = cast(Mapping[str, Any], ref_doc["prototype_params"])
    sx, sy, sz = (int(x) for x in ref_doc["supercell_diag"])
    supercell_diag = (sx, sy, sz)
    model = cast(str, ref_doc["model"])
    relax_cell = bool(ref_doc["relax_cell"])
    dtype = cast(str, ref_doc["dtype"])
    train_refs = cast(Sequence[Mapping[str, Any]], ref_doc["train_refs"])

    # Sublattices come from the base/random CE mixtures
    base_mixtures = _gather_composition_mixtures(cast(Sequence[Mapping[str, Any]], base_doc["sources"]))
    if not base_mixtures:
        raise RuntimeError("Base CE sources did not contain 'composition' mixtures.")
    sublattices = sublattices_from_mixtures(base_mixtures)

    # Build refined training structures + energies (order = stored train_refs)
    structures: list[Structure] = []
    energies_supercell: list[float] = []
    for r in train_refs:
        sid = str(r["set_id"])
        ok = str(r["occ_key"])
        ff_doc = lookup_ff_task(set_id=sid, occ_key=ok, model=model, relax_cell=relax_cell, dtype=dtype)
        if ff_doc is None:
            raise RuntimeError(f"Missing FF task for set_id={sid} occ_key={ok}")
        structures.append(_structure_from_ff_doc(ff_doc, base_ce_key=base_ce_key))
        energies_supercell.append(float(ff_doc["output"]["energy"]))

    if len(structures) != len(train_refs):
        raise RuntimeError("Mismatch in number of reconstructed structures vs. refined train_refs.")

    # Rehydrate **base/random** CE and build features over refined structures
    base_ce = ClusterExpansion.from_dict(cast(Mapping[str, Any], base_doc["payload"]))
    subspace = base_ce.cluster_subspace
    _, X = featurize_structures(
        subspace=subspace,
        structures=cast(Sequence[Structure], structures),
        supercell_diag=supercell_diag,
    )

    # Targets per primitive/conventional cell
    n_prims = int(np.prod(np.asarray(supercell_diag, dtype=int)))
    y_true_per_prim: NDArray[np.float64] = (
        np.asarray(energies_supercell, dtype=np.float64) / float(n_prims)
    ).astype(np.float64, copy=False)

    # Per-site scaling
    sites_per_supercell = _n_replace_sites_from_prototype(
        prototype=prototype,
        prototype_params=proto_params,
        supercell_diag=supercell_diag,
        sublattices=sublattices,
    )
    if sites_per_supercell % n_prims != 0:
        raise RuntimeError("sites_per_supercell is not divisible by n_prims.")
    sites_per_prim = sites_per_supercell // n_prims
    scale_site = 1.0 / float(sites_per_prim)

    # Predict with **base/random** CE ECIs
    coefs = np.asarray(getattr(base_ce, "coefs"), dtype=np.float64)
    y_pred_per_prim = predict_from_features(X, coefs)

    random_on_refined = compute_stats(
        (y_true_per_prim * scale_site).tolist(),
        (y_pred_per_prim * scale_site).tolist(),
    )

    refined_cv = cast(Mapping[str, Any] | None, cast(Mapping[str, Any], ref_doc["stats"]).get("five_fold_cv"))

    if verbose:
        _print_comparison_summary(
            refined_ce_key=refined_ce_key,
            base_ce_key=base_ce_key,
            n=int(random_on_refined["n"]),
            random_on_refined=random_on_refined,
            refined_cv=refined_cv,
        )

    return {
        "random_on_refined": random_on_refined,
        "refined_cv": refined_cv,
        "n": int(random_on_refined["n"]),
        "base_ce_key": base_ce_key,
        "refined_ce_key": refined_ce_key,
    }


# ---- Example usage:
res = score_random_ce_on_refined_training("76195c946ffef01b43a45fbccfb192c96d55011b69a54b0ccc2d6ff9f6f59bba")
