In [40]:
from typing import Tuple, Dict, Any
import haiku as hk
import jax
import jax.numpy as jnp
import e3nn_jax as e3nn
import optax
import chex
from clu import parameter_overview

In [41]:
def get_tetris_datasets(rng: chex.PRNGKey):
    positions = jnp.asarray(
        [
            [(0, 0, 0), (0, 0, 1), (1, 0, 0), (1, 1, 0)],  # chiral_shape_1
            [(0, 0, 0), (0, 0, 1), (1, 0, 0), (1, -1, 0)],  # chiral_shape_2
            [(0, 0, 0), (1, 0, 0), (0, 1, 0), (1, 1, 0)],  # square
            [(0, 0, 0), (0, 0, 1), (0, 0, 2), (0, 0, 3)],  # line
            [(0, 0, 0), (0, 0, 1), (0, 1, 0), (1, 0, 0)],  # corner
            [(0, 0, 0), (0, 0, 1), (0, 0, 2), (0, 1, 0)],  # L
            [(0, 0, 0), (0, 0, 1), (0, 0, 2), (0, 1, 1)],  # T
            [(0, 0, 0), (1, 0, 0), (1, 1, 0), (2, 1, 0)],  # zigzag
        ],
        dtype=jnp.float32,
    )
    positions = e3nn.IrrepsArray("1o", positions)

    # Since chiral shapes are the mirror of one another we need an *odd* scalar to distinguish them
    labels = jnp.asarray(
        [
            [-1, 0, 0, 0, 0, 0, 0],  # chiral_shape_1
            [1, 0, 0, 0, 0, 0, 0],  # chiral_shape_2
            [0, 1, 0, 0, 0, 0, 0],  # square
            [0, 0, 1, 0, 0, 0, 0],  # line
            [0, 0, 0, 1, 0, 0, 0],  # corner
            [0, 0, 0, 0, 1, 0, 0],  # L
            [0, 0, 0, 0, 0, 1, 0],  # T
            [0, 0, 0, 0, 0, 0, 1],  # zigzag
        ],
        dtype=jnp.int32,
    )
    labels = e3nn.IrrepsArray("0o + 6x0e", labels)

    while True:
        # Apply a random rotation to the positions.
        rotation_rng, rng = jax.random.split(rng)
        random_rotations = e3nn.rand_matrix(rotation_rng, shape=(positions.shape[0],))
        positions = jax.vmap(lambda pos, rot: pos.transform_by_matrix(rot))(
            positions, random_rotations
        )

        # Apply a random translation to the positions.
        translation_rng, rng = jax.random.split(rng)
        random_translations = e3nn.normal(
            "1o", translation_rng, leading_shape=(positions.shape[0],)
        )
        positions = positions + random_translations[:, None, :]

        yield {
            "positions": positions,
            "labels": labels,
        }


rng = jax.random.PRNGKey(0)
dataset = get_tetris_datasets(rng)

In [42]:
class GNNLayer(hk.Module):
    def __init__(
        self,
        radial_embedding_dims: int,
        radial_embedding_layers: int,
        output_lmax: int,
    ):
        super().__init__()
        self.radial_embedding_dims = radial_embedding_dims
        self.radial_embedding_layers = radial_embedding_layers
        self.output_lmax = output_lmax

    def __call__(
        self,
        node_features: e3nn.IrrepsArray,
        distances: jnp.ndarray,
        relative_positions_embedded: e3nn.IrrepsArray,
        neighbor_features: e3nn.IrrepsArray,
    ) -> e3nn.IrrepsArray:
        def convolve_with_neighbours(
            distances, relative_positions_embedded, neighbor_features
        ):
            product = e3nn.tensor_product(
                relative_positions_embedded, neighbor_features
            )
            radial_mlp = e3nn.haiku.MultiLayerPerceptron(
                [self.radial_embedding_dims] * (self.radial_embedding_layers - 1)
                + [product.irreps.num_irreps],
                act=jax.nn.swish,
            )
            radial = radial_mlp(distances)
            return radial * product

        convolved_features = hk.vmap(convolve_with_neighbours, split_rng=False)(
            distances, relative_positions_embedded, neighbor_features
        )
        convolved_features = e3nn.mean(convolved_features, axis=-2)
        node_features = e3nn.concatenate([node_features, convolved_features])
        node_features = node_features.filter(lmax=self.output_lmax)
        node_features = e3nn.haiku.Linear(
            node_features.irreps + f"{node_features.irreps.num_irreps}x0e"
        )(node_features)
        node_features = e3nn.gate(node_features)
        node_features = node_features.regroup()
        return node_features


