In [1]:
import os
os.environ["CUDA_VISIBLE_DEVICES"] = "4"


In [2]:
import ase.neighborlist
import e3nn as e3nn_torch
import e3nn_jax as e3nn
import haiku as hk
import jax
import jax.numpy as jnp
import matplotlib.pyplot as plt
import numpy as np
import torch
from mace import modules as torch_modules

from mace_jax.modules import GeneralMACE

In [4]:
@hk.without_apply_rng
@hk.transform
def jax_model(
    vectors: jnp.ndarray,  # [n_edges, 3]
    node_specie: jnp.ndarray,  # [n_nodes, #scalar_features]
    senders: jnp.ndarray,  # [n_edges]
    receivers: jnp.ndarray,  # [n_edges]
):
    e3nn.config("path_normalization", "path")
    e3nn.config("gradient_normalization", "path")
    return GeneralMACE(
        r_max=2.0,
        radial_basis=lambda r, r_max: e3nn.bessel(r, 8, r_max),
        radial_envelope=lambda r, r_max: e3nn.poly_envelope(5 - 1, 2, r_max)(r),
        max_ell=3,
        num_interactions=2,
        num_species=1,
        hidden_irreps="11x0e+11x1o",
        readout_mlp_irreps="16x0e",
        avg_num_neighbors=3.0,
        correlation=2,
        output_irreps="0e",
    )(vectors, node_specie, senders, receivers).array[:, :, 0]


torch_model = torch_modules.MACE(
    r_max=2.0,
    num_bessel=8,
    num_polynomial_cutoff=5,
    max_ell=3,
    interaction_cls_first=torch_modules.RealAgnosticInteractionBlock,
    interaction_cls=torch_modules.RealAgnosticResidualInteractionBlock,
    num_interactions=2,
    num_elements=1,
    hidden_irreps=e3nn_torch.o3.Irreps("11x0e+11x1o"),
    MLP_irreps="16x0e",
    avg_num_neighbors=3.0,
    correlation=2,
    atomic_energies=torch.zeros(1),
    atomic_numbers=[],
    gate=torch.nn.SiLU(),
)


  torch.tensor(atomic_energies, dtype=torch.get_default_dtype()),


In [5]:
def linear_torch_to_jax(linear):
    return {
        f"w[{ins.i_in},{ins.i_out}] {linear.irreps_in[ins.i_in]},{linear.irreps_out[ins.i_out]}": jnp.asarray(
            w.data
        )
        for i, ins, w in linear.weight_views(yield_instruction=True)
    }


def skip_tp_torch_to_jax(tp):
    return {
        f"w[{ins.i_in1},{ins.i_out}] {tp.irreps_in1[ins.i_in1]},{tp.irreps_out[ins.i_out]}": jnp.moveaxis(
            jnp.asarray(w.data), 1, 0
        )
        for i, ins, w in tp.weight_views(yield_instruction=True)
    }


