In [None]:
import matplotlib as mpl
import matplotlib.pyplot as plt
import numpy as np

mpl.rcParams.update(
    {
        "text.usetex": False,
        "axes.labelsize": 18,
        "xtick.labelsize": 15,
        "ytick.labelsize": 15,
        "figure.constrained_layout.wspace": 0,
        "figure.constrained_layout.hspace": 0,
        "figure.constrained_layout.h_pad": 0,
        "figure.constrained_layout.w_pad": 0,
        "axes.linewidth": 1.3,
    }
)

import jax
import jax.numpy as jnp

# you should always set this
jax.config.update("jax_enable_x64", True)

## Damped Random Walk (DRW)

The Gaussian Process (GP) kernel corresponding to a Damped Random Walk (DRW) process is

$$
k(|\Delta t|) = \sigma^2 \exp \left(-\frac{|\Delta t|}{\ell}\right),
$$

where $|\Delta t|$ denotes the time separation between two observations. The parameter $\ell$ represents the correlation length scale of the process, while $\sigma^2$ is the asymptotic variance of the GP. In **EzTaoX**, this kernel is implemented via the `kernels.quasisep.Exp` class.

<div class="alert alert-info">

**Note**

In the astronomy literature, a DRW process is commonly parameterized by a damping timescale $\tau_{\rm DRW}$ and a root-mean-square (RMS) variability amplitude $\sigma_{\rm DRW}$. The correspondence with the `Exp` kernel parameters is:

* $\tau_{\rm DRW} = \ell$ (the correlation length scale),
* $\sigma_{\rm DRW} = \sigma$ (the standard deviation, i.e., square root of the asymptotic variance).

</div>


### 1. Light Curve Simulation
We use `UniVarSim` to simulate DRW light curves

In [None]:
from eztaox.kernels.quasisep import Exp
from eztaox.simulator import UniVarSim
from eztaox.ts_utils import add_noise

In [None]:
# Simulated DRW parameters
drw_scale, drw_sigma = 100.0, 0.15
sim_params = {"log_kernel_param": jnp.log(jnp.asarray([drw_scale, drw_sigma]))}


# initiate univariate (i.e., single-band) simulator
min_dt, max_dt = 10, 3650.0
s = UniVarSim(Exp(*sim_params["log_kernel_param"]), min_dt, max_dt, sim_params)

# simulate light curve, add noise
sim_t, sim_y = s.random(200, jax.random.PRNGKey(0), jax.random.PRNGKey(1))
sim_yerr = jnp.ones_like(sim_t) * 0.05
sim_y_noisy = add_noise(sim_y, sim_yerr, jax.random.PRNGKey(2))

In [None]:
plt.errorbar(sim_t, sim_y_noisy, sim_yerr, fmt=".")

plt.xlabel("Time (day)")
plt.ylabel("Flux (mag)")

### 2. Fitting
Here, we demonstrate how to use the `UniVarModel` for fitting single-band light curves.

In [None]:
import numpyro
import numpyro.distributions as dist
from eztaox.fitter import random_search
from eztaox.models import UniVarModel
from numpyro.handlers import seed as numpyro_seed

#### 2.1 Initialize Light Curve Model

In [None]:
# whether assuming the input light curve have mean of zero
zero_mean = False

# initialize a GP kernel, note the initial parameters are not used in the fitting
k = Exp(scale=100.0, sigma=1.0)
m = UniVarModel(sim_t, sim_y_noisy, sim_yerr, k, zero_mean=zero_mean)
m

#### 2.2 Define InitSampler

In [None]:
def initSampler():
    # GP kernel param
    log_drw_scale = numpyro.sample(
        "drw_scale", dist.Uniform(jnp.log(0.01), jnp.log(1000))
    )
    log_drw_sigma = numpyro.sample(
        "drw_sigma", dist.Uniform(jnp.log(0.01), jnp.log(10))
    )
    log_kernel_param = jnp.stack([log_drw_scale, log_drw_sigma])
    numpyro.deterministic("log_kernel_param", log_kernel_param)

    # mean
    mean = numpyro.sample("mean", dist.Uniform(low=-0.2, high=0.2))

    sample_params = {"log_kernel_param": log_kernel_param, "mean": mean}
    return sample_params

In [None]:
# generate a random initial guess
sample_key = jax.random.PRNGKey(1)
prior_sample = numpyro_seed(initSampler, rng_seed=sample_key)()
prior_sample

#### 2.3 MLE Fitting

In [None]:
%%time
model = m
sampler = initSampler
fit_key = jax.random.PRNGKey(1)
nSample = 10_000
nBest = 10  # it seems like this number needs to be high

