In [1]:
from typing import Any, Mapping, Sequence, cast
import numpy as np

from pymatgen.core import Structure
from smol.cofe import ClusterExpansion

# ---- phaseedge imports (modern-only) ----
from phaseedge.jobs.store_ce_model import lookup_ce_by_key
from phaseedge.jobs.train_ce import (
    featurize_structures,
    lookup_train_refs_by_key,
    predict_from_features,
    compute_stats,
    _n_replace_sites_from_prototype,
)
from phaseedge.schemas.mixture import Mixture, sublattices_from_mixtures
from phaseedge.storage.store import lookup_total_energy_eV
from phaseedge.utils.keys import occ_key_for_structure


# ----------------- small helpers -----------------

def fmt_mev(x: float | str) -> str:
    if isinstance(x, str):
        return x
    return f"{1e3 * float(x):.3f}"


def _gather_composition_mixtures(sources: Sequence[Mapping[str, Any]]) -> list[Mixture]:
    mixes: list[Mixture] = []
    for s in sources:
        if str(s.get("type", "")).lower() != "composition":
            continue
        for m in s.get("mixtures", []):
            mixes.append(
                Mixture.from_dict(
                    {
                        "composition_map": dict(m["composition_map"]),
                        "K": int(m["K"]),
                        "seed": int(m["seed"]),
                    }
                )
            )
    if not mixes:
        raise RuntimeError("Expected 'composition' mixtures in sources (modern schema).")
    return mixes


def _sublattices_for_doc(doc: Mapping[str, Any]) -> dict[str, tuple[str, ...]]:
    """
    Replaceable sublattices for per-site scaling:
      - composition CE: from this doc's mixtures
      - wl_refined CE : from the base (composition) CE's mixtures
    """
    sources = cast(Sequence[Mapping[str, Any]], doc["sources"])
    if not sources:
        raise RuntimeError("CE document missing 'sources'.")
    s0 = sources[0]
    s_type = str(s0.get("type", "")).lower()

    if s_type == "composition":
        return sublattices_from_mixtures(_gather_composition_mixtures(sources))

    if s_type == "dopt_sampling_intent":
        base_ce_key = cast(str, s0["base_ce_key"])
        base_doc = lookup_ce_by_key(base_ce_key)
        if not base_doc:
            raise RuntimeError(f"Base CE not found for dopt_sampling_intent source: {base_ce_key}")
        return sublattices_from_mixtures(
            _gather_composition_mixtures(cast(Sequence[Mapping[str, Any]], base_doc["sources"]))
        )

    raise RuntimeError(f"Unsupported source type in CE doc: {s_type!r}")


def _fmt_f(x: float) -> str:
    return f"{float(x):.6g}"


def _grab(dm: Mapping[str, Any] | None, key: str, default: Any) -> Any:
    return default if dm is None else dm.get(key, default)


def _print_two_col_table(title: str, col_a_name: str, col_b_name: str, rows: list[tuple[str, str, str]]) -> None:
    """Pretty side-by-side ASCII table."""
    w_name = max(len(r[0]) for r in rows + [("Metric", "", "")])
    w_a = max(len(col_a_name), max(len(r[1]) for r in rows))
    w_b = max(len(col_b_name), max(len(r[2]) for r in rows))
    total = 3 + w_name + 3 + w_a + 3 + w_b + 3

    print(title.center(total, "─"))
    print(f" {'Metric'.ljust(w_name)} │ {col_a_name.ljust(w_a)} │ {col_b_name.ljust(w_b)} ")
    print("─" * total)
    for name, a, b in rows:
        print(f" {name.ljust(w_name)} │ {a.ljust(w_a)} │ {b.ljust(w_b)} ")
    print("")


def _cmp(a: float, b: float, *, atol: float, rtol: float) -> int:
    """
    Compare a vs b with tolerances.
      -1 if a < b beyond tol
       0 if |a-b| <= tol
       1 if a > b beyond tol
    """
    tol = max(atol, rtol * abs(b))
    if a < b - tol:
        return -1
    if a > b + tol:
        return 1
    return 0


