# 5-4-時変係数モデル

In [None]:
# -*- coding: utf-8 -*-
# Time-varying coefficient model in Python (NumPyro + ArviZ)
# - CSV loading: pandas
# - Visualization: matplotlib / ArviZ
# - Bayesian inference: NumPyro (NUTS)
# - Model diagram: numpyro.render_model
# - Posterior visualization: ArviZ (plot_posterior with hdi_prob, plot_ppc with group="posterior")
# - No use of az.from_numpyro(observed_data=..., mcmc=..., sample_stats=...)

import os
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import arviz as az

import jax
import jax.numpy as jnp
import numpyro
import numpyro.distributions as dist
from numpyro.infer import MCMC, NUTS, Predictive

# -----------------------------
# 0) Reproducibility / JAX setup
# -----------------------------
rng_key = jax.random.PRNGKey(1)

# -----------------------------
# 1) Data loading and plotting
# -----------------------------
csv_path = "5-4-1-sales-ts-2.csv"
df = pd.read_csv(csv_path)
df["date"] = pd.to_datetime(df["date"])
df = df.sort_values("date").reset_index(drop=True)

print("Head (3 rows):")
print(df.head(3))

# Basic time series plots (English labels)
fig, axes = plt.subplots(2, 1, figsize=(10, 6), sharex=True)
axes[0].plot(df["date"], df["sales"])
axes[0].set_title("Sales over Time")
axes[0].set_ylabel("Sales")

axes[1].plot(df["date"], df["publicity"])
axes[1].set_title("Publicity over Time")
axes[1].set_ylabel("Publicity")
axes[1].set_xlabel("Date")

plt.tight_layout()
plt.show()

# JAX arrays
x_full = jnp.asarray(df["publicity"].values)
y_full = jnp.asarray(df["sales"].values)
T = len(df)

# --------------------------------------------
# 2) Simple Bayesian linear regressions (NumPyro)
#    - Full period
#    - First 50 rows
#    - Last 50 rows
# --------------------------------------------
def model_lm(ex, y=None):
    intercept = numpyro.sample("intercept", dist.Normal(0.0, 10.0))
    beta = numpyro.sample("beta", dist.Normal(0.0, 10.0))
    sigma = numpyro.sample("sigma", dist.HalfNormal(1.0))
    mu = intercept + beta * ex
    numpyro.sample("y", dist.Normal(mu, sigma), obs=y)

def fit_lm(ex, y, key, warmup=2000, samples=1000, thinning=6, chains=2):
    kernel = NUTS(model_lm)
    mcmc = MCMC(kernel, num_warmup=warmup, num_samples=samples, thinning=thinning, num_chains=chains, progress_bar=False)
    mcmc.run(key, ex=ex, y=y)
    # ❌ NG（dict を渡していた）
    # posterior = mcmc.get_samples(group_by_chain=True)
    # idata = az.from_numpyro(posterior=posterior)

    # ✅ OK（MCMC オブジェクトを渡す）
    idata = az.from_numpyro(posterior=mcmc)
    return mcmc, idata
    return mcmc, idata

# Full
mcmc_full, idata_full = fit_lm(x_full, y_full, rng_key)
print("\nSimple Bayesian linear regression (full) — posterior summary:")
print(az.summary(idata_full, var_names=["intercept", "beta", "sigma"]).to_string())

# First 50
n_head = 50
x_head = jnp.asarray(df["publicity"].values[:n_head])
y_head = jnp.asarray(df["sales"].values[:n_head])
mcmc_head, idata_head = fit_lm(x_head, y_head, jax.random.split(rng_key, 2)[0])
print("\nSimple Bayesian linear regression (first 50) — posterior summary:")
print(az.summary(idata_head, var_names=["intercept", "beta", "sigma"]).to_string())

# Last 50
x_tail = jnp.asarray(df["publicity"].values[-n_head:])
y_tail = jnp.asarray(df["sales"].values[-n_head:])
mcmc_tail, idata_tail = fit_lm(x_tail, y_tail, jax.random.split(rng_key, 2)[1])
print("\nSimple Bayesian linear regression (last 50) — posterior summary:")
print(az.summary(idata_tail, var_names=["intercept", "beta", "sigma"]).to_string())

# Visualize the simple LM model structure (NumPyro built-in)
# Saved as PDF if graphviz is available.
try:
    dot_lm = numpyro.render_model(
        model_lm,
        model_args=(x_full,),
        model_kwargs={"y": None},
        render_distributions=True,
        render_params=True,
    )
    out_path = dot_lm.render("lm_model", format="pdf")
    print(f"\nSaved NumPyro model diagram for simple LM to: {out_path}")
except Exception as e:
    print(f"\n[Info] Could not render simple LM model diagram (install graphviz to enable). Reason: {e}")

