# 5-2-ローカルレベルモデル

In [None]:
# -*- coding: utf-8 -*-
# Local Level Model in Python with NumPyro + ArviZ
# - Data I/O: pandas
# - Visualization: matplotlib / ArviZ
# - Bayesian inference: NumPyro (JAX)
# - Model visualization: numpyro.render_model
# - All printed results shown via print()

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

import jax
import jax.numpy as jnp
from numpyro import sample, deterministic
import numpyro
import numpyro.distributions as dist
from numpyro.infer import MCMC, NUTS, Predictive
from numpyro.contrib.control_flow import scan

import arviz as az

# ------------------------------------------------------------------------------
# 0. Reproducibility
# ------------------------------------------------------------------------------
np.random.seed(1)
rng_key = jax.random.PRNGKey(1)

# ------------------------------------------------------------------------------
# 1. White noise & Random walk (single series)
# ------------------------------------------------------------------------------
wn = np.random.normal(loc=0.0, scale=1.0, size=100)

# demonstration of cumulative sum (R: cumsum(c(1,3,2)))
print("cumsum([1, 3, 2]) =", np.cumsum([1, 3, 2]))

rw = np.cumsum(wn)

# Plot: White Noise
plt.figure(figsize=(7, 3))
plt.plot(wn, linewidth=1.5)
plt.title("White Noise")
plt.xlabel("Time")
plt.ylabel("Value")
plt.tight_layout()
plt.show()

# Plot: Random Walk
plt.figure(figsize=(7, 3))
plt.plot(rw, linewidth=1.5)
plt.title("Random Walk")
plt.xlabel("Time")
plt.ylabel("Level")
plt.tight_layout()
plt.show()

# ------------------------------------------------------------------------------
# 2. Multiple white noises & random walks
# ------------------------------------------------------------------------------
T, K = 100, 20
wn_mat = np.random.normal(0.0, 1.0, size=(T, K))
rw_mat = np.cumsum(wn_mat, axis=0)

# Plot: multiple white noises (no legend)
plt.figure(figsize=(7, 3))
plt.plot(wn_mat, alpha=0.7)
plt.title("White Noise (20 series)")
plt.xlabel("Time")
plt.ylabel("Value")
plt.tight_layout()
plt.show()

# Plot: multiple random walks (no legend)
plt.figure(figsize=(7, 3))
plt.plot(rw_mat, alpha=0.7)
plt.title("Random Walk (20 series)")
plt.xlabel("Time")
plt.ylabel("Level")
plt.tight_layout()
plt.show()

# ------------------------------------------------------------------------------
# 3. Load CSV & convert to datetime (POSIXct equivalent)
# ------------------------------------------------------------------------------
# Expecting columns: "date", "sales"
sales_df = pd.read_csv("5-2-1-sales-ts-1.csv")
sales_df["date"] = pd.to_datetime(sales_df["date"])

print("Head of sales_df:")
print(sales_df.head(3))

# POSIXct numeric seconds demo (R: as.numeric(POSIXct_time))
posix_time = pd.Timestamp("1970-01-01 00:00:05", tz="UTC")
seconds_since_epoch = posix_time.value // 10**9  # nanoseconds -> seconds
print("Seconds since epoch for 1970-01-01 00:00:05 UTC:", int(seconds_since_epoch))

# ------------------------------------------------------------------------------
# 4. Local Level Model in NumPyro (Stan translation)
#    Stan model:
#    mu[1] ~ (implicit diffuse prior)
#    for t=2..T: mu[t] ~ normal(mu[t-1], s_w)
#    for t=1..T: y[t]  ~ normal(mu[t],   s_v)
# ------------------------------------------------------------------------------
#def local_level_model(T, y=None):
    # Priors
#    s_w = sample("s_w", dist.HalfNormal(1.0))   # process noise (>= 0)
#    s_v = sample("s_v", dist.HalfNormal(1.0))   # observation noise (>= 0)
#    mu0 = sample("mu0", dist.Normal(0.0, 10.0)) # diffuse-ish prior for initial level

#    def transition(mu_prev, t):
#        mu_t = sample(f"mu_raw_{int(t)}", dist.Normal(mu_prev, s_w))
#        return mu_t, mu_t

    # Build latent state trajectory mu[0..T-1]
#    ts_idx = jnp.arange(1, T)
#    _, mu_seq = scan(transition, mu0, ts_idx)
#    mu = jnp.concatenate([jnp.array([mu0]), mu_seq])
#    deterministic("mu", mu)  # collect states as one vector for easier analysis

#    # Observation model
#    sample("y", dist.Normal(mu, s_v), obs=y)

