---
## Torch â†’ JAX model initialisation
---
This notebook shows how to take an existing Torch MACE model, convert it
to the JAX implementation using the same helpers that power the
`mace_torch2jax` CLI, and then run both versions on the same example batch.

In [1]:
import json
from pathlib import Path

import ase
import jax
import jax.numpy as jnp
import numpy as np
import torch

jax.config.update('jax_enable_x64', True)

from mace.data.atomic_data import AtomicData
from mace.data.utils import KeySpecification, config_from_atoms
from mace.tools import build_default_arg_parser, check_args, torch_geometric
from mace.tools.model_script_utils import configure_model as configure_model_torch
from mace.tools.multihead_tools import AtomicNumberTable, prepare_default_head
from mace.tools.scripts_utils import extract_config_mace_model
from mace.tools.torch_geometric.batch import Batch

from mace_jax.cli.mace_torch2jax import convert_model


  _Jd, _W3j_flat, _W3j_indices = torch.load(os.path.join(os.path.dirname(__file__), 'constants.pt'))
W1019 11:01:32.640000 1154533 /home/pbenner/Env/mace-jax/.venv/lib/python3.12/site-packages/torch/utils/cpp_extension.py:118] No CUDA runtime is found, using CUDA_HOME='/usr/local/cuda'


---
### Build a Torch reference model
---

In [2]:
def load_statistics(path: Path = Path('statistics.json')):
    data = json.loads(path.read_text())
    stats = dict(data)
    stats['atomic_numbers'] = AtomicNumberTable(stats['atomic_numbers'])
    stats['atomic_energies'] = [
        stats['atomic_energies'][str(z)] for z in stats['atomic_numbers'].zs
    ]
    return stats


def build_example_args():
    arguments = [
        '--name', 'MACE_large_density',
        '--interaction_first', 'RealAgnosticDensityInteractionBlock',
        '--interaction', 'RealAgnosticDensityResidualInteractionBlock',
        '--num_channels', '128',
        '--max_L', '2',
        '--max_ell', '3',
        '--num_interactions', '3',
        '--correlation', '3',
        '--num_radial_basis', '8',
        '--MLP_irreps', '16x0e',
        '--distance_transform', 'Agnesi',
        '--pair_repulsion',
        '--only_cueq', 'True',
    ]
    parser = build_default_arg_parser()
    args = parser.parse_args(arguments)
    args, _ = check_args(args)
    return args


def build_torch_model(args, statistics):
    args.mean = statistics['mean']
    args.std = statistics['std']
    args.compute_energy = True
    args.compute_dipole = False
    args.key_specification = KeySpecification.from_defaults()
    args.heads = prepare_default_head(args)

    torch.set_default_dtype(torch.float32)
    model, _ = configure_model_torch(
        args,
        None,
        statistics['atomic_energies'],
        z_table=statistics['atomic_numbers'],
        heads=args.heads,
        model_foundation=None,
    )
    model.eval()
    return model


In [3]:
statistics = load_statistics(Path('statistics.json'))
torch_args = build_example_args()
torch_model = build_torch_model(torch_args, statistics)




---
### Convert Torch model to JAX
---

In [4]:
config = extract_config_mace_model(torch_model)
config['torch_model_class'] = torch_model.__class__.__name__
jax_model, variables, template_batch = convert_model(torch_model, config)




---
### Prepare an evaluation batch
---

In [5]:
def batch_to_jax(batch: Batch) -> dict[str, jnp.ndarray]:
    converted = {}
    for key in batch.keys:
        value = batch[key]
        if isinstance(value, torch.Tensor):
            converted[key] = jnp.asarray(value.detach().cpu().numpy())
        else:
            converted[key] = value
    return converted

# Example SrTiO3 structure as before
a0 = 3.905
atoms = ase.Atoms(
    symbols=['Sr', 'Ti', 'O', 'O', 'O'],
    positions=a0 * np.array([
        [0.0, 0.0, 0.0],
        [0.5, 0.5, 0.5],
        [0.5, 0.5, 0.0],
        [0.5, 0.0, 0.5],
        [0.0, 0.5, 0.5],
    ]),
    cell=a0 * np.identity(3),
    pbc=[True, True, True],
)
config_atoms = config_from_atoms(atoms)
config_atoms.pbc = [bool(x) for x in config_atoms.pbc]
atomic_data = AtomicData.from_config(
    config_atoms,
    z_table=statistics['atomic_numbers'],
    cutoff=torch_args.r_max,
)
batch_torch = torch_geometric.batch.Batch.from_data_list([atomic_data])
batch_jax = batch_to_jax(batch_torch)


---
### Compare Torch and JAX outputs
---

In [6]:
jax_out = jax_model.apply(variables, batch_jax, compute_stress=True)
torch_out = torch_model(batch_torch, compute_stress=True)

print('JAX energy:', jax_out['energy'])
print('Torch energy:', torch_out['energy'])
print('JAX forces:', jax_out['forces'])
print('Torch forces:', torch_out['forces'])
print('JAX stress:', jax_out['stress'])
print('Torch stress:', torch_out['stress'])


JAX energy: [-34.23169481]
Torch energy: tensor([-34.3235], grad_fn=<AddBackward0>)
JAX forces: [[-3.5766813e-16 -3.1762380e-16  2.2593418e-16]
 [-6.0715322e-18  8.6447653e-18 -3.7816972e-16]
 [ 1.0043812e-15  3.8329596e-16  2.6496546e-16]
 [-1.8431437e-16  1.8252544e-16  3.2289574e-16]
 [-4.4876060e-16 -2.5798932e-16 -4.5882081e-16]]
Torch forces: tensor([[-2.6441e-08, -2.2097e-08, -5.8884e-08],
        [ 8.6002e-09, -6.4191e-09, -7.6870e-09],
        [-1.4959e-08, -6.3694e-08,  9.9535e-08],
        [ 5.6098e-08,  7.1595e-08, -3.0195e-08],
        [-2.2999e-08,  1.9916e-08, -7.3269e-09]])
JAX stress: [[[-1.6334608e-02 -2.8270712e-11  4.1750032e-12]
  [-2.8270712e-11 -1.6334604e-02 -4.9264953e-11]
  [ 4.1750032e-12 -4.9264953e-11 -1.6334604e-02]]]
Torch stress: tensor([[[-2.6785e-03, -9.2789e-10, -1.9385e-09],
         [-9.2789e-10, -2.6785e-03,  1.0321e-08],
         [-1.9385e-09,  1.0321e-08, -2.6785e-03]]])
