## E(3)-Equivariant Message-Passing Networks

Here, we briefly introduce E(3)-equivariant message-passing networks.

A message-passing network is a neural network that operates on a graph, computing node and graph-level representations.

The nodes of the graph are associated with feature vectors.
For example, in a molecular graph, each node can represent an atom and the initial feature vector can represent the atom type embedded as a vector. This is the setting we consider in this example.

The message-passing network operates in iterations.
At each iteration of message-passing, each node aggregates messages from its neighbors and updates its own feature vector. 

See these [Distill](https://distill.pub/2021/gnn-intro/) [articles](https://distill.pub/2021/understanding-gnns/) for an interactive introduction to graph neural networks!


Here, we will implement a simpler version of [NequIP](https://www.nature.com/articles/s41467-022-29939-5). The operation of the network is shown in the figure below. If this doesn't make sense to you, don't worry! We will explain the details in the following sections.

![image.png](images/algorithm.png)

In [11]:
# Imports
from typing import Callable, Tuple, Sequence
import flax.linen as nn
import jax
import jax.numpy as jnp
import jraph
import e3nn_jax as e3nn

First, let's discuss [Flax](https://github.com/google/flax), the JAX framework that we will use to implement the network.

Flax allows us to define neural networks in JAX, but in a way similar to that of PyTorch.

Let's see how we can define a multi-layer perceptron (MLP) in Flax.

In [12]:
class MLP(nn.Module):
    """A simple multi-layer perceptron."""

    # These are attributes.
    output_dims: int
    hidden_dims: int
    num_layers: int

    @nn.compact
    def __call__(self, x: jnp.ndarray) -> jnp.ndarray:
        for _ in range(self.num_layers - 1):
            x = nn.Dense(features=self.hidden_dims)(x)
            x = nn.LayerNorm()(x)
            x = nn.silu(x)
        x = nn.Dense(features=self.output_dims)(x)
        return x

*@nn.compact* is a decorator that allows us to define submodules (eg. nn.Dense, nn.LayerNorm above) within the __call__ method.

Another thing that you may notice is the automatic shape inference: we didn't need to define the input dimensions of 'x' above. 

We can create a model as follows:

In [13]:
model = MLP(output_dims=5, hidden_dims=8, num_layers=2)

JAX is a functional language, which means that we need to explicitly pass in the parameters to the model when calling it. This is different from PyTorch, where the parameters are stored in the model.

Fortunately, Flax provides a convenient way to handle this. We can use the 'init' method to initialize the model and the 'apply' method to call it.

JAX is also explicit about randomness. We need to pass in a PRNGKey to the model when calling it, which is a seed for the random number generator.

In [14]:
x = jnp.ones((3, 4))
rng = jax.random.PRNGKey(0)

# Initialize the model.
rng, init_rng = jax.random.split(rng)
params = model.init(init_rng, x)

# Flax provides a convenient way to print the model structure,
# inputs and outputs.
print(nn.tabulate(model, init_rng)(x))

# The forward pass is just obtained by calling model.apply.
y = model.apply(params, x)


[3m                                  MLP Summary                                   [0m
┏━━━━━━━━━━━━━┳━━━━━━━━━━━┳━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━┓
┃[1m [0m[1mpath       [0m[1m [0m┃[1m [0m[1mmodule   [0m[1m [0m┃[1m [0m[1minputs      [0m[1m [0m┃[1m [0m[1moutputs     [0m[1m [0m┃[1m [0m[1mparams              [0m[1m [0m┃
┡━━━━━━━━━━━━━╇━━━━━━━━━━━╇━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━┩
│             │ MLP       │ [2mfloat32[0m[3,4] │ [2mfloat32[0m[3,5] │                      │
├─────────────┼───────────┼──────────────┼──────────────┼──────────────────────┤
│ Dense_0     │ Dense     │ [2mfloat32[0m[3,4] │ [2mfloat32[0m[3,8] │ bias: [2mfloat32[0m[8]     │
│             │           │              │              │ kernel: [2mfloat32[0m[4,8] │
│             │           │              │              │                      │
│             │           │              │              │ [1m40 [0m[1;2m(160 B)[0m        

We will not go too much into the details of Flax here, but you can find more information in the [Flax documentation](https://flax.readthedocs.io/en/latest/) with some cool [examples](https://flax.readthedocs.io/en/latest/examples/index.html).

Let's move on to the implementation of the E(3)-equivariant message-passing network.

Initially, each node gets a feature vector based on the atom type. The 'AtomEmbedding' class below implements this:

In [15]:
class AtomEmbedding(nn.Module):
    """Embeds atomic numbers into a learnable vector space."""

    embed_dims: int
    max_atomic_number: int

    @nn.compact
    def __call__(self, atomic_numbers: jnp.ndarray) -> jnp.ndarray:
        atom_embeddings = nn.Embed(
            num_embeddings=self.max_atomic_number, features=self.embed_dims
        )(atomic_numbers)
        return e3nn.IrrepsArray(f"{self.embed_dims}x0e", atom_embeddings)

Clearly, the feature vectors of each node are initially scalars, because the atom types (and hence, their embeddings) are invariant under rotations.

The next step is to convolve the feature vectors of the neighbors of each node to get a new feature vector for each node. This is done multiple times, so we define a single 'SimpleNetworkLayer' that we can use multiple times.

In [16]:
class SimpleNetworkLayer(nn.Module):
    """A layer of a simple E(3)-equivariant message passing network."""

    mlp_hidden_dims: int
    mlp_num_layers: int
    output_irreps: e3nn.Irreps

    @nn.compact
    def __call__(
        self,
        node_features: e3nn.IrrepsArray,
        senders: jnp.ndarray,
        receivers: jnp.ndarray,
        relative_vectors_sh: e3nn.IrrepsArray,
        relative_vectors_norm: jnp.ndarray,
    ) -> e3nn.IrrepsArray:
        
        # Compute the skip connection.
        node_features_skip = node_features

        # Tensor product of the relative vectors and the neighbouring node features.
        node_features_broadcasted = node_features[senders]
        node_features_broadcasted = e3nn.tensor_product(
            relative_vectors_sh, node_features_broadcasted
        )

        # Simply multiply each irrep by a learned scalar, based on the norm of the relative vector.
        scalars = MLP(
            output_dims=node_features_broadcasted.irreps.num_irreps,
            hidden_dims=self.mlp_hidden_dims,
            num_layers=self.mlp_num_layers,
        )(relative_vectors_norm)
        scalars = e3nn.IrrepsArray(f"{scalars.shape[-1]}x0e", scalars)
        node_features_broadcasted = jax.vmap(lambda scale, feature: scale * feature)(
            scalars, node_features_broadcasted
        )

        # Aggregate the node features back.
        node_features = e3nn.scatter_mean(
            node_features_broadcasted,
            dst=receivers,
            output_size=node_features.shape[0],
        )

        # Apply a non-linearity.
        # Note that using an unnormalized non-linearity will make the model not equivariant.
        gate_irreps = e3nn.Irreps(f"{node_features.irreps.num_irreps}x0e")
        node_features_expanded = e3nn.flax.Linear(node_features.irreps + gate_irreps)(
            node_features
        )
        node_features = e3nn.gate(node_features_expanded)

        # Add the skip connection.
        node_features = e3nn.concatenate([node_features, node_features_skip])

        # Apply a linear transformation to the output.
        node_features = e3nn.flax.Linear(self.output_irreps)(node_features)
        return node_features

The message-passing is implemented by keeping track of senders and receiver nodes. This is a sparse representation of the adjacency matrix of the graph! The features of each neighbor are then 'tensor product'ed with the spherical harmonics of the relative position of the sender and receiver nodes. This creates higher-order irreps!

* Since we use $O(3)$-equivariant tensor products and non-linearities, the network is equivariant to *rotations*.

* To account for *translation symmetry*, we only use relative positions of the neighbors. This guarantees that the network is invariant to translations.

* To account for *permutation symmetry* amongst the neighbors, we sum the feature vectors of the neighbors and only then apply a non-linearity.

Fortunately, most of these operations are already implemented in e3nn.

In [21]:
def compute_features_of_relative_vectors(
    relative_vectors: jnp.ndarray, lmax: int
) -> Tuple[e3nn.IrrepsArray, jnp.ndarray]:
    """Compute the spherical harmonics of the relative vectors and their norms."""
    relative_vectors_sh = e3nn.spherical_harmonics(
        e3nn.s2_irreps(lmax=lmax),
        relative_vectors,
        normalize=True,
        normalization="norm",
    )
    relative_vectors_norm = jnp.linalg.norm(relative_vectors, axis=-1, keepdims=True)
    return relative_vectors_sh, relative_vectors_norm


class SimpleNetwork(nn.Module):
    """A simple E(3)-equivariant message passing network."""

    sh_lmax: int
    init_embed_dims: int
    max_atomic_number: int
    mlp_hidden_dims: int
    mlp_num_layers: int
    output_irreps_per_layer: Sequence[e3nn.Irreps]

    @nn.compact
    def __call__(self, graphs: jraph.GraphsTuple) -> jnp.ndarray:
        # Node features are initially the atomic numbers embedded.
        node_features = graphs.nodes["numbers"]
        node_features = AtomEmbedding(
            embed_dims=self.init_embed_dims,
            max_atomic_number=self.max_atomic_number,
        )(node_features)

        # Precompute the spherical harmonics of the relative vectors.
        positions = graphs.nodes["positions"]
        relative_vectors = positions[graphs.receivers] - positions[graphs.senders]
        (
            relative_vectors_sh,
            relative_vectors_norm,
        ) = compute_features_of_relative_vectors(
            relative_vectors,
            lmax=self.sh_lmax,
        )

        # Message passing.
        for output_irreps in self.output_irreps_per_layer:
            node_features = SimpleNetworkLayer(
                mlp_hidden_dims=self.mlp_hidden_dims,
                mlp_num_layers=self.mlp_num_layers,
                output_irreps=output_irreps,
            )(
                node_features,
                graphs.senders,
                graphs.receivers,
                relative_vectors_sh,
                relative_vectors_norm,
            )

        # Readout.
        return node_features, e3nn.scatter_mean(node_features, nel=graphs.n_node)

Finally, we readout the graph-level representation by summing the feature vectors of all nodes, which accounts for the permutation symmetry of the nodes.

In conclusion, we have implemented a simple E(3)-equivariant message-passing network that operates on a molecular graph. The differences with general message-passing networks are the use of equivariant features, but otherwise, the operations are quite similar.

We end with some visualizations of $E(3)$-equivariant features from a randomly initialized $E(3)$-equivariant network!

In [19]:
graph = jraph.GraphsTuple(
    nodes=dict(
        numbers=jnp.array([1, 6, 8]),
        positions=jnp.array([[0.0, 1.0, 1.0], [1.0, 0.0, 0.0], [2.0, 0.0, 0.0]]),
    ),
    senders=jnp.array([0, 1, 2]),
    receivers=jnp.array([1, 2, 0]),
    globals=jnp.array([0]),
    edges=None,
    n_node=jnp.array([3]),
    n_edge=jnp.array([3]),
)

In [22]:
model = SimpleNetwork(
    sh_lmax=3,
    init_embed_dims=8,
    max_atomic_number=10,
    mlp_hidden_dims=8,
    mlp_num_layers=2,
    output_irreps_per_layer=[
        e3nn.Irreps("8x0e + 8x1o"),
        e3nn.Irreps("8x0e + 8x1o"),
    ],
)

params = model.init(init_rng, graph)
node_features, global_features = model.apply(params, graph)
    

In [None]:
import pl