In [None]:
from typing import Tuple, Dict, Any, Sequence
import itertools

import e3nn_jax as e3nn
import chex
import jax
import jax.numpy as jnp
import plotly.graph_objects as go

jnp.set_printoptions(precision=3)
jnp.set_printoptions(suppress=True)

We store the coefficients for each Y(l, j, mj) in a dictionary.

In [None]:
def vsh_iterator(jmax: int):
    """Iterates over all VSH up to some jmax."""
    for j in range(jmax + 1):
        for l in [j - 1, j, j + 1]:
            if j == 0 and l != 1:
                continue
            yield j, l


def get_vsh_irrep(j: int, l: int, parity: int) -> e3nn.Irrep:
    """Returns the irrep of a VSH."""
    if parity == -1:
        return e3nn.Irrep(j, (-1) ** (l + 1))
    elif parity == 1:
        return e3nn.Irrep(j, (-1) ** l)
    raise ValueError(f"Invalid parity {parity}.")


def get_vsh_irreps(jmax: int, parity: int) -> e3nn.Irreps:
    """Returns the irreps for the VSH upto some jmax."""
    irreps = []
    for j, l in vsh_iterator(jmax):
        ir = get_vsh_irrep(j, l, parity)
        irreps.append(ir)
    return e3nn.Irreps(irreps)


def get_change_of_basis_matrices(jmax: int, parity: int) -> jnp.ndarray:
    """Returns the change of basis for each (j, l) pair."""
    if parity not in [1, -1]:
        raise ValueError(f"Invalid parity {parity}.")

    rtps = {}
    for j, l in vsh_iterator(jmax):
        rtp = e3nn.reduced_tensor_product_basis(
            "ij",
            i=e3nn.Irrep(1, parity),
            j=e3nn.Irrep(l, (-1) ** (l)),
            keep_ir=get_vsh_irrep(j, l, parity),
        )
        rtps[(j, l)] = rtp
    return rtps


class VSHCoeffs(dict):
    def __init__(self, parity, *args, **kwargs):
        super().__init__(*args, **kwargs)
        if parity not in [1, -1]:
            raise ValueError(f"Invalid parity {parity}.")
        self.parity = parity

    def __setitem__(self, key: Tuple[int, int], value: e3nn.IrrepsArray) -> None:
        j, l = key
        assert (
            value.irreps.num_irreps == 1
        ), f"Invalid count {value.irreps.count} for VSH {j, l}."
        mul, ir = value.irreps[0]
        assert l - 1 <= j <= l + 1, f"Invalid j={j} for VSH {j, l}."
        assert mul == 1, f"Invalid multiplicity {mul} for VSH {j, l}."
        assert ir == get_vsh_irrep(
            j, l, self.parity
        ), f"Invalid irrep {ir} for VSH {j, l}."
        super().__setitem__(key, value)

    def get_jmax(self) -> int:
        """Returns the maximum j in a dictionary of coefficients."""
        return max(j for j, _ in self.keys())

    def to_irreps_array(self) -> e3nn.IrrepsArray:
        """Converts a dictionary of VSH coefficients to an IrrepsArray."""
        return e3nn.concatenate([v for v in self.values()])

    def __repr__(self):
        lines = [f"VSHCoeffs(parity={self.parity})"]
        for key, value in self.items():
            lines.append(f" {key}: {value}")
        return "\n".join(lines)

    @classmethod
    def zeros(cls, jmax: int, parity: int) -> "VSHCoeffs":
        """Creates a dictionary of all-zeros coefficients for each VSH."""
        coeffs = cls(parity=parity)
        for j, l in vsh_iterator(jmax):
            ir = get_vsh_irrep(j, l, parity)
            coeffs[(j, l)] = e3nn.zeros(ir)
        return coeffs

    @classmethod
    def normal(cls, jmax: int, parity: int, key: chex.PRNGKey) -> "VSHCoeffs":
        """Creates a dictionary of all-zeros coefficients for each VSH."""
        coeffs = cls(parity=parity)
        for j, l in vsh_iterator(jmax):
            ir = get_vsh_irrep(j, l, parity)
            coeffs[(j, l)] = e3nn.normal(ir, key)
            key, _ = jax.random.split(key)
        return coeffs

    def to_vector_coeffs(self) -> e3nn.IrrepsArray:
        """Converts a dictionary of VSH coefficients to a 3D IrrepsArray."""
        rtps = get_change_of_basis_matrices(jmax=self.get_jmax(), parity=self.parity)
        all_vector_coeffs = []
        for j, l in self.keys():
            rtp = rtps[(j, l)]
            coeff = self[(j, l)]
            if rtp.array.shape[-1] != coeff.array.shape[-1]:
                raise ValueError(
                    f"Invalid shape {coeff.shape} for coefficients with j={j}, l={l}."
                )

            vector_coeffs = jnp.einsum("ijk,k->ij", rtp.array, coeff.array)
            vector_coeffs = e3nn.IrrepsArray(e3nn.s2_irreps(l)[-1], vector_coeffs)
            all_vector_coeffs.append(vector_coeffs)
        return e3nn.concatenate(all_vector_coeffs)

    def to_vector_signal(
        self, res_beta: int = 90, res_alpha: int = 89, quadrature="soft"
    ) -> e3nn.SphericalSignal:
        """Converts a dictionary of VSH coefficients to a vector spherical signal."""
        vector_coeffs = self.to_vector_coeffs()
        vector_coeffs = e3nn.sum(vector_coeffs.regroup(), axis=-1)
        vector_sig = e3nn.to_s2grid(
            vector_coeffs,
            res_beta=res_beta,
            res_alpha=res_alpha,
            quadrature=quadrature,
            p_val=1,
            p_arg=-1,
        )
        return vector_sig

    def filter(self, keep: Sequence[e3nn.Irreps]) -> "VSHCoeffs":
        """Filters out to keep only certain irreps."""
        keep = e3nn.Irreps(keep)
        new_coeffs = VSHCoeffs(parity=self.parity)
        for (j, l), coeff in self.items():
            _, coeff_ir = coeff.irreps[0]
            if coeff_ir in keep:
                new_coeffs[(j, l)] = coeff
        return new_coeffs