class GNN(hk.Module):
    def __init__(
        self,
        num_layers: int,
        lmax: int,
        initial_embedding_dims: int,
        radial_embedding_dims: int,
        radial_embedding_layers: int,
        hidden_lmax: int,
        output_irreps: e3nn.Irreps,
    ):
        super().__init__()
        self.lmax = lmax
        self.num_layers = num_layers
        self.radial_embedding_dims = radial_embedding_dims
        self.radial_embedding_layers = radial_embedding_layers
        self.initial_embedding_dims = initial_embedding_dims
        self.hidden_lmax = hidden_lmax
        self.output_irreps = output_irreps

    def __call__(self, positions: e3nn.IrrepsArray) -> e3nn.IrrepsArray:
        num_graphs, num_nodes, _ = positions.shape
        assert positions.shape == (num_graphs, num_nodes, 3)
        relative_positions = jax.vmap(lambda pos: pos[None, :, :] - pos[:, None, :])(
            positions
        )
        assert relative_positions.shape == (num_graphs, num_nodes, num_nodes, 3)

        distances = e3nn.norm(relative_positions)
        assert distances.shape == (num_graphs, num_nodes, num_nodes, 1), distances.shape

        relative_positions_embedded = e3nn.spherical_harmonics(
            e3nn.s2_irreps(self.lmax), relative_positions, normalize=True
        )
        assert relative_positions_embedded.shape == (
            num_graphs,
            num_nodes,
            num_nodes,
            (self.lmax + 1) ** 2,
        )

        node_features = e3nn.ones(
            f"{self.initial_embedding_dims}x0e", leading_shape=(num_graphs, num_nodes)
        )
        assert node_features.irreps.is_scalar()
        assert node_features.shape == (
            num_graphs,
            num_nodes,
            self.initial_embedding_dims,
        )

        for _ in range(self.num_layers):
            layer = GNNLayer(
                self.radial_embedding_dims,
                self.radial_embedding_layers,
                self.hidden_lmax,
            )
            layer = hk.vmap(
                layer, split_rng=False, in_axes=(0, 0, 0, None)
            )  # node axis
            layer = hk.vmap(layer, split_rng=False)  # graph axis
            print("before", node_features.irreps)

            node_features = layer(
                node_features, distances, relative_positions_embedded, node_features
            )
            print("after", node_features.irreps)

        global_features = e3nn.mean(node_features, axis=-2)
        global_features = e3nn.haiku.Linear(self.output_irreps, force_irreps_out=True)(
            global_features
        )
        # global_features = e3nn.scalar_activation(global_features)
        return global_features, node_features

In [43]:
@hk.without_apply_rng
@hk.transform
def model(
    data: Dict[str, e3nn.IrrepsArray]
) -> Tuple[e3nn.IrrepsArray, e3nn.IrrepsArray]:
    gnn = GNN(
        num_layers=3,
        lmax=2,
        radial_embedding_dims=5,
        radial_embedding_layers=2,
        initial_embedding_dims=2,
        hidden_lmax=2,
        output_irreps="0o + 6x0e",
    )
    return gnn(data["positions"])

