# Spherical transform benchmarking

In [1]:
import os
os.environ['CUDA_VISIBLE_DEVICES'] = '2'

import numpy as np
import pyssht as ssht 

from jax import jit, device_put
from jax.lib import xla_bridge
print(xla_bridge.get_backend().platform)

import s2fft
from s2fft.general_precompute import spin_spherical, construct

gpu


# Results

| Wall-Time | MW (SSHT)         | MW (numpy)        | MW (JAX)          | MWSS (SSHT)       | MWSS (numpy)     | MWSS (JAX)        | DH (SSHT)         | DH (numpy)        | DH (JAX)          |
|-----------|-------------------|-------------------|-------------------|-------------------|------------------|-------------------|-------------------|-------------------|-------------------|
| 8         | 30.8 µs ± 340 ns  | 27.4 µs ± 147 ns  | 60.4 µs ± 1.45 µs | 31.1 µs ± 347 ns  | 28.6 µs ± 149 ns | 66 µs ± 4.66 µs   | 61.3 µs ± 512 ns  | 33.1 µs ± 506 ns  | 63.5 µs ± 1.26 µs |
| 16        | 99.2 µs ± 379 ns  | 58.1 µs ± 354 ns  | 62.9 µs ± 1.63 µs | 65.8 µs ± 332 ns  | 47.4 µs ± 436 ns | 62.6 µs ± 815 ns  | 276 µs ± 4.01 µs  | 90.4 µs ± 952 ns  | 65.7 µs ± 2.77 µs |
| 32        | 270 µs ± 2.91 µs  | 161 µs ± 647 ns   | 66.8 µs ± 928 ns  | 292 µs ± 10.1 µs  | 159 µs ± 1.02 µs | 67.9 µs ± 2.14 µs | 1.29 ms ± 7.78 µs | 295 µs ± 2.98 µs  | 69 µs ± 2.42 µs   |
| 64        | 2.34 ms ± 29.5 µs | 1.46 ms ± 7.89 µs | 67.3 µs ± 1.39 µs | 2.09 ms ± 13.3 µs | 919 µs ± 8.44 µs | 68.8 µs ± 2.7 µs  | 7.68 ms ± 61.6 µs | 2.9 ms ± 32.2 µs  | 66.3 µs ± 2.1 µs  |
| 128       | 9.03 ms ± 71.6 µs | 7.01 ms ± 29.5 µs | 90.9 µs ± 3.95 µs | 14.7 ms ± 87.1 µs | 6.99 ms ± 54 µs  | 69.2 µs ± 1.07 µs | 45.3 ms ± 223 µs  | 14.2 ms ± 55.2 µs | 94.5 µs ± 1.7 µs  |
| 256       | 72.8 ms ± 385 µs  | 55.5 ms ± 372 µs  | 148 µs ± 247 ns   | 111 ms ± 534 µs   | 54.5 ms ± 197 µs | 135 µs ± 7.09 µs  | 315 ms ± 1.15 ms  | 111 ms ± 464 µs   | 251 µs ± 230 ns   |
| 512       | 644 ms ± 5.76 ms  | 441 ms ± 41.8 ms  | 655 µs ± 1.86 µs  | 1.05 s ± 13.1 ms  | 438 ms ± 1.7 ms  | 623 µs ± 1.01 µs  | 2.31 s ± 5.46 ms  | 884 ms ± 30.6 ms  | 1.25 ms ± 2.26 µs |
| 1024      | 5.69 s ± 35.4 ms  | 3.31 s ± 58.6 ms  | 4.93 ms ± 8.35 µs | 9.47 s ± 78.9 ms  | 3.74 s ± 355 ms  | 4.85 ms ± 7.32 µs | 12.6 s ± 6.63 ms  | 5.09 s ± 541 ms   | 9.25 ms ± 5.1 µs  |
| 2048      |                   |                   |                   |                   |                  |                   |                   |                   |                   |