In [None]:
x = e3nn.Irrep("1e")
keep = e3nn.Irreps([x])
print(keep)
e3nn.Irrep("1e") in keep

In [None]:
def irreps_array_to_coeffs_dict(irreps_array: e3nn.IrrepsArray) -> VSHCoeffs:
    """Converts an IrrepsArray to a dictionary of VSH coefficients."""

    # Try to figure out the parity
    jmax = irreps_array.irreps.lmax
    detected_parity = None
    for parity in [1, -1]:
        if get_vsh_irreps(jmax, parity) == irreps_array.irreps:
            detected_parity = parity
            break

    if detected_parity is None:
        raise ValueError(f"Invalid irreps {irreps_array.irreps} for VSH.")

    coeffs_dict = VSHCoeffs.zeros(jmax, detected_parity)
    for (j, l), (_, ir), chunk in zip(
        coeffs_dict.keys(), irreps_array.irreps, irreps_array.chunks
    ):
        if chunk is None:
            continue
        coeffs_dict[(j, l)] = e3nn.IrrepsArray(ir, chunk[0])

    return coeffs_dict

In [None]:
get_vsh_irreps(3, parity=1), get_vsh_irreps(3, parity=-1)

This creates the Y(j, l, mj) by reducing the 1 x l representation into a (l - 1), l and (l + 1) representation.

In [None]:
get_change_of_basis_matrices(jmax=2, parity=1)

In [None]:
def vector_spherical_harmonics(
    j: int, l: int, mj: int, parity: int = -1
) -> e3nn.SphericalSignal:
    """Returns a (pseudo)-vector spherical harmonic for a given (j, l, mj)."""
    if j not in [l - 1, l, l + 1]:
        raise ValueError(f"Invalid j={j} for l={l}.")

    if mj not in range(-j, j + 1):
        raise ValueError(f"Invalid mj={mj} for j={j}.")

    coeffs = e3nn.IrrepsArray(
        get_vsh_irrep(j, l, parity),
        jnp.asarray([1.0 if i == mj else 0.0 for i in range(-j, j + 1)]),
    )
    coeffs_dict = VSHCoeffs(parity=parity)
    coeffs_dict[(j, l)] = coeffs
    return coeffs_dict.to_vector_signal()

In [None]:
coeffs_dict1 = VSHCoeffs(parity=-1)
coeffs_dict1[(1, 1)] = e3nn.IrrepsArray("1e", jnp.asarray([1.0, 0.0, 0.0]))
print(coeffs_dict1.to_vector_coeffs())

