# GPU timing, batching, and PPL inference (extension of gradient tutorial)

This notebook extends the gradient-based transmission-string workflow with:

1. Single-run GPU timing and NumPyro/NUTS inference for one light curve.
2. Batched GPU timing and NumPyro/NUTS inference for 10 light curves with different `r0` and `rn_frac` values.

It follows the same model style used in the gradient-based inference tutorial.


In [None]:
import time
import numpy as np
import jax
import jax.numpy as jnp
import matplotlib.pyplot as plt
import matplotlib.cm as cm
import arviz as az
import numpyro
import numpyro.distributions as dist

from harmonica import HarmonicaTransit
from harmonica.jax import harmonica_transit_quad_ld
from harmonica.jax.custom_primitives import harmonica_transit_quad_ld_batch
from numpyro.infer import MCMC, NUTS, init_to_median


In [None]:
gpu_devices = [d for d in jax.devices() if d.platform == "gpu"]
if not gpu_devices:
    raise RuntimeError(
        "No GPU detected by JAX. Install a CUDA-enabled jax/jaxlib build and rerun this notebook."
    )

gpu_device = gpu_devices[0]
print("JAX backend:", jax.default_backend())
print("Selected GPU device:", gpu_device)


## Shared synthetic setup (same orbit/LD style as gradient tutorial)


In [None]:
np.random.seed(12)

times = np.linspace(-0.15, 0.15, 500)
theta = np.linspace(-np.pi, np.pi, 1000)

# Same transmission-string construction pattern as the gradient tutorial.
r_mean = np.array([0.15])
r_dev = np.random.uniform(-0.1, 0.1, size=6)
injected_r = np.concatenate([r_mean, r_dev * r_mean])

ht = HarmonicaTransit(times)
ht.set_orbit(t0=0.0, period=4.0, a=11.0, inc=87.0 * np.pi / 180.0)
ht.set_stellar_limb_darkening(np.array([0.027, 0.246]), limb_dark_law="quadratic")
ht.set_planet_transmission_string(injected_r)

injected_transmission_string = ht.get_planet_transmission_string(theta)
flux_sigma_single = 500.0e-6 * np.ones(times.shape[0])
flux_errs_single = np.random.normal(loc=0.0, scale=flux_sigma_single, size=times.shape[0])
observed_flux_single = ht.get_transit_light_curve() + flux_errs_single

times_jax = jnp.asarray(times)


## Single light-curve GPU timing


In [None]:
single_gpu_fn = jax.jit(
    lambda r: harmonica_transit_quad_ld(
        times_jax,
        t0=0.0,
        period=4.0,
        a=11.0,
        inc=87.0 * np.pi / 180.0,
        ecc=0.0,
        omega=0.0,
        u1=0.027,
        u2=0.246,
        r=r,
    )
)

with jax.default_device(gpu_device):
    t0_compile = time.perf_counter()
    _ = single_gpu_fn(jnp.asarray(injected_r)).block_until_ready()
    compile_plus_first_ms_single = (time.perf_counter() - t0_compile) * 1.0e3

with jax.default_device(gpu_device):
    t0_run = time.perf_counter()
    single_flux_gpu = single_gpu_fn(jnp.asarray(injected_r))
    single_flux_gpu.block_until_ready()
    run_ms_single = (time.perf_counter() - t0_run) * 1.0e3

print(f"Single compile + first run: {compile_plus_first_ms_single:.3f} ms")
print(f"Single steady-state GPU run: {run_ms_single:.3f} ms")


In [None]:
plt.figure(figsize=(10, 6))
plt.errorbar(times, observed_flux_single, yerr=flux_sigma_single, fmt=".k", alpha=0.3, label="Noisy observations")
plt.plot(times, np.asarray(single_flux_gpu), color=cm.BuGn(0.8), lw=2.0, label="GPU model")
plt.xlabel("Time / days")
plt.ylabel("Relative flux")
plt.legend(loc="best")
plt.show()


## Single light-curve PPL model (NumPyro/NUTS)

