In [1]:
import numpy as np

import pyssht as ssht
import s2fft as s2f
import s2fft.samples as samples
import s2fft.quadrature as quadrature
import s2fft.resampling as resampling
import s2fft.wigner as wigner

import jax
from jax import jit, device_put
import jax.numpy as jnp
from jax.config import config

import matplotlib.pyplot as plt

config.update("jax_enable_x64", True)

In [5]:
## Sample data

# input params
L = 5  # 128 # in tests: 5
spin = 2  # 2 # in tests: [0, 1, 2]
sampling = "mwss"  #'dh' # in tests: ["mw", "mwss", "dh"]

# generate spherical harmonics (ground truth)
# random---modify to use JAX random key approach?
DEFAULT_SEED = 8966433580120847635
rn_gen = np.random.default_rng(DEFAULT_SEED)
flm_gt = s2f.utils.generate_flm(
    rn_gen, L, spin, reality=False
)  

# compute signal in time domain with ssht (starting point)
f = ssht.inverse(
    s2f.samples.flm_2d_to_1d(flm_gt, L),  # 2D indexed coeffs to 1D indexed
    L,
    Method=sampling.upper(),
    Spin=spin,
    Reality=False,
) 

# thetas
nside = None

if sampling.lower() in ["mw", "mwss"]:
    sampling = "mwss"
    # f = resampling.upsample_by_two_mwss(f, L, spin)
    thetas = samples.thetas(2 * L, sampling)

else:
    thetas = samples.thetas(L, sampling, nside)

In [7]:
# Compute vectorised dl function with vmap
dl_vmapped = jax.vmap(
    jax.vmap(
        wigner.turok_jax.compute_slice, # (theta, el, L, -spin)
        in_axes=(0, None, None, None),
        out_axes=-1,
    ),
    in_axes=(None, 0, None, None),
    out_axes=0,
)

# Compute 3D array for dl (evaluate vmapped fn at range of thetas and el)
dl_3D = dl_vmapped(thetas, 
                   jnp.array(range(abs(spin), L), dtype=np.int64), 
                   L, 
                   -spin) # inputs in the same order as original fn, but out axes as specified!
dl_3D.shape

print(type(dl_3D[0,:,0]))
print(dl_3D[0,:,0].shape) # axes: el, (2L-1), theta ----OJO!!!!
#
# dl_vmapped(thetas,el_array,L,-spin)[1,:,0] == 
# dl at theta_i=0, el_i=1 (returns a vector of len 2L-1)

<class 'jaxlib.xla_extension.DeviceArray'>
(9,)


In [None]:
# Compute dl the usual way and check for each theta, el
flag_same_slice = []
flag_same = []
for t, theta in enumerate(thetas):
    el_i=0
    for el in range(abs(spin), L):
        dl = wigner.turok.compute_slice(theta, el, L, -spin) #shape (9,)
        flag_same.append(np.allclose(dl,
                                     dl_3D[el_i,:,t],
                                     atol=1e-14))
        print(dl)
        print(dl_3D[el_i,:,t])
        print('---')
        flag_same_slice.append(np.allclose(dl[L-1-el:L-1+el+1],
                                           dl_3D[el_i, L-1-el:L-1+el+1, t],
                                           atol=1e-14))
        el_i+=1
        
print(np.all(flag_same))
print(np.all(flag_same_slice))

In [28]:
# Check wigner compute slice? --- not the same for equal theta, el, L, -spin?
flag_same_one_theta_el = []
flag_same_slice = []
for t, theta in enumerate(thetas):

    for el in range(abs(spin), L):
        dl = wigner.turok.compute_slice(theta, el, L, -spin) # shape 9,
        dl_jax = wigner.turok_jax.compute_slice(theta, el, L, -spin) #shape 9, .block_until_ready()

        # compare all 2L-1 terms
        flag_same_one_theta_el.append(np.allclose(dl,dl_jax,atol=1e-14))

        # compare only slice from L-1-el to L-1+el+1
        flag_same_slice.append(np.allclose(dl[L-1-el:L-1+el+1],
                                           dl_jax[L-1-el:L-1+el+1],
                                           atol=1e-14))
                                     

print(np.all(flag_same_one_theta_el))
print(np.all(flag_same_slice))

False
True


In [19]:
len(range(abs(spin), L))

3