## 2. Learning Equivariant Angular Distributions with the Spherical Harmonics

Next, we introduce the spherical harmonics.

In [None]:
# Imports
from typing import Iterable, Tuple, Sequence
import ase
import e3nn_jax as e3nn
import jax
import jax.numpy as jnp
import plotly.graph_objects as go
import plotly.subplots
import ase.io
import ase.visualize

$$
\def\gR{\mathcal{R}}
\def\gS{\mathcal{S}}
\def\r{\vec{\mathbf{r}}}
\def\R{{\mathbf{R}}}
\def\T{{\mathbf{T}}}
\def\pTi{p_\Theta^{(i)}}
$$

# Background

Let $S^2 = \{(x, y, z) \in \mathbb{R}^3 \ | \ x^2 + y^2 + z^2 = 1\}$ be the unit sphere in $\mathbb{R}^3$.
We can also describe each point on $S^2$ by the usual angular coordinates $(\theta, \phi)$, where $\theta \in [0, \pi]$ is the polar angle and $\phi \in [0, 2\pi]$ is the azimuthal angle:
$$
x = \sin \theta \cos \phi, \quad y = \sin \theta \sin \phi, \quad z = \cos \theta
$$

The spherical harmonics $Y^l_m: S^2 \to \mathbb{R}$ are a set of functions indexed by degree $l \in \mathbb{N}$ and order $-l \leq m \leq l$. The cell below shows a 3D plot of the spherical harmonics for a given degree $l$, created using the `e3nn` library:

In [None]:
def spherical_harmonics_as_signals(l: int) -> Iterable[e3nn.SphericalSignal]:
    """Yields the spherical harmonics of degree l as a sequence of e3nn.SphericalSignal objects for each m such that -l <= m <= l."""
    res = (50, 49)
    for m in range(-l, l + 1):
        coeffs = e3nn.IrrepsArray(e3nn.s2_irreps(l)[-1], jnp.asarray([1. if md == m else 0. for md in range(-l, l + 1)]))
        yield e3nn.to_s2grid(coeffs, *res, quadrature="soft", p_val=1, p_arg=-1)

def plot_spherical_harmonics(l: int) -> None:
    """Plots the spherical harmonics of degree l on a single row of subplots with one column for each m such that -l <= m <= l."""
    fig = plotly.subplots.make_subplots(rows=1, cols=2*l + 1, specs=[[{'type': 'surface'} for _ in range(2*l + 1)]], subplot_titles=[f"m = {m}" for m in range(-l, l + 1)])
    camera = dict(
        up=dict(x=0, y=0, z=1),
        center=dict(x=0, y=0, z=0),
        eye=dict(x=2.25, y=2.25, z=2.25)
    )

    for index, sig in enumerate(spherical_harmonics_as_signals(l), start=1):
        fig.add_trace(go.Surface(sig.plotly_surface(scale_radius_by_amplitude=True), cmax=2, cmin=-2, showscale=(index == 2*l + 1)), row=1, col=index)
        fig.layout[f"scene{index}"].camera = camera

    fig.update_layout(title=f"Spherical Harmonics of Degree l = {l}", title_x=0.5)
    fig.show()

plot_spherical_harmonics(l=2)

The spherical harmonics have the following properties which are crucial for our model:

**Orthonormality**: The spherical harmonics form an *orthonormal basis* for functions on the sphere. Concretely, let $f: S^2 \to \mathbb{R}$ be any function on the sphere. Then, there exists unique coefficients $c^l(\gS) \in \mathbb{R}^{2l + 1}$ for each $l \in \mathbb{N}$ such that:
$$\begin{aligned}
    f(\theta, \phi; \gS) = \sum_{l = 0}^\infty {c^l}^T Y^l(\theta, \phi)
\end{aligned}$$
where $Y^l(\theta, \phi) = [Y^l_{-l}(\theta, \phi), \ldots, Y^l_{l}(\theta, \phi)] \in \mathbb{R}^{2l + 1}$.
We term these coefficients $c^l$ as the spherical harmonic coefficients of $f$
as they are obtained by projecting $f$ onto the spherical harmonics.

**$E(3)$-Equivariance**: Let $\R$ denote an arbitrary rotation of the the sphere
mapping each point $(\theta, \phi)$ on $S^2$ to the point
$\R \cdot (\theta, \phi)$ also on $S^2$. Then, the spherical harmonics $Y^l(\theta, \phi)$ transforms as an $E(3)$-equivariant feature of degree $l$:
$$\begin{aligned}
    Y^l(\R(\theta, \phi)) = D^l(\R)^T Y^l(\theta, \phi)
\end{aligned}
$$
with parity $p = (-1)^l$ under inversion.

