## Torch â†’ JAX foundation model initialisation
---
This notebook downloads a pre-trained Torch MACE foundation model, ports it to the
JAX implementation using the same helpers that power the `mace_torch2jax` CLI, and then
runs both versions on a shared example batch.


In [1]:
import json
from pathlib import Path

import ase
import jax
import jax.numpy as jnp
import numpy as np
import torch

jax.config.update('jax_enable_x64', True)

from mace.calculators import foundations_models
from mace.data.atomic_data import AtomicData
from mace.data.utils import config_from_atoms
from mace.tools import torch_geometric
from mace.tools.multihead_tools import AtomicNumberTable
from mace.tools.scripts_utils import extract_config_mace_model
from mace.tools.torch_geometric.batch import Batch

from mace_jax.cli.mace_torch2jax import convert_model
from mace_jax.tools.device import (
    configure_torch_runtime,
    get_torch_device,
    runtime_device_summary,
)

torch_device = configure_torch_runtime(
    get_torch_device(),
    deterministic=True,
)

print(f'Selected Torch device: {torch_device}')
print(f'Device summary: {runtime_device_summary(torch_device)}')


  _Jd, _W3j_flat, _W3j_indices = torch.load(os.path.join(os.path.dirname(__file__), 'constants.pt'))


Selected Torch device: cpu
Device summary: {'jax_devices': ['cpu:0'], 'torch_device': 'cpu', 'torch_cuda_available': False}


---
### Build a Torch reference model
---

In [2]:
def load_foundation_model(source: str = 'mp', variant: str | None = None, device: str = 'cpu') -> torch.nn.Module:
    """Return a pretrained Torch MACE foundation model on the specified device."""
    loader_kwargs = {'device': device}
    source_lower = source.lower()
    if source_lower in {'mp', 'off', 'omol'}:
        loader = getattr(foundations_models, f'mace_{source_lower}')
        if variant is not None:
            loader_kwargs['model'] = variant
    elif source_lower == 'anicc':
        loader = foundations_models.mace_anicc
        if variant is not None:
            loader_kwargs['model_path'] = variant
    else:
        raise ValueError(
            "Unknown foundation source. Expected one of {'mp', 'off', 'anicc', 'omol'}."
        )

    try:
        model = loader(return_raw_model=True, **loader_kwargs)
    except Exception:
        calculator = loader(return_raw_model=False, **loader_kwargs)
        model = getattr(calculator, 'model', None)
        if model is None:
            models = getattr(calculator, 'models', None)
            if models:
                model = models[0]
        if model is None:
            raise

    return model.float().eval()

def extract_foundation_metadata(torch_model):
    """Extract config, atomic number table, and cutoff from a Torch foundation model."""
    config = extract_config_mace_model(torch_model)
    config['torch_model_class'] = torch_model.__class__.__name__
    atomic_numbers = tuple(int(z) for z in config['atomic_numbers'])
    z_table = AtomicNumberTable(atomic_numbers)
    cutoff = float(config['r_max'])
    return config, z_table, cutoff

In [3]:
FOUNDATION_SOURCE = 'mp'
FOUNDATION_VARIANT = 'medium-mpa-0'

torch_model = load_foundation_model(FOUNDATION_SOURCE, FOUNDATION_VARIANT)
config, z_table, cutoff = extract_foundation_metadata(torch_model)
torch_model = torch_model.to(torch_device)
print(f"Torch model loaded on device: {next(torch_model.parameters()).device}")


Using Materials Project MACE for MACECalculator with /home/philipp/Source/mace-jax/mace/mace/calculators/foundations_models/mace-mpa-0-medium.model
Using float32 for MACECalculator, which is faster but less accurate. Recommended for MD. Use float64 for geometry optimization.
Torch model loaded on device: cpu


---
### Convert Torch model to JAX
---

In [4]:
jax_model, variables, template_batch = convert_model(torch_model, config)

---
### Prepare an evaluation batch
---

