In [1]:
import math
from collections import defaultdict

import numpy as np
from scipy.special import gamma

### ptilde(n) (equation 7)

In [6]:
def ptilde(n):
    """return the power of x, y, z separately"""
    l = math.floor(math.sqrt(n))
    m = n - l * l - l
    mu = l - m
    nu = l + m
    if nu % 2 == 0:
        i = mu // 2
        j = nu // 2
        k = 0
    else:
        i = (mu - 1) // 2
        j = (nu - 1) // 2
        k = 1
    return (i, j, k)

In [7]:
ptilde(37)

(5, 0, 1)

### Alm (equation 44)

In [8]:
def Alm(l, m):
    return math.sqrt(
        (2 - int(m == 0))
        * (2 * l + 1)
        * math.factorial(l - m)
        / (4 * math.pi * math.factorial(l + m))
    )

In [11]:
Alm(5,0)

0.9356025796273888

### Blmjk (equation 47)

In [21]:
def Blmjk(l, m, j, k):
    a = l + m + k - 1
#     print("a", a)
    b = -l + m + k - 1
#     print("b", b)
    if (b < 0) and (b % 2 == 0):
        return 0
    else:
        ratio = gamma(0.5 * a + 1) / gamma(0.5 * b + 1)
#         print("gamma b", gamma(0.5*b+1))
    return (
        2**l
        * math.factorial(m)
        / (
            math.factorial(j)
            * math.factorial(k)
            * math.factorial(m - j)
            * math.factorial(l - m - k)
        )
        * ratio
    )

In [22]:
Blmjk(5,0,0,1)

1.8750000000000002

### Cpqk (equation 49)

