## 2. Representing Angular Distributions with the Spherical Harmonics

Next, we introduce the spherical harmonics. Then, we show how to use them to represent equivariant angular distributions.

In [None]:
# Imports
from typing import Iterable, Tuple
import e3nn_jax as e3nn
import jax
import jax.numpy as jnp
import plotly.graph_objects as go
import plotly.subplots
import optax
import tqdm

$$
\def\gR{\mathcal{R}}
\def\gS{\mathcal{S}}
\def\r{\vec{\mathbf{r}}}
\def\R{{\mathbf{R}}}
\def\T{{\mathbf{T}}}
\def\pT{p_\Theta}
\def\fT{f_\Theta}
\def\p{p}
$$

### Spherical Harmonics

Let $S^2 = \{(x, y, z) \in \mathbb{R}^3 \ | \ x^2 + y^2 + z^2 = 1\}$ be the unit sphere in $\mathbb{R}^3$.
We can also describe each point on $S^2$ by the usual angular coordinates $(\theta, \phi)$, where $\theta \in [0, \pi]$ is the polar angle and $\phi \in [0, 2\pi]$ is the azimuthal angle:
$$
x = \sin \theta \cos \phi \quad y = \sin \theta \sin \phi \quad z = \cos \theta
$$

The spherical harmonics $Y^l_m: S^2 \to \mathbb{R}$ are a set of functions indexed by degree $l \in \mathbb{N}$ and order $-l \leq m \leq l$. The cell below shows a 3D plot of the spherical harmonics for a given degree $l$:

In [None]:
# Plot spherical harmonics.

def spherical_harmonics_as_signals(l: int) -> Iterable[e3nn.SphericalSignal]:
    """Yields the spherical harmonics of degree l as a sequence of e3nn.SphericalSignal objects for each m such that -l <= m <= l."""
    res = (50, 49)
    for m in range(-l, l + 1):
        coeffs = e3nn.IrrepsArray(
            e3nn.s2_irreps(l)[-1],
            jnp.asarray([1.0 if md == m else 0.0 for md in range(-l, l + 1)]),
        )
        yield e3nn.to_s2grid(coeffs, *res, quadrature="soft", p_val=1, p_arg=-1)


def plot_spherical_harmonics(l: int) -> None:
    """Plots the spherical harmonics of degree l on a single row of subplots with one column for each m such that -l <= m <= l."""
    fig = plotly.subplots.make_subplots(
        rows=1,
        cols=2 * l + 1,
        specs=[[{"type": "surface"} for _ in range(2 * l + 1)]],
        subplot_titles=[f"m = {m}" for m in range(-l, l + 1)],
    )
    camera = dict(
        up=dict(x=0, y=0, z=1),
        center=dict(x=0, y=0, z=0),
        eye=dict(x=2.25, y=2.25, z=2.25),
    )
    axis = dict(
        title="",
        showticklabels=False,
        showgrid=False,
        zeroline=False,
        backgroundcolor="rgba(255,255,255,255)",
        range=[-2.5, 2.5],
    )

    for index, sig in enumerate(spherical_harmonics_as_signals(l), start=1):
        fig.add_trace(
            go.Surface(
                sig.plotly_surface(scale_radius_by_amplitude=True),
                cmax=2,
                cmin=-2,
                showscale=(index == 2 * l + 1),
            ),
            row=1,
            col=index,
        )
        fig.layout[f"scene{index}"].camera = camera

        fig.layout[f"scene{index}"] = dict(
            xaxis=axis,
            yaxis=axis,
            zaxis=axis,
            bgcolor="rgba(255,255,255,255)",
            aspectmode="cube",
            camera=dict(eye=dict(x=1.5, y=1.5, z=1.5)),
        )

    fig.update_layout(
        title=f"Spherical Harmonics of Degree l = {l}",
        title_x=0.5,
    )
    fig.show()


plot_spherical_harmonics(l=0)
plot_spherical_harmonics(l=1)
plot_spherical_harmonics(l=2)

If these plots reminded you of the s, p and d atomic orbitals, you're on the right track! The spherical harmonics are indeed the angular part of the atomic orbitals.