This mirrors the gradient tutorial model: infer `r0` and `rn_frac`, and build
`r = [r0, rn_frac * r0]` before calling Harmonica.


In [None]:
def numpyro_model_single(t, flux_sigma, f_obs=None):
    r0 = numpyro.sample("r0", dist.Uniform(0.15 - 0.05, 0.15 + 0.05))
    rn_frac = numpyro.sample("rn_frac", dist.Normal(0.0, 0.1), sample_shape=(6,))
    r = numpyro.deterministic("r", jnp.concatenate([jnp.array([r0]), rn_frac * r0]))

    fs = harmonica_transit_quad_ld(
        t,
        t0=0.0,
        period=4.0,
        a=11.0,
        inc=87.0 * np.pi / 180.0,
        u1=0.027,
        u2=0.246,
        r=r,
    )

    numpyro.sample("obs", dist.Normal(fs, flux_sigma), obs=f_obs)

nuts_single = NUTS(
    numpyro_model_single,
    dense_mass=True,
    adapt_mass_matrix=True,
    max_tree_depth=7,
    target_accept_prob=0.75,
    init_strategy=init_to_median(),
)

mcmc_single = MCMC(
    nuts_single,
    num_warmup=200,
    num_samples=500,
    num_chains=1,
    chain_method="sequential",
    progress_bar=True,
)

with jax.default_device(gpu_device):
    t0_ppl = time.perf_counter()
    mcmc_single.run(
        jax.random.PRNGKey(2),
        jnp.asarray(times),
        flux_sigma=jnp.asarray(flux_sigma_single),
        f_obs=jnp.asarray(observed_flux_single),
    )
    ppl_single_s = time.perf_counter() - t0_ppl

print(f"Single-curve NumPyro run time: {ppl_single_s:.2f} s")


In [None]:
single_data = az.from_numpyro(mcmc_single)
az.summary(single_data, var_names=["r0", "rn_frac"], round_to=5)


## Build 10-curve batch with varied `r0` and `rn_frac`


In [None]:
batch_size = 10
rng = np.random.default_rng(21)

# Vary baseline radius across the 10 curves.
r0_values_true = np.linspace(0.12, 0.18, batch_size)

# Draw six rn/r0 terms for each curve.
rn_frac_values_true = rng.normal(0.0, 0.06, size=(batch_size, 6))
r_batch_true = np.concatenate(
    [r0_values_true[:, None], rn_frac_values_true * r0_values_true[:, None]], axis=1
)

# Shared orbital/stellar parameters across curves.
t0_batch = np.zeros(batch_size)
period_batch = np.full(batch_size, 4.0)
a_batch = np.full(batch_size, 11.0)
inc_batch = np.full(batch_size, 87.0 * np.pi / 180.0)
ecc_batch = np.zeros(batch_size)
omega_batch = np.zeros(batch_size)
u1_batch = np.full(batch_size, 0.027)
u2_batch = np.full(batch_size, 0.246)

batch_gpu_fn = jax.jit(
    lambda r: harmonica_transit_quad_ld_batch(
        times_jax,
        t0=jnp.asarray(t0_batch),
        period=jnp.asarray(period_batch),
        a=jnp.asarray(a_batch),
        inc=jnp.asarray(inc_batch),
        ecc=jnp.asarray(ecc_batch),
        omega=jnp.asarray(omega_batch),
        u1=jnp.asarray(u1_batch),
        u2=jnp.asarray(u2_batch),
        r=r,
    )
)

with jax.default_device(gpu_device):
    batch_flux_true = batch_gpu_fn(jnp.asarray(r_batch_true))
    batch_flux_true.block_until_ready()

flux_sigma_batch = 500.0e-6 * np.ones((batch_size, times.shape[0]))
flux_errs_batch = rng.normal(loc=0.0, scale=flux_sigma_batch)
observed_flux_batch = np.asarray(batch_flux_true) + flux_errs_batch


## Batched GPU timing and visualization


