## Torch â†’ JAX foundation model initialisation
---
This notebook downloads a pre-trained Torch MACE foundation model, ports it to the
JAX implementation using the same helpers that power the `mace_torch2jax` CLI, and then
runs both versions on a shared example batch.


In [1]:
import json
import os
from pathlib import Path

import ase
import jax
import jax.numpy as jnp
import numpy as np
import torch

jax.config.update('jax_enable_x64', True)

from mace.calculators import foundations_models
from mace.data.atomic_data import AtomicData
from mace.data.utils import config_from_atoms
from mace.tools import torch_geometric
from mace.tools.multihead_tools import AtomicNumberTable
from mace.tools.scripts_utils import extract_config_mace_model
from mace.tools.torch_geometric.batch import Batch

from mace_jax.cli.mace_torch2jax import convert_model

DEVICE = 'cpu'  # Change to 'cpu' if CUDA is not available
if DEVICE == 'cuda':
    # Determinism Configuration
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
    torch.use_deterministic_algorithms(True)
    os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":4096:8"
    # Set XLA flag for deterministic GPU operations in JAX
    os.environ["XLA_FLAGS"] = "--xla_gpu_deterministic_ops=true"

print(f"JAX devices: {jax.devices()}")
print(f"Torch CUDA available: {torch.cuda.is_available()}")

  _Jd, _W3j_flat, _W3j_indices = torch.load(os.path.join(os.path.dirname(__file__), 'constants.pt'))


JAX devices: [CudaDevice(id=0)]
Torch CUDA available: True


---
### Build a Torch reference model
---

In [2]:
def load_foundation_model(source: str = 'mp', variant: str | None = None, device: str = 'cpu') -> torch.nn.Module:
    """Return a pretrained Torch MACE foundation model on the specified device."""
    loader_kwargs = {'device': device}
    source_lower = source.lower()
    if source_lower in {'mp', 'off', 'omol'}:
        loader = getattr(foundations_models, f'mace_{source_lower}')
        if variant is not None:
            loader_kwargs['model'] = variant
    elif source_lower == 'anicc':
        loader = foundations_models.mace_anicc
        if variant is not None:
            loader_kwargs['model_path'] = variant
    else:
        raise ValueError(
            "Unknown foundation source. Expected one of {'mp', 'off', 'anicc', 'omol'}."
        )

    try:
        model = loader(return_raw_model=True, **loader_kwargs)
    except Exception:
        calculator = loader(return_raw_model=False, **loader_kwargs)
        model = getattr(calculator, 'model', None)
        if model is None:
            models = getattr(calculator, 'models', None)
            if models:
                model = models[0]
        if model is None:
            raise

    return model.float().eval()

def extract_foundation_metadata(torch_model):
    """Extract config, atomic number table, and cutoff from a Torch foundation model."""
    config = extract_config_mace_model(torch_model)
    config['torch_model_class'] = torch_model.__class__.__name__
    atomic_numbers = tuple(int(z) for z in config['atomic_numbers'])
    z_table = AtomicNumberTable(atomic_numbers)
    cutoff = float(config['r_max'])
    return config, z_table, cutoff

In [3]:
FOUNDATION_SOURCE = 'mp'
FOUNDATION_VARIANT = 'medium-mpa-0'

torch_model = load_foundation_model(FOUNDATION_SOURCE, FOUNDATION_VARIANT, device=DEVICE)
config, z_table, cutoff = extract_foundation_metadata(torch_model)
torch_model = torch_model.to(torch.device(DEVICE))
print(f"Torch model loaded on device: {next(torch_model.parameters()).device}")

Using medium MPA-0 model as default MACE-MP model, to use previous (before 3.10) default model please specify 'medium' as model argument
Using Materials Project MACE for MACECalculator with /home/abhijeet/.cache/mace/macempa0mediummodel
Using float32 for MACECalculator, which is faster but less accurate. Recommended for MD. Use float64 for geometry optimization.
Torch model loaded on device: cpu


---
### Convert Torch model to JAX
---

In [4]:
jax_model, variables, template_batch = convert_model(torch_model, config)

---
### Prepare an evaluation batch
---

In [5]:
def batch_to_jax(batch: Batch) -> dict[str, jnp.ndarray]:
    converted = {}
    for key in batch.keys:
        value = batch[key]
        if isinstance(value, torch.Tensor):
            converted[key] = jnp.asarray(value.detach().cpu().numpy())
        else:
            converted[key] = value
    return converted

# Example SrTiO3 structure as before
a0 = 3.905
atoms = ase.Atoms(
    symbols=['Sr', 'Ti', 'O', 'O', 'O'],
    positions=a0 * np.array([
        [0.0, 0.0, 0.0],
        [0.5, 0.5, 0.5],
        [0.5, 0.5, 0.0],
        [0.5, 0.0, 0.5],
        [0.0, 0.5, 0.5],
    ]),
    cell=a0 * np.identity(3),
    pbc=[True, True, True],
)
config_atoms = config_from_atoms(atoms)
config_atoms.pbc = [bool(x) for x in config_atoms.pbc]
atomic_data = AtomicData.from_config(
    config_atoms,
    z_table=z_table,
    cutoff=cutoff,
)
batch_torch = torch_geometric.batch.Batch.from_data_list([atomic_data])
batch_torch = batch_torch.to(torch.device(DEVICE))
print(f"Batch moved to device: {batch_torch.positions.device}")

