In [None]:
import numpy as np
import matplotlib.pyplot as plt
import matplotlib as mpl

mpl.rcParams.update(
    {
        "text.usetex": False,
        "axes.labelsize": 20,
        "figure.labelsize": 18,
        "xtick.labelsize": 16,
        "ytick.labelsize": 16,
        "figure.constrained_layout.wspace": 0,
        "figure.constrained_layout.hspace": 0,
        "figure.constrained_layout.h_pad": 0,
        "figure.constrained_layout.w_pad": 0,
        "axes.linewidth": 1.2,
    }
)

import jax
import jax.numpy as jnp

# you should always set this
jax.config.update("jax_enable_x64", True)

## Multiband GP Fitting
This notebook present a complete workflow on how to perform multiband fitting using **EzTaoX**. A damped random walk (DRW) GP kernel is assumed. 

In [None]:
from eztaox.kernels.quasisep import Exp
from eztaox.simulator import UniVarSim
from eztaox.ts_utils import add_noise

### 1. Light Curve Simulation
We first simulate DRW light curves in two different bands. The intrinsic DRW timescales are set to be the same across bands, and the amplitudes are set to differ. We also add a five-day time delay between these two bands.

In [None]:
drw_scale = {"g": 100, "r": 100}
drw_sigma = {"g": 0.25, "r": 0.15}
lc_seed = 2
sampling_seed = {"g": 2, "r": 5}  # seed for random sampling
noise_seed = {"g": 1, "r": 2}  # seed for mocking observational noise
min_dt, max_dt = 1, 3650.0
ts, ys, yerrs, ys_noisy = {}, {}, {}, {}

for band in "gr":
    sim_params = {
        "log_kernel_param": jnp.log(jnp.asarray([drw_scale[band], drw_sigma[band]]))
    }
    s = UniVarSim(Exp(*sim_params["log_kernel_param"]), min_dt, max_dt, sim_params)

    # simulate light curve, add noise
    sim_t, sim_y = s.random(
        200,
        jax.random.PRNGKey(lc_seed),
        jax.random.PRNGKey(sampling_seed[band]),
    )

    # add to dict
    ts[band] = sim_t
    ys[band] = sim_y
    yerrs[band] = sim_yerr

    # add simulated photometric noise
    sim_yerr = jnp.ones_like(sim_t) * 0.05
    ys_noisy[band] = add_noise(sim_y, sim_yerr, jax.random.PRNGKey(noise_seed[band]))

## add time lag
ts["r"] += 5

for b in "gr":
    plt.errorbar(
        ts[b][::1], ys_noisy[b][::1], yerrs[b][::1], fmt=".", label=f"{b}-band"
    )

plt.xlabel("Time (day)")
plt.ylabel("Flux (mag)")
plt.legend(fontsize=15)

### 2. Light Curve Formatting

To fit multi-band data, you need to put the LCs into a specific format. If your LC are stored in a dictionary (with the key being the band name), see example in Section I, you can use the following function to format it. The output are X, y, yerr:

- **X**: A tuple of arrays in the format of (time, band index)
    - *time*: An array of time stamps for observations in all bands.
    - *band index*: An array of integers, starting with 0. This array has the same size as the *time* array. Observations with the band index belong to the same band. Band assigned with a band index of 0 is treated as the 'reference' band. 
- **y**: An array of observed values (from all bands).
- **yerr**: An array observational uncertainties associated with **y**.

In [None]:
from eztaox.ts_utils import formatlc

In [None]:
band_index = {"g": 0, "r": 1}
X, y, yerr = formatlc(ts, ys_noisy, yerrs, band_index)
X, y, yerr

### 3. The Inference Interface
Classes included in the `eztaox.models` module constitute the main interface for performing light curve modeling. 

In [None]:
from eztaox.kernels.quasisep import MultibandLowRank
from eztaox.models import MultiVarModel
from eztaox.fitter import random_search
import numpyro
from numpyro.handlers import seed as numpyro_seed
import numpyro.distributions as dist

#### 3.1 Initialize a light curve model

In [None]:
# define model parameters
has_lag = True  # True: Fit for inter-band lag
zero_mean = True  # True: Fit for light curve mean
nBand = 2  # number of bands in the provide light curve (X, y, yerr)

# initialize a GP kernel, note the initial parameters are not used in the fitting
k = Exp(scale=100.0, sigma=1.0)
m = MultiVarModel(X, y, yerr, k, nBand, has_lag=has_lag, zero_mean=zero_mean)
m

#### 3.2 Maximum Likelihood (MLE) Fitting

To find the best-fit parameters, one can start at a random point in the parameter space and optimize the likelihood function until it stops changing. This likelihood function given a set of new parameters can be obtained by calling `MutliVarModel.log_prob(params)`. However, I find this approach often stuck in local minima. **EzTaoX** provides a fitter function (`random_search`) to alleviate this issue (to some level). The `random_search` function first does a random search (i.e., evaluate the likelihood at a large number of randomly chosen positions in the parameter space) and then select a few (defaults to five) positions with the highest likelihood to proceed with additional non-linear optimization (e.g., using L-BFGS-B).

