In [None]:
import starry
import jax
import jax.numpy as jnp
import numpy as np

In [None]:
from jaxoplanet.experimental.starry.light_curve.ylm import light_curve
from jaxoplanet.experimental.starry.light_curve.inference import (
    design_matrix,
    cast,
    set_data,
    set_prior,
    map_solve,
    solve,
    get_lnlike,
    get_lnlike_woodbury,
    lnlike,
)
from jaxoplanet.test_utils import assert_allclose

In [None]:
starry.config.lazy = False

# light_curve() test has no occultations.

With the given values of zo, the light curve is invariant for all theta (the condition of zo<=0 results in no occultations being computed). Perhaps we can set zo = xo, giving some occultations for the test.

In [None]:
# current test

l_max = 5
ro = 0.1
xo = jnp.linspace(0, ro + 2, 500)
yo = jnp.zeros(500)
zo = jnp.zeros(500)
inc = 0
obl = np.pi / 2
theta = jnp.linspace(0, np.pi, 500)
n_max = (l_max + 1) ** 2
y = np.random.uniform(0, 1, n_max)
y[0] = 1.0

m = starry.Map(l_max)
expect = m.ops.flux(theta, xo, yo, zo, ro, inc, obl, y, m._u, m._f) * (
    0.5 * np.sqrt(np.pi)
)

calc = light_curve(l_max, inc, obl, y, xo, yo, zo, ro, theta)

assert_allclose(calc, expect)

In [None]:
calc[0:30]  # All values are the same.

In [None]:
# Since all of zo == 0 ...

zo[0:30]

In [None]:
# all points are masked out, and none are occulted.

b = jnp.sqrt(xo**2 + yo**2)

# occultation mask
cond_rot = (b >= (1.0 + ro)) | (zo <= 0.0) | (ro == 0.0)
cond_rot[0:30]

In [None]:
# propsed change: (zo == xo)

l_max = 5

ro = 0.1
xo = jnp.linspace(0, ro + 2, 500)
yo = jnp.zeros(500)
zo = jnp.linspace(0, ro + 2, 500)
inc = 0
obl = np.pi / 2
theta = jnp.linspace(0, np.pi, 500)
n_max = (l_max + 1) ** 2
y = np.random.uniform(0, 1, n_max)
y[0] = 1.0

m = starry.Map(l_max)
expect = m.ops.flux(theta, xo, yo, zo, ro, inc, obl, y, m._u, m._f) * (
    0.5 * np.sqrt(np.pi)
)

calc = light_curve(l_max, inc, obl, y, xo, yo, zo, ro, theta)

np.testing.assert_allclose(expect, calc, atol=1e-5)

In [None]:
calc[0:30]  # occultations occur

In [None]:
# With zo == xo, not all points are masked out -> some occultations computed.

b = jnp.sqrt(xo**2 + yo**2)

# occultation mask
cond_rot = (b >= (1.0 + ro)) | (zo <= 0.0) | (ro == 0.0)
cond_rot[0:30]

### NOTE: angle unit for inc, obl

If the angle unit is set to degrees, and obl=pi/2 is giving the correct result, is jaxoplanet also in degrees?

In [None]:
m.inc, m.obl, m._angle_unit, obl

# Design matrix test

- I can't figure out how to create a jaxoplanet design matrix that matches starry's design matrix.
- I can't figure out how to compute equivalent starry design matrices using m.design_matrix() and m.ops.X().

In [None]:
l_max = 5
ro = 0.1
xo = jnp.linspace(0, ro + 2, 500)
yo = jnp.zeros(500)
zo = jnp.linspace(0, ro + 2, 500)
inc = 0
obl = np.pi / 2
theta = jnp.linspace(0, np.pi, 500)
n_max = (l_max + 1) ** 2
y = np.random.uniform(0, 1, n_max)
y[0] = 1.0

m = starry.Map(l_max)
m.y[:] = y

First I see how m.flux() compares with m.ops.flux().

In [None]:
expect_A = m.flux(xo=xo, yo=yo, zo=zo, ro=ro)
expect_B = m.ops.flux(theta, xo, yo, zo, ro, inc, obl, y, m._u, m._f)

In [None]:
np.testing.assert_allclose(expect_A, expect_B)