| Precision | MW (numpy) | MW (JAX) | MWSS (numpy) | MWSS (JAX) | DH (numpy) | DH (JAX) |Memory(MB)| MW       | MWSS     | DH       | 
|-----------|------------|----------|--------------|------------|------------|----------|----------|----------|----------|----------|
| 8         | 1.12e-07   | 1.64e-07 | 1.062e-07    | 1.45e-07   | 9.81e-08   | 1.84e-07 |     8    | 0.002048 | 0.002304 | 0.004096 |
| 16        | 2.30e-07   | 3.28e-07 | 2.35e-07     | 3.51e-07   | 2.47e-07   | 3.27e-07 |    16    | 0.016384 | 0.017408 | 0.032768 |
| 32        | 5.77e-07   | 8.62e-07 | 6.13e-07     | 7.66e-07   | 6.10e-07   | 8.21e-07 |    32    | 0.131072 | 0.135168 | 0.262144 |
| 64        | 1.69e-06   | 1.79e-06 | 1.67e-06     | 1.99e-06   | 1.72e-06   | 1.77e-06 |    64    | 1.048576 | 1.06496  | 2.097152 |
| 128       | 4.74e-06   | 5.21e-06 | 4.80e-06     | 4.20e-06   | 4.50e-06   | 5.11e-06 |   128    | 8.388608 | 8.454144 | 16.77721 |
| 256       | 1.26e-05   | 8.01e-06 | 1.32e-05     | 1.39e-05   | 1.26e-05   | 8.00e-06 |   256    | 67.10886 | 67.37101 | 134.2177 |
| 512       | 3.56e-05   | 1.98e-05 | 3.61e-05     | 2.49e-05   | 3.65e-05   | 1.99e-05 |   512    | 536.8709 | 537.9195 | 1073.742 |
| 1024      | 1.02e-04   | 3.67e-05 | 1.03e-04     | 4.72e-05   | 1.02e-04   | 3.67e-05 |  1024    | 4294.967 | 4299.162 | 8589.935 |
| 2048      |            |          |              |            |            |          |  2048    |          |          |          |

# Code

In [2]:
# Generate a random bandlimited field
L = 1024
spin = 0
sampling = "dh"
reality = True
rng = np.random.default_rng(193412341234)

flm = s2fft.utils.generate_flm(rng, L, spin).astype(np.complex64)
flm_1d = s2fft.samples.flm_2d_to_1d(flm, L)

In [3]:
kernel_i = construct.spin_spherical_kernel(L, spin, reality, sampling, forward=False)
kernel_i_jax = device_put(kernel_i)
print(f"Kernel memory = {kernel_i.nbytes*(1e-6)} MB")

Kernel memory = 8589.934592 MB


### SSHT C implementation

In [4]:
f = ssht.inverse(flm_1d, L, spin, Reality=True, Method=sampling.upper())
%timeit ssht.inverse(flm_1d, L, spin, Method=sampling.upper())

12.6 s ± 6.63 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)


### numpy implementation

In [5]:
f_numpy = spin_spherical.inverse_transform(flm, kernel_i, L, sampling, reality, spin, None)
%timeit spin_spherical.inverse_transform(flm, kernel_i, L, sampling, reality, spin, None)

5.09 s ± 541 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)


### JAX implementation

In [6]:
flm_jax = device_put(flm)
inverse_jit = jit(spin_spherical.inverse_transform_jax, static_argnums=(2, 3, 4, 5, 6))
f_jax = inverse_jit(flm_jax, kernel_i_jax, L, sampling, reality, spin, None).block_until_ready()

%timeit inverse_jit(flm_jax, kernel_i_jax, L, sampling, reality, spin, None)

9.25 ms ± 5.1 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)


### Evaluate transform error

In [7]:
f_jax = np.array(f_jax)
print("numpy: Inverse mean absolute error = {}".format(np.nanmean(np.abs(f_numpy - f))))
print("jax: Inverse mean absolute error = {}".format(np.nanmean(np.abs(f_jax - f))))

numpy: Inverse mean absolute error = 0.0001023178604685384
jax: Inverse mean absolute error = 3.6663851805025655e-05
