In [None]:
import os

import numpy as np
from emmet.core.disorder import DisorderedTaskDoc
from phaseedge.sampling.train_ce_driver import run_train_ce
from phaseedge.sampling.wl_block_driver import run_wl_block
from phaseedge.schemas.mixture import sublattices_from_composition_maps
from phaseedge.schemas.wl_sampler_spec import WLSamplerSpec
from phaseedge.science.prototype_spec import PrototypeSpec

from smol.cofe import ClusterExpansion
from smol.moca.ensemble import Ensemble

def postprocess_disordered_material(dir_path: str) -> None:
    # Loop over the subdirectories in dir and process DisorderedTaskDoc on them using DisorderedTaskDoc.from_directory
    docs = []
    dir_list = os.listdir(dir_path)
    for i, subdir in enumerate(dir_list):
        full_path = os.path.join(dir_path, subdir)
        if not os.path.isdir(full_path):
            raise ValueError(f"{full_path} is not a directory")
        
        # Change the instance of PLACEHOLDER_MgAl2O4 in disordered_task_doc_metadata.json to mp-2029216
        meta_file = os.path.join(full_path, "disordered_task_doc_metadata.json")
        with open(meta_file, "r") as f:
            meta_data = f.read()
        meta_data = meta_data.replace("PLACEHOLDER_MgAl2O4", "mp-2029216")
        with open(meta_file, "w") as f:
            f.write(meta_data)
        
        doc, _ = DisorderedTaskDoc.from_directory(full_path)
        print(f"{i+1}/{len(dir_list)} Processed {full_path}")
        docs.append(doc)

    # Ensure that the task docs have the same ordered_task_id, supercell_diag, protoype, prototype_params, and versions
    first_doc = docs[0]
    for doc in docs[1:]:
        if doc.ordered_task_id != first_doc.ordered_task_id:
            raise ValueError("Ordered task IDs do not match")
        if doc.supercell_diag != first_doc.supercell_diag:
            raise ValueError("Supercell diagonals do not match")
        if doc.prototype != first_doc.prototype:
            raise ValueError("Prototypes do not match")
        if doc.prototype_params != first_doc.prototype_params:
            raise ValueError("Prototype parameters do not match")
        if doc.versions != first_doc.versions:
            raise ValueError("Versions do not match")

    # Train a cluster expansion model on the collected documents
    structures_pm = [doc.reference_structure for doc in docs]
    n_prims = int(np.prod(first_doc.supercell_diag))
    y_cell = [doc.output.energy / float(n_prims) for doc in docs]
    prototype_spec = PrototypeSpec(prototype=first_doc.prototype, params=first_doc.prototype_params)
    # replace A with Es and B with Fm in the keys of composition maps
    composition_maps = [
        {("Es" if site == "A" else "Fm" if site == "B" else site): comp for site, comp in doc.composition_map.items()}
        for doc in docs
    ]
    sublattices = sublattices_from_composition_maps(composition_maps)
    basis_spec = {"basis": "sinusoid", "cutoffs": {2:10, 3:8, 4:5}}
    regularization = {"type": "ridge", "alpha": 1e-3, "l1_ratio": 0.5}
    weighting = {"scheme": "balance_by_comp", "alpha": 1.0}
    cv_seed = 42
    ce_train_output = run_train_ce(
        structures_pm=structures_pm,
        y_cell=y_cell,
        prototype_spec=prototype_spec,
        supercell_diag=first_doc.supercell_diag,
        sublattices=sublattices,
        basis_spec=basis_spec,
        regularization=regularization,
        weighting=weighting,
        cv_seed=cv_seed,
    )
    print("Cluster expansion training completed successfully.")
    print(ce_train_output["stats"])
    print(ce_train_output["design_metrics"])


    ce = ClusterExpansion.from_dict(ce_train_output["payload"])
    ensemble = Ensemble.from_cluster_expansion(ce, supercell_matrix=np.diag(first_doc.supercell_diag))

    # Run WL sampling using the trained model
    ## First we determine the appropriate sampling window
    bin_width = 0.1
    while True:
        wl_spec = WLSamplerSpec(
            ce_key="",
            bin_width=bin_width,
            steps=1_000_000,
            initial_comp_map=composition_maps[0],
            step_type="swap",
            check_period=5000,
            update_period=1,
            seed=0,
            samples_per_bin=0,
            collect_cation_stats=False,
            production_mode=False,
            reject_cross_sublattice_swaps=False,
        )
        wl_block = run_wl_block(spec=wl_spec, ensemble=ensemble, tip=None, prototype_spec=prototype_spec, supercell_diag=first_doc.supercell_diag)
        num_bins = len(wl_block["state"]["bin_indices"])
        if num_bins < 50:
            bin_width /= 2.0
            print(f"Adjusting bin width to {bin_width} (num_bins={num_bins})")
        elif num_bins > 200:
            bin_width *= 2.0
            print(f"Adjusting bin width to {bin_width} (num_bins={num_bins})")
        else:
            print(f"Selected bin width: {bin_width} (num_bins={num_bins})")
            break

    ## Then we run the sampling until we reach the desired convergence threshold
    print(f"Current mod factor: {wl_block['state']['mod_factor']}")
    while wl_block["state"]["mod_factor"] > 1e-7:
        wl_block = run_wl_block(spec=wl_spec, ensemble=ensemble, tip=wl_block, prototype_spec=prototype_spec, supercell_diag=first_doc.supercell_diag)
        print(f"Current mod factor: {wl_block['state']['mod_factor']}")

    print("Postprocessing completed successfully.")

postprocess_disordered_material("/scratch/cbu/test/MgAl2O4_20meVpA")

1/500 Processed /scratch/cbu/test/MgAl2O4_20meVpA/d9ecb1a758ad94e97b82adf685282e0dbc4f932213f0dc993ede79f7099caf05
2/500 Processed /scratch/cbu/test/MgAl2O4_20meVpA/6252a5440500ddf1f8b349c25ca041072c10eb4ee82f0083101fbe76f3e8a453
3/500 Processed /scratch/cbu/test/MgAl2O4_20meVpA/28e31d2e493de2c868549cf9089ad98d8433d3c8c9996174f5f1715159abf852
4/500 Processed /scratch/cbu/test/MgAl2O4_20meVpA/196b974583f8c03a5622b442074001e62c8cfd308720dd7826f41a031cc3e985
5/500 Processed /scratch/cbu/test/MgAl2O4_20meVpA/eb65beee4362d8f0a900f03d0f92697bd50783e0c3f57d09df8e77c919b7415d
6/500 Processed /scratch/cbu/test/MgAl2O4_20meVpA/39cf9ea2b00119ff40c64bd6215fe06553447ce3923c4c8723670cdaacc693ff
7/500 Processed /scratch/cbu/test/MgAl2O4_20meVpA/18dbd2e22e01e5b12b301c2c173651954985cbceb99ee1e2af21d96339ace0c0
8/500 Processed /scratch/cbu/test/MgAl2O4_20meVpA/8b0b040363c3aa26b297693de3b83476e73915f31e4d075902232febcbcc4e7c
9/500 Processed /scratch/cbu/test/MgAl2O4_20meVpA/d578858c8c0487651a02d93bb1819c