def local_level_model(T, y=None):
    s_w = sample("s_w", dist.HalfNormal(1.0))
    s_v = sample("s_v", dist.HalfNormal(1.0))
    mu0 = sample("mu0", dist.Normal(0.0, 10.0))

    # 1期以降の増分 ~ N(0, s_w) をまとめてサンプル
    increments = sample("eps", dist.Normal(0.0, s_w).expand([T-1]))
    # 累積和でランダムウォークを生成（初期水準 mu0 を加える）
    mu = mu0 + jnp.concatenate([jnp.zeros(1), jnp.cumsum(increments)])
    deterministic("mu", mu)

    sample("y", dist.Normal(mu, s_v), obs=y)

# ------------------------------------------------------------------------------
# 5. Prepare data & run MCMC
# ------------------------------------------------------------------------------
y_data = sales_df["sales"].to_numpy()
T_data = y_data.shape[0]

nuts = NUTS(local_level_model)
mcmc = MCMC(nuts, num_warmup=1000, num_samples=2000, num_chains=4, progress_bar=True)
mcmc.run(rng_key, T=T_data, y=y_data)

# ------------------------------------------------------------------------------
# 6. Diagnostics & numeric results (use print())
# ------------------------------------------------------------------------------
# Convert to ArviZ InferenceData (do NOT pass observed_data per your constraint)
#posterior_samples = mcmc.get_samples(group_by_chain=True)
#predictive = Predictive(local_level_model, posterior_samples, return_sites=["y", "mu"])
posterior_samples = mcmc.get_samples()  # group_by_chain=False が既定
predictive = Predictive(local_level_model, posterior_samples, return_sites=["y", "mu"])
ppc_samples = predictive(rng_key, T=T_data)  # posterior predictive (y, mu)

#idata = az.from_numpyro(
#    mcmc=mcmc,
#    posterior_predictive=ppc_samples,  # ok; observed_data argument is intentionally NOT used
#)
idata = az.from_numpyro(posterior=mcmc, posterior_predictive=ppc_samples)

# R-hat
rhat_df = az.rhat(idata).to_dataframe()
print("\nR-hat diagnostics (selected):")
print(rhat_df.filter(regex="^(s_w|s_v|mu0)$"))

# Summary for key parameters
summary_df = az.summary(idata, var_names=["s_w", "s_v", "mu0"], hdi_prob=0.95)
print("\nParameter summary (s_w, s_v, mu0):")
print(summary_df)

# Quantiles for the first state's posterior (mu[0])
# Extract posterior samples of mu (shape: chains, draws, T)
mu_samples = idata.posterior["mu"].stack(sample=("chain", "draw")).values  # (S, T)
q_low, q_med, q_high = np.quantile(mu_samples[:, 0], [0.025, 0.5, 0.975])
print("\nQuantiles for mu[1] (time index 1 in R):")
print({"2.5%": q_low, "50%": q_med, "97.5%": q_high})

# ------------------------------------------------------------------------------
# 7. Visualization
#    7.1 Model graph with numpyro.render_model (built-in model visualization)
# ------------------------------------------------------------------------------
try:
    from numpyro import render_model
    dot = render_model(
        local_level_model,
        model_args=(T_data,),
        model_kwargs={"y": y_data},
        render_distributions=True,
        render_params=True,
    )
    # Save a diagram; will create 'local_level_model.png'
    dot.render("local_level_model", format="png", cleanup=True)
    print('\nModel graph saved to "local_level_model.png"')
except Exception as e:
    print("\nModel graph rendering skipped (graphviz may be missing):", e)

# ------------------------------------------------------------------------------
#    7.2 Posterior of key parameters with ArviZ (use hdi_prob; do NOT use credible_interval)
# ------------------------------------------------------------------------------
fig = az.plot_posterior(
    idata,
    var_names=["s_w", "s_v", "mu0"],
    hdi_prob=0.95
)
plt.suptitle("Posterior Distributions (s_w, s_v, mu0)", y=1.02)
plt.tight_layout()
plt.show()

# ------------------------------------------------------------------------------
#    7.3 Posterior Predictive Check (set group="posterior" as required)
# ------------------------------------------------------------------------------
fig = az.plot_ppc(idata, group="posterior", num_pp_samples=200)
plt.suptitle("Posterior Predictive Check", y=1.02)
plt.tight_layout()
plt.show()

# ------------------------------------------------------------------------------
#    7.4 Reproduce the R-style state plot (median + 95% interval) with matplotlib
# ------------------------------------------------------------------------------
def summarize_state_intervals(mu_samples_2d, time_index):
    """
    Compute 2.5%, 50%, 97.5% quantiles across samples for each time point.
    mu_samples_2d: array (S, T)
    """
    lwr, med, upr = np.quantile(mu_samples_2d, [0.025, 0.5, 0.975], axis=0)
    out = pd.DataFrame({"lwr": lwr, "fit": med, "upr": upr})
    out["time"] = time_index
    return out

