# 4-1-階層ベイズモデルと一般化線形混合モデルの基本

In [None]:
# -*- coding: utf-8 -*-
# 4-1 階層ベイズモデルと一般化線形混合モデルの基本（Python/NumPyro版）
# - Data I/O: pandas
# - Visualization: matplotlib / seaborn / ArviZ
# - Bayesian inference: NumPyro (NUTS)
# - Model visualization: numpyro.render_model
# - Posterior visualization: ArviZ (with hdi_prob)
# - PPC: az.plot_ppc(..., group="posterior")
# - 禁止事項遵守:
#   - az.from_numpyro(observed_data=...) を使わない
#   - az.plot_density(kind=...) を使わない
#   - az.plot_forest(group=...) を使わない
#   - az.plot_posterior は credible_interval ではなく hdi_prob を使う
#   - 計算結果の表示は print() を使用する
#   - プロットのラベルは英語表記

import warnings
warnings.filterwarnings("ignore", category=FutureWarning)

import os
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns

import arviz as az
import jax
import jax.numpy as jnp
import numpyro
from numpyro import distributions as dist
from numpyro.infer import MCMC, NUTS, Predictive
from jax.random import PRNGKey

# Reproducibility
numpyro.enable_x64()
rng_key_master = PRNGKey(1)

# ------------------------------------------------------------------------------
# 1) Load data (pandas) & light preprocessing
# ------------------------------------------------------------------------------
df = pd.read_csv("4-1-1-fish-num-2.csv")
# match R: id as factor -> category
df["id"] = df["id"].astype("category")
df["id_code"] = df["id"].cat.codes
# sunny dummy like R's "weathersunny"
df["weather"] = df["weather"].astype(str).str.lower()
df["sunny"] = (df["weather"] == "sunny").astype(int)

print("Head of data (first 3 rows):")
print(df.head(3))

N = len(df)
id_idx = df["id_code"].to_numpy()
n_id = df["id_code"].nunique()

data_list_1 = {
    "N": N,
    "fish_num": df["fish_num"].to_numpy(),
    "temp": df["temperature"].to_numpy(),
    "sunny": df["sunny"].to_numpy(),
}
print("\nData dictionary for the GLMM (Stan-analogue) style:")
print({k: (v.shape if hasattr(v, 'shape') else v) for k, v in data_list_1.items()})

# ------------------------------------------------------------------------------
# 2) Poisson GLM (no random effects)  ~ brm(fish_num ~ weather + temperature)
# ------------------------------------------------------------------------------

def model_glm_pois(sunny, temp, fish=None):
    Intercept = numpyro.sample("Intercept", dist.Normal(0.0, 10.0))
    b_sunny   = numpyro.sample("b_sunny",   dist.Normal(0.0, 10.0))
    b_temp    = numpyro.sample("b_temp",    dist.Normal(0.0, 10.0))
    eta = Intercept + b_sunny * sunny + b_temp * temp
    lam = jnp.exp(eta)
    numpyro.sample("fish_num", dist.Poisson(lam), obs=fish)

# Fit with NUTS
kernel_glm = NUTS(model_glm_pois)
mcmc_glm = MCMC(kernel_glm, num_warmup=1000, num_samples=2000, num_chains=4)
mcmc_glm.run(
    rng_key_master,
    sunny=jnp.array(data_list_1["sunny"]),
    temp=jnp.array(data_list_1["temp"]),
    fish=jnp.array(data_list_1["fish_num"]),
)
# Print summary (use print())
print("\n[Poisson GLM] Posterior summary:")
print(az.summary(az.from_numpyro(mcmc_glm), var_names=["Intercept", "b_sunny", "b_temp"], hdi_prob=0.95).to_string())

# Model visualization with numpyro built-in (render_model)
try:
    gv_glm = numpyro.render_model(
        model_glm_pois,
        model_args=(),
        model_kwargs=dict(
            sunny=jnp.array(data_list_1["sunny"]),
            temp=jnp.array(data_list_1["temp"]),
            fish=None
        ),
        render_distributions=True,
        render_params=True
    )
    # Save graph if Graphviz is available
    gv_glm.render(filename="glm_pois_model", format="pdf", cleanup=True)
    print("\n[Poisson GLM] Model graph saved as 'glm_pois_model.pdf'")
except Exception as e:
    print("\n[Poisson GLM] render_model skipped:", e)

# Posterior predictive & InferenceData (NO observed_data)
pred_glm = Predictive(model_glm_pois, posterior_samples=mcmc_glm.get_samples())
ppc_glm = pred_glm(
    PRNGKey(123),
    sunny=jnp.array(data_list_1["sunny"]),
    temp=jnp.array(data_list_1["temp"]),
    fish=None
)
idata_glm = az.from_numpyro(mcmc_glm, posterior_predictive=ppc_glm)

