## Torch â†’ JAX foundation model initialisation
---
This notebook downloads a pre-trained Torch MACE foundation model, ports it to the
JAX implementation using the same helpers that power the `mace_torch2jax` CLI, and then
runs both versions on a shared example batch.


In [None]:
import json
from pathlib import Path

import ase
import jax
import jax.numpy as jnp
import numpy as np
import torch

jax.config.update('jax_enable_x64', True)

from mace.calculators import foundations_models
from mace.data.atomic_data import AtomicData
from mace.data.utils import config_from_atoms
from mace.tools import torch_geometric
from mace.tools.multihead_tools import AtomicNumberTable
from mace.tools.scripts_utils import extract_config_mace_model
from mace.tools.torch_geometric.batch import Batch

from mace_jax.cli.mace_torch2jax import convert_model

  _Jd, _W3j_flat, _W3j_indices = torch.load(os.path.join(os.path.dirname(__file__), 'constants.pt'))
W1021 19:14:23.687000 607771 /home/pbenner/Env/mace-jax/.venv/lib/python3.12/site-packages/torch/utils/cpp_extension.py:118] No CUDA runtime is found, using CUDA_HOME='/usr/local/cuda'


---
### Build a Torch reference model
---

In [None]:
def load_foundation_model(source: str = 'mp', variant: str | None = None) -> torch.nn.Module:
    """Return a pretrained Torch MACE foundation model on CPU."""
    loader_kwargs = {'device': 'cpu'}
    source_lower = source.lower()
    if source_lower in {'mp', 'off', 'omol'}:
        loader = getattr(foundations_models, f'mace_{source_lower}')
        if variant is not None:
            loader_kwargs['model'] = variant
    elif source_lower == 'anicc':
        loader = foundations_models.mace_anicc
        if variant is not None:
            loader_kwargs['model_path'] = variant
    else:
        raise ValueError(
            "Unknown foundation source. Expected one of {'mp', 'off', 'anicc', 'omol'}."
        )

    try:
        model = loader(return_raw_model=True, **loader_kwargs)
    except Exception:
        calculator = loader(return_raw_model=False, **loader_kwargs)
        model = getattr(calculator, 'model', None)
        if model is None:
            models = getattr(calculator, 'models', None)
            if models:
                model = models[0]
        if model is None:
            raise

    return model.float().eval()

def extract_foundation_metadata(torch_model):
    """Extract config, atomic number table, and cutoff from a Torch foundation model."""
    config = extract_config_mace_model(torch_model)
    config['torch_model_class'] = torch_model.__class__.__name__
    atomic_numbers = tuple(int(z) for z in config['atomic_numbers'])
    z_table = AtomicNumberTable(atomic_numbers)
    cutoff = float(config['r_max'])
    return config, z_table, cutoff

In [None]:
FOUNDATION_SOURCE = 'mp'
FOUNDATION_VARIANT = 'medium-mpa-0'

torch_model = load_foundation_model(FOUNDATION_SOURCE, FOUNDATION_VARIANT)
config, z_table, cutoff = extract_foundation_metadata(torch_model)

Using Materials Project MACE for MACECalculator with /home/pbenner/Env/mace-jax/mace/mace/calculators/foundations_models/mace-mpa-0-medium.model
Using float32 for MACECalculator, which is faster but less accurate. Recommended for MD. Use float64 for geometry optimization.


---
### Convert Torch model to JAX
---

In [None]:
jax_model, variables, template_batch = convert_model(torch_model, config)



---
### Prepare an evaluation batch
---

In [None]:
def batch_to_jax(batch: Batch) -> dict[str, jnp.ndarray]:
    converted = {}
    for key in batch.keys:
        value = batch[key]
        if isinstance(value, torch.Tensor):
            converted[key] = jnp.asarray(value.detach().cpu().numpy())
        else:
            converted[key] = value
    return converted

# Example SrTiO3 structure as before
a0 = 3.905
atoms = ase.Atoms(
    symbols=['Sr', 'Ti', 'O', 'O', 'O'],
    positions=a0 * np.array([
        [0.0, 0.0, 0.0],
        [0.5, 0.5, 0.5],
        [0.5, 0.5, 0.0],
        [0.5, 0.0, 0.5],
        [0.0, 0.5, 0.5],
    ]),
    cell=a0 * np.identity(3),
    pbc=[True, True, True],
)
config_atoms = config_from_atoms(atoms)
config_atoms.pbc = [bool(x) for x in config_atoms.pbc]
atomic_data = AtomicData.from_config(
    config_atoms,
    z_table=z_table,
    cutoff=cutoff,
)
batch_torch = torch_geometric.batch.Batch.from_data_list([atomic_data])
batch_jax = batch_to_jax(batch_torch)

