# Ball wavelet transform

Lets start by importing some packages

In [None]:
import numpy as np

from jax import jit, device_put
from jax.config import config
from jax.lib import xla_bridge
print(xla_bridge.get_backend().platform)

config.update("jax_enable_x64", True)

from s2ball.transform import laguerre, ball_wavelet
from s2ball.wavelets.helper_functions import *
from s2ball.wavelets import tiling
from s2ball.construct import wavelet_constructor
from s2ball.utils import *

### Generate a random complex bandlimited field
Here we generate random Spherical-Laguerre coefficients flmp which we then convert into a bandlimit signal f on $\mathbb{B}^4=\mathbb{R}^+\times\mathbb{S}^2$. We also compute some matrices which are cached and pass to their associated functions at run time.

In [None]:
L = 32
P = L
N = 3
tau = 1.0
lam = 2.0
rng = np.random.default_rng(193412341234)

Jl = j_max(L, lam)
Jp = j_max(P, lam)
wav_lmp_i, scal_lmp = tiling.compute_wav_lmp(L, N, P, lam, lam)
scal_lmp *= np.sqrt((4 * np.pi) / (2 * np.arange(L) + 1))

wav_lmp_f = tiling.construct_wav_lmp(L, P, lam, lam)
factor = 8 * np.pi**2 / (2 * np.arange(L) + 1)
for jp in range(Jp + 1):
    for jl in range(Jl + 1):
        L_l = angular_bandlimit(jl, lam)
        wav_lmp_f[jp][jl] = np.einsum(
            "pln,l->pln", np.conj(wav_lmp_i[jp][jl]), factor[:L_l]
        )

legendre_kernels = wavelet_constructor.wavelet_legendre_kernels(L)
lag_polys = wavelet_constructor.scaling_laguerre_kernels(P, tau)

wav_lag_polys_forward = wavelet_constructor.wavelet_laguerre_kernels(P, lam, tau, forward=False)
wig_kernels_forward = wavelet_constructor.wavelet_wigner_kernels(L, N, lam, forward=False)

wav_lag_polys_inverse = wavelet_constructor.wavelet_laguerre_kernels(P, lam, tau, forward=True)
wig_kernels_inverse = wavelet_constructor.wavelet_wigner_kernels(L, N, lam, forward=True)

wigner_kernels_forward_jax = device_put(wig_kernels_forward)
wigner_kernels_inverse_jax = device_put(wig_kernels_inverse)

wav_lag_polys_forward_jax = device_put(wav_lag_polys_forward)
wav_lag_polys_inverse_jax = device_put(wav_lag_polys_inverse)

legendre_kernels_jax = device_put(legendre_kernels)
lag_polys_jax = device_put(lag_polys)

wav_lmp_i_jax = device_put(wav_lmp_i)
wav_lmp_f_jax = device_put(wav_lmp_f)
scal_lmp_jax = device_put(scal_lmp)

flmp =  generate_flmp(rng, L, P)
f = laguerre.inverse(flmp, L, P, tau)
flmp = laguerre.forward(f, L, P, tau)
f = laguerre.inverse(flmp, L, P, tau)

# Forward transform

### NumPy CPU implementation

In [None]:
f_wav_numpy, f_scal_numpy = ball_wavelet.forward_transform(f, wav_lmp_f, scal_lmp, L, N, P, lam, lam, legendre_kernels, lag_polys, wig_kernels_forward, wav_lag_polys_forward)
%timeit ball_wavelet.forward_transform(f, wav_lmp_f, scal_lmp, L, N, P, lam, lam, legendre_kernels, lag_polys, wig_kernels_forward, wav_lag_polys_forward)

### JAX GPU implementation

In [None]:
f_jax = device_put(f)
forward_jit = jit(ball_wavelet.forward_transform_jax, static_argnums=(3, 4, 5, 6, 7))
f_wav_jax, f_scal_jax = forward_jit(f_jax, wav_lmp_f_jax, scal_lmp_jax, L, N, P, lam, lam, legendre_kernels_jax, lag_polys_jax, wigner_kernels_forward_jax, wav_lag_polys_forward_jax)

%timeit forward_jit(f_jax, wav_lmp_f_jax, scal_lmp_jax, L, N, P, lam, lam, legendre_kernels_jax, lag_polys_jax, wigner_kernels_forward_jax, wav_lag_polys_forward_jax)

# Inverse transform 

### NumPy CPU implementation

In [None]:
f_numpy = ball_wavelet.inverse_transform(f_wav_numpy, f_scal_numpy, wav_lmp_i, scal_lmp, L, N, P, lam, lam, legendre_kernels, lag_polys, wig_kernels_inverse, wav_lag_polys_inverse)
%timeit ball_wavelet.inverse_transform(f_wav_numpy, f_scal_numpy, wav_lmp_i, scal_lmp, L, N, P, lam, lam, legendre_kernels, lag_polys, wig_kernels_inverse, wav_lag_polys_inverse)

### JAX GPU implementation

In [None]:
inverse_jit = jit(ball_wavelet.inverse_transform_jax, static_argnums=(4, 5, 6, 7, 8))
f_jax = inverse_jit(f_wav_jax, f_scal_jax, wav_lmp_i_jax, scal_lmp_jax, L, N, P, lam, lam, legendre_kernels_jax, lag_polys_jax, wigner_kernels_inverse_jax, wav_lag_polys_inverse_jax)

%timeit inverse_jit(f_wav_jax, f_scal_jax, wav_lmp_i_jax, scal_lmp_jax, L, N, P, lam, lam, legendre_kernels_jax, lag_polys_jax, wigner_kernels_inverse_jax, wav_lag_polys_inverse_jax)

### Evaluate transform error

In [None]:
f_jax = np.array(f_jax)
print("Round-trip mean absolute difference = {}".format(np.nanmean(np.abs(f_numpy - f_jax))))