# 5-6-周期性のモデル化

In [None]:
# -*- coding: utf-8 -*-
# Basic Structural Time Series with seasonality (NumPyro version)
# - CSV loading: pandas
# - Visualization: matplotlib / seaborn / ArviZ
# - Bayesian inference: NumPyro (NUTS)
# - Model visualization: numpyro.contrib.render.render_model
# - Posterior visualization: ArviZ (uses hdi_prob, not credible_interval)
# - No .loc on pandas DataFrame

import warnings
warnings.filterwarnings("ignore")

import os
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns

import arviz as az

import jax
import jax.numpy as jnp
from jax import random, lax

import numpyro
import numpyro.distributions as dist
from numpyro.infer import MCMC, NUTS, Predictive

# -----------------------------
# 1) Data Loading & Quick Plot
# -----------------------------
# CSV: 5-6-1-sales-ts-4.csv (columns: date,sales)
csv_path = "5-6-1-sales-ts-4.csv"
sales_df_4 = pd.read_csv(csv_path, parse_dates=["date"])

# Print head (use print() as required)
print("Head of data:")
print(sales_df_4.head(3))

# Basic time series plot (English labels)
plt.figure(figsize=(9, 4))
sns.lineplot(data=sales_df_4, x="date", y="sales")
plt.title("Sales Time Series")
plt.xlabel("Date")
plt.ylabel("Sales")
plt.tight_layout()
plt.show()

# -----------------------------------------
# 2) NumPyro Model (Basic Structural + Seasonality)
#    Stan code equivalent:
#      mu[i]    ~ Normal(2*mu[i-1] - mu[i-2], s_z)
#      gamma[i] ~ Normal(-sum(gamma[i-6:i-1]), s_s)  for i>=7
#      y[i]     ~ Normal(mu[i] + gamma[i], s_v)
# -----------------------------------------
def basic_structural_model(y=None, T: int = None):
    if T is None:
        if y is None:
            raise ValueError("Either y or T must be provided.")
        T = int(np.asarray(y).shape[0])  # 形状は静的な Python int に

    # 1) scales first
    s_z = numpyro.sample("s_z", dist.HalfNormal(10.0))
    s_s = numpyro.sample("s_s", dist.HalfNormal(10.0))
    s_v = numpyro.sample("s_v", dist.HalfNormal(10.0))

    # 2) initial states
    mu0 = numpyro.sample("mu0", dist.Normal(0.0, 10.0))
    mu1 = numpyro.sample("mu1", dist.Normal(0.0, 10.0))
    gamma_init = numpyro.sample("gamma_init",
                                dist.Normal(jnp.zeros(6), 10.0).to_event(1))

    # 3) innovation lengths as Python ints
    n_mu = max(T - 2, 0)
    n_gamma = max(T - 6, 0)
    eps_mu = (numpyro.sample("eps_mu",
              dist.Normal(0.0, s_z).expand([n_mu]).to_event(1))
              if n_mu > 0 else jnp.zeros(0))
    eps_gamma = (numpyro.sample("eps_gamma",
                 dist.Normal(0.0, s_s).expand([n_gamma]).to_event(1))
                 if n_gamma > 0 else jnp.zeros(0))

    # 4) build mu
    if T >= 2:
        def mu_body(carry, t):
            mu_tm2, mu_tm1 = carry
            mu_t = 2.0 * mu_tm1 - mu_tm2 + eps_mu[t - 2]
            return (mu_tm1, mu_t), mu_t
        (_, _), mu_rest = lax.scan(mu_body, (mu0, mu1), jnp.arange(2, T))
        mu = jnp.concatenate([jnp.array([mu0, mu1]), mu_rest]) if T > 2 else jnp.array([mu0, mu1])
    elif T == 1:
        mu = jnp.array([mu0])
    else:
        mu = jnp.array([])

    # 5) build gamma
    if T >= 6:
        def gamma_body(gbuf, t):
            g_t = -jnp.sum(gbuf) + eps_gamma[t - 6]
            gbuf = jnp.concatenate([gbuf[1:], jnp.array([g_t])])
            return gbuf, g_t
        _, gamma_rest = lax.scan(gamma_body, gamma_init, jnp.arange(6, T))
        gamma = jnp.concatenate([gamma_init, gamma_rest]) if T > 6 else gamma_init
    else:
        gamma = gamma_init[:T]

    # 6) define alpha BEFORE using it in the likelihood
    alpha = mu + gamma
    numpyro.deterministic("mu", mu)
    numpyro.deterministic("gamma", gamma)
    numpyro.deterministic("alpha", alpha)

    # 7) observation model
    numpyro.sample("y", dist.Normal(alpha, s_v), obs=y)