coeffs_dict1 = VSHCoeffs.normal(jmax=3, parity=-1, key=jax.random.PRNGKey(0))
print(coeffs_dict1.to_vector_coeffs())

In [None]:
def plot_vector_signal(
    sig: e3nn.SphericalSignal, scale_vec: float = 0.1, title: str = None
):
    """Plots a vector spherical signal."""
    grid = sig.grid_vectors.transpose((2, 0, 1)).reshape((3, -1))
    values = sig.grid_values.reshape((3, -1))

    fig = go.Figure()
    fig.add_trace(
        go.Cone(
            x=grid[0, :],
            y=grid[1, :],
            z=grid[2, :],
            u=scale_vec * values[0, :],
            v=scale_vec * values[1, :],
            w=scale_vec * values[2, :],
            colorscale="Viridis",
            sizemode="absolute",
            sizeref=5,
            showscale=True,
            hoverinfo="skip",
        )
    )
    if title is not None:
        fig.update_layout(title=title)
    return fig

In [None]:
j, l, mj = 2, 1, 0
plot_vector_signal(
    vector_spherical_harmonics(j, l, mj), title=f"VSH j={j}, l={l}, mj={mj}"
)

## Reconstruction

Check that we can emulate from_s2grid() and to_s2grid() with VSH.

In [None]:
def wrap_fn_for_vector_signal(fn):
    """vmaps a fn over res_beta and res_alpha axes."""
    fn = jax.vmap(fn, in_axes=-1, out_axes=-1)
    fn = jax.vmap(fn, in_axes=-1, out_axes=-1)
    return fn


def get_vsh_coeffs_at_mj(
    sig: e3nn.SphericalSignal, j_out: int, l_out: int, mj_out: int
) -> float:
    """Returns the component of Y_{j_out, l_out, mj_out} in the signal sig."""
    vsh_signal = vector_spherical_harmonics(j_out, l_out, mj_out)
    dot_product = sig.replace_values(
        wrap_fn_for_vector_signal(jnp.dot)(sig.grid_values, vsh_signal.grid_values)
    )
    return dot_product.integrate().array[0] / (4 * jnp.pi)


def get_vsh_coeffs_at_j(
    sig: e3nn.SphericalSignal,
    j_out: int,
    l_out: int,
    parity_out: int,
) -> e3nn.IrrepsArray:
    """Returns the components of Y_{j_out, l_out, mj_out} in the signal sig for all mj_out in [-j_out, ..., j_out]."""
    computed_coeffs = jnp.stack(
        [
            get_vsh_coeffs_at_mj(sig, j_out, l_out, mj_out)
            for mj_out in range(-j_out, j_out + 1)
        ]
    )
    computed_coeffs = e3nn.IrrepsArray(
        get_vsh_irrep(j_out, l_out, parity_out), computed_coeffs
    )
    return computed_coeffs


def get_vsh_coeffs(sig: e3nn.SphericalSignal, lmax: int, parity: int) -> VSHCoeffs:
    """Returns the components of Y_{j_out, l_out, mj_out} in the signal sig for all mj_out in [-j_out, ..., j_out] and j_out in [-l_out, ..., l_out] and l_out upto lmax."""
    if sig.shape[-3] != 3:
        raise ValueError(f"Invalid shape {sig.shape} for signal.")

    result = VSHCoeffs(parity=parity)
    for j_out, l_out in vsh_iterator(lmax):
        result[j_out, l_out] = get_vsh_coeffs_at_j(sig, j_out, l_out, parity)
    return result


def get_ssh_coeffs(sig: e3nn.SphericalSignal, lmax: int) -> e3nn.IrrepsArray:
    """Returns the components of the scalar spherical harmonics for each l and m in [-l, ..., l]."""
    return e3nn.from_s2grid(sig, irreps=e3nn.s2_irreps(lmax))

In [None]:
# Reconstruction example
lmax = 3
coeffs_dict1 = VSHCoeffs.normal(jmax=lmax, parity=-1, key=jax.random.PRNGKey(0))
sig1 = coeffs_dict1.to_vector_signal()
reconstructed_coeffs_dict1 = get_vsh_coeffs(sig1, lmax, parity=-1)
print(reconstructed_coeffs_dict1)
print(coeffs_dict1)
for j, l in reconstructed_coeffs_dict1.keys():
    assert reconstructed_coeffs_dict1[(j, l)].irreps == coeffs_dict1[(j, l)].irreps
    assert jnp.allclose(
        reconstructed_coeffs_dict1[(j, l)].array, coeffs_dict1[(j, l)].array
    )

