In [32]:
from typing import Callable
from ase.io import read, write
from functools import partial
from ase.atoms import Atoms
from ase.constraints import FixAtoms
from ase.calculators.calculator import Calculator
from ase.calculators.singlepoint import SinglePointCalculator
from ase.optimize import BFGS
from fairchem.core.units.mlip_unit import load_predict_unit
from fairchem.core import FAIRChemCalculator

In [33]:
def indice_from_tags(atoms: Atoms, tags: list[int]):
    return [i for i, atom in enumerate(atoms) if atom.tag in tags]


def constrain_atoms(atoms_list: list[Atoms], index_fn: Callable) -> list[Atoms]:
    constrained_atoms_list = []
    for atoms in atoms_list:
        energy = atoms.get_potential_energy()
        atoms.pbc = True
        indices = index_fn(atoms)
        atoms.set_constraint(FixAtoms(indices=indices))
        constrained_atoms_list.append(atoms)
        calc = SinglePointCalculator(atoms=atoms, energy=energy)
        atoms.calc = calc
    return constrained_atoms_list

def batch_relax(
    atoms_list: list[Atoms], calc: Calculator, fmax: float = 0.05, steps: int = 100
) -> list[Atoms]:
    relaxed_atoms_list = []
    for atoms in atoms_list:
        energy = atoms.get_potential_energy()
        atoms.calc = calc
        opt = BFGS(atoms, logfile=None)
        opt.run(fmax=fmax, steps=steps)
        calc = SinglePointCalculator(atoms=atoms, energy=energy)
        atoms.calc = calc
        relaxed_atoms_list.append(atoms)
    return relaxed_atoms_list

In [34]:
atoms_list = read("ideal_oc20_val_ood_cat_oh.extxyz", index=":1")
index_fn = partial(indice_from_tags, tags=[0])
slab_list = constrain_atoms(atoms_list=atoms_list, index_fn=index_fn)
predict_unit = load_predict_unit(path="/Users/averyhill/github/oasis/data/checkpoints/uma-s-1p1.pt", device="cpu")
calc = FAIRChemCalculator(predict_unit=predict_unit, task_name="oc20")
relaxed_atoms_list = batch_relax(atoms_list=atoms_list, calc=calc)
write("atoms.extxyz", relaxed_atoms_list)

In [35]:
relaxed_atoms_list[0].get_potential_energy()

np.float64(1.4057107500000257)