In [44]:
def train_on_dataset(model, dataset, num_training_steps: int):

    params = model.init(jax.random.PRNGKey(0), next(dataset))
    print(parameter_overview.get_parameter_overview(params))

    tx = optax.adam(1e-3)
    opt_state = tx.init(params)
    apply_fn = jax.jit(model.apply)

    def loss_fn(params, data):
        global_embedding, _ = apply_fn(params, data)
        return e3nn.norm(
            (global_embedding - data["labels"]), squared=True, per_irrep=False
        ).array.mean()

    @jax.jit
    def train_step(params, opt_state, data):
        loss_value, grads = jax.value_and_grad(loss_fn)(params, data)
        updates, opt_state = tx.update(grads, opt_state, params)
        params = optax.apply_updates(params, updates)
        return params, opt_state, loss_value

    for step, data in enumerate(dataset):
        params, opt_state, loss_value = train_step(params, opt_state, data)
        if step % 100 == 0:
            print(f"Step {step}: loss={loss_value}")

        if step > num_training_steps:
            break

    return params


params = train_on_dataset(model, dataset, num_training_steps=1000)

before 2x0e
after 8x0e+2x1o+2x2e
before 8x0e+2x1o+2x2e
after 40x0e+16x1o+4x1e+16x2e+4x2o
before 40x0e+16x1o+4x1e+16x2e+4x2o
after 232x0e+8x0o+112x1o+48x1e+112x2e+48x2o
+---------------------------------------------------+------------+--------+----------+-------+
| Name                                              | Shape      | Size   | Mean     | Std   |
+---------------------------------------------------+------------+--------+----------+-------+
| gnn/gnn_layer/linear/w[0,0] 4x0e,4x0e             | (4, 4)     | 16     | 0.0553   | 0.539 |
| gnn/gnn_layer/linear/w[0,3] 4x0e,8x0e             | (4, 8)     | 32     | -0.0265  | 1.27  |
| gnn/gnn_layer/linear/w[1,1] 2x1o,2x1o             | (2, 2)     | 4      | 0.213    | 0.898 |
| gnn/gnn_layer/linear/w[2,2] 2x2e,2x2e             | (2, 2)     | 4      | -0.577   | 0.631 |
| gnn/gnn_layer/multi_layer_perceptron/linear_0/w   | (1, 5)     | 5      | -0.253   | 1.58  |
| gnn/gnn_layer/multi_layer_perceptron/linear_1/w   | (5, 6)     | 30   

In [46]:
preds, _ = model.apply(params, next(dataset))
preds.array.round(2)

before 2x0e
after 8x0e+2x1o+2x2e
before 8x0e+2x1o+2x2e
after 40x0e+16x1o+4x1e+16x2e+4x2o
before 40x0e+16x1o+4x1e+16x2e+4x2o
after 232x0e+8x0o+112x1o+48x1e+112x2e+48x2o


Array([[-1.        , -0.05      , -0.02      ,  0.07      , -0.        ,
         0.12      , -0.        ],
       [ 1.        , -0.05      , -0.02      ,  0.07      , -0.        ,
         0.12      , -0.        ],
       [ 0.        ,  0.78999996, -0.06      ,  0.16      , -0.03      ,
         0.12      ,  0.05      ],
       [ 0.        , -0.        ,  0.98999995,  0.        ,  0.        ,
        -0.01      ,  0.01      ],
       [ 0.        ,  0.22999999,  0.06      ,  0.82      ,  0.04      ,
        -0.13      , -0.07      ],
       [ 0.        , -0.05      ,  0.03      ,  0.08      ,  0.78      ,
        -0.04      ,  0.26      ],
       [ 0.        ,  0.07      ,  0.02      , -0.06      , -0.04      ,
         0.74      ,  0.08      ],
       [ 0.        ,  0.05      , -0.02      , -0.11      ,  0.29      ,
         0.16      ,  0.61      ]], dtype=float32)