In [None]:
import jax.numpy as jnp
import jax
import numpy as np
import matplotlib.pyplot as plt
from jax import grad, jit, vmap
import numpyro
from numpyro import distributions as dist
from numpyro import infer

from numpyro_ext import distributions as distx
from numpyro_ext import info, optim

numpyro.set_host_device_count(2)
jax.config.update("jax_enable_x64", True)


from jaxoplanet import light_curves, orbits
import arviz as az
import corner

In [None]:
jax.local_device_count()

In [None]:
import jax.numpy as jnp
import jax
import numpy as np
import matplotlib.pyplot as plt
from jax import grad, jit, vmap
from functools import partial

import numpyro
from numpyro import distributions as dist
from numpyro import infer

from numpyro_ext import distributions as distx
from numpyro_ext import info, optim

import arviz as az

In [None]:
from jax.config import config

config.update("jax_debug_nans", True)
config.update("jax_debug_infs", True)
config.update("jax_enable_x64", True)

numpyro.set_host_device_count(2)
numpyro.enable_x64()

In [None]:
from jaxoplanet.experimental.starry import YlmLightCurve, RotationPhase
from jaxoplanet.orbits import KeplerianOrbit

# rotation phase

In [None]:
# use default values
phase = RotationPhase.init(period=10.0)
print(phase)

In [None]:
# user define
phase = RotationPhase.init(period=10.0, t_0=0.1, theta_0=0.1)
print(phase)

# Ylm light curve

In [None]:
# use default values
LY = YlmLightCurve.init(l_max=5)
print(LY)

In [None]:
# user define
y_coeff = np.random.uniform(0, 1, 36)
inc = jnp.pi / 4
obl = 0.1
LY = YlmLightCurve.init(l_max=5, inc=inc, obl=obl, y=y_coeff)
print(LY)

# Example

In [None]:
orbit = KeplerianOrbit.init(period=1.0, radius=0.1)
phase = RotationPhase.init(period=10.0)

# Compute a Ylm light curve
t = jnp.linspace(-0.1, 1, 1000)

LY = YlmLightCurve.init(5)
lc = LY.light_curve(orbit=orbit, phase=phase, t=t)


plt.plot(t, lc, color="C0", lw=2)
# plt.plot(t, expect2, color="C0", lw=2)
plt.ylabel("relative flux")
plt.xlabel("time [days]")
_ = plt.xlim(t.min(), t.max())

In [None]:
%time jax.jacfwd(LY.light_curve, argnums=2)(orbit, phase, t)

In [None]:
def fold(period, t0, time, flux):
    phase = ((time - (t0 - 0.5 * period)) % period) / period - 0.5
    sorted_indices = np.argsort(phase)
    phase_sorted = phase[sorted_indices]
    flux_sorted = flux[sorted_indices]

    return phase_sorted, flux_sorted


def light_curve(params, t):
    # The light curve calculation requires an orbit
    orbit = KeplerianOrbit.init(
        period=params["period"], radius=params["radius"], time_transit=params["t0"]
    )
    # Compute a Ylm light curve
    phase = RotationPhase.init(period=params["rot_period"])
    LY = YlmLightCurve.init(5)
    lc = LY.light_curve(orbit, phase, t)
    return lc

In [None]:
np.random.seed(11)
period_true = np.random.uniform(5, 20)
t0_true = np.random.uniform(low=0, high=5.0)
rot_period_true = 10.0
t = np.arange(0, 80, 0.02)
yerr = 5e-4

true_params = {
    "period": period_true,
    "t0": t0_true,
    "radius": 0.1,
    "rot_period": rot_period_true,
}

print(true_params)

# Compute a Ylm light curve
lc_true = light_curve(true_params, t)

lc = lc_true + yerr * np.random.normal(size=len(t))

phase, flux = fold(period_true, t0_true, t, lc)
_, flux_true = fold(period_true, t0_true, t, lc_true)

fig = plt.figure()
ax, ax1 = fig.subplots(2, 1)
ax.plot(t, lc, "C0.")
ax.plot(t, lc_true, color="k")
ax1.plot(phase, flux, "C0.")
ax1.plot(phase, flux_true, color="k")
ax1.set_ylabel("relative flux")
ax1.set_xlabel("phase")
_ = plt.xlim(-0.05, 0.05)

## numpyro model

In [None]:
def model(t, yerr, y=None):
    # If we wanted to fit for all the parameters, we could use the following,
    # but we'll keep these fixed for simplicity.

    log_jitter = numpyro.sample("log_jitter", dist.Normal(jnp.log(yerr), 1.0))
    period = numpyro.sample("period", dist.Normal(period_true, 0.001))
    t0 = numpyro.sample("t0", dist.Normal(t0_true, 0.01))
    numpyro.deterministic("t0_minutes", t0 * 24 * 60)
    log_r = numpyro.sample("log_r", dist.Normal(jnp.log(0.1), 2.0))
    r = numpyro.deterministic("r", jnp.exp(log_r))
    rot_period = numpyro.sample("rot_period", dist.Normal(rot_period_true, 0.001))
    params = {
        "t0": t0,
        "radius": r,
        "period": period,
        "rot_period": rot_period,
    }

    numpyro.sample(
        "flux",
        dist.Normal(
            light_curve(params, t), jnp.sqrt(yerr**2 + jnp.exp(2 * log_jitter))
        ),
        obs=y,
    )

In [None]:
init_params = {
    "period": period_true,
    "t0": t0_true,
    "log_r": jnp.log(0.1),
    "rot_period": rot_period_true,
}

In [None]:
jax.config.update("jax_log_compiles", False)

In [None]:
sampler_wn = infer.MCMC(
    infer.NUTS(
        model,
        target_accept_prob=0.9,
        dense_mass=False,
        init_strategy=infer.init_to_value(values=init_params),
        regularize_mass_matrix=False,
    ),
    num_warmup=2,
    num_samples=3,
    num_chains=2,
    progress_bar=True,
)
%time sampler_wn.run(jax.random.PRNGKey(11), t, yerr, lc)

In [None]:
sampler = infer.NUTS(
    model,
    target_accept_prob=0.9,
    dense_mass=False,
    init_strategy=infer.init_to_value(values=init_params),
    regularize_mass_matrix=False,
)

In [None]:
dir(sampler)

In [None]:
for key in dir(sampler):
    print(key, getattr(sampler, key))

In [None]:
inf_data_wn = az.from_numpyro(sampler_wn)
az.summary(inf_data_wn, var_names=["t0", "r", "period", "rot_period", "log_jitter"])