In [23]:
def Cpqk(p, q, k):
    return math.factorial(k // 2) / (
        math.factorial(q // 2)
        * math.factorial((k - p) // 2)
        * math.factorial((p - q) // 2)
    )

In [25]:
Cpqk(1,1,3)

1.0

### Ylm (equation 50)

In [75]:
def Ylm(l, m):
    """the output is a dictionary, where the key (xp,yp,zp) is 
       the power to x, y, z separately, and the key value is the 
       coefficient
    """
    res = defaultdict(lambda: 0)
#     print("res", res)
    A = Alm(l, abs(m))
    for j in range(int(m < 0), abs(m) + 1, 2):
        for k in range(0, l - abs(m) + 1, 2):
            B = Blmjk(l, abs(m), j, k)
            if not B:
                continue
            factor = A * B
            for p in range(0, k + 1, 2):
                for q in range(0, p + 1, 2):
                    ind = (abs(m) - j + p - q, j + q, 0)
                    res[ind] += (
                        (-1) ** ((j + p - (m < 0)) // 2) * factor * Cpqk(p, q, k)
                    )
#                     print("res_pq", res)
        for k in range(1, l - abs(m) + 1, 2):
            B = Blmjk(l, abs(m), j, k)
            if not B:
                continue
            factor = A * B
            for p in range(0, k, 2):
                for q in range(0, p + 1, 2):
                    ind = (abs(m) - j + p - q, j + q, 1)
                    res[ind] += (
                        (-1) ** ((j + p - (m < 0)) // 2) * factor * Cpqk(p, q, k - 1)
                    )
#                     print("res_pq", dict(res))

    return dict(res)

In [80]:
Ylm(0,0)

{(0, 0, 0): 0.28209479177387814}

In [81]:
Ylm(1,-1)

{(0, 1, 0): 0.4886025119029199}

In [82]:
Ylm(1,0)

{(0, 0, 1): 0.4886025119029199}

In [83]:
Ylm(1,1)

{(1, 0, 0): 0.4886025119029199}

In [67]:
def p_Y(p, l, m, res):
    for k, v in Ylm(l, m).items():
        print(k,v)
        if k not in p:
            continue
        res[p[k]] = v
    return res

In [74]:
res = np.zeros((9,9))
po = {ptilde(i): i for i in range(9)}
p_Y(po,2,1,res[:,0])

(1, 0, 1) 1.0925484305920792


array([0.        , 0.        , 0.        , 0.        , 0.        ,
       1.09254843, 0.        , 0.        , 0.        ])

### A1 (equation 52)

In [62]:
def A1(lmax):
    """Note: The normalization here matches the starry paper, but not the
    code. To get the code's normalization, multiply the result by 2 /
    sqrt(pi).
    """
    n = (lmax + 1) ** 2
    res = np.zeros((n, n))
    p = {ptilde(i): i for i in range(n)}
    print(p)
    n = 0
    for l in range(lmax + 1):
        for m in range(-l, l + 1):
            temp = p_Y(p, l, m, res[:, n])
#             print("py", temp)
            n += 1
    return res

In [65]:
A1(2)

{(0, 0, 0): 0, (1, 0, 0): 1, (0, 0, 1): 2, (0, 1, 0): 3, (2, 0, 0): 4, (1, 0, 1): 5, (1, 1, 0): 6, (0, 1, 1): 7, (0, 2, 0): 8}


array([[ 0.28209479,  0.        ,  0.        ,  0.        ,  0.        ,
         0.        ,  0.63078313,  0.        ,  0.        ],
       [ 0.        ,  0.        ,  0.        ,  0.48860251,  0.        ,
         0.        ,  0.        ,  0.        ,  0.        ],
       [ 0.        ,  0.        ,  0.48860251,  0.        ,  0.        ,
         0.        ,  0.        ,  0.        ,  0.        ],
       [ 0.        ,  0.48860251,  0.        ,  0.        ,  0.        ,
         0.        ,  0.        ,  0.        ,  0.        ],
       [ 0.        ,  0.        ,  0.        ,  0.        ,  0.        ,
         0.        , -0.9461747 ,  0.        ,  0.54627422],
       [ 0.        ,  0.        ,  0.        ,  0.        ,  0.        ,
         0.        ,  0.        ,  1.09254843,  0.        ],
       [ 0.        ,  0.        ,  0.        ,  0.        ,  1.09254843,
         0.        ,  0.        ,  0.        ,  0.        ],
       [ 0.        ,  0.        ,  0.        ,  0.        ,  0

### gtilde(n) (equation 11)

In [87]:
def gtilde(n):
    l = math.floor(math.sqrt(n))
    m = n - l * l - l
    mu = l - m
    nu = l + m
    if nu % 2 == 0:
        I = [mu // 2]
        J = [nu // 2]
        K = [0]
        C = [(mu + 2) // 2]
    elif (l == 1) and (m == 0):
        I = [0]
        J = [0]
        K = [1]
        C = [1]
    elif (mu == 1) and (l % 2 == 0):
        I = [l - 2]
        J = [1]
        K = [1]
        C = [3]
    elif mu == 1:
        I = [l - 3, l - 1, l - 3]
        J = [0, 0, 2]
        K = [1, 1, 1]
        C = [-1, 1, 4]
    else:
        I = [(mu - 5) // 2, (mu - 5) // 2, (mu - 1) // 2]
        J = [(nu - 1) // 2, (nu + 3) // 2, (nu - 1) // 2]
        K = [1, 1, 1]
        C = [(mu - 3) // 2, -(mu - 3) // 2, -(mu + 3) // 2]
    res = {}
    for i, j, k, c in zip(I, J, K, C):
        res[(i, j, k)] = c
    return res



In [95]:
gtilde(6)

{(1, 1, 0): 2}

### A2 (equation 53)

In [96]:
def p_G(p, n, res):
    for k, v in gtilde(n).items():
        if k not in p:
            continue
        res[p[k]] = v
    return res

In [100]:
def A2_inv(lmax):
    n = (lmax + 1) ** 2
    res = np.zeros((n, n))
    p = {ptilde(m): m for m in range(n)}
    n = 0
    for l in range(lmax + 1):
        for _ in range(-l, l + 1):
            p_G(p, n, res[:, n])
            n += 1
    return res

In [102]:
a2_inv = A2_inv(2)

In [103]:
a2 = np.linalg.inv(a2_inv)

In [104]:
a2

array([[ 1.        ,  0.        ,  0.        ,  0.        ,  0.        ,
         0.        ,  0.        ,  0.        ,  0.        ],
       [ 0.        ,  0.5       ,  0.        ,  0.        ,  0.        ,
         0.        ,  0.        ,  0.        ,  0.        ],
       [ 0.        ,  0.        ,  1.        ,  0.        ,  0.        ,
         0.        ,  0.        ,  0.        ,  0.        ],
       [ 0.        ,  0.        ,  0.        ,  1.        ,  0.        ,
         0.        ,  0.        ,  0.        ,  0.        ],
       [ 0.        ,  0.        ,  0.        ,  0.        ,  0.33333333,
         0.        ,  0.        ,  0.        ,  0.        ],
       [-0.        , -0.        , -0.        , -0.        , -0.        ,
        -0.33333333, -0.        , -0.        , -0.        ],
       [ 0.        ,  0.        ,  0.        ,  0.        ,  0.        ,
         0.        ,  0.5       ,  0.        ,  0.        ],
       [ 0.        ,  0.        ,  0.        ,  0.        ,  0

### A (equation 14)

In [107]:
def A(lmax):
    a1 = A1(lmax)
    a2_inv = A2_inv(lmax)
    a2 = np.linalg.inv(a2_inv)
    return np.dot(a1,a2)

In [108]:
A(2)

{(0, 0, 0): 0, (1, 0, 0): 1, (0, 0, 1): 2, (0, 1, 0): 3, (2, 0, 0): 4, (1, 0, 1): 5, (1, 1, 0): 6, (0, 1, 1): 7, (0, 2, 0): 8}
(0, 0, 0) 0.28209479177387814
(0, 1, 0) 0.4886025119029199
(0, 0, 1) 0.4886025119029199
(1, 0, 0) 0.4886025119029199
(1, 1, 0) 1.0925484305920792
(0, 1, 1) 1.0925484305920792
(0, 0, 0) 0.6307831305050402
(2, 0, 0) -0.9461746957575603
(0, 2, 0) -0.9461746957575603
(1, 0, 1) 1.0925484305920792
(2, 0, 0) 0.5462742152960396
(0, 2, 0) -0.5462742152960396


array([[ 0.28209479,  0.        ,  0.        ,  0.        ,  0.        ,
         0.        ,  0.31539157,  0.        ,  0.        ],
       [ 0.        ,  0.        ,  0.        ,  0.48860251,  0.        ,
         0.        ,  0.        ,  0.        ,  0.        ],
       [ 0.        ,  0.        ,  0.48860251,  0.        ,  0.        ,
         0.        ,  0.        ,  0.        ,  0.        ],
       [ 0.        ,  0.24430126,  0.        ,  0.        ,  0.        ,
         0.        ,  0.        ,  0.        ,  0.        ],
       [ 0.        ,  0.        ,  0.        ,  0.        ,  0.        ,
         0.        , -0.47308735,  0.        ,  0.54627422],
       [ 0.        ,  0.        ,  0.        ,  0.        ,  0.        ,
         0.        ,  0.        ,  0.36418281,  0.        ],
       [ 0.        ,  0.        ,  0.        ,  0.        ,  0.36418281,
         0.        ,  0.        ,  0.        ,  0.        ],
       [ 0.        ,  0.        ,  0.        ,  0.        ,  0

### jaxoplanet dependencies

In [112]:
from functools import partial
from typing import Callable, Tuple

In [113]:
import jax
import jax.numpy as jnp
from scipy.special import roots_legendre

#### types

In [109]:
from typing import Any

Scalar = Any
Array = Any
PyTree = Any

#### kite area

In [110]:
def kite_area(a: Array, b: Array, c: Array) -> Array:
    def sort2(a: Array, b: Array) -> Tuple[Array, Array]:
        return jnp.minimum(a, b), jnp.maximum(a, b)

    a, b = sort2(a, b)
    b, c = sort2(b, c)
    a, b = sort2(a, b)

    square_area = (a + (b + c)) * (c - (a - b)) * (c + (a - b)) * (a + (b - c))
    return jnp.sqrt(jnp.maximum(square_area, 0.0))

In [114]:
def kappas(b: Array, r: Array) -> Tuple[Array, Array]:
    b2 = jnp.square(b)
    factor = (r - 1) * (r + 1)
    b_cond = jnp.logical_and(jnp.greater(b, jnp.abs(1 - r)), jnp.less(b, 1 + r))
    b_ = jnp.where(b_cond, b, 1)
    area = jnp.where(b_cond, kite_area(r, b_, 1), 0)
    return jnp.arctan2(area, b2 + factor), jnp.arctan2(area, b2 - factor)


In [115]:
def q_integral(l_max: int, lam: Array) -> Array:
    zero = jnp.zeros_like(lam)
    c = jnp.cos(lam)
    s = jnp.sin(lam)
    h = {
        (0, 0): 2 * lam + jnp.pi,
        (0, 1): -2 * c,
    }

    def get(u: int, v: int) -> Array:
        if (u, v) in h:
            return h[(u, v)]
        if u >= 2:
            comp = 2 * c ** (u - 1) * s ** (v + 1) + (u - 1) * get(u - 2, v)
        else:
            assert v >= 2
            comp = -2 * c ** (u + 1) * s ** (v - 1) + (v - 1) * get(u, v - 2)
        comp /= u + v
        h[(u, v)] = comp
        return comp

    U = []
    for l in range(l_max + 1):  # noqa
        for m in range(-l, l + 1):
            if l == 1 and m == 0:
                U.append((np.pi + 2 * lam) / 3)
                continue
            mu = l - m
            nu = l + m
            if (mu % 2) == 0 and (mu // 2) % 2 == 0:
                u = mu // 2 + 2
                v = nu // 2
                assert u % 2 == 0
                U.append(get(u, v))
            else:
                U.append(zero)

    return jnp.stack(U)

In [116]:
b = 1.5
r = 1

In [117]:
k0,k1 = kappas(b,r)

In [118]:
k0

DeviceArray(0.7227342, dtype=float32, weak_type=True)

In [119]:
k1

DeviceArray(0.7227342, dtype=float32, weak_type=True)

In [120]:
q = q_integral(2, 0.5*jnp.pi - k1)

In [121]:
q

DeviceArray([ 2.9149368,  0.       ,  1.6125723, -0.1929193,  2.2947197,
              0.       ,  0.       ,  0.       ,  0.6202171],            dtype=float32, weak_type=True)

In [122]:
q.shape

(9,)

In [127]:
def p_integral(order: int, l_max: int, b: Array, r: Array, kappa0: Array) -> Array:
    b2 = jnp.square(b)
    r2 = jnp.square(r)

    # This is a hack for when r -> 0 or b -> 0, so k2 -> inf
    factor = 4 * b * r
    k2_cond = jnp.less(factor, 10 * jnp.finfo(factor.dtype).eps)
    factor = jnp.where(k2_cond, 1, factor)
    k2 = jnp.maximum(0, (1 - r2 - b2 + 2 * b * r) / factor)

    # And for when r -> 0
    r_cond = jnp.less(r, 10 * jnp.finfo(r.dtype).eps)
    delta = (b - r) / (2 * jnp.where(r_cond, 1, r))

    roots, weights = roots_legendre(order)
    rng = 0.5 * kappa0
    phi = rng * roots
    c = jnp.cos(phi + 0.5 * kappa0)
    s = jnp.sin(phi)
    s2 = jnp.square(s)

    f0 = jnp.maximum(0, jnp.where(k2_cond, 1 - r2, factor * (k2 - s2))) ** 1.5
    a1 = s2 - jnp.square(s2)
    a2 = jnp.where(r_cond, 0, delta + s2)
    a4 = 1 - 2 * s2

    ind = []
    arg = []
    n = 0
    for l in range(l_max + 1):  # noqa
        fa3 = (2 * r) ** (l - 1) * f0
        for m in range(-l, l + 1):
            mu = l - m
            nu = l + m

            if mu == 1 and l == 1:
                omz2 = r2 + b2 - 2 * b * r * c
                cond = jnp.less(omz2, 10 * jnp.finfo(omz2.dtype).eps)
                omz2 = jnp.where(cond, 1, omz2)
                z2 = jnp.maximum(0, 1 - omz2)
                result = 2 * r * (r - b * c) * (1 - z2 * jnp.sqrt(z2)) / (3 * omz2)
                arg.append(jnp.where(cond, 0, result))

            elif mu % 2 == 0 and (mu // 2) % 2 == 0:
                arg.append(
                    2 * (2 * r) ** (l + 2) * a1 ** (0.25 * (mu + 4)) * a2 ** (0.5 * nu)
                )

            elif mu == 1 and l % 2 == 0:
                arg.append(fa3 * a1 ** (l // 2 - 1) * a4)

            elif mu == 1:
                arg.append(fa3 * a1 ** ((l - 3) // 2) * a2 * a4)

            elif (mu - 1) % 2 == 0 and ((mu - 1) // 2) % 2 == 0:
                arg.append(2 * fa3 * a1 ** ((mu - 1) // 4) * a2 ** (0.5 * (nu - 1)))

            else:
                n += 1
                continue

            ind.append(n)
            n += 1

    P0 = rng * jnp.sum(jnp.stack(arg) * weights[None, :], axis=1)
    P = jnp.zeros(l_max**2 + 2 * l_max + 1)

    # Yes, using np not jnp here: 'ind' is always static.
    inds = np.stack(ind)

    return P.at[inds].set(P0)

In [123]:
def solution_vector(l_max: int, order: int = 20) -> Callable[[Array, Array], Array]:
    n_max = l_max**2 + 2 * l_max + 1

    @jax.jit
    @partial(jnp.vectorize, signature=f"(),()->({n_max})")
    def impl(b: Array, r: Array) -> Array:
        b = jnp.abs(b)
        r = jnp.abs(r)
        kappa0, kappa1 = kappas(b, r)
        P = p_integral(order, l_max, b, r, kappa0)
        Q = q_integral(l_max, 0.5 * jnp.pi - kappa1)
        return Q - P

    return impl

In [137]:
res = st(np.float32(b),np.float32(r))

In [138]:
res.shape

(9,)

In [139]:
res

DeviceArray([ 2.6882808 ,  0.        ,  1.8490419 , -0.33998373,
              2.233245  ,  0.        ,  0.        , -0.5238213 ,
              0.5238184 ], dtype=float32)