In [None]:
with jax.default_device(gpu_device):
    t0_compile_batch = time.perf_counter()
    _ = batch_gpu_fn(jnp.asarray(r_batch_true)).block_until_ready()
    compile_plus_first_ms_batch = (time.perf_counter() - t0_compile_batch) * 1.0e3

with jax.default_device(gpu_device):
    t0_run_batch = time.perf_counter()
    batch_flux_gpu = batch_gpu_fn(jnp.asarray(r_batch_true))
    batch_flux_gpu.block_until_ready()
    run_ms_batch = (time.perf_counter() - t0_run_batch) * 1.0e3

print(f"Batch compile + first run (B={batch_size}): {compile_plus_first_ms_batch:.3f} ms")
print(f"Batch steady-state GPU run (B={batch_size}): {run_ms_batch:.3f} ms")
print("Batch output shape:", batch_flux_gpu.shape)

batch_flux_np = np.asarray(batch_flux_gpu)
rn_frac_rms = np.sqrt(np.mean(rn_frac_values_true**2, axis=1))

plt.figure(figsize=(11, 7))
for i in range(batch_size):
    color = cm.viridis(i / (batch_size - 1))
    label = f"{i + 1}: r0={r0_values_true[i]:.3f}, rn_frac_rms={rn_frac_rms[i]:.3f}"
    plt.plot(times, batch_flux_np[i], color=color, lw=1.7, label=label)

plt.xlabel("Time / days")
plt.ylabel("Relative flux")
plt.title("Batched GPU light curves with varying r0 and rn_frac")
plt.legend(loc="best", fontsize=8, ncol=2)
plt.show()


## Batched PPL model (NumPyro/NUTS)

This is the batched analog of the single model:

- infer `r0` for each of the 10 curves,
- infer `rn_frac` for each curve,
- build batched `r`,
- evaluate `harmonica_transit_quad_ld_batch`,
- condition on the 10 observed light curves.


In [None]:
def numpyro_model_batch(t, flux_sigma, f_obs=None):
    r0 = numpyro.sample("r0", dist.Uniform(0.15 - 0.05, 0.15 + 0.05), sample_shape=(batch_size,))
    rn_frac = numpyro.sample("rn_frac", dist.Normal(0.0, 0.1), sample_shape=(batch_size, 6))
    r = numpyro.deterministic("r", jnp.concatenate([r0[:, None], rn_frac * r0[:, None]], axis=1))

    fs = harmonica_transit_quad_ld_batch(
        t,
        t0=jnp.asarray(t0_batch),
        period=jnp.asarray(period_batch),
        a=jnp.asarray(a_batch),
        inc=jnp.asarray(inc_batch),
        ecc=jnp.asarray(ecc_batch),
        omega=jnp.asarray(omega_batch),
        u1=jnp.asarray(u1_batch),
        u2=jnp.asarray(u2_batch),
        r=r,
    )

    numpyro.sample("obs", dist.Normal(fs, flux_sigma), obs=f_obs)

nuts_batch = NUTS(
    numpyro_model_batch,
    dense_mass=False,
    adapt_mass_matrix=True,
    max_tree_depth=7,
    target_accept_prob=0.75,
    init_strategy=init_to_median(),
)

mcmc_batch = MCMC(
    nuts_batch,
    num_warmup=150,
    num_samples=300,
    num_chains=1,
    chain_method="sequential",
    progress_bar=True,
)

with jax.default_device(gpu_device):
    t0_ppl_batch = time.perf_counter()
    mcmc_batch.run(
        jax.random.PRNGKey(7),
        jnp.asarray(times),
        flux_sigma=jnp.asarray(flux_sigma_batch),
        f_obs=jnp.asarray(observed_flux_batch),
    )
    ppl_batch_s = time.perf_counter() - t0_ppl_batch

print(f"Batched NumPyro run time (B={batch_size}): {ppl_batch_s:.2f} s")


In [None]:
batch_data = az.from_numpyro(mcmc_batch)
az.summary(batch_data, var_names=["r0"], round_to=5)