## Cross Product

In [None]:
def cross_gaunt_tensor_product(
    coeffs_dict1: VSHCoeffs, coeffs_dict2: VSHCoeffs, output_lmax: int
) -> VSHCoeffs:
    sig1 = coeffs_dict1.to_vector_signal()
    sig2 = coeffs_dict2.to_vector_signal()
    cross_sig = sig1.replace_values(
        wrap_fn_for_vector_signal(jnp.cross)(sig1.grid_values, sig2.grid_values)
    )
    return get_vsh_coeffs(
        cross_sig, lmax=output_lmax, parity=coeffs_dict1.parity * coeffs_dict2.parity
    )

In [None]:
# Visualizing the cross product of two VSH
coeffs_dict1 = VSHCoeffs(parity=-1)
coeffs_dict1[(1, 2)] = e3nn.IrrepsArray("1o", jnp.asarray([1.0, 0.0, 0.0]))
plot_vector_signal(coeffs_dict1.to_vector_signal(), title="Signal 1")

In [None]:
coeffs_dict2 = VSHCoeffs(parity=-1)
coeffs_dict2[(1, 2)] = e3nn.IrrepsArray("1o", jnp.asarray([0.0, 1.0, 0.0]))
plot_vector_signal(coeffs_dict2.to_vector_signal(), title="Signal 2")

In [None]:
cross_product_dict = cross_gaunt_tensor_product(
    coeffs_dict1, coeffs_dict2, output_lmax=2
)
print(cross_product_dict)
plot_vector_signal(cross_product_dict.to_vector_signal(), title="Cross Product")

In [None]:
cross_product_dict

In [None]:
flipped_cross_product = cross_gaunt_tensor_product(
    coeffs_dict2, coeffs_dict1, output_lmax=2
)
plot_vector_signal(
    flipped_cross_product.to_vector_signal(), title="Flipped Cross Product"
)

In [None]:
flipped_cross_product

In [None]:
# PVSH x VSH
coeffs_dict1 = VSHCoeffs(parity=1)
coeffs_dict1[(2, 2)] = e3nn.normal("2e", jax.random.PRNGKey(0))

coeffs_dict2 = VSHCoeffs(parity=-1)
coeffs_dict2[(2, 3)] = e3nn.normal("2e", jax.random.PRNGKey(1))

(cross_gaunt_tensor_product(coeffs_dict1, coeffs_dict2, output_lmax=4).filter("4e"))

In [None]:
# VSH x VSH
coeffs_dict1 = VSHCoeffs(parity=-1)
coeffs_dict1[(2, 3)] = e3nn.normal("2e", jax.random.PRNGKey(0))

coeffs_dict2 = VSHCoeffs(parity=-1)
coeffs_dict2[(2, 3)] = e3nn.normal("2e", jax.random.PRNGKey(1))

(cross_gaunt_tensor_product(coeffs_dict1, coeffs_dict2, output_lmax=4).filter("4e"))

## Check YuQing's selection rules

YuQing:
1. |l_i - 1|   <=  j_i  <= l_i + 1 for all i (usual selection rule for l_i times 1)
2. |l_1 - l_2| <=  l_3  <= l_1 + l_2 (usual selection rule for l_1 times l_2)
3. |j_1 - j_2| <=  j_3  <= j_1 + j_2 (usual selection rule for j_1 times j_2)
4. l_1 + l_2 + l_3 is even
5. There is no choice of a, b, c where l_a = j_a and (l_b, j_b) = (l_c, j_c)

First one is guaranteed by our construction.

In [None]:
coeffs_dict1 = VSHCoeffs(parity=1)
coeffs_dict1[(2, 2)] = e3nn.normal("2e", jax.random.PRNGKey(0))

coeffs_dict2 = VSHCoeffs(parity=1)
coeffs_dict2[(2, 2)] = e3nn.normal("2e", jax.random.PRNGKey(1))

cross_product_dict = cross_gaunt_tensor_product(
    coeffs_dict1, coeffs_dict2, output_lmax=4
)
cross_product_dict