As $l$ increases, the angular frequency of the spherical harmonics increases, which means it varies more rapidly over the surface of the sphere. This is analogous to the Fourier series, where higher frequencies correspond to faster oscillations. Compare the spherical harmonics for $l = 1$ and $l = 2$ in the cell above!

Note that for each $l$, there are $2l + 1$ spherical harmonics $Y^l_m$ for $-l \leq m \leq l$.
Let $Y^l(\theta, \phi) = [Y^l_m(\theta, \phi)]_{m = -l}^{m = l}$.

When written in Cartesian coordinates, the spherical harmonics are given by (upto normalization):
- For $l = 0$, we have $Y^0(x, y, z) = [1]$.
- For $l = 1$, we have $Y^1(x, y, z) = [x, y, z]$.
- For $l = 2$, we have $Y^2(x, y, z) = [xy, yz, 2z^2 − x^2 − y^2, zx, x^2 − y^2]$.

When collected together like this, the spherical harmonics indeed transform as an irrep of $O(3)$:

**$O(3)$-Equivariance**:  $Y^l(\theta, \phi)$ transforms as an irrep of $O(3)$ of type $(l, (-1)^l)$.

To evaluate the spherical harmonics at any point on $S^2$, e3nn provides the following function:

In [None]:
# Points on the unit sphere where we want to evaluate the spherical harmonics.
s2_points = jnp.asarray([[1.0, 0.0, 0.0], [0.0, 1.0, 0.0], [0.0, 0.0, 1.0]])

# Normalization doesn't make a difference here because the points are all on the unit sphere
# In general, e3nn will scale the radial component r as r^l when computing Y^l_m.
e3nn.spherical_harmonics("2e", s2_points, normalize=True)

The next property allows us to uniquely convert between functions on the sphere $S^2$ and irreducible representations of $SO(3)$.

**Orthonormality**: The spherical harmonics form an *orthonormal basis* for functions on the sphere. Concretely, let $f: S^2 \to \mathbb{R}$ be any (square-integrable) function on the sphere. Then, there exists *unique* coefficients $c^l \in \mathbb{R}^{2l + 1}$ for each $l \in \mathbb{N}$ such that:
$$
\begin{aligned}
    f(\theta, \phi) &= \sum_{l = 0}^\infty (c^l)^T Y^l(\theta, \phi) \\
    \implies c^l_m &= \int_{S^2} f(\theta, \phi) Y^l_m(\theta, \phi) d\Omega
\end{aligned}
$$
where $d\Omega = \sin \theta d\theta d\phi$ is the surface area element on the sphere.

We term these coefficients $c^l$ as the spherical harmonic coefficients of $f$
as they are obtained by projecting $f$ onto the spherical harmonics.

These coefficients $c^l$ actually transform as an irrep of $O(3)$ of type $(l, (-1)^l)$, exactly how the corresponding spherical harmonics $Y^l$ transform. What does this mean?

Consider the 'rotated' function:
$$
f_\R(\theta, \phi) = f(\R^{-1} \cdot (\theta, \phi))
$$
where $\R \cdot (\theta, \phi)$ denotes the action of the rotation $\R$ on the point represented by spherical coordinates $(\theta, \phi)$.

What are the spherical harmonic coefficients of $f_\R$? They are given by:
$$
c^l(f_\R) = D^l(\R) c^l(f)
$$
To see this, we can use the $O(3)$-equivariance of the spherical harmonics, the unitary property of the Wigner D-matrices, and the uniqueness of the spherical harmonic coefficients:
$$
\begin{aligned}
f_\R(\theta, \phi) &= \sum_{l = 0}^\infty c^l(f_\R)^T Y^l(\theta, \phi) \\
&= \sum_{l = 0}^\infty c^l(f)^T D^l(\R)^T Y^l(\theta, \phi) \\
&= \sum_{l = 0}^\infty c^l(f)^T D^l(\R)^{-1} Y^l(\theta, \phi) \\
&= \sum_{l = 0}^\infty c^l(f)^T Y^l(\R^{-1} \cdot (\theta, \phi)) \\
&= f(\R^{-1} \cdot (\theta, \phi))
\end{aligned}
$$
The parity can also be checked with a similar argument, under inversion.

