In [9]:
import e3nn_jax as e3nn
import haiku as hk
import jax
import jax.numpy as jnp
import jraph
import mace_jax
from mace_jax.data import GraphNodes, GraphEdges, GraphGlobals
from mace_jax.modules import GeneralMACE

In [2]:
@hk.without_apply_rng
@hk.transform
def mace_fn(vectors, atom_type, senders, receivers):
    return GeneralMACE(
        output_irreps="128x0e + 32x1o + 32x2e + 32x3o + 32x4e + 32x5o",
        r_max=5,
        num_interactions=2,
        hidden_irreps="128x0e + 128x1o + 128x2e",
        readout_mlp_irreps="128x0e + 128x1o + 128x2e",
        avg_num_neighbors=3, # idk
        num_species=5,
        radial_basis=lambda x, x_max: e3nn.bessel(x, 8, x_max),
        radial_envelope=e3nn.soft_envelope,
        max_ell=3
    )(vectors, atom_type, senders, receivers)

mace_apply = jax.jit(mace_fn.apply)

In [3]:
g = jraph.GraphsTuple(
    nodes=GraphNodes(jnp.asarray([[0.0, 0, 0], [1.0, 2, 0]]), None, jnp.asarray([1, 4])),
    edges=GraphEdges(None),
    globals=GraphGlobals(None, None, None, None),
    receivers=jnp.asarray([0, 1]),
    senders=jnp.asarray([1, 0]),
    n_node=jnp.asarray([2]),
    n_edge=jnp.asarray([2])
)

In [4]:
vectors = g.nodes.positions[g.receivers] - g.nodes.positions[g.senders]
atom_types = g.nodes.species

w = mace_fn.init(jax.random.PRNGKey(0), vectors, atom_types, g.senders, g.receivers)

In [5]:
output = mace_apply(w, vectors, atom_types, g.senders, g.receivers)

In [6]:
output.shape

(2, 2, 1248)

In [13]:
import logging
logger = logging.getLogger()
logger.setLevel(logging.DEBUG)
logging.debug("test")

DEBUG:root:test


In [14]:
from clu import parameter_overview
parameter_overview.log_parameter_overview(w)

INFO:absl:
+------------------------------------------------------------------------------------------------------+---------------+---------+-----------+-------+
| Name                                                                                                 | Shape         | Size    | Mean      | Std   |
+------------------------------------------------------------------------------------------------------+---------------+---------+-----------+-------+
| general_mace/layer_0/equivariant_product_basis_block/linear/w[0,0] 128x0e,128x0e                     | (128, 128)    | 16,384  | -0.00984  | 1.01  |
| general_mace/layer_0/equivariant_product_basis_block/linear/w[1,1] 128x1o,128x1o                     | (128, 128)    | 16,384  | -0.0188   | 1.01  |
| general_mace/layer_0/equivariant_product_basis_block/linear/w[2,2] 128x2e,128x2e                     | (128, 128)    | 16,384  | -0.0159   | 1.01  |
| general_mace/layer_0/equivariant_product_basis_block/~/symmetric_contraction/w1_0

In [19]:
arr = jnp.asarray([1, 2, 3, 4, 5])
jnp.cumsum(arr)

DEBUG:jax._src.dispatch:Finished tracing + transforming jit(_cumulative_reduction) in 0.0003020763397216797 sec
DEBUG:jax._src.interpreters.pxla:Compiling _cumulative_reduction (11869827888) for with global shapes and types (ShapedArray(int32[5]),). Argument mapping: (OpShardingSharding({replicated}),).
DEBUG:jax._src.lib.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[CpuDevice(id=0)]]
DEBUG:jax._src.dispatch:Finished XLA compilation of jit(_cumulative_reduction) in 0.012703895568847656 sec


Array([ 1,  3,  6, 10, 15], dtype=int32)