# Posterior plots (ArviZ, hdi_prob)
fig = az.plot_posterior(
    idata_glm,
    var_names=["Intercept", "b_sunny", "b_temp"],
    hdi_prob=0.95
)
plt.suptitle("Posterior distributions: Poisson GLM", y=1.02)
plt.show()

# PPC (ArviZ, group="posterior")
az.plot_ppc(idata_glm, group="posterior", num_pp_samples=200)
plt.title("Posterior predictive check: Poisson GLM")
plt.xlabel("Observed fish count")
plt.ylabel("Density")
plt.show()

# "Marginal effects" like brms: predict across temperature by weather with 99% PI
temp_grid = np.linspace(df["temperature"].min(), df["temperature"].max(), 100)
weather_levels = [0, 1]  # 0=not sunny, 1=sunny

samples_glm = mcmc_glm.get_samples()
Intercept_s = np.array(samples_glm["Intercept"])
b_sunny_s   = np.array(samples_glm["b_sunny"])
b_temp_s    = np.array(samples_glm["b_temp"])

def predict_pi(temp_vals, sunny_flag):
    eta = (Intercept_s[:, None]
           + b_sunny_s[:, None] * sunny_flag
           + b_temp_s[:, None] * temp_vals[None, :])
    lam = np.exp(eta)
    # Predictive (Poisson) 99% PI via sampling:
    rng = np.random.default_rng(42)
    y_rep = rng.poisson(lam)
    q_low, q_med, q_hi = np.quantile(y_rep, [0.005, 0.5, 0.995], axis=0)
    return q_low, q_med, q_hi

plt.figure(figsize=(7,5))
for sunny_flag, label in zip(weather_levels, ["Cloudy/Other", "Sunny"]):
    ql, qm, qh = predict_pi(temp_grid, sunny_flag)
    plt.fill_between(temp_grid, ql, qh, alpha=0.2, label=f"{label} (99% PI)")
    plt.plot(temp_grid, qm, label=f"{label} median")

sns.scatterplot(
    data=df,
    x="temperature",
    y="fish_num",
    hue="weather",
    alpha=0.6
)
plt.title("Predicted fish count by temperature and weather (Poisson GLM)")
plt.xlabel("Temperature")
plt.ylabel("Fish count")
plt.legend()
plt.tight_layout()
plt.show()

# ------------------------------------------------------------------------------
# 3) Stan GLMM analogue: observation-level random effect (Poisson-lognormal)
#    Equivalent to the provided Stan where r ~ Normal(0, sigma_r) per observation
# ------------------------------------------------------------------------------

def model_glmm_overdisp(sunny, temp, fish=None):
    Intercept = numpyro.sample("Intercept", dist.Normal(0.0, 10.0))
    b_temp    = numpyro.sample("b_temp",    dist.Normal(0.0, 10.0))
    b_sunny   = numpyro.sample("b_sunny",   dist.Normal(0.0, 10.0))
    sigma_r   = numpyro.sample("sigma_r",   dist.HalfNormal(1.0))
    with numpyro.plate("obs", sunny.shape[0]):
        r = numpyro.sample("r", dist.Normal(0.0, sigma_r))
        eta = Intercept + b_sunny * sunny + b_temp * temp + r
        lam = jnp.exp(eta)
        numpyro.sample("fish_num", dist.Poisson(lam), obs=fish)

kernel_over = NUTS(model_glmm_overdisp)
mcmc_over = MCMC(kernel_over, num_warmup=1000, num_samples=2000, num_chains=4)
mcmc_over.run(
    PRNGKey(2),
    sunny=jnp.array(data_list_1["sunny"]),
    temp=jnp.array(data_list_1["temp"]),
    fish=jnp.array(data_list_1["fish_num"]),
)
print("\n[GLMM (overdispersion)] Posterior summary:")
print(az.summary(
    az.from_numpyro(mcmc_over),
    var_names=["Intercept", "b_sunny", "b_temp", "sigma_r"],
    hdi_prob=0.95
).to_string())

# Check R-hat only (like mcmc_rhat in R)
rhat_over = az.summary(az.from_numpyro(mcmc_over), hdi_prob=0.95)[["r_hat"]]
print("\n[GLMM (overdispersion)] R-hat diagnostics:")
print(rhat_over.to_string())

# Trace-like plots & forest (no 'group' arg)
az.plot_forest(az.from_numpyro(mcmc_over), var_names=["Intercept", "b_sunny", "b_temp", "sigma_r"])
plt.title("Forest plot: GLMM (overdispersion)")
plt.show()

