In [2]:
import sys
sys.path.append('../')

from jax import random, jit
import jax.numpy as jnp

In [4]:
from jax.scipy.special import sph_harm

In [5]:
import equiformer.spherical as spherical
import equiformer.spherical_alt as spherical_alt

## Custom Solid Harmonic

Real spherical harmonics $Y_{\ell m}(x, y, z)$ given by:

$$r^{\ell}\left(\begin{array}{c}
Y_{\ell m} \\
Y_{\ell-m}
\end{array}\right)=\sqrt{\frac{2 \ell+1}{2 \pi}} \bar{\Pi}_{\ell}^m(z)\left(\begin{array}{c}
A_m \\
B_m
\end{array}\right), \quad m>0$$

$$\bar{\Pi}_{\ell}^m(z)=\left[\frac{(\ell-m) !}{(\ell+m) !}\right]^{1 / 2} \sum_{k=0}^{\lfloor(\ell-m) / 2\rfloor}(-1)^k 2^{-\ell}\left(\begin{array}{l}
\ell \\
k
\end{array}\right)\left(\begin{array}{c}
2 \ell-2 k \\
\ell
\end{array}\right) \frac{(\ell-2 k) !}{(\ell-2 k-m) !} r^{2 k} z^{\ell-2 k-m}$$

$$A_m(x, y) \equiv \Re{[(x+i y)^m]}, \qquad B_m(x, y) \equiv \Im{[(x+i y)^m]}$$

### Playing with Individual Functions

In [4]:
# same but operating on jax arrays
sph_harm(jnp.array([0]), jnp.array([0]), jnp.array([0, 1]), jnp.array([0, 1]))

Array([0.28209478+0.j, 0.28209478+0.j], dtype=complex64)

## Benchmarks

In [6]:
key = random.PRNGKey(0)
key, subkey = random.split(key, num=2)

n = 1000

coords = random.normal(subkey, shape=(n, 3))

l, m = 10, 7

In [6]:
%timeit spherical.solid_harmonic_jit(coords, l, m).block_until_ready()

1.21 ms ± 46.9 µs per loop (mean ± std. dev. of 7 runs, 1 loop each)


In [7]:
# NOTE This is slower than I'd like as-is
%timeit spherical.solid_harmonics_jit(coords, l).block_until_ready()

115 µs ± 32.8 µs per loop (mean ± std. dev. of 7 runs, 1 loop each)


In [8]:
%timeit spherical_alt.SH_sympy_jit(coords, l, m).block_until_ready()

The slowest run took 5.05 times longer than the fastest. This could mean that an intermediate result is being cached.
1.08 ms ± 338 µs per loop (mean ± std. dev. of 7 runs, 1 loop each)


In [6]:
%timeit spherical_alt.SH_real_jit(coords, l, m).block_until_ready()

1.37 ms ± 16.7 µs per loop (mean ± std. dev. of 7 runs, 1 loop each)


In [None]:
%timeit spherical_alt.SH_real_jit_all(coords, l).block_until_ready()

In [23]:
# Checking that values agree
assert jnp.isclose(
    spherical_alt.SH_sympy_jit(coords, l, m),
    spherical.solid_harmonic_jit(coords, l, m),
    rtol=1e-5,
    atol=1e-5,
).all()