batch_jax = batch_to_jax(batch_torch)

Batch moved to device: cpu


---
### Compare Torch and JAX outputs
---

In [None]:
jax_out = jax_model.apply(variables, batch_jax, compute_stress=True)
torch_out = torch_model(batch_torch, compute_stress=True)

print('JAX energy:', jax_out['energy'])
print('Torch energy:', torch_out['energy'])
print('JAX forces:', jax_out['forces'])
print('Torch forces:', torch_out['forces'])
print('JAX stress:', jax_out['stress'])
print('Torch stress:', torch_out['stress'])

---
### Foundation model energy sanity checks (JSON-loaded structures)
---

In [None]:
def atoms_from_serialised(record):
    struct = record['structure']
    lattice = np.asarray(struct['lattice'], dtype=float)
    symbols = [site['species'] for site in struct['sites']]
    coords = np.asarray([site['xyz'] for site in struct['sites']], dtype=float)
    atoms = ase.Atoms(symbols=symbols, positions=coords, cell=lattice, pbc=[True, True, True])
    if len(atoms) < 5:
        atoms = atoms.repeat((2, 2, 2))
    return atoms

records_path = Path('material_structures.json')
with records_path.open() as fh:
    structures = json.load(fh)

results = []
for record in structures:
    atoms = atoms_from_serialised(record)
    config_atoms = config_from_atoms(atoms)
    config_atoms.pbc = [bool(x) for x in config_atoms.pbc]
    atomic_data = AtomicData.from_config(
        config_atoms,
        z_table=z_table,
        cutoff=cutoff,
        heads=config.get('heads'),
    )
    batch_torch = torch_geometric.batch.Batch.from_data_list([atomic_data])
    batch_torch = batch_torch.to(torch.device(DEVICE))
    batch_jax = batch_to_jax(batch_torch)

    torch_output = torch_model(batch_torch, compute_stress=False)
    jax_output = jax_model.apply(variables, batch_jax, compute_stress=False)

    results.append(
        {
            'material': record['label'],
            'mp_id': record['mp_id'],
            'torch_energy': torch_output['energy'].detach().cpu().numpy()[0],
            'jax_energy': float(np.asarray(jax_output['energy'])[0]),
        }
    )

print(f"Foundation model: {torch_model.__class__.__name__} ({FOUNDATION_VARIANT})")
for entry in results:
    print(f"Material: {entry['material']} (MP ID: {entry['mp_id']})")
    print(f"  Torch ({torch_model.__class__.__name__}) energy: {entry['torch_energy']:.6f} eV")
    print(f"  JAX ({jax_model.__class__.__name__}) energy:   {entry['jax_energy']:.6f} eV")
    print()

Foundation model: ScaleShiftMACE (medium-mpa-0)
Material: Al2O3 (MP ID: mp-1143)
  Torch (ScaleShiftMACE) energy: -74.799065 eV
  JAX (ScaleShiftMACE) energy:   -74.799072 eV

Material: MgSiO3 (MP ID: mp-3470)
  Torch (ScaleShiftMACE) energy: -287.229797 eV
  JAX (ScaleShiftMACE) energy:   -287.229817 eV

Material: Li2O (MP ID: mp-1960)
  Torch (ScaleShiftMACE) energy: -113.934143 eV
  JAX (ScaleShiftMACE) energy:   -113.934143 eV



In [None]:
# JIT compile the JAX model for better performance
jit_jax_model = jax.jit(jax_model.apply)

# Test JIT compiled model
jit_jax_output = jit_jax_model(variables, batch_jax, compute_stress=False)
print('JIT JAX energy:', jit_jax_output['energy'])
print('JIT JAX forces:', jit_jax_output['forces'])

TracerBoolConversionError: Attempted boolean conversion of traced array with shape bool[].
The error occurred while tracing the function apply at /home/abhijeet/.local/share/mamba/envs/macejax_env/lib/python3.12/site-packages/flax/linen/module.py:2088 for jit. This value became a tracer due to JAX operations on these lines:

  operation a[35m:bool[24][39m = lt b 0:i32[]
    from line /home/abhijeet/Desktop/Home/Code/mace-jax-dev/mace-jax/mace_jax/modules/utils.py:255:21 (prepare_graph)

  operation a[35m:i32[24][39m = add b 1:i32[]
    from line /home/abhijeet/Desktop/Home/Code/mace-jax-dev/mace-jax/mace_jax/modules/utils.py:255:21 (prepare_graph)

  operation a[35m:f64[1,3,3][39m = broadcast_in_dim[
  broadcast_dimensions=()
  shape=(1, 3, 3)
  sharding=None
] 0.0:f64[]
    from line /home/abhijeet/Desktop/Home/Code/mace-jax-dev/mace-jax/mace_jax/modules/utils.py:301:19 (prepare_graph)

  operation a[35m:bool[2512][39m = lt b 0:i32[]
    from line /home/abhijeet/Desktop/Home/Code/mace-jax-dev/mace-jax/mace_jax/modules/utils.py:213:14 (get_edge_vectors_and_lengths)

  operation a[35m:i32[2512][39m = add b 24:i32[]
    from line /home/abhijeet/Desktop/Home/Code/mace-jax-dev/mace-jax/mace_jax/modules/utils.py:213:14 (get_edge_vectors_and_lengths)

(Additional originating lines are not shown.)
See https://docs.jax.dev/en/latest/errors.html#jax.errors.TracerBoolConversionError