# __S2FFT CUDA Implementation__
---

[![colab image](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/astro-informatics/s2fft/blob/main/notebooks/JAX_HEALPix_frontend.ipynb)

In [None]:
import sys
IN_COLAB = 'google.colab' in sys.modules

# Install s2fft and data if running on google colab.
if IN_COLAB:
    !pip install s2fft &> /dev/null

In [12]:
!pip install healpy matplotlib seaborn &> /dev/null

Short comparaison between the pure JAX implementation and the CUDA implementation of the S2FFT algorithm.

In [1]:
import jax
from jax import numpy as jnp
import argparse
import time
from time import perf_counter
import matplotlib.pyplot as plt
import seaborn as sns

jax.config.update("jax_enable_x64", True)

from s2fft.utils.healpix_ffts import  healpix_fft_jax, healpix_ifft_jax, healpix_fft_cuda, healpix_ifft_cuda
from s2fft.sampling.reindex import flm_2d_to_hp_fast, flm_hp_to_2d_fast
import numpy as np
import s2fft 
from s2fft import forward , inverse
import healpy as hp
import numpy as np


### Initial Setup and Forward Transform Comparison

This section sets up the HEALPix parameters and performs a forward spherical harmonic transform using `s2fft`'s JAX CUDA implementation, comparing the results with `healpy`.

In [2]:
# Set up
nside = 16
npix = hp.nside2npix(nside)
map_random = jax.random.normal(jax.random.key(0) , shape=npix)

# Compute alms (spherical harmonic coefficients)
lmax = 3 * nside - 1
L = lmax + 1  # So S2FFT covers ell=0 to lmax inclusive

# healpy alms
alms_healpy = hp.map2alm(np.array(map_random), lmax=lmax , iter=3)
alm_healpy_2d = flm_hp_to_2d_fast(alms_healpy, L=L)

j_alms = forward(map_random, nside=nside, L=L, sampling='healpix' , method='jax_cuda' , iter=3 )
healpix_order_alms = flm_2d_to_hp_fast(j_alms, L=L)
print(f"shape of j_alms: {j_alms.shape}")
print(f"shape of healpix_order_alms: {healpix_order_alms.shape}")


print(f"MSE between j_alms and alms_healpy: {jnp.mean((healpix_order_alms - alms_healpy) ** 2)}")

shape of j_alms: (48, 95)
shape of healpix_order_alms: (1176,)
MSE between j_alms and alms_healpy: (-3.690730140133011e-30+3.982002422466866e-31j)


### VMAP and JAX Transforms Test

This cell demonstrates the use of `jax.vmap` with the forward transform and tests JAX's automatic differentiation capabilities (`jacfwd`, `jacrev`) with the CUDA implementation.

In [3]:
# Set up
nside = 16
npix = hp.nside2npix(nside)
map_random = jax.random.normal(jax.random.key(0) , shape=npix)
# Compute alms (spherical harmonic coefficients)
lmax = 3 * nside - 1
L = lmax + 1  # So S2FFT covers ell=0 to lmax inclusive

maps = jnp.stack([map_random, map_random, map_random , map_random], axis=0)
print(f"Shape of maps: {maps.shape}")

def forward_maps(maps):
    return forward(maps, nside=nside, L=L, sampling='healpix', method='jax_cuda').real

alm_maps = jax.vmap(forward_maps)(maps)

Shape of maps: (4, 3072)


### Inverse Transform Comparison

This cell performs an inverse spherical harmonic transform and compares the reconstructed map from `s2fft`'s JAX CUDA implementation with `healpy`'s reconstruction.

In [4]:
reconstruction_healpy = hp.alm2map(alms_healpy, nside=nside, lmax=lmax)
reconstruction_jax = inverse(j_alms, nside=nside, L=L, sampling='healpix', method='jax_cuda')

print(f"MSE between reconstruction_healpy and reconstruction_jax: {jnp.mean((reconstruction_healpy - reconstruction_jax) ** 2)}")

MSE between reconstruction_healpy and reconstruction_jax: (1.8236620334440454e-27-8.008792862185043e-31j)


### Performance Benchmarking Functions

This section defines helper functions to benchmark the forward and backward spherical harmonic transforms across different `nside` values, comparing `s2fft`'s JAX CUDA, pure JAX, and `healpy` implementations.

In [17]:
sampling = "healpix"
n_iter = 3  # Number of iterations for the forward and inverse transforms

def mse(x, y):
    return jnp.mean(jnp.abs(x - y)**2)


def run_fwd_test(nside):
    L = 2 * nside 

    total_pixels = 12 * nside**2
    arr = jax.random.normal(jax.random.PRNGKey(0), (total_pixels, ))

    method = "jax_cuda"
    start = time.perf_counter()
    cuda_res = forward(arr, L, nside=nside,sampling=sampling, method=method, iter=n_iter ).block_until_ready()
    end = time.perf_counter()
    cuda_jit_time = end - start

    start = time.perf_counter()
    cuda_res = forward(arr, L, nside=nside,sampling=sampling, method=method, iter=n_iter ).block_until_ready()
    end = time.perf_counter()
    cuda_run_time = end - start

    method = "jax"
    start = time.perf_counter()
    jax_res = forward(arr, L, nside=nside,sampling=sampling, method=method, iter=n_iter ).block_until_ready()
    end = time.perf_counter()
    jax_jit_time = end - start

    start = time.perf_counter()
    jax_res = forward(arr, L, nside=nside,sampling=sampling, method=method, iter=n_iter ).block_until_ready()
    end = time.perf_counter()
    jax_run_time = end - start

    method = "jax_healpy"
    arr += 0j
    arr = jax.device_put(arr, jax.devices("cpu")[0])
    start = time.perf_counter()
    flm = s2fft.forward(arr, L, nside=nside, sampling=sampling, method=method, iter=n_iter ).block_until_ready()
    end = time.perf_counter()
    healpy_jit_time = end - start

    start = time.perf_counter()
    flm = s2fft.forward(arr, L, nside=nside, sampling=sampling, method=method, iter=n_iter ).block_until_ready()
    end = perf_counter()
    healpy_run_time = end - start

    print(f"For nside {nside}")
    print(f" -> FWD")
    print(f" -> -> cuda_jit_time: {cuda_jit_time:.4f}, cuda_run_time: {cuda_run_time:.4f} mse against hp {mse(cuda_res, flm)}")
    print(f" -> -> jax_jit_time: {jax_jit_time:.4f}, jax_run_time: {jax_run_time:.4f} mse against hp {mse(cuda_res, flm)}")
    print(f" -> -> healpy_jit_time: {healpy_jit_time:.4f}, healpy_run_time: {healpy_run_time:.4f}")

    return cuda_jit_time , cuda_run_time, jax_jit_time, jax_run_time , healpy_jit_time, healpy_run_time


def run_bwd_test(nside):
    
    sampling = "healpix"
    L = 2 * nside
    total_pixels = 12 * nside**2
    arr = jax.random.normal(jax.random.PRNGKey(0), (total_pixels, )) + 0j
    alm = forward(arr, L, nside=nside, sampling=sampling, method="jax_healpy")
    
    method = "jax"
    start = time.perf_counter()
    jax_res = inverse(alm, L, nside=nside,sampling=sampling, method=method).block_until_ready()
    end = time.perf_counter()
    jax_jit_time = end - start
    start = time.perf_counter()
    jax_res = inverse(alm, L, nside=nside,sampling=sampling, method=method ).block_until_ready()
    end = time.perf_counter()
    jax_run_time = end - start
    
    method = "jax_cuda"
    start = time.perf_counter()
    cuda_res = inverse(alm, L, nside=nside,sampling=sampling, method=method ).block_until_ready()
    end = time.perf_counter()
    cuda_jit_time = end - start
    start = time.perf_counter()
    cuda_res = inverse(alm, L, nside=nside,sampling=sampling, method=method ).block_until_ready()
    end = time.perf_counter()
    cuda_run_time = end - start


    method = "jax_healpy"
    sampling = "healpix"

    alm = jax.device_put(alm, jax.devices("cpu")[0])
    start = time.perf_counter()
    f = inverse(alm, L, nside=nside, sampling=sampling, method=method).block_until_ready()
    end = time.perf_counter()
    healpy_jit_time = end - start

    start = time.perf_counter()
    f = inverse(alm, L, nside=nside, sampling=sampling, method=method ).block_until_ready()
    end = time.perf_counter()
    healpy_run_time = end - start

    print(f"For nside {nside}")
    print(f" -> BWD")
    print(f" -> -> cuda_jit_time: {cuda_jit_time:.4f}, cuda_run_time: {cuda_run_time:.4f} mse against hp {mse(cuda_res, f)}")
    print(f" -> -> jax_jit_time: {jax_jit_time:.4f}, jax_run_time: {jax_run_time:.4f} mse against hp {mse(jax_res, f)}")
    print(f" -> -> healpy_jit_time: {healpy_jit_time:.4f}, healpy_run_time: {healpy_run_time:.4f} ")

    return cuda_jit_time , cuda_run_time, jax_jit_time, jax_run_time , healpy_jit_time, healpy_run_time

### Clear JAX Caches

Clears JAX's internal caches to ensure fresh compilation for benchmarking.

In [18]:
jax.clear_caches()

### Run Benchmarking

Executes the benchmarking functions for various `nside` values to collect performance data.

In [None]:
fwd_times = []
bwd_times = []
nsides = [4 , 8 , 16 , 32 , 64 , 128 , 256 ]
for nside in nsides:
    fwd_times.append(run_fwd_test(nside))
    bwd_times.append(run_bwd_test(nside))

For nside 128
 -> FWD
 -> -> cuda_jit_time: 4.4200, cuda_run_time: 0.6231 mse against hp 2.3766630166715178e-29
 -> -> jax_jit_time: 38.6306, jax_run_time: 0.6253 mse against hp 2.3766630166715178e-29
 -> -> healpy_jit_time: 0.8766, healpy_run_time: 0.4540
For nside 128
 -> BWD
 -> -> cuda_jit_time: 1.3143, cuda_run_time: 0.0907 mse against hp 2.5339123457221976e-25
 -> -> jax_jit_time: 15.6730, jax_run_time: 0.1263 mse against hp 2.5339096506006936e-25
 -> -> healpy_jit_time: 0.0512, healpy_run_time: 0.0041 
For nside 256
 -> FWD
 -> -> cuda_jit_time: 8.7759, cuda_run_time: 4.6370 mse against hp 4.332503429570958e-10
 -> -> jax_jit_time: 88.8303, jax_run_time: 4.6417 mse against hp 4.332503429570958e-10
 -> -> healpy_jit_time: 2.5950, healpy_run_time: 1.7487


KeyboardInterrupt: 

: 

### Plotting Utility

This cell defines a utility function to plot the compilation and execution times obtained from the benchmarking tests.

In [None]:
import matplotlib.pyplot as plt
import numpy as np
import seaborn as sns
sns.plotting_context("poster")
sns.set(font_scale=1.4)


def plot_times(title, nsides, chrono_times):

    # Extracting times from the chrono_times
    cuda_jit_times = [times[0] for times in chrono_times]
    cuda_run_times = [times[1] for times in chrono_times]
    jax_jit_times = [times[2] for times in chrono_times]
    jax_run_times = [times[3] for times in chrono_times]
    healpy_jit_times = [times[4] for times in chrono_times]
    healpy_run_times = [times[5] for times in chrono_times]

    # Create subplots
    fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(14, 7))

    f2 = lambda a: np.log2(a)
    g2 = lambda b: b**2


    # Plot for JIT times
    ax1.plot(nsides, cuda_jit_times, 'g-o', label='ours')
    ax1.plot(nsides, jax_jit_times, 'b-o', label='s2fft base')
    ax1.plot(nsides, healpy_jit_times, 'r-o', label='Healpy')
    ax1.set_title('Compilation Times (first run)')
    ax1.set_xlabel('nside')
    ax1.set_ylabel('Time (seconds)')
    ax1.set_xscale('function', functions=(f2, g2))
    ax1.set_xticks(nsides)
    ax1.set_xticklabels(nsides)
    ax1.legend()
    ax1.grid(True, which="both", ls="--")

    # Plot for Run times
    ax2.plot(nsides, cuda_run_times, 'g-o', label='ours')
    ax2.plot(nsides, jax_run_times, 'b-o', label='s2fft base')
    ax2.plot(nsides, healpy_run_times, 'r-o', label='Healpy')
    ax2.set_title('Execution Times')
    ax2.set_xlabel('nside')
    ax2.set_ylabel('Time (seconds)')
    ax2.set_xscale('function', functions=(f2, g2))
    ax2.set_xticks(nsides)
    ax2.set_xticklabels(nsides)
    ax2.legend()
    ax2.grid(True, which="both", ls="--")

    # Set the overall title for the figure
    fig.suptitle(title, fontsize=16)

    # Show the plots
    plt.tight_layout(rect=[0, 0, 1, 0.96])  # Adjust rect to make space for the suptitle
    plt.show()

### Visualize Performance Results

This cell calls the plotting function to visualize the benchmark results for forward and backward transforms.

In [None]:
plot_times("Forward FFT Times", nsides, fwd_times)
plot_times("Backward FFT Times", nsides, bwd_times)

### Final Reconstruction and Error Check

This cell performs a final inverse transform to reconstruct the map and calculates the Mean Squared Error (MSE) against the `healpy` reconstructed map to verify accuracy.

In [None]:
# Test backward transform
map_reconstructed = inverse(j_alms, nside=nside, L=L, sampling='healpix', method='jax_cuda')
print(f"shape of map_reconstructed: {map_reconstructed.shape}")
hp_reconstructed = hp.alm2map(alms_healpy, nside=nside, lmax=lmax)

# Compute the mean squared error between the two maps
mse = jnp.mean((map_reconstructed - hp_reconstructed) ** 2)
print(f"Mean Squared Error between reconstructed map and healpy map: {mse}")

shape of map_reconstructed: (3072,)
Mean Squared Error between reconstructed map and healpy map: (1.8236620334440454e-27-8.008792862185043e-31j)
