In [3]:
from jax import random, jit
import jax.numpy as jnp

In [4]:
import spherical
import spherical_alt

## Custom Solid Harmonic

Real spherical harmonics $Y_{\ell m}(x, y, z)$ given by:

$$r^{\ell}\left(\begin{array}{c}
Y_{\ell m} \\
Y_{\ell-m}
\end{array}\right)=\sqrt{\frac{2 \ell+1}{2 \pi}} \bar{\Pi}_{\ell}^m(z)\left(\begin{array}{c}
A_m \\
B_m
\end{array}\right), \quad m>0$$

$$\bar{\Pi}_{\ell}^m(z)=\left[\frac{(\ell-m) !}{(\ell+m) !}\right]^{1 / 2} \sum_{k=0}^{\lfloor(\ell-m) / 2\rfloor}(-1)^k 2^{-\ell}\left(\begin{array}{l}
\ell \\
k
\end{array}\right)\left(\begin{array}{c}
2 \ell-2 k \\
\ell
\end{array}\right) \frac{(\ell-2 k) !}{(\ell-2 k-m) !} r^{2 k} z^{\ell-2 k-m}$$

$$A_m(x, y) \equiv \Re{[(x+i y)^m]}, \qquad B_m(x, y) \equiv \Im{[(x+i y)^m]}$$

## Benchmarks

In [5]:
key = random.PRNGKey(0)
key, subkey = random.split(key, num=2)

n = 1000

coords = random.normal(subkey, shape=(n, 3))

l, m = 10, 7



In [6]:
%time spherical.SolidHarmonic_jit(coords, l, m).block_until_ready()
%timeit spherical.SolidHarmonic_jit(coords, l, m).block_until_ready()

CPU times: user 1.3 s, sys: 0 ns, total: 1.3 s
Wall time: 1.3 s
2 µs ± 6.15 ns per loop (mean ± std. dev. of 7 runs, 1,000,000 loops each)


In [7]:
# NOTE This is slower than I'd like as-is
%time spherical.SolidHarmonics_jit(coords, l).block_until_ready()
%timeit spherical.SolidHarmonics_jit(coords, l).block_until_ready()

CPU times: user 338 ms, sys: 0 ns, total: 338 ms
Wall time: 336 ms
156 µs ± 6.37 µs per loop (mean ± std. dev. of 7 runs, 10,000 loops each)


In [8]:
%time spherical_alt.SH_sympy_jit(coords, l, m).block_until_ready()
%timeit spherical_alt.SH_sympy_jit(coords, l, m).block_until_ready()

CPU times: user 799 ms, sys: 0 ns, total: 799 ms
Wall time: 797 ms
1.96 µs ± 49.9 ns per loop (mean ± std. dev. of 7 runs, 1,000,000 loops each)


In [9]:
%time spherical_alt.SH_real_jit(coords, l, m).block_until_ready()
%timeit spherical_alt.SH_real_jit(coords, l, m).block_until_ready()

CPU times: user 120 ms, sys: 0 ns, total: 120 ms
Wall time: 118 ms
46.2 µs ± 454 ns per loop (mean ± std. dev. of 7 runs, 10,000 loops each)
