# Spin functional JAX-SHT by exploiting kernel precomputes

Lets start by importing some packages

In [None]:
import numpy as np
import pyssht as ssht 

from jax import jit, device_put
import jax.numpy as jnp
from jax.config import config

config.update("jax_enable_x64", True)

from s2fft.precompute.construct_legendre_matrix import load_legendre_matrix
from s2fft.precompute.transforms import *
from s2fft.utils import generate_flm

### Generate a random complex bandlimited field

In [None]:
L = 128
spin = 2
flm = generate_flm(L, spin)
f = ssht.inverse(flm, L, spin)

### Load/construct relevant associated Legendre kernels

In [None]:
legendre_forward = load_legendre_matrix(L=L, direction="forward", save_dir="../.matrices", spin=spin)
legendre_inverse = load_legendre_matrix(L=L, direction="inverse", save_dir="../.matrices", spin=spin)

legendre_forward_jax = device_put(legendre_forward)
legendre_inverse_jax = device_put(legendre_inverse)

# Forward transform

Shape: $(L, 2L-1) \rightarrow L(2L-1)$ triangularly oversampled spherical harmonic coefficients.

### SSHT CPU Cython implementation

In [None]:
%timeit ssht.forward(f, L, spin)

### NumPy CPU implementation

In [None]:
flm_cpu = forward_transform_cpu(f, legendre_forward, L)
%timeit forward_transform_cpu(f, legendre_forward, L)

### JAX GPU implementation

In [None]:
f_jax = device_put(f)
forward_jit = jit(forward_transform_gpu, static_argnums=(2,))
flm_gpu = forward_jit(f_jax, legendre_forward_jax, L).block_until_ready()

%timeit forward_jit(f_jax, legendre_forward_jax, L)

### Evaluate transform error

In [None]:
flm_gpu = np.array(flm_gpu)
print("CPU: Forward mean absolute error = {}".format(np.nanmean(np.abs(flm_cpu[np.nonzero(flm_cpu)] - flm[spin**2:]))))
print("GPU: Forward mean absolute error = {}".format(np.nanmean(np.abs(flm_gpu[np.nonzero(flm_gpu)] - flm[spin**2:]))))

# Inverse transform 
Shape: $L(2L-1) \rightarrow (L, 2L-1)$ 

### SSHT CPU Cython implementation

In [None]:
%timeit ssht.inverse(flm, L, spin)

### NumPy CPU implementation

In [None]:
f_cpu = inverse_transform_cpu(flm_cpu, legendre_inverse, L)
%timeit inverse_transform_cpu(flm_cpu, legendre_inverse, L)

### JAX GPU implementation

In [None]:
flm_gpu = device_put(flm_gpu)
inverse_jit = jit(inverse_transform_gpu, static_argnums=(2,))
f_gpu = inverse_jit(flm_gpu, legendre_inverse_jax, L).block_until_ready()

%timeit inverse_jit(flm_gpu, legendre_inverse_jax, L)

### Evaluate transform error

In [None]:
f_gpu = np.array(f_gpu)
print("CPU: Forward mean absolute error = {}".format(np.nanmean(np.abs(f_cpu - f))))
print("GPU: Forward mean absolute error = {}".format(np.nanmean(np.abs(f_gpu - f))))