# 5-5-トレンドの構造

In [None]:
# -*- coding: utf-8 -*-
# Requirements (example):
# pip install numpyro arviz jax jaxlib pandas matplotlib seaborn graphviz

import os
import warnings
warnings.filterwarnings("ignore")

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
import arviz as az

import jax.numpy as jnp
from jax.random import PRNGKey, split
from numpyro import sample, deterministic
import numpyro
import numpyro.distributions as dist
from numpyro.infer import MCMC, NUTS, Predictive

# ------------------------------------------------------------
# 0) Data loading & quick visualization
# ------------------------------------------------------------
# CSVはpandasで読み込み
df = pd.read_csv("5-5-1-sales-ts-3.csv")
df["date"] = pd.to_datetime(df["date"])
y = df["sales"].to_numpy()
T = len(y)

print("Head of the data (first 3 rows):")
print(df.head(3))  # 計算結果の表示はprint()を使用

# 時系列の図（英語ラベル）
plt.figure(figsize=(9, 3.6))
sns.lineplot(data=df, x="date", y="sales")
plt.title("Sales Time Series")
plt.xlabel("Date")
plt.ylabel("Sales")
plt.tight_layout()
plt.show()

# ------------------------------------------------------------
# 1) NumPyro models
# ------------------------------------------------------------
def model_local_level(y=None):
    """
    Local level model:
      mu_t ~ Normal(mu_{t-1}, s_w)
      y_t  ~ Normal(mu_t, s_v)
    """
    s_w = sample("s_w", dist.TruncatedNormal(loc=2.0, scale=2.0, low=0.0))  # weakly informative, >=0
    s_v = sample("s_v", dist.TruncatedNormal(loc=10.0, scale=5.0, low=0.0))

    # weak priors for initial level
    mu0 = sample("mu0", dist.Normal(loc=jnp.array(y[0]) if y is not None else 0.0, scale=10.0))
    mu = [mu0]
    for t in range(1, T):
        mu_t = sample(f"mu[{t}]", dist.Normal(mu[-1], s_w))
        mu.append(mu_t)
    mu = jnp.stack(mu)
    deterministic("mu", mu)

    sample("y", dist.Normal(mu, s_v), obs=jnp.asarray(y) if y is not None else None)


def model_smooth_trend(y=None):
    """
    Smoothing trend (2nd-order local trend without explicit delta):
      mu_t ~ Normal(2*mu_{t-1} - mu_{t-2}, s_z)  for t>=2
      y_t  ~ Normal(mu_t, s_v)
    """
    s_z = sample("s_z", dist.TruncatedNormal(loc=2.0, scale=2.0, low=0.0))
    s_v = sample("s_v", dist.TruncatedNormal(loc=10.0, scale=5.0, low=0.0))

    # weak priors for initial states
    mu0 = sample("mu0", dist.Normal(loc=jnp.array(y[0]) if y is not None else 0.0, scale=10.0))
    mu1 = sample("mu1", dist.Normal(loc=jnp.array(y[0]) if y is not None else 0.0, scale=10.0))
    mu = [mu0, mu1]
    for t in range(2, T):
        pred = 2.0 * mu[-1] - mu[-2]
        mu_t = sample(f"mu[{t}]", dist.Normal(pred, s_z))
        mu.append(mu_t)
    mu = jnp.stack(mu)
    deterministic("mu", mu)

    sample("y", dist.Normal(mu, s_v), obs=jnp.asarray(y) if y is not None else None)


def model_local_linear_trend(y=None):
    """
    Local linear trend model:
      s_w ~ Normal(2,2) truncated >=0
      s_z ~ Normal(0.5,0.5) truncated >=0
      s_v ~ Normal(10,5) truncated >=0
      mu_t    ~ Normal(mu_{t-1} + delta_{t-1}, s_w)
      delta_t ~ Normal(delta_{t-1}, s_z)
      y_t     ~ Normal(mu_t, s_v)
    """
    s_w = sample("s_w", dist.TruncatedNormal(loc=2.0, scale=2.0, low=0.0))
    s_z = sample("s_z", dist.TruncatedNormal(loc=0.5, scale=0.5, low=0.0))
    s_v = sample("s_v", dist.TruncatedNormal(loc=10.0, scale=5.0, low=0.0))

    mu0 = sample("mu0", dist.Normal(loc=jnp.array(y[0]) if y is not None else 0.0, scale=10.0))
    delta0 = sample("delta0", dist.Normal(0.0, 10.0))

    mu = [mu0]
    delta = [delta0]
    for t in range(1, T):
        mu_t = sample(f"mu[{t}]", dist.Normal(mu[-1] + delta[-1], s_w))
        delta_t = sample(f"delta[{t}]", dist.Normal(delta[-1], s_z))
        mu.append(mu_t)
        delta.append(delta_t)
    mu = jnp.stack(mu)
    delta = jnp.stack(delta)
    deterministic("mu", mu)
    deterministic("delta", delta)

    sample("y", dist.Normal(mu, s_v), obs=jnp.asarray(y) if y is not None else None)


