In [None]:
%load_ext autoreload
%autoreload 2

import numpy as np
from ase.build import bulk

# Build a bulk structure of MgO with rocksalt crystal structure
# Then, replace some of the Mg atoms with Fe atoms at a specified ratio.
conv_cell = bulk("MgO", crystalstructure="rocksalt", a=4.2, cubic=True)
rng = np.random.default_rng(123)
replace_element = "Mg"
new_elements=("Mg", "Fe")
supercell_size = 4
supercell_diag = (supercell_size, supercell_size, supercell_size)
T_low = 100
T_high = 1000
ratio = 0.5
snapshot_counts = 500
debug_mode = False

In [None]:
from monty.serialization import loadfn
import tc.dataset
import tc.testing
from mace.calculators import mace_mp
import tc.wang_landau
from pymatgen.io.ase import AseAtomsAdaptor

# Only create the ensembles if they do not already exist
try:
    ensemble, endpoint_energies = loadfn(f"{''.join(new_elements)}O_ensemble4_{round(1000*ratio)}.json.gz")
    print("Ensemble already exist, skipping creation.")
except FileNotFoundError:
    print(f"Creating random snapshot ensemble with {snapshot_counts} snapshots...")
    calc = mace_mp(model="large", device="cuda", default_dtype="float32")
    endpoint_energies = tc.dataset.calculate_endpoint_energies(conv_cell, calc, replace_element, new_elements)
    snapshots = tc.dataset.make_snapshots(conv_cell, supercell_diag, rng, replace_element, new_elements, snapshot_counts, ratio=ratio)
    ensemble = tc.dataset.create_canonical_ensemble(conv_cell, calc, replace_element, new_elements, supercell_size, ratio, endpoint_energies, supercell_diag, snapshots, reuse_site_map=True)
    if debug_mode:
        stats = tc.testing.evaluate_ensemble_vs_mace(ensemble=ensemble, calc=calc, conv_cell=conv_cell, rng=rng, endpoint_energies=endpoint_energies, replace_element=replace_element, new_elements=new_elements, comps=(ratio,))

    # Use Wang-Landau to sample configurations uniformly across the range of energies
    random_samples = tc.testing.sample_configs_fast(ensemble, rng, n_samples=10_000, ratio=ratio)
    sampler, mu, min_E, max_E, bin_size = tc.wang_landau.run_wang_landau(ensemble, random_samples, rng, thin_target=snapshot_counts, ratio=ratio, n_samples_per_site=5_000, num_bins=200, window_width_factor=(50,50), progress=True)
    if debug_mode:
        tc.testing.compare_sampler_and_mace(sampler, ensemble, calc, endpoint_eVpc=tuple(endpoint_energies), cation_elements=new_elements)
    
    # Retrain the ensemble with the Wang-Landau configurations
    wl_occupancies = sampler.samples.get_trace_value("occupancy")
    snapshots = [AseAtomsAdaptor.get_atoms(ensemble.processor.structure_from_occupancy(occ)) for occ in wl_occupancies]
    ensemble = tc.dataset.create_canonical_ensemble(conv_cell, calc, replace_element, new_elements, supercell_size, ratio, endpoint_energies, supercell_diag, snapshots, reuse_site_map=True)
    if debug_mode:
        stats = tc.testing.evaluate_ensemble_vs_mace(ensemble=ensemble, calc=calc, conv_cell=conv_cell, rng=rng, endpoint_energies=endpoint_energies, replace_element=replace_element, new_elements=new_elements, comps=(ratio,))
        sampler, mu, min_E, max_E, bin_size = tc.wang_landau.run_wang_landau(ensemble, random_samples, rng, thin_target=snapshot_counts, ratio=ratio, n_samples_per_site=5_000, num_bins=200, window_width_factor=(50,50), progress=True)
        tc.testing.compare_sampler_and_mace(sampler, ensemble, calc, endpoint_eVpc=tuple(endpoint_energies), cation_elements=new_elements)

In [None]:
import tc.wang_landau
import tc.testing

samples = tc.testing.sample_configs_fast(ensemble, rng, n_samples=10_000, ratio=ratio)
sampler, mu, min_E, max_E, bin_size = tc.wang_landau.run_wang_landau(ensemble=ensemble, samples=samples, rng=rng, 
                                                                     ratio=ratio, n_samples_per_site=500_000, num_bins=200,
                                                                     window_width_factor=(50,50), progress=True)
temperatures_K = np.linspace(T_low, T_high, 10_000)
Cv = tc.wang_landau.compute_thermodynamics(sampler, temperatures_K)
tc.wang_landau.generate_wl_plots( mu, min_E, max_E, bin_size, sampler, temperatures_K, Cv)