In [None]:
from collections import defaultdict
from itertools import permutations, product
from typing import Any, Mapping, Sequence, cast
from pathlib import Path
import json
import shutil
import sys
import importlib.metadata as im
from pymatgen.core.structure import Structure
from pymatgen.io.ase import AseAtomsAdaptor

from phaseedge.export.garnet_ce import garnet_ce
from phaseedge.export.spinel_ce import spinel_ce
from phaseedge.export.double_perovskite_ce import double_perovskite_ce
from phaseedge.jobs.store_ce_model import lookup_ce_by_key
from phaseedge.schemas.calc_spec import CalcSpec
from phaseedge.science.prototype_spec import PrototypeSpec
from phaseedge.storage.cetrainref_dataset import Dataset
from phaseedge.storage.store import lookup_total_energy_eV


def compute_composition_map(structure: Structure) -> dict[str, dict[str, int]]:
    """
    Compute a nested composition map over active cation sites.

    Top-level keys are sublattice names (from the 'sublattice' site property).
    Inner dict maps element symbol -> count of atoms on that sublattice.

    Only sites with role == "active_cation" are included.

    Example output
    --------------
    {
        "A": {"Mg": 16},
        "B": {"Al": 32},
    }
    """
    props = structure.site_properties

    if "role" not in props:
        raise ValueError("Structure is missing 'role' site property.")
    if "sublattice" not in props:
        raise ValueError("Structure is missing 'sublattice' site property.")

    roles = props["role"]
    sublattices = props["sublattice"]

    n_sites = len(structure)
    if len(roles) != n_sites or len(sublattices) != n_sites:
        raise ValueError(
            "Site property lengths must match number of sites: "
            f"len(roles)={len(roles)}, len(sublattice)={len(sublattices)}, "
            f"n_sites={n_sites}."
        )

    # composition_map[sublattice][element] = count
    composition_map: dict[str, dict[str, int]] = defaultdict(lambda: defaultdict(int))

    for i, site in enumerate(structure):
        role = roles[i]
        if role != "active_cation":
            continue

        sublattice = sublattices[i]
        if sublattice is None:
            raise ValueError(f"Site {i} has role='active_cation' but no sublattice label.")

        if not site.is_ordered:
            raise ValueError(
                f"Site {i} on sublattice {sublattice!r} is disordered; "
                "compute_composition_map requires ordered sites."
            )

        element = site.specie.symbol  # type: ignore[attr-defined]
        composition_map[str(sublattice)][element] += 1

    # Convert nested defaultdicts to plain dicts, with sorted keys for determinism
    return {
        sub: {el: count for el, count in sorted(elem_counts.items())}
        for sub, elem_counts in sorted(composition_map.items())
    }


def strip_site_property(struct: Structure, name: str) -> Structure:
    """Return a copy of struct with a given site_property removed, if present."""
    if name not in struct.site_properties:
        return struct
    s = struct.copy()
    s.remove_site_property(name)
    return s