# ------------------------------------------------------------
# 2) Run MCMC for each model with NumPyro
# ------------------------------------------------------------
rng = PRNGKey(1)
n_chains = 4
n_warmup = 2000
n_samples = 1000

def run_mcmc(model, y, target_accept=0.9, max_tree_depth=15):
    global rng
    rng, subkey = split(rng)
    kernel = NUTS(model, target_accept_prob=target_accept, max_tree_depth=max_tree_depth)
    mcmc = MCMC(kernel, num_warmup=n_warmup, num_samples=n_samples, num_chains=n_chains, progress_bar=True)
    mcmc.run(subkey, y=y)
    return mcmc

# Local level
mcmc_ll = run_mcmc(model_local_level, y, target_accept=0.8, max_tree_depth=10)
# Smooth trend (use tighter control like R)
mcmc_st = run_mcmc(model_smooth_trend, y, target_accept=0.9, max_tree_depth=15)
# Local linear trend
mcmc_llt = run_mcmc(model_local_linear_trend, y, target_accept=0.8, max_tree_depth=10)

# ------------------------------------------------------------
# 3) Convert to ArviZ InferenceData (without observed_data/mcmc/sample_stats)
# ------------------------------------------------------------
#def to_idata(model, mcmc, y):
    # posterior samples grouped by chain for rhat etc.
#    posterior = mcmc.get_samples(group_by_chain=True)

    # posterior predictive (returns 'y' plus state nodes if asked)
#    rng_ppc = split(PRNGKey(999), 1)[0]
#    predictive = Predictive(model, posterior_samples=mcmc.get_samples(), return_sites=["y", "mu", "delta"])
#    ppc = predictive(rng_ppc, y=None)  # generate from posterior

    # coords/dims
#    coords = {"time": np.arange(T)}
#    dims = {"y": ["time"], "mu": ["time"], "delta": ["time"]}

    # Build InferenceData (NOTE: do NOT pass observed_data/mcmc/sample_stats)
#    idata = az.from_numpyro(
#        posterior=posterior,
#        posterior_predictive=ppc,
#        coords=coords,
#        dims=dims,
#    )
#    return idata
def to_idata(model, mcmc, y):
    # posterior predictive（チェイン×サンプル形状で生成）
    posterior_samples = mcmc.get_samples(group_by_chain=True)
    predictive = Predictive(
        model,
        posterior_samples=posterior_samples,
        return_sites=["y", "mu", "delta"],
        batch_ndims=2,  # ← チェイン×サンプルの2次元をバッチとして扱う
    )
    rng_ppc = PRNGKey(999)
    ppc = predictive(rng_ppc, y=None)

    coords = {"time": np.arange(len(y))}
    dims = {"y": ["time"], "mu": ["time"], "delta": ["time"]}

    # ここがポイント：dictではなくMCMCオブジェクトを渡す
    idata = az.from_numpyro(
        posterior=mcmc,                 # ← 修正
        posterior_predictive=ppc,
        coords=coords,
        dims=dims,
    )
    return idata


idata_ll  = to_idata(model_local_level,        mcmc_ll,  y)
idata_st  = to_idata(model_smooth_trend,       mcmc_st,  y)
idata_llt = to_idata(model_local_linear_trend, mcmc_llt, y)

# ------------------------------------------------------------
# 4) Print summaries (use print()) to display results
# ------------------------------------------------------------
print("\n=== Local Level: parameter summary ===")
print(az.summary(idata_ll, var_names=["s_w", "s_v"], kind="stats", round_to=2))

print("\n=== Smooth Trend: parameter summary ===")
print(az.summary(idata_st, var_names=["s_z", "s_v"], kind="stats", round_to=2))

print("\n=== Local Linear Trend: parameter summary ===")
print(az.summary(idata_llt, var_names=["s_w", "s_z", "s_v"], kind="stats", round_to=2))

# R-hat参考値
# --- replace this block ---
# print("\nR-hat (Local Level):")
# print(az.rhat(idata_ll).to_dataframe().loc[["s_w","s_v"]])