In [5]:
def batch_to_jax(batch: Batch) -> dict[str, jnp.ndarray]:
    converted = {}
    for key in batch.keys:
        value = batch[key]
        if isinstance(value, torch.Tensor):
            converted[key] = jnp.asarray(value.detach().cpu().numpy())
        else:
            converted[key] = value
    return converted

# Example SrTiO3 structure as before
a0 = 3.905
atoms = ase.Atoms(
    symbols=['Sr', 'Ti', 'O', 'O', 'O'],
    positions=a0 * np.array([
        [0.0, 0.0, 0.0],
        [0.5, 0.5, 0.5],
        [0.5, 0.5, 0.0],
        [0.5, 0.0, 0.5],
        [0.0, 0.5, 0.5],
    ]),
    cell=a0 * np.identity(3),
    pbc=[True, True, True],
)
config_atoms = config_from_atoms(atoms)
config_atoms.pbc = [bool(x) for x in config_atoms.pbc]
atomic_data = AtomicData.from_config(
    config_atoms,
    z_table=z_table,
    cutoff=cutoff,
)
batch_torch = torch_geometric.batch.Batch.from_data_list([atomic_data])
batch_torch = batch_torch.to(torch_device)
print(f"Batch moved to device: {batch_torch.positions.device}")

batch_jax = batch_to_jax(batch_torch)


Batch moved to device: cpu


---
### Compare Torch and JAX outputs
---

In [6]:
jax_out = jax_model.apply(variables, batch_jax, compute_stress=True)
torch_out = torch_model(batch_torch, compute_stress=True)

print('JAX energy:', jax_out['energy'])
print('Torch energy:', torch_out['energy'])
print('JAX forces:', jax_out['forces'])
print('Torch forces:', torch_out['forces'])
print('JAX stress:', jax_out['stress'])
print('Torch stress:', torch_out['stress'])

JAX energy: [-40.09476294]
Torch energy: tensor([-40.0948], grad_fn=<AddBackward0>)
JAX forces: [[-4.1306920e-16 -5.4752210e-17 -4.2836828e-16]
 [-2.6232306e-16 -4.1221279e-16  1.3268161e-15]
 [ 1.1535911e-16 -7.8452869e-16 -5.8373445e-16]
 [ 5.2041704e-18  3.9725168e-16  1.5768636e-15]
 [ 7.1730747e-16  9.8944290e-16 -2.0007867e-15]]
Torch forces: tensor([[ 1.2655e-07,  1.4831e-07,  1.4075e-07],
        [-2.3149e-07,  1.0955e-07, -5.5951e-07],
        [-1.6717e-07, -1.1898e-07,  4.4797e-07],
        [ 1.4715e-07,  7.8697e-07,  1.6834e-07],
        [ 2.3330e-07, -6.5658e-07, -1.4051e-07]])
JAX stress: [[[-3.3666383e-02 -1.4657818e-09 -2.0612556e-09]
  [-1.4657818e-09 -3.3666413e-02  1.2978275e-09]
  [-2.0612556e-09  1.2978275e-09 -3.3666506e-02]]]
Torch stress: tensor([[[-3.3666e-02,  1.2668e-08, -2.3222e-08],
         [ 1.2668e-08, -3.3666e-02,  6.4436e-09],
         [-2.3222e-08,  6.4436e-09, -3.3666e-02]]])


---
### Foundation model energy sanity checks (JSON-loaded structures)
---

In [7]:
def atoms_from_serialised(record):
    struct = record['structure']
    lattice = np.asarray(struct['lattice'], dtype=float)
    symbols = [site['species'] for site in struct['sites']]
    coords = np.asarray([site['xyz'] for site in struct['sites']], dtype=float)
    atoms = ase.Atoms(symbols=symbols, positions=coords, cell=lattice, pbc=[True, True, True])
    if len(atoms) < 5:
        atoms = atoms.repeat((2, 2, 2))
    return atoms

records_path = Path('material_structures.json')
with records_path.open() as fh:
    structures = json.load(fh)

