In [None]:
%load_ext autoreload
%autoreload 2

import numpy as np
from ase.build import bulk

Mg_conv_cell = bulk("MgO", crystalstructure="rocksalt", a=4.2, cubic=True)
rng = np.random.default_rng(42)
replace_element = "Mg"
new_elements=("Mg", "Fe")
filename = f"{''.join(new_elements)}O_ensembles.json.gz"
ratio = 0.5

In [None]:
from monty.serialization import dumpfn, loadfn
import tc.dataset
import tc.testing
import tc.wang_landau
from mace.calculators import mace_mp

# Only create the ensembles if they do not already exist
try:
    ensembles = loadfn(filename)
    print("Ensembles already exist, skipping creation.")
except FileNotFoundError:
    print("Creating new ensembles...")
    calc = mace_mp(model="large", device="cuda", default_dtype="float64")
    ensembles = tc.dataset.make_ce_ensembles_from_mace(conv_cell=Mg_conv_cell, rng=rng, calc=calc, ratio=ratio, 
                                                       replace_element=replace_element, new_elements=new_elements, bin_counts=200)
    dumpfn(ensembles, filename, indent=2)
    for ensemble in ensembles:
        stats = tc.testing.evaluate_ensemble_vs_mace(ensemble=ensemble, calc=calc, conv_cell=Mg_conv_cell, rng=rng, 
                                                     replace_element=replace_element, new_elements=new_elements)

ensemble_4, ensemble_6, ensemble_8 = ensembles
samples = tc.testing.sample_configs_fast(ensemble_8, rng, n_samples=10_000, ratio=ratio)
sampler = tc.wang_landau.run_wang_landau(ensemble=ensemble_8, samples=samples, rng=rng, ratio=ratio)

  _Jd, _W3j_flat, _W3j_indices = torch.load(os.path.join(os.path.dirname(__file__), 'constants.pt'))


Ensembles already exist, skipping creation.


  0%|          | 0/10000 [00:00<?, ?it/s]

CE energies: mean = -77617.00 meV, std =  1404.58 meV, min = -83467.04 meV, max = -71943.38 meV
Energy window : [-84.640, -70.594] eV (100 bins, 0.1405 eV each)


Sampling 1 chain(s) from a cell with 4096 sites:  55%|█████▍    | 2250822/4096000 [03:13<02:48, 10930.39it/s]