In [None]:
import starry
import jax
import jax.numpy as jnp
import numpy as np

In [None]:
from jaxoplanet.experimental.starry.light_curve.ylm import light_curve
from jaxoplanet.experimental.starry.light_curve.inference import (
    design_matrix,
    cast,
    set_data,
    set_prior,
    map_solve,
    solve,
    get_lnlike,
    get_lnlike_woodbury,
    lnlike,
)
from jaxoplanet.test_utils import assert_allclose

In [None]:
starry.config.lazy = False

# light_curve() test has no occultations.

With the given values of zo, the light curve is invariant for all theta (the condition of zo<=0 results in no occultations being computed). Perhaps we can set zo = xo, giving some occultations for the test.

In [None]:
# current test

l_max = 5
ro = 0.1
xo = jnp.linspace(0, ro + 2, 500)
yo = jnp.zeros(500)
zo = jnp.zeros(500)
inc = 0
obl = np.pi / 2
theta = jnp.linspace(0, np.pi, 500)
n_max = (l_max + 1) ** 2
y = np.random.uniform(0, 1, n_max)
y[0] = 1.0

m = starry.Map(l_max)
expect = m.ops.flux(theta, xo, yo, zo, ro, inc, obl, y, m._u, m._f) * (
    0.5 * np.sqrt(np.pi)
)

calc = light_curve(l_max, inc, obl, y, xo, yo, zo, ro, theta)

assert_allclose(calc, expect)

In [None]:
calc[0:30]  # All values are the same.

In [None]:
# Since all of zo == 0 ...

zo[0:30]

In [None]:
# all points are masked out, and none are occulted.

b = jnp.sqrt(xo**2 + yo**2)

# occultation mask
cond_rot = (b >= (1.0 + ro)) | (zo <= 0.0) | (ro == 0.0)
cond_rot[0:30]

In [None]:
# propsed change: (zo == xo)

l_max = 5

ro = 0.1
xo = jnp.linspace(0, ro + 2, 500)
yo = jnp.zeros(500)
zo = jnp.linspace(0, ro + 2, 500)
inc = 0
obl = np.pi / 2
theta = jnp.linspace(0, np.pi, 500)
n_max = (l_max + 1) ** 2
y = np.random.uniform(0, 1, n_max)
y[0] = 1.0

m = starry.Map(l_max)
expect = m.ops.flux(theta, xo, yo, zo, ro, inc, obl, y, m._u, m._f) * (
    0.5 * np.sqrt(np.pi)
)

calc = light_curve(l_max, inc, obl, y, xo, yo, zo, ro, theta)

assert_allclose(calc, expect)

In [None]:
calc[0:30]  # occultations occur

In [None]:
# With zo == xo, not all points are masked out -> some occultations computed.

b = jnp.sqrt(xo**2 + yo**2)

# occultation mask
cond_rot = (b >= (1.0 + ro)) | (zo <= 0.0) | (ro == 0.0)
cond_rot[0:30]