diff --git a/examples/scripts/5_Workflow/5.3_Hot_Swap_WBM.py b/examples/scripts/5_Workflow/5.3_Hot_Swap_WBM.py index b0a49db7..79f458fd 100644 --- a/examples/scripts/5_Workflow/5.3_Hot_Swap_WBM.py +++ b/examples/scripts/5_Workflow/5.3_Hot_Swap_WBM.py @@ -10,9 +10,9 @@ import os import time +import numpy as np import torch from mace.calculators.foundations_models import mace_mp -from matbench_discovery.data import DataFiles, ase_atoms_from_zip import torch_sim as ts @@ -40,11 +40,24 @@ max_atoms_in_batch = 50 if os.getenv("CI") else 8_000 # --- Data Loading --- -n_structures_to_relax = 2 if os.getenv("CI") else 100 -print(f"Loading {n_structures_to_relax:,} structures...") -ase_atoms_list = ase_atoms_from_zip( - DataFiles.wbm_initial_atoms.path, limit=n_structures_to_relax -) +if not os.getenv("CI"): + n_structures_to_relax = 100 + print(f"Loading {n_structures_to_relax:,} structures...") + from matbench_discovery.data import DataFiles, ase_atoms_from_zip + + ase_atoms_list = ase_atoms_from_zip( + DataFiles.wbm_initial_atoms.path, limit=n_structures_to_relax + ) +else: + n_structures_to_relax = 2 + print(f"Loading {n_structures_to_relax:,} structures...") + from ase.build import bulk + + al_atoms = bulk("Al", "hcp", a=4.05) + al_atoms.positions += 0.1 * np.random.randn(*al_atoms.positions.shape) # noqa: NPY002 + fe_atoms = bulk("Fe", "bcc", a=2.86).repeat((2, 2, 2)) + fe_atoms.positions += 0.1 * np.random.randn(*fe_atoms.positions.shape) # noqa: NPY002 + ase_atoms_list = [al_atoms, fe_atoms] # --- Optimization Setup --- # Statistics tracking