To utilize this machinery to represent probability distributions, we first switch to spherical coordinates, described above.
Any probability distribution $\pTi()$ must satisfy the following normalization and non-negativity constraints:
$$
\begin{aligned}
\int_{\mathbb{R}^3} \pTi(r, \theta, \phi) \ dV &= 1 \\
\pTi(r, \theta, \phi) \geq 0
\end{aligned}
$$
where $dV = r dr \sin \theta d\theta d\phi$ is the volume element in spherical coordinates.
These constraints are hard to incorporate directly into a neural network. It is easier to predict the logits $f_\Theta^{(i)}(r, \theta, \phi)$ instead, and take the softmax:
$$
\pTi(r, \theta, \phi) = \frac{\exp{f_\Theta^{(i)}(r, \theta, \phi)}}{\int_{\mathbb{R}^3}\exp{f_\Theta^{(i)}(r', \theta', \phi')} \ dV'}
$$
 
Now, we discuss how to model the logits $f_\Theta^{(i)}(r, \theta, \phi)$ using spherical harmonics.
For ease of notation we drop the subscript $\Theta$ and superscript $(i)$ in the remainder of this section.
We discretize the radial component $r$ into a fixed number of bins (here, 64). For each value of $r$, we obtain a function on the sphere $S^2$:
$$
f(r, \theta, \phi; \gS) = \sum_{l = 0}^{l_\text{max}} c^l(r; \gS)^T Y^l(\theta, \phi)
$$
where the coefficients $c^l$ are now a function of both $r$ and the point cloud $\gS$.
To guarantee $E(3)$-equivariance, these coefficients must also transform as an irreducible representation of $O(3)$:
$$
c^l(r; \R\gS) = (-1)^l D^l(\R)^T c^l(r; \gS)
$$
Our model predicts these coefficients $c^l$ using a $E(3)$ equivariant message-passing neural network.
Using the properties of the spherical harmonics and the unitary property of the Wigner D matrices, we are guaranteed that $f$ is $O(3)$-equivariant:
$$
\begin{aligned}
f(\R(r, \theta, \phi); \R\gS) &= \sum_{l = 0}^\infty c^l(r; \R\gS)^T Y^l(\R(\theta, \phi)) \\
&= \sum_{l = 0}^\infty (-1)^{2l} c^l(\gS)^T D^l(\R) D^l(\R)^T Y^l(\theta, \phi) \\
&= \sum_{l = 0}^\infty c^l(\gS)^T Y^l(\theta, \phi) \\
&= f(r, \theta, \phi; \gS)
\end{aligned}
$$

To guarantee translational equivariance, we use only relative position vectors to predict the coefficients $c$. This is a common practice in many equivariant neural networks ([Geiger and Smidt](https://arxiv.org/abs/2207.09453)). Thus, our predicted logits $f$ (and hence our predicted probability distribution $\pTi$) are $E(3)$-equivariant.

## Example

We demonstrate this $E(3)$-equivariance property for the following example of the unnormalized probability distribution:
$$
f(\theta, \phi) = \sin \theta + \cos \phi
$$
Using e3nn, we can compute the spherical harmonic coefficients:


In [None]:
def to_spherical_coordinates(coords: Sequence[float]) -> Tuple[float, float]:
    """Returns the spherical coordinates of a point in 3D space."""
    x, y, z = coords
    return jnp.arccos(z / jnp.sqrt(x ** 2 + y ** 2 + z ** 2)), jnp.arctan2(y, x)

def plot_signal(signal: e3nn.SphericalSignal) -> go.Figure:
    """Plots a SphericalSignal on a sphere."""
    fig = go.Figure([
        go.Surface(
            signal.plotly_surface(scale_radius_by_amplitude=True),
            cmin=-1,
            cmax=2
    )])
    fig.update_layout(
        scene_camera=dict(
            up=dict(x=0, y=0, z=1),
            center=dict(x=0, y=0, z=0),
            eye=dict(x=-1.5, y=1.5, z=1.5)
        ),
        paper_bgcolor="rgba(0,0,0,0)",
        plot_bgcolor="rgba(0,0,0,0)"
    )
    fig.update_traces(showscale=False)
    return fig

def f(theta: float, phi: float) -> float:
    """Returns the value of the function at the given spherical coordinates."""
    return jnp.sin(theta) + jnp.cos(phi)

In [None]:
# Defines a resolution of the grid on the sphere.
resolution = (100, 359)
f_signal = e3nn.SphericalSignal.zeros(*resolution, quadrature="soft")
thetas, phis = jax.vmap(jax.vmap(to_spherical_coordinates))(f_signal.grid_vectors)
f_signal.grid_values = jax.vmap(jax.vmap(f))(thetas, phis)

In [None]:
fig = plot_signal(f_signal)
fig.show()

We can compute the spherical coefficients of our signal. This requires defining a cutoff $l_\text{max}$. Below, we visualize the reconstructions as a function of $l_\text{max}$:

In [None]:
# Define reconstructions upto a cutoff.
for lmax in range(10):
    f_coeffs = e3nn.from_s2grid(f_signal, e3nn.s2_irreps(lmax))
    f_signal_reconstructed = e3nn.to_s2grid(f_coeffs, *resolution, quadrature="soft", p_val=1, p_arg=-1)
    fig = plot_signal(f_signal_reconstructed)
    fig.show()

We compute the spherical harmonic coefficients of our signal:

In [None]:
lmax = 10
f_coeffs = e3nn.from_s2grid(f_signal, e3nn.s2_irreps(lmax))

Now, let's see how the function transforms when we rotate these coefficients by $\frac{\pi}{2}$ radians around the $z$ axis, using the Wigner D matrices corresponding to this rotation:

In [None]:
R = e3nn.matrix_z(0.5 * jnp.pi)
rotated_f_coeffs = f_coeffs.transform_by_matrix(R)
rotated_f_signal = e3nn.to_s2grid(rotated_f_coeffs, *resolution, quadrature="soft")

In [None]:
plot_signal(rotated_f_signal)

We clearly see that the function transforms as expected by $\frac{\pi}{2}$ radians.