In [None]:
for (j3, l3), coeffs in cross_product_dict.items():
    if jnp.allclose(cross_product_dict[(j3, l3)].array, 0, atol=1e-3):
        continue

    print(f"Checking (j3, l3) = ({j3}, {l3})")
    for j2, l2 in coeffs_dict2.keys():
        for j1, l1 in coeffs_dict1.keys():
            # Check first conditions
            assert l1 - 1 <= j1 <= l1 + 1
            assert l2 - 1 <= j2 <= l2 + 1
            assert l3 - 1 <= j3 <= l3 + 1

            # Check second condition
            assert abs(l1 - l2) <= l3 <= l1 + l2, (l1, l2, l3)

            # Check third condition
            assert abs(j1 - j2) <= j3 <= j1 + j2, (j1, j2, j3)

            # Check fourth condition
            assert (l1 + l2 + l3) % 2 == 0, (l1, l2, l3)

            # Check fifth condition
            ls = [l1, l2, l3]
            js = [j1, j2, j3]
            for a, b, c in itertools.permutations(range(3)):
                assert not ((ls[a], ls[b], js[b]) == (js[a], ls[c], js[c]))

# Dot Product

In [None]:
# Dot product of vector fields
def dot_gaunt_tensor_product(
    coeffs_dict1: VSHCoeffs, coeffs_dict2: VSHCoeffs, output_lmax: int
) -> e3nn.IrrepsArray:
    sig1 = coeffs_dict1.to_vector_signal()
    sig2 = coeffs_dict2.to_vector_signal()
    dot_sig = sig1.replace_values(
        wrap_fn_for_vector_signal(jnp.dot)(sig1.grid_values, sig2.grid_values)
    )
    return get_ssh_coeffs(dot_sig, lmax=output_lmax)

In [None]:
coeffs_dict1 = VSHCoeffs(parity=-1)
coeffs_dict1[(1, 2)] = e3nn.IrrepsArray("1o", jnp.asarray([1.0, 0.0, 0.0]))

coeffs_dict2 = VSHCoeffs(parity=-1)
coeffs_dict2[(1, 2)] = e3nn.IrrepsArray("1o", jnp.asarray([0.0, 1.0, 1.0]))

(
    dot_gaunt_tensor_product(coeffs_dict1, coeffs_dict2, output_lmax=4),
    dot_gaunt_tensor_product(coeffs_dict2, coeffs_dict1, output_lmax=4),
)

# Full Vector Gaunt Tensor Product

- Go from irreps1 -> VSHCoeffs1 using Linear layer,
similarly for irreps2 -> VSHCoeffs2.

- Do the cross product of the VSHCoeffss to get VSHCoeffs3.

- Finally, go from VSHCoeffs3 -> irreps3 using Linear layer.

In [None]:
import haiku as hk


@hk.without_apply_rng
@hk.transform
def full_gaunt_tensor_product(
    coeffs1: e3nn.IrrepsArray,
    coeffs2: e3nn.IrrepsArray,
    parity1: int,
    parity2: int,
    output_lmax: int,
) -> e3nn.IrrepsArray:
    coeffs1 = e3nn.haiku.Linear(
        get_vsh_irreps(coeffs1.irreps.lmax, parity=parity1), force_irreps_out=True
    )(coeffs1)
    coeffs2 = e3nn.haiku.Linear(
        get_vsh_irreps(coeffs2.irreps.lmax, parity=parity2), force_irreps_out=True
    )(coeffs2)

    coeffs_dict1 = irreps_array_to_coeffs_dict(coeffs1)
    coeffs_dict2 = irreps_array_to_coeffs_dict(coeffs2)
    cross_product_dict = cross_gaunt_tensor_product(
        coeffs_dict1, coeffs_dict2, output_lmax=output_lmax
    )

    cross_product_coeffs = cross_product_dict.to_irreps_array()
    cross_product_coeffs = e3nn.haiku.Linear(
        get_vsh_irreps(output_lmax, parity=parity1 * parity2)
    )(cross_product_coeffs)
    return cross_product_coeffs


coeffs1 = e3nn.normal(e3nn.s2_irreps(3), jax.random.PRNGKey(0))
coeffs2 = e3nn.normal(e3nn.s2_irreps(3), jax.random.PRNGKey(1))

params = full_gaunt_tensor_product.init(
    jax.random.PRNGKey(0),
    coeffs1,
    coeffs2,
    parity1=-1,
    parity2=-1,
    output_lmax=4,
)
params

