In [1]:
import numpy as np
import jax.numpy as jnp

In [2]:
from jax.config import config
config.update("jax_enable_x64", True)

In [3]:
import os
import sys
sys.path.insert(1, os.path.join(sys.path[0], '..'))

In [4]:
import s2fft
import s2fft.wigner as wigner
import pyssht as ssht

In [5]:
L = 128

In [6]:
# Compare to routines in SSHT, which have been validated extensively.
dl = np.zeros((2 * L - 1, 2 * L - 1), dtype=np.float64)
dl = wigner.trapani.init(dl, L)
for el in range(1, L):
    dl = wigner.trapani.compute_full(dl, L, el)

In [7]:
%timeit wigner.trapani.compute_full(dl, L, el)

45 ms ± 187 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)


In [8]:
# Compare to routines in SSHT, which have been validated extensively.
dl_vect = np.zeros((2 * L - 1, 2 * L - 1), dtype=np.float64)
dl_vect = wigner.trapani.init(dl_vect, L)
for el in range(1, L):
    dl_vect = wigner.trapani.compute_full_vectorized(dl_vect, L, el)   

In [9]:
%timeit wigner.trapani.compute_full_vectorized(dl_vect, L, el)

2.69 ms ± 14.7 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)


In [10]:
# Compare to routines in SSHT, which have been validated extensively.
dl_jax = jnp.zeros((2 * L - 1, 2 * L - 1), dtype=jnp.float64)
dl_jax = wigner.trapani.init_jax(dl_jax, L)
for el in range(1, L):
    dl_jax = wigner.trapani.compute_full_jax(dl_jax, L, el)

In [11]:
%timeit wigner.trapani.compute_full_jax(dl_jax, L, el)

1.04 s ± 63.1 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)


In [12]:
dl_jax.dtype

dtype('float64')

In [13]:
from jax import grad, jit

In [14]:
compute_full_jax_jit = jit(wigner.trapani.compute_full_jax, static_argnums=(1,))

In [15]:
%timeit compute_full_jax_jit(dl_jax, L, el)

911 µs ± 68.1 µs per loop (mean ± std. dev. of 7 runs, 1 loop each)
