# __Wigner transform__ Interactive Tutorial

---



This tutorial demonstrates how to call the Wigner transform apis within `S2FFT`. Specifically we will be working with the forward and inverse Wigner transforms (see [McEwen *et al*](https://arxiv.org/pdf/1508.03101.pdf)). To demonstrate how to apply ``S2FFT`` transforms we must first construct an input signal which is correctly sampled on the rotation group, sadly no particularly appealing come to hand so we will be working with a random signal.

In [None]:
import numpy as np
import s2fft 

L = 128
N = 3
reality = True
rng = np.random.default_rng(0)
flmn = s2fft.utils.signal_generator.generate_flmn(rng, L, N, reality=reality)

### Computing the inverse Wigner transform

---
Lets run the JAX function to compute the inverse Wigner transform of this random signal.

In [None]:
f = s2fft.wigner.inverse_jax(flmn, L, N, reality=reality)

If you are planning on applying this transform many times (e.g. during training of a model) we recommend precomputing and storing some small arrays which are used every time. To do this simply compute these and pass as a static argument

In [None]:
precomps = s2fft.generate_precomputes_wigner_jax(L, N, forward=False)
f = s2fft.wigner.inverse_jax(flmn, L, N, reality=reality, precomps=precomps)

### Computing the Wigner transform

---
Lets run the JAX function to get us back to the random Wigner coefficients.

In [None]:
flmn_test = s2fft.wigner.forward_jax(f, L, N, reality=reality)

Again, if you are planning on applying this transform many times (e.g. during training of a model) we recommend precomputing and storing some small arrays which are used every time. To do this simply compute these and pass as a static argument

In [None]:
precomps = s2fft.generate_precomputes_wigner_jax(L, N, forward=True)
flmn_pre = s2fft.wigner.forward_jax(f, L, N, reality=reality, precomps=precomps)

Lets check the roundtrip error, which should be close to machine precision for the McEwen-Wiaux sampling theorem which is selected by default

In [None]:
print(f"Mean absolute error = {np.nanmean(np.abs(flmn_test - flmn))}")