# Spherical-Laguerre transform

Lets start by importing some packages

In [10]:
# Lets set the precision.
from jax.config import config
config.update("jax_enable_x64", True)

# Import math libraries.
import numpy as np
import jax.numpy as jnp

# Check which devices we're running on.
from jax.lib import xla_bridge
print(xla_bridge.get_backend().platform)

# Import the s2ball library.
import s2ball
from s2ball.transform import laguerre

cpu


  from jax.config import config


### Generate a random complex bandlimited field
Here we generate random Spherical-Laguerre coefficients flmp which we then convert into a bandlimit signal f on $\mathbb{B}^3=\mathbb{R}^+\times \mathbb{S}^2$. We also generate some precomputed values which are cached and passed to the associated transforms at run time.

In [11]:
L = 32        # Harmonic bandlimit of the problem.
P = 32        # Radial bandlimit of the problem.

# Define a random seed
rng = np.random.default_rng(193412341234)

# Use s2ball functions to generate a random signal.
flmp = s2ball.utils.generate_flmp(rng, L, P)
f = laguerre.inverse(flmp, L, P)    # Note currently this has to explicitly bandlimit flmp,
flmp = laguerre.forward(f, L, P)    # as I have yet to enforce bandlimiting symmetries to
f = laguerre.inverse(flmp, L, P)    # generate_flmp directly.


### Load/construct relevant associated Legendre matrices

Load precomputed associated Legendre matrices which are used to evaluate the spherical harmonic transform. If these matrices have already been computed, the load function will attempt to locate them inside the .matrices hidden directory. Note that you can specify a directory of your choice, .matrices is simply a default. 

In [12]:
matrices = s2ball.construct.matrix.generate_matrices("spherical_laguerre", L, P=P)

# Forward transform

Shape: $(P, L, 2L-1) \rightarrow (P,L, 2L-1)$ triangularly oversampled spherical Laguerre coefficients.

### NumPy CPU implementation

In [13]:
flmp_numpy = laguerre.forward_transform(f, matrices)
%timeit laguerre.forward_transform(f, matrices)

10.2 ms ± 18.9 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)


### JAX GPU implementation

In [14]:
flmp_jax = laguerre.forward_transform_jax(f, matrices)
%timeit laguerre.forward_transform_jax(f, matrices)

1.63 ms ± 1.75 µs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)


### Evaluate transform error

In [15]:
print("Numpy: Forward mean absolute error = {}".format(np.nanmean(np.abs(flmp_numpy - flmp))))
print("JAX: Forward mean absolute error = {}".format(np.nanmean(np.abs(flmp_jax - flmp))))

Numpy: Forward mean absolute error = 2.833944977439441e-14
JAX: Forward mean absolute error = 2.8340976016381722e-14


# Inverse transform 
Shape: $(P, L, 2L-1) \rightarrow (P, L, 2L-1)$ 

### NumPy CPU implementation

In [16]:
f_numpy = laguerre.inverse_transform(flmp_numpy, matrices)
%timeit laguerre.inverse_transform(flmp_numpy, matrices)

13.3 ms ± 158 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)


### JAX GPU implementation

In [17]:
f_jax = laguerre.inverse_transform_jax(flmp_jax, matrices)
%timeit laguerre.inverse_transform_jax(flmp_jax, matrices)

1.69 ms ± 18 µs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)


### Evaluate transform error

In [18]:
print("Numpy: Forward mean absolute error = {}".format(np.nanmean(np.abs(f_numpy - f))))
print("JAX: Forward mean absolute error = {}".format(np.nanmean(np.abs(f_jax - f))))

Numpy: Forward mean absolute error = 2.758189784854311e-13
JAX: Forward mean absolute error = 2.7573983646235196e-13