def _better_text(
    label_left: str,
    label_right: str,
    *,
    mae_left: float,
    rmse_left: float,
    mae_right: float,
    rmse_right: float,
    atol: float,
    rtol: float,
) -> str:
    """
    Produce a human-readable verdict that *always* names which is better
    when there is a clear winner; otherwise reports 'Comparable' or 'No clear winner'.
    """
    c_mae = _cmp(mae_left, mae_right, atol=atol, rtol=rtol)
    c_rmse = _cmp(rmse_left, rmse_right, atol=atol, rtol=rtol)

    d_mae = abs(mae_left - mae_right)
    d_rmse = abs(rmse_left - rmse_right)

    if c_mae == -1 and c_rmse == -1:
        # left better
        return (
            f"Better: {label_left} "
            f"(MAE ↓ by {fmt_mev(d_mae)} meV/site, RMSE ↓ by {fmt_mev(d_rmse)} meV/site)."
        )
    if c_mae == 1 and c_rmse == 1:
        # right better
        return (
            f"Better: {label_right} "
            f"(MAE ↓ by {fmt_mev(d_mae)} meV/site, RMSE ↓ by {fmt_mev(d_rmse)} meV/site)."
        )
    if c_mae == 0 and c_rmse == 0:
        return (
            "Comparable within tolerance "
            f"(ΔMAE≈{fmt_mev(d_mae)} meV/site, ΔRMSE≈{fmt_mev(d_rmse)} meV/site)."
        )

    # mixed case: one metric favors left, the other favors right
    pieces: list[str] = []
    if c_mae == -1:
        pieces.append(f"{label_left} has lower MAE by {fmt_mev(d_mae)} meV/site")
    elif c_mae == 1:
        pieces.append(f"{label_right} has lower MAE by {fmt_mev(d_mae)} meV/site")
    if c_rmse == -1:
        pieces.append(f"{label_left} has lower RMSE by {fmt_mev(d_rmse)} meV/site")
    elif c_rmse == 1:
        pieces.append(f"{label_right} has lower RMSE by {fmt_mev(d_rmse)} meV/site")
    return "No clear winner: " + "; ".join(pieces) + "."


def _load_structures_and_energies(
    *,
    dataset_key: str,
    model: str,
    relax_cell: bool,
) -> tuple[list[Structure], list[float]]:
    """
    Load structures and *supercell* total energies for a dataset key.
    """
    train_refs = lookup_train_refs_by_key(dataset_key)
    if not train_refs:
        raise RuntimeError(f"Dataset {dataset_key} missing 'train_refs'.")

    structures: list[Structure] = []
    energies_eV_super: list[float] = []

    for i, r in enumerate(train_refs):
        # sanity: occ_key should round-trip from structure
        ok2 = occ_key_for_structure(r["structure"])
        ok_expected = cast(str, r["occ_key"])
        if ok2 != ok_expected:
            raise RuntimeError(
                f"train_refs[{i}] occ_key mismatch: expected {ok_expected}, rebuilt {ok2}."
            )

        E = lookup_total_energy_eV(
            set_id=cast(str, r["set_id"]),
            occ_key=ok_expected,
            model=model,
            relax_cell=relax_cell,
        )
        if E is None:
            raise RuntimeError(
                f"Energy not found in store for train_refs[{i}] set_id={r['set_id']}, occ_key={ok_expected[:12]}..."
            )

        structures.append(r["structure"])
        energies_eV_super.append(E)

    return structures, energies_eV_super


def _predict_per_prim(
    ce: ClusterExpansion,
    *,
    structures: list[Structure],
    supercell_diag: tuple[int, int, int],
) -> np.ndarray:
    """
    Predict energy per-primitive using CE coefficients and features.
    """
    _, X = featurize_structures(
        subspace=ce.cluster_subspace,
        structures=structures,
        supercell_diag=supercell_diag,
    )
    coefs = np.asarray(getattr(ce, "coefs"), dtype=np.float64)
    y_pred_per_prim = predict_from_features(X, coefs)
    return y_pred_per_prim


# ----------------- main comparison -----------------