def ensure_structure_properties(
    disordered: Structure,
    default: Structure,
    *,
    frac_tol: float = 1e-6,
) -> Structure:
    """
    Ensure that `disordered` has the same per-site properties as `default`,
    allowing for scrambled site ordering as long as a one-to-one mapping
    can be made via fractional coordinates.

    Steps
    -----
    1. Check that both structures have the same number of sites.
    2. Build a bijection disordered_index -> default_index by matching
       fractional coordinates within `frac_tol`.
    3. For each site property present on `default`:
       - If missing on `disordered`, copy it over using the mapping.
       - If present on `disordered`, require that all values match the
         mapped default values exactly; otherwise raise.

    Any additional site_properties on `disordered` that do not exist on
    `default` are left untouched.

    Returns
    -------
    Structure
        A copy of `disordered` with site_properties guaranteed to be
        consistent with `default` under the coordinate mapping.
    """
    n_dis = len(disordered)
    n_def = len(default)

    if n_dis != n_def:
        raise ValueError(
            "Structure site count mismatch: "
            f"disordered has {n_dis} sites, default has {n_def} sites."
        )

    n_sites = n_dis

    # Build a mapping dis_idx -> def_idx by matching fractional coordinates.
    used_default: set[int] = set()
    dis_to_def: list[int] = [-1] * n_sites

    for i_dis, site_dis in enumerate(disordered):
        fc_dis = site_dis.frac_coords
        matched_j: int | None = None

        for j_def, site_def in enumerate(default):
            if j_def in used_default:
                continue

            fc_def = site_def.frac_coords
            dx = float(fc_dis[0] - fc_def[0])
            dy = float(fc_dis[1] - fc_def[1])
            dz = float(fc_dis[2] - fc_def[2])
            max_abs = max(abs(dx), abs(dy), abs(dz))

            if max_abs <= frac_tol:
                matched_j = j_def
                break

        if matched_j is None:
            raise ValueError(
                "Could not find matching site in default structure for "
                f"disordered site {i_dis} with fractional coords "
                f"{fc_dis} (tol={frac_tol})."
            )

        dis_to_def[i_dis] = matched_j
        used_default.add(matched_j)

    if len(used_default) != n_sites:
        # This should not happen if we required a unique match for each
        # disordered site, but keep it as a sanity check.
        raise ValueError(
            "Failed to construct a one-to-one mapping between disordered "
            "and default sites; some default sites were not matched."
        )

    # Work on a copy so we don't mutate the original disordered structure.
    result = disordered.copy()

    default_props = default.site_properties
    result_props = result.site_properties

    for name, default_vals in default_props.items():
        if len(default_vals) != n_sites:
            raise ValueError(
                f"Default site property '{name}' has length "
                f"{len(default_vals)}, expected {n_sites}."
            )

        if name not in result_props:
            # Property missing on disordered: copy values according to mapping.
            mapped_vals: list[Any] = [None] * n_sites
            for i_dis in range(n_sites):
                j_def = dis_to_def[i_dis]
                mapped_vals[i_dis] = default_vals[j_def]
            result.add_site_property(name, mapped_vals)
        else:
            # Property exists on both: require exact match under mapping.
            existing_vals = result_props[name]
            if len(existing_vals) != n_sites:
                raise ValueError(
                    f"Disordered site property '{name}' has length "
                    f"{len(existing_vals)}, expected {n_sites}."
                )

            for i_dis in range(n_sites):
                j_def = dis_to_def[i_dis]
                expected = default_vals[j_def]
                actual = existing_vals[i_dis]
                if actual != expected:
                    raise ValueError(
                        f"Site property '{name}' mismatch at disordered site "
                        f"{i_dis}: existing={actual!r}, "
                        f"default-mapped={expected!r}."
                    )

    return result


def _strip_kfs3_prefix(dir_name: str, anchor: str = "/kfs3/scratch/cbu/") -> Path:
    """
    Strip everything before the /kfs3/scratch/cbu/ anchor from an energy_result.dir_name.

    Example
    -------
    x3107c0s13b0n0.head.cm.kestrel.hpc.nrel.gov:/kfs3/scratch/cbu/block_.../launcher_...
    -> /kfs3/scratch/cbu/block_.../launcher_...

    Parameters
    ----------
    dir_name : str
        Raw dir_name from the energy_result.
    anchor : str
        Anchor substring that marks the beginning of the filesystem path.

    Returns
    -------
    Path
        Path starting at /kfs3/scratch/cbu/...
    """
    idx = dir_name.find(anchor)
    if idx == -1:
        raise RuntimeError(
            f"dir_name does not contain expected anchor {anchor!r}: {dir_name!r}"
        )
    stripped = dir_name[idx:]
    return Path(stripped)