### Representing Functions on the Sphere

We demonstrate this $E(3)$-equivariance property for the following function, which represents a point cloud of $6$ points on the unit sphere:
$$
f(x, y, z) = \sum_{i = 1}^6 {\delta}(x - x_i)
$$
where $\{x_i\}_{i = 1}^6$ are the coordinates of a regular hexagon inscribed in the unit circle in the $xy$-plane.

In [None]:
# Plotting utilities.
# Defines a resolution of the grid on the sphere.
grid_kwargs = dict(
    res_beta=100,
    res_alpha=179,
    quadrature="soft",
    p_val=1,
    p_arg=-1,
)


def plot_signal(signal: e3nn.SphericalSignal, scale_radius_by_amplitude: bool = True) -> go.Figure:
    """Plots a SphericalSignal on a sphere."""
    fig = go.Figure(
        [
            go.Surface(
                signal.plotly_surface(scale_radius_by_amplitude=scale_radius_by_amplitude),
                showscale=True,
            )
        ]
    )
    fig.update_layout(
        scene_camera=dict(
            up=dict(x=0, y=0, z=1),
            center=dict(x=0, y=0, z=0),
            eye=dict(x=-1.5, y=1.5, z=1.5),
        ),
        # paper_bgcolor="rgba(0,0,0,0)",
        # plot_bgcolor="rgba(0,0,0,0)"
    )
    return fig

Using e3nn, we can compute the spherical harmonic coefficients of $f$:

In [None]:
hexagon_points = jnp.array(
    [
        [1, 0, 0],
        [0.5, jnp.sqrt(3) / 2, 0],
        [-0.5, jnp.sqrt(3) / 2, 0],
        [-1, 0, 0],
        [-0.5, -jnp.sqrt(3) / 2, 0],
        [0.5, -jnp.sqrt(3) / 2, 0],
    ]
)
f_coeffs = e3nn.sum(e3nn.s2_dirac(hexagon_points, lmax=6), axis=0)
f_signal = e3nn.to_s2grid(f_coeffs, **grid_kwargs)

fig = plot_signal(f_signal)
fig.update_layout(title="Visualizing Function f")
fig.show()

As an aside, e3nn also allows the creation of a signal directly from a Cartesian function:

In [None]:
def cartesian_f(coords: jax.Array) -> float:
    """Computes the function at the given Cartesian coordinates."""
    x, y, z = coords
    return x**3 - y**2 + z


cartesian_f_signal = e3nn.SphericalSignal.from_function(cartesian_f, **grid_kwargs)

Below, we visualize the reconstructions as a function of $l_\text{max}$:

In [None]:
# Define reconstructions upto a cutoff.
for lmax in range(7):
    f_coeffs_upto_lmax = f_coeffs.filter(lmax=lmax)
    f_signal_reconstructed = e3nn.to_s2grid(f_coeffs_upto_lmax, **grid_kwargs)
    fig = plot_signal(f_signal_reconstructed)
    fig.update_layout(title=f"Reconstructed f upto L = {lmax}")
    fig.show()

Note how we could not approximate the precise positions of the points on the sphere until $L = 6$ (in this case). We call this an 'angular resolution' bottleneck. Later, we will see one approach to fix this.

Now, let's see how the function transforms when we rotate these coefficients by $\frac{\pi}{2}$ radians around the $x$ axis, using the Wigner D-matrices corresponding to this rotation:

In [None]:
R = e3nn.matrix_x(0.5 * jnp.pi)
rotated_f_coeffs = f_coeffs.transform_by_matrix(R)
rotated_f_signal = e3nn.to_s2grid(rotated_f_coeffs, **grid_kwargs)

In [None]:
fig = plot_signal(rotated_f_signal)
fig.update_layout(title="Visualizing Rotated Function R f")
fig.show()

We clearly see that the function transforms as expected by $\frac{\pi}{2}$ radians, which is 90 degrees. This demonstrates the $O(3)$-equivariance property of the spherical harmonic coefficients.

### Representing Probability Distributions on the Sphere

