# 5-9-動的一般化線形モデル：ポアソン分布を仮定した例

In [None]:
# -*- coding: utf-8 -*-
# Dynamic GLM with Poisson observations (NumPyro version)
# Requirements:
#   pip install numpyro arviz jax jaxlib pandas matplotlib seaborn

import warnings
warnings.filterwarnings("ignore")

import os
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns

import jax
import jax.numpy as jnp
#from jax.config import config as jax_config
#jax_config.update("jax_enable_x64", True)
jax.config.update("jax_enable_x64", True)

import numpyro
import numpyro.distributions as dist
from numpyro.infer import MCMC, NUTS, Predictive

import arviz as az

In [None]:
# ------------------------------
# 1) Data loading & quick plot
# ------------------------------
# CSV読み込み（pandas使用、.locは不使用）
csv_path = "5-9-1-fish-num-ts.csv"
fish_ts = pd.read_csv(csv_path)
fish_ts["date"] = pd.to_datetime(fish_ts["date"])

# 計算結果は print() で表示
print("Head of data:")
print(fish_ts.head(3))

# 英語ラベルで可視化（matplotlib / seaborn）
fig, axes = plt.subplots(2, 1, figsize=(10, 6), sharex=True)
axes[0].plot(fish_ts["date"], fish_ts["fish_num"])
axes[0].set_ylabel("Catch count")
axes[0].set_title("Time series of catch count")

axes[1].plot(fish_ts["date"], fish_ts["temperature"])
axes[1].set_ylabel("Temperature (°C)")
axes[1].set_xlabel("Date")
axes[1].set_title("Time series of temperature")
plt.tight_layout()
plt.show()

In [None]:
# ------------------------------
# 2) Prepare arrays
# ------------------------------
y = fish_ts["fish_num"].to_numpy(dtype=np.int64)
ex = fish_ts["temperature"].to_numpy(dtype=np.float64)
T = int(fish_ts.shape[0])

print(f"Number of time points T = {T}")

In [None]:
# -----------------------------------
# 3) NumPyro model (dynamic GLM)
# -----------------------------------
def dglm_poisson(y=None, ex=None, T=None):
    # Priors
    s_z = numpyro.sample("s_z", dist.HalfNormal(1.0))   # drift sd
    s_r = numpyro.sample("s_r", dist.HalfNormal(1.0))   # random effect sd
    b   = numpyro.sample("b",   dist.Normal(0.0, 5.0))  # regression coef

    # Initial states (weakly-informative)
    mu_1 = numpyro.sample("mu_1", dist.Normal(0.0, 10.0))
    mu_2 = numpyro.sample("mu_2", dist.Normal(0.0, 10.0))
    mu_seq = [mu_1, mu_2]

    # i.i.d. random effects r_t
    with numpyro.plate("time", int(T)):
        r = numpyro.sample("r", dist.Normal(0.0, s_r))

    # State evolution: mu_t ~ Normal(2*mu_{t-1} - mu_{t-2}, s_z), for t>=3
    for t in range(2, int(T)):
        mu_hat = 2.0 * mu_seq[t-1] - mu_seq[t-2]
        mu_t = numpyro.sample(f"mu_{t+1}", dist.Normal(mu_hat, s_z))
        mu_seq.append(mu_t)

    mu = jnp.stack(mu_seq)  # shape (T,)

    # Linear predictor and Poisson likelihood
    ex_arr = jnp.array(ex)
    lam_log = mu + b * ex_arr + r
    lam = jnp.exp(lam_log)

    # Deterministic nodes (generated quantities in Stan)
    numpyro.deterministic("lambda_exp", lam)
    numpyro.deterministic("lambda_smooth", jnp.exp(mu + b * ex_arr))
    numpyro.deterministic("lambda_smooth_fix", jnp.exp(mu + b * jnp.mean(ex_arr)))

    # Observation
    if y is not None:
        numpyro.sample("y", dist.Poisson(lam), obs=jnp.array(y))
    else:
        numpyro.sample("y", dist.Poisson(lam))

In [None]:
# -----------------------------------
# 4) (Optional) visualize model graph
#     * use NumPyro's built-in rendering if available
# -----------------------------------
try:
    # NumPyro 0.14+ provides render_model at top-level
    from numpyro.render_model import render_model   # type: ignore
    dot = render_model(dglm_poisson, model_args=(), model_kwargs={"y": y, "ex": ex, "T": T})
    # In notebooks, just `dot` displays. Here, we only announce its creation.
    print("Rendered model graph (Graphviz Digraph object) created via numpyro.render_model.")
