In [None]:
import matplotlib as mpl
import matplotlib.pyplot as plt
import numpy as np

mpl.rcParams.update(
    {
        "text.usetex": False,
        "axes.labelsize": 20,
        "figure.labelsize": 18,
        "xtick.labelsize": 16,
        "ytick.labelsize": 16,
        "figure.constrained_layout.wspace": 0,
        "figure.constrained_layout.hspace": 0,
        "figure.constrained_layout.h_pad": 0,
        "figure.constrained_layout.w_pad": 0,
        "axes.linewidth": 1.2,
    }
)

import jax
import jax.numpy as jnp

# you should always set this
jax.config.update("jax_enable_x64", True)

## Damped Harmonic Oscillator (DHO) Process

This notebook demonstrates how to fit the Damped Harmonic Oscillator (DHO) model to a single-band light curve. For multiband fitting, see notebook [02_Multiband](02_Multiband.ipynb).

The DHO model is a second-order continuous-time autoregressive moving average (CARMA) process. It is defined as the solution to the stochastic differential equation

$$
\mathrm{d}^{2}x + \alpha_{1}\mathrm{d}x + \alpha_{0}x
= \beta_{0}\mathrm{d}W + \beta_{1}\mathrm{d}\bigl(\mathrm{d}W\bigr),
$$

where $W$ denotes a Wiener process. The coefficients $\alpha_{0}$ and $\alpha_{1}$ are the autoregressive parameters, while $\beta_{0}$ and $\beta_{1}$ are the moving-average parameters.

<div class="alert alert-info">

**Note**

