# 4-2-ランダム切片モデル

In [None]:
# -*- coding: utf-8 -*-
# Random-intercept Poisson GLMM with NumPyro
# Converted from R (brms/Stan) to Python (NumPyro)

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
import arviz as az

import jax
import jax.numpy as jnp
import numpyro
import numpyro.distributions as dist
from numpyro.infer import NUTS, MCMC, Predictive

# ---------------------------------------------------------------------
# 0. NumPyro setup
# ---------------------------------------------------------------------
numpyro.set_platform("cpu")  # CPUでOK（環境に合わせて変更可）
numpyro.set_host_device_count(1)  # ポータブルにするため1に固定

# ---------------------------------------------------------------------
# 1. Load data
# ---------------------------------------------------------------------
# CSVファイル名はRコードに合わせています（必要ならパスを調整してください）
df = pd.read_csv("4-2-1-fish-num-3.csv")

print("Head (first 3 rows):")
print(df.head(3).to_string())

print("\nData summary (pandas describe):")
# include='all' でカテゴリ列も要約
print(df.describe(include='all').transpose().to_string())

# 列名の想定：fish_num, weather, temperature, human
# 型を明示的にカテゴリ/数値に整える
df["weather"] = df["weather"].astype("category")
df["human"] = df["human"].astype("category")
df["temperature"] = pd.to_numeric(df["temperature"], errors="coerce")
df["fish_num"] = pd.to_numeric(df["fish_num"], errors="coerce")

# 欠損があれば落とす（必要に応じて前処理を調整）
df = df.dropna(subset=["fish_num", "weather", "temperature", "human"]).reset_index(drop=True)

# カテゴリのレベル情報を確保（表示用）
weather_levels = df["weather"].cat.categories.tolist()
human_levels = df["human"].cat.categories.tolist()

# モデル用の整数インデックス
df["weather_idx"] = df["weather"].cat.codes
df["human_idx"] = df["human"].cat.codes

# NumPy/JAX配列
y = df["fish_num"].to_numpy(dtype=np.int32)
temp = df["temperature"].to_numpy(dtype=np.float32)
w_idx = df["weather_idx"].to_numpy(dtype=np.int32)
h_idx = df["human_idx"].to_numpy(dtype=np.int32)

N = y.shape[0]
W = len(weather_levels)
H = len(human_levels)

print(f"\nN={N}, W={W} weather levels, H={H} observers")
print(f"Weather levels: {weather_levels}")
print(f"Observer levels: {human_levels}")

# ---------------------------------------------------------------------
# 2. NumPyro model: Poisson GLMM with random intercepts for 'human'
#    fish_num ~ weather + temperature + (1 | human)
# ---------------------------------------------------------------------
def model(fish_num=None, temperature=None, weather_idx=None, human_idx=None, W=1, H=1, N=None):
    N_data = N if N is not None else (temperature.shape[0] if temperature is not None else fish_num.shape[0])
    # Intercept
    intercept = numpyro.sample("Intercept", dist.Normal(0.0, 5.0))
    # Slope for temperature
    beta_temp = numpyro.sample("beta_temperature", dist.Normal(0.0, 1.0))

    # Fixed effects for weather (baseline coded: first level is 0)
    if W > 1:
        beta_weather = numpyro.sample(
            "beta_weather",
            dist.Normal(jnp.zeros(W - 1), jnp.ones(W - 1)),
        )
        beta_weather_full = jnp.concatenate([jnp.array([0.0]), beta_weather])
    else:
        beta_weather_full = jnp.array([0.0])  # single level

    # Random intercepts for human
    sigma_human = numpyro.sample("sigma_human", dist.HalfNormal(1.0))
    with numpyro.plate("human", H):
        z_human = numpyro.sample("z_human", dist.Normal(0.0, 1.0))
    b_human = numpyro.deterministic("b_human", z_human * sigma_human)

    # Linear predictor and mean
    eta = intercept + beta_temp * temperature + beta_weather_full[weather_idx] + b_human[human_idx]
    mu = jnp.exp(eta)

    # muを保存しておくと予測曲線やPPCの集約に便利
    numpyro.deterministic("mu", mu)

    # Likelihood
    with numpyro.plate("data", N_data):
        numpyro.sample("obs", dist.Poisson(mu), obs=fish_num)