def organize_results(ce_key: str, formula: str, convergence: float) -> None:
    """
    For a given CE key, validate that all associated energy_results are converged and
    have existing directories on disk, then organize them into a new directory tree.

    Steps
    -----
    1. Load CE doc and dataset.
    2. For each train_ref:
       - Fetch energy_result.
       - Require that max_force_eV_per_A <= calc_spec.max_force_eV_per_A.
       - Strip dir_name to /kfs3/scratch/cbu/... and require that directory exists.
       - Collect (occ_key, stripped_dir_path, train_ref).
       If *any* check fails, raise an error (fail loudly).
    3. Create /scratch/cbu/mp/{name}/.
    4. For each (occ_key, stripped_dir_path, ref):
       - Create /scratch/cbu/mp/{name}/{occ_key}/.
       - Copy the contents of stripped_dir_path into that directory.
       - Write a JSON file with the 6 fields:
         ordered_task_id, prototype, prototype_params, composition_map,
         configurational_structure, supercell_diag.

    Parameters
    ----------
    ce_key : str
        Cluster-expansion key whose training-set results should be organized.
    name : str
        Name of the new folder under /scratch/cbu/mp/, e.g. "test" -> /scratch/cbu/mp/test/.
    ordered_task_id : str
        The durable ordered parent task id to record in each config JSON.
    """
    # ---- Load CE document ----
    doc = lookup_ce_by_key(ce_key)
    if not doc:
        raise RuntimeError(f"No CE found for ce_key={ce_key}")

    calc_spec_top = CalcSpec.from_dict(cast(Mapping[str, Any], doc["calc_spec"]))
    dataset_key = cast(str, doc["dataset_key"])
    dataset = Dataset.from_key(dataset_key)

    # supercell_diag: stored in the CE doc; ensure it's a 3-tuple of ints
    sx, sy, sz = (int(x) for x in doc["supercell_diag"])
    supercell_diag = (sx, sy, sz)

    # prototype + params: pull from prototype_spec dict in the CE doc
    prototype_spec_data = cast(Mapping[str, Any], doc.get("prototype_spec", {}))
    prototype = str(prototype_spec_data.get("prototype", ""))

    raw_params = prototype_spec_data.get("params", {})  # type: ignore[assignment]
    prototype_params: dict[str, float] = {}
    for k, v in raw_params.items():
        prototype_params[str(k)] = float(v)
    if 'spinel' in prototype and 'u' not in prototype_params:
        # Special case: add default u parameter for spinels if missing
        prototype_params['u'] = 0.36

    # Build the canonical ordered prototype supercell with metadata
    proto_spec = PrototypeSpec(
        prototype=prototype,
        params=prototype_params,
    )
    primitive_cell = proto_spec.primitive_cell  # ASE Atoms with metadata
    sc = primitive_cell.repeat(supercell_diag)
    default_structure = AseAtomsAdaptor.get_structure(sc)

    train_refs: Sequence[Any] = dataset.train_refs

    print(f"CE key      : {ce_key}")
    print(f"dataset_key : {dataset_key}")
    print(f"n train_refs: {len(train_refs)}")
    print(f"supercell_diag: {supercell_diag}")
    print(f"prototype      : {prototype}")
    print(f"prototype_params (numeric only): {prototype_params}")

    # ---- First pass: validate everything and collect paths + refs ----
    collected: list[tuple[str, Path, Any]] = []

    for i, ref in enumerate(train_refs):
        composition_map = ref.composition_map
        occ_key = ref.occ_key
        calc_spec = ref.calc_spec

        # Sanity: warn if calc_spec differs from CE-level spec, but still use ref.calc_spec
        if calc_spec != calc_spec_top:
            print(
                f"[{i:4d}] WARNING: calc_spec mismatch between CE doc and train_ref; "
                "using ref.calc_spec for lookup."
            )

        energy_result = lookup_total_energy_eV(
            occ_key=occ_key,
            calc_spec=calc_spec,
        )
        if energy_result is None:
            raise RuntimeError(
                f"[{i}] Missing energy_result for occ_key={occ_key}, "
                f"composition_map={composition_map}"
            )

        # Convergence check
        if energy_result.max_force_eV_per_A > calc_spec.max_force_eV_per_A:
            raise RuntimeError(
                f"[{i}] Unconverged energy_result for occ_key={occ_key}, "
                f"composition_map={composition_map}. "
                f"max|F|={energy_result.max_force_eV_per_A:.3f} eV/Å "
                f"> threshold={calc_spec.max_force_eV_per_A:.3f} eV/Å"
            )

        # Strip the host prefix to get /kfs3/scratch/cbu/...
        raw_dir_name = energy_result.dir_name
        stripped_path = _strip_kfs3_prefix(raw_dir_name)

        # Require that the directory exists
        if not stripped_path.is_dir():
            raise RuntimeError(
                f"[{i}] Directory does not exist on disk: {stripped_path} "
                f"(occ_key={occ_key}, composition_map={composition_map})"
            )

        collected.append((occ_key, stripped_path, ref))

    print("All energy_results are converged and all directories exist. Proceeding to copy.")

    # ---- Prepare target root directory under /scratch/cbu/mp/{prototype}/{name} ----
    target_root = Path("/scratch/cbu/mp") / prototype / f"{formula}_{int(1000*convergence)}meVpA"
    if target_root.exists():
        raise RuntimeError(
            f"Target directory already exists: {target_root}. "
            "Refusing to overwrite; please choose a different name or clean up manually."
        )

    target_root.mkdir(parents=True, exist_ok=False)
    print(f"Created target root: {target_root}")

    # ---- Second pass: copy into /scratch/cbu/mp/{name}/{occ_key}/ and write JSON ----
    for i, (occ_key, src_dir, ref) in enumerate(collected):
        dest_dir = target_root / occ_key

        if dest_dir.exists():
            raise RuntimeError(
                f"[{i}] Destination directory already exists: {dest_dir}. "
                "This should not happen; aborting to avoid clobbering."
            )

        print(f"[{i:4d}] Copying {src_dir} -> {dest_dir}")
        shutil.copytree(src_dir, dest_dir, dirs_exist_ok=False)

        # Ensure it has the canonical site_properties from the prototype
        reference_structure = ensure_structure_properties(
            disordered=ref.structure,
            default=default_structure,
        )

        reference_structure = strip_site_property(reference_structure, "spacegroup_kinds")
        composition_map = compute_composition_map(reference_structure)

        payload = {
            "ordered_task_id": f"PLACEHOLDER_{formula}",
            "prototype": prototype,
            "prototype_params": prototype_params,
            "composition_map": composition_map,
            "reference_structure": reference_structure.as_dict(),
            "supercell_diag": list(supercell_diag),
            "versions": {
                "python_version": sys.version.split()[0],
                "smol_version": im.version("smol"),
                "phase_edge_version": "1.0.1",
                "pymatgen_version": im.version("pymatgen"),
                "ase_version": im.version("ase"), 
                "numpy_version": im.version("numpy"),
                "pre_relax_model": "TensorNet-MatPES-r2SCAN-v2025.1-PES",
            }
        }

        json_path = dest_dir / "disordered_task_doc_metadata.json"
        with json_path.open("w", encoding="utf-8") as fh:
            json.dump(payload, fh, indent=2, sort_keys=True)
        print(f"[{i:4d}] Wrote metadata JSON: {json_path}")

    print("Done organizing results.")
    print(f"All data copied under: {target_root}")