In [None]:
expect_ops = m.ops.flux(theta, xo, yo, zo, ro, inc, obl, y, m._u, m._f) * (
    0.5 * np.sqrt(np.pi)
)
calc = light_curve(l_max, inc, obl, y, xo, yo, zo, ro, theta)

In [None]:
np.testing.assert_allclose(calc, expect_ops, atol=1e-5)  # passes only at atol=1e-5

Now, check if m.design_matrix() can be equivalent to m.ops.X().

In [None]:
st_X = m.design_matrix(xo=xo, yo=yo, zo=zo, ro=ro)
st_X_theta = m.design_matrix(xo=xo, yo=yo, zo=zo, ro=ro, theta=theta)
st_X_ops = m.ops.X(theta, xo, yo, zo, ro, inc, obl, m._u, m._f)

In [None]:
np.testing.assert_allclose(st_X, st_X_theta)

In [None]:
np.testing.assert_allclose(st_X, st_X_ops)

In [None]:
np.testing.assert_allclose(st_X_ops, st_X_theta)

Compare against jaxoplanet implementation.

In [None]:
j_X = design_matrix(l_max, inc, obl, y, xo, yo, zo, ro, theta)

In [None]:
np.testing.assert_allclose(j_X, st_X)

In [None]:
np.testing.assert_allclose(j_X, st_X_theta)

In [None]:
np.testing.assert_allclose(j_X, st_X_ops)

# Tests in the style of test_solve_greedy.py

## test_solve()

In [None]:
from scipy.stats import multivariate_normal

Original code from test_solve_greedy()

In [None]:
def data():
    # Instantiate a dipole map
    map = starry.Map(ydeg=1, reflected=True)
    amp_true = 0.75
    inc_true = 60
    y_true = np.array([1, 0.1, 0.2, 0.3])
    map.amp = amp_true
    map[1, :] = y_true[1:]
    map.inc = inc_true

    # Generate a synthetic light curve with just a little noise
    theta = np.linspace(0, 360, 100)
    phi = 3.5 * theta
    xs = np.cos(phi * np.pi / 180)
    ys = 0.1 * np.cos(phi * np.pi / 180)
    zs = np.sin(phi * np.pi / 180)
    kwargs = dict(theta=theta, xs=xs, ys=ys, zs=zs)
    flux = map.flux(**kwargs)
    sigma = 1e-5
    np.random.seed(1)
    flux += np.random.randn(len(theta)) * sigma

    return (map, kwargs, amp_true, inc_true, y_true, sigma, flux)


map, kwargs, amp_true, inc_true, y_true, sigma, flux = data()

# Place a generous prior on the map coefficients
map.set_prior(L=1)

# Provide the dataset
map.set_data(flux, C=sigma**2)

# Solve the linear problem
map.inc = inc_true
mu, cho_cov = map.solve(**kwargs)

# Ensure the likelihood of the true value is close to that of the MAP solution
cov = cho_cov.dot(cho_cov.T)
LnL0 = multivariate_normal.logpdf(mu, mean=mu, cov=cov)
LnL = multivariate_normal.logpdf(amp_true * y_true, mean=mu, cov=cov)
assert LnL0 - LnL < 5.00

In [None]:
LnL0, LnL

Modify above test slightly to be suitable for jaxoplanet testing.

In [None]:
# starry version of test


def data():
    l_max = 1
    ro = 0.1
    xo = np.linspace(0, ro + 2, 500)
    yo = np.zeros(500)
    zo = np.linspace(0, ro + 2, 500)
    inc = 0
    obl = np.pi / 2
    theta = np.linspace(0, np.pi, 500)
    n_max = (l_max + 1) ** 2
    y = np.random.uniform(0, 1, n_max)
    y[0] = 1.0

    # Instantiate a dipole map
    map = starry.Map(ydeg=l_max)
    map[1, :] = y[1:]
    # map.inc = inc
    # map.obl = obl

    # Generate a synthetic light curve with just a little noise
    kwargs = dict(theta=theta, xo=xo, yo=yo, zo=zo, ro=ro)
    true_flux = map.flux(**kwargs)
    sigma = 1e-5
    np.random.seed(1)
    syn_flux = true_flux + np.random.randn(len(theta)) * sigma

    return (map, kwargs, inc, obl, y, sigma, syn_flux)