w = {
    "general_mace/~/linear_node_embedding_block": {
        "embeddings_0": (
            torch_model.node_embedding.linear.weight.detach()
            .numpy()
            .reshape((1, -1, 1, 1))
        )
    },
    "general_mace/layer_0/skip_tp_first": skip_tp_torch_to_jax(
        torch_model.interactions[0].skip_tp
    ),
    "general_mace/layer_1/skip_tp": skip_tp_torch_to_jax(
        torch_model.interactions[1].skip_tp
    ),
    "general_mace/layer_0/interaction_block/linear_up": linear_torch_to_jax(
        torch_model.interactions[0].linear_up
    ),
    "general_mace/layer_0/interaction_block/linear_down": linear_torch_to_jax(
        torch_model.interactions[0].linear
    ),
    "general_mace/layer_0/interaction_block/message_passing_convolution/multi_layer_perceptron/linear_0": {
        "w": (
            torch_model.interactions[0].conv_tp_weights.layer0.weight.detach().numpy()
        )
    },
    "general_mace/layer_0/interaction_block/message_passing_convolution/multi_layer_perceptron/linear_1": {
        "w": (
            torch_model.interactions[0].conv_tp_weights.layer1.weight.detach().numpy()
        )
    },
    "general_mace/layer_0/interaction_block/message_passing_convolution/multi_layer_perceptron/linear_2": {
        "w": (
            torch_model.interactions[0].conv_tp_weights.layer2.weight.detach().numpy()
        )
    },
    "general_mace/layer_0/interaction_block/message_passing_convolution/multi_layer_perceptron/linear_3": {
        "w": (
            torch_model.interactions[0].conv_tp_weights.layer3.weight.detach().numpy()
        )
    },
    "general_mace/layer_1/interaction_block/linear_up": linear_torch_to_jax(
        torch_model.interactions[1].linear_up
    ),
    "general_mace/layer_1/interaction_block/linear_down": linear_torch_to_jax(
        torch_model.interactions[1].linear
    ),
    "general_mace/layer_1/interaction_block/message_passing_convolution/multi_layer_perceptron/linear_0": {
        "w": (
            torch_model.interactions[1].conv_tp_weights.layer0.weight.detach().numpy()
        )
    },
    "general_mace/layer_1/interaction_block/message_passing_convolution/multi_layer_perceptron/linear_1": {
        "w": (
            torch_model.interactions[1].conv_tp_weights.layer1.weight.detach().numpy()
        )
    },
    "general_mace/layer_1/interaction_block/message_passing_convolution/multi_layer_perceptron/linear_2": {
        "w": (
            torch_model.interactions[1].conv_tp_weights.layer2.weight.detach().numpy()
        )
    },
    "general_mace/layer_1/interaction_block/message_passing_convolution/multi_layer_perceptron/linear_3": {
        "w": (
            torch_model.interactions[1].conv_tp_weights.layer3.weight.detach().numpy()
        )
    },
    "general_mace/layer_0/equivariant_product_basis_block/~/symmetric_contraction": {
        "w2_0e": jnp.array(
            torch_model.products[0]
            .symmetric_contractions.contractions[0]
            .weights_max.detach()
            .numpy()
        ),
        "w2_1o": jnp.array(
            torch_model.products[0]
            .symmetric_contractions.contractions[1]
            .weights_max.detach()
            .numpy()
        ),
        "w1_0e": jnp.array(
            torch_model.products[0]
            .symmetric_contractions.contractions[0]
            .weights[0]
            .detach()
            .numpy()
        ),
        "w1_1o": jnp.array(
            torch_model.products[0]
            .symmetric_contractions.contractions[1]
            .weights[0]
            .detach()
            .numpy()
        ),
    },
    "general_mace/layer_0/equivariant_product_basis_block/linear": linear_torch_to_jax(
        torch_model.products[0].linear
    ),
    "general_mace/layer_1/equivariant_product_basis_block/~/symmetric_contraction": {
        "w2_0e": jnp.array(
            torch_model.products[1]
            .symmetric_contractions.contractions[0]
            .weights_max.detach()
            .numpy()
        ),
        "w1_0e": jnp.array(
            torch_model.products[1]
            .symmetric_contractions.contractions[0]
            .weights[0]
            .detach()
            .numpy()
        ),
    },
    "general_mace/layer_1/equivariant_product_basis_block/linear": linear_torch_to_jax(
        torch_model.products[1].linear
    ),
    "general_mace/layer_0/linear_readout_block/linear": linear_torch_to_jax(
        torch_model.readouts[0].linear
    ),
    "general_mace/layer_1/non_linear_readout_block/linear": linear_torch_to_jax(
        torch_model.readouts[1].linear_1
    ),
    "general_mace/layer_1/non_linear_readout_block/linear_1": linear_torch_to_jax(
        torch_model.readouts[1].linear_2
    ),
}


In [7]:
positions = np.array(
    [
        [0.0, 0.0, 0.0],
        [0.5, 0.0, 0.0],
        [0.0, 0.4, 0.0],
        [0.0, 0.3, 0.3],
    ]
)
node_specie = np.arange(4) % 1
cell = np.identity(3)

senders, receivers, receivers_unit_shifts = ase.neighborlist.primitive_neighbor_list(
    quantities="ijS",
    pbc=(True, True, False),
    cell=cell,
    positions=positions,
    cutoff=2.0,
)

print(f"n_nodes: {len(positions)}")
print(f"n_edges: {len(senders)}")


t_out = torch_model(
    {
        "positions": torch.tensor(positions, dtype=torch.float32),
        "edge_index": torch.tensor(np.stack([senders, receivers]), dtype=torch.long),
        "shifts": torch.tensor(receivers_unit_shifts, dtype=torch.float32),
        "node_attrs": torch.eye(1)[node_specie],
        "ptr": torch.tensor([0, len(positions)], dtype=torch.long),
        "batch": torch.zeros(len(positions), dtype=torch.long),
        "cell": torch.tensor(cell, dtype=torch.float32),
    }
)

t_out = t_out["contributions"][0, 1:].detach().numpy()


vectors = (positions[receivers] + receivers_unit_shifts @ cell) - positions[senders]
j_out = jnp.sum(jax_model.apply(w, vectors, node_specie, senders, receivers), axis=0)


n_nodes: 4
n_edges: 172
hardcoded normalization
hardcoded normalization
hardcoded normalization
hardcoded normalization


In [8]:
d = t_out - j_out

d / np.abs(t_out)

DeviceArray([-1.9792391e-05,  6.2816642e-04], dtype=float32)