In [7]:
import e3nn_jax as e3nn
import haiku as hk
import jax
import jax.numpy as jnp
import jraph
import mace_jax
from mace_jax.data import GraphNodes, GraphEdges, GraphGlobals
from mace_jax.modules import GeneralMACE

In [14]:
@hk.without_apply_rng
@hk.transform
def mace_fn(vectors, atom_type, senders, receivers):
    return GeneralMACE(
        output_irreps="128x0e + 32x1o + 32x2e + 32x3o + 32x4e + 32x5o",
        r_max=5,
        num_interactions=2,
        hidden_irreps="128x0e + 128x1o + 128x2e",
        readout_mlp_irreps="128x0e + 128x1o + 128x2e",
        avg_num_neighbors=3, # idk
        num_species=5,
        radial_basis=lambda x, x_max: e3nn.bessel(x, 8, x_max),
        radial_envelope=e3nn.soft_envelope,
        max_ell=3
    )(vectors, atom_type, senders, receivers)

mace_apply = jax.jit(mace_fn.apply)

In [8]:
g = jraph.GraphsTuple(
    nodes=GraphNodes(jnp.asarray([[0.0, 0, 0], [1.0, 2, 0]]), None, jnp.asarray([1, 4])),
    edges=GraphEdges(None),
    globals=GraphGlobals(None, None, None, None),
    receivers=jnp.asarray([0, 1]),
    senders=jnp.asarray([1, 0]),
    n_node=jnp.asarray([2]),
    n_edge=jnp.asarray([2])
)



In [10]:
vectors = g.nodes.positions[g.receivers] - g.nodes.positions[g.senders]
atom_types = g.nodes.species

w = mace_fn.init(jax.random.PRNGKey(0), vectors, atom_types, g.senders, g.receivers)

In [17]:
output = mace_apply(w, vectors, atom_types, g.senders, g.receivers)

In [19]:
output.shape

(2, 2, 1248)