map, kwargs, inc, obl, y, sigma, syn_flux = data()

map.set_prior(L=1)
map.set_data(syn_flux, C=sigma**2)

# Solve the linear problem
mu, cho_cov = map.solve(**kwargs)

# Ensure the likelihood of the true value is close to that of
# the MAP solution
cov = cho_cov.dot(cho_cov.T)
LnL0 = multivariate_normal.logpdf(mu, mean=mu, cov=cov, allow_singular=True)
LnL = multivariate_normal.logpdf(y, mean=mu, cov=cov, allow_singular=True)
assert LnL0 - LnL < 5.00

In [None]:
LnL0, LnL

Jaxoplanet version of map_solve()

In [None]:
# ...

## test_map_solve()

Now, a jaxoplanet version of a similar test (testing map_solve() rather than solve()).

In [None]:
def data():
    # Generate a synthetic light curve with just a little noise
    l_max = 1
    ro = 0.1
    xo = np.linspace(0, ro + 2, 500)
    yo = np.zeros(500)
    zo = np.linspace(0, ro + 2, 500)
    inc = 0
    obl = np.pi / 2
    theta = np.linspace(0, np.pi, 500)
    n_max = (l_max + 1) ** 2
    y = np.random.uniform(0, 1, n_max)
    y[0] = 1.0
    # kwargs = dict(theta=theta, xo=xo, yo=yo, zo=zo)

    true_flux = light_curve(l_max, inc, obl, y, xo, yo, zo, ro, theta)

    sigma = 1e-5
    np.random.seed(1)
    syn_flux = true_flux + np.random.randn(len(theta)) * sigma

    X = design_matrix(l_max, inc, obl, y, xo, yo, zo, ro, theta)

    return (l_max, true_flux, syn_flux, sigma, y, X)


l_max, true_flux, syn_flux, sigma, y, X = data()

# Place a generous prior on the map coefficients
(calc_mu, calc_L) = set_prior(l_max, L=1)

# Provide the dataset
(calc_flux, calc_C) = set_data(syn_flux, C=sigma**2)

# Solve the linear problem
mu, cho_cov = map_solve(X, syn_flux, calc_C[1], calc_mu, calc_L[2])

# Ensure the likelihood of the true value is close to that of the MAP solution
cov = np.dot(cho_cov, np.transpose(cho_cov))
LnL0 = multivariate_normal.logpdf(mu, mean=mu, cov=cov)
LnL = multivariate_normal.logpdf(y, mean=mu, cov=cov)
assert LnL0 - LnL < 5.00

In [None]:
LnL0, LnL

## test_lnlike()

original starry test

In [None]:
def data():
    # Instantiate a dipole map
    map = starry.Map(ydeg=1, reflected=True)
    amp_true = 0.75
    inc_true = 60
    y_true = np.array([1, 0.1, 0.2, 0.3])
    map.amp = amp_true
    map[1, :] = y_true[1:]
    map.inc = inc_true

    # Generate a synthetic light curve with just a little noise
    theta = np.linspace(0, 360, 100)
    phi = 3.5 * theta
    xs = np.cos(phi * np.pi / 180)
    ys = 0.1 * np.cos(phi * np.pi / 180)
    zs = np.sin(phi * np.pi / 180)
    kwargs = dict(theta=theta, xs=xs, ys=ys, zs=zs)
    flux = map.flux(**kwargs)
    sigma = 1e-5
    np.random.seed(1)
    flux += np.random.randn(len(theta)) * sigma

    return (map, kwargs, amp_true, inc_true, y_true, sigma, flux)


map, kwargs, amp_true, inc_true, y_true, sigma, flux = data()

# Place a generous prior on the map coefficients
map.set_prior(L=1)

# Provide the dataset
map.set_data(flux, C=sigma**2)

# Compute the marginal log likelihood for different inclinations
incs = [15, 30, 45, 60, 75, 90]
ll = np.zeros_like(incs, dtype=float)
for i, inc in enumerate(incs):
    map.inc = inc
    ll[i] = map.lnlike(woodbury=True, **kwargs)

# Verify that we get the correct inclination
assert incs[np.argmax(ll)] == 60
assert np.allclose(ll[np.argmax(ll)], 974.221605)  # benchmarked

