In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
from typing import Callable
from absl import logging
import optax
import tqdm
import jax
import jraph
import jax.numpy as jnp
import e3nn_jax as e3nn
import flax.linen as nn
import matplotlib.pyplot as plt
import lovelyplots
plt.style.use('ipynb')

logging.set_verbosity(logging.DEBUG)

import sys
sys.path.append('..')

from src.data.qm9 import QM9Dataset
from src import tensor_products

In [3]:
ds = QM9Dataset(root='../data', target_property='mu', cutoff=5.0, add_self_edges=True, splits={
    'train': 110000,
    'val': 10000,
}, seed=0)
datasets = ds.get_datasets(batch_size=32)

INFO:absl:Target property mu: Dipole moment (D)


In [4]:
# Models
class MLP(nn.Module):

    output_dims: int
    hidden_dims: int = 32
    num_layers: int = 2

    @nn.compact
    def __call__(self, x: jnp.ndarray) -> jnp.ndarray:
        for _ in range(self.num_layers - 1):
            x = nn.Dense(features=self.hidden_dims)(x)
            x = nn.LayerNorm()(x)
            x = nn.silu(x)
        x = nn.Dense(features=self.output_dims)(x)
        return x


class SimpleNetwork(nn.Module):

    sh_lmax: int
    lmax: int
    init_node_features: int
    max_atomic_number: int
    num_hops: int
    output_dims: int
    tensor_product_fn: Callable[[], nn.Module]

    @nn.compact
    def __call__(self, graphs: jraph.GraphsTuple) -> jnp.ndarray:
        # Node features are initially the atomic numbers embedded.
        node_features = graphs.nodes['numbers']
        node_features = nn.Embed(num_embeddings=self.max_atomic_number, features=self.init_node_features)(node_features)
        node_features = e3nn.IrrepsArray(f"{self.init_node_features}x0e", node_features)
        
        # Precompute the spherical harmonics of the relative vectors.
        relative_vectors = graphs.edges["relative_vectors"]
        relative_vectors_sh = e3nn.spherical_harmonics(e3nn.s2_irreps(lmax=self.sh_lmax), relative_vectors, normalize=True,
                                                       normalization='norm')
        relative_vectors_norm = jnp.linalg.norm(relative_vectors, axis=-1, keepdims=True)
        
        # print("relative_vectors_sh", e3nn.norm(relative_vectors_sh))
        # print("node_features", e3nn.norm(node_features))

        for _ in range(self.num_hops):
            # Tensor product of the relative vectors and the neighbouring node features.
            node_features_broadcasted = node_features[graphs.senders]
            tp = self.tensor_product_fn()(
                relative_vectors_sh, node_features_broadcasted
            )
            tp = tp.filter(lmax=self.lmax)

            # Apply a linear transformation to the tensor product.
            tp = e3nn.flax.Linear(tp.irreps)(tp)

            # Simply multiply each irrep by a learned scalar.
            scalars = MLP(output_dims=tp.irreps.num_irreps)(relative_vectors_norm)
            scalars = e3nn.IrrepsArray(f"{scalars.shape[-1]}x0e", scalars)
            node_features_broadcasted = jax.vmap(lambda sc, feat: sc * feat)(scalars, tp)

            # Aggregate the node features back.
            node_features = e3nn.scatter_mean(node_features_broadcasted, dst=graphs.receivers, output_size=node_features.shape[0])

        # Global readout.
        graph_globals = e3nn.scatter_mean(node_features.filter("0e"), nel=graphs.n_node)
        return MLP(output_dims=self.output_dims)(graph_globals.array)

import functools
model = SimpleNetwork(
    sh_lmax=2,
    lmax=2,
    init_node_features=16,
    max_atomic_number=12,
    num_hops=3,
    output_dims=1,
    tensor_product_fn=functools.partial(
        tensor_products.TensorProductNaive,
    )
)
params = model.init(jax.random.PRNGKey(0), graphs=next(datasets['train']))

# Optimizer
tx = optax.adam(learning_rate=0.01)
opt_state = tx.init(params)

INFO:absl:Creating train dataset.
INFO:absl:Split train: Padding computed as {'n_node': 1152, 'n_edge': 2432, 'n_graph': 33}


In [5]:
def loss_fn(params, graphs):
    preds = model.apply(params, graphs)
    labels = graphs.globals
    assert preds.shape == labels.shape, (preds.shape, labels.shape)
    loss = (preds - labels) ** 2
    loss = jnp.mean(loss)
    return loss, preds

@jax.jit
def update_fn(params, opt_state, graphs):
    grad_fn = jax.value_and_grad(loss_fn, has_aux=True)
    (loss, preds), grads = grad_fn(params, graphs)

    updates, opt_state = tx.update(grads, opt_state)
    params = optax.apply_updates(params, updates)
    return params, opt_state, loss, preds



In [6]:
# Train
losses = []
steps = []
num_steps = 100
print("Training...", flush=True)
with tqdm.tqdm(range(num_steps)) as bar:
    for step in bar:
        graphs = next(datasets['train'])
        params, opt_state, loss, preds = update_fn(params, opt_state, graphs)
        
        bar.set_postfix(loss=f"{loss:.2f}")

        if step % 10 == 0:
            losses.append(loss)
            steps.append(step)

Training...


  0%|          | 0/100 [00:00<?, ?it/s]

In [None]:
plt.plot(steps, losses)
plt.xlabel('Step')
plt.ylabel('Loss')
plt.show()