# ---------------------------------------------------------------------
# 3. Inference (NUTS)
# ---------------------------------------------------------------------
rng_key = jax.random.PRNGKey(1)
kernel = NUTS(model)
# 互換性のためチェーンは順次実行
mcmc = MCMC(kernel, num_warmup=1000, num_samples=2000, num_chains=2, chain_method="sequential", progress_bar=True)
mcmc.run(rng_key, fish_num=y, temperature=temp, weather_idx=w_idx, human_idx=h_idx, W=W, H=H)

# 計算結果の表示はprint()を使用
print("\nMCMC summary (core parameters):")
idata_core = az.from_numpyro(mcmc)  # observed_dataは渡さない（禁止要件順守）
print(
    az.summary(
        idata_core,
        var_names=["Intercept", "beta_temperature", "beta_weather", "sigma_human"],
        kind="stats",
    ).to_string()
)

# ランダム効果（各humanの切片）の要約を出力
# idata_core.posterior["b_human"] の形状: (chain, draw, human)
b = idata_core.posterior["b_human"].values  # numpy array: (C, S, H)
b_flat = b.reshape(-1, b.shape[-1])  # (C*S, H)
ranef_df = pd.DataFrame({
    "Observer": human_levels,
    "mean": b_flat.mean(axis=0),
    "sd": b_flat.std(axis=0),
    "hdi_3%": az.hdi(b_flat, hdi_prob=0.94)[:, 0],
    "hdi_97%": az.hdi(b_flat, hdi_prob=0.94)[:, 1],
})
print("\nRandom intercepts (per observer):")
print(ranef_df.to_string(index=False))

# ---------------------------------------------------------------------
# 4. Visualize the model structure (NumPyro built-in)
#    ※ 環境にgraphvizがない場合はスキップ
# ---------------------------------------------------------------------
try:
    from numpyro import render_model
    # 観測データを渡して構造図を作る
    g = render_model(
        model,
        model_args=(y, temp, w_idx, h_idx, W, H),
        render_distributions=True,
        render_params=True,
    )
    g.render("random_intercept_poisson_glmm", format="png", cleanup=True)
    print('\nSaved model graph to "random_intercept_poisson_glmm.png"')
except Exception as e:
    print(f"\nModel rendering skipped: {e}")

# ---------------------------------------------------------------------
# 5. Posterior predictive & InferenceData for ArviZ
# ---------------------------------------------------------------------
# 事後予測（obs と mu を返す）
predictive = Predictive(model, posterior_samples=mcmc.get_samples(), return_sites=["obs", "mu"])
pp_rng = jax.random.PRNGKey(2)
pp_dict = predictive(pp_rng, fish_num=None, temperature=temp, weather_idx=w_idx, human_idx=h_idx, W=W, H=H)

# InferenceDataへ（observed_dataは渡さない）
idata = az.from_numpyro(mcmc, posterior_predictive={"obs": np.array(pp_dict["obs"]), "mu": np.array(pp_dict["mu"])})

# ---------------------------------------------------------------------
# 6. Diagnostics & posterior visualization (ArviZ)  ※ラベルは英語
# ---------------------------------------------------------------------
# Trace
az.plot_trace(idata, var_names=["Intercept", "beta_temperature", "beta_weather", "sigma_human", "b_human"])
plt.suptitle("Trace Plots", y=1.02)
plt.show()

# Posterior (use hdi_prob, NOT credible_interval)
az.plot_posterior(
    idata,
    var_names=["Intercept", "beta_temperature", "sigma_human"],
    hdi_prob=0.94,
)
plt.suptitle("Posterior Distributions (Core Parameters)", y=1.02)
plt.show()

if W > 1:
    az.plot_posterior(idata, var_names=["beta_weather"], hdi_prob=0.94)
    plt.suptitle("Posterior Distributions (Weather Effects)", y=1.02)
    plt.show()