def compare_comp_vs_refined_on_refined_dataset(
    refined_ce_key: str,
    *,
    atol: float = 1e-6,
    rtol: float = 1e-8,
    verbose: bool = True,
) -> dict[str, Any]:
    """
    Cross-compare base composition CE and refined WL CE.

    Inputs:
      - refined_ce_key: key of a **refined WL CE**.

    What it prints (per-site, meV), refined is always the RIGHT column:
      1) Base CE evaluated on the *refined* dataset  (left)  vs  Refined CE stored 5-fold CV  (right)
      2) Base CE stored 5-fold CV                     (left)  vs  Refined CE evaluated on the *base* dataset (right)

    Returns (dict):
      - refined_ce_key, base_ce_key, n_prims, sites_per_prim
      - base_on_refined_stats_per_site
      - refined_cv_stats_per_site
      - refined_on_base_stats_per_site
      - base_cv_stats_per_site
      - design_metrics: {base, refined}
    """
    # --- Load refined CE doc ---
    ref_doc = lookup_ce_by_key(refined_ce_key)
    if not ref_doc:
        raise RuntimeError(f"No CE found for refined_ce_key={refined_ce_key}")

    sources = cast(Sequence[Mapping[str, Any]], ref_doc["sources"])
    if not sources:
        raise RuntimeError("Refined CE doc has no sources.")
    s0 = sources[0]
    s_type = str(s0.get("type", "")).lower()
    if s_type != "dopt_sampling_intent":
        raise RuntimeError(f"Provided CE is not a dopt_sampling_intent CE (source.type={s_type!r}).")

    base_ce_key = cast(str, s0["base_ce_key"])
    base_doc = lookup_ce_by_key(base_ce_key)
    if not base_doc:
        raise RuntimeError(f"Base (composition) CE not found: {base_ce_key}")

    # --- Basic identity checks (should match in the modern pipeline) ---
    r_proto = ref_doc["prototype"]
    b_proto = base_doc["prototype"]
    if r_proto != b_proto:
        raise RuntimeError("Prototype mismatch between refined CE and base composition CE.")
    r_pp = cast(Mapping[str, Any], ref_doc["prototype_params"])
    b_pp = cast(Mapping[str, Any], base_doc["prototype_params"])
    if dict(r_pp) != dict(b_pp):
        raise RuntimeError("prototype_params mismatch between refined CE and base composition CE.")
    
    xs, ys, zs = ref_doc["supercell_diag"]
    r_sc = (int(xs), int(ys), int(zs))
    b_xs, b_ys, b_zs = base_doc["supercell_diag"]
    b_sc = (int(b_xs), int(b_ys), int(b_zs))
    if r_sc != b_sc:
        raise RuntimeError("supercell_diag mismatch between refined CE and base composition CE.")

    # --- Per-site scaling (replaceable sites / primitive) ---
    sublattices = _sublattices_for_doc(base_doc)  # composition-ground truth
    n_prims = int(np.prod(np.asarray(r_sc, dtype=int)))
    n_sites_const = _n_replace_sites_from_prototype(
        prototype=r_proto,
        prototype_params=r_pp,
        supercell_diag=r_sc,
        sublattices=sublattices,
    )
    if n_sites_const % n_prims != 0:
        raise RuntimeError("sites_per_supercell not divisible by n_prims (prototype/sublattice mismatch).")
    sites_per_prim = n_sites_const // n_prims
    scale_site = 1.0 / float(sites_per_prim)

    # --- Build refined dataset from refined CE train_refs (structures + energies) ---
    model_r = cast(str, ref_doc["model"])
    relax_cell_r = bool(ref_doc["relax_cell"])

    structures_refined, energies_eV_super_refined = _load_structures_and_energies(
        dataset_key=cast(str, ref_doc["dataset_key"]),
        model=model_r,
        relax_cell=relax_cell_r,
    )
    y_true_per_prim_refined = (np.asarray(energies_eV_super_refined, dtype=np.float64) / float(n_prims)).astype(np.float64, copy=False)

    # --- Evaluate BASE CE on the refined dataset ---
    base_ce = ClusterExpansion.from_dict(cast(Mapping[str, Any], base_doc["payload"]))
    y_pred_per_prim_base_on_refined = _predict_per_prim(
        base_ce,
        structures=structures_refined,
        supercell_diag=r_sc,
    )
    base_on_refined_stats = compute_stats(
        (y_true_per_prim_refined * scale_site).tolist(),
        (y_pred_per_prim_base_on_refined * scale_site).tolist(),
    )  # per-site

    # --- Pull refined CE stored 5-fold CV stats (per-site) ---
    ref_stats = cast(Mapping[str, Any], ref_doc["stats"])
    ref_cv = cast(Mapping[str, Any], ref_stats["five_fold_cv"])
    # also harvest design metrics (stored at training time)
    dm_base = cast(Mapping[str, Any] | None, base_doc.get("design_metrics"))
    dm_ref = cast(Mapping[str, Any] | None, ref_doc.get("design_metrics"))

    # --- Evaluate REFINED CE on the base dataset ---
    model_b = cast(str, base_doc["model"])
    relax_cell_b = bool(base_doc["relax_cell"])

    structures_base, energies_eV_super_base = _load_structures_and_energies(
        dataset_key=cast(str, base_doc["dataset_key"]),
        model=model_b,
        relax_cell=relax_cell_b,
    )
    y_true_per_prim_base = (np.asarray(energies_eV_super_base, dtype=np.float64) / float(n_prims)).astype(np.float64, copy=False)

    refined_ce = ClusterExpansion.from_dict(cast(Mapping[str, Any], ref_doc["payload"]))
    y_pred_per_prim_refined_on_base = _predict_per_prim(
        refined_ce,
        structures=structures_base,
        supercell_diag=b_sc,  # == r_sc
    )
    refined_on_base_stats = compute_stats(
        (y_true_per_prim_base * scale_site).tolist(),
        (y_pred_per_prim_refined_on_base * scale_site).tolist(),
    )  # per-site

    # --- Base CE stored 5-fold CV (per-site) ---
    base_stats = cast(Mapping[str, Any], base_doc["stats"])
    base_cv = cast(Mapping[str, Any], base_stats["five_fold_cv"])

    # --- Pretty side-by-side printout (refined is ALWAYS right column) ---
    if verbose:
        print("═" * 86)
        print(" CE Comparison ".center(86, "═"))
        print("═" * 86)
        print(f"Refined CE key     : {refined_ce_key}")
        print(f"Base (composition) : {base_ce_key}")
        print(f"Prototype          : {r_proto}")
        print(f"Supercell diag     : {r_sc}  (n_prims={n_prims}, sites/prim={sites_per_prim})")
        print("")

        # 1) Performance: Base CE on refined dataset (left) vs Refined CE stored CV (right)
        left_label_1 = "Base CE (OOS on refined set)"
        right_label_1 = "Refined CE (5-fold CV, stored)"
        perf_rows_1: list[tuple[str, str, str]] = [
            ("n",                   str(int(base_on_refined_stats["n"])),                 str(int(ref_cv["n"]))),
            ("MAE (meV/site)",      fmt_mev(base_on_refined_stats["mae_per_site"]),       fmt_mev(float(ref_cv["mae_per_site"]))),
            ("RMSE (meV/site)",     fmt_mev(base_on_refined_stats["rmse_per_site"]),      fmt_mev(float(ref_cv["rmse_per_site"]))),
            ("Max|err| (meV/site)", fmt_mev(base_on_refined_stats["max_abs_per_site"]),   fmt_mev(float(ref_cv["max_abs_per_site"]))),
        ]
        _print_two_col_table(
            title="Performance (per-site, meV) — Base CE evaluated on refined dataset  vs  Refined CE stored 5-fold CV",
            col_a_name=left_label_1,
            col_b_name=right_label_1,
            rows=perf_rows_1,
        )

        verdict1 = _better_text(
            left_label_1, right_label_1,
            mae_left=float(base_on_refined_stats["mae_per_site"]),
            rmse_left=float(base_on_refined_stats["rmse_per_site"]),
            mae_right=float(ref_cv["mae_per_site"]),
            rmse_right=float(ref_cv["rmse_per_site"]),
            atol=atol, rtol=rtol,
        )
        print(f"Verdict A      : {verdict1}")

        # 2) Performance: Base CE stored CV (left) vs Refined CE on base dataset (right)
        left_label_2 = "Base CE (5-fold CV, stored)"
        right_label_2 = "Refined CE (OOS on base set)"
        perf_rows_2: list[tuple[str, str, str]] = [
            ("n",                   str(int(base_cv["n"])),                               str(int(refined_on_base_stats["n"]))),
            ("MAE (meV/site)",      fmt_mev(float(base_cv["mae_per_site"])),              fmt_mev(refined_on_base_stats["mae_per_site"])),
            ("RMSE (meV/site)",     fmt_mev(float(base_cv["rmse_per_site"])),             fmt_mev(refined_on_base_stats["rmse_per_site"])),
            ("Max|err| (meV/site)", fmt_mev(float(base_cv["max_abs_per_site"])),          fmt_mev(refined_on_base_stats["max_abs_per_site"])),
        ]
        _print_two_col_table(
            title="Performance (per-site, meV) — Base CE stored 5-fold CV  vs  Refined CE evaluated on base dataset",
            col_a_name=left_label_2,
            col_b_name=right_label_2,
            rows=perf_rows_2,
        )

        verdict2 = _better_text(
            left_label_2, right_label_2,
            mae_left=float(base_cv["mae_per_site"]),
            rmse_left=float(base_cv["rmse_per_site"]),
            mae_right=float(refined_on_base_stats["mae_per_site"]),
            rmse_right=float(refined_on_base_stats["rmse_per_site"]),
            atol=atol, rtol=rtol,
        )
        print(f"Verdict B      : {verdict2}")
        print("═" * 86)

    return {
        "refined_ce_key": refined_ce_key,
        "base_ce_key": base_ce_key,
        "n_prims": n_prims,
        "sites_per_prim": sites_per_prim,
        "base_on_refined_stats_per_site": base_on_refined_stats,
        "refined_cv_stats_per_site": ref_cv,
        "refined_on_base_stats_per_site": refined_on_base_stats,
        "base_cv_stats_per_site": base_cv,
        "design_metrics": {
            "base": dm_base,
            "refined": dm_ref,
        },
    }