# Optional: visualize the model structure (NumPyro built-in)
try:
    from numpyro.contrib.render import render_model
    dummy_y = jnp.ones(sales_df_4.shape[0])  # shape for graph rendering
    gm = render_model(
        basic_structural_model,
        model_args=(dummy_y,),
        render_distributions=True,
        render_params=True
    )
    # In notebooks, display(gm). Here we just announce it.
    print("Graphical model prepared via numpyro.contrib.render.render_model().")
except Exception as e:
    print(f"Model rendering skipped (reason: {e})")

# -------------------------
# 3) Run MCMC (NUTS)
# -------------------------
y_obs = jnp.array(sales_df_4["sales"].to_numpy())
T = y_obs.shape[0]

rng_key = random.PRNGKey(1)
kernel = NUTS(basic_structural_model, target_accept_prob=0.97, max_tree_depth=15)
mcmc = MCMC(kernel, num_warmup=2000, num_samples=1000, num_chains=4, progress_bar=True)
#mcmc.run(rng_key, y=y_obs)
T = int(y_obs.shape[0])
mcmc.run(rng_key, y=y_obs, T=T)
mcmc.print_summary(exclude_deterministic=False)  # console summary
print("MCMC finished.")

# Diagnostics (print results)
extra = mcmc.get_extra_fields()
if "diverging" in extra:
    total_div = int(np.sum(np.array(extra["diverging"])))
    print(f"Number of divergences: {total_div}")

# -------------------------
# 4) Posterior & PPC via ArviZ (respecting constraints)
#    - az.from_numpyro: use 'posterior' (NO mcmc / sample_stats / observed_data)
# -------------------------
# Get posterior samples grouped by chain so ArviZ can read dimensions
#posterior_samples = mcmc.get_samples(group_by_chain=True)
# 1) サンプルを「チェイン結合」にする（推奨）
posterior_samples = mcmc.get_samples()  # ← group_by_chain=False（デフォルト）


# Add deterministic series (alpha, mu, gamma) into posterior dict
rng_key, rng_key_det, rng_key_ppc = random.split(rng_key, 3)
#det_out = Predictive(
#    basic_structural_model,
#    posterior_samples,
#    return_sites=["alpha", "mu", "gamma"]
#)(rng_key_det, y=y_obs)
#det_out = Predictive(
#    basic_structural_model,
#    posterior_samples,
#    return_sites=["alpha", "mu", "gamma"]
#)(rng_key_det, T=T)  # y は渡さない → 無条件サンプルはしない
# 2) Predictive でサンプル次元を1つと明示
det_out = Predictive(
    basic_structural_model,
    posterior_samples,
    return_sites=["alpha", "mu", "gamma"],
    batch_ndims=1,            # ★これがポイント
)(rng_key_det, T=T)

# Merge deterministics into posterior dictionary
posterior_full = dict(posterior_samples)
for k in ["alpha", "mu", "gamma"]:
    posterior_full[k] = det_out[k]