except Exception as e:
    print("Model rendering skipped (numpyro.render_model not available):", str(e))

In [None]:
# -----------------------------------
# 5) Run MCMC (matching Stan settings)
# -----------------------------------
rng_key = jax.random.PRNGKey(1)
nuts = NUTS(dglm_poisson, target_accept_prob=0.99, max_tree_depth=15)

# Stan側: iter=8000, warmup=2000, thin=6 => 有効サンプル ~1000/ch.
# こちらも num_warmup=2000, num_samples=1000, thinning=6, chains=4 に合わせる
mcmc = MCMC(nuts, num_warmup=2000, num_samples=1000, num_chains=4, thinning=6, progress_bar=True)
mcmc.run(rng_key, y=y, ex=ex, T=T)
mcmc.print_summary()  # 計算結果の表示は print() 経由

In [None]:
# -----------------------------------
# 6) Convert to ArviZ InferenceData
#     (禁止条件に従い、mcmc/observed_data/sample_stats 引数は使用しない)
# -----------------------------------
idata = az.from_numpyro(posterior=mcmc)
print("InferenceData groups:", idata.groups())

# R-hat 参考: （az.rhat は Dataset を返す）
rhat_ds = az.rhat(idata)
print("R-hat (selected):")
print(rhat_ds[["b", "s_z", "s_r"]])  # Datasetのまま print()

In [None]:
# -----------------------------------
# 7) Posterior predictive samples with Predictive
#     ※ batch_ndims=1 to indicate merged chain dimension
# -----------------------------------
rng_key_ppc = jax.random.PRNGKey(123)
#posterior_samples = mcmc.get_samples(group_by_chain=True)

#predictive = Predictive(
#    dglm_poisson,
#    posterior_samples=posterior_samples,
#    return_sites=["y", "lambda_exp", "lambda_smooth", "lambda_smooth_fix"],
#    batch_ndims=1  # チェイン結合を示す 1
#)
# 変更後（OK）
posterior_samples = mcmc.get_samples(group_by_chain=False)  # ← チェインを結合して (num_samples, …) にする
predictive = Predictive(dglm_poisson,
                        posterior_samples=posterior_samples,
                        return_sites=["y","lambda_exp","lambda_smooth","lambda_smooth_fix"],
                        batch_ndims=1)  # ← 1 で整合
ppc_samples = predictive(rng_key_ppc, y=None, ex=ex, T=T)

# InferenceData に posterior_predictive を追加（禁止引数は使わない）
idata = az.from_numpyro(
    posterior=mcmc,
    posterior_predictive=ppc_samples
)
print("Added posterior_predictive to InferenceData.")

In [None]:
# -----------------------------------
# 8) Parameter posterior visualization (ArviZ)
#     * az.plot_posterior は hdi_prob を使用（credible_interval禁止）
# -----------------------------------
az.plot_posterior(
    idata,
    var_names=["b", "s_z", "s_r"],
    hdi_prob=0.95
)
plt.suptitle("Posterior distributions (95% HDI)", y=1.02)
plt.show()

# Trace plots (参考、ArviZ)
az.plot_trace(idata, var_names=["b", "s_z", "s_r"])
plt.suptitle("Trace plots", y=1.02)
plt.show()

# Forest plot（group引数は使用しない）
az.plot_forest(idata, var_names=["b", "s_z", "s_r"])
plt.title("Forest plot of parameters")
plt.show()

In [None]:
# -----------------------------------
# 9) Posterior predictive check
#     * group="posterior" を指定
# -----------------------------------
az.plot_ppc(idata, group="posterior", data_pairs={"y": "y"})
plt.suptitle("Posterior Predictive Check", y=1.02)
plt.show()

In [None]:
# HDI を計算（Dataset が返る）
#hdi_ds = az.hdi(idata.posterior_predictive, hdi_prob=0.94)

# 変数名を指定して DataArray を取り出し、hdi 座標で lower/higher を選択（←重要）
#low_exp  = hdi_ds["lambda_exp"].sel(hdi="lower").values
#high_exp = hdi_ds["lambda_exp"].sel(hdi="higher").values

#low_smooth  = hdi_ds["lambda_smooth"].sel(hdi="lower").values
#high_smooth = hdi_ds["lambda_smooth"].sel(hdi="higher").values

#low_fix  = hdi_ds["lambda_smooth_fix"].sel(hdi="lower").values
#high_fix = hdi_ds["lambda_smooth_fix"].sel(hdi="higher").values

