In [4]:
import sys
sys.path.append('..')

import e3nn_jax as e3nn
import haiku as hk
import jax
import jax.numpy as jnp
import jraph
from mace_jax.modules import MACE

from datatypes import NodesInfo

In [5]:
@hk.without_apply_rng
@hk.transform
def mace_fn(vectors, atom_type, senders, receivers):
    return MACE(
        output_irreps="128x0e + 32x1o + 32x2e + 32x3o + 32x4e + 32x5o",
        r_max=5,
        num_interactions=2,
        hidden_irreps="128x0e + 128x1o + 128x2e",
        readout_mlp_irreps="128x0e + 128x1o + 128x2e",
        avg_num_neighbors=3, # idk
        num_species=5,
        radial_basis=lambda x, x_max: e3nn.bessel(x, 8, x_max),
        radial_envelope=e3nn.soft_envelope,
        max_ell=3
    )(vectors, atom_type, senders, receivers)

mace_apply = jax.jit(mace_fn.apply)

In [6]:
g = jraph.GraphsTuple(
    nodes=NodesInfo(jnp.asarray([[0.0, 0, 0], [1.0, 2, 0]]), jnp.asarray([1, 4])),
    edges=None,
    globals=None,
    receivers=jnp.asarray([0, 1]),
    senders=jnp.asarray([1, 0]),
    n_node=jnp.asarray([2]),
    n_edge=jnp.asarray([2])
)

In [8]:
vectors = g.nodes.positions[g.receivers] - g.nodes.positions[g.senders]
atom_types = g.nodes.atomic_numbers

w = mace_fn.init(jax.random.PRNGKey(0), vectors, atom_types, g.senders, g.receivers)

In [9]:
output = mace_apply(w, vectors, atom_types, g.senders, g.receivers)

In [10]:
output.shape

(2, 2, 1248)