In [None]:
import numpy as np
import jax.numpy as jnp
import s2fft.wigner as wigner
import s2fft.samples as samples
import pyssht as ssht

from jax.config import config

config.update("jax_enable_x64", True)

In [None]:
L = 32
el = L-1
beta = 2*np.pi/3
spin = 0
sampling = "mw"

In [None]:
# Test all dl() terms up to L.
dl = np.zeros(2 * L - 1, dtype=np.float64)

dl_array = ssht.generate_dl(beta, L)[el][L-1-spin]

%timeit ssht.generate_dl(beta, L)

def test_dls(dl):
    np.testing.assert_allclose(
        dl, dl_array, atol=1e-10, rtol=1e-12
    )

In [None]:
dl = np.zeros(2 * L - 1, dtype=np.float64)
dl_turok = wigner.turok.compute_slice(dl, beta, el, L, -spin)

test_dls(dl_turok)

%timeit wigner.turok.compute_slice(dl, beta, el, L, -spin)

In [None]:
dl_turok_gpu = wigner.turok_gpu.compute_slice(beta, el, L, -spin)
test_dls(dl_turok_gpu)

%timeit wigner.turok_gpu.compute_slice(beta, el, L, -spin)

In [None]:
from jax import jit, device_put
forward_jit = jit(wigner.turok_gpu.compute_slice, static_argnums=(2,3))

beta_jax = device_put(beta)
el_jax = device_put(el)
L_jax = device_put(L)
spin_jax = device_put(-spin)

dl_turok_gpu_jit = forward_jit(beta, el, L, -spin).block_until_ready()
test_dls(dl_turok_gpu_jit)

%timeit forward_jit(beta, el, L, -spin)

In [None]:
print(dl_turok_gpu[0:10])
print(dl_turok_gpu_jit[0:10])
print(dl_turok[0:10])
print(dl_array[0:10])