In [1]:
import argparse
import json

import jax

jax.config.update("jax_enable_x64", True)
import jax.numpy as jnp
import torch

from mace_jax import modules


---
## Prepare data
---

In [2]:
import ase
import numpy as np
from mace.data.atomic_data import AtomicData
from mace.data.utils import config_from_atoms
from mace.tools import torch_geometric
from mace.tools.torch_geometric.batch import Batch

In [3]:
def batch_to_jax(batch: Batch):
    """Convert every tensor field in a PyG Batch to JAX arrays."""
    jax_dict = {}

    # .keys is already a list of attribute names
    for key in batch.keys:
        value = batch[key]

        if isinstance(value, torch.Tensor):
            arr = value.detach().cpu().numpy()     # to numpy
            jax_dict[key] = jnp.asarray(arr)       # to JAX
        else:
            jax_dict[key] = value                  # leave as-is

    return jax_dict

In [4]:
def get_statistics(filename = 'statistics.json'):

    from mace.tools.multihead_tools import (  # noqa: PLC0415
        AtomicNumberTable,
    )

    print(f"Reading statistics from `{filename}`")

    with open(filename) as f:
        statistics = json.load(f)

    statistics['atomic_numbers'] = AtomicNumberTable(statistics['atomic_numbers'])
    statistics['atomic_energies'] = [ statistics['atomic_energies'][str(i)] for i in statistics['atomic_numbers'].zs ]

    return statistics

statistics = get_statistics()

W1010 17:13:12.401000 204324 /home/pbenner/Env/mace-jax/.venv/lib/python3.12/site-packages/torch/utils/cpp_extension.py:118] No CUDA runtime is found, using CUDA_HOME='/usr/local/cuda'


Reading statistics from `statistics.json`


In [5]:
def get_model_args():

    from mace.tools import build_default_arg_parser, check_args  # noqa: PLC0415

    arguments = [
        "--name"              , "MACE_large_density",
        "--interaction_first" , "RealAgnosticDensityInteractionBlock",
        "--interaction"       , "RealAgnosticDensityResidualInteractionBlock",
        "--num_channels"      , "128",
        "--max_L"             , "2",
        "--max_ell"           , "3",
        "--num_interactions"  , "3",
        "--correlation"       , "3",
        "--num_radial_basis"  , "8",
        "--MLP_irreps"        , "16x0e",
        "--distance_transform", "Agnesi",
        "--pair_repulsion",
        "--only_cueq", "True"
    ]

    args    = build_default_arg_parser().parse_args(arguments)
    args, _ = check_args(args)

    return args

model_args = get_model_args()

In [6]:
# SrTiO3 perovskite unit cell
a0 = 3.905  # Angstrom
atoms = ase.Atoms(
    symbols=['Sr', 'Ti', 'O', 'O', 'O'],
    positions=a0 * np.array(
    [
        [0.0, 0.0, 0.0],
        [0.5, 0.5, 0.5],
        [0.5, 0.5, 0.0],
        [0.5, 0.0, 0.5],
        [0.0, 0.5, 0.5],
    ]),
    cell=a0 * np.identity(3),
    pbc=[True, True, True],
)

In [7]:
config = config_from_atoms(atoms)
config.pbc = [bool(x) for x in config.pbc]
x = AtomicData.from_config(config, z_table=statistics['atomic_numbers'], cutoff=2.0)
x = torch_geometric.batch.Batch.from_data_list([x])

---
## Prepare Torch model
---

In [8]:
from mace.tools.model_script_utils import configure_model as configure_model_torch

In [9]:
def get_model_torch(args: argparse.Namespace, statistics):

    from mace.data.utils import KeySpecification  # noqa: PLC0415
    from mace.tools.multihead_tools import (  # noqa: PLC0415
        prepare_default_head,
    )

    torch.set_default_dtype(torch.float32)

    args.mean = statistics['mean']
    args.std = statistics['std']
    args.compute_energy = True
    args.compute_dipole = False
    args.key_specification = KeySpecification.from_defaults()
    args.heads = prepare_default_head(args)

    model, _ = configure_model_torch(
        args,
        None,
        statistics['atomic_energies'],
        z_table = statistics['atomic_numbers'],
        heads = args.heads,
        model_foundation = None,
    )
    return model

In [10]:
model_torch = get_model_torch(model_args, statistics)



---
## Prepare JAX model
---