# Posterior predictive draws for y
#ppc = Predictive(
#    basic_structural_model,
#    posterior_samples,
#    return_sites=["y"]
#)(rng_key_ppc)  # y=None internally → draws from likelihood
#ppc = Predictive(
#    basic_structural_model,
#    posterior_samples,
#    return_sites=["y"]
#)(rng_key_ppc, T=T)  # y を渡さない → モデルが新たに y を生成
ppc = Predictive(
    basic_structural_model,
    posterior_samples,
    return_sites=["y"],
    batch_ndims=1,            # ★同様に指定
)(rng_key_ppc, T=T)

# Build InferenceData WITHOUT using banned args
coords = {"time": np.arange(T)}
dims = {"alpha": ["time"], "mu": ["time"], "gamma": ["time"], "y": ["time"]}

#idata = az.from_numpyro(
#    posterior=posterior_full,
#    posterior_predictive={"y": ppc["y"]},
#    coords=coords,
#    dims=dims,
#)
# チェイン単位に並んだサンプルを辞書化
#post = mcmc.get_samples(group_by_chain=True)  # dict of arrays (chain, draw, ...)
#post = {k: np.asarray(v) for k, v in post.items()}

# 決定論サイトを追加
#post["alpha"] = np.asarray(det_out["alpha"])
#post["mu"]    = np.asarray(det_out["mu"])
#post["gamma"] = np.asarray(det_out["gamma"])

# 事後予測
#pp_dict = {"y": np.asarray(ppc["y"])}

# すべて from_dict で InferenceData 化
#idata = az.from_dict(
#    posterior=post,
#    posterior_predictive=pp_dict,
#    coords=coords,
#    dims=dims,
#)

post = mcmc.get_samples(group_by_chain=True)  # dict: (chain, draw, ...)

det_out = Predictive(
    basic_structural_model,
    post,
    return_sites=["alpha", "mu", "gamma"],
    batch_ndims=2,        # ← ここが重要
)(rng_key_det, T=T)

ppc = Predictive(
    basic_structural_model,
    post,
    return_sites=["y"],
    batch_ndims=2,        # ← ここも重要
)(rng_key_ppc, T=T)

post = {k: np.asarray(v) for k, v in post.items()}
post.update({
    "alpha": np.asarray(det_out["alpha"]),  # (chain, draw, time)
    "mu":    np.asarray(det_out["mu"]),
    "gamma": np.asarray(det_out["gamma"]),
})
pp_dict = {"y": np.asarray(ppc["y"])}       # (chain, draw, time)

idata = az.from_dict(
    posterior=post,
    posterior_predictive=pp_dict,
    coords={"time": np.arange(T)},
    dims={"alpha": ["time"], "mu": ["time"], "gamma": ["time"], "y": ["time"]},
)

# Print R-hat & ESS (Dataset printed via print(); az.hdi also returns Dataset here)
print("\nR-hat (all variables):")
print(az.rhat(idata))

print("\nEffective sample size (bulk):")
print(az.ess(idata, method="bulk"))

print("\n95% HDI for key scales (Dataset):")
print(az.hdi(idata.posterior[["s_z", "s_s", "s_v"]], hdi_prob=0.95))

# -------------------------
# 5) Plots: posterior parameters & diagnostics
# -------------------------
# Posterior for s_z, s_s, s_v (use hdi_prob)
az.plot_posterior(
    idata,
    var_names=["s_z", "s_s", "s_v"],
    hdi_prob=0.95
)
plt.suptitle("Posterior of Scale Parameters", y=1.02)
plt.tight_layout()
plt.show()

# Forest plot (do NOT pass 'group' arg per requirements)
az.plot_forest(idata, var_names=["s_z", "s_s", "s_v"], combined=True)
plt.title("Forest Plot of Scale Parameters")
plt.tight_layout()
plt.show()

# Trace plots for scales
az.plot_trace(idata, var_names=["s_z", "s_s", "s_v"])
plt.tight_layout()
plt.show()