We have just discussed how to represent functions on the sphere using spherical harmonic coefficients.
Now, let's move on to representing probability distributions on the sphere.

Any probability distribution $p$ must satisfy the following normalization and non-negativity constraints:
$$
\begin{aligned}
\int_{S^2} p(\theta, \phi) \ d\Omega &= 1 \\
\pT(\theta, \phi) &\geq 0
\end{aligned}
$$
where $\sin \theta d\theta d\phi$ is the area element in spherical coordinates.
These constraints are hard to incorporate directly into a neural network. It is easier to predict the unnormalized logits $f(\theta, \phi)$ instead, and take the softmax:
$$
p(\theta, \phi) = \frac{\exp{f(\theta, \phi)}}{\int_{S^2}\exp{f(\theta', \phi')} \ d\Omega'}
$$

e3nn has utilities for apply the exponential and performing integration over $S^2$:

In [None]:
p_signal = f_signal.apply(jnp.exp)
p_signal.integrate()

How can we model the logits $f(\theta, \phi)$?
The first idea is to represent the logits in the basis of spherical harmonics, upto some cutoff 
$L$:
$$
f(\theta, \phi) = \sum_{l = 0}^{L} (c^l)^T Y^l(\theta, \phi)
$$
where we have some model (eg. an equivariant neural network) that predicts the spherical harmonic coefficients $c^l$.

This is not so different from the ideas of [DeepONet](https://arxiv.org/abs/1910.03193) and [Fourier Neural Operators](https://arxiv.org/abs/2010.08895), where the function is represented as a linear combination of basis functions, and the coefficients are predicted by a neural network.

**Bypassing the Angular Frequency Bottleneck**:

However, as we have seen above, we may need a large $L$ to represent the function accurately due to the 'angular frequency bottleneck'.
An approach that we have found to work better is to predict *multiple channels* of spherical harmonic coefficients, and then combine them non-linearly via an activation on $S^2$:
$$
f(\theta, \phi) = \log \sum_{\text{ch}} \exp\left(\sum_{l = 0}^{L} (c^l_{\text{ch}})^T Y^l(\theta, \phi)\right)
$$

Again, each channel of coefficients $c^l_{\text{ch}}$ must transform as an irrep of $O(3)$ of type $(l, (-1)^l)$.

### Example: Learning Equivariant Angular Distributions

Next, we show an example of learning a probability distribution on the sphere using this approach, using the same hexagon point cloud as before.

In particular, our parameters are the coefficients $c^l_{\text{ch}}$ for $l = 0, 1, \ldots L$. The coefficients are being trained to represent the probability distribution of the hexagon point cloud, which is a uniform discrete distribution over the $6$ points of the point cloud. Let $p_c(\theta, \phi)$ represent the probability distribution induced by the coefficients $c$, as defined above.
Then, we seek to minimize the following loss:
$$
\min_c KL(p \ || \ p_c) = \max_c \frac{1}{6} \sum_{i = 1}^6 \log p_c(x_i)
$$
where $x_i$ are the coordinates of the hexagon point cloud, as defined above.

As we saw before, we needed atleast upto $L = 6$ spherical harmonics to represent the hexagon point cloud accurately, when we had only one channel of coefficients. Let's see if we can do better with multiple channels:

In [None]:
# Helpers for converting the coefficients to a probability distribution.
def coeffs_to_probability_distribution(
    coeffs: e3nn.IrrepsArray,
    res_beta: int = 180,
    res_alpha: int = 179,
    quadrature: str = "soft",
) -> e3nn.SphericalSignal:
    """Converts the coefficients at this radius to a probability distribution."""
    num_channels, _  = coeffs.shape

    # Convert to signal on S2.
    prob_signal = e3nn.to_s2grid(
        coeffs,
        res_beta=res_beta,
        res_alpha=res_alpha,
        quadrature=quadrature
    )
    assert prob_signal.shape == (
        num_channels,
        res_beta,
        res_alpha,
    )

    # Exponentiate the values.
    prob_signal = prob_signal.replace_values(jnp.exp(prob_signal.grid_values))

    # Sum over the channels.
    prob_signal = prob_signal.replace_values(jnp.sum(prob_signal.grid_values, axis=-3))
    log_Z = jnp.log(prob_signal.integrate().array.sum())
    
    prob_signal /= prob_signal.integrate().array.sum()
    assert prob_signal.shape == (
        res_beta,
        res_alpha,
    )
    return prob_signal, log_Z


def log_prob(position: e3nn.IrrepsArray, coeffs: e3nn.IrrepsArray) -> float:
    """Computes the logits for the given position and coefficients."""
    # We have to compute the log partition function,
    # because the distribution is not normalized.
    prob_signal, log_Z = coeffs_to_probability_distribution(
        coeffs
    )
    num_channels, _ = coeffs.shape
    assert prob_signal.shape == (
        prob_signal.res_beta,
        prob_signal.res_alpha,
    ), prob_signal.shape
    assert log_Z.shape == (), log_Z.shape

    # Compute the logits for each channel.
    vals = e3nn.to_s2point(coeffs, position)
    vals = vals.array.squeeze(-1)
    assert vals.shape == (num_channels,), vals.shape

    # Then, apply the same log-sum-exp transformation over channels.
    logits = jax.scipy.special.logsumexp(vals, axis=-1)
    assert logits.shape == (), logits.shape

    return logits - log_Z

In [None]:
@jax.jit
def train_step(
    coeffs: e3nn.IrrepsArray, positions: jax.Array, opt_state: optax.OptState,
) -> Tuple[e3nn.IrrepsArray, optax.GradientTransformation]:
    """Performs a single optimization step."""
    def loss_fn(coeffs: e3nn.IrrepsArray) -> float:
        log_prob_fn = lambda position: log_prob(position, coeffs)
        loss = -jax.vmap(log_prob_fn)(positions).mean()
        return e3nn.IrrepsArray("0e", loss)

    # Compute gradients.
    grad_fn = e3nn.grad(loss_fn)
    grads = grad_fn(coeffs)

    # Compute updates.
    updates, opt_state = tx.update(grads, opt_state)
    coeffs = optax.apply_updates(coeffs, updates)
    return coeffs, opt_state, loss_fn(coeffs).array

# Hyperparameters.
lmax = 2
num_channels = 6

# Initialize the coefficients.
key = jax.random.PRNGKey(0)
coeffs = e3nn.normal(e3nn.s2_irreps(lmax), key, leading_shape=(num_channels,))
hexagon_points_with_irreps = e3nn.IrrepsArray("1o", hexagon_points)

# Initialize the optimizer.
tx = optax.adam(1e-2)
opt_state = tx.init(coeffs)

# Run the training loop.
with tqdm.trange(1000) as pbar:
    for step in pbar:
        coeffs, opt_state, loss = train_step(coeffs, hexagon_points_with_irreps, opt_state)
        pbar.set_postfix({"loss": loss.item()})

Let's plot the predicted probability distribution on the sphere:

In [None]:
prob_signal, _ = coeffs_to_probability_distribution(coeffs)
fig = plot_signal(prob_signal)
fig.add_trace(
    go.Scatter3d(
        x=hexagon_points[:, 0],
        y=hexagon_points[:, 1],
        z=hexagon_points[:, 2],
        mode="markers",
        marker=dict(size=10, color="green"),
    )
)
fig.update_layout(title="Learned Probability Distribution")
fig.show()

We are able to represent the hexagon point cloud with a combination of $L = 2$ spherical harmonics, instead of requiring $L = 6$ spherical harmonics. 

A question that an interested reader might ask is: "What is the benefit of requiring a smaller L, because we need multiple channels now?". 
Here, we simply parametrized the coefficients directly. In practice, these coefficients would have to be predicted by an equivariant neural network.The tensor products needed to predict a higher degree spherical harmonic signal are much more expensive, scaling as $O(L^3)$ atleast; see [our paper](https://openreview.net/pdf?id=0HHidbjwcf) for the exact calculations!). It is much cheaper to predict multiple channels of lower degree spherical harmonics.

This is a simple example, but the same principle can be applied to more complex functions on the sphere. In Symphony, we generalized this trick to probability distributions over $R^3$ as well.