# print("\nR-hat (Smooth Trend):")
# print(az.rhat(idata_st).to_dataframe().loc[["s_z","s_v"]])

# print("\nR-hat (Local Linear Trend):")
# print(az.rhat(idata_llt).to_dataframe().loc[["s_w","s_z","s_v"]])

# --- with either A or B ---

# A) rhat だけを var_names で抽出（おすすめ）
print("\nR-hat (Local Level):")
print(az.rhat(idata_ll,  var_names=["s_w","s_v"]).to_array().to_series())

print("\nR-hat (Smooth Trend):")
print(az.rhat(idata_st,  var_names=["s_z","s_v"]).to_array().to_series())

print("\nR-hat (Local Linear Trend):")
print(az.rhat(idata_llt, var_names=["s_w","s_z","s_v"]).to_array().to_series())

# B) summary から r_hat 列だけ
# print("\nR-hat (Local Level):")
# print(az.summary(idata_ll,  var_names=["s_w","s_v"])[["r_hat"]])
# print("\nR-hat (Smooth Trend):")
# print(az.summary(idata_st,  var_names=["s_z","s_v"])[["r_hat"]])
# print("\nR-hat (Local Linear Trend):")
# print(az.summary(idata_llt, var_names=["s_w","s_z","s_v"])[["r_hat"]])


# ------------------------------------------------------------
# 5) Visualize models using NumPyro built-in function (render_model)
#    If graphviz is not installed, skip gracefully.
# ------------------------------------------------------------
try:
    from numpyro import render_model
    import matplotlib.image as mpimg

    def show_model_graph(model, y, title, filename):
        dot = render_model(model, model_args=(y,), render_distributions=True)
        # Save and show
        outpath = dot.render(filename=filename, format="png", cleanup=True)
        img = mpimg.imread(outpath)
        plt.figure(figsize=(7.5, 5))
        plt.imshow(img)
        plt.axis("off")
        plt.title(title)
        plt.tight_layout()
        plt.show()

    show_model_graph(model_local_level, y, "Local Level Model (Graph)", "local_level_graph")
    show_model_graph(model_smooth_trend, y, "Smooth Trend Model (Graph)", "smooth_trend_graph")
    show_model_graph(model_local_linear_trend, y, "Local Linear Trend Model (Graph)", "local_linear_trend_graph")
except Exception as e:
    print("\n[Info] Skipped model graph rendering (numpyro.render_model/graphviz not available).")

# ------------------------------------------------------------
# 6) Posterior distributions (ArviZ) — use hdi_prob (NOT credible_interval)
# ------------------------------------------------------------
fig = az.plot_posterior(idata_ll,  var_names=["s_w","s_v"], hdi_prob=0.95)
plt.suptitle("Local Level: Posterior of Parameters", y=1.02)
plt.show()

fig = az.plot_posterior(idata_st,  var_names=["s_z","s_v"], hdi_prob=0.95)
plt.suptitle("Smooth Trend: Posterior of Parameters", y=1.02)
plt.show()

fig = az.plot_posterior(idata_llt, var_names=["s_w","s_z","s_v"], hdi_prob=0.95)
plt.suptitle("Local Linear Trend: Posterior of Parameters", y=1.02)
plt.show()

# ------------------------------------------------------------
# 7) Posterior predictive checks (group must be 'posterior')
#    ※ observed_data は InferenceData に入れていないので、予測分布の形状を確認する用途。
# ------------------------------------------------------------
az.plot_ppc(idata_ll,  group="posterior")
plt.title("Local Level: Posterior Predictive Check")
plt.xlabel("Observed/Predicted")
plt.ylabel("Density")
plt.tight_layout()
plt.show()

az.plot_ppc(idata_st,  group="posterior")
plt.title("Smooth Trend: Posterior Predictive Check")
plt.xlabel("Observed/Predicted")
plt.ylabel("Density")
plt.tight_layout()
plt.show()

az.plot_ppc(idata_llt, group="posterior")
plt.title("Local Linear Trend: Posterior Predictive Check")
plt.xlabel("Observed/Predicted")
plt.ylabel("Density")
plt.tight_layout()
plt.show()

# ------------------------------------------------------------
# 8) Plot smoothed states (mu) with HDI bands; and drift (delta) for LLT
# ------------------------------------------------------------
#def plot_state_hdi(idata, state_name, title, ylabel):
    # posterior samples -> xarray: (chain, draw, time)
