In [1]:
import numpy as np

import jax
import jax.numpy as jnp
from jax.config import config

import pyssht as ssht
import s2fft as s2f

import s2fft.samples as samples
import s2fft.resampling as resampling
import s2fft.wigner as wigner


config.update("jax_enable_x64", True)

In [2]:
## Example data

L = 5  # 128 # in tests: 5
spin = 0  # 2 # in tests: [0, 1, 2]
sampling = "dh"  # in tests: ["mw", "mwss", "dh"]

# generate spherical harmonics (ground truth)
DEFAULT_SEED = 8966433580120847635
rn_gen = np.random.default_rng(DEFAULT_SEED)
flm_gt = s2f.utils.generate_flm(rn_gen, L, spin, reality=False)

# compute signal in time domain (starting point)
f = ssht.inverse(
    s2f.samples.flm_2d_to_1d(flm_gt, L),  # 2D indexed coeffs to 1D indexed
    L,
    Method=sampling.upper(),
    Spin=spin,
    Reality=False,
)

device = jax.devices()[0]
f = jax.device_put(f, device)  # transform f to DeviceArray and commit to device

In [3]:
# resample f if required and compute theta
if sampling.lower() == "mw":
    f = resampling.mw_to_mwss(f, L, spin)

if sampling.lower() in ["mw", "mwss"]:
    sampling = "mwss"
    f = resampling.upsample_by_two_mwss(f, L, spin)
    thetas = samples.thetas(2 * L, sampling)
else:
    thetas = samples.thetas(L, sampling)

# define el array
el_array = jnp.array(range(spin, L), dtype=np.int64)

In [4]:
## Compute dl using wigner.turok.compute_slice (Naive approach)

dl = np.zeros((len(thetas), len(range(spin, L)), 2 * L - 1), dtype=np.float64)
for t, theta in enumerate(thetas):

    for e, el in enumerate(range(spin, L)):  # enumerate(el_array)

        dl[t, e, :] = wigner.turok.compute_slice(theta, el, L, -spin)

print(dl.shape)
print(type(dl))

(10, 5, 9)
<class 'numpy.ndarray'>


In [5]:
## Compute dl w/ JAX approach (vmap + wigner.turok_jax.compute_slice)
dl_vmapped_theta = jax.vmap(
    wigner.turok_jax.compute_slice, in_axes=(0, None, None, None), out_axes=0
)

dl_vmapped = jax.vmap(dl_vmapped_theta, in_axes=(None, 0, None, None), out_axes=1)

print(dl_vmapped(thetas, el_array, L, -spin).shape)
print(type(dl_vmapped(thetas, el_array, L, -spin)))

(10, 5, 9)
<class 'jaxlib.xla_extension.DeviceArray'>


In [8]:
dl[:,:,0]

array([[0.00000000e+00, 0.00000000e+00, 0.00000000e+00, 0.00000000e+00,
        3.13154605e-04],
       [0.00000000e+00, 0.00000000e+00, 0.00000000e+00, 0.00000000e+00,
        2.22134542e-02],
       [0.00000000e+00, 0.00000000e+00, 0.00000000e+00, 0.00000000e+00,
        1.30728129e-01],
       [0.00000000e+00, 0.00000000e+00, 0.00000000e+00, 0.00000000e+00,
        3.29573720e-01],
       [0.00000000e+00, 0.00000000e+00, 0.00000000e+00, 0.00000000e+00,
        4.97632511e-01],
       [0.00000000e+00, 0.00000000e+00, 0.00000000e+00, 0.00000000e+00,
        4.97632511e-01],
       [0.00000000e+00, 0.00000000e+00, 0.00000000e+00, 0.00000000e+00,
        3.29573720e-01],
       [0.00000000e+00, 0.00000000e+00, 0.00000000e+00, 0.00000000e+00,
        1.30728129e-01],
       [0.00000000e+00, 0.00000000e+00, 0.00000000e+00, 0.00000000e+00,
        2.22134542e-02],
       [0.00000000e+00, 0.00000000e+00, 0.00000000e+00, 0.00000000e+00,
        3.13154605e-04]])

In [9]:
dl_vmapped(thetas, el_array, L, -spin)[:,:,0]

DeviceArray([[ 0.00000000e+00, -1.10615871e-01,  1.89233490e-01,
               2.14004219e-03,  3.13154605e-04],
             [ 0.00000000e+00, -3.21019761e-01,  4.95419707e-01,
               5.23076616e-02,  2.22134542e-02],
             [ 0.00000000e+00, -5.00000000e-01,  6.12372436e-01,
               1.97642354e-01,  1.30728129e-01],
             [ 0.00000000e+00, -6.30036755e-01,  4.95419707e-01,
               3.95428223e-01,  3.29573720e-01],
             [ 0.00000000e+00, -6.98401123e-01,  1.89233490e-01,
               5.38622873e-01,  4.97632511e-01],
             [ 0.00000000e+00, -6.98401123e-01, -1.89233490e-01,
               5.38622873e-01,  4.97632511e-01],
             [ 0.00000000e+00, -6.30036755e-01, -4.95419707e-01,
               3.95428223e-01,  3.29573720e-01],
             [ 0.00000000e+00, -5.00000000e-01, -6.12372436e-01,
               1.97642354e-01,  1.30728129e-01],
             [ 0.00000000e+00, -3.21019761e-01, -4.95419707e-01,
               5.230766

In [11]:
# Compare
print(jnp.max(jnp.abs(jnp.array(dl) - dl_vmapped(thetas, el_array, L, -spin))))

print(
    np.allclose(dl, 
                dl_vmapped(thetas, el_array, L, -spin), 
                atol=1e-14)
)  # ----returns False?

1.9149505459360456
False


#### Some observations:
- I think the vmap step is fine because if I use the jax version of the function (`wigner.turok_jax.compute_slice`), rather than the non-jax version (`wigner.turok.compute_slice`), the naive and vmapped approach match

- I noticed that in the tests, the [test for turok_jax.compute_slice](https://github.com/astro-informatics/s2fft/blob/main/tests/test_wigner.py#L190) and [the test for turok.compute_slice](https://github.com/astro-informatics/s2fft/blob/main/tests/test_wigner.py#L167) don't use the same 'ground-truth' (there is an np.flip operation and they seem to check different slices) so maybe I'm missing something there?