In [1]:
import numpy as np
import jax.numpy as jnp
import s2fft.wigner as wigner
import s2fft.samples as samples
import pyssht as ssht

from jax.config import config

config.update("jax_enable_x64", True)

In [2]:
L = 128
el = L-1
beta = 2*np.pi/3
spin = 0
sampling = "mw"

In [3]:
# Test all dl() terms up to L.
dl = np.zeros(2 * L - 1, dtype=np.float64)

dl_array = ssht.generate_dl(beta, L)[el][L-1-spin]

%timeit ssht.generate_dl(beta, L)

def test_dls(dl):
    np.testing.assert_allclose(
        dl, dl_array, atol=1e-10, rtol=1e-12
    )

23.1 ms ± 1.43 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)


In [4]:
dl = np.zeros(2 * L - 1, dtype=np.float64)
dl_turok = wigner.turok.compute_slice(dl, beta, el, L, -spin)

test_dls(dl_turok)

%timeit wigner.turok.compute_slice(dl, beta, el, L, -spin)

2.16 ms ± 606 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)


In [5]:
dl_turok_gpu = wigner.turok_gpu.compute_slice(beta, el, L, -spin)
test_dls(dl_turok_gpu)

%timeit wigner.turok_gpu.compute_slice(beta, el, L, -spin)

852 µs ± 93.9 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)


In [6]:
from jax import jit
forward_jit = jit(wigner.turok_gpu.compute_slice, static_argnums=(0,1,2,3))

dl_turok_gpu_jit = forward_jit(beta, el, L, -spin).block_until_ready()
test_dls(dl_turok_gpu_jit)

%timeit forward_jit(beta, el, L, -spin)

772 µs ± 195 ns per loop (mean ± std. dev. of 7 runs, 1000 loops each)


In [7]:
print(dl_turok_gpu[0:10])
print(dl_turok_gpu_jit[0:10])
print(dl_turok[0:10])
print(dl_array[0:10])

[-2.60578212e-09 -2.39769717e-08 -1.53234936e-07 -7.84791884e-07
 -3.41335692e-06 -1.30086701e-05 -4.42896338e-05 -1.36451038e-04
 -3.83824869e-04 -9.91975415e-04]
[-2.60578212e-09 -2.39769717e-08 -1.53234936e-07 -7.84791884e-07
 -3.41335692e-06 -1.30086701e-05 -4.42896338e-05 -1.36451038e-04
 -3.83824869e-04 -9.91975415e-04]
[-2.60578212e-09 -2.39769717e-08 -1.53234936e-07 -7.84791884e-07
 -3.41335692e-06 -1.30086701e-05 -4.42896338e-05 -1.36451038e-04
 -3.83824869e-04 -9.91975415e-04]
[-2.60578212e-09 -2.39769717e-08 -1.53234936e-07 -7.84791884e-07
 -3.41335692e-06 -1.30086701e-05 -4.42896338e-05 -1.36451038e-04
 -3.83824869e-04 -9.91975415e-04]
