## 2. Learning Equivariant Angular Distributions with the Spherical Harmonics

Next, we introduce the spherical harmonics. Then, we show how to use them to represent equivariant angular distributions.

In [None]:
# Imports
from typing import Iterable, Tuple, Sequence
import e3nn_jax as e3nn
import jax
import jax.numpy as jnp
import plotly.graph_objects as go
import plotly.subplots
import ase.io
import ase.visualize

$$
\def\gR{\mathcal{R}}
\def\gS{\mathcal{S}}
\def\r{\vec{\mathbf{r}}}
\def\R{{\mathbf{R}}}
\def\T{{\mathbf{T}}}
\def\pT{p_\Theta}
\deff{f_\Theta}
\def\p{p}
$$

### Spherical Harmonics

Let $S^2 = \{(x, y, z) \in \mathbb{R}^3 \ | \ x^2 + y^2 + z^2 = 1\}$ be the unit sphere in $\mathbb{R}^3$.
We can also describe each point on $S^2$ by the usual angular coordinates $(\theta, \phi)$, where $\theta \in [0, \pi]$ is the polar angle and $\phi \in [0, 2\pi]$ is the azimuthal angle:
$$
x = \sin \theta \cos \phi \quad y = \sin \theta \sin \phi \quad z = \cos \theta
$$

The spherical harmonics $Y^l_m: S^2 \to \mathbb{R}$ are a set of functions indexed by degree $l \in \mathbb{N}$ and order $-l \leq m \leq l$. The cell below shows a 3D plot of the spherical harmonics for a given degree $l$:

In [None]:
def spherical_harmonics_as_signals(l: int) -> Iterable[e3nn.SphericalSignal]:
    """Yields the spherical harmonics of degree l as a sequence of e3nn.SphericalSignal objects for each m such that -l <= m <= l."""
    res = (50, 49)
    for m in range(-l, l + 1):
        coeffs = e3nn.IrrepsArray(e3nn.s2_irreps(l)[-1], jnp.asarray([1. if md == m else 0. for md in range(-l, l + 1)]))
        yield e3nn.to_s2grid(coeffs, *res, quadrature="soft", p_val=1, p_arg=-1)


def plot_spherical_harmonics(l: int) -> None:
    """Plots the spherical harmonics of degree l on a single row of subplots with one column for each m such that -l <= m <= l."""
    fig = plotly.subplots.make_subplots(rows=1, cols=2*l + 1, specs=[[{'type': 'surface'} for _ in range(2*l + 1)]], subplot_titles=[f"m = {m}" for m in range(-l, l + 1)])
    camera = dict(
        up=dict(x=0, y=0, z=1),
        center=dict(x=0, y=0, z=0),
        eye=dict(x=2.25, y=2.25, z=2.25)
    )

    for index, sig in enumerate(spherical_harmonics_as_signals(l), start=1):
        fig.add_trace(go.Surface(sig.plotly_surface(scale_radius_by_amplitude=True), cmax=2, cmin=-2, showscale=(index == 2*l + 1)), row=1, col=index)
        fig.layout[f"scene{index}"].camera = camera

    fig.update_layout(title=f"Spherical Harmonics of Degree l = {l}", title_x=0.5)
    fig.show()

plot_spherical_harmonics(l=1)
plot_spherical_harmonics(l=2)

