In [6]:
import numpy as np

import pyssht as ssht
import s2fft as s2f
import s2fft.samples as samples
import s2fft.quadrature as quadrature
import s2fft.resampling as resampling

import jax
from jax import jit, device_put
import jax.numpy as jnp
from jax.config import config

import matplotlib.pyplot as plt

config.update("jax_enable_x64", True)

In [7]:
## Sample data

# input params
L = 5  # 128 # in tests: 5
spin = 2  # 2 # in tests: [0, 1, 2]
sampling = "mwss"  #'dh' # in tests: ["mw", "mwss", "dh"]

# generate spherical harmonics (ground truth)
# random---modify to use JAX random key approach?
DEFAULT_SEED = 8966433580120847635
rn_gen = np.random.default_rng(DEFAULT_SEED)
flm_gt = s2f.utils.generate_flm(rn_gen, L, spin, reality=False)

# compute signal in time domain with ssht (starting point)
f = ssht.inverse(
    s2f.samples.flm_2d_to_1d(flm_gt, L),  # 2D indexed coeffs to 1D indexed
    L,
    Method=sampling.upper(),
    Spin=spin,
    Reality=False,
)

In [8]:
# Using SOV + FFT Vectorised
method_str = "sov_fft_vectorized"
flm_sov_fft_vec = s2f.transform._forward(f, L, spin, sampling, method=method_str)

print(np.allclose(flm_gt, flm_sov_fft_vec, atol=1e-14))

# %timeit s2f.transform._forward(f, L, spin, sampling, method=method_str)

True


In [9]:
# Using SOV + FFT Vectorised JAXed
method_str = "sov_fft_vectorized_jax"
flm_sov_fft_vec_jax = s2f.transform._forward(f, L, spin, sampling, method=method_str)

print(np.allclose(flm_gt, flm_sov_fft_vec_jax, atol=1e-14))
# print(np.allclose(flm_sov_fft_vec, flm_sov_fft_vec_jax, atol=1e-14)) 
# ---returns True if replacing `wigner.turok.compute_slice(theta, el, L, -spin)` with 
# `np.array(wigner.turok_jax.compute_slice(theta, el, L, -spin))` in _compute_forward_sov_fft_vectorized

# %timeit s2f.transform._forward(f, L, spin, sampling, method=method_str)

False


In [None]:
# Using SOV + FFT Vectorised JAXed
# method_str = "sov_fft_vectorized_jax"
if sampling.lower() == "mw":
    f = resampling.mw_to_mwss(f, L, spin)

if sampling.lower() in ["mw", "mwss"]:
    sampling = "mwss"
    f = resampling.upsample_by_two_mwss(f, L, spin)
    thetas = samples.thetas(2 * L, sampling)
else:
    thetas = samples.thetas(L, sampling)

weights = quadrature.quad_weights_transform(L, sampling, 0, nside=None)


flm_sov_fft_vec_2 = s2f.transform._compute_forward_sov_fft_vectorized_jax(
    f, L, spin, sampling, thetas, weights, nside=None
)  # , method=method_str)

print(np.allclose(flm_gt, flm_sov_fft_vec, atol=1e-14))

# %timeit s2f.transform._compute_forward_sov_fft_vectorized_jax(f, L, spin, sampling, thetas, weights, nside=None) #, method=method_str)

In [None]:
# # JAX approach 1 (vectorised jaxifying)
# flm_jax = s2f.transform.forward_sov_fft_vectorized_jax_1(
#     f, L, spin, sampling
# )
# # compare to alternative computation of flm: flm_sov_fft_vec_jax_turok
# print(np.allclose(flm_jax_1,
#                   flm_sov_fft_vec_jax_turok, #w/o padding: flm_sov_fft_vec_jax_turok[spin:L, :],
#                   atol=1e-14))

# # %timeit s2f.transform.forward_sov_fft_vectorized_jax_turok(f, L, spin, sampling)
# %timeit s2f.transform.forward_sov_fft_vectorized(f, L, spin, sampling)
# %timeit s2f.transform.forward_sov_fft_vectorized_jax_1(f, L, spin, sampling).block_until_ready()