In [1]:
import argparse
import json

import torch

---
## Prepare data
---

In [2]:
import ase
import numpy as np
from mace.data.atomic_data import AtomicData
from mace.data.utils import config_from_atoms
from mace.tools import torch_geometric
from torch_geometric.data import Batch


In [3]:
def get_statistics(filename = 'statistics.json'):

    from mace.tools.multihead_tools import (  # noqa: PLC0415
        AtomicNumberTable,
    )

    print(f"Reading statistics from `{filename}`")

    with open(filename, 'r') as f:
        statistics = json.load(f)

    statistics['atomic_numbers'] = AtomicNumberTable(statistics['atomic_numbers'])
    statistics['atomic_energies'] = [ statistics['atomic_energies'][str(i)] for i in statistics['atomic_numbers'].zs ]

    return statistics

statistics = get_statistics()

W0927 11:27:41.857000 1252156 /home/pbenner/Env/mace-jax/.venv/lib/python3.12/site-packages/torch/utils/cpp_extension.py:118] No CUDA runtime is found, using CUDA_HOME='/usr/local/cuda'


Reading statistics from `statistics.json`


In [4]:
atoms = ase.Atoms(
    symbols = ['H', 'H', 'Ne', 'O'],
    positions = np.array(
    [
        [0.0, 0.0, 0.0],
        [0.5, 0.0, 0.0],
        [0.0, 0.4, 0.0],
        [0.0, 0.3, 0.3],
    ]),
    cell = np.identity(3),
    pbc = [True, True, False],
)

In [5]:
config = config_from_atoms(atoms)
config.pbc = [bool(x) for x in config.pbc]
x = AtomicData.from_config(config, z_table=statistics['atomic_numbers'], cutoff=2.0)
x = torch_geometric.batch.Batch.from_data_list([x])

---
## Prepare JAX model
---

In [6]:
def configure_model_jax(
    args,
    atomic_energies,
    z_table=None,
    model_foundation=None,
    head_configs=None,
):
    import ast  # noqa: PLC0415

    from e3nn_jax import Irreps  # noqa: PLC0415

    from mace_jax import modules  # noqa: PLC0415

    model_config = dict(
        r_max=args.r_max,
        num_bessel=args.num_radial_basis,
        num_polynomial_cutoff=args.num_cutoff_basis,
        max_ell=args.max_ell,
        interaction_cls=modules.interaction_classes[args.interaction],
        num_interactions=args.num_interactions,
        num_elements=len(z_table),
        hidden_irreps=Irreps(args.hidden_irreps),
        edge_irreps=Irreps(args.edge_irreps) if args.edge_irreps else None,
        atomic_energies=atomic_energies,
        apply_cutoff=args.apply_cutoff,
        avg_num_neighbors=args.avg_num_neighbors,
        atomic_numbers=z_table.zs,
        use_reduced_cg=args.use_reduced_cg,
        use_so3=args.use_so3,
        cueq_config=None,
    )
    if args.model == "MACE":
        if args.interaction_first not in [
            "RealAgnosticInteractionBlock",
            "RealAgnosticDensityInteractionBlock",
        ]:
            args.interaction_first = "RealAgnosticInteractionBlock"
        return modules.ScaleShiftMACE(
            **model_config,
            pair_repulsion=args.pair_repulsion,
            distance_transform=args.distance_transform,
            correlation=args.correlation,
            gate=modules.gate_dict[args.gate],
            interaction_cls_first=modules.interaction_classes[args.interaction_first],
            MLP_irreps=Irreps(args.MLP_irreps),
            atomic_inter_scale=args.std,
            atomic_inter_shift=[0.0] * len(args.heads),
            radial_MLP=ast.literal_eval(args.radial_MLP),
            radial_type=args.radial_type,
            heads=args.heads,
            embedding_specs=args.embedding_specs,
            use_embedding_readout=args.use_embedding_readout,
            use_last_readout_only=args.use_last_readout_only,
            use_agnostic_product=args.use_agnostic_product,
        )
    if args.model == "ScaleShiftMACE":
        return modules.ScaleShiftMACE(
            **model_config,
            pair_repulsion=args.pair_repulsion,
            distance_transform=args.distance_transform,
            correlation=args.correlation,
            gate=modules.gate_dict[args.gate],
            interaction_cls_first=modules.interaction_classes[args.interaction_first],
            MLP_irreps=Irreps(args.MLP_irreps),
            atomic_inter_scale=args.std,
            atomic_inter_shift=args.mean,
            radial_MLP=ast.literal_eval(args.radial_MLP),
            radial_type=args.radial_type,
            heads=args.heads,
            embedding_specs=args.embedding_specs,
            use_embedding_readout=args.use_embedding_readout,
            use_last_readout_only=args.use_last_readout_only,
            use_agnostic_product=args.use_agnostic_product,
        )


In [7]:
def _get_model_jax(args: argparse.Namespace, statistics):

    import ast  # noqa: PLC0415

    from mace.data.utils import KeySpecification  # noqa: PLC0415
    from mace.tools.multihead_tools import (  # noqa: PLC0415
        prepare_default_head,
    )

    torch.set_default_dtype(torch.float64)

    args.mean = statistics['mean']
    args.std = statistics['std']
    args.compute_energy = True
    args.compute_dipole = False
    args.key_specification = KeySpecification.from_defaults()

    if args.heads is not None:
        args.heads = ast.literal_eval(args.heads)
    else:
        args.heads = prepare_default_head(args)

    model = configure_model_jax(
        args,
        statistics['atomic_energies'],
        z_table = statistics['atomic_numbers'],
        model_foundation = None,
    )

    return model

def get_model_jax(statistics):

    from mace.tools import build_default_arg_parser, check_args  # noqa: PLC0415

    arguments = [
        "--name"              , "MACE_large_density",
        "--interaction_first" , "RealAgnosticDensityInteractionBlock",
        "--interaction"       , "RealAgnosticDensityResidualInteractionBlock",
        "--num_channels"      , "128",
        "--max_L"             , "2",
        "--max_ell"           , "3",
        "--num_interactions"  , "3",
        "--correlation"       , "3",
        "--num_radial_basis"  , "8",
        "--MLP_irreps"        , "16x0e",
        "--distance_transform", "Agnesi",
        "--pair_repulsion"
    ]

    args    = build_default_arg_parser().parse_args(arguments)
    args, _ = check_args(args)

    return _get_model_jax(args, statistics)


In [None]:
import haiku as hk
import jax

def forward_fn(x):
    model = get_model_jax(statistics)
    return model(x)

transformed = hk.transform(forward_fn)
rng = jax.random.PRNGKey(42)
params = transformed.init(rng, x)