# Model visualization
try:
    gv_over = numpyro.render_model(
        model_glmm_overdisp,
        model_kwargs=dict(
            sunny=jnp.array(data_list_1["sunny"]),
            temp=jnp.array(data_list_1["temp"]),
            fish=None
        ),
        render_distributions=True,
        render_params=True
    )
    gv_over.render(filename="glmm_overdisp_model", format="pdf", cleanup=True)
    print("[GLMM (overdispersion)] Model graph saved as 'glmm_overdisp_model.pdf'")
except Exception as e:
    print("[GLMM (overdispersion)] render_model skipped:", e)

# Posterior predictive & PPC
pred_over = Predictive(model_glmm_overdisp, posterior_samples=mcmc_over.get_samples())
ppc_over = pred_over(
    PRNGKey(222),
    sunny=jnp.array(data_list_1["sunny"]),
    temp=jnp.array(data_list_1["temp"]),
    fish=None
)
idata_over = az.from_numpyro(mcmc_over, posterior_predictive=ppc_over)

az.plot_posterior(idata_over, var_names=["Intercept", "b_sunny", "b_temp", "sigma_r"], hdi_prob=0.95)
plt.suptitle("Posterior distributions: GLMM (overdispersion)", y=1.02)
plt.show()

az.plot_ppc(idata_over, group="posterior", num_pp_samples=200)
plt.title("Posterior predictive check: GLMM (overdispersion)")
plt.xlabel("Observed fish count")
plt.ylabel("Density")
plt.show()

# ------------------------------------------------------------------------------
# 4) GLMM with random intercept by ID: fish_num ~ weather + temp + (1 | id)
# ------------------------------------------------------------------------------

def model_glmm_by_id(sunny, temp, id_idx, n_id, fish=None):
    Intercept = numpyro.sample("Intercept", dist.Normal(0.0, 10.0))
    b_temp    = numpyro.sample("b_temp",    dist.Normal(0.0, 10.0))
    b_sunny   = numpyro.sample("b_sunny",   dist.Normal(0.0, 10.0))

    sigma_id  = numpyro.sample("sigma_id",  dist.HalfNormal(1.0))
    with numpyro.plate("id", n_id):
        r_id = numpyro.sample("r_id", dist.Normal(0.0, sigma_id))

    eta = Intercept + b_sunny * sunny + b_temp * temp + r_id[id_idx]
    lam = jnp.exp(eta)
    numpyro.sample("fish_num", dist.Poisson(lam), obs=fish)

kernel_id = NUTS(model_glmm_by_id)
mcmc_id = MCMC(kernel_id, num_warmup=1000, num_samples=2000, num_chains=4)
mcmc_id.run(
    PRNGKey(3),
    sunny=jnp.array(data_list_1["sunny"]),
    temp=jnp.array(data_list_1["temp"]),
    id_idx=jnp.array(id_idx),
    n_id=n_id,
    fish=jnp.array(data_list_1["fish_num"]),
)
print("\n[GLMM (random intercept by id)] Posterior summary:")
print(az.summary(
    az.from_numpyro(mcmc_id),
    var_names=["Intercept", "b_sunny", "b_temp", "sigma_id"],
    hdi_prob=0.95
).to_string())

# Model visualization
try:
    gv_id = numpyro.render_model(
        model_glmm_by_id,
        model_kwargs=dict(
            sunny=jnp.array(data_list_1["sunny"]),
            temp=jnp.array(data_list_1["temp"]),
            id_idx=jnp.array(id_idx),
            n_id=n_id,
            fish=None
        ),
        render_distributions=True,
        render_params=True
    )
    gv_id.render(filename="glmm_by_id_model", format="pdf", cleanup=True)
    print("[GLMM (by id)] Model graph saved as 'glmm_by_id_model.pdf'")
except Exception as e:
    print("[GLMM (by id)] render_model skipped:", e)

# Posterior predictive & PPC
pred_id = Predictive(model_glmm_by_id, posterior_samples=mcmc_id.get_samples())
ppc_id = pred_id(
    PRNGKey(333),
    sunny=jnp.array(data_list_1["sunny"]),
    temp=jnp.array(data_list_1["temp"]),
    id_idx=jnp.array(id_idx),
    n_id=n_id,
    fish=None
)
idata_id = az.from_numpyro(mcmc_id, posterior_predictive=ppc_id)

az.plot_posterior(idata_id, var_names=["Intercept", "b_sunny", "b_temp", "sigma_id"], hdi_prob=0.95)
plt.suptitle("Posterior distributions: GLMM (random intercept by id)", y=1.02)
plt.show()

az.plot_ppc(idata_id, group="posterior", num_pp_samples=200)
plt.title("Posterior predictive check: GLMM (random intercept by id)")
plt.xlabel("Observed fish count")
plt.ylabel("Density")
plt.show()

print("\nAll modeling finished.")