# for A, B in product(["Mg", "Zn"], ["Al", "Ga", "In", "Sc"]):
#     convergence=0.02
#     formula = f"{A}{B}2O4"
#     ce_key = spinel_ce(a_cation=A, b_cation=B, convergence=convergence)
#     if ce_key is None:
#         raise RuntimeError(f"No CE found for {formula} with convergence={convergence}")
#     organize_results(ce_key, formula=formula, convergence=convergence)

# for A in ["Y"]:
#     for B, C in permutations(("Al", "Ga", "Sc"), 2):
#         convergence=0.02
#         formula = f"{A}3{B}2{C}3O4"
#         ce_key = garnet_ce(a_cation=A, b_cation=B, c_cation=C, convergence=convergence)
#         if ce_key is None:
#             raise RuntimeError(f"No CE found for {formula} with convergence={convergence}")
#         organize_results(ce_key, formula=formula, convergence=convergence)

# skip_list = [("Lu", "Al", "Sc"), ("La", "Al", "Sc"), ("Lu", "Sc", "Al"), ("Lu", "Ga", "Sc")]
# for A in ["La", "Lu"]:
#     for B in ["Al", "Ga", "Sc"]:
#         for C in ["Al", "Ga", "Sc"]:
#             if B != C and (A, B, C) not in skip_list:
#                 garnet_ce(a_cation=A, b_cation=B, c_cation=C, convergence=0.02, launch=True)

# for A in ["Mg", "Sr", "Ba"]:
#     for B in ["Al", "Ga", "Sc"]:
#         for C in ["Si", "Ge"]:
#             garnet_ce(a_cation=A, b_cation=B, c_cation=C, convergence=0.02, launch=True)

# for A in ["Ca"]:
#     for Bp in ["Al", "Ga", "Sc", "In", "Y", "Lu"]:
#         for Bpp in ["Nb", "Ta", "Sb"]:
#             double_perovskite_ce(a_cation=A, bprime_cation=Bp, bdblprime_cation=Bpp, convergence=0.02, launch=True)

