In [None]:
from typing import Callable
import functools
import e3nn_jax as e3nn
import jax
import jax.numpy as jnp
import numpy as np
import lovelyplots
import matplotlib.pyplot as plt
import matplotlib.colors as mcolors
import plotly.graph_objects as go
plt.style.use('ipynb')

In [None]:
jax.config.update("jax_enable_x64", True)

In [None]:
class RectangularSignal:
    """Create a signal defined on a rectangular region from a function of theta and phi."""

    def __init__(self, grid_values: jax.Array, res_theta: int, res_phi: int):
        self.grid_values = grid_values
        self.res_theta = res_theta
        self.res_phi = res_phi
        assert self.grid_values.shape == (*self.grid_values.shape[:-2], res_theta, res_phi)
    
    def from_function(f: Callable[[jax.Array, jax.Array], jax.Array], res_theta: int, res_phi: int, wrap_theta: bool = False):
        """Create a signal from a function of theta and phi."""
        if wrap_theta:
            def f_wrapped(theta, phi):
                f_val = f(theta, phi)
                fd_val = f(2 * jnp.pi - theta, phi)
                return jnp.where(theta < jnp.pi, f_val, fd_val)
            func = f_wrapped
        else:
            func = f
        thetas, phis = jnp.meshgrid(RectangularSignal._thetas(res_theta), RectangularSignal._phis(res_phi), indexing="ij")
        fn_vmap = jax.vmap(jax.vmap(func))
        grid_values = fn_vmap(thetas, phis)

        return RectangularSignal(grid_values, res_theta, res_phi)
    
    def thetas(self):
        return RectangularSignal._thetas(self.res_theta)
    
    def phis(self):
        return RectangularSignal._phis(self.res_phi)

    @staticmethod
    def _thetas(res_theta: int):
        """Returns the theta values of the grid."""
        return jnp.linspace(0, 2 * jnp.pi, res_theta)
    
    @staticmethod
    def _phis(res_phi: int):
        """Returns the phi values of the grid."""
        return jnp.linspace(0, 2 * jnp.pi, res_phi)
    
    def grid_values(self) -> jax.Array:
        """Computes the values of the signal on the grid."""
        # Use the double fourier trick.  
        if self.wrap_theta:
            def f_wrapped(theta, phi):
                f_val = self.f(theta, phi)
                fd_val = self.f(2 * jnp.pi - theta, phi)
                return jnp.where(theta < jnp.pi, f_val, fd_val)
            func = f_wrapped
        else:
            func = self.f

        thetas, phis = jnp.meshgrid(self.thetas(), self.phis(), indexing="ij")
        fn_vmap = jax.vmap(jax.vmap(func))
        return fn_vmap(thetas, phis)
    
    def integrate(
        self,
        area_element: str
    ) -> jax.Array:
        """Computes the integral of the signal over a rectangular/spherical region."""
        if area_element == "rectangular":
            return self.integrate_rectangular()
        elif area_element == "spherical":
            return self.integrate_spherical()
        else:
            raise ValueError(f"Unknown area element {area_element}")
    
    def integrate_rectangular(self) -> jax.Array:
        """Computes the integral of the signal over the rectangular region."""
        return RectangularSignal._integrate(self.grid_values, self.thetas(), self.phis())
    
    def integrate_spherical(self) -> jax.Array:
        """Computes the integral of the signal over the spherical region."""
        thetas = self.thetas()[:self.res_theta // 2]
        grid_values = self.grid_values[..., :self.res_theta // 2, :]
        return RectangularSignal._integrate(grid_values * jnp.sin(thetas)[:, None], thetas, self.phis())

    @staticmethod
    def _integrate(grid_values: jax.Array, thetas: jax.Array, phis: jax.Array) -> jax.Array:
        """Computes the integral of the signal over the rectangular region."""
        assert grid_values.shape == (len(thetas), len(phis))

        # Integrate over theta axis first, with the trapezoidal rule.
        # Ideally we would use the symmetry around theta = pi to reduce the number of points to integrate over,
        # but I don't think it's worth the effort for now.
        dtheta = thetas[1] - thetas[0]
        theta_weights = jnp.concatenate([jnp.array([0.5]), jnp.ones(len(thetas) - 2), jnp.array([0.5])])
        integral = jnp.sum(grid_values * theta_weights[:, None], axis=0) * dtheta
        assert integral.shape == (len(phis),)

        # Integrate over phi axis next, with the trapezoidal rule.
        dphi = phis[1] - phis[0]
        phi_weights = jnp.concatenate([jnp.array([0.5]), jnp.ones(len(phis) - 2), jnp.array([0.5])])
        integral = jnp.sum(integral * phi_weights, axis=0) * dphi
        assert integral.shape == ()
        return integral
    
    
    def __mul__(self, other):
        """Pointwise multiplication of two signals."""
        assert isinstance(other, RectangularSignal)
        assert self.res_theta == other.res_theta
        assert self.res_phi == other.res_phi
        return RectangularSignal(self.grid_values * other.grid_values, self.res_theta, self.res_phi)

In [None]:
def from_lm_index(lm_index: int) -> tuple:
    l = jnp.floor(jnp.sqrt(lm_index)).astype(jnp.int32)
    m = lm_index - l * (l + 1)
    return l, m


def to_lm_index(l: int, m: int) -> int:
    return l * (l + 1) + m


def sh_phi(l: int, m: int, phi: float) -> float:
    r"""Phi dependence of spherical harmonics.

    Args:
        l: l value
        phi: phi value

    Returns:
        Array of shape ``(2 * l + 1,)``
    """
    assert phi.ndim == 0
    phi = phi[..., None]  # [..., 1]
    ms = jnp.arange(1, l + 1)  # [1, 2, 3, ..., l]
    cos = jnp.cos(ms * phi)  # [..., m]

    ms = jnp.arange(l, 0, -1)  # [l, l-1, l-2, ..., 1]
    sin = jnp.sin(ms * phi)  # [..., m]

    return jnp.concatenate(
        [
            jnp.sqrt(2) * sin,
            jnp.ones_like(phi),
            jnp.sqrt(2) * cos,
        ],
        axis=-1,
    )[l + m]


def sh_theta(l: int, m: int, theta: float, lmax: int = 5) -> float:
    r"""Beta dependence of spherical harmonics.

    Args:
        lmax: l value
        theta: theta value

    Returns:
        Array of shape ``(l, m)``
    """
    assert theta.ndim == 0
    cos_theta = jnp.cos(theta)
    legendres = e3nn.legendre.legendre(lmax, cos_theta, phase=1.0, is_normalized=True)  # [l, m, ...]
    sh_theta_comp = legendres[l, jnp.abs(m)]
    return sh_theta_comp

def spherical_harmonic(l: int, m: int) -> float:
    r"""Spherical harmonics.

    Args:
        l: l value
        m: m value
        theta: theta value
        phi: phi value

    Returns:
        Array of shape ``()``
    """
    def Y_lm(theta, phi):
        assert theta.shape == phi.shape
        return sh_theta(l, m, theta) * sh_phi(l, m, phi)
    return Y_lm

In [None]:
theta = jnp.asarray(jnp.pi / 2)
sh_theta_comp_expected = jnp.asarray([
    jnp.sqrt(1 / (4 * jnp.pi)),
    jnp.sqrt(3 / (8 * jnp.pi)) * jnp.sin(theta),
    jnp.sqrt(3 / (4 * jnp.pi)) * jnp.cos(theta),
    jnp.sqrt(3 / (8 * jnp.pi)) * jnp.sin(theta),
])
sh_theta_comps = jnp.asarray([sh_theta(0, 0, theta), sh_theta(1, -1, theta), sh_theta(1, 0, theta), sh_theta(1, 1, theta)])
sh_theta_comp_expected, sh_theta_comps

Check that the SH match expected values upto l = 2.

In [None]:
def allclose(a: jax.Array, b: jax.Array) -> bool:
    return jnp.allclose(a, b, atol=1e-5)

expected_Y = {
    (0, 0): lambda theta, phi: jnp.sqrt(1/(4*jnp.pi)),
    (1, -1): lambda theta, phi: jnp.sin(theta) * jnp.sin(phi) * jnp.sqrt(3/(4*jnp.pi)),
    (1, 0): lambda theta, phi: jnp.cos(theta) * jnp.sqrt(3/(4*jnp.pi)),
    (1, 1): lambda theta, phi: jnp.sin(theta) * jnp.cos(phi) * jnp.sqrt(3/(4*jnp.pi)),
    (2, -2): lambda theta, phi: jnp.sin(theta)**2 * jnp.sin(2*phi) * jnp.sqrt(15/(16*jnp.pi)),
    (2, -1): lambda theta, phi: jnp.sin(theta) * jnp.cos(theta) * jnp.sin(phi) * jnp.sqrt(15/(4*jnp.pi)),
    (2, 0): lambda theta, phi: (3*jnp.cos(theta)**2 - 1) * jnp.sqrt(5/(16*jnp.pi)),
    (2, 1): lambda theta, phi: jnp.sin(theta) * jnp.cos(theta) * jnp.cos(phi) * jnp.sqrt(15/(4*jnp.pi)),
    (2, 2): lambda theta, phi: jnp.sin(theta)**2 * jnp.cos(2*phi) * jnp.sqrt(15/(16*jnp.pi)),
}

failed = []
for l in range(3):
    for m in range(-l, l+1):
        expected = RectangularSignal.from_function(expected_Y[(l, m)], res_theta=100, res_phi=100, wrap_theta=True).grid_values
        actual = RectangularSignal.from_function(spherical_harmonic(l, m), res_theta=100, res_phi=100, wrap_theta=True).grid_values
        if not allclose(actual, expected):
            print(f"l={l}, m={m} failed, diff={jnp.abs(actual - expected).max()}")
            max_diff = jnp.abs(actual - expected).max()
            failed_theta, failed_phi = jnp.unravel_index(jnp.argmax(jnp.abs(actual - expected)), actual.shape)
            print(f"Failed at theta={actual.thetas()[failed_theta]}, phi={actual.phis()[failed_phi]}")
            failed.append((l, m))
if not failed:
    print("All tests passed")

In [None]:
fig, axs = plt.subplots(ncols=2, nrows=1, figsize=(10, 5))

axs[0].imshow(RectangularSignal.from_function(spherical_harmonic(1, -1), 10, 10, True).grid_values.T, cmap='RdBu', vmin=-1, vmax=1)
axs[0].set_xlabel(r"$\theta$")
axs[0].set_ylabel(r"$\phi$")
axs[0].set_title(r"Computed $Y_{1,-1}$")

axs[1].imshow(RectangularSignal.from_function(expected_Y[(1, -1)], 10, 10, True).grid_values.T, cmap='RdBu', vmin=-1, vmax=1)
axs[1].set_xlabel(r"$\theta$")
axs[1].set_ylabel(r"$\phi$")
axs[1].set_title(r"Expected $Y_{1,-1}$")
fig.colorbar(mappable=axs[0].images[0], ax=axs, orientation='horizontal', label='Value')
plt.show()

# Spherical harmonics in the basis of 2D Fourier functions

$Y_{lm}(\theta, \phi) = \sum_{u,v} ({y^{lm}_{uv}})^* \exp{i (u \theta + v \phi)}$

so:

$
{y^{lm}_{uv}} = \frac{1}{2\pi} \int_0^{2\pi} \int_0^{2\pi} Y_{lm}(\theta, \phi) \exp{i (u \theta + v \phi)}  d\phi d\theta
$


In [None]:
def fourier_function(u: int, v: int) -> Callable[[float, float], float]:
    """Fourier function on the sphere."""
    def fourier(theta: float, phi: float) -> float:        
        return jnp.exp(1j * (u * theta + v * phi)) /(2 * jnp.pi)
    return fourier


@functools.lru_cache(maxsize=None)
def create_spherical_harmonic_signal(l: int, m: int, **grid_kwargs):
    """Creates a signal for Y^{l,m}."""
    return RectangularSignal.from_function(spherical_harmonic(l, m), **grid_kwargs)


@functools.lru_cache(maxsize=None)
def create_fourier_signal(u: int, v: int, **grid_kwargs):
    """Creates a signal for Fourier function defined by {u, v}."""
    return RectangularSignal.from_function(fourier_function(u, v), **grid_kwargs)


def to_u_index(u: int, lmax: int) -> int:
    """Returns the index of u in the grid."""
    return u + lmax

def to_v_index(v: int, lmax: int) -> int:
    """Returns the index of v in the grid."""
    return v + lmax


@functools.lru_cache(maxsize=None)
def compute_y(l: int, m: int, u: int, v: int, **grid_kwargs):
    """Computes y^{l,m}_{u, v}."""
    Y_signal = create_spherical_harmonic_signal(l, m, **grid_kwargs, wrap_theta=True)
    F_signal = create_fourier_signal(u, v, **grid_kwargs, wrap_theta=False)
    return (Y_signal * F_signal).integrate(area_element="rectangular")


def compute_y_grid(lmax: int, **grid_kwargs):
    """Computes the grid of y^{l,m}_{u, v}."""
    lm_indices = jnp.arange((lmax + 1) ** 2)
    us = jnp.arange(-lmax, lmax + 1)
    vs = jnp.arange(-lmax, lmax + 1)
    mesh = jnp.meshgrid(lm_indices, us, vs, indexing='ij')
    
    y_grid = np.zeros(((lmax + 1) ** 2, 2 * lmax + 1, 2 * lmax + 1), dtype=jnp.complex64)
    for lm_index, u, v in zip(*[m.ravel() for m in mesh]):
        l, m = from_lm_index(lm_index)
        u_index = to_u_index(u, lmax)
        v_index = to_v_index(v, lmax)
        l, m, u, v = int(l), int(m), int(u), int(v)
        y_grid[lm_index, u_index, v_index] = compute_y(l, m, u, v, **grid_kwargs)

    assert y_grid.shape == ((lmax + 1) ** 2, 2 * lmax + 1, 2 * lmax + 1)
    return y_grid



In [None]:
grid_kwargs = dict(res_theta=100, res_phi=100)
lmax = 2

In [None]:
u, v = 1, 0
F_signal = RectangularSignal.from_function(fourier_function(u, v), **grid_kwargs, wrap_theta=False)
Y_signal = RectangularSignal.from_function(lambda theta, phi: 1, **grid_kwargs, wrap_theta=True)
YF_signal = Y_signal * F_signal
YF_signal.integrate(area_element="spherical"), YF_signal.integrate(area_element="rectangular")

In [None]:
failed = []
for u in range(-lmax, lmax + 1):
    for v in range(-lmax, lmax + 1):
        F_signal = RectangularSignal.from_function(fourier_function(u, v), **grid_kwargs, wrap_theta=False)
        F_signal_conj = RectangularSignal.from_function(lambda theta, phi: jnp.conj(fourier_function(u, v)(theta, phi)), **grid_kwargs, wrap_theta=False)
        integral = (F_signal * F_signal_conj).integrate(area_element="rectangular")
        if not jnp.allclose(integral, 1):
            print(f"Failed: u={u}, v={v}, integral={integral}")
            failed.append((u, v))
if not failed:
    print("All tests passed")

In [None]:
lmax = 4
y_grid = compute_y_grid(lmax, **grid_kwargs)

In [None]:
l, m = 3, 2
assert jnp.abs(m) <= l
y_lm = y_grid[to_lm_index(l, m)]

plt.imshow(jnp.abs(y_lm).T, cmap="Grays", vmin=0)
plt.title("$|y_{{u, v}}^{{%d,%d}}|$" % (l, m))
plt.xticks(range(2 * lmax + 1), range(-lmax, lmax + 1))
plt.yticks(range(2 * lmax + 1), range(-lmax, lmax + 1))
plt.xlabel("u")
plt.ylabel("v")
plt.colorbar()
plt.show()


norm = mcolors.TwoSlopeNorm(vmin=-jnp.real(y_lm).max(), vcenter=0, vmax=jnp.real(y_lm).max())
plt.imshow(jnp.real(y_lm).T, cmap="RdBu", norm=norm)
plt.title("$\mathrm{{Re}}(y_{{u, v}}^{{%d,%d}})$" % (l, m))
plt.xticks(range(2 * lmax + 1), range(-lmax, lmax + 1))
plt.yticks(range(2 * lmax + 1), range(-lmax, lmax + 1))
plt.xlabel("u")
plt.ylabel("v")
plt.colorbar()
plt.show()


norm = mcolors.TwoSlopeNorm(vmin=-jnp.imag(y_lm).max(), vcenter=0, vmax=jnp.imag(y_lm).max())
plt.imshow(jnp.imag(y_lm).T, cmap="RdBu", norm=norm)
plt.title("$\mathrm{{Im}}(y_{{u, v}}^{{%d,%d}})$" % (l, m))
plt.xticks(range(2 * lmax + 1), range(-lmax, lmax + 1))
plt.yticks(range(2 * lmax + 1), range(-lmax, lmax + 1))
plt.xlabel("u")
plt.ylabel("v")
plt.colorbar()
plt.show()

# 2D Fourier functions Spherical harmonics in the basis of spherical harmonics


This means an extra factor of $\sin(\theta)$ in the integral.

$\exp{i (u \theta + v \phi)} = \sum_{u,v} {z^{lm}_{uv}} Y_{lm}(\theta, \phi) $

so:

$
{z^{lm}_{uv}} = \frac{1}{\sqrt{4\pi}} \int_0^{\pi} \int_0^{2\pi} Y_{lm}(\theta, \phi) \exp{i (u \theta + v \phi)} \sin \theta d\phi d\theta
$


In [None]:
failed = []
for l in range(0, lmax + 1):
    for m in range(-l, l + 1):
        Y_signal = RectangularSignal.from_function(spherical_harmonic(l, m), **grid_kwargs, wrap_theta=False)
        integral = (Y_signal * Y_signal).integrate(area_element="spherical")
        if not jnp.allclose(integral, 1, rtol=1e-2):
            print(f"Failed: l={l}, m={m}, integral={integral}")
            failed.append((l, m))
if not failed:
    print("All tests passed")

In [None]:
Y_signal = RectangularSignal.from_function(
    lambda theta, phi: jnp.sin(theta) * jnp.cos(phi) * jnp.sqrt(3/(4*jnp.pi)),
    **grid_kwargs, wrap_theta=False)
F_signal = RectangularSignal.from_function(
    fourier_function(1, 1),
    **grid_kwargs, wrap_theta=False)
(Y_signal * F_signal).integrate(area_element="spherical")


In [None]:

@functools.lru_cache(maxsize=None)
def compute_z(l: int, m: int, u: int, v: int, **grid_kwargs):
    """Computes z^{l,m}_{u, v}."""
    Y_signal = create_spherical_harmonic_signal(l, m, **grid_kwargs, wrap_theta=False)
    F_signal = create_fourier_signal(u, v, **grid_kwargs, wrap_theta=False)
    return (Y_signal * F_signal).integrate(area_element="spherical")


def compute_z_grid(lmax: int, **grid_kwargs):
    """Computes the grid of z^{l,m}_{u, v}."""
    lm_indices = jnp.arange((lmax + 1) ** 2)
    us = jnp.arange(-lmax, lmax + 1)
    vs = jnp.arange(-lmax, lmax + 1)
    mesh = jnp.meshgrid(lm_indices, us, vs, indexing='ij')
    
    z_grid = np.zeros(((lmax + 1) ** 2, 2 * lmax + 1, 2 * lmax + 1), dtype=jnp.complex64)
    for lm_index, u, v in zip(*[m.ravel() for m in mesh]):
        l, m = from_lm_index(lm_index)
        u_index = to_u_index(u, lmax)
        v_index = to_v_index(v, lmax)
        l, m, u, v = int(l), int(m), int(u), int(v)
        z_grid[lm_index, u_index, v_index] = compute_z(l, m, u, v, **grid_kwargs)

    assert z_grid.shape == ((lmax + 1) ** 2, 2 * lmax + 1, 2 * lmax + 1)
    return z_grid

In [None]:
lmax = 6
z_grid = compute_z_grid(lmax, **grid_kwargs)
y_grid = compute_y_grid(lmax, **grid_kwargs)

z_grid.shape, y_grid.shape

In [None]:
l, m = 2, 1
(z_grid[to_lm_index(l, m)] * y_grid[to_lm_index(l, m)].conj()).sum()

In [None]:
l, m = 2, 1
assert jnp.abs(m) <= l
z_lm = z_grid[to_lm_index(l, m)]

plt.imshow(jnp.abs(z_lm).T, cmap="Grays", vmin=0)
plt.title("$|z_{{u, v}}^{{%d,%d}}|$" % (l, m))
plt.xticks(range(2 * lmax + 1), range(-lmax, lmax + 1))
plt.yticks(range(2 * lmax + 1), range(-lmax, lmax + 1))
plt.xlabel("u")
plt.ylabel("v")
plt.colorbar()
plt.show()


norm = mcolors.TwoSlopeNorm(vmin=-jnp.real(z_lm).max(), vcenter=0, vmax=jnp.real(z_lm).max())
plt.imshow(jnp.real(z_lm).T, cmap="RdBu", norm=norm)
plt.title("$\mathrm{{Re}}(z_{{u, v}}^{{%d,%d}})$" % (l, m))
plt.xticks(range(2 * lmax + 1), range(-lmax, lmax + 1))
plt.yticks(range(2 * lmax + 1), range(-lmax, lmax + 1))
plt.xlabel("u")
plt.ylabel("v")
plt.colorbar()
plt.show()


norm = mcolors.TwoSlopeNorm(vmin=-jnp.imag(z_lm).max(), vcenter=0, vmax=jnp.imag(z_lm).max())
plt.imshow(jnp.imag(z_lm).T, cmap="RdBu", norm=norm)
plt.title("$\mathrm{{Im}}(z_{{u, v}}^{{%d,%d}})$" % (l, m))
plt.xticks(range(2 * lmax + 1), range(-lmax, lmax + 1))
plt.yticks(range(2 * lmax + 1), range(-lmax, lmax + 1))
plt.xlabel("u")
plt.ylabel("v")
plt.colorbar()
plt.show()

In [None]:
# Transform SH coefficients to 2D Fourier coefficients
x1_lm = e3nn.normal(e3nn.s2_irreps(lmax), key=jax.random.PRNGKey(1)).array
x1_lm = x1_lm.at[to_lm_index(2, -2):].set(0)
x1_uv = jnp.einsum("a,auv->uv", x1_lm, y_grid)
assert x1_uv.shape == (2 * lmax + 1, 2 * lmax + 1)

x2_lm = e3nn.normal(e3nn.s2_irreps(lmax), key=jax.random.PRNGKey(2)).array
x2_lm = x2_lm.at[to_lm_index(2, -2):].set(0)
x2_uv = jnp.einsum("a,auv->uv", x2_lm, y_grid)
assert x2_uv.shape == (2 * lmax + 1, 2 * lmax + 1)

In [None]:
# Check that we can convert back to SH coefficients
x1_lm_reconstructed = jnp.einsum("uv,auv->a", x1_uv.conj(), z_grid)
x1_lm_reconstructed[:5], x1_lm[:5]
print(x1_lm[0], x1_lm_reconstructed[0])
jnp.allclose(x1_lm, x1_lm_reconstructed, atol=5e-3)

In [None]:
# 1D convolution with FFT
x1 = jnp.asarray([1, 2, 3, 5])
x2 = jnp.asarray([6, 8, 4, 9])

x1_padded = jnp.pad(x1, (0, 3))
x2_padded = jnp.pad(x2, (0, 3))
x1_x2_expected = jnp.convolve(x1, x2, mode='full')
x1_x2 = jnp.fft.ifft(jnp.fft.fft(x1_padded) * jnp.fft.fft(x2_padded))
x1_x2.real.round(2), x1_x2_expected

In [None]:
# 2D convolution with FFT
x1 = jnp.asarray([[1, 2], [3, 5]])
x2 = jnp.asarray([[6, 8], [4, 9]])

x1_padded = jnp.pad(x1, (0, 1))
x2_padded = jnp.pad(x2, (0, 1))
print(x1_padded.shape, x2_padded.shape)
x1_x2_expected = jax.scipy.signal.convolve2d(x1_padded, x2_padded, mode='full')
x1_x2 = jnp.fft.ifft(jnp.fft.fft(x1_padded) * jnp.fft.fft(x2_padded))
x1_x2.real.round(2), x1_x2_expected

In [None]:
# Perform convolution manually
x1_x2_uv_manual = jnp.zeros((2 * lmax + 1, 2 * lmax + 1), dtype=jnp.complex64)
for u3 in range(-lmax, lmax + 1):
    for v3 in range(-lmax, lmax + 1):
        for u1 in range(-lmax // 2, lmax // 2 + 1):
            for v1 in range(-lmax // 2, lmax // 2 + 1):
                u2 = u3 - u1
                v2 = v3 - v1

                if not (-lmax <= u2 <= lmax and -lmax <= v2 <= lmax):
                    continue
                    
                x1_x2_uv_manual = x1_x2_uv_manual.at[to_u_index(u3, lmax), to_v_index(v3, lmax)].add(
                    x1_uv[to_u_index(u1, lmax), to_v_index(v1, lmax)] * x2_uv[to_u_index(u2, lmax), to_v_index(v2, lmax)]
                )

# Check that the convolution matches JAX.
x1_x2_uv_expected = jax.scipy.signal.convolve2d(x1_uv, x2_uv, mode='full')
print(x1_uv.shape, x2_uv.shape, x1_x2_uv_expected.shape)
print(x1_x2_uv_expected.round(2))
print(jnp.isclose(x1_x2_uv_manual, x1_x2_uv_expected, atol=1e-3))

# Convert back to SH coefficients
x1_x2_lm_manual = jnp.einsum("uv,auv->a", x1_x2_uv_manual.conj(), z_grid)
print(x1_x2_lm_manual.imag.min(), x1_x2_lm_manual.imag.max())

In [None]:
# 2D FFT
x1_fft = jnp.fft.fft2(x1_uv)
x2_fft = jnp.fft.fft2(x2_uv)
assert x1_fft.shape == (2 * lmax + 1, 2 * lmax + 1)
assert x2_fft.shape == (2 * lmax + 1, 2 * lmax + 1)

# Product in Fourier space
x1_x2_fft = x1_fft * x2_fft
assert x1_x2_fft.shape == (2 * lmax + 1, 2 * lmax + 1)

# Inverse 2D FFT
x1_x2_uv = jnp.fft.ifft2(x1_x2_fft)
assert x1_x2_uv.shape == (2 * lmax + 1, 2 * lmax + 1)
print((x1_x2_uv_expected - x1_x2_uv).max())

# Convert back to spherical harmonics
x1_x2_lm = jnp.einsum("uv,auv->a", x1_x2_uv.conj(), z_grid)
assert x1_x2_lm.shape == ((lmax + 1) ** 2,)

# Hmm, the result is not real, but it should be
# x1_x2_lm

In [None]:
x1_lm, x2_lm, x1_x2_lm.round(2)

In [None]:
print(x1_x2_lm.imag.max(), x1_x2_lm.imag.min())
product = e3nn.to_s2grid(e3nn.IrrepsArray(e3nn.s2_irreps(lmax), x1_x2_lm.real), res_beta=100, res_alpha=99, quadrature="soft")
go.Figure([go.Surface(product.plotly_surface())])
