In [1]:
import numpy as np

import pyssht as ssht
import s2fft as s2f
import s2fft.samples as samples
import s2fft.quadrature as quadrature
import s2fft.resampling as resampling
import s2fft.wigner as wigner
import s2fft.healpix_ffts as hp
import healpy as hpy

import jax
from jax import jit, device_put
import jax.numpy as jnp
from jax.config import config

import matplotlib.pyplot as plt

config.update("jax_enable_x64", True)

In [2]:
## Sample data

# input params
L = 5  # 128 # in tests: 5
spin = 2  # 2 # in tests: [0, 1, 2]
sampling = "healpix"  #'dh' # in tests: ["mw", "mwss", "dh", "healpix"]
nside = 2 #[2,4,8] in tests, None if sampling not healpix

# generate spherical harmonics (ground truth)
# random---modify to use JAX random key approach?
DEFAULT_SEED = 8966433580120847635
rn_gen = np.random.default_rng(DEFAULT_SEED)


# compute signal in time domain with ssht (starting point)
# if healpix: we use the inverse from 'direct'? (or 'sov_fft_vectorized'?)
if sampling == 'healpix':
    flm_gt0 = s2f.utils.generate_flm(rn_gen, L, spin, reality=False) # shape L, 2L-1
    f = s2f.transform._inverse(flm_gt0, L, sampling=sampling, method='direct', nside=nside)
    flm_gt = hpy.sphtfunc.map2alm(np.real(f), lmax=L - 1, iter=0) #ground-truth
    flm_gt = s2f.samples.flm_hp_to_2d(flm_gt, L)
    
else:
    flm_gt = s2f.utils.generate_flm(rn_gen, L, spin, reality=False) # shape L, 2L-1
    f = ssht.inverse(
        s2f.samples.flm_2d_to_1d(flm_gt, L),  # 2D indexed coeffs to 1D indexed
        L,
        Method=sampling.upper(),
        Spin=spin,
        Reality=False,
    )

In [26]:
# Using SOV + FFT Vectorised
method_str = "sov_fft_vectorized"
flm_sov_fft_vec = s2f.transform._forward(f, L, spin, sampling, method=method_str, nside=nside)

print(np.allclose(flm_gt, flm_sov_fft_vec, atol=1e-14)) #returns False for healpix

# %timeit s2f.transform._forward(f, L, spin, sampling, method=method_str)

False


In [27]:
# Using SOV + FFT Vectorised JAXed
method_str = "sov_fft_vectorized_jax"
flm_sov_fft_vec_jax = s2f.transform._forward(f, L, spin, sampling, method=method_str, nside=nside) # shape L, 2L-1

print(np.allclose(flm_gt, flm_sov_fft_vec_jax, atol=1e-14))
# print(np.allclose(flm_sov_fft_vec, flm_sov_fft_vec_jax, atol=1e-14)) 
# ---returns True if replacing `wigner.turok.compute_slice(theta, el, L, -spin)` with 
# `np.array(wigner.turok_jax.compute_slice(theta, el, L, -spin))` in _compute_forward_sov_fft_vectorized

# %timeit s2f.transform._forward(f, L, spin, sampling, method=method_str)


False


In [3]:
# Using SOV + FFT Vectorised JAXed --boiler plate
# method_str = "sov_fft_vectorized_jax"
if sampling.lower() == "mw":
    f = resampling.mw_to_mwss(f, L, spin)

if sampling.lower() in ["mw", "mwss"]:
    sampling = "mwss"
    f = resampling.upsample_by_two_mwss(f, L, spin)
    thetas = samples.thetas(2 * L, sampling)
else:
    thetas = samples.thetas(L, sampling,nside=nside)

weights = quadrature.quad_weights_transform(L, sampling, spin, nside=nside)

# nside=None
# flm_sov_fft_vec_2 = s2f.transform._compute_forward_sov_fft_vectorized_jax(
#     f, L, spin, sampling, thetas, weights, nside=None
# )  # , method=method_str)

# print(np.allclose(flm_gt, flm_sov_fft_vec_2, atol=1e-14))

# %timeit s2f.transform._compute_forward_sov_fft_vectorized_jax(f, L, spin, sampling, thetas, weights, nside=None) #, method=method_str)

In [4]:
# Phase shift array per theta
phase_shift_per_theta = np.zeros((len(thetas), 2*L-1),dtype=complex)
if sampling.lower() == "healpix":
    for t, theta in enumerate(thetas):
        print(t)
        phase_shift_per_theta[t,:] = samples.ring_phase_shift_hp(L, t, nside, forward=True)
                
phase_shift_per_theta.shape

0
1
2
3
4
5
6


(7, 9)

In [5]:
# Phase shift array per theta
fn_aux = lambda L, t, nside, forward: samples.ring_phase_shift_hp(L, t, nside, forward=forward)

phase_shift_per_theta2 = np.zeros((len(thetas), 2*L-1),dtype=complex)
if sampling.lower() == "healpix":
    for t, theta in enumerate(thetas):
        phase_shift_per_theta2[t,:] = fn_aux(L, t, nside, True)
                
phase_shift_per_theta2.shape
np.allclose(phase_shift_per_theta,phase_shift_per_theta2)

True

In [6]:
# Compute phase shift vmapped
if sampling.lower() == "healpix":
    phase_shift_vmapped = jax.vmap(fn_aux, in_axes=(None,0,None,None), out_axes=-1) #samples.ring_phase_shift_hp,in_axes=((None,0,None,None)))
    # phase_shift_vmapped(L, jnp.array(range(len(thetas))), nside, forward=True)
    # phase_shift = jnp.expand_dims(phase_shift_vmapped(L, thetas, nside, forward=True), axis=(0,-1))  
else:
    phase_shift = 1.0    

In [7]:
phase_shift_vmapped(L, jnp.array(range(len(thetas))), nside, True) 
# https://github.com/google/jax/issues/7465

ConcretizationTypeError: Abstract tracer value encountered where concrete value is expected: Traced<ShapedArray(bool[])>with<BatchTrace(level=1/0)> with
  val = DeviceArray([False,  True,  True,  True,  True,  True, False], dtype=bool)
  batch_dim = 0
The problem arose with the `bool` function. 
This Tracer was created on line /Users/sofia/Documents_local/SAX project/s2fft/s2fft/samples.py:350 (p2phi_ring)

See https://jax.readthedocs.io/en/latest/errors.html#jax.errors.ConcretizationTypeError

In [9]:
# Compute dl_vmapped fn
dl_vmapped = jax.vmap(
    jax.vmap(
        wigner.turok_jax.compute_slice,
        in_axes=(0, None, None, None),
        out_axes=-1,
    ),
    in_axes=(None, 0, None, None),
    out_axes=0,
)