The CARMA parameter notation follows the convention of [Kelly+14](https://arxiv.org/abs/1402.5978).

</div>


### 1. Light Curve Simulation

In [None]:
from eztaox.kernels.quasisep import CARMA
from eztaox.simulator import UniVarSim
from eztaox.ts_utils import add_noise

In [None]:
# CARMA(2,1)/DHO parameters
alphas = jnp.asarray([0.0002, 0.05])
betas = jnp.asarray([0.0006, 0.03])
sim_params = {"log_kernel_param": jnp.log(jnp.hstack([alphas, betas]))}

# simulation configurations
lc_seed = 2
random_seed = 2
noise_seed = 11
min_dt, max_dt = 1, 3650.0

# simulate
s = UniVarSim(CARMA.init(jnp.log(alphas), jnp.log(betas)), min_dt, max_dt, sim_params)
sim_t, sim_y = s.random(
    200, jax.random.PRNGKey(lc_seed), jax.random.PRNGKey(random_seed)
)
sim_yerr = jnp.ones_like(sim_t) * 0.05
sim_y_noisy = add_noise(sim_y, sim_yerr, jax.random.PRNGKey(noise_seed))

plt.errorbar(sim_t, sim_y_noisy, sim_yerr, fmt=".")
plt.xlabel("Time (day)")
plt.ylabel("Flux (mag)")

### 2. Fitting
Here, we demonstrate how to use the `UniVarModel` for fitting single-band light curves.

In [None]:
import numpyro
import numpyro.distributions as dist
from eztaox.fitter import random_search
from eztaox.models import UniVarModel
from numpyro.handlers import seed as numpyro_seed

#### 2.1 Initialize Light Curve Model

In [None]:
zero_mean = False
p = 2  # CARMA p-order
test_params = {"log_kernel_param": jnp.log(np.array([0.1, 1.1, 1.0, 3.0]))}

# define kernel
k = CARMA.init(
    jnp.exp(test_params["log_kernel_param"][:p]),
    jnp.exp(test_params["log_kernel_param"][p:]),
)

# define univar model
m = UniVarModel(sim_t, sim_y_noisy, sim_yerr, k, zero_mean=zero_mean)
m

#### 2.2 Define InitSampler

In [None]:
def initSampler():
    # DHO Alpha & Beta parameters
    log_alpha = numpyro.sample(
        "log_alpha", dist.Uniform(low=-16.0, high=0.0).expand([2])
    )
    log_beta = numpyro.sample("log_beta", dist.Uniform(low=-10.0, high=2.0).expand([2]))

    log_kernel_param = jnp.hstack([log_alpha, log_beta])
    numpyro.deterministic("log_kernel_param", log_kernel_param)

    # mean
    mean = numpyro.sample("mean", dist.Uniform(low=-0.2, high=0.2))

    sample_params = {"log_kernel_param": log_kernel_param, "mean": mean}
    return sample_params

In [None]:
# generate a random initial guess
sample_key = jax.random.PRNGKey(1)
prior_sample = numpyro_seed(initSampler, rng_seed=sample_key)()
prior_sample

#### 2.3 MLE Fitting

In [None]:
%%time
model = m
sampler = initSampler
fit_key = jax.random.PRNGKey(1)
nSample = 10_000
nBest = 10  # it seems like this number needs to be high

bestP, ll = random_search(model, initSampler, fit_key, nSample, nBest)
bestP

In [None]:
# True DHO param
# Note that EzTao follows the CARMA notation from Moreno+19,
# and EzTaoX adopts the CARMA notation from Kelly+14.
# The main difference is that the alpha parameter index is reversed.
print("True DHO Params (in natual log):")
print(np.log(np.hstack([alphas, betas])))
print("MLE DHO Params (in natual log):")
print(bestP["log_kernel_param"])

### 3. MCMC

In [None]:
import arviz as az
from numpyro.infer import MCMC, NUTS

In [None]:
def numpyro_model(t, yerr, y=None):
    log_alpha = numpyro.sample(
        "log_alpha", dist.Uniform(low=-16.0, high=0.0).expand([2])
    )
    log_beta = numpyro.sample("log_beta", dist.Uniform(low=-10.0, high=2.0).expand([2]))

    log_kernel_param = jnp.hstack([log_alpha, log_beta])
    numpyro.deterministic("log_kernel_param", log_kernel_param)

    # mean: use a normal prior for better convergence
    mean = numpyro.sample("mean", dist.Normal(0.0, 0.1))

    sample_params = {"log_kernel_param": log_kernel_param, "mean": mean}

    # the following is different from the initSampler
    zero_mean = False
    p = 2

    k = CARMA.init(
        jnp.exp(test_params["log_kernel_param"][:p]),
        jnp.exp(test_params["log_kernel_param"][p:]),
    )
    m = UniVarModel(sim_t, sim_y_noisy, sim_yerr, k, zero_mean=zero_mean)
    m.sample(sample_params)

In [None]:
%%time
nuts_kernel = NUTS(
    numpyro_model,
    dense_mass=True,
    target_accept_prob=0.9,
    init_strategy=numpyro.infer.init_to_median,
)

mcmc = MCMC(
    nuts_kernel,
    num_warmup=1000,
    num_samples=5000,
    num_chains=1,
    # progress_bar=False,
)

mcmc_seed = 10
mcmc.run(jax.random.PRNGKey(mcmc_seed), sim_t, sim_yerr, y=sim_y_noisy)
data = az.from_numpyro(mcmc)
mcmc.print_summary()

#### Visualize Chains, Posterior Distributions

In [None]:
import warnings

warnings.filterwarnings("ignore", category=RuntimeWarning)

In [None]:
az.plot_trace(data, var_names=["log_alpha", "log_beta", "mean"])
plt.subplots_adjust(hspace=0.4)

In [None]:
az.plot_pair(
    data,
    var_names=["log_alpha", "log_beta", "mean"],
    reference_values={
        "log_alpha 0": np.log(alphas)[0],
        "log_alpha 1": np.log(alphas)[1],
        "log_beta 0": np.log(betas)[0],
        "log_beta 1": np.log(betas)[1],
        "mean": 0.0,
    },
    reference_values_kwargs={"color": "orange", "markersize": 20, "marker": "s"},
    kind="scatter",
    marginals=True,
    textsize=25,
)
plt.subplots_adjust(hspace=0.0, wspace=0.0)

### 4. Second-order Statistics

In [None]:
from eztaox.kernel_stat2 import gpStat2

ts = np.logspace(0, 4)
fs = np.logspace(-4, 0)

In [None]:
# get MCMC samples
flatPost = data.posterior.stack(sample=["chain", "draw"])
log_carma_draws = flatPost["log_kernel_param"].values.T

In [None]:
# create second-order stat object
dho_k = CARMA.init(alphas, betas)
gpStat2_dho = gpStat2(dho_k)

#### 4.1 Structure Function

In [None]:
# compute sf for MCMC draws
mcmc_sf = jax.vmap(gpStat2_dho.sf, in_axes=(None, 0))(ts, jnp.exp(log_carma_draws))

In [None]:
## plot
# ture SF
plt.loglog(ts, gpStat2_dho.sf(ts), c="k", label="True SF", zorder=100, lw=2)
plt.legend(fontsize=15)
# MCMC SFs
for sf in mcmc_sf[::50]:
    plt.loglog(ts, sf, c="tab:green", alpha=0.15)

plt.xlabel("Time")
plt.ylabel("SF")

#### 4.1 Power Spectral Density (PSD)

In [None]:
# compute sf for MCMC draws
mcmc_psd = jax.vmap(gpStat2_dho.psd, in_axes=(None, 0))(fs, jnp.exp(log_carma_draws))

In [None]:
## plot
# ture PSD
plt.loglog(fs, gpStat2_dho.psd(fs), c="k", label="True PSD", zorder=100, lw=2)
plt.legend(fontsize=15)

# MCMC PSDs
for psd in mcmc_psd[::50]:
    plt.loglog(fs, psd, c="tab:green", alpha=0.15)

plt.xlabel("Frequency")
plt.ylabel("PSD")