results = []
for record in structures:
    atoms = atoms_from_serialised(record)
    config_atoms = config_from_atoms(atoms)
    config_atoms.pbc = [bool(x) for x in config_atoms.pbc]
    atomic_data = AtomicData.from_config(
        config_atoms,
        z_table=z_table,
        cutoff=cutoff,
        heads=config.get('heads'),
    )
    batch_torch = torch_geometric.batch.Batch.from_data_list([atomic_data])
    batch_torch = batch_torch.to(torch_device)
    batch_jax = batch_to_jax(batch_torch)

    torch_output = torch_model(batch_torch, compute_stress=False)
    jax_output = jax_model.apply(variables, batch_jax, compute_stress=False)

    results.append(
        {
            'material': record['label'],
            'mp_id': record['mp_id'],
            'torch_energy': torch_output['energy'].detach().cpu().numpy()[0],
            'jax_energy': float(np.asarray(jax_output['energy'])[0]),
        }
    )

print(f"Foundation model: {torch_model.__class__.__name__} ({FOUNDATION_VARIANT})")
for entry in results:
    print(f"Material: {entry['material']} (MP ID: {entry['mp_id']})")
    print(f"  Torch ({torch_model.__class__.__name__}) energy: {entry['torch_energy']:.6f} eV")
    print(f"  JAX ({jax_model.__class__.__name__}) energy:   {entry['jax_energy']:.6f} eV")
    print()


Foundation model: ScaleShiftMACE (medium-mpa-0)
Material: Al2O3 (MP ID: mp-1143)
  Torch (ScaleShiftMACE) energy: -74.799065 eV
  JAX (ScaleShiftMACE) energy:   -74.799073 eV

Material: MgSiO3 (MP ID: mp-3470)
  Torch (ScaleShiftMACE) energy: -287.229797 eV
  JAX (ScaleShiftMACE) energy:   -287.229823 eV

Material: Li2O (MP ID: mp-1960)
  Torch (ScaleShiftMACE) energy: -113.934151 eV
  JAX (ScaleShiftMACE) energy:   -113.934157 eV



In [9]:
# JIT compile the JAX model for better performance
jit_jax_model = jax.jit(jax_model.apply)

# Test JIT compiled model
jit_jax_output = jit_jax_model(variables, batch_jax, compute_stress=False)
print('JIT JAX energy:', jit_jax_output['energy'])
print('JIT JAX forces:', jit_jax_output['forces'])

JIT JAX energy: [-113.9341567]
JIT JAX forces: [[ 3.93212531e-07 -1.18850799e-07  7.89958847e-07]
 [ 3.13900216e-07 -4.03335633e-07  6.41009734e-08]
 [ 1.31217206e-07 -5.39205018e-07  1.31559858e-07]
 [ 3.45689614e-07 -1.64081996e-07  5.38921903e-08]
 [ 3.92819516e-07 -3.31421603e-07  4.15970163e-07]
 [ 5.41347049e-07 -2.32153468e-07  8.29453882e-07]
 [ 2.87010636e-07  5.10786492e-07 -4.83557983e-07]
 [ 2.91991569e-07 -1.16606813e-07  6.12360722e-08]
 [ 4.19995217e-07 -6.42548343e-08  4.78216691e-07]
 [-9.23437948e-09  1.38274970e-06  3.08417896e-07]
 [ 3.73079530e-07  1.26730654e-07 -7.13185329e-07]
 [ 3.94275816e-07 -8.22094250e-08 -9.38564597e-07]
 [-1.27960641e-06 -1.18716400e-07 -4.70802064e-07]
 [ 8.99598689e-08 -3.58053484e-07  5.29210240e-08]
 [ 1.42866014e-07 -3.84033370e-07  4.80888389e-07]
 [-5.63067260e-07 -6.82715950e-07  3.04773749e-07]
 [ 3.41937096e-07 -3.65845324e-07 -7.06360822e-07]
 [ 1.09914275e-07 -4.04661563e-07 -9.27937663e-07]
 [-8.41715121e-07  1.25499434e-06  