In [None]:
jax.jit(
    full_gaunt_tensor_product.apply,
    static_argnames=["output_lmax", "parity1", "parity2"],
)(
    params,
    coeffs1,
    coeffs2,
    parity1=-1,
    parity2=-1,
    output_lmax=4,
)

## Example with something chiral.

Standard tetris example.

In [None]:
import haiku as hk
import optax
import chex
from clu import parameter_overview

In [None]:
def get_tetris_datasets(rng: chex.PRNGKey):
    positions = jnp.asarray(
        [
            [(0, 0, 0), (0, 0, 1), (1, 0, 0), (1, 1, 0)],  # chiral_shape_1
            [(0, 0, 0), (0, 0, 1), (1, 0, 0), (1, -1, 0)],  # chiral_shape_2
            [(0, 0, 0), (1, 0, 0), (0, 1, 0), (1, 1, 0)],  # square
            [(0, 0, 0), (0, 0, 1), (0, 0, 2), (0, 0, 3)],  # line
            [(0, 0, 0), (0, 0, 1), (0, 1, 0), (1, 0, 0)],  # corner
            [(0, 0, 0), (0, 0, 1), (0, 0, 2), (0, 1, 0)],  # L
            [(0, 0, 0), (0, 0, 1), (0, 0, 2), (0, 1, 1)],  # T
            [(0, 0, 0), (1, 0, 0), (1, 1, 0), (2, 1, 0)],  # zigzag
        ],
        dtype=jnp.float32,
    )
    positions = e3nn.IrrepsArray("1o", positions)

    # Since chiral shapes are the mirror of one another we need an *odd* scalar to distinguish them
    labels = jnp.asarray(
        [
            [-1, 0, 0, 0, 0, 0, 0],  # chiral_shape_1
            [1, 0, 0, 0, 0, 0, 0],  # chiral_shape_2
            [0, 1, 0, 0, 0, 0, 0],  # square
            [0, 0, 1, 0, 0, 0, 0],  # line
            [0, 0, 0, 1, 0, 0, 0],  # corner
            [0, 0, 0, 0, 1, 0, 0],  # L
            [0, 0, 0, 0, 0, 1, 0],  # T
            [0, 0, 0, 0, 0, 0, 1],  # zigzag
        ],
        dtype=jnp.int32,
    )
    labels = e3nn.IrrepsArray("0o + 6x0e", labels)

    while True:
        # Apply a random rotation to the positions.
        rotation_rng, rng = jax.random.split(rng)
        random_rotations = e3nn.rand_matrix(rotation_rng, shape=(positions.shape[0],))
        positions = jax.vmap(lambda pos, rot: pos.transform_by_matrix(rot))(
            positions, random_rotations
        )

        # Apply a random translation to the positions.
        translation_rng, rng = jax.random.split(rng)
        random_translations = e3nn.normal(
            "1o", translation_rng, leading_shape=(positions.shape[0],)
        )
        positions = positions + random_translations[:, None, :]

        yield {
            "positions": positions,
            "labels": labels,
        }


rng = jax.random.PRNGKey(0)
dataset = get_tetris_datasets(rng)

In [None]:
class GNNLayer(hk.Module):
    def __init__(
        self,
        radial_embedding_dims: int,
        radial_embedding_layers: int,
        output_irreps: e3nn.Irreps,
    ):
        super().__init__()
        self.radial_embedding_dims = radial_embedding_dims
        self.radial_embedding_layers = radial_embedding_layers
        self.output_irreps = output_irreps

    def __call__(
        self,
        node_features: e3nn.IrrepsArray,
        distances: jnp.ndarray,
        relative_positions_embedded: e3nn.IrrepsArray,
        neighbor_features: e3nn.IrrepsArray,
    ) -> e3nn.IrrepsArray:
        def convolve_with_neighbours(
            distances, relative_positions_embedded, neighbor_features
        ):
            product = e3nn.tensor_product(
                relative_positions_embedded, neighbor_features
            )
            radial_mlp = e3nn.haiku.MultiLayerPerceptron(
                [self.radial_embedding_dims] * (self.radial_embedding_layers - 1)
                + [product.irreps.num_irreps],
                act=jax.nn.swish,
            )
            radial = radial_mlp(distances)
            return radial * product

        convolved_features = hk.vmap(convolve_with_neighbours, split_rng=False)(
            distances, relative_positions_embedded, neighbor_features
        )
        convolved_features = e3nn.mean(convolved_features, axis=-2)
        node_features = e3nn.concatenate([node_features, convolved_features])
        node_features = e3nn.haiku.Linear(self.output_irreps)(node_features)
        return node_features


