In [None]:
import lightkurve as lk
import numpy as np

def download_sector(tic_id, author="SPOC"):
    search = lk.search_lightcurve(f"TIC {tic_id}", mission="TESS", author=author)
    tbl = search.table
    data = {}
    sectors = np.unique(tbl["sequence_number"])
    for sec in sectors:
        mask = tbl["sequence_number"] == sec
        sub = tbl[mask]
        best_idx = np.argmin(sub["exptime"])
        row = search[mask][best_idx]

        exptime = sub["exptime"][best_idx]
        print(f"Downloading TIC {tic_id} Sector {sec} (exptime={exptime}s) ...")

        lc = (row.download(flux_column="pdcsap_flux")
                .remove_nans()
                .normalize())

        key = (int(sec), float(exptime))
        data[key] = lc

    return data

sector_data = download_sector(29857954)





Downloading TIC 29857954 Sector 28 (exptime=20.0s) ...
Downloading TIC 29857954 Sector 68 (exptime=120.0s) ...
Downloading TIC 29857954 Sector 92 (exptime=120.0s) ...
Downloading TIC 29857954 Sector 95 (exptime=120.0s) ...


In [1]:
# =========================
# Transit fit (NumPyro + jaxoplanet)
# Sector-level global fit (u fixed)
# =========================

import numpy as np
import matplotlib.pyplot as plt

import jax
import jax.numpy as jnp
import numpyro
import numpyro.distributions as dist
from numpyro.infer import MCMC, NUTS

from jaxoplanet.orbits import TransitOrbit
from jaxoplanet.light_curves import limb_dark_light_curve

jax.config.update("jax_enable_x64", True)
numpyro.set_platform("cpu")


# ---------- 0) Choose sector ----------
SECTOR = 95
TIC = 120.0   # you used 120.0 in keys; change if your dict uses 120 or other

lc = sector_data[(SECTOR, TIC)]


# ---------- 1) Extract time/flux ----------
time_np = lc.time.value
flux_np = lc.flux.value

mask = np.isfinite(time_np) & np.isfinite(flux_np)
time_np = time_np[mask]
flux_np = flux_np[mask]

# OPTIONAL: shift time to reduce scale (recommended)
# This does NOT change physics; it just improves numerical conditioning.
t_ref = float(np.min(time_np))
time0_np = time_np - t_ref

# plot raw
plt.figure(figsize=(10, 4))
plt.plot(time_np, flux_np, ".k", ms=2)
plt.xlabel("Time [BTJD]")
plt.ylabel("Flux")
plt.title(f"Raw TESS Light Curve (Sector {SECTOR})")
plt.tight_layout()
plt.show()

# ---------- 2) Simple error model ----------
sigma0 = float(np.std(flux_np))
err_np = np.ones_like(flux_np) * sigma0

print("sigma0 =", sigma0)
print("time span (days) =", float(time_np.max() - time_np.min()))


# ---------- 3) Model (u fixed) ----------
# Fix limb darkening to reduce degeneracy (debug stage)
U1_FIXED = 0.3
U2_FIXED = 0.2

def transit_model_global(t, yerr, y=None):
    """
    Global transit model on (possibly) full-sector data.
    Free: t0, period, duration, r, b, sigma_jit
    Fixed: u1, u2
    """

    # NOTE: after shifting time, t0 prior should match new time range.
    # Here we use Uniform(0, 10) like you did; you can widen if needed.
    t0 = numpyro.sample("t0", dist.Uniform(0.0, 10.0))

    # period prior: you used Uniform(8, 10)
    period = numpyro.sample("period", dist.Uniform(8.0, 10.0))

    # duration prior (log-uniform)
    logD = numpyro.sample("logD", dist.Uniform(jnp.log(0.05), jnp.log(0.3)))
    duration = numpyro.deterministic("duration", jnp.exp(logD))

    # geometry
    r = numpyro.sample("r", dist.Uniform(1e-4, 0.3))
    b = numpyro.sample("b", dist.Uniform(0.0, 1.0))

    # jitter
    sigma_jit = numpyro.sample("sigma_jit", dist.HalfNormal(0.0002))

    orbit = TransitOrbit(
        period=period,
        duration=duration,
        time_transit=t0,
        impact_param=b,
        radius_ratio=r,
    )

    u = jnp.array([U1_FIXED, U2_FIXED])
    delta_flux = limb_dark_light_curve(orbit, u)(t)
    model_flux = 1.0 + delta_flux

    sigma_tot = jnp.sqrt(yerr**2 + sigma_jit**2)
    numpyro.sample("obs", dist.Normal(model_flux, sigma_tot), obs=y)


