# Spherical transform benchmarking

In [43]:
from jax import config

config.update("jax_enable_x64", True)


import os
os.environ['CUDA_VISIBLE_DEVICES'] = '2'

import numpy as np
import pyssht as ssht 

from jax import jit, device_put
from jax.lib import xla_bridge
print(xla_bridge.get_backend().platform)

from s2fft import samples
from s2fft.precompute_transform import spin_spherical, construct

gpu


# Code

In [44]:
# Generate a random bandlimited field
L_approx = 512
L = 3*int(L_approx/3)
spin = 0
sampling = "dh"
reality = False
rng = np.random.default_rng(193412341234)

ntheta = samples.ntheta(L, sampling)
nphi = samples.nphi_equiang(L, sampling)

# Generate a random signal
f = np.random.randn(ntheta,nphi) + 1j*np.random.randn(ntheta,nphi)
f = f.astype(np.complex128)

# Compute the bandlimit harmonic coefficients
flm_1d = ssht.forward(f, L, spin, Method=sampling.upper())
flm = samples.flm_1d_to_2d(flm_1d, L)
flm[:abs(spin)] = 0
flm_1d = samples.flm_2d_to_1d(flm, L)

In [45]:
kernel_i = construct.spin_spherical_kernel(L, spin, reality, sampling, forward=False)
kernel_i_jax = device_put(kernel_i)
print(f"Kernel memory = {kernel_i.nbytes*(1e-6)} MB")

Kernel memory = 4240.6704 MB


### SSHT C implementation

In [46]:
f = ssht.inverse(flm_1d, L, spin, Reality=reality, Method=sampling.upper())
%timeit ssht.inverse(flm_1d, L, spin, Method=sampling.upper())

1.63 s ± 5.25 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)


### numpy implementation

In [47]:
f_numpy = spin_spherical.inverse_transform(flm, kernel_i, L, sampling, reality, spin, None)
%timeit spin_spherical.inverse_transform(flm, kernel_i, L, sampling, reality, spin, None)

1.64 s ± 20.7 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)


### JAX implementation

In [48]:
flm_jax = device_put(flm)
f_jax = spin_spherical.inverse_transform_jax(flm_jax, kernel_i_jax, L, sampling, reality, spin, None)
%timeit spin_spherical.inverse_transform_jax(flm_jax, kernel_i_jax, L, sampling, reality, spin, None)

7.16 ms ± 620 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)


### Evaluate transform error

In [49]:
print("numpy: Inverse mean absolute error = {}".format(np.nanmean(np.abs(f_numpy - f))))
print("jax: Inverse mean absolute error = {}".format(np.nanmean(np.abs(f_jax - f))))

numpy: Inverse mean absolute error = 2.7959168883063927e-11
jax: Inverse mean absolute error = 2.795924569525004e-11
