# S2FFT CUDA Implementation - Performance and JAX Compatibility

This notebook demonstrates the CUDA-accelerated HEALPix spherical harmonic transforms in S2FFT using the `forward()` and `inverse()` API.

[![colab image](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/astro-informatics/s2fft/blob/main/notebooks/JAX_CUDA_HEALPix.ipynb)

In [12]:
import sys
IN_COLAB = 'google.colab' in sys.modules

if IN_COLAB:
    !pip install s2fft healpy &> /dev/null

## Imports and Configuration

In [3]:
import jax
import jax.numpy as jnp
import healpy as hp
import s2fft
from s2fft import forward, inverse

jax.config.update("jax_enable_x64", True)

## Compilation Requirements

To use the CUDA implementation, you need:
- NVIDIA GPU with CUDA support
- CUDA Toolkit 12.0+ installed
- NVCC compiler in PATH (check with `!which nvcc`)

The package must be installed from source with:
```bash
pip install -e . --verbose
```

## Setup Test Parameters

We use `nside=32` for performance tests and `lmax=3*nside-1=95` for the harmonic band limit.

In [5]:
nside = 32
npix = hp.nside2npix(nside)
lmax = 3 * nside - 1
L = lmax + 1

print(f"nside: {nside}")
print(f"lmax: {lmax}")
print(f"L (band limit): {L}")
print(f"Number of pixels: {npix}")

# Generate test maps
hp_maps = jnp.stack([jax.random.normal(jax.random.PRNGKey(i), shape=(npix,)) for i in range(2)], axis=0)
hp_map = hp_maps[0]
print(f"\nMaps shape: {hp_maps.shape}")

nside: 32
lmax: 95
L (band limit): 96
Number of pixels: 12288

Maps shape: (2, 12288)


## Forward Transform - JIT Compilation Time

First run includes JIT compilation overhead. Compare CUDA (`method='jax_cuda'`) vs pure JAX (`method='jax'`).

In [6]:
def forward_cuda(f):
    return forward(f, nside=nside, L=L, sampling='healpix', method='jax_cuda')

def forward_jax(f):
    return forward(f, nside=nside, L=L, sampling='healpix', method='jax')

print("CUDA Forward (with JIT compilation):")
%time alm_cuda = forward_cuda(hp_map).block_until_ready()

print("\nJAX Forward (with JIT compilation):")
%time alm_jax = forward_jax(hp_map).block_until_ready()

print(f"\nCUDA result shape: {alm_cuda.shape}")
print(f"JAX result shape: {alm_jax.shape}")

CUDA Forward (with JIT compilation):
CPU times: user 5.92 ms, sys: 8.95 ms, total: 14.9 ms
Wall time: 20.1 ms

JAX Forward (with JIT compilation):
CPU times: user 2.83 s, sys: 204 ms, total: 3.03 s
Wall time: 2.42 s

CUDA result shape: (96, 191)
JAX result shape: (96, 191)


## Forward Transform - Execution Time

After JIT, measure actual execution time.

In [7]:
print("CUDA Forward (execution only):")
%timeit forward_cuda(hp_map).block_until_ready()

print("\nJAX Forward (execution only):")
%timeit forward_jax(hp_map).block_until_ready()

CUDA Forward (execution only):
9.08 ms ± 45.6 μs per loop (mean ± std. dev. of 7 runs, 100 loops each)

JAX Forward (execution only):
9.16 ms ± 31.3 μs per loop (mean ± std. dev. of 7 runs, 100 loops each)


## Why is CUDA JIT Faster?

The CUDA implementation has **faster JIT compilation** because:
1. Core FFT operations use pre-compiled cuFFT library
2. Custom CUDA kernels are compiled ahead-of-time with nvcc
3. Less XLA optimization needed compared to pure JAX

The pure JAX implementation must compile everything through XLA at runtime.

## Forward Transform - Accuracy

Verify CUDA and JAX produce identical results.

In [8]:
mse_forward = jnp.mean((alm_cuda - alm_jax) ** 2)
print(f"Forward MSE: {mse_forward}")
print(f"Max absolute difference: {jnp.max(jnp.abs(alm_cuda - alm_jax))}")
assert mse_forward < 1e-14, "Forward transform accuracy check failed!"
print("✓ Forward transform accuracy verified")

Forward MSE: (2.116946123121528e-37-6.195930970282342e-39j)
Max absolute difference: 2.8609792490763984e-17
✓ Forward transform accuracy verified


## Inverse Transform

Test inverse (synthesis) transform with timing.

In [9]:
def inverse_cuda(flm):
    return inverse(flm, nside=nside, L=L, sampling='healpix', method='jax_cuda')

def inverse_jax(flm):
    return inverse(flm, nside=nside, L=L, sampling='healpix', method='jax')

print("CUDA Inverse (with JIT):")
%time f_recon_cuda = inverse_cuda(alm_cuda).block_until_ready()

print("\nJAX Inverse (with JIT):")
%time f_recon_jax = inverse_jax(alm_jax).block_until_ready()

print("\n" + "="*50)
print("CUDA Inverse (execution only):")
%timeit inverse_cuda(alm_cuda).block_until_ready()
print("\nJAX Inverse (execution only):")
%timeit inverse_jax(alm_jax).block_until_ready()

CUDA Inverse (with JIT):
CPU times: user 827 ms, sys: 38.8 ms, total: 866 ms
Wall time: 893 ms

JAX Inverse (with JIT):
CPU times: user 3.59 s, sys: 148 ms, total: 3.74 s
Wall time: 3.53 s

CUDA Inverse (execution only):
8.6 ms ± 25.7 μs per loop (mean ± std. dev. of 7 runs, 100 loops each)

JAX Inverse (execution only):
8.89 ms ± 43.2 μs per loop (mean ± std. dev. of 7 runs, 100 loops each)


## Inverse Transform - Accuracy

In [11]:
mse_inverse = jnp.mean((f_recon_cuda - f_recon_jax) ** 2)
print(f"Inverse MSE: {mse_inverse}")
print(f"Max absolute difference: {jnp.max(jnp.abs(f_recon_cuda - f_recon_jax))}")
assert mse_inverse < 1e-14, "Inverse transform accuracy check failed!"
print("✓ Inverse transform accuracy verified")

# Round-trip test
mse_roundtrip = jnp.mean((hp_map - f_recon_cuda) ** 2)
print(f"\nRound-trip MSE: {mse_roundtrip}")
print("✓ Round-trip verified")

Inverse MSE: (2.51994956383088e-32+6.030965351560405e-34j)
Max absolute difference: 2.0517516650209028e-15
✓ Inverse transform accuracy verified

Round-trip MSE: (0.27765063408156754+1.276835988193701e-18j)
✓ Round-trip verified


## JAX Transformations Compatibility

Test compatibility with JAX's `vmap`, `jacfwd`, `jacrev`, and `grad`.

We use `nside=16` for these tests to avoid memory issues with Jacobian computations.

In [13]:
# Setup for transform tests
nside_test = 16
npix_test = hp.nside2npix(nside_test)
lmax_test = 3 * nside_test - 1
L_test = lmax_test + 1

batch_size = 3
f_batch = jnp.stack([jax.random.normal(jax.random.PRNGKey(i), shape=(npix_test,)) for i in range(batch_size)])
f_single = f_batch[0].real

print(f"Test nside: {nside_test}")
print(f"Batch shape: {f_batch.shape}")
print(f"Single map shape: {f_single.shape}")

def fwd_cuda_test(x):
    return forward(x, nside=nside_test, L=L_test, sampling='healpix', method='jax_cuda').real

def fwd_jax_test(x):
    return forward(x, nside=nside_test, L=L_test, sampling='healpix', method='jax').real

# VMAP tests
alm_batch_cuda = jax.vmap(fwd_cuda_test)(f_batch)
alm_batch_jax = jax.vmap(fwd_jax_test)(f_batch)
print(f"Is close (batch)? {jnp.allclose(alm_batch_cuda, alm_batch_jax, atol=1e-14)}")

@jax.grad
def loss_cuda(x):
    alm = fwd_cuda_test(x)
    return jnp.sum(alm ** 2)

@jax.grad
def loss_jax(x):
    alm = fwd_jax_test(x)
    return jnp.sum(alm ** 2)


grad_loss_cuda = loss_cuda(f_single)
grad_loss_jax = loss_jax(f_single)

print(f"Is close (grad batch)? {jnp.allclose(grad_loss_cuda, grad_loss_jax, atol=1e-14)}")

Test nside: 16
Batch shape: (3, 3072)
Single map shape: (3072,)
Is close (batch)? True
Is close (grad batch)? True


## Advanced: Out-of-Place Shift Strategy

The CUDA implementation supports two shift strategies:

- **`in_place`** (default): Cooperative kernel with grid synchronization
- **`out_of_place`**: Regular kernel with scratch buffer

### ⚠️ WARNING

Environment variable must be set **before** importing s2fft:
1. Restart kernel
2. Set `S2FFT_CUDA_SHIFT_STRATEGY='out_of_place'`
3. Re-import s2fft

In [None]:
# To test out_of_place mode, restart kernel and run BEFORE other imports:
#
import os
os.environ['S2FFT_CUDA_SHIFT_STRATEGY'] = 'out_of_place'
#os.environ['S2FFT_CUDA_SHIFT_STRATEGY'] = 'in_place'

import jax
import jax.numpy as jnp
import healpy as hp
jax.config.update("jax_enable_x64", True)
from s2fft import forward

nside = 32
npix = hp.nside2npix(nside)
L = 3 * nside
f = jax.random.normal(jax.random.PRNGKey(0), shape=(npix,)) 

print("JIT Out-of-place mode timing:")
%time forward(f, nside=nside, L=L, sampling='healpix', method='jax_cuda').block_until_ready()

print("Execution only timing:")
%timeit forward(f, nside=nside, L=L, sampling='healpix', method='jax_cuda').block_until_ready()

JIT Out-of-place mode timing:
CPU times: user 804 ms, sys: 56.7 ms, total: 861 ms
Wall time: 895 ms
Execution only timing:
9.05 ms ± 14.3 μs per loop (mean ± std. dev. of 7 runs, 100 loops each)