# ---------- 4) Run MCMC ----------
t = jnp.array(time0_np)   # shifted time
f = jnp.array(flux_np)
e = jnp.array(err_np)

kernel = NUTS(
    transit_model_global,
    max_tree_depth=8,
    target_accept_prob=0.8
)
mcmc = MCMC(kernel, num_warmup=20000, num_samples=2000, num_chains=1, progress_bar=True)

rng_key = jax.random.PRNGKey(42)
mcmc.run(rng_key, t, e, y=f)

mcmc.print_summary()
samples = mcmc.get_samples()


# ---------- 5) Posterior summary helpers ----------
def summarize_param(name):
    arr = np.array(samples[name])
    return arr.mean(), arr.std()

t0_mean, t0_std = summarize_param("t0")
P_mean, P_std = summarize_param("period")
D_mean, D_std = summarize_param("duration")
r_mean, r_std = summarize_param("r")
b_mean, b_std = summarize_param("b")
sj_mean, sj_std = summarize_param("sigma_jit")

print("\n===== Global transit fit (u fixed) =====")
print(f"t0       = {t0_mean:.6f} ± {t0_std:.6f}   (NOTE: in shifted time; add t_ref back)")
print(f"Period   = {P_mean:.6f} ± {P_std:.6f} days")
print(f"Duration = {D_mean:.6f} ± {D_std:.6f} days")
print(f"r (Rp/R*)= {r_mean:.6f} ± {r_std:.6f}")
print(f"b        = {b_mean:.6f} ± {b_std:.6f}")
print(f"sigma_jit= {sj_mean:.6f} ± {sj_std:.6f}")

print("\nConvert t0 back to BTJD:")
print(f"t0_BTJD ≈ {t0_mean + t_ref:.6f} ± {t0_std:.6f}")


# ---------- 6) Plot best-fit (posterior median) ----------
theta_med = {k: np.median(np.array(v)) for k, v in samples.items()}

orbit_map = TransitOrbit(
    period=theta_med["period"],
    duration=theta_med["duration"],
    time_transit=theta_med["t0"],
    impact_param=theta_med["b"],
    radius_ratio=theta_med["r"],
)
u_map = jnp.array([U1_FIXED, U2_FIXED])

delta_flux_map = limb_dark_light_curve(orbit_map, u_map)(t)
model_flux_map = 1.0 + np.array(delta_flux_map)

print("model_flux_map min, max:", model_flux_map.min(), model_flux_map.max())

plt.figure(figsize=(10, 4))
plt.plot(time_np, flux_np, ".k", ms=2, alpha=0.35, label="Data")

# For plotting, convert model time back to BTJD scale
plt.plot(time0_np + t_ref, model_flux_map, "-", lw=1.5, label="Median model")

plt.xlabel("Time [BTJD]")
plt.ylabel("Flux")
plt.title(f"Global Transit Fit (Sector {SECTOR}, u fixed)")
plt.legend()
plt.tight_layout()
plt.show()


# ---------- 7) Trace plot (optional, uses ArviZ) ----------
try:
    import arviz as az
    idata = az.from_numpyro(mcmc)
    az.plot_trace(idata, var_names=["t0", "period", "logD", "r", "b", "sigma_jit"])
    plt.tight_layout()
    plt.show()
except Exception as err:
    print("ArviZ trace plot failed (optional):", err)


RuntimeError: jaxlib version 0.8.2 is newer than and incompatible with jax version 0.4.21. Please update your jax and/or jaxlib packages.