In [1]:
import numpy as np
import jax.numpy as jnp
from jax.config import config
config.update("jax_enable_x64", True)

In [2]:
import s2fft
import s2fft.wigner as wigner
import pyssht as ssht

In [3]:
# Set bandlimit
L = 128

In [4]:
dl = np.zeros((2 * L - 1, 2 * L - 1), dtype=np.float64)
dl = wigner.trapani.init(dl, L)
for el in range(1, L):
    dl = wigner.trapani.compute_full(dl, L, el)

In [5]:
%timeit wigner.trapani.compute_full(dl, L, el)

83.5 ms ± 4.18 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)


In [6]:
dl_vect = np.zeros((2 * L - 1, 2 * L - 1), dtype=np.float64)
dl_vect = wigner.trapani.init(dl_vect, L)
for el in range(1, L):
    dl_vect = wigner.trapani.compute_full_vectorized(dl_vect, L, el)   

In [7]:
%timeit wigner.trapani.compute_full_vectorized(dl_vect, L, el)

4.47 ms ± 1.04 ms per loop (mean ± std. dev. of 7 runs, 100 loops each)


In [8]:
# Compare to routines in SSHT, which have been validated extensively.
dl_jax = jnp.zeros((2 * L - 1, 2 * L - 1), dtype=jnp.float64)
dl_jax = wigner.trapani.init_jax(dl_jax, L)
for el in range(1, L):
    dl_jax = wigner.trapani.compute_full_jax(dl_jax, L, el)



In [9]:
%timeit wigner.trapani.compute_full_jax(dl_jax, L, el)

1.06 ms ± 179 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)
