In [None]:
import jax.numpy as jnp
import jax
import numpy as np
import matplotlib.pyplot as plt
from jax import grad, jit, vmap
import numpyro
from numpyro import distributions as dist
from numpyro import infer

from numpyro_ext import distributions as distx
from numpyro_ext import info, optim

jax.config.update("jax_enable_x64", True)
numpyro.set_host_device_count(2)

from jaxoplanet import light_curves, orbits
import arviz as az
import corner

# Transit fitting tutorial in jaxoplanet:

In this tutorial, I will attempt to reproduce the transit fitting example code in exoplanet (https://gallery.exoplanet.codes/tutorials/transit/) using the package jaxoplanet. Jaxoplanet is built on the Jax autodiff package, so it allows us to do gradient based inference. I will be using Numpyro for the sampling and also will demonstrate the ability to switch between Hamiltonian Monte Carlo and Nested Sampling with jaxns.

Warning, this code is meant to reproduce the exoplanet docs as exactly as possible! Plagiarism is the point (with [attribution](https://gallery.exoplanet.codes/tutorials/transit/) of course).

*jaxoplanet* includes methods for computing the light curves transiting planets. In its simplest form this can be used to evaluate a light curve like you would do with [batman](https://github.com/lkreidberg/batman), for example:

In [None]:
from jaxoplanet.light_curves import LimbDarkLightCurve

# The light curve calculation requires an orbit
orbit = orbits.KeplerianOrbit.init(period=3.456, radius=0.1)

# Compute a limb-darkened light curve using starry
t = jnp.linspace(-0.1, 0.1, 1000)
u = [0.3, 0.2]
lc = LimbDarkLightCurve(u).light_curve(orbit, t=t)

plt.plot(t, lc[0], color="C0", lw=2)
plt.ylabel("relative flux")
plt.xlabel("time [days]")
_ = plt.xlim(t.min(), t.max())

In [None]:
jax.jacfwd(LimbDarkLightCurve(u).light_curve, argnums=1)(orbit, t)

But the real power comes from being able to do gradient based inference with jax

# The transit model in NumPyro:

In this section, we will construct a simple transit fit model using NumPyro and then we will fit a two planet model to simulated data. To start, let’s randomly sample some periods and phases and then create a test light curve:

In [None]:
import numpy as np


def fold(period, t0, time, flux):
    phase = ((time - (t0 - 0.5 * period)) % period) / period - 0.5
    sorted_indices = np.argsort(phase)
    phase_sorted = phase[sorted_indices]
    flux_sorted = flux[sorted_indices]

    return phase_sorted, flux_sorted


def light_curve(params, t):
    # The light curve calculation requires an orbit
    orbit = orbits.KeplerianOrbit.init(
        period=params["period"], radius=params["radius"], time_transit=params["t0"]
    )
    # Compute a limb-darkened light curve using starry
    lc = LimbDarkLightCurve(params["u"]).light_curve(orbit, t=t)
    return lc[0]

In [None]:
np.random.seed(11)
period_true = np.random.uniform(5, 20)
t0_true = np.random.uniform(low=0, high=5.0)
t = np.arange(0, 80, 0.02)
yerr = 5e-4

true_params = {
    "period": period_true,
    "t0": t0_true,
    "radius": 0.1,
    "u": jnp.array([0.3, 0.2]),
}

print(true_params)

# Compute a limb-darkened light curve using starry
lc_true = light_curve(true_params, t)

lc = lc_true + yerr * np.random.normal(size=len(t))

phase, flux = fold(period_true, t0_true, t, lc)
_, flux_true = fold(period_true, t0_true, t, lc_true)

fig = plt.figure()
ax, ax1 = fig.subplots(2, 1)
ax.plot(t, lc, "C0.")
ax.plot(t, lc_true, color="k")
ax1.plot(phase, flux, "C0.")
ax1.plot(phase, flux_true, color="k")
ax1.set_ylabel("relative flux")
ax1.set_xlabel("phase")
_ = plt.xlim(-0.05, 0.05)

In [None]:
ph = ((t - (t0_true - 0.5 * period_true)) % period_true) / period_true - 0.5

In [None]:
plt.plot(phase, flux)

In [None]:
numpyro.enable_x64()


def model(t, yerr, y=None):
    # If we wanted to fit for all the parameters, we could use the following,
    # but we'll keep these fixed for simplicity.

    # log_duration = numpyro.sample("log_duration", dist.Uniform(jnp.log(0.08), jnp.log(0.2)))
    # b = numpyro.sample("b", dist.Uniform(0.0, 1.0))

    log_jitter = numpyro.sample("log_jitter", dist.Normal(jnp.log(yerr), 1.0))
    u = numpyro.sample("u", distx.QuadLDParams())
    period = numpyro.sample("period", dist.Normal(period_true, 0.001))
    t0 = numpyro.sample("t0", dist.Normal(t0_true, 0.01))
    numpyro.deterministic("t0_minutes", t0 * 24 * 60)
    log_r = numpyro.sample("log_r", dist.Normal(jnp.log(0.1), 2.0))
    r = numpyro.deterministic("r", jnp.exp(log_r))
    params = {
        "u": u,
        "t0": t0,
        "radius": r,
        "period": period,
    }

    numpyro.sample(
        "flux",
        dist.Normal(
            light_curve(params, t), jnp.sqrt(yerr**2 + jnp.exp(2 * log_jitter))
        ),
        obs=y,
    )

In [None]:
init_params = {
    "period": period_true,
    "t0": t0_true,
    "log_r": jnp.log(0.1),
    "u": jnp.array([0.3, 0.2]),
}

In [None]:
%timeit light_curve(true_params, t)

In [None]:
sampler_wn = infer.MCMC(
    infer.NUTS(
        model,
        target_accept_prob=0.9,
        dense_mass=False,
        init_strategy=infer.init_to_value(values=init_params),
        regularize_mass_matrix=False,
    ),
    num_warmup=1000,
    num_samples=2000,
    num_chains=2,
    progress_bar=True,
)
%time sampler_wn.run(jax.random.PRNGKey(11), t, yerr, lc)

In [None]:
jax.local_device_count()

In [None]:
inf_data_wn = az.from_numpyro(sampler_wn)
az.summary(inf_data_wn, var_names=["t0", "r", "period", "u", "log_jitter"])

In [None]:
corner_params = true_params
corner_params["r"] = true_params["radius"]
fig = corner.corner(
    inf_data_wn,
    bins=30,
    var_names=["t0", "r", "period", "u"],
    color="k",
    truths=corner_params,
)

## Nested Sampling Tests:

Using the package jaxns, we should be able to use the same numpyro model to run a fit using nested sampling instead of NUTS/HMC. Here's some code to do this, although I will need to further document this and how to get jaxns working alongside numpyro in the near future

In [None]:
from numpyro.contrib.nested_sampling import NestedSampler
from jaxns.prior_transforms.prior import UniformBase

In [None]:
ns = NestedSampler(model)
ns.run(jax.random.PRNGKey(2), t, yerr, lc)

ns_samples = ns.get_samples(jax.random.PRNGKey(121), num_samples=10000)

In [None]:
inf_data = az.from_dict(ns_samples, dropna=True)

In [None]:
az.summary(inf_data, var_names=["t0", "r", "period", "u"])

In [None]:
import pandas as pd

df = inf_data.to_dataframe().dropna()

In [None]:
corner_params = true_params
corner_params["r"] = true_params["radius"]
fig = corner.corner(
    df, bins=30, var_names=["t0", "r", "period", "u"], color="k", truths=corner_params
)