# ---- Example:
compare_comp_vs_refined_on_refined_dataset("0c5b7a6930102e3ee48ae6f28aec48bcba181402c2acde5ba77704c97db2b177") # 500


  from .autonotebook import tqdm as notebook_tqdm


══════════════════════════════════════════════════════════════════════════════════════
═══════════════════════════════════ CE Comparison ════════════════════════════════════
══════════════════════════════════════════════════════════════════════════════════════
Refined CE key     : 0c5b7a6930102e3ee48ae6f28aec48bcba181402c2acde5ba77704c97db2b177
Base (composition) : 11a657d3fc1372898856817e9576d01355520acdf466f33b572359197e6f7426
Prototype          : spinel16c_Q0Cl
Supercell diag     : (2, 2, 2)  (n_prims=8, sites/prim=10)

Performance (per-site, meV) — Base CE evaluated on refined dataset  vs  Refined CE stored 5-fold CV
 Metric              │ Base CE (OOS on refined set) │ Refined CE (5-fold CV, stored) 
─────────────────────────────────────────────────────────────────────────────────────────
 n                   │ 500                          │ 500                            
 MAE (meV/site)      │ 22.039                       │ 6.963                          
 RMSE (meV/site)     │ 

{'refined_ce_key': '0c5b7a6930102e3ee48ae6f28aec48bcba181402c2acde5ba77704c97db2b177',
 'base_ce_key': '11a657d3fc1372898856817e9576d01355520acdf466f33b572359197e6f7426',
 'n_prims': 8,
 'sites_per_prim': 10,
 'base_on_refined_stats_per_site': {'n': 500,
  'mae_per_site': 0.022038559109991338,
  'rmse_per_site': 0.029465411421590346,
  'max_abs_per_site': 0.0805451287307335},
 'refined_cv_stats_per_site': {'n': 500,
  'mae_per_site': 0.0069634943471185515,
  'rmse_per_site': 0.008796481862236702,
  'max_abs_per_site': 0.026566101216257465},
 'refined_on_base_stats_per_site': {'n': 102,
  'mae_per_site': 0.0062258414841970585,
  'rmse_per_site': 0.007712783391994033,
  'max_abs_per_site': 0.02041815321230711},
 'base_cv_stats_per_site': {'n': 102,
  'mae_per_site': 0.008026480751031517,
  'rmse_per_site': 0.010341332573649442,
  'max_abs_per_site': 0.03948092140349413},
 'design_metrics': {'base': {'n_samples': 102,
   'n_features': 1395,
   'rank': 43,
   'sigma_max': 277.6701464908878