class GNN(hk.Module):
    def __init__(
        self,
        num_layers: int,
        lmax: int,
        initial_embedding_dims: int,
        radial_embedding_dims: int,
        radial_embedding_layers: int,
        hidden_irreps: e3nn.Irreps,
        output_irreps: e3nn.Irreps,
    ):
        super().__init__()
        self.lmax = lmax
        self.num_layers = num_layers
        self.radial_embedding_dims = radial_embedding_dims
        self.radial_embedding_layers = radial_embedding_layers
        self.initial_embedding_dims = initial_embedding_dims
        self.hidden_irreps = hidden_irreps
        self.output_irreps = output_irreps

    def __call__(self, positions: e3nn.IrrepsArray) -> e3nn.IrrepsArray:
        num_graphs, num_nodes, _ = positions.shape
        assert positions.shape == (num_graphs, num_nodes, 3)
        relative_positions = jax.vmap(lambda pos: pos[None, :, :] - pos[:, None, :])(
            positions
        )
        assert relative_positions.shape == (num_graphs, num_nodes, num_nodes, 3)

        distances = e3nn.norm(relative_positions)
        assert distances.shape == (num_graphs, num_nodes, num_nodes, 1), distances.shape

        relative_positions_embedded = e3nn.spherical_harmonics(
            e3nn.s2_irreps(self.lmax), relative_positions, normalize=True
        )
        assert relative_positions_embedded.shape == (
            num_graphs,
            num_nodes,
            num_nodes,
            (self.lmax + 1) ** 2,
        )

        node_features = e3nn.ones(
            f"{self.initial_embedding_dims}x0e", leading_shape=(num_graphs, num_nodes)
        )
        assert node_features.irreps.is_scalar()
        assert node_features.shape == (
            num_graphs,
            num_nodes,
            self.initial_embedding_dims,
        )

        for _ in range(self.num_layers):
            layer = GNNLayer(
                self.radial_embedding_dims,
                self.radial_embedding_layers,
                self.hidden_irreps,
            )
            layer = hk.vmap(
                layer, split_rng=False, in_axes=(0, 0, 0, None)
            )  # node axis
            layer = hk.vmap(layer, split_rng=False)  # graph axis
            node_features = layer(
                node_features, distances, relative_positions_embedded, node_features
            )

        global_features = e3nn.mean(node_features, axis=-2)
        global_features = e3nn.haiku.Linear(self.output_irreps, force_irreps_out=True)(
            global_features
        )
        return global_features, node_features

In [None]:
@hk.without_apply_rng
@hk.transform
def model(data: Any) -> Any:
    gnn = GNN(
        num_layers=4,
        lmax=2,
        radial_embedding_dims=5,
        radial_embedding_layers=2,
        initial_embedding_dims=10,
        hidden_irreps=e3nn.s2_irreps(2) * 2,
        output_irreps="0o + 6x0e",
    )
    return gnn(data["positions"])

In [None]:
def train_on_dataset(model, dataset, num_training_steps: int):

    params = model.init(jax.random.PRNGKey(0), next(dataset))
    print(parameter_overview.get_parameter_overview(params))

    tx = optax.adam(1e-3)
    opt_state = tx.init(params)
    apply_fn = jax.jit(model.apply)

    def loss_fn(params, data):
        global_embedding, _ = apply_fn(params, data)
        return e3nn.norm(
            (global_embedding - data["labels"]), squared=True, per_irrep=False
        ).array.mean()

    @jax.jit
    def train_step(params, opt_state, data):
        loss_value, grads = jax.value_and_grad(loss_fn)(params, data)
        updates, opt_state = tx.update(grads, opt_state, params)
        params = optax.apply_updates(params, updates)
        return params, opt_state, loss_value

    for step, data in enumerate(dataset):
        params, opt_state, loss_value = train_step(params, opt_state, data)
        if step % 100 == 0:
            print(f"Step {step}: loss={loss_value}")

        if step > num_training_steps:
            break

    return params


params = train_on_dataset(model, dataset, num_training_steps=1000)

In [None]:
preds = model.apply(params, next(dataset))
preds[0]