# Wigner Laguerre transforms by exploiting kernel precomputes

Lets start by importing some packages

In [None]:
import os
os.environ['CUDA_VISIBLE_DEVICES'] = ''
# os.environ['CUDA_VISIBLE_DEVICES'] = '2'

import numpy as np

from jax import jit, device_put
import jax.numpy as jnp
from jax.config import config
from jax.lib import xla_bridge
print(xla_bridge.get_backend().platform)

config.update("jax_enable_x64", True)

from baller.construct.wigner_constructor import load_wigner_matrix
from baller.transform.wigner_laguerre import *
from baller.sampling import laguerre_sampling
from baller.utils import *

### Generate a random complex bandlimited field

In [None]:
L = 32
P = 32
N = 3
tau = 1.0
rng = np.random.default_rng(193412341234)

wigner_forward = load_wigner_matrix(L, N, forward=True)
wigner_inverse = load_wigner_matrix(L, N, forward=False)

lag_poly_f = laguerre_sampling.polynomials(P, tau, forward=True)
lag_poly_i = laguerre_sampling.polynomials(P, tau, forward=False)

wigner_forward_jax = device_put(wigner_forward)
wigner_inverse_jax = device_put(wigner_inverse)

lag_poly_f_jax = device_put(lag_poly_f)
lag_poly_i_jax = device_put(lag_poly_i)


flmnp = generate_flmnp(rng, L, N, P)
f = inverse(flmnp, L, N, P, tau, wigner_inverse, lag_poly_i)
flmnp = forward(f, L, N, P, tau, wigner_forward, lag_poly_f)
f = inverse(flmnp, L, N, P, tau, wigner_inverse, lag_poly_i)

# Forward transform

Shape: $(P, 2N-1, L, 2L-1) \rightarrow (P, 2N-1, L, 2L-1)$ triangularly oversampled spherical Laguerre coefficients.

FLAG implementation takes ~ (2N-1)*0.2s -- [figure 6b](https://arxiv.org/pdf/1205.0792.pdf) extrapolated by assuming harmonic transform for N is linear (it may well not be...)

### NumPy CPU implementation

In [None]:
flmnp_numpy = forward_transform(f, wigner_forward, lag_poly_f, L, N)
%timeit forward_transform(f, wigner_forward, lag_poly_f, L, N)

### JAX GPU implementation

In [None]:
f_jax = device_put(f)
forward_jit = jit(forward_transform_jax, static_argnums=(3,4))
flmnp_jax = forward_jit(f_jax, wigner_forward_jax, lag_poly_f_jax, L, N).block_until_ready()

%timeit forward_jit(f_jax, wigner_forward_jax, lag_poly_f_jax, L, N)

### Evaluate transform error

In [None]:
flmnp_jax = np.array(flmnp_jax)
print("Numpy: Forward mean absolute error = {}".format(np.nanmean(np.abs(flmnp_numpy - flmnp))))
print("JAX: Forward mean absolute error = {}".format(np.nanmean(np.abs(flmnp_jax - flmnp))))

# Inverse transform 
Shape: $(P, 2N-1, L, 2L-1) \rightarrow (P, 2N-1, L, 2L-1)$ 

FLAG implementation takes ~ (2N-1)*0.2s -- [figure 6b](https://arxiv.org/pdf/1205.0792.pdf) extrapolated by assuming harmonic transform for N is linear (it may well not be...)

### NumPy CPU implementation

In [None]:
f_numpy = inverse_transform(flmnp_numpy, wigner_inverse, lag_poly_i, L)
%timeit inverse_transform(flmnp_numpy, wigner_inverse, lag_poly_i, L)

### JAX GPU implementation

In [None]:
flmnp_jax = device_put(flmnp_jax)
inverse_jit = jit(inverse_transform_jax, static_argnums=(3))
f_jax = inverse_jit(flmnp_jax, wigner_inverse_jax, lag_poly_i_jax, L).block_until_ready()

%timeit inverse_jit(flmnp_jax, wigner_inverse_jax, lag_poly_i_jax, L)

### Evaluate transform error

In [None]:
f_jax = np.array(f_jax)
print("Numpy: Forward mean absolute error = {}".format(np.nanmean(np.abs(f_numpy - f))))
print("JAX: Forward mean absolute error = {}".format(np.nanmean(np.abs(f_jax - f))))