# Spherical Laguerre transforms by exploiting kernel precomputes

Lets start by importing some packages

In [None]:
import numpy as np

from jax import jit, device_put
import jax.numpy as jnp
from jax.config import config
from jax.lib import xla_bridge
print(xla_bridge.get_backend().platform)

config.update("jax_enable_x64", True)

from s2ball.construct.legendre_constructor import load_legendre_matrix
from s2ball.transform.laguerre import *
from s2ball.utils import *

### Generate a random complex bandlimited field

In [None]:
L = 8
P = L
tau = 1.0
rng = np.random.default_rng(193412341234)

legendre_forward = load_legendre_matrix(L, forward=True)
legendre_inverse = load_legendre_matrix(L, forward=False)

lag_poly_f = laguerre_sampling.polynomials(P, tau, forward=True)
lag_poly_i = laguerre_sampling.polynomials(P, tau, forward=False)

legendre_forward_jax = device_put(legendre_forward)
legendre_inverse_jax = device_put(legendre_inverse)

lag_poly_f_jax = device_put(lag_poly_f)
lag_poly_i_jax = device_put(lag_poly_i)


flmp = generate_flmp(rng, L, P)
f = inverse(flmp, L, P, tau, legendre_inverse, lag_poly_i)
flmp = forward(f, L, P, tau, legendre_forward, lag_poly_f)
f = inverse(flmp, L, P, tau, legendre_inverse, lag_poly_i)

# Forward transform

Shape: $(P, L, 2L-1) \rightarrow (P,L, 2L-1)$ triangularly oversampled spherical Laguerre coefficients.

FLAG implementation takes ~ 0.2s -- [figure 6b](https://arxiv.org/pdf/1205.0792.pdf)

### NumPy CPU implementation

In [None]:
flmp_numpy = forward_transform(f, legendre_forward, lag_poly_f)
%timeit forward_transform(f, legendre_forward, lag_poly_f)

### JAX GPU implementation

In [None]:
f_jax = device_put(f)
forward_jit = jit(forward_transform_jax)
flmp_jax = forward_jit(f_jax, legendre_forward_jax, lag_poly_f_jax).block_until_ready()

%timeit forward_jit(f_jax, legendre_forward_jax, lag_poly_f_jax)

### Evaluate transform error

In [None]:
flmp_jax = np.array(flmp_jax)
print("Numpy: Forward mean absolute error = {}".format(np.nanmean(np.abs(flmp_numpy - flmp))))
print("JAX: Forward mean absolute error = {}".format(np.nanmean(np.abs(flmp_jax - flmp))))

# Inverse transform 
Shape: $(P, L, 2L-1) \rightarrow (P, L, 2L-1)$ 

FLAG implementation takes ~ 0.2s -- [figure 6b](https://arxiv.org/pdf/1205.0792.pdf)

### NumPy CPU implementation

In [None]:
f_numpy = inverse_transform(flmp_numpy, legendre_inverse, lag_poly_i)
%timeit inverse_transform(flmp_numpy, legendre_inverse, lag_poly_i)

### JAX GPU implementation

In [None]:
flmp_jax = device_put(flmp_jax)
inverse_jit = jit(inverse_transform_jax)
f_jax = inverse_jit(flmp_jax, legendre_inverse_jax, lag_poly_i_jax).block_until_ready()

%timeit inverse_jit(flmp_jax, legendre_inverse_jax, lag_poly_i_jax)

### Evaluate transform error

In [None]:
f_jax = np.array(f_jax)
print("Numpy: Forward mean absolute error = {}".format(np.nanmean(np.abs(f_numpy - f))))
print("JAX: Forward mean absolute error = {}".format(np.nanmean(np.abs(f_jax - f))))