# -------------------------
# 6) State plots: alpha (all components), mu (without seasonal), gamma (seasonal)
#    Use ArviZ plotting utilities for HDI shading; labels are in English.
# -------------------------
time = sales_df_4["date"].to_numpy()

def plot_state_with_hdi(var_name, title, y_obs=None, ylabel="Value", hdi_prob=0.95):
    da = idata.posterior[var_name]  # (chain, draw, time)
    mean_series = da.mean(dim=("chain", "draw")).to_numpy()

    plt.figure(figsize=(10, 4))
    # --- 修正ポイント（どちらか一方でOK） ---
    # A) シンプルに：
    az.plot_hdi(x=time, y=da, hdi_prob=hdi_prob, smooth=False)

    # B) 補間したい場合（Aの代わりに使用）：
    # import matplotlib.dates as mdates
    # xnum = mdates.date2num(pd.to_datetime(time))
    # az.plot_hdi(x=xnum, y=da, hdi_prob=hdi_prob)
    # ax = plt.gca()
    # ax.xaxis_date()
    # locator = mdates.AutoDateLocator()
    # ax.xaxis.set_major_locator(locator)
    # ax.xaxis.set_major_formatter(mdates.ConciseDateFormatter(locator))
    # ---------------------------------------

    plt.plot(time, mean_series, linewidth=2)
    if y_obs is not None:
        plt.scatter(time, np.array(y_obs), s=10, alpha=0.6, label="Observed")
        plt.legend()
    plt.title(title)
    plt.xlabel("Date")
    plt.ylabel(ylabel)
    plt.tight_layout()
    plt.show()

# All components included (alpha): overlay observed y
plot_state_with_hdi("alpha", title="State Estimate (All Components)", y_obs=y_obs, ylabel="Sales")

# Without seasonal (mu)
plot_state_with_hdi("mu", title="State Estimate (Without Seasonal Component)", ylabel="Level + Drift")

# Seasonal component (gamma)
plot_state_with_hdi("gamma", title="Seasonal Component", ylabel="Gamma")

# -------------------------
# 7) Posterior Predictive Check
#    Must pass group="posterior" when plotting PPC of posterior
# -------------------------
coords = {"time": np.arange(T)}
dims   = {"alpha": ["time"], "mu": ["time"], "gamma": ["time"], "y": ["time"]}

post   = {k: np.asarray(v) for k, v in mcmc.get_samples(group_by_chain=True).items()}
post.update({
    "alpha": np.asarray(det_out["alpha"]),  # (chain, draw, time)
    "mu":    np.asarray(det_out["mu"]),
    "gamma": np.asarray(det_out["gamma"]),
})
pp_dict = {"y": np.asarray(ppc["y"])}       # (chain, draw, time)

idata = az.from_dict(
    posterior=post,
    posterior_predictive=pp_dict,
    observed_data={"y": np.asarray(y_obs)},  # ← これを追加（shape=(T,) でOK）
    coords=coords,
    dims=dims,
)

# OK: 両方そろったので描ける
az.plot_ppc(idata, group="posterior", num_pp_samples=200)
#az.plot_ppc(idata, group="posterior", num_pp_samples=200)
plt.title("Posterior Predictive Check (Posterior)")
plt.tight_layout()
plt.show()

# -------------------------
# 8) Print a compact summary (percentiles etc.)
# -------------------------
summary_df = az.summary(
    idata,
    var_names=["s_z", "s_s", "s_v"],
    hdi_prob=0.95
)
print("\nPosterior summary (scales):")
print(summary_df)

# Done.
print("\nCompleted: Basic structural time series with seasonality (NumPyro).")


In [None]:
# -------------------------
# 8) Print a compact summary (percentiles etc.)
# -------------------------
summary_df = az.summary(
    idata,
    var_names=["s_z", "s_s", "s_v"],
    hdi_prob=0.95
)
print("\nPosterior summary (scales):")
print(summary_df)

# Done.
print("\nCompleted: Basic structural time series with seasonality (NumPyro).")