# [`S2FFT`](https://github.com/astro-informatics/s2fft) - __Spherical harmonic transform__ Interactive Tutorial
---



This tutorial demonstrates how to call the spherical harmonic transform apis within `S2FFT`. Specifically we will be working with the forward and inverse spin-s spherical harmonic transforms (see [McEwen & Wiaux](https://arxiv.org/pdf/1110.6298.pdf)), i.e. 

$
\begin{equation}
{}_sf_{\ell m} = \int_{\mathbb{S}^2} \text{d}\Omega (\theta, \varphi) {}_sf(\theta, \varphi) {}_sY^*_{\ell m}(\theta, \varphi) \qquad \text{and} \qquad 
{}_sf(\theta, \varphi) = \sum_{\ell=0}^{\infty} \sum_{m=-\ell}^{\ell} {}_sf_{\ell m} {}_sY_{\ell m}(\theta, \varphi)
\end{equation}
$

respectively, where the $\infty$ is truncated at some upper bandlimit L such that $_sf_{\ell m} = 0 \: \forall \: \ell > L$ and $\text{d}\Omega(\theta, \varphi)$ is the typical Haar measure on $\mathbb{S}^2$. To demonstrate how to apply ``S2FFT`` transforms we must first construct an input signal which is correctly sampled on the sphere, one such image could be that of the Galactic plane captured by ESA's [Gaia satellite](https://sci.esa.int/web/gaia)!

In [None]:
import numpy as np
from plotting_functions import plot_sphere
from s2fft.transforms import spherical
from s2fft.recursions import price_mcewen
from s2fft.sampling import s2_samples as samples

L = 1000
sampling = "mw"
f = np.load('data/Gaia_EDR3_flux.npy')

Note that if reality is True ``S2FFT`` will enforce Hermitian symmetry ${}_sf_{\ell m} = -1^{s+m}{}_{-s}f^*_{\ell, -m}$ which leads to a 2-fold reduction to both memory overhead and compute time for real signals with spin $s=0$.

Now, lets take a look at the data on the sphere using [`PyVista`](https://docs.pyvista.org/index.html) (try moving the camera inside to see what you would see from the earth)!

In [None]:
plot_sphere(f, L, sampling)

### Computing the spherical harmonic transform
---

Lets JIT compile a JAX function to compute the spherical harmonic transform of this observational map. First we will run a fast precompute to generate a list of arrays with memory overhead which scales as $\mathcal{O}(L^2)$. Note that this is the same memory as the input image which is negligible in all but the most extreme cases, e.g. very large bandlimits. Further note that these values need only be computed a single time (both for the forward and inverse transforms) after which they may be reused indefinitely.

In [None]:
forward_precomputes = price_mcewen.generate_precomputes_jax(L, forward=True)

Now lets go ahead compute the spherical harmonic coefficients ${}_sf_{\ell m}$ from ${}_sf(\theta, \varphi)$ by applying the forward spherical harmonic transform.

In [None]:
flm = spherical.forward_jax(f, L, spin=0, reality=True, precomps=forward_precomputes)

``S2FFT`` adopts a redundant indexing system to ensure that arrays are of fixed shapes, which is a strict requirement of many JAX apis. Specifically we store ${}_sf_{\ell m}$ as a 2-dimensional array, with indices $0 \leq \ell < L$ and $-L < m < L$. As $Y_{\ell m}$ is strictly 0 for entries where $m > \ell$ our matrix is triangular, i.e. we have 

\begin{equation}
    \text{ 2D data format}:
        \begin{bmatrix}
            0 & 0 & flm_{(0,0)} & 0 & 0 \\
            0 & flm_{(1,-1)} & flm_{(1,0)} & flm_{(1,1)} & 0 \\
            flm_{(2,-2)} & flm_{(2,-1)} & flm_{(2,0)} & flm_{(2,1)} & flm_{(2,2)}
        \end{bmatrix}
\end{equation}

### Computing the inverse spherical harmonic transform
---

Lets JIT compile a JAX function to get us back to the observational map, or at least a bandlimited version of said map. Again we'll begin by generating some precomputes that need only be computed a single time

In [None]:
inverse_precomputes = price_mcewen.generate_precomputes_jax(L, forward=False)

from which we can now readily call a function to map back into pixel-space

In [None]:
f_test = spherical.inverse_jax(flm, L, spin=0, reality=True, precomps=inverse_precomputes)

and we can check the associated error and again look at the night sky

In [None]:
print(f"Mean absolute error = {np.nanmean(np.abs(np.real(f_test) - f))}")