The `random_search` function takes the following arguments:
- **model**: an instance of `MultiVarModel`
- **initSampler**: a custom function (you need to provide) for generating random samples for the random search step.
- **prng_key**: a **JAX** random number generator key.
- **nSample**: number of random samples to draw.
- **nBest**: number of best samples to keep for continued optimization.
- **jaxoptMethod**: fine optimization method to use. => see [**here**](https://jaxopt.github.io/stable/_autosummary/jaxopt.ScipyMinimize.html#jaxopt.ScipyMinimize.method) for supported methods.
- **batch_size**: The number of likelihood to evaluate each time. Defaults to 1000, for simpler models (and if you have enough memory), you can set this to `nSample`.

#### InitSampler
The `initSampler` defines a prior distribution from which to draw random samples to evaluate the likelihood. It shares a similar structure as that used to perform MCMC using `numpyro` (see Section 4). The distribution for a parameter can take any shape, as long as it has a `numpyro` implementation. A list of `numpyro` distributions can be found [here](https://num.pyro.ai/en/stable/distributions.html). 

In [None]:
def initSampler():
    # GP kernel param
    log_drw_scale = numpyro.sample(
        "drw_scale", dist.Uniform(jnp.log(0.01), jnp.log(1000))
    )
    log_drw_sigma = numpyro.sample(
        "drw_sigma", dist.Uniform(jnp.log(0.01), jnp.log(10))
    )
    log_kernel_param = jnp.stack([log_drw_scale, log_drw_sigma])
    numpyro.deterministic("log_kernel_param", log_kernel_param)

    # parameters to relate the amplitudes in each band
    log_amp_scale = numpyro.sample("log_amp_scale", dist.Uniform(-2, 2))

    mean = numpyro.sample(
        "mean",
        dist.Uniform(low=jnp.asarray([-0.1, -0.1]), high=jnp.asarray([0.1, 0.1])),
    )

    # interband lags
    lag = numpyro.sample("lag", dist.Uniform(-10, 10))

    sample_params = {
        "log_kernel_param": log_kernel_param,
        "log_amp_scale": log_amp_scale,
        "mean": mean,
        "lag": lag,
    }

    return sample_params

In [None]:
# generate a random initial guess
sample_key = jax.random.PRNGKey(1)
prior_sample = numpyro_seed(initSampler, rng_seed=sample_key)()
prior_sample

#### **A note on model parameters**:
- **`log_kernel_param`**: The parameters of the latent GP process. 
- **`log_amp_scale`**: This parameter characterizes the log of the ratio between the amplitude of the GP in each band relative to the latent GP (i.e., the $S$ parameter in the kernel function). Since the $S$ parameter is set to 1 by default, `log_amp_scale` is an array of size M-1, where M is the number of bands.
- **`mean`**: The mean of the light curve in each band, with a size M. 
- **`lag`**: The inter-band lags with respect to the reference band. `lag` is any array with a size M-1

#### Try MLE Fitting

In [None]:
%%time
model = m
sampler = initSampler
fit_key = jax.random.PRNGKey(1)
nSample = 1_000
nBest = 5  # it seems like this number needs to be high

bestP, ll = random_search(model, initSampler, fit_key, nSample, nBest)
bestP

### 4.0 MCMC
MCMC sampling is carried out using the `numpyro` package, which is native to `JAX`. In this example, I will use the NUTS algorithm, however, there are a large collection of sampling algorithms that you can pick from (see [here](https://num.pyro.ai/en/stable/infer.html)). In addition, you can freely specify more flexible (no longer just flat!!) prior distributions for each parameter in the light curve model. 

#### Define `numpyro` MCMC model

In [None]:
def numpyro_model(X, yerr, y=None):
    # GP kernel param
    log_drw_scale = numpyro.sample(
        "log_drw_scale", dist.Uniform(jnp.log(0.01), jnp.log(1000))
    )
    log_drw_sigma = numpyro.sample(
        "log_drw_sigma", dist.Uniform(jnp.log(0.01), jnp.log(10))
    )
    log_kernel_param = jnp.stack([log_drw_scale, log_drw_sigma])
    numpyro.deterministic("log_kernel_param", log_kernel_param)

    # parameters to relate the amplitudes in each band
    log_amp_scale = numpyro.sample("log_amp_scale", dist.Uniform(-2, 2))

    mean = numpyro.sample(
        "mean",
        dist.Uniform(low=jnp.asarray([-0.1, -0.1]), high=jnp.asarray([0.1, 0.1])),
    )

    # interband lags
    lag = numpyro.sample("lag", dist.Uniform(-10, 10))

    sample_params = {
        "log_kernel_param": log_kernel_param,
        "log_amp_scale": log_amp_scale,
        "mean": mean,
        "lag": lag,
    }

    ## the following is different from the initSampler
    has_lag = True
    zero_mean = True
    nBand = 2

    k = Exp(scale=100.0, sigma=1.0)  # init params for k are not used
    m = MultiVarModel(X, y, yerr, k, nBand, has_lag=has_lag, zero_mean=zero_mean)
    m.sample(sample_params)

#### Run MCMC

In [None]:
from numpyro.infer import MCMC, NUTS, init_to_median
import arviz as az

In [None]:
%%time
nuts_kernel = NUTS(
    numpyro_model,
    dense_mass=True,
    target_accept_prob=0.9,
    init_strategy=init_to_median,
)

mcmc = MCMC(
    nuts_kernel,
    num_warmup=500,
    num_samples=1000,
    num_chains=1,
    # progress_bar=False,
)

mcmc_seed = 0
mcmc.run(jax.random.PRNGKey(mcmc_seed), X, yerr, y=y)
data = az.from_numpyro(mcmc)
mcmc.print_summary()

#### Visualize Chains, Posterior Distributions

In [None]:
import warnings

warnings.filterwarnings("ignore", category=RuntimeWarning)

In [None]:
az.plot_trace(
    data, var_names=["log_drw_scale", "log_drw_sigma", "log_amp_scale", "lag"]
)
plt.subplots_adjust(hspace=0.4)

In [None]:
az.plot_pair(
    data,
    var_names=["log_drw_scale", "log_drw_sigma", "log_amp_scale", "lag"],
    reference_values={
        "log_drw_scale": np.log(drw_scale["g"]),
        "log_drw_sigma": np.log(drw_sigma["g"]),
        "log_amp_scale": np.log(drw_sigma["r"] / drw_sigma["g"]),
        "lag": 5.0,
    },
    reference_values_kwargs={"color": "orange", "markersize": 20, "marker": "s"},
    kind="scatter",
    marginals=True,
    textsize=25,
)
plt.subplots_adjust(hspace=0.0, wspace=0.0)

### 5.0 Second-order Statistics
`EzTaoX` provides a unified class (`extaox.kernel_stat2.gpStat2`) for generating the second-order statistic functions (ACF, SF, and PSD) of any supported kernels. All you need to do is initialize a `gpStat2` instance with your desired kernel. 

In [None]:
from eztaox.kernel_stat2 import gpStat2

In [None]:
ts = np.logspace(-1, 3)
fs = np.logspace(-3, 3)

#### 5.1 Get MCMC Samples

In [None]:
flatPost = data.posterior.stack(sample=["chain", "draw"])

log_drw_draws = flatPost["log_kernel_param"].values.T
log_amp_scale_draws = flatPost["log_amp_scale"].values.T
lag_draws = flatPost["lag"].values.T

#### 5.2 `g`-band SF

In [None]:
# create a 2nd statistic object using the true g-band kernel
g_drw = Exp(scale=drw_scale["g"], sigma=drw_sigma["g"])
gpStat2_g = gpStat2(g_drw)

# compute sf for MCMC draws
mcmc_sf_g = jax.vmap(gpStat2_g.sf, in_axes=(None, 0))(ts, jnp.exp(log_drw_draws))

In [None]:
## plot
# ture SF
plt.loglog(ts, gpStat2_g.sf(ts), c="k", label="True g-band SF", zorder=100, lw=2)
plt.loglog(ts, mcmc_sf_g[0], label="MCMC g-band SF", c="tab:green", alpha=0.8, lw=2)
plt.legend(fontsize=15)
# MCMC SFs
for sf in mcmc_sf_g[::20]:
    plt.loglog(ts, sf, c="tab:green", alpha=0.15)

plt.xlabel("Time")
plt.ylabel("SF")

#### 5.3 `r`-band SF

In [None]:
# create a 2nd statistic object using the true g-band kernel
r_drw = Exp(scale=drw_scale["r"], sigma=drw_sigma["r"])
gpStat2_r = gpStat2(r_drw)

# compute sf for MCMC draws
log_drw_draws_r = log_drw_draws.copy()
log_drw_draws_r[:, 1] += log_amp_scale_draws
mcmc_sf_r = jax.vmap(gpStat2_r.sf, in_axes=(None, 0))(ts, jnp.exp(log_drw_draws_r))

In [None]:
## plot
# ture SF
plt.loglog(ts, gpStat2_r.sf(ts), c="k", label="True r-band SF", zorder=100, lw=2)
plt.loglog(ts, mcmc_sf_r[0], label="MCMC r-band SF", c="tab:red", alpha=0.8, lw=2)
plt.legend(fontsize=15)
# MCMC SFs
for sf in mcmc_sf_r[::20]:
    plt.loglog(ts, sf, c="tab:red", alpha=0.15)

plt.xlabel("Time")
plt.ylabel("SF")

### 6.0 Lag distribution

In [None]:
_ = plt.hist(lag_draws, density=True)
plt.vlines(5.0, ymin=0, ymax=1, color="k", lw=2, label="True g-r Lag")
plt.legend(fontsize=15, loc=2)

plt.xlabel("Lag")
plt.ylim(0, 0.5)