Note that for each $l$, there are $2l + 1$ spherical harmonics $Y^l_m$ for $-l \leq m \leq l$.
Hmm, but an `irrep' of type $(l, p)$ also has dimension $2l + 1$. In fact, the spherical harmonics can be used to build irreps!

**$E(3)$-Equivariance**: Let $Y^l(\theta, \phi) = [Y^l_m(\theta, \phi)]_{m = -l}^{m = l}$. Then, $Y^l(\theta, \phi)$ is an irrep of type $(l, (-1)^l)$.



To evaluate the spherical harmonics at any point on $S^2$, e3nn provides the following function:

In [None]:
# Points on the unit sphere where we want to evaluate the spherical harmonics.
s2_points = jnp.asarray([[1., 0., 0.], [0., 1., 0.], [0., 0., 1.]])

# Normalization doesn't make a difference here because the points are all on the unit sphere
# In general, e3nn will scale the radial component r as r^l when computing Y^l_m.
e3nn.spherical_harmonics("2e", s2_points, normalize=True)


**Orthonormality**: The spherical harmonics form an *orthonormal basis* for functions on the sphere. Concretely, let $f: S^2 \to \mathbb{R}$ be any function on the sphere. Then, there exists unique coefficients $c^l \in \mathbb{R}^{2l + 1}$ for each $l \in \mathbb{N}$ such that:
$$\begin{aligned}
    f(\theta, \phi; \gS) = \sum_{l = 0}^\infty (c^l)^T Y^l(\theta, \phi)
\end{aligned}$$
where $Y^l(\theta, \phi) = [Y^l_{-l}(\theta, \phi), \ldots, Y^l_{l}(\theta, \phi)] \in \mathbb{R}^{2l + 1}$.
We term these coefficients $c^l$ as the spherical harmonic coefficients of $f$
as they are obtained by projecting $f$ onto the spherical harmonics.

As $l$ increases, the angular frequency of the spherical harmonics increases, which means it varies more rapidly over the surface of the sphere. This is analogous to the Fourier series, where higher frequencies correspond to faster oscillations. Compare the spherical harmonics for $l = 1, 2$ in the cell above!

## Rotational Equivariance

We demonstrate this $E(3)$-equivariance property for the following example of the unnormalized probability distribution:
$$
f(\theta, \phi) = \sin \theta + \cos \phi
$$
Using e3nn, we can compute the spherical harmonic coefficients:


In [None]:
def to_spherical_coordinates(coords: Sequence[float]) -> Tuple[float, float]:
    """Returns the spherical coordinates of a point in 3D space."""
    x, y, z = coords
    return jnp.arccos(z / jnp.sqrt(x ** 2 + y ** 2 + z ** 2)), jnp.arctan2(y, x)

def plot_signal(signal: e3nn.SphericalSignal) -> go.Figure:
    """Plots a SphericalSignal on a sphere."""
    fig = go.Figure([
        go.Surface(
            signal.plotly_surface(scale_radius_by_amplitude=True),
            cmin=-1,
            cmax=2
    )])
    fig.update_layout(
        scene_camera=dict(
            up=dict(x=0, y=0, z=1),
            center=dict(x=0, y=0, z=0),
            eye=dict(x=-1.5, y=1.5, z=1.5)
        ),
        paper_bgcolor="rgba(0,0,0,0)",
        plot_bgcolor="rgba(0,0,0,0)"
    )
    fig.update_traces(showscale=False)
    return fig

def f(theta: float, phi: float) -> float:
    """Returns the value of the function at the given spherical coordinates."""
    return jnp.sin(theta) + jnp.cos(phi)

In [None]:
# Defines a resolution of the grid on the sphere.
resolution = (100, 359)
f_signal = e3nn.SphericalSignal.zeros(*resolution, quadrature="soft")
thetas, phis = jax.vmap(jax.vmap(to_spherical_coordinates))(f_signal.grid_vectors)
f_signal.grid_values = jax.vmap(jax.vmap(f))(thetas, phis)

In [None]:
fig = plot_signal(f_signal)
fig.show()

We can compute the spherical coefficients of our signal. This requires defining a cutoff $l_\text{max}$. Below, we visualize the reconstructions as a function of $l_\text{max}$:

In [None]:
# Define reconstructions upto a cutoff.
for lmax in range(5):
    f_coeffs = e3nn.from_s2grid(f_signal, e3nn.s2_irreps(lmax))
    f_signal_reconstructed = e3nn.to_s2grid(f_coeffs, *resolution, quadrature="soft", p_val=1, p_arg=-1)
    fig = plot_signal(f_signal_reconstructed)
    fig.show()

We compute the spherical harmonic coefficients of our signal:

In [None]:
lmax = 10
f_coeffs = e3nn.from_s2grid(f_signal, e3nn.s2_irreps(lmax))

Now, let's see how the function transforms when we rotate these coefficients by $\frac{\pi}{2}$ radians around the $z$ axis, using the Wigner D-matrices corresponding to this rotation:

In [None]:
R = e3nn.matrix_z(0.5 * jnp.pi)
rotated_f_coeffs = f_coeffs.transform_by_matrix(R)
rotated_f_signal = e3nn.to_s2grid(rotated_f_coeffs, *resolution, quadrature="soft")

In [None]:
plot_signal(rotated_f_signal)

We clearly see that the function transforms as expected by $\frac{\pi}{2}$ radians.

## Representing Probability Distributions

Let $F$ be a frame of reference, that is, a coordinate system. For example, $F$ could be the standard frame of reference represented by $3$ vectors $(e_1, e_2, e_3)$ in $\mathbb{R}^3$.

We represent a probability distribution $p$ in the frame of reference $F$.
We want our representation to be equivariant to rotations of the frame of reference $F$:
$$
p(\theta, \phi; F) = p(\R \cdot(\theta, \phi); \R F)
$$
So, if we rotate the frame of reference $F$ by $\R$, then the probability distribution $p$ should also rotate by $\R$. The notation $\R \cdot(\theta, \phi)$ denotes the result of rotating the point on $S^2$ represented by spherical coordinates $(\theta, \phi)$ by $\R$.

Any probability distribution $p$ must satisfy the following normalization and non-negativity constraints:
$$
\begin{aligned}
\int_{\mathbb{S}^2} p(\theta, \phi; F) \ \sin \theta d\theta d\phi &= 1 \\
\pT(\theta, \phi; F) \geq 0
\end{aligned}
$$
where $\sin \theta d\theta d\phi$ is the area element in spherical coordinates.
These constraints are hard to incorporate directly into a neural network. It is easier to predict the unnormalized logits $f(\theta, \phi; F)$ instead, and take the softmax:
$$
p(\theta, \phi; F) = \frac{\exp{f(\theta, \phi; F)}}{\int_{S^2}\exp{f(\theta', \phi'; F)} \ \sin \theta d\theta d\phi}
$$


How can we model the logits $f(\theta, \phi; F)$?
The first idea is to represent the logits in the basis of spherical harmonics:
$$
f(\theta, \phi; F) = \sum_{l = 0}^{l_\text{max}} c^l_\Theta(F)^T Y^l(\theta, \phi)
$$
where the coefficients $c^l_\Theta$ are predicted by a neural network.

This is not so different from the ideas of [DeepONet](https://arxiv.org/abs/1910.03193) and [Fourier Neural Operators](https://arxiv.org/abs/2010.08895), where the function is represented as a linear combination of basis functions, and the coefficients are predicted by a neural network.

To guarantee $E(3)$-equivariance, these coefficients $c^l_\Theta$ must transform as a $(l, (-1)^l)$-irrep of $O(3)$:
$$
c^l_\Theta(F) = (-1)^l D^l(\R)^T c^l_\Theta(F)
$$
Using the properties of the spherical harmonics and the unitary property of the Wigner-D matrices, we are guaranteed that $f$ is $O(3)$-equivariant:
$$
\begin{aligned}
f(\R \cdot (r, \theta, \phi); \R F) &= \sum_{l = 0}^\infty c^l_\Theta(F)^T Y^l(\R(\theta, \phi)) \\
&= \sum_{l = 0}^\infty (-1)^{l} c^l(F)^T D^l(\R) \times (-1)^{l}D^l(\R)^T Y^l(\theta, \phi) \\
&= \sum_{l = 0}^\infty c^l(F)^T Y^l(\theta, \phi) \\
&= f(r, \theta, \phi; F)
\end{aligned}
$$
This guarantees that the probability distribution $p$ is $O(3)$-equivariant.

Because these coefficients transform as an irrep of $O(3)$, we need to use a neural network that is $E(3)$-equivariant, which we can define using e3nn.

Next, we show an example.

## Example: Learning Equivariant Angular Distributions

In [None]:

import flax.linen as nn
import optax

class AngularPredictor(nn.Module):
    """Predicts the coefficients"""
    def __call__(frame: e3nn.IrrepsArray) -> e3nn.IrrepsArray:
        frame = e3nn.tensor_square(frame)
        return e3nn.flax.Linear(frame)

predictor = AngularPredictor()

def kl_divergence_loss(params: optax.Params, positions: jax.Array, frame: e3nn.IrrepsArray) -> float:
    """Computes the KL divergence between the true and predicted coefficients."""
    coeffs = predictor.apply(params, frame)
    return -jax.vmap(lambda position: log_prob(position, coeffs))(positions).mean()


def coeffs_to_probability_distribution(
    coeffs: e3nn.IrrepsArray, res_beta: int, res_alpha: int, quadrature: str
) -> e3nn.SphericalSignal:
    """Converts the coefficients at this radius to a probability distribution."""
    num_channels = coeffs.shape[-2]

    prob_signal = e3nn.to_s2grid(
        coeffs, res_beta=res_beta, res_alpha=res_alpha, quadrature=quadrature
    )
    assert prob_signal.shape == (
        num_channels,
        res_beta,
        res_alpha,
    )

    prob_signal = prob_signal.replace_values(
        prob_signal.grid_values - jnp.max(prob_signal.grid_values)
    )
    prob_signal = prob_signal.replace_values(jnp.exp(prob_signal.grid_values))
    
    # Sum over the channels.
    prob_signal = prob_signal.replace_values(
        jnp.sum(prob_signal.grid_values, axis=-3)
    )
    prob_signal /= prob_signal.integrate().array.sum()
    assert prob_signal.shape == (
        res_beta,
        res_alpha,
    )
    return prob_signal



def log_prob(
    self, position: e3nn.IrrepsArray, coeffs: e3nn.IrrepsArray
) -> float:
    """Computes the logits for the given position and coefficients."""
    # Normalize the position.
    normalized_position = position / jnp.linalg.norm(position.array)
    assert normalized_position.shape == (3,), normalized_position.shape

    # We have to compute the log partition function, because the distribution is not normalized.
    prob_signal = self.coeffs_to_probability_distribution(
        coeffs, self.res_beta, self.res_alpha, self.quadrature
    )
    assert prob_signal.shape == (
        self.num_channels,
        self.res_beta,
        self.res_alpha,
    )
    log_Z = jnp.log(prob_signal.integrate().array.sum())
    assert log_Z.shape == (), log_Z.shape

    # We can compute the logits.
    vals = e3nn.to_s2point(coeffs, normalized_position)
    vals = vals.array.squeeze(-1)
    assert vals.shape == (self.num_channels,), vals.shape

    logits = jax.scipy.special.logsumexp(vals, axis=-1)
    assert logits.shape == (), logits.shape

    return logits - log_Z

@staticmethod
def coeffs_to_probability_distribution(
    coeffs: e3nn.IrrepsArray, res_beta: int, res_alpha: int, quadrature: str
) -> e3nn.SphericalSignal:
    """Converts the coefficients at this radius to a probability distribution."""
    num_channels = coeffs.shape[-2]

    prob_signal = e3nn.to_s2grid(
        coeffs, res_beta=res_beta, res_alpha=res_alpha, quadrature=quadrature
    )
    assert prob_signal.shape == (
        num_channels,
        res_beta,
        res_alpha,
    )

    prob_signal = prob_signal.replace_values(
        prob_signal.grid_values - jnp.max(prob_signal.grid_values)
    )
    prob_signal = prob_signal.replace_values(jnp.exp(prob_signal.grid_values))
    
    # Sum over the channels.
    prob_signal = prob_signal.replace_values(
        jnp.sum(prob_signal.grid_values, axis=-3)
    )
    prob_signal /= prob_signal.integrate().array.sum()
    assert prob_signal.shape == (
        res_beta,
        res_alpha,
    )
    return prob_signal

def sample(
    self, radius: float, conditioning: e3nn.IrrepsArray, inverse_temperature: float
) -> e3nn.IrrepsArray:
    """Samples from the learned distribution using the discretized grid."""
    # Compute the coefficients at this radius.
    coeffs = self.coeffs(radius, conditioning)

    # Scale coefficients by the inverse temperature.
    beta = self.sampling_inverse_temperature_factor * inverse_temperature
    coeffs *= beta

    # We have to compute the log partition function, because the distribution is not normalized.
    prob_signal = self.coeffs_to_probability_distribution(
        coeffs, self.res_beta, self.res_alpha, self.quadrature
    )

    # Sample from the distribution.
    key = hk.next_rng_key()
    key, sample_key = jax.random.split(key)
    beta_index, alpha_index = prob_signal.sample(sample_key)
    sample = prob_signal.grid_vectors[beta_index, alpha_index]
    assert sample.shape == (3,), sample.shape

    # Scale by the radius.
    return sample * radius
