In [None]:
from pathlib import Path

from ase.build import bulk
from ase.filters import FrechetCellFilter
from ase.io import Trajectory
from ase.optimize import LBFGSLineSearch

Import `ASECalculator` objects for the different ML potentials

In [None]:
# import matgl
# from matgl.ext.ase import M3GNetCalculator
# from chgnet.model import CHGNetCalculator
from mace.calculators import mace_mp

# from orb_models.forcefield import pretrained
# from orb_models.forcefield.calculator import ORBCalculator
from sevenn.sevennet_calculator import SevenNetCalculator

In [None]:
device = "cpu"  # e.g. 'cuda:0' or 'cpu'

calculators = {
    "sevenn": SevenNetCalculator("7net-0", device=device),
    "mace": mace_mp(device=device, default_dtype="float64"),
    #    "chgnet": CHGNetCalculator(use_device=device),
    #    "m3gnet": M3GNetCalculator(potential=matgl.load_model("M3GNet-MP-2021.2.8-PES")),
    # "orb": ORBCalculator(pretrained.orb_v2(device=device), device=device),
}

In [None]:
structure = bulk("CeO2", "fluorite", a=5.411)

In [None]:
optimized_structures = {}

for _name, calculator in calculators.items():
    structure.calc = calculator
    fcf = FrechetCellFilter(structure)
    opt = LBFGSLineSearch(fcf)
    opt.attach(Trajectory("opt.traj", "w", structure))
    opt.run(fmax=0.001)

    optimized_structures[_name] = structure
    Path("opt.traj").unlink()

In [None]:
from numpy.testing import assert_allclose

# Assert that energies and lattice parameters fit to the expected values
expected_values = {
    "sevenn": {"energy": -26.181656, "lattice_param": 5.411},
    "mace": {"energy": -26.181656, "lattice_param": 5.411},
    "orb": {"energy": --26.181656, "lattice_param": 5.411},
    #    "chgnet": {"energy": -26.181656, "lattice_param": 5.411},
    #    "m3gnet": {"energy": -26.181656, "lattice_param": 5.411},
}

for name, structure in optimized_structures.items():
    assert_allclose(structure.get_potential_energy(), expected_values[name]["energy"], atol=1e-2)
    print(structure.cell)
    # assert_allclose(structure.cell[0, 0], expected_values[name]["lattice_param"], atol=1e-2)