# Wigner transforms by exploiting kernel precomputes

Lets start by importing some packages

In [None]:
import numpy as np
import so3

from jax import jit, device_put
import jax.numpy as jnp
from jax.config import config
from jax.lib import xla_bridge
print(xla_bridge.get_backend().platform)

config.update("jax_enable_x64", True)

from s2ball.construct.wigner_constructor import load_wigner_matrix
from s2ball.transform.wigner import *
from s2ball.utils import *

### Generate a random complex bandlimited field

In [None]:
L = 32
N = 32
params = so3.create_parameter_dict(L=L, N=N, sampling_scheme_str="SO3_SAMPLING_MW")
rng = np.random.default_rng(193412341234)

flmn_3d = generate_flmn(rng, L, N)
flmn_1d = flmn_3d_to_1d(flmn_3d, L, N)

f_1d = so3.inverse(flmn_1d, params)
f_3d = f_1d.reshape(2 * N - 1, L, 2 * L - 1)

### Load/construct relevant Wigner kernels

In [None]:
wigner_forward = load_wigner_matrix(L=L, N=N, forward=True)
wigner_inverse = load_wigner_matrix(L=L, N=N, forward=False)

wigner_forward_jax = device_put(wigner_forward)
wigner_inverse_jax = device_put(wigner_inverse)

# Forward transform

Shape: $(2N-1, L, 2L-1) \rightarrow (2N-1,L, 2L-1)$ triangularly oversampled wigner coefficients.

### SO3 CPU Cython implementation

In [None]:
flmn_1d = so3.forward(f_1d, params)
%timeit so3.forward(f_1d, params)

### NumPy CPU implementation

In [None]:
flmn_numpy = forward_transform(f_3d, wigner_forward, L, N)
%timeit forward_transform(f_3d, wigner_forward, L, N)

### JAX GPU implementation

In [None]:
f_jax = device_put(f_3d)
forward_jit = jit(forward_transform_jax, static_argnums=(2, 3))
flmn_jax = forward_jit(f_jax, wigner_forward_jax, L, N).block_until_ready()

%timeit forward_jit(f_jax, wigner_forward_jax, L, N)

### Evaluate transform error

In [None]:
flmn_jax = np.array(flmn_jax)
print("Numpy: Forward mean absolute error = {}".format(np.nanmean(np.abs(flmn_numpy - flmn_3d))))
print("JAX: Forward mean absolute error = {}".format(np.nanmean(np.abs(flmn_jax - flmn_3d))))

# Inverse transform 
Shape: $(2N-1, L, 2L-1) \rightarrow (2N-1, L, 2L-1)$ 

### SO3 CPU Cython implementation

In [None]:
f_check = so3.inverse(flmn_1d, params)
%timeit so3.inverse(flmn_1d, params)

### NumPy CPU implementation

In [None]:
f_numpy = inverse_transform(flmn_numpy, wigner_inverse, L)
%timeit inverse_transform(flmn_numpy, wigner_inverse, L)

### JAX GPU implementation

In [None]:
flmn_jax = device_put(flmn_jax)
inverse_jit = jit(inverse_transform_jax, static_argnums=(2))
f_jax = inverse_jit(flmn_jax, wigner_inverse_jax, L).block_until_ready()

%timeit inverse_jit(flmn_jax, wigner_inverse_jax, L)

### Evaluate transform error

In [None]:
f_jax = np.array(f_jax)
f_check = f_check.reshape(2*N-1, L, 2*L-1)
print("Numpy: Forward mean absolute error = {}".format(np.nanmean(np.abs(f_numpy - f_check))))
print("JAX: Forward mean absolute error = {}".format(np.nanmean(np.abs(f_jax - f_check))))