# Spherical harmonic transform

Lets start by importing some packages

In [None]:
import numpy as np
import pyssht as ssht 

from jax import jit, device_put
import jax.numpy as jnp
from jax.config import config
from jax.lib import xla_bridge
print(xla_bridge.get_backend().platform)

config.update("jax_enable_x64", True)

from s2ball.construct.legendre_constructor import load_legendre_matrix
from s2ball.transform.harmonic import *
from s2ball.utils import *

### Generate a random complex bandlimited field

Here we generate random harmonic coefficients flm_2d which we then convert into a bandlimit signal f on $\mathbb{S}^2$.

In [None]:
L = 64
spin = 0
rng = np.random.default_rng(193412341234)
flm_2d = generate_flm(rng, L, spin)
flm = flm_2d_to_1d(flm_2d, L)
f = ssht.inverse(flm, L, spin)

### Load/construct relevant associated Legendre matrices

Load precomputed associated Legendre matrices which are used to evaluate the spherical harmonic transform. If these matrices have already been computed, the load function will attempt to locate them inside the .matrices hidden directory. Note that you can specify a directory of your choice, .matrices is simply a default. 

In [None]:
legendre_forward = load_legendre_matrix(L=L, forward=True, spin=spin)
legendre_inverse = load_legendre_matrix(L=L, forward=False, spin=spin)

legendre_forward_jax = device_put(legendre_forward)
legendre_inverse_jax = device_put(legendre_inverse)

# Forward transform

Shape: $(L, 2L-1) \rightarrow (L, 2L-1)$ triangularly oversampled spherical harmonic coefficients.

### SSHT CPU Cython implementation

In [None]:
%timeit ssht.forward(f, L, spin)

### NumPy CPU implementation

In [None]:
flm_numpy = forward_transform(f, legendre_forward)
%timeit forward_transform(f, legendre_forward)

### JAX GPU implementation

In [None]:
f_jax = device_put(f)
forward_jit = jit(forward_transform_jax)
flm_jax = forward_jit(f_jax, legendre_forward_jax).block_until_ready()

%timeit forward_jit(f_jax, legendre_forward_jax)

### Evaluate transform error

In [None]:
flm_jax = np.array(flm_jax)
print("Numpy: Forward mean absolute error = {}".format(np.nanmean(np.abs(flm_numpy - flm_2d))))
print("JAX: Forward mean absolute error = {}".format(np.nanmean(np.abs(flm_jax - flm_2d))))

# Inverse transform 
Shape: $(L, 2L-1) \rightarrow (L, 2L-1)$ 

### SSHT CPU Cython implementation

In [None]:
%timeit ssht.inverse(flm, L, spin)

### NumPy CPU implementation

In [None]:
f_numpy = inverse_transform(flm_numpy, legendre_inverse)
%timeit inverse_transform(flm_numpy, legendre_inverse)

### JAX GPU implementation

In [None]:
flm_jax = device_put(flm_jax)
inverse_jit = jit(inverse_transform_jax)
f_jax = inverse_jit(flm_jax, legendre_inverse_jax).block_until_ready()

%timeit inverse_jit(flm_jax, legendre_inverse_jax)

### Evaluate transform error

In [None]:
f_jax = np.array(f_jax)
print("Numpy: Inverse mean absolute error = {}".format(np.nanmean(np.abs(f_numpy - f))))
print("JAX: Inverse mean absolute error = {}".format(np.nanmean(np.abs(f_jax - f))))