#state_df = summarize_state_intervals(mu_samples, sales_df["date"])
#state_df["obs"] = y_data
# (chain, draw, <time_dim>) -> "sample" 次元にスタック → (S, <time_dim>)
mu_da = idata.posterior["mu"]
time_dim = [d for d in mu_da.dims if d not in ("chain","draw")][0]
mu_stacked = mu_da.stack(sample=("chain","draw")).transpose("sample", time_dim).values  # (S,T)
# ↑ transposeで「サンプルが0軸、時間が1軸」と明示
lwr, fit, upr = np.quantile(mu_stacked, [0.025, 0.5, 0.975], axis=0)  # 時間方向のみに長さ T
state_df = pd.DataFrame({"lwr": lwr, "fit": fit, "upr": upr})
state_df["time"] = sales_df["date"].values
state_df["obs"]  = y_data



print("\nHead of state summary (lwr, fit, upr, time, obs):")
print(state_df.head(3))

plt.figure(figsize=(9, 4))
plt.plot(state_df["time"], state_df["fit"], linewidth=2.0, label="State (median)")
plt.fill_between(state_df["time"], state_df["lwr"], state_df["upr"], alpha=0.3, label="95% HDI")
plt.scatter(state_df["time"], state_df["obs"], s=10, alpha=0.7, label="Observed")
plt.title("Local Level Model: Estimated State and Observations")
plt.xlabel("Time")
plt.ylabel("Sales")
plt.legend()
plt.tight_layout()
plt.show()

# ------------------------------------------------------------------------------
#    7.5 Generic plotting function like R's plotSSM (English labels)
# ------------------------------------------------------------------------------
#def plot_ssm(mu_samples_2d, time_vec, obs_vec=None,
#             title="Local Level Model Result", y_label="Sales"):
#    """
#    Plot median and 95% interval for the state; optionally overlay observations.
#    - mu_samples_2d: (S, T) posterior draws of the state trajectory `mu`
#    - time_vec: sequence of datetimes (length T)
#    - obs_vec: optional observations y (length T)
#    """
#    lwr, med, upr = np.quantile(mu_samples_2d, [0.025, 0.5, 0.975], axis=0)

#    plt.figure(figsize=(9, 4))
#    plt.plot(time_vec, med, linewidth=2.0, label="State (median)")
#    plt.fill_between(time_vec, lwr, upr, alpha=0.3, label="95% HDI")
#    if obs_vec is not None:
#        plt.scatter(time_vec, obs_vec, s=10, alpha=0.7, label="Observed")
#    plt.title(title)
#    plt.xlabel("Time")
#    plt.ylabel(y_label)
#    plt.legend()
#    plt.tight_layout()
#    plt.show()
def plot_ssm(mu_samples_any, time_vec, obs_vec=None, title="Local Level Model Result", y_label="Sales"):
    arr = np.asarray(mu_samples_any)
    # 3次元なら (chain, draw, time) を (sample, time) に
    if arr.ndim == 3:
        arr = arr.reshape(arr.shape[0]*arr.shape[1], arr.shape[2])
    # 2次元なら時間軸を右側にそろえる
    if arr.ndim == 2:
        if arr.shape[1] != len(time_vec) and arr.shape[0] == len(time_vec):
            arr = arr.T
    # ここまでで arr は (S,T)
    lwr, med, upr = np.quantile(arr, [0.025, 0.5, 0.975], axis=0)  # サンプル軸で要約
    plt.figure(figsize=(9,4))
    plt.plot(time_vec, med, linewidth=2.0, label="State (median)")
    plt.fill_between(time_vec, lwr, upr, alpha=0.3, label="95% HDI")
    if obs_vec is not None:
        plt.scatter(time_vec, obs_vec, s=10, alpha=0.7, label="Observed")
    plt.title(title); plt.xlabel("Time"); plt.ylabel(y_label); plt.legend(); plt.tight_layout(); plt.show()



# Call the helper to reproduce the final plot
#plot_ssm(mu_samples, sales_df["date"], obs_vec=y_data,
#         title="Local Level Model: Estimated State",
#         y_label="Sales")

# --- 事前に mu_samples を作り直す（S,T 形状を保証） ---
mu_da = idata.posterior["mu"]                         # dims: ("chain","draw", time_dim)
time_dim = [d for d in mu_da.dims if d not in ("chain","draw")][0]
mu_samples = (mu_da
              .stack(sample=("chain","draw"))         # dims: ("sample", time_dim)
              .transpose("sample", time_dim)          # (S, T) に明示
              .values)
# 以降は今の plot_ssm をそのまま使える
plot_ssm(mu_samples, sales_df["date"], obs_vec=y_data,
         title="Local Level Model: Estimated State", y_label="Sales")
