# Spin functional JAX-SHT by exploiting kernel precomputes

Lets start by importing some packages

In [None]:
import os
# os.environ['CUDA_VISIBLE_DEVICES'] = ''
os.environ['CUDA_VISIBLE_DEVICES'] = '2'

import numpy as np
import healpy as hp 

from jax import jit, device_put
import jax.numpy as jnp

from jax.config import config
config.update("jax_enable_x64", True)

from jax.lib import xla_bridge
print(xla_bridge.get_backend().platform)

from s2fft.utils import *
from s2fft.samples import flm_2d_to_hp
from s2fft.general_precompute import spin_spherical, construct

### Generate a random complex bandlimited field

In [None]:
nside = 4
L = 2*nside
spin = 0
sampling = "healpix"
reality = False 
rng = np.random.default_rng(193412341234)

flm = generate_flm(rng, L, spin, reality)
for el in range(L):
    elfactor = np.sqrt((2 * el + 1) / (4 * np.pi))
    flm[el] *= elfactor
flm_hp = flm_2d_to_hp(flm, L)
f = hp.sphtfunc.alm2map(flm_hp, nside, lmax=L - 1)


### Load/construct relevant associated Legendre kernels

In [None]:
kernel_i = construct.spin_spherical_kernel(L, spin, reality, sampling, nside, forward=False)
kernel_i_jax = device_put(kernel_i)
print(f"Kernel memory = {kernel_i.nbytes*(1e-6)} MB")

phases_i = construct.healpix_phase_shifts(L, nside, False)
phases_i_jax = device_put(phases_i)

# Inverse transform 

In [None]:
f_check = hp.sphtfunc.alm2map(flm_hp, nside, lmax=L - 1)
%timeit hp.sphtfunc.alm2map(flm_hp, nside, lmax=L - 1)

### NumPy CPU implementation

In [None]:
f_numpy = spin_spherical.inverse_transform(flm, kernel_i, L, sampling, spin, nside, phases_i)
%timeit spin_spherical.inverse_transform(flm, kernel_i, L, sampling, spin, nside, phases_i)

### JAX GPU implementation

In [None]:
flm_jax = device_put(flm)
f_jax = spin_spherical.inverse_transform_jax(flm_jax, kernel_i_jax, L, sampling, spin, nside, phases_i_jax)
%timeit spin_spherical.inverse_transform_jax(flm_jax, kernel_i_jax, L, sampling, spin, nside, phases_i_jax)

### Evaluate transform error

In [None]:
f_jax = np.array(f_jax)
print("numpy: Inverse mean absolute error = {}".format(np.nanmean(np.abs(np.real(f_numpy) - np.real(f_check)))))
print("jax: Inverse mean absolute error = {}".format(np.nanmean(np.abs(np.real(f_jax) - np.real(f_check)))))