#    post = idata.posterior
#    if state_name not in post:
#        return
#    state = post[state_name]  # dims: chain, draw, time
#    median = state.median(dim=("chain","draw")).values
#    hdi = az.hdi(state, hdi_prob=0.95)
#    lower = hdi.sel(hdi="lower").values
#    upper = hdi.sel(hdi="higher").values

#    plt.figure(figsize=(9, 3.6))
#    plt.fill_between(df["date"], lower, upper, alpha=0.25, label="95% HDI")
#    plt.plot(df["date"], median, lw=2, label="Posterior median")
#    if state_name == "mu":
#        plt.scatter(df["date"], y, s=10, alpha=0.6, label="Observed")
#    plt.title(title)
#    plt.xlabel("Date")
#    plt.ylabel(ylabel)
#    plt.legend()
#    plt.tight_layout()
#    plt.show()

#import matplotlib.dates as mdates

#def plot_state_hdi(idata, state_name, title, ylabel):
#    post = idata.posterior
#    if state_name not in post:
#        return
#    state = post[state_name]  # dims: chain, draw, time
#    median = state.median(dim=("chain","draw")).values
#    hdi = az.hdi(state, hdi_prob=0.95)
#    lower = hdi.sel(hdi="lower").values.astype(float)
#    upper = hdi.sel(hdi="higher").values.astype(float)

    # ★ 日付をMatplotlibの内部数値（float）に変換
#    x = mdates.date2num(df["date"].dt.to_pydatetime())

#    fig, ax = plt.subplots(figsize=(9, 3.6))
#    ax.fill_between(x, lower, upper, alpha=0.25, label="95% HDI")
#    ax.plot(x, median, lw=2, label="Posterior median")
#    if state_name == "mu":
#        ax.scatter(x, y, s=10, alpha=0.6, label="Observed")
#    ax.set_title(title)
#    ax.set_xlabel("Date")
#    ax.set_ylabel(ylabel)
#    ax.legend()

    # 軸を日付フォーマットに
#    ax.xaxis_date()
#    fig.autofmt_xdate()
#    plt.tight_layout()
#    plt.show()

import matplotlib.dates as mdates
import xarray as xr  # 念のため

def plot_state_hdi(idata, state_name, title, ylabel):
    post = idata.posterior
    if state_name not in post:
        return

    state = post[state_name]  # dims: (chain, draw, time)
    median = state.median(dim=("chain","draw")).to_numpy()

    # ★ az.hdi は Dataset を返す → 変数名で DataArray を取り出す
    hdi_ds = az.hdi(state, hdi_prob=0.95)            # Dataset
    hdi_da = hdi_ds[state_name]                      # DataArray (dimsに'hdi'を含む)
    lower = hdi_da.sel(hdi="lower").to_numpy()
    upper = hdi_da.sel(hdi="higher").to_numpy()

    # 日付は数値化してから fill_between
    x = mdates.date2num(df["date"].dt.to_pydatetime())

    fig, ax = plt.subplots(figsize=(9, 3.6))
    ax.fill_between(x, lower, upper, alpha=0.25, label="95% HDI")
    ax.plot(x, median, lw=2, label="Posterior median")
    if state_name == "mu":
        ax.scatter(x, y, s=10, alpha=0.6, label="Observed")
    ax.set_title(title)
    ax.set_xlabel("Date")
    ax.set_ylabel(ylabel)
    ax.legend()
    ax.xaxis_date()
    fig.autofmt_xdate()
    plt.tight_layout()
    plt.show()



# mu for all models
plot_state_hdi(idata_ll,  "mu", "Local Level: Smoothed Level (mu)", "Level")
plot_state_hdi(idata_st,  "mu", "Smooth Trend: Smoothed Level (mu)", "Level")
plot_state_hdi(idata_llt, "mu", "Local Linear Trend: Smoothed Level (mu)", "Level")

# drift for Local Linear Trend
plot_state_hdi(idata_llt, "delta", "Local Linear Trend: Drift (delta)", "Drift")

# ------------------------------------------------------------
# 9) Optional: Trace plots (no banned args used)
# ------------------------------------------------------------
az.plot_trace(idata_ll,  var_names=["s_w","s_v"])
plt.suptitle("Local Level: Trace", y=1.02)
plt.show()

az.plot_trace(idata_st,  var_names=["s_z","s_v"])
plt.suptitle("Smooth Trend: Trace", y=1.02)
plt.show()

az.plot_trace(idata_llt, var_names=["s_w","s_z","s_v"])
plt.suptitle("Local Linear Trend: Trace", y=1.02)
plt.show()