---
### Compare Torch and JAX outputs
---

In [None]:
jax_out = jax_model.apply(variables, batch_jax, compute_stress=True)
torch_out = torch_model(batch_torch, compute_stress=True)

print('JAX energy:', jax_out['energy'])
print('Torch energy:', torch_out['energy'])
print('JAX forces:', jax_out['forces'])
print('Torch forces:', torch_out['forces'])
print('JAX stress:', jax_out['stress'])
print('Torch stress:', torch_out['stress'])

JAX energy: [-40.09476277]
Torch energy: tensor([-40.0948], grad_fn=<AddBackward0>)
JAX forces: [[ 1.0548549e-16  1.1785278e-16  4.2717566e-16]
 [-1.0904907e-15 -4.5937741e-16 -1.7226211e-15]
 [ 3.7947076e-16  7.8236029e-16  2.8536201e-16]
 [ 1.4137996e-16 -3.1225023e-17  9.4022012e-16]
 [ 3.6126392e-16 -6.8976942e-16  2.1076890e-16]]
Torch forces: tensor([[ 2.9168e-08, -7.2469e-08,  9.7556e-08],
        [ 1.0250e-06, -1.3690e-07,  1.7440e-06],
        [-9.5880e-07,  5.9395e-07,  2.0396e-07],
        [ 3.8487e-07, -8.6939e-07, -8.6892e-07],
        [-4.5973e-07,  4.4843e-07, -9.3516e-07]])
JAX stress: [[[-3.3666380e-02 -1.4734161e-09 -2.0536213e-09]
  [-1.4734161e-09 -3.3666413e-02  1.2520219e-09]
  [-2.0536213e-09  1.2520219e-09 -3.3666503e-02]]]
Torch stress: tensor([[[-3.3666e-02,  2.8892e-08, -2.8611e-08],
         [ 2.8892e-08, -3.3666e-02,  8.7222e-08],
         [-2.8611e-08,  8.7222e-08, -3.3666e-02]]])


---
### Foundation model energy sanity checks (JSON-loaded structures)
---

In [None]:
def atoms_from_serialised(record):
    struct = record['structure']
    lattice = np.asarray(struct['lattice'], dtype=float)
    symbols = [site['species'] for site in struct['sites']]
    coords = np.asarray([site['xyz'] for site in struct['sites']], dtype=float)
    atoms = ase.Atoms(symbols=symbols, positions=coords, cell=lattice, pbc=[True, True, True])
    if len(atoms) < 5:
        atoms = atoms.repeat((2, 2, 2))
    return atoms

records_path = Path('material_structures.json')
with records_path.open() as fh:
    structures = json.load(fh)

results = []
for record in structures:
    atoms = atoms_from_serialised(record)
    config_atoms = config_from_atoms(atoms)
    config_atoms.pbc = [bool(x) for x in config_atoms.pbc]
    atomic_data = AtomicData.from_config(
        config_atoms,
        z_table=z_table,
        cutoff=cutoff,
        heads=config.get('heads'),
    )
    batch_torch = torch_geometric.batch.Batch.from_data_list([atomic_data])
    batch_jax = batch_to_jax(batch_torch)

    torch_output = torch_model(batch_torch, compute_stress=False)
    jax_output = jax_model.apply(variables, batch_jax, compute_stress=False)

    results.append(
        {
            'material': record['label'],
            'mp_id': record['mp_id'],
            'torch_energy': torch_output['energy'].detach().cpu().numpy()[0],
            'jax_energy': float(np.asarray(jax_output['energy'])[0]),
        }
    )

print(f"Foundation model: {torch_model.__class__.__name__} ({FOUNDATION_VARIANT})")
for entry in results:
    print(f"Material: {entry['material']} (MP ID: {entry['mp_id']})")
    print(f"  Torch ({torch_model.__class__.__name__}) energy: {entry['torch_energy']:.6f} eV")
    print(f"  JAX ({jax_model.__class__.__name__}) energy:   {entry['jax_energy']:.6f} eV")
    print()

Foundation model: ScaleShiftMACE (medium-mpa-0)
Material: Al2O3 (MP ID: mp-1143)
  Torch (ScaleShiftMACE) energy: -74.799065 eV
  JAX (ScaleShiftMACE) energy:   -74.799073 eV

Material: MgSiO3 (MP ID: mp-3470)
  Torch (ScaleShiftMACE) energy: -287.229797 eV
  JAX (ScaleShiftMACE) energy:   -287.229821 eV

Material: Li2O (MP ID: mp-1960)
  Torch (ScaleShiftMACE) energy: -113.934151 eV
  JAX (ScaleShiftMACE) energy:   -113.934156 eV

