In [1]:
import sympy
from sympy import lambdify

from sympy import Znm, Symbol, simplify

In [2]:
import math

from sympy import S
from sympy.core.numbers import I, pi

from sympy.functions.combinatorial.factorials import binomial, factorial

In [1]:
from functools import partial

from jax import jit
from jax import random
import jax.numpy as jnp

In [2]:
import spherical

## Standard SymPy Spherical Harmonics

In [20]:
def SH_real(l, m):
    """Return JAX version of real spherical harmonics Y(theta, phi)."""
    theta, phi = Symbol("theta", real=True), Symbol("phi", real=True)

    ylm = sympy.simplify(Znm(l, m, theta, phi).expand(func=True))

    return lambdify([theta, phi], ylm, modules="jax")

def SH_real_cart(l, m):
    """Return JAX version of real spherical harmonics in cartesian coordinates Y(x, y, z)."""
    theta, phi = Symbol("theta", real=True), Symbol("phi", real=True)
    x, y, z = Symbol("x", real=True), Symbol("y", real=True), Symbol("z", real=True)

    ylm = sympy.simplify(Znm(l, m, theta, phi).expand(func=True))
    ylm = sympy.expand_trig(ylm)

    # Extra step to try and expand complex exponentials
    # NOTE Recent addition
    ylm = sympy.expand(ylm, complex=True)
    ylm = sympy.expand_trig(ylm)
    #ylm = sympy.simplify(ylm)

    # Replacing spherical coords
    # TODO Be careful that phi is being substituted correctly with y < 0
    ylm = ylm.subs(theta, sympy.acos(z / sympy.sqrt(x**2 + y**2 + z**2)))
    ylm = ylm.subs(phi, sympy.acos(x / sympy.sqrt(x**2 + y**2)))

    # Manipulating fractions
    ylm = simplify(ylm)
    ylm = sympy.cancel(ylm)
    ylm = sympy.simplify(ylm)
    ylm = sympy.expand_power_base(ylm, force=True)
    ylm = sympy.powdenest(ylm, force=True)
    ylm = sympy.simplify(ylm)
    ylm = sympy.factor(ylm)

    # Condon-Shortley Phase
    ylm = ylm if (m % 2 == 0) else -ylm

    # Display for debugging
    #display(ylm)

    return lambdify([x, y, z], ylm, modules="jax")

### Defining JITed Versions

In [21]:
@partial(jit, static_argnums=(1, 2))
def SH_real_jit(rs, l, m):
    """Return SH applied to an array of shape [*, 3], where last three coords are x, y, z.
    Expect this function to be slower than using the algebraic methods above.
    """
    # TODO Think about best ordering of axes.

    # Get views
    xs, ys, zs = rs[..., 0], rs[..., 1], rs[..., 2]

    # Convert to spherical coords
    radii = jnp.sqrt(xs**2 + ys**2 + zs**2)
    thetas = jnp.nan_to_num(jnp.arccos(zs / radii), nan=0.0, copy=False)
    #thetas = jnp.arccos(zs / radii)
    phis = jnp.arctan2(ys, xs)

    SH = SH_real(l, m)

    return SH(thetas, phis)

@partial(jit, static_argnums=(1, 2))
def SH_sympy_jit(rs, l, m):
    """Return SH applied to an array of shape [*, 3], where last three coords are x, y, z.
    Expect this function to be slower than using the algebraic methods above.
    """
    xs, ys, zs = rs[..., 0], rs[..., 1], rs[..., 2]
    SH = SH_real_cart(l, m)

    return SH(xs, ys, zs)

## Custom Solid Harmonic

Real spherical harmonics $Y_{\ell m}(x, y, z)$ given by:

$$r^{\ell}\left(\begin{array}{c}
Y_{\ell m} \\
Y_{\ell-m}
\end{array}\right)=\sqrt{\frac{2 \ell+1}{2 \pi}} \bar{\Pi}_{\ell}^m(z)\left(\begin{array}{c}
A_m \\
B_m
\end{array}\right), \quad m>0$$