In [11]:
def configure_model_jax(
    args,
    atomic_energies,
    z_table=None,
    model_foundation=None,
    head_configs=None,
):
    import ast  # noqa: PLC0415

    from e3nn_jax import Irreps  # noqa: PLC0415

    from mace_jax import modules  # noqa: PLC0415

    model_config = dict(
        r_max=args.r_max,
        num_bessel=args.num_radial_basis,
        num_polynomial_cutoff=args.num_cutoff_basis,
        max_ell=args.max_ell,
        interaction_cls=modules.interaction_classes[args.interaction],
        num_interactions=args.num_interactions,
        num_elements=len(z_table),
        hidden_irreps=Irreps(args.hidden_irreps),
        edge_irreps=Irreps(args.edge_irreps) if args.edge_irreps else None,
        atomic_energies=atomic_energies,
        apply_cutoff=args.apply_cutoff,
        avg_num_neighbors=args.avg_num_neighbors,
        atomic_numbers=tuple(int(z) for z in z_table.zs),
        use_reduced_cg=args.use_reduced_cg,
        use_so3=args.use_so3,
        cueq_config=None,
    )
    return modules.ScaleShiftMACE(
        **model_config,
        pair_repulsion=args.pair_repulsion,
        distance_transform=args.distance_transform,
        correlation=args.correlation,
        gate=modules.gate_dict[args.gate],
        interaction_cls_first=modules.interaction_classes[args.interaction_first],
        MLP_irreps=Irreps(args.MLP_irreps),
        atomic_inter_scale=args.std,
        atomic_inter_shift=args.mean,
        radial_MLP=ast.literal_eval(args.radial_MLP),
        radial_type=args.radial_type,
        heads=tuple(args.heads) if args.heads is not None else None,
        embedding_specs=args.embedding_specs,
        use_embedding_readout=args.use_embedding_readout,
        use_last_readout_only=args.use_last_readout_only,
        use_agnostic_product=args.use_agnostic_product,
    )

In [12]:
def get_model_jax(args: argparse.Namespace, statistics):

    from mace.data.utils import KeySpecification  # noqa: PLC0415
    from mace.tools.multihead_tools import (  # noqa: PLC0415
        prepare_default_head,
    )

    args.mean = statistics['mean']
    args.std = statistics['std']
    args.compute_energy = True
    args.compute_dipole = False
    args.key_specification = KeySpecification.from_defaults()
    args.heads = prepare_default_head(args)

    model = configure_model_jax(
        args,
        statistics['atomic_energies'],
        z_table = statistics['atomic_numbers'],
        model_foundation = None,
    )
    return model


In [13]:
batch_jax = batch_to_jax(x)

model_jax = get_model_jax(model_args, statistics)

print("Initializing Flax model, this may take a while...")
rng = jax.random.PRNGKey(42)
variables = model_jax.init(rng, batch_jax)
variables = modules.ScaleShiftMACE.import_from_torch(model_torch, variables)



Initializing Flax model, this may take a while...


In [14]:
result_jax = model_jax.apply(variables, batch_jax, compute_stress=True)
print(f"energy: {result_jax['energy']}\nforces: {result_jax['forces']}\nstress: {result_jax['stress']}")

energy: [-34.20229984]
forces: [[-0. -0. -0.]
 [-0. -0. -0.]
 [-0. -0. -0.]
 [-0. -0. -0.]
 [-0. -0. -0.]]
stress: [[[-0.00320603  0.          0.        ]
  [ 0.         -0.00320603  0.        ]
  [ 0.          0.         -0.00320603]]]


In [15]:
result_torch = model_torch(x, compute_stress=True)
print(f"energy: {result_torch['energy']}\nforces: {result_torch['forces']}\nstress: {result_torch['stress']}")

energy: tensor([-34.5958], grad_fn=<AddBackward0>)
forces: tensor([[-0.0000e+00, -0.0000e+00, -0.0000e+00],
        [ 3.7253e-09, -3.7253e-09, -0.0000e+00],
        [-0.0000e+00, -0.0000e+00, -0.0000e+00],
        [-0.0000e+00,  3.7253e-09, -0.0000e+00],
        [-3.7253e-09, -0.0000e+00, -0.0000e+00]])
stress: tensor([[[-0.0027,  0.0000,  0.0000],
         [ 0.0000, -0.0027,  0.0000],
         [ 0.0000,  0.0000, -0.0027]]])
