# __S2FFT CUDA Implementation__
---

[![colab image](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/astro-informatics/s2fft/blob/main/notebooks/JAX_HEALPix_frontend.ipynb)

In [None]:
import sys
IN_COLAB = 'google.colab' in sys.modules

# Install s2fft and data if running on google colab.
if IN_COLAB:
    !pip install s2fft &> /dev/null

Short comparaison between the pure JAX implementation and the CUDA implementation of the S2FFT algorithm.

In [1]:
import jax
from jax import numpy as jnp
import argparse
import time

jax.config.update("jax_enable_x64", False)

from s2fft.utils.healpix_ffts import healpix_ifft_nospec_jax, spectral_folding_jax, spectral_periodic_extension_jax, \
    healpix_fft_jax, healpix_ifft_jax, healpix_fft_cuda, healpix_ifft_cuda

JAX is not using 64-bit precision. This will dramatically affect numerical precision at even moderate L.


In [8]:
nside = 256
L = 2 * nside
total_pixels = 12 * nside**2
ftm_shape = (4 * nside - 1, 2 * L)
ftm_size = ftm_shape[0] * ftm_shape[1]
arr = jax.random.normal(jax.random.PRNGKey(0), (total_pixels, ))

In [22]:
%%time
cuda_res = healpix_fft_cuda(arr, L, nside,reality=False).block_until_ready()

CPU times: user 1.54 ms, sys: 2.81 ms, total: 4.35 ms
Wall time: 3.41 ms


In [15]:
%%time
jax_res = healpix_fft_jax(arr, L, nside , reality=False).block_until_ready()

CPU times: user 4.1 ms, sys: 3.46 ms, total: 7.55 ms
Wall time: 5.23 ms


In [7]:
jnp.allclose(cuda_res, jax_res , atol=1e-3 , rtol=1e-3)

Array(True, dtype=bool)

In [26]:
arr = jax.random.normal(jax.random.PRNGKey(0), ftm_shape)

In [27]:
%%time
cuda_res = healpix_ifft_cuda(arr, L, nside,reality=False).block_until_ready()

f.shape (1023, 1024)
CPU times: user 117 ms, sys: 53.2 ms, total: 170 ms
Wall time: 182 ms


In [30]:
%%time
jax_res = healpix_ifft_jax(arr, L, nside , reality=False).block_until_ready()

CPU times: user 480 µs, sys: 12.3 ms, total: 12.8 ms
Wall time: 12.2 ms


In [31]:
jnp.allclose(cuda_res, jax_res , atol=1e-3 , rtol=1e-3)

Array(False, dtype=bool)

: 