$$\bar{\Pi}_{\ell}^m(z)=\left[\frac{(\ell-m) !}{(\ell+m) !}\right]^{1 / 2} \sum_{k=0}^{\lfloor(\ell-m) / 2\rfloor}(-1)^k 2^{-\ell}\left(\begin{array}{l}
\ell \\
k
\end{array}\right)\left(\begin{array}{c}
2 \ell-2 k \\
\ell
\end{array}\right) \frac{(\ell-2 k) !}{(\ell-2 k-m) !} r^{2 k} z^{\ell-2 k-m}$$

$$A_m(x, y) \equiv \Re{[(x+i y)^m]}, \qquad B_m(x, y) \equiv \Im{[(x+i y)^m]}$$

In [17]:
def Scale(l):
    r = Symbol("r", real=True, positive=True)

    return sympy.sqrt((sympy.Integer(2) * l + 1) / (2 * pi)) * r**(-l)

def AB(m):
    x, y = Symbol("x", real=True), Symbol("y", real=True)

    power = (x + I * y)**m
    return sympy.re(power), sympy.im(power)

def Pi(l, m):
    z, r = Symbol("z", real=True), Symbol("r", real=True, positive=True)

    prefac = sympy.sqrt(factorial(l - m) / factorial(l + m)) * sympy.sqrt(2 - sympy.KroneckerDelta(m, 0)) / sympy.sqrt(2)

    summation = 0
    for k in range(0, math.floor((l - m) / 2) + 1):
        summation += S.NegativeOne**k * sympy.Integer(2)**(-l) * binomial(l, k) * binomial(2 * l - 2 * k, l) * (factorial(l - 2 * k) / factorial(l - 2 * k - m)) * r**(2 * k) * z**(l - 2 * k - m)

    return prefac * summation

def SolidHarmonics(l):
    SH_dict = {}

    for m in range(0, l + 1, 1):
        A, B = AB(m)
        pre =  Pi(l, m) * Scale(l)
        y_plus = simplify(pre * A)
        y_minus = simplify(pre * B)

        SH_dict[m] = y_plus

        if m > 0:
            SH_dict[-m] = y_minus

    return SH_dict

def SolidHarmonicsJax(l):
    sh_dict = SolidHarmonics(l)

    x, y, z = Symbol("x", real=True), Symbol("y", real=True), Symbol("z", real=True)
    r = Symbol("r", real=True, positive=True)

    return {m: lambdify([x, y, z, r], sh, modules="jax") for m, sh in sh_dict.items()}

@partial(jit, static_argnums=(1, 2))
def SolidHarmonic_jit(coords, l, m):
    
    # TODO See if this can be moved outside
    sh_dict = SolidHarmonicsJax(l)

    # Get views
    xs, ys, zs = coords[..., 0], coords[..., 1], coords[..., 2]
    rs = jnp.sqrt(xs**2 + ys**2 + zs**2)

    return sh_dict[m](xs, ys, zs, rs)

## Benchmarks

In [3]:
key = random.PRNGKey(0)
key, subkey = random.split(key, num=2)

n = 1000
coords = random.normal(subkey, shape=(n, 3))

xs, ys, zs = coords[..., 0], coords[..., 1], coords[..., 2]
rs = jnp.sqrt(xs**2 + ys**2 + zs**2)

l, m = 10, 3



In [4]:
%time spherical.SolidHarmonic_jit(coords, l, m).block_until_ready()
%timeit spherical.SolidHarmonic_jit(coords, l, m).block_until_ready()

CPU times: user 1.34 s, sys: 0 ns, total: 1.34 s
Wall time: 1.34 s
1.94 µs ± 5.22 ns per loop (mean ± std. dev. of 7 runs, 1,000,000 loops each)


In [28]:
%time SH_sympy_jit(rs, l, m).block_until_ready()
%timeit SH_sympy_jit(rs, l, m).block_until_ready()

CPU times: user 1.02 s, sys: 0 ns, total: 1.02 s
Wall time: 1.01 s
1.31 µs ± 3.35 ns per loop (mean ± std. dev. of 7 runs, 1,000,000 loops each)


In [29]:
%time SH_real_jit(rs, l, m).block_until_ready()
%timeit SH_real_jit(rs, l, m).block_until_ready()

CPU times: user 181 ms, sys: 0 ns, total: 181 ms
Wall time: 179 ms
1.34 µs ± 4 ns per loop (mean ± std. dev. of 7 runs, 1,000,000 loops each)
