# __JAX HEALPix frontend__
---

[![colab image](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/astro-informatics/s2fft/blob/main/notebooks/JAX_HEALPix_frontend.ipynb)

In [1]:
# import sys
# IN_COLAB = 'google.colab' in sys.modules

# # Install s2fft and data if running on google colab.
# if IN_COLAB:
#     !pip install s2fft &> /dev/null

This short tutorial demonstrates how to use the custom JAX frontend support `S2FFT` provides for the [`HEALPix`](https://healpix.jpl.nasa.gov) C++ library.  This solves the long JIT compile time for HEALPix when running on CPU.

As with the other introductions, let's import some packages and define an arbitrary bandlimited signal to work with.

In [10]:
import jax
jax.config.update("jax_enable_x64", True)

import numpy as np
import s2fft 

L = 128
nside = 64
method = "jax_healpy"
sampling = "healpix"
rng = np.random.default_rng(23457801234570)
# spherical harmonic coefficients (flm)
flm = s2fft.utils.signal_generator.generate_flm(rng, L)
f = s2fft.inverse(flm, L, nside=nside, sampling=sampling, method=method)

In [26]:
flm.shape

(100, 199)

### Calling forward HEALPix C++ function from JAX.

---

In [11]:
flm = s2fft.forward(f, L, nside=nside, sampling=sampling, method=method)

### Calling inverse HEALPix C++ function from JAX.

---

In [12]:
f_recov = s2fft.inverse(flm, L, nside=nside, sampling=sampling,  method=method)

### Computing the roundtrip error

---

Let's check the associated error, which should be around 1e-5 for healpix, which is not an exact sampling of the sphere. Note that increasing `iters` will reduce the numerical error here slightly, at the cost of linearly increased compute.

In [6]:
print(f"Mean absolute error = {np.nanmean(np.abs(f_recov - f))}")

Mean absolute error = 4.970630544067721e-05


### Differentiating through HEALPix C++ functions.

---

So far all this is doing is providing an interface between `JAX` and `HEALPix`, the real novelty comes when we differentiate through the C++ library.

In [7]:
# Define an arbitrary JAX function
def differentiable_test(flm) -> int:
    f = s2fft.inverse(flm, L, nside=nside, sampling=sampling, method=method)
    return jax.numpy.nanmean(jax.numpy.abs(f)**2)

# Create the JAX reverse mode gradient function
gradient_func = jax.grad(differentiable_test)

# Compute the gradient automatically
gradient = gradient_func(flm)

### Validating these gradients

---
This is all well and good, but how do we know these gradients are correct? Thankfully `JAX` prvoides a simple function to check this...

In [9]:
from jax.test_util import check_grads
check_grads(differentiable_test, (flm,), order=1, modes=("rev"))



None


In [2]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
%matplotlib inline 
#show graph inside the notebook


# Make sure we configure 64 bit precision. 
# 32 bit can be faster but you will be (potentially much) less precise.
import jax
jax.config.update("jax_enable_x64", True)
import jax.numpy as jnp

import healpy as hp
import s2wav       # Wavelet transforms on the sphere and rotation group
import s2fft       # Spherical harmonic and Wigner transforms
PLA_Data = pd.read_csv("data/planck_simulation/PLA-Results.csv")

# Store the names of the datasets into PLA_Data_List
PLA_Data_List = PLA_Data['SIMULATED_MAP.FILE_ID'].to_list()



# Initialize PLA_Data_Dict : (key:frequnecy, value: [csv path])
PLA_Data_Dict = dict()
for each_csv_path in PLA_Data_List:
    PLA_Data_Dict[each_csv_path[20:23]] = ["data/planck_simulation/"+each_csv_path]
# display(PLA_Data_Dict)

# Read the CMB datasets and store them inside the PLA_Data_Dict
# Update PLA_Data_Dict : (key:frequnecy, value: [csv path, hp map dat ])
 
for frequency, storage_list in PLA_Data_Dict.items():
    storage_list.append(hp.read_map(storage_list[0], dtype=jnp.float64))
# display(PLA_Data_Dict)

# Convert the unit from unitMJy/steradian to K_CMB
# Source: https://wiki.cosmos.esa.int/planckpla2015/index.php/UC_CC_Tables

PLA_Data_Dict["545"][1] = PLA_Data_Dict["545"][1]/58.0356
PLA_Data_Dict["857"][1] = PLA_Data_Dict["857"][1]/2.2681
# PLA_Data_Dict["857"][1]

In [15]:
L = 3998
nside = 1024
method = "jax_healpy"
sampling = "healpix"

# rng = np.random.default_rng(23457801234570)
# # spherical harmonic coefficients (flm)
# flm = s2fft.utils.signal_generator.generate_flm(rng, L)
# f = s2fft.inverse(flm, L, nside=nside, sampling=sampling, method=method)

In [1]:

# Perform spherical harmonic analysis
lmax = 100  # Maximum degree of spherical harmonics
nside = hp.get_nside(PLA_Data_Dict["030"][1])  # Resolution of HEALPix grid
# map2alm function computes the spherical harmonic coefficients
alm = hp.map2alm(PLA_Data_Dict["030"][1], lmax=lmax)
method = "jax_healpy"
sampling = "healpix"
f = s2fft.inverse(alm, lmax, nside=nside, sampling=sampling, method=method)

NameError: name 'hp' is not defined

In [22]:
f.shape

(12582912,)

In [18]:
f = PLA_Data_Dict["030"][1]

In [27]:
alm.shape

(5151,)

In [23]:
flm = s2fft.forward(PLA_Data_Dict["030"][1], lmax, nside=nside, sampling=sampling, method=method)
flm.shape

(100, 199)

In [29]:
flm[1]

Array([ 0.00000000e+00+0.0000000e+00j,  0.00000000e+00+0.0000000e+00j,
        0.00000000e+00+0.0000000e+00j,  0.00000000e+00+0.0000000e+00j,
        0.00000000e+00+0.0000000e+00j,  0.00000000e+00+0.0000000e+00j,
        0.00000000e+00+0.0000000e+00j,  0.00000000e+00+0.0000000e+00j,
        0.00000000e+00+0.0000000e+00j,  0.00000000e+00+0.0000000e+00j,
        0.00000000e+00+0.0000000e+00j,  0.00000000e+00+0.0000000e+00j,
        0.00000000e+00+0.0000000e+00j,  0.00000000e+00+0.0000000e+00j,
        0.00000000e+00+0.0000000e+00j,  0.00000000e+00+0.0000000e+00j,
        0.00000000e+00+0.0000000e+00j,  0.00000000e+00+0.0000000e+00j,
        0.00000000e+00+0.0000000e+00j,  0.00000000e+00+0.0000000e+00j,
        0.00000000e+00+0.0000000e+00j,  0.00000000e+00+0.0000000e+00j,
        0.00000000e+00+0.0000000e+00j,  0.00000000e+00+0.0000000e+00j,
        0.00000000e+00+0.0000000e+00j,  0.00000000e+00+0.0000000e+00j,
        0.00000000e+00+0.0000000e+00j,  0.00000000e+00+0.0000000e+00j,
      

In [24]:
f_recov = s2fft.inverse(flm, lmax, nside=nside, sampling=sampling,  method=method)
f_recov.shape

(12582912,)

In [21]:
print(f"Mean absolute error = {np.nanmean(np.abs(f_recov - f))}")

Mean absolute error = 0.00022402470120383457