# -------------------------------------------------------
# 3) Time-varying coefficient model (state-space / DLM)
#    mu_t = mu_{t-1} + eps_mu_t,     eps_mu_t ~ N(0, s_w)
#    b_t  = b_{t-1}  + eps_b_t,      eps_b_t  ~ N(0, s_t)
#    y_t  ~ N(mu_t + b_t * x_t, s_v)
# -------------------------------------------------------
def model_tvc(ex, y=None):
    T = ex.shape[0]
    s_w = numpyro.sample("s_w", dist.HalfNormal(1.0))
    s_t = numpyro.sample("s_t", dist.HalfNormal(1.0))
    s_v = numpyro.sample("s_v", dist.HalfNormal(1.0))

    mu0 = numpyro.sample("mu0", dist.Normal(0.0, 10.0))
    b0  = numpyro.sample("b0",  dist.Normal(0.0, 10.0))

    # Random walk innovations (vectorized)
    #eps_mu = numpyro.sample("eps_mu", dist.Normal(0.0, s_w).expand([T - 1]).to_event(1))
    #eps_b  = numpyro.sample("eps_b",  dist.Normal(0.0, s_t).expand([T - 1]).to_event(1))

    #mu = jnp.concatenate([jnp.array([mu0]), mu0 + jnp.cumsum(eps_mu)])
    #b  = jnp.concatenate([jnp.array([b0 ]), b0  + jnp.cumsum(eps_b)])

    # draw standard normals over time, then scale by s_w / s_t
    #eps_mu_std = numpyro.sample("eps_mu_std", dist.Normal(0.0, 1.0).expand([T - 1]).to_event(1))
    #eps_b_std  = numpyro.sample("eps_b_std",  dist.Normal(0.0, 1.0).expand([T - 1]).to_event(1))

    #eps_mu = eps_mu_std * s_w
    #eps_b  = eps_b_std  * s_t

    #mu = jnp.concatenate([jnp.array([mu0]), mu0 + jnp.cumsum(eps_mu, axis=-1)])
    #b  = jnp.concatenate([jnp.array([b0 ]), b0  + jnp.cumsum(eps_b , axis=-1)])

    # 標準正規（時間方向）を引く
    eps_mu_std = numpyro.sample("eps_mu_std", dist.Normal(0.0, 1.0).expand([T - 1]).to_event(1))
    eps_b_std  = numpyro.sample("eps_b_std",  dist.Normal(0.0, 1.0).expand([T - 1]).to_event(1))

    # ★ ここがポイント：末尾に軸を足してからスケール
    eps_mu = eps_mu_std * s_w[..., None]   # (… , T-1)
    eps_b  = eps_b_std  * s_t[..., None]   # (… , T-1)

    # 時系列（最後の軸が時間になるよう結合）
    mu = jnp.concatenate([mu0[..., None], mu0[..., None] + jnp.cumsum(eps_mu, axis=-1)], axis=-1)  # (… , T)
    b  = jnp.concatenate([b0[..., None],  b0[..., None]  + jnp.cumsum(eps_b,  axis=-1)], axis=-1)  # (… , T)

    alpha = mu + b * ex  # state
    numpyro.deterministic("mu", mu)
    numpyro.deterministic("b", b)
    numpyro.deterministic("alpha", alpha)

    numpyro.sample("y", dist.Normal(alpha, s_v), obs=y)

# Fit TVC model
nuts_tvc = NUTS(model_tvc)
mcmc_tvc = MCMC(nuts_tvc, num_warmup=2000, num_samples=1000, thinning=6, num_chains=2, progress_bar=False)
mcmc_tvc.run(jax.random.PRNGKey(2), ex=x_full, y=y_full)

# NumPyro built-in: model diagram for TVC
try:
    dot_tvc = numpyro.render_model(
        model_tvc,
        model_args=(x_full,),
        model_kwargs={"y": None},
        render_distributions=True,
        render_params=True,
    )
    out_path_tvc = dot_tvc.render("tvc_model", format="pdf")
    print(f"\nSaved NumPyro model diagram for TVC to: {out_path_tvc}")
except Exception as e:
    print(f"\n[Info] Could not render TVC model diagram (install graphviz to enable). Reason: {e}")

# Posterior (grouped by chain to keep chain dimension for ArviZ)
#posterior_tvc = mcmc_tvc.get_samples(group_by_chain=True)

# Get deterministic states & posterior predictive from posterior draws
#pred_sites = ["mu", "b", "alpha", "y"]
#pred_gen = Predictive(model_tvc, posterior_tvc, return_sites=pred_sites)
#ppc = pred_gen(jax.random.PRNGKey(3), ex=x_full, y=None)  # y is generated
posterior_tvc = mcmc_tvc.get_samples(group_by_chain=True)

pred_sites = ["mu", "b", "alpha", "y"]
pred_gen = Predictive(
    model_tvc,
    posterior_tvc,
    return_sites=pred_sites,
    batch_ndims=2,                    # ★ 重要：チェーンとサンプルの2バッチ次元
)
ppc = pred_gen(jax.random.PRNGKey(3), ex=x_full, y=None)