2025-12-20 10:36:21,143 INFO Added a workflow. id_map: {-36: 26399}
A=Ca Bp=Al Bpp=Nb ce_key=80ca1204f06085b37dca0e97f95937e55235da96027c37b06199d1d95a4faa0a
2025-12-20 10:36:21,962 INFO Added a workflow. id_map: {-37: 26400}
A=Ca Bp=Al Bpp=Ta ce_key=d7eef8216a60c1e0dc6a0a270a1267299b6e2205a9cd1b422a05b6f69dc3cbf6
2025-12-20 10:36:22,770 INFO Added a workflow. id_map: {-38: 26401}
A=Ca Bp=Al Bpp=Sb ce_key=4fdaa02830728d2b96f9812ba88a37d5cff4543ac54724b7a22326fcd9ab0dc0
2025-12-20 10:36:23,564 INFO Added a workflow. id_map: {-39: 26402}
A=Ca Bp=Ga Bpp=Nb ce_key=cc98fcc8c660c7019eacacd18bbd0e81ae9a5882cf8adaf52cc672d8b7cdc45f
2025-12-20 10:36:24,318 INFO Added a workflow. id_map: {-40: 26406}
A=Ca Bp=Ga Bpp=Ta ce_key=bbb8e087028c868a7d59f26471402d7f014c144a828d1b21013ded206c7a51bb
2025-12-20 10:36:25,123 INFO Added a workflow. id_map: {-41: 26434}
A=Ca Bp=Ga Bpp=Sb ce_key=3e9947484f3b306cbf1cbec494d5cb01a2e38a7176edf6c6f51fedc4f4841674
2025-12-20 10:36:25,911 INFO Added a workflow. id_ma

In [None]:
from pymatgen.io.validation import VaspValidator # pyright: ignore[reportPrivateImportUsage]
validation_doc = VaspValidator.from_directory("/scratch/cbu/mp/MgAl2O4_02eVpA_mp-3536/000ad105c405e04483c8e17cb6fe0fad817ab8ba975432ef4b8b86e152dfc10c", fast=False)
validation_doc.reasons

['INPUT SETTINGS --> GGA_COMPAT: is False, but should be == True.']

In [None]:
from phaseedge.science.prototype_spec import PrototypeSpec

from pathlib import Path
from itertools import permutations, product, combinations
import json

from pymatgen.io.ase import AseAtomsAdaptor
from pymatgen.core.structure import Structure


