In [None]:
from typing import Dict, Callable, Optional, Iterator

import jraph
import jax
import jax.numpy as jnp
import e3nn_jax as e3nn
import nequip_jax
import numpy as np
import optax
import haiku as hk
import chex
from clu import parameter_overview
import ml_collections

# Visualization of molecules
import io
import ase.io
from contextlib import redirect_stdout
import py3Dmol

# Visualization of training
import matplotlib.pyplot as plt
import seaborn as sns

from absl import logging

logging.set_verbosity(logging.DEBUG)

In [None]:
config = ml_collections.ConfigDict()
config.molecule = "uracil"
config.batch_size = 8
config.num_train_steps = 10000
config.num_test_steps = 100
config.train_fraction = 0.001
config.test_fraction = 0.001
config.nn_cutoff = 1.5

# These are NequIP-specific parameters
# Not super important for this tutorial!
config.nequip_max_ell = 3
config.nequip_normalization_factor = 100.
config.nequip_n_radial_basis = 8
config.nequip_num_channels = 16
config.nequip_num_interactions = 3

We have to transform our molecules into a graph, represented as a GraphsTuple object. We do this by connecting all atoms within a cutoff (config.nn_cutoff = 1.5 A here) by an edge.

In [None]:
def create_graphstuples(
    molecules: Dict[str, np.ndarray],
    num_steps: int,
    batch_size: int,
    nn_cutoff: float,
    rng: chex.PRNGKey,
) -> Iterator[jraph.GraphsTuple]:
    """Creates batches of graphs from a molecule dictionary."""
    num_samples, num_atoms, _ = molecules["R"].shape

    # Compute the distance matrix to select the edges.
    positions = molecules["R"][0]
    distance_matrix = jnp.linalg.norm(
        positions[None, :, :] - positions[:, None, :], axis=-1
    )

    # Avoid self-edges.
    valid_edges = (distance_matrix > 0) & (distance_matrix <= nn_cutoff)
    senders, receivers = np.nonzero(valid_edges)
    n_edge = jnp.asarray([valid_edges.sum()])
    n_node = jnp.asarray([num_atoms])

    # Embed species as one-hot vectors.
    species = molecules["z"]
    species_encoder = {
        1: 0,
        6: 1,
        7: 2,
        8: 3,
        9: 4,
    }
    species = jnp.asarray([species_encoder[int(s)] for s in species])
    del molecules["z"]

    graphs = []
    for step in range(num_steps * batch_size):
        step = step % num_samples
        rng, step_rng = jax.random.split(rng)

        # Create a graph from the molecule data.
        graph = jraph.GraphsTuple(
            nodes={
                "positions": e3nn.IrrepsArray("1o", molecules["R"][step]),
                "forces": e3nn.IrrepsArray("1o", molecules["F"][step]),
                "species": species,
            },
            edges=None,
            globals=None,
            senders=senders,
            receivers=receivers,
            n_node=n_node,
            n_edge=n_edge,
        )
        graphs.append(graph)

        if step % batch_size == (batch_size - 1):
            yield jraph.batch(graphs)

            # Reset the list of graphs.
            # Permute the molecules to create a new batch.
            graphs = []
            permutation = jax.random.permutation(step_rng, num_samples)
            molecules = {
                "R": molecules["R"][permutation],
                "F": molecules["F"][permutation],
            }


def get_datasets(
    config: ml_collections.ConfigDict,
    rng: chex.PRNGKey,
) -> Dict[str, Iterator[jraph.GraphsTuple]]:
    """Creates train and test datasets for a given molecule."""
    molecule = config.molecule
    batch_size = config.batch_size
    nn_cutoff = config.nn_cutoff
    train_fraction = config.train_fraction
    test_fraction = config.test_fraction
    num_train_steps = config.num_train_steps
    num_test_steps = config.num_test_steps

    # Load the molecule data, and split into train and test.
    molecules = np.load(f"data/md17_{molecule}.npz")
    molecules = {
        "R": molecules["R"],
        "z": molecules["z"],
        "F": molecules["F"],
    }
    num_molecules = molecules["R"].shape[0]
    num_train_samples = int(num_molecules * train_fraction)
    num_test_samples = int(num_molecules * test_fraction)

    logging.info("Number of training samples: %d", num_train_samples)
    logging.info("Number of test samples: %d", num_test_samples)

    train_rng, test_rng = jax.random.split(rng)
    train_molecules = jax.tree_map(lambda x: x[:num_train_samples], molecules)
    test_molecules = jax.tree_map(lambda x: x[-num_test_samples:], molecules)
    datasets = {
        "train": create_graphstuples(
            train_molecules, num_train_steps, batch_size, nn_cutoff, train_rng
        ),
        "test": create_graphstuples(
            test_molecules, num_test_steps, batch_size, nn_cutoff, test_rng
        ),
    }
    return datasets