# 以降の fill_between は長さが一致するので OK
#fig, axes = plt.subplots(3, 1, figsize=(11, 9), sharex=True)
#axes[0].fill_between(fish_ts["date"].values, low_exp,  high_exp,  alpha=0.3, label="94% HDI")
#axes[1].fill_between(fish_ts["date"].values, low_smooth,high_smooth,alpha=0.3, label="94% HDI")
#axes[2].fill_between(fish_ts["date"].values, low_fix,  high_fix,  alpha=0.3, label="94% HDI")

In [None]:
# -----------------------------------
# 10) State visualization (lambda_exp / smooth / smooth_fix)
#      * HDI は az.hdi を使い、戻り値 Dataset から変数名指定で取り出す
# -----------------------------------
# 予測関連の配列（chain, draw, time）→ 事後中央値
lam_exp = idata.posterior_predictive["lambda_exp"].values  # (chain, draw, T)
lam_smooth = idata.posterior_predictive["lambda_smooth"].values
lam_smooth_fix = idata.posterior_predictive["lambda_smooth_fix"].values

# 事後中央値
med_exp = np.median(lam_exp.reshape(-1, T), axis=0)
med_smooth = np.median(lam_smooth.reshape(-1, T), axis=0)
med_smooth_fix = np.median(lam_smooth_fix.reshape(-1, T), axis=0)

# 94% HDI を計算（Dataset が返る）
hdi_ds = az.hdi(idata.posterior_predictive, hdi_prob=0.94)

# 変数名を指定して DataArray を取り出し、hdi 座標で lower/higher を選択（←重要）
low_exp  = hdi_ds["lambda_exp"].sel(hdi="lower").values
high_exp = hdi_ds["lambda_exp"].sel(hdi="higher").values

low_smooth  = hdi_ds["lambda_smooth"].sel(hdi="lower").values
high_smooth = hdi_ds["lambda_smooth"].sel(hdi="higher").values

low_fix  = hdi_ds["lambda_smooth_fix"].sel(hdi="lower").values
high_fix = hdi_ds["lambda_smooth_fix"].sel(hdi="higher").values

#hdi_exp = hdi_ds["lambda_exp"].values  # shape (2, T) -> [low, high]
#hdi_smooth = hdi_ds["lambda_smooth"].values
#hdi_smooth_fix = hdi_ds["lambda_smooth_fix"].values


# 図1: lambda_exp
fig, axes = plt.subplots(3, 1, figsize=(11, 9), sharex=True)

#axes[0].fill_between(fish_ts["date"], hdi_exp[0], hdi_exp[1], alpha=0.3, label="94% HDI")
axes[0].fill_between(fish_ts["date"].values, low_exp,  high_exp,  alpha=0.3, label="94% HDI")
axes[0].plot(fish_ts["date"], med_exp, label="Median of state (all effects)")
axes[0].scatter(fish_ts["date"], y, s=15, alpha=0.6, label="Observed")
axes[0].set_ylabel("Expected catch")
axes[0].set_title("State: lambda_exp (with random effect and covariate)")
axes[0].legend()

# 図2: lambda_smooth
#axes[1].fill_between(fish_ts["date"], hdi_smooth[0], hdi_smooth[1], alpha=0.3, label="94% HDI")
axes[1].fill_between(fish_ts["date"].values, low_smooth,high_smooth,alpha=0.3, label="94% HDI")
axes[1].plot(fish_ts["date"], med_smooth, label="Median of state (no random effect)")
axes[1].scatter(fish_ts["date"], y, s=15, alpha=0.6, label="Observed")
axes[1].set_ylabel("Expected catch")
axes[1].set_title("State: lambda_smooth (random effect removed)")
axes[1].legend()

# 図3: lambda_smooth_fix
#axes[2].fill_between(fish_ts["date"], hdi_smooth_fix[0], hdi_smooth_fix[1], alpha=0.3, label="94% HDI")
axes[2].fill_between(fish_ts["date"].values, low_fix,  high_fix,  alpha=0.3, label="94% HDI")
axes[2].plot(fish_ts["date"], med_smooth_fix, label="Median of state (covariate fixed)")
axes[2].scatter(fish_ts["date"], y, s=15, alpha=0.6, label="Observed")
axes[2].set_ylabel("Expected catch")
axes[2].set_xlabel("Date")
axes[2].set_title("State: lambda_smooth_fix (random effect removed, temperature fixed)")
axes[2].legend()

plt.tight_layout()
plt.show()

In [None]:
# -----------------------------------
# 11) Print a concise parameter summary (with HDI)
# -----------------------------------
summary_df = az.summary(idata, var_names=["b", "s_z", "s_r"], hdi_prob=0.95)
print("Parameter summary (95% HDI):")
print(summary_df)