# Build InferenceData WITHOUT observed_data/mcmc/sample_stats
coords = {"time": np.arange(T)}
dims = {"mu": ["time"], "b": ["time"], "alpha": ["time"], "y": ["time"]}

# ❌ NG（dict を posterior に渡していた）
# idata_tvc = az.from_numpyro(
#     posterior=posterior_tvc,
#     posterior_predictive={"y": ppc["y"]},
#     predictions={"mu": ppc["mu"], "b": ppc["b"], "alpha": ppc["alpha"]},
#     coords=coords, dims=dims,
# )

# ✅ OK（posterior は MCMC オブジェクトを渡す）
idata_tvc = az.from_numpyro(
    posterior=mcmc_tvc,
    posterior_predictive={"y": ppc["y"]},
    predictions={"mu": ppc["mu"], "b": ppc["b"], "alpha": ppc["alpha"]},
    coords=coords,
    dims=dims,
)

# -------------------------------------------
# 4) Print results analogous to the R/Stan code
# -------------------------------------------
print("\nTime-varying model — posterior summary for scale parameters:")
print(az.summary(idata_tvc, var_names=["s_w", "s_t", "s_v"]).to_string())

# Print b[100] (i.e., 100th time point; index 99). Use predictions 'b'.
if T >= 100:
    b_draws = np.asarray(idata_tvc.predictions["b"]).transpose(0, 1, 2)  # (chain, draw, time)
    b_100 = b_draws[:, :, 99].reshape(-1)
    b_100_mean = float(np.mean(b_100))
    b_100_hdi = az.hdi(b_100, hdi_prob=0.95)
    print("\nPosterior of b[100] (mean and 95% HDI):")
    print(f"mean = {b_100_mean:.4f}, hdi_low = {float(b_100_hdi[0]):.4f}, hdi_high = {float(b_100_hdi[1]):.4f}")
else:
    print(f"\nDataset has only {T} rows; skipping b[100] summary.")

# R-hat (convergence) for key parameters
print("\nR-hat for key parameters (s_w, s_t, s_v):")
print(az.rhat(idata_tvc, var_names=["s_w", "s_t", "s_v"]))

# -------------------------------------------
# 5) Diagnostic & posterior plots (ArviZ)
#    - We obey all given plotting constraints.
# -------------------------------------------
# Trace plots for scale params
az.plot_trace(idata_tvc, var_names=["s_w", "s_t", "s_v"])
plt.tight_layout()
plt.show()

# Posterior plots with hdi_prob
az.plot_posterior(idata_tvc, var_names=["s_w", "s_t", "s_v"], hdi_prob=0.95)
plt.tight_layout()
plt.show()

# Posterior predictive check (group must be "posterior")
az.plot_ppc(idata_tvc, group="posterior")
plt.tight_layout()
plt.show()

# -------------------------------------------
# 6) Plot estimated states over time (mu, alpha, b)
#    - Use posterior means + 95% HDIs computed from predictions
# -------------------------------------------
def mean_and_hdi(x_chain_draw_time, hdi=0.95):
    # x shape: (chain, draw, time)
    samples = x_chain_draw_time.reshape(-1, x_chain_draw_time.shape[-1])  # (samples, time)
    mean = np.mean(samples, axis=0)
    hdi_band = az.hdi(samples, hdi_prob=hdi)  # shape: (time, 2)
    return mean, hdi_band

mu_draws = np.asarray(idata_tvc.predictions["mu"])      # (chain, draw, time)
alpha_draws = np.asarray(idata_tvc.predictions["alpha"])# (chain, draw, time)
b_draws = np.asarray(idata_tvc.predictions["b"])        # (chain, draw, time)

mu_mean, mu_hdi = mean_and_hdi(mu_draws, 0.95)
alpha_mean, alpha_hdi = mean_and_hdi(alpha_draws, 0.95)
b_mean, b_hdi = mean_and_hdi(b_draws, 0.95)

time_axis = df["date"].values

def plot_state(title, ylabel, mean, hdi, x=time_axis):
    fig, ax = plt.subplots(figsize=(10, 3.2))
    ax.plot(x, mean, label="Posterior mean")
    ax.fill_between(x, hdi[:, 0], hdi[:, 1], alpha=0.3, label="95% HDI")
    ax.set_title(title)
    ax.set_ylabel(ylabel)
    ax.set_xlabel("Date")
    ax.legend()
    plt.tight_layout()
    plt.show()

plot_state("Estimated state: alpha (mu + b * publicity)", "Sales (state)", alpha_mean, alpha_hdi)
plot_state("Estimated baseline: mu", "Sales (baseline)", mu_mean, mu_hdi)
plot_state("Evolution of publicity effect: b", "Coefficient", b_mean, b_hdi)

print("\nDone.")