[NequIP](https://www.nature.com/articles/s41467-022-29939-5) is a E(3)-equivariant graph neural network model that can create node embeddings for graphs. Each node embedding is a set of features consisting of scalars ("0e" objects), vectors ("1o" objects) and so on. The order of the features is controlled by config.max_ell. 

In [None]:
class NequIP(hk.Module):
    """Wrapper class for NequIP."""

    def __init__(
        self,
        num_species: int,
        r_max: float,
        avg_num_neighbors: float,
        max_ell: int,
        init_embedding_dims: int,
        hidden_irreps: str,
        output_irreps: str,
        num_interactions: int,
        even_activation: Callable[[jnp.ndarray], jnp.ndarray],
        odd_activation: Callable[[jnp.ndarray], jnp.ndarray],
        mlp_activation: Callable[[jnp.ndarray], jnp.ndarray],
        mlp_n_hidden: int,
        mlp_n_layers: int,
        n_radial_basis: int,
        name: Optional[str] = None,
    ):
        super().__init__(name=name)
        self.num_species = num_species
        self.r_max = r_max
        self.avg_num_neighbors = avg_num_neighbors
        self.max_ell = max_ell
        self.init_embedding_dims = init_embedding_dims
        self.hidden_irreps = hidden_irreps
        self.output_irreps = output_irreps
        self.num_interactions = num_interactions
        self.even_activation = even_activation
        self.odd_activation = odd_activation
        self.mlp_activation = mlp_activation
        self.mlp_n_hidden = mlp_n_hidden
        self.mlp_n_layers = mlp_n_layers
        self.n_radial_basis = n_radial_basis

    def __call__(
        self,
        graphs: jraph.GraphsTuple,
    ) -> e3nn.IrrepsArray:
        # Relative positions for translation invariance
        positions = graphs.nodes["positions"]
        relative_positions = positions[graphs.receivers] - positions[graphs.senders]
        relative_positions = relative_positions / self.r_max

        # Embed species (H, C, N, O, F)
        species = graphs.nodes["species"]
        node_feats = hk.Embed(self.num_species, self.init_embedding_dims)(species)
        node_feats = e3nn.IrrepsArray(f"{node_feats.shape[1]}x0e", node_feats)

        # Iteratively compute node embeddings
        for _ in range(self.num_interactions):
            node_feats = nequip_jax.NEQUIPESCNLayerHaiku(
                avg_num_neighbors=self.avg_num_neighbors,
                num_species=self.num_species,
                output_irreps=self.hidden_irreps,
                even_activation=self.even_activation,
                odd_activation=self.odd_activation,
                mlp_activation=self.mlp_activation,
                mlp_n_hidden=self.mlp_n_hidden,
                mlp_n_layers=self.mlp_n_layers,
                n_radial_basis=self.n_radial_basis,
            )(relative_positions, node_feats, species, graphs.senders, graphs.receivers)
            node_feats = e3nn.haiku.Linear(
                irreps_out=self.hidden_irreps, force_irreps_out=True
            )(node_feats)

        # Final linear layer
        node_feats = e3nn.haiku.Linear(
            irreps_out=self.output_irreps, force_irreps_out=True
        )(node_feats)
        return node_feats

In [None]:
@hk.without_apply_rng
@hk.transform
def model(graphs: jraph.GraphsTuple) -> e3nn.IrrepsArray:
    net = NequIP(
        num_species=5,
        r_max=config.nn_cutoff,
        avg_num_neighbors=config.nequip_normalization_factor,
        max_ell=config.nequip_max_ell,
        init_embedding_dims=config.nequip_num_channels,
        hidden_irreps=config.nequip_num_channels * e3nn.s2_irreps(config.nequip_max_ell),
        output_irreps="0e",
        num_interactions=config.nequip_num_interactions,
        even_activation=jax.nn.swish,
        odd_activation=jax.nn.tanh,
        mlp_activation=jax.nn.swish,
        mlp_n_hidden=64,
        mlp_n_layers=2,
        n_radial_basis=config.nequip_n_radial_basis,
    )
    return net(graphs)

We initialize the parameters of NequIP. 
We will use [Adam](https://arxiv.org/abs/1412.6980), a variant of stochastic gradient descent to optimize the parameters of NequIP to minimize the l2 loss between the predicted forces and true forces.

In [None]:
rng = jax.random.PRNGKey(0)
rng, dataset_rng = jax.random.split(rng)
datasets = get_datasets(config, dataset_rng)

# Initialize the model.
rng, init_rng = jax.random.split(rng)
example_graphs = next(iter(datasets["train"]))
params = model.init(init_rng, example_graphs)
parameter_overview.log_parameter_overview(params)

# Initialize the optimizer.
tx = optax.adam(1e-3)
opt_state = tx.init(params)

In [None]:
def get_predicted_forces(graphs: jraph.GraphsTuple, params: optax.Params):
    """Computes the predicted forces for a given set of positions and graphs."""

    def get_predicted_energies(positions: e3nn.IrrepsArray):
        updated_graphs = graphs._replace(
            nodes={
                "positions": positions,
                "species": graphs.nodes["species"],
            }
        )
        return model.apply(params, updated_graphs).array.sum()

    return jax.grad(get_predicted_energies)(graphs.nodes["positions"])


@jax.jit
def train_step(
    graphs: jraph.GraphsTuple, params: optax.Params, opt_state: optax.OptState
) -> Tuple[optax.Params, optax.OptState, float]:
    def loss_fn(params: optax.Params) -> float:
        predicted_forces = get_predicted_forces(graphs, params).array
        true_forces = graphs.nodes["forces"].array
        return (
            optax.l2_loss(predicted_forces, true_forces).sum(axis=-1).mean()
        )

    loss, grad = jax.value_and_grad(loss_fn)(params)
    updates, opt_state = tx.update(grad, opt_state, params)
    params = optax.apply_updates(params, updates)
    return params, opt_state, loss

In [None]:
losses = []
steps = []

for step, graphs in enumerate(datasets["train"]):
    params, opt_state, loss = train_step(graphs, params, opt_state)
    if step % 100 == 0:
        print(f"Step {step}: Loss {loss:.2f}")
        losses.append(loss)
        steps.append(step)

We plot the performance of our model as a function of training steps. Note that we are training on a very small subset of data. In practice, we would need to train on a much larger subset for longer, but this is beyond the scope of this tutorial.

In [None]:
steps = np.asarray(steps)
losses = np.asarray(losses)

sns.set_style("darkgrid")
sns.lineplot(x=steps, y=losses, label="loss")
plt.xlabel("steps")
plt.ylabel("loss")
plt.show()

In [None]:
def visualize_predictions(
    graph: jraph.GraphsTuple, predicted_forces: np.ndarray, forces_scale: float = 0.05
):
    """Visualizes the true forces (green) and predicted forces (orange) on the atoms of a molecule."""

    positions = graph.nodes["positions"].array.tolist()
    species = graph.nodes["species"].tolist()
    true_forces = (forces_scale * graph.nodes["forces"]).array.tolist()
    predicted_forces = (forces_scale * predicted_forces).array.tolist()

    if not len(positions) == len(species) == len(true_forces) == len(predicted_forces):
        raise ValueError(
            "The number of positions, species and predicted_forces must be the same."
        )

    species_decoder = {
        0: "H",
        1: "C",
        2: "N",
        3: "O",
        4: "F",
    }
    molecule = ase.Atoms(
        symbols=[species_decoder[s] for s in species],
        positions=positions,
    )

    with io.StringIO() as buf, redirect_stdout(buf):
        ase.io.write("-", molecule, format="xyz")
        xyz = buf.getvalue()

    xyzview = py3Dmol.view(width=400, height=400)
    xyzview.addModel(xyz, "xyz")
    for position, true_force, predicted_force in zip(
        positions, true_forces, predicted_forces
    ):
        xyzview.addArrow(
            {
                "start": {"x": position[0], "y": position[1], "z": position[2]},
                "end": {
                    "x": position[0] + true_force[0],
                    "y": position[1] + true_force[1],
                    "z": position[2] + true_force[2],
                },
                "radius": 0.05,
                "radiusRadio": 1.0,
                "mid": 1.0,
                "color": "green",
            }
        )
        xyzview.addArrow(
            {
                "start": {"x": position[0], "y": position[1], "z": position[2]},
                "end": {
                    "x": position[0] + predicted_force[0],
                    "y": position[1] + predicted_force[1],
                    "z": position[2] + predicted_force[2],
                },
                "radius": 0.05,
                "radiusRadio": 1.0,
                "mid": 1.0,
                "color": "orange",
            }
        )

    xyzview.setStyle({"stick": {"radius": 0.1}, "sphere": {"radius": 0.3}})
    xyzview.setBackgroundColor("0xeeeeee")
    xyzview.zoomTo()
    xyzview.show()

We take  a random graph from the training set and visualize the predicted forces.
The true forces are shown in green, and the predicted forces are shown in orange.

In [None]:
graphs = next(datasets["train"])
graph = jraph.unbatch(graphs)[0]
predictions = get_predicted_forces(graph, params)

visualize_predictions(graph, predictions)

If we rotate the atom positions by an arbitrary rotation matrix, the predicted forces should also rotate. We can check this visually:


In [None]:
rotation_matrix = e3nn.angles_to_matrix(1., 0.5, 0.)
rotated_graph = graph._replace(
    nodes={
        "positions": graph.nodes["positions"].transform_by_matrix(rotation_matrix),
        "species": graph.nodes["species"],
        "forces": graph.nodes["forces"].transform_by_matrix(rotation_matrix)
    }
)
rotated_predictions = get_predicted_forces(rotated_graph, params)
    
visualize_predictions(rotated_graph, rotated_predictions)