# Density (do NOT pass 'kind' arg)
az.plot_density(idata, var_names=["Intercept", "beta_temperature", "sigma_human"])
plt.suptitle("Posterior Density (Selected Parameters)", y=1.02)
plt.show()

# Forest (do NOT pass 'group' arg)
az.plot_forest(idata, var_names=["Intercept", "beta_temperature", "sigma_human"], combined=True)
plt.title("Forest Plot (Core Parameters)")
plt.show()

# Posterior predictive check (group="posterior" per requirement)
az.plot_ppc(idata, group="posterior")
plt.suptitle("Posterior Predictive Check", y=1.02)
plt.show()

# ---------------------------------------------------------------------
# 7. Random-intercept regression curves like brms::marginal_effects
#    - Draw curves over temperature per Weather x Observer
#    - Use posterior mean of mu and 94% HDI ribbons
# ---------------------------------------------------------------------
# Temperature grid
t_min, t_max = df["temperature"].min(), df["temperature"].max()
t_grid = np.linspace(t_min, t_max, 50)

# 全組み合わせのグリッドを作成（Weather × Observer × Temperature）
grid_list = []
for w in range(W):
    for h in range(H):
        tmp = pd.DataFrame({
            "temperature": t_grid,
            "weather_idx": w,
            "human_idx": h,
            "Weather": weather_levels[w],
            "Observer": human_levels[h],
        })
        grid_list.append(tmp)
grid_df = pd.concat(grid_list, ignore_index=True)

# 1回のPredictive呼び出しで全グリッドのmuを取得
pred = Predictive(model, posterior_samples=mcmc.get_samples(), return_sites=["mu"])(
    jax.random.PRNGKey(3),
    fish_num=None,
    temperature=jnp.array(grid_df["temperature"].to_numpy(dtype=np.float32)),
    weather_idx=jnp.array(grid_df["weather_idx"].to_numpy(dtype=np.int32)),
    human_idx=jnp.array(grid_df["human_idx"].to_numpy(dtype=np.int32)),
    W=W,
    H=H,
)

mu_samps = np.array(pred["mu"])  # shape: (samples, grid_size)
mu_mean = mu_samps.mean(axis=0)
mu_hdi = az.hdi(mu_samps, hdi_prob=0.94)  # shape: (grid_size, 2)

grid_df["Expected count"] = mu_mean
grid_df["HDI_low"] = mu_hdi[:, 0]
grid_df["HDI_high"] = mu_hdi[:, 1]

# Facet by Observer, color by Weather
sns.set(style="whitegrid")
g = sns.FacetGrid(grid_df, col="Observer", col_wrap=5, hue="Weather", sharey=False, height=3.0, aspect=1.2)
# ribbon
def _ribbon(x, ylow, yhigh, **kwargs):
    plt.fill_between(x, ylow, yhigh, alpha=0.2, **kwargs)

g.map(_ribbon, "temperature", "HDI_low", "HDI_high")
g.map(plt.plot, "temperature", "Expected count")
g.add_legend(title="Weather")
g.set_axis_labels("Temperature (°C)", "Expected fish count")
g.set_titles(col_template="Observer: {col_name}")
plt.subplots_adjust(top=0.9)
plt.suptitle("Random-Intercept Regression Curves", y=1.03)
plt.show()

# 参考として、生データの散布図（色=Weather、面=Observer）
sns.lmplot(
    data=df,
    x="temperature",
    y="fish_num",
    hue="weather",
    col="human",
    col_wrap=5,
    fit_reg=False,
    height=3.0,
    scatter_kws={"alpha": 0.6},
)
plt.subplots_adjust(top=0.9)
plt.suptitle("Observed Data by Observer (color = Weather)", y=1.02)
plt.xlabel("Temperature (°C)")
plt.ylabel("Fish count")
plt.show()

# ---------------------------------------------------------------------
# 8. Print final notes
# ---------------------------------------------------------------------
print("\nFinished: model fitted, summaries printed, and figures displayed.")