bestP, ll = random_search(model, initSampler, fit_key, nSample, nBest)
bestP

In [None]:
print("True DRW Params (in natual log):")
print(np.log(np.hstack([drw_scale, drw_sigma])))
print("MLE DHO Params (in natual log):")
print(bestP["log_kernel_param"])

### 3. MCMC

In [None]:
import arviz as az
from numpyro.infer import MCMC, NUTS, init_to_median

In [None]:
def numpyro_model(t, yerr, y=None):
    # GP kernel param
    log_drw_scale = numpyro.sample(
        "log_drw_scale", dist.Uniform(jnp.log(0.01), jnp.log(1000))
    )
    log_drw_sigma = numpyro.sample(
        "log_drw_sigma", dist.Uniform(jnp.log(0.01), jnp.log(10))
    )
    log_kernel_param = jnp.stack([log_drw_scale, log_drw_sigma])
    numpyro.deterministic("log_kernel_param", log_kernel_param)

    # mean: use a normal prior for better convergence
    mean = numpyro.sample("mean", dist.Normal(0.0, 0.1))

    sample_params = {"log_kernel_param": log_kernel_param, "mean": mean}

    # the following is different from the initSampler
    zero_mean = False

    k = Exp(scale=100.0, sigma=1.0)  # init params for k are not used
    m = UniVarModel(sim_t, sim_y_noisy, sim_yerr, k, zero_mean=zero_mean)
    m.sample(sample_params)

In [None]:
%%time
nuts_kernel = NUTS(
    numpyro_model,
    dense_mass=True,
    target_accept_prob=0.9,
    init_strategy=init_to_median,
)

mcmc = MCMC(
    nuts_kernel,
    num_warmup=1000,
    num_samples=5000,
    num_chains=1,
    # progress_bar=False,
)

mcmc_seed = 0
mcmc.run(jax.random.PRNGKey(mcmc_seed), sim_t, sim_yerr, y=sim_y_noisy)
data = az.from_numpyro(mcmc)
mcmc.print_summary()

#### Visualize Chains, Posterior Distributions

In [None]:
import warnings

warnings.filterwarnings("ignore", category=RuntimeWarning)

In [None]:
az.plot_trace(
    data,
    var_names=["log_drw_scale", "log_drw_sigma", "mean"],
)
plt.subplots_adjust(hspace=0.4)

In [None]:
az.plot_pair(
    data,
    var_names=["log_drw_scale", "log_drw_sigma", "mean"],
    reference_values={
        "log_drw_scale": np.log(drw_scale),
        "log_drw_sigma": np.log(drw_sigma),
        "mean": 0.0,
    },
    reference_values_kwargs={"color": "orange", "markersize": 20, "marker": "s"},
    kind="scatter",
    marginals=True,
    textsize=25,
)
plt.subplots_adjust(hspace=0.0, wspace=0.0)

### 4. Second-order Statistics

In [None]:
from eztaox.kernel_stat2 import gpStat2

ts = np.logspace(0, 4)
fs = np.logspace(-4, 0)

In [None]:
# get MCMC samples
flatPost = data.posterior.stack(sample=["chain", "draw"])
log_drw_draws = flatPost["log_kernel_param"].values.T

In [None]:
# create second-order stat object
drw_k = Exp(scale=drw_scale, sigma=drw_sigma)
gpStat2_drw = gpStat2(drw_k)

#### 4.1 Structure Function

In [None]:
# compute sf for MCMC draws
mcmc_sf = jax.vmap(gpStat2_drw.sf, in_axes=(None, 0))(ts, jnp.exp(log_drw_draws))

In [None]:
## plot
# ture SF
plt.loglog(ts, gpStat2_drw.sf(ts), c="k", label="True SF", zorder=100, lw=2)
plt.legend(fontsize=15)
# MCMC SFs
for sf in mcmc_sf[::50]:
    plt.loglog(ts, sf, c="tab:green", alpha=0.15)

plt.xlabel("Time")
plt.ylabel("SF")

#### 4.1 Power Spectral Density (PSD)

In [None]:
# compute sf for MCMC draws
mcmc_psd = jax.vmap(gpStat2_drw.psd, in_axes=(None, 0))(fs, jnp.exp(log_drw_draws))

In [None]:
## plot
# ture PSD
plt.loglog(fs, gpStat2_drw.psd(fs), c="k", label="True PSD", zorder=100, lw=2)
plt.legend(fontsize=15)

# MCMC PSDs
for psd in mcmc_psd[::50]:
    plt.loglog(fs, psd, c="tab:green", alpha=0.15)

plt.xlabel("Frequency")
plt.ylabel("PSD")