In [1]:
import numpy as np
import jax.numpy as jnp

In [2]:
from jax.config import config
config.update("jax_enable_x64", True)

In [3]:
import os
import sys
sys.path.insert(1, os.path.join(sys.path[0], '..'))

In [4]:
import s2fft
import s2fft.wigner as wigner
import pyssht as ssht

In [5]:
# Test all dl(pi/2) terms up to L.
L = 5

# Compute using SSHT.
beta = np.pi / 2.0
dl_array = ssht.generate_dl(beta, L)

# Compare to routines in SSHT, which have been validated extensively.
dl = np.zeros((2 * L - 1, 2 * L - 1), dtype=np.float64)
dl = wigner.trapani.init(dl, L)
for el in range(1, L):
    dl = wigner.trapani.compute_full(dl, L, el)

In [6]:
%timeit wigner.trapani.compute_full(dl, L, el)

92.1 µs ± 3.66 µs per loop (mean ± std. dev. of 7 runs, 10000 loops each)


In [7]:
# Test all dl(pi/2) terms up to L.
L = 5

# Compare to routines in SSHT, which have been validated extensively.
dl_vect = np.zeros((2 * L - 1, 2 * L - 1), dtype=np.float64)
dl_vect = wigner.trapani.init(dl_vect, L)
for el in range(1, L):
    dl_vect = wigner.trapani.compute_full_vectorized(dl_vect, L, el)   

In [8]:
%timeit wigner.trapani.compute_full_vectorized(dl_vect, L, el)

156 µs ± 5.12 µs per loop (mean ± std. dev. of 7 runs, 10000 loops each)


In [9]:
# Test all dl(pi/2) terms up to L.
L = 5

# Compare to routines in SSHT, which have been validated extensively.
dl_jax = jnp.zeros((2 * L - 1, 2 * L - 1), dtype=jnp.float64)
dl_jax = wigner.trapani.init_jax(dl_jax, L)
for el in range(1, L):
    dl_jax = wigner.trapani.compute_full_jax(dl_jax, L, el)



In [10]:
%timeit wigner.trapani.compute_full_jax(dl_jax, L, el)

30.1 ms ± 1.19 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)


In [11]:
dl_jax.dtype

dtype('float64')

In [12]:
from jax import grad, jit

In [13]:
compute_full_jax_jit = jit(wigner.trapani.compute_full_jax, static_argnums=(1,))

In [14]:
%timeit compute_full_jax_jit(dl_jax, L, el)

12.1 µs ± 6.78 µs per loop (mean ± std. dev. of 7 runs, 1 loop each)