def generate_ordered_prototypes(output_path: str | Path) -> None:
    """
    Generate the full matrix of ordered prototypes for:
      - Double perovskite A2B'B''O6
      - Spinel AB2O4
      - Ilmenite ABO3
      - Pyrochlore A2B2O7
      - Garnet A3B2C3O12

    using the PrototypeSpec machinery (Es/Fm placeholders), then:
      * Replace Es/Fm with the actual cations for each category.
      * For double perovskites and garnets, put the inactive A cation
        into the prototype string (J0A) as requested.
      * Drop any site labels/properties from the resulting Structures.
      * Serialize each Structure via Structure.as_dict().
      * Write a single flat JSON list of all structures.

    Also prints a list of compositions per category and counts
    for sanity checking.

    Args:
        output_path: Path to the JSON file to write.
    """
    path = Path(output_path)
    path.parent.mkdir(parents=True, exist_ok=True)

    adaptor = AseAtomsAdaptor()

    def make_structure(
        prototype: str,
        params: dict[str, float],
        placeholder_map: dict[str, str],
    ) -> Structure:
        """
        Build a pymatgen Structure from a PrototypeSpec and replace
        placeholder species (Es, Fm, Md, ...) according to placeholder_map.

        We explicitly rebuild the Structure so that no site labels or
        site properties survive. The resulting Structure.sites have
        no 'label' field in as_dict().
        """
        spec = PrototypeSpec(prototype=prototype, params=params)
        atoms = spec.primitive_cell

        # Convert to a temporary Structure (may have labels from the placeholder symbols)
        tmp = adaptor.get_structure(atoms) # pyright: ignore[reportArgumentType]

        species: list[str] = []
        frac_coords: list[list[float]] = []

        for site in tmp.sites:
            symbol = site.species_string
            new_symbol = placeholder_map.get(symbol, symbol)
            species.append(new_symbol)
            frac_coords.append(site.frac_coords.tolist())

        # Rebuild without any labels or site_properties
        struct = Structure(
            lattice=tmp.lattice,
            species=species,
            coords=frac_coords,
            coords_are_cartesian=False,
        )

        return struct

    all_structures: list[Structure] = []

    # =========================
    # Double perovskite A2B'B''O6
    # =========================
    dp_params: dict[str, float] = {"a": 7.86}
    dp_A: tuple[str, ...] = ("Ca", "Sr", "Ba", "Pb")

    # ---- Category 1 ----
    dp1_Bp: tuple[str, ...] = ("Al", "Ga", "Sc", "In", "Y", "Lu")
    dp1_Bpp: tuple[str, ...] = ("Nb", "Ta", "Sb", "V")
    dp1_structs: list[Structure] = []
    dp1_labels: list[str] = []

    for A, Bp, Bpp in product(dp_A, dp1_Bp, dp1_Bpp):
        prototype = f"doubleperovskite_J0{A}_Q0O"
        struct = make_structure(
            prototype,
            dp_params,
            {"Es": Bp, "Fm": Bpp},
        )
        dp1_structs.append(struct)
        dp1_labels.append(f"A={A}, B'={Bp}, B''={Bpp}")

    print("Double perovskite Category 1 (A2B'B''O6):")
    for label in dp1_labels:
        print("  ", label)
    print(f"  Count: {len(dp1_structs)}")

    all_structures.extend(dp1_structs)

    # ---- Category 2 ----
    dp2_Bp: tuple[str, ...] = ("Mg", "Ca", "Zn")
    dp2_Bpp: tuple[str, ...] = ("W", "Mo")
    dp2_structs: list[Structure] = []
    dp2_labels: list[str] = []

    for A, Bp, Bpp in product(dp_A, dp2_Bp, dp2_Bpp):
        prototype = f"doubleperovskite_J0{A}_Q0O"
        struct = make_structure(
            prototype,
            dp_params,
            {"Es": Bp, "Fm": Bpp},
        )
        dp2_structs.append(struct)
        dp2_labels.append(f"A={A}, B'={Bp}, B''={Bpp}")

    print("\nDouble perovskite Category 2 (A2B'B''O6):")
    for label in dp2_labels:
        print("  ", label)
    print(f"  Count: {len(dp2_structs)}")

    all_structures.extend(dp2_structs)

    # ---- Category 3 ----
    dp3_A: tuple[str, ...] = ("La", "Bi", "Y", "Lu")
    dp3_Bp: tuple[str, ...] = ("Mg", "Ca", "Zn")
    dp3_Bpp: tuple[str, ...] = ("Ti", "Sn", "Zr", "Hf", "Ce", "Ge")
    dp3_structs: list[Structure] = []
    dp3_labels: list[str] = []

    for A, Bp, Bpp in product(dp3_A, dp3_Bp, dp3_Bpp):
        prototype = f"doubleperovskite_J0{A}_Q0O"
        struct = make_structure(
            prototype,
            dp_params,
            {"Es": Bp, "Fm": Bpp},
        )
        dp3_structs.append(struct)
        dp3_labels.append(f"A={A}, B'={Bp}, B''={Bpp}")

    print("\nDouble perovskite Category 3 (A2B'B''O6):")
    for label in dp3_labels:
        print("  ", label)
    print(f"  Count: {len(dp3_structs)}")

    all_structures.extend(dp3_structs)

    # ---- Category 4 (B'/B'': pick two, unordered) ----
    dp4_B: tuple[str, ...] = ("Ti", "Zr", "Hf", "Ce", "Sn")
    dp4_structs: list[Structure] = []
    dp4_labels: list[str] = []

    for A in dp_A:
        for Bp, Bpp in combinations(dp4_B, 2):  # unordered pair
            prototype = f"doubleperovskite_J0{A}_Q0O"
            struct = make_structure(
                prototype,
                dp_params,
                {"Es": Bp, "Fm": Bpp},
            )
            dp4_structs.append(struct)
            dp4_labels.append(f"A={A}, B'={Bp}, B''={Bpp}")

    print("\nDouble perovskite Category 4 (A2B'B''O6, B'/B'' pick two):")
    for label in dp4_labels:
        print("  ", label)
    print(f"  Count: {len(dp4_structs)}")

    all_structures.extend(dp4_structs)

    # =========================
    # Spinel AB2O4
    # =========================
    spinel_params: dict[str, float] = {"a": 8.08, "u": 0.36}
    spinel_A: tuple[str, ...] = ("Mg", "Zn")
    spinel_B: tuple[str, ...] = ("Al", "Ga", "In", "Sc")
    spinel_structs: list[Structure] = []
    spinel_labels: list[str] = []

    for A, B in product(spinel_A, spinel_B):
        prototype = "spinel_Q0O"
        struct = make_structure(
            prototype,
            spinel_params,
            {"Es": A, "Fm": B},  # Es -> A, Fm -> B
        )
        spinel_structs.append(struct)
        spinel_labels.append(f"A={A}, B={B}")

    print("\nSpinel AB2O4:")
    for label in spinel_labels:
        print("  ", label)
    print(f"  Count: {len(spinel_structs)}")

    all_structures.extend(spinel_structs)

    # =========================
    # Ilmenite ABO3
    # =========================
    ilmenite_params: dict[str, float] = {"a": 5.1, "c": 14.0}
    ilmenite_A: tuple[str, ...] = ("Mg", "Zn", "Ca", "Sr", "Ba", "Pb")
    ilmenite_B: tuple[str, ...] = ("Ti", "Si", "Ge", "Sn", "Zr", "Hf", "Ce")
    ilmenite_structs: list[Structure] = []
    ilmenite_labels: list[str] = []

    for A, B in product(ilmenite_A, ilmenite_B):
        prototype = "ilmenite_Q0O"
        struct = make_structure(
            prototype,
            ilmenite_params,
            {"Es": A, "Fm": B},  # Es -> A, Fm -> B
        )
        ilmenite_structs.append(struct)
        ilmenite_labels.append(f"A={A}, B={B}")

    print("\nIlmenite ABO3:")
    for label in ilmenite_labels:
        print("  ", label)
    print(f"  Count: {len(ilmenite_structs)}")

    all_structures.extend(ilmenite_structs)

    # =========================
    # Pyrochlore A2B2O7
    # =========================
    pyrochlore_params: dict[str, float] = {"a": 10.5, "x": 0.045}

    # ---- Category 1 ----
    py1_A: tuple[str, ...] = ("La", "Y", "Lu", "Bi")
    py1_B: tuple[str, ...] = ("Ti", "Zr", "Hf", "Sn", "Ge", "Ce")
    py1_structs: list[Structure] = []
    py1_labels: list[str] = []

    for A, B in product(py1_A, py1_B):
        prototype = "pyrochlore_Q0O"
        struct = make_structure(
            prototype,
            pyrochlore_params,
            {"Es": A, "Fm": B},  # Es -> A, Fm -> B
        )
        py1_structs.append(struct)
        py1_labels.append(f"A={A}, B={B}")

    print("\nPyrochlore Category 1 (A2B2O7):")
    for label in py1_labels:
        print("  ", label)
    print(f"  Count: {len(py1_structs)}")

    all_structures.extend(py1_structs)

    # ---- Category 2 ----
    py2_A: tuple[str, ...] = ("Ca", "Sr", "Ba", "Pb")
    py2_B: tuple[str, ...] = ("Nb", "Ta")
    py2_structs: list[Structure] = []
    py2_labels: list[str] = []

    for A, B in product(py2_A, py2_B):
        prototype = "pyrochlore_Q0O"
        struct = make_structure(
            prototype,
            pyrochlore_params,
            {"Es": A, "Fm": B},  # Es -> A, Fm -> B
        )
        py2_structs.append(struct)
        py2_labels.append(f"A={A}, B={B}")

    print("\nPyrochlore Category 2 (A2B2O7):")
    for label in py2_labels:
        print("  ", label)
    print(f"  Count: {len(py2_structs)}")

    all_structures.extend(py2_structs)

    # =========================
    # Garnet A3B2C3O12
    # =========================
    garnet_params: dict[str, float] = {"a": 12.0}

    # ---- Category 1 (Silicate-type) ----
    g1_A: tuple[str, ...] = ("Ca", "Mg", "Sr", "Ba")
    g1_B: tuple[str, ...] = ("Al", "Ga", "Sc")
    g1_C: tuple[str, ...] = ("Si", "Ge")
    g1_structs: list[Structure] = []
    g1_labels: list[str] = []

    for A, B, C in product(g1_A, g1_B, g1_C):
        # A is the "inactive" dodecahedral site, so it goes in J0
        prototype = f"garnet_J0{A}_Q0O"
        struct = make_structure(
            prototype,
            garnet_params,
            {"Es": B, "Fm": C},  # Es -> B, Fm -> C
        )
        g1_structs.append(struct)
        g1_labels.append(f"A={A}, B={B}, C={C}")

    print("\nGarnet Category 1 (A3B2C3O12, silicate-type):")
    for label in g1_labels:
        print("  ", label)
    print(f"  Count: {len(g1_structs)}")

    all_structures.extend(g1_structs)

    # ---- Category 2 (Rare Earth): B/C ordered (distinct sites) ----
    g2_A: tuple[str, ...] = ("Y", "La", "Lu")
    g2_BC: tuple[str, ...] = ("Al", "Ga", "Sc")
    g2_structs: list[Structure] = []
    g2_labels: list[str] = []

    for A in g2_A:
        # permutations gives (Al, Ga) and (Ga, Al), etc.
        for B, C in permutations(g2_BC, 2):
            prototype = f"garnet_J0{A}_Q0O"
            struct = make_structure(
                prototype,
                garnet_params,
                {"Es": B, "Fm": C},  # Es -> B (octa), Fm -> C (tetra)
            )
            g2_structs.append(struct)
            g2_labels.append(f"A={A}, B={B}, C={C}")

    print("\nGarnet Category 2 (A3B2C3O12, rare earth; ordered B/C):")
    for label in g2_labels:
        print("  ", label)
    print(f"  Count: {len(g2_structs)}")

    all_structures.extend(g2_structs)

    # =========================
    # Final summary + JSON dump
    # =========================
    print("\nTotal structures generated:", len(all_structures))

    structures_as_dicts: list[dict[str, object]] = [
        s.as_dict() for s in all_structures
    ]
    with path.open("w", encoding="utf-8") as f:
        json.dump(structures_as_dicts, f, indent=2)


