# [`S2FFT`](https://github.com/astro-informatics/s2fft) - __Wigner transform__ Interactive Tutorial
---



This tutorial demonstrates how to call the Wigner transform apis within `S2FFT`. Specifically we will be working with the forward and inverse Wigner transforms (see [McEwen *et al*](https://arxiv.org/pdf/1508.03101.pdf)), i.e. 

$
\begin{equation}
f^{\ell}_{m n} = \int_{\text{SO}(3)} \text{d} \Omega(\alpha, \beta, \gamma) f(\alpha, \beta, \gamma) D^{\ell}_{mn}(\alpha, \beta \gamma) \qquad \text{and} \qquad 
f(\alpha, \beta, \gamma) = \sum_{\ell=0}^{\infty} \frac{2\ell +1}{8\pi^2} \sum_{m=-\ell}^{\ell} \sum_{n=-\ell}^{\ell} f^{\ell}_{mn} D^{\ell*}_{mn}(\alpha, \beta, \gamma)
\end{equation}
$

respectively, where $(\alpha, \beta, \gamma)$ are Euler angles in the $zyz$-convention, the $\infty$ is truncated at some upper bandlimit L such that $f^{\ell}_{mn} = 0 \: \forall \: \ell > L$, and $\text{d}\Omega(\theta, \varphi)$ is the typical Haar measure on the special orthogonal group SO(3). To demonstrate how to apply ``S2FFT`` transforms we must first construct an input signal which is correctly sampled on the rotation group, sadly no particularly appealing come to hand so we will be working with a random signal.

In [None]:
import numpy as np
from s2fft.transforms import wigner
from s2fft.recursions import price_mcewen
from s2fft.sampling import so3_samples as samples
from s2fft.utils.signal_generator import generate_flmn

L = 128
N = 3
reality = True
rng = np.random.default_rng(0)
flmn = generate_flmn(rng, L, N, reality=reality)

Note that if reality is True ``S2FFT`` will enforce Hermitian symmetry $f^{\ell}_{mn} = -1^{m+n}f^{\ell*}_{-m,-n}$ which leads to a 2-fold reduction to both memory overhead and compute time.

### Computing the inverse Wigner transform
---
Lets JIT compile a JAX function to compute the inverse Wigner transform of this random signal. First we will run a fast precompute to generate a list of arrays with memory overhead which scales as $\mathcal{O}(NL^2)$. Note that this is the same memory as the input image which is negligible in all but the most extreme cases, e.g. very large bandlimits. Further note that these values need only be computed a single time (both for the forward and inverse transforms) after which they may be reused indefinitely.

In [None]:
inverse_precomputes = price_mcewen.generate_precomputes_wigner_jax(L, N, forward=False, reality=reality)

from which we can now readily call a function to map back into pixel-space

In [None]:
f = wigner.inverse_jax(flmn, L, N, reality=reality, precomps=inverse_precomputes)

### Computing the Wigner transform
---
Lets JIT compile a JAX function to get us back to the random Wigner coefficients. Again we'll begin by generating some precomputes that need only be computed a single time.

In [None]:
forward_precomputes = price_mcewen.generate_precomputes_wigner_jax(L, N, forward=True, reality=reality)

Now lets go ahead compute the Wigner coefficients $f^{\ell}_{mn}$ from $f(\alpha, \beta, \gamma)$ by applying the forward Wigner transform.

In [None]:
flmn_test = wigner.forward_jax(f, L, N, reality=reality, precomps=forward_precomputes)

``S2FFT`` adopts a redundant indexing system to ensure that arrays are of fixed shapes, which is a strict requirement of many JAX apis. Specifically we store $f^{\ell}_{mn}$ as a 3-dimensional array, with indices $0 \leq \ell < L$, $-L < m < L$, and $-N < n < N$. As $D^{\ell}_{mn}$ is strictly 0 for entries where $m,n > \ell$ our matrix is triangular, i.e. we have 

\begin{equation}
    \text{ Format for each n}:
        \begin{bmatrix}
            0 & 0 & flmn_{(0,0)} & 0 & 0 \\
            0 & flmn_{(1,-1)} & flmn_{(1,0)} & flmn_{(1,1)} & 0 \\
            flmn_{(2,-2)} & flmn_{(2,-1)} & flmn_{(2,0)} & flmn_{(2,1)} & flmn_{(2,2)}
        \end{bmatrix}
\end{equation}

Lets check the roundtrip error, which should be close to machine precision for the McEwen-Wiaux sampling theorem which is selected by default

In [None]:
print(f"Mean absolute error = {np.nanmean(np.abs(flmn_test - flmn))}")