Can we create a starry version of our jaxoplanet test to get the benchmark?

In [None]:
def data():
    l_max = 1
    ro = 0.1
    xo = np.linspace(0, ro + 2, 500)
    yo = np.zeros(500)
    zo = np.linspace(0, ro + 2, 500)
    inc = 60
    obl = 90
    theta = np.linspace(0, np.pi, 500)
    n_max = (l_max + 1) ** 2
    y = np.random.uniform(0, 1, n_max)
    y[0] = 1.0

    # Instantiate a dipole map
    map = starry.Map(ydeg=l_max)
    map[1, :] = y[1:]
    map.inc = inc

    # Generate a synthetic light curve with just a little noise
    kwargs = dict(theta=theta, xo=xo, yo=yo, zo=zo, ro=ro)
    true_flux = map.flux(**kwargs)
    sigma = 1e-5
    np.random.seed(1)
    syn_flux = true_flux + np.random.randn(len(theta)) * sigma

    return (map, kwargs, inc, obl, y, sigma, syn_flux)


map, kwargs, inc, obl, y, sigma, syn_flux = data()

# Place a generous prior on the map coefficients
map.set_prior(L=1)

# Provide the dataset
map.set_data(syn_flux, C=sigma**2)

# Compute the marginal log likelihood for different inclinations
incs = [0, 15, 30, 45, 60, 75, 90]
ll = np.zeros_like(incs, dtype=float)
for i, inc in enumerate(incs):
    map.inc = inc
    ll[i] = map.lnlike(woodbury=False, **kwargs)

# Verify that we get the correct inclination
assert incs[np.argmax(ll)] == 60
assert np.allclose(ll[np.argmax(ll)], 5015)  # benchmarked

Now, a jaxoplanet version of the test.

In [None]:
def data():
    # Generate a synthetic light curve with just a little noise
    l_max = 1
    ro = 0.1
    xo = jnp.linspace(0, ro + 2, 500)
    yo = jnp.zeros(500)
    zo = jnp.linspace(0, ro + 2, 500)
    inc = np.pi / 3
    obl = np.pi / 2
    theta = jnp.linspace(0, np.pi, 500)
    n_max = (l_max + 1) ** 2
    y = np.random.uniform(0, 1, n_max)
    y[0] = 1.0
    kwargs = dict(l_max=l_max, obl=obl, y=y, xo=xo, yo=yo, zo=zo, ro=ro, theta=theta)

    true_flux = light_curve(l_max, inc, obl, y, xo, yo, zo, ro, theta)

    sigma = 1e-5
    np.random.seed(1)
    syn_flux = true_flux + np.random.randn(len(theta)) * sigma

    return (l_max, n_max, syn_flux, sigma, y, kwargs)


l_max, n_max, syn_flux, sigma, y, kwargs = data()

# Place a generous prior on the map coefficients
(calc_mu, calc_L) = set_prior(l_max, L=1)
L = calc_L[0] * jnp.ones(n_max)
LInv = calc_L[2] * jnp.ones(n_max)
lndetL = cast([calc_L[3]])

# Provide the dataset
(calc_flux, calc_C) = set_data(syn_flux, C=sigma**2)

# Compute the marginal log likelihood for different inclinations
# incs = [0, 15, 30, 45, 60, 75, 90]
# incs = [0, jnp.pi/12, jnp.pi/6, jnp.pi/4, jnp.pi/3, jnp.pi/2]
incs = [0, np.pi / 12, np.pi / 6, np.pi / 4, np.pi / 3, np.pi / 2]
ll = np.zeros_like(incs, dtype=float)
for i, inc in enumerate(incs):
    X = design_matrix(inc=inc, **kwargs)
    # ll[i] = get_lnlike(X, syn_flux, calc_C[0], calc_mu, L)
    ll[i] = get_lnlike_woodbury(
        X, syn_flux, calc_C[2], calc_mu, LInv, calc_C[3], lndetL
    )

# Verify that we get the correct inclination
assert incs[jnp.argmax(ll)] == jnp.pi / 3
assert_allclose(ll[jnp.argmax(ll)], 5002.211, rtol=1e-5)  # benchmarked
# assert np.allclose(ll[jnp.argmax(ll)], 5002.211)