generate_ordered_prototypes("/scratch/cbu/ordered_prototypes.json")

Double perovskite Category 1 (A2B'B''O6):
   A=Ca, B'=Al, B''=Nb
   A=Ca, B'=Al, B''=Ta
   A=Ca, B'=Al, B''=Sb
   A=Ca, B'=Al, B''=V
   A=Ca, B'=Ga, B''=Nb
   A=Ca, B'=Ga, B''=Ta
   A=Ca, B'=Ga, B''=Sb
   A=Ca, B'=Ga, B''=V
   A=Ca, B'=Sc, B''=Nb
   A=Ca, B'=Sc, B''=Ta
   A=Ca, B'=Sc, B''=Sb
   A=Ca, B'=Sc, B''=V
   A=Ca, B'=In, B''=Nb
   A=Ca, B'=In, B''=Ta
   A=Ca, B'=In, B''=Sb
   A=Ca, B'=In, B''=V
   A=Ca, B'=Y, B''=Nb
   A=Ca, B'=Y, B''=Ta
   A=Ca, B'=Y, B''=Sb
   A=Ca, B'=Y, B''=V
   A=Ca, B'=Lu, B''=Nb
   A=Ca, B'=Lu, B''=Ta
   A=Ca, B'=Lu, B''=Sb
   A=Ca, B'=Lu, B''=V
   A=Sr, B'=Al, B''=Nb
   A=Sr, B'=Al, B''=Ta
   A=Sr, B'=Al, B''=Sb
   A=Sr, B'=Al, B''=V
   A=Sr, B'=Ga, B''=Nb
   A=Sr, B'=Ga, B''=Ta
   A=Sr, B'=Ga, B''=Sb
   A=Sr, B'=Ga, B''=V
   A=Sr, B'=Sc, B''=Nb
   A=Sr, B'=Sc, B''=Ta
   A=Sr, B'=Sc, B''=Sb
   A=Sr, B'=Sc, B''=V
   A=Sr, B'=In, B''=Nb
   A=Sr, B'=In, B''=Ta
   A=Sr, B'=In, B''=Sb
   A=Sr, B'=In, B''=V
   A=Sr, B'=Y, B''=Nb
   A=Sr, B'=Y, B''=Ta
   A=Sr,

In [1]:
from phaseedge.science.prototype_spec import PrototypeSpec
from pymatgen.io.ase import AseAtomsAdaptor

proto_spec = PrototypeSpec(
    prototype="spinel_Q0O",
    params={"a": 8.08, "u": 0.36},
)

primitive_cell = proto_spec.primitive_cell
print("Atoms arrays:", primitive_cell.arrays.keys())

sc = primitive_cell.repeat((2, 2, 2))
default_structure = AseAtomsAdaptor.get_structure(sc)
print("site_properties keys:", default_structure.site_properties.keys())

Atoms arrays: dict_keys(['numbers', 'positions', 'spacegroup_kinds', 'role', 'sublattice', 'wyckoff'])
site_properties keys: dict_keys(['role', 'wyckoff', 'sublattice', 'spacegroup_kinds'])
