# 4-3-ランダム係数モデル

In [None]:
# -*- coding: utf-8 -*-
# Random-coefficient Poisson models in NumPyro (Python rewrite of the R/brms code)

import warnings
warnings.filterwarnings("ignore")

import pandas as pd
import numpy as np
import jax.numpy as jnp
from jax import random

import numpyro
import numpyro.distributions as dist
from numpyro.infer import NUTS, MCMC, Predictive

import arviz as az
import matplotlib.pyplot as plt
import seaborn as sns

# -------------------------------------------------------------------
# 0) Setup
# -------------------------------------------------------------------
numpyro.set_platform("cpu")  # CPUでOK。GPUがあれば "gpu"
rng_master = random.PRNGKey(1)
sns.set(style="whitegrid")

# -------------------------------------------------------------------
# 1) Load data (pandas) & quick summary
# -------------------------------------------------------------------
df = pd.read_csv("4-3-1-fish-num-4.csv")
df["human"] = df["human"].astype("category")

print("Head (first 3 rows):")
print(df.head(3))

print("\nSummary (pandas describe):")
print(df.describe(include="all"))

# Basic arrays
y = df["fish_num"].astype(int).to_numpy()
temp = df["temperature"].to_numpy()
humans = df["human"].cat.categories.tolist()
K = len(humans)
human_idx = df["human"].cat.codes.to_numpy()  # 0..K-1

# -------------------------------------------------------------------
# 2) Fixed-effects Poisson GLM with interaction: fish_num ~ temperature * human
#    (treatment coding: base category = first level)
# -------------------------------------------------------------------
# Design matrices for intercept and slope adjustments (drop_first=True)
#   eta = alpha0 + beta0*temp
#         + d_int @ gamma  +  (temp[:,None] * d_int) @ delta
d_int_df = pd.get_dummies(df["human"], drop_first=True)
d_int = d_int_df.to_numpy()  # N x (K-1)
d_slope = d_int  # same shape

def model_fixed(temp, d_int, d_slope, y=None):
    N, Kminus1 = d_int.shape
    alpha0 = numpyro.sample("alpha0", dist.Normal(0.0, 10.0))
    beta0  = numpyro.sample("beta0",  dist.Normal(0.0, 5.0))
    gamma  = numpyro.sample("gamma",  dist.Normal(0.0, 5.0).expand([Kminus1]).to_event(1))
    delta  = numpyro.sample("delta",  dist.Normal(0.0, 5.0).expand([Kminus1]).to_event(1))

    eta = alpha0 + beta0 * temp
    eta = eta + jnp.dot(d_int, gamma)
    eta = eta + jnp.dot(d_slope * temp.reshape(-1, 1), delta)

    lam = jnp.exp(eta)
    numpyro.sample("y", dist.Poisson(lam), obs=y)

print("\nChosen priors for FIXED model:")
print(" alpha0 ~ Normal(0,10), beta0 ~ Normal(0,5)")
print(" gamma_k ~ Normal(0,5), delta_k ~ Normal(0,5)")

nuts_fixed = NUTS(model_fixed)
mcmc_fixed = MCMC(nuts_fixed, num_warmup=1500, num_samples=2000, num_chains=4, progress_bar=True)
rng_master, subkey = random.split(rng_master)
mcmc_fixed.run(subkey, temp=jnp.array(temp), d_int=jnp.array(d_int), d_slope=jnp.array(d_slope), y=jnp.array(y))

# Summaries (print with print())
idata_fixed = az.from_numpyro(posterior=mcmc_fixed)
print("\nMCMC summary (FIXED model):")
print(az.summary(idata_fixed, var_names=["alpha0", "beta0", "gamma", "delta"], kind="stats"))

print("\nR-hat (FIXED model):")
print(az.rhat(idata_fixed).to_dataframe())

# Posterior predictive using NumPyro's Predictive (built-in)
rng_master, subkey = random.split(rng_master)
ppc_fixed = Predictive(model_fixed, posterior_samples=mcmc_fixed.get_samples())(
    subkey, temp=jnp.array(temp), d_int=jnp.array(d_int), d_slope=jnp.array(d_slope), y=None
)
idata_fixed = az.from_numpyro(posterior=mcmc_fixed, posterior_predictive=ppc_fixed)

# Diagnostics/visualizations (ArviZ)
az.plot_trace(idata_fixed, var_names=["alpha0", "beta0"])
plt.suptitle("Trace plots (Fixed-effects model)", y=1.02)
plt.show()

az.plot_posterior(idata_fixed, var_names=["alpha0", "beta0"], hdi_prob=0.95)
plt.suptitle("Posterior (Fixed-effects model)", y=1.02)
plt.show()

az.plot_ppc(idata_fixed, group="posterior")
plt.suptitle("Posterior Predictive Check (Fixed-effects model)", y=1.02)
plt.show()

# Marginal effects: regression curves per human (posterior mean)
samples_fixed = mcmc_fixed.get_samples(group_by_chain=False)
alpha0 = np.asarray(samples_fixed["alpha0"])
beta0  = np.asarray(samples_fixed["beta0"])
gamma  = np.asarray(samples_fixed["gamma"])  # shape: (S, K-1)
delta  = np.asarray(samples_fixed["delta"])  # shape: (S, K-1)

t_grid = np.linspace(df["temperature"].min(), df["temperature"].max(), 50)

fig, ax = plt.subplots(figsize=(9, 6))
for k, h in enumerate(humans):
    # one-hot for this human in (K-1)-space (base has all zeros)
    gvec = np.zeros((1, K-1))
    if k > 0:  # not base
        gvec[0, k-1] = 1.0
    G = np.repeat(gvec, len(t_grid), axis=0)  # 50 x (K-1)

    # Compute eta for all posterior samples (vectorized)
    # (gamma @ G.T) -> (S, 50); same for delta; then add base parts
    eta_samples = (
        alpha0[:, None]
        + beta0[:, None] * t_grid[None, :]
        + (gamma @ G.T)
        + (delta @ G.T) * t_grid[None, :]
    )
    mu_samples = np.exp(eta_samples)
    mu_mean = mu_samples.mean(axis=0)
    ax.plot(t_grid, mu_mean, label=f"Human {h}")

# scatter observed data
sns.scatterplot(x="temperature", y="fish_num", hue="human", data=df, ax=ax, alpha=0.6, legend=False)
ax.set_title("Marginal effects: Poisson GLM with interaction")
ax.set_xlabel("Temperature")
ax.set_ylabel("Fish count")
ax.legend(title="Human")
plt.tight_layout()
plt.show()

# -------------------------------------------------------------------
# 3) Random-coefficient (varying intercept & slope) Poisson GLMM:
#    fish_num ~ temperature + (temperature || human)  (no correlation)
# -------------------------------------------------------------------
def model_random(temp, human_idx, n_humans, y=None):
    alpha_bar = numpyro.sample("alpha_bar", dist.Normal(0.0, 10.0))
    beta_bar  = numpyro.sample("beta_bar",  dist.Normal(0.0, 5.0))
    sigma_a   = numpyro.sample("sigma_alpha", dist.HalfNormal(1.0))
    sigma_b   = numpyro.sample("sigma_beta",  dist.HalfNormal(1.0))

    with numpyro.plate("human", n_humans):
        a_raw = numpyro.sample("alpha_raw", dist.Normal(0.0, 1.0))
        b_raw = numpyro.sample("beta_raw",  dist.Normal(0.0, 1.0))

    alpha = numpyro.deterministic("alpha", alpha_bar + sigma_a * a_raw)  # (K,)
    beta  = numpyro.deterministic("beta",  beta_bar  + sigma_b * b_raw)  # (K,)

    eta = alpha[human_idx] + beta[human_idx] * temp
    lam = jnp.exp(eta)
    numpyro.sample("y", dist.Poisson(lam), obs=y)

print("\nChosen priors for RANDOM-COEFFICIENT model:")
print(" alpha_bar ~ Normal(0,10), beta_bar ~ Normal(0,5)")
print(" sigma_alpha ~ HalfNormal(1), sigma_beta ~ HalfNormal(1)")
print(" alpha_h = alpha_bar + sigma_alpha*z_a, beta_h = beta_bar + sigma_beta*z_b (independent)")

nuts_rand = NUTS(model_random)
mcmc_rand = MCMC(nuts_rand, num_warmup=2000, num_samples=3000, num_chains=4, progress_bar=True)
rng_master, subkey = random.split(rng_master)
mcmc_rand.run(subkey, temp=jnp.array(temp), human_idx=jnp.array(human_idx), n_humans=K, y=jnp.array(y))

# Summaries (print with print())
idata_rand = az.from_numpyro(posterior=mcmc_rand)
print("\nMCMC summary (RANDOM-COEFFICIENT model):")
print(az.summary(idata_rand, var_names=["alpha_bar", "beta_bar", "sigma_alpha", "sigma_beta"], kind="stats"))

print("\nR-hat (RANDOM-COEFFICIENT model):")
print(az.rhat(idata_rand).to_dataframe())

# Posterior predictive (NumPyro Predictive)
rng_master, subkey = random.split(rng_master)
ppc_rand = Predictive(model_random, posterior_samples=mcmc_rand.get_samples())(
    subkey, temp=jnp.array(temp), human_idx=jnp.array(human_idx), n_humans=K, y=None
)
idata_rand = az.from_numpyro(posterior=mcmc_rand, posterior_predictive=ppc_rand)

# Diagnostics/visualizations (ArviZ)
az.plot_trace(idata_rand, var_names=["alpha_bar", "beta_bar", "sigma_alpha", "sigma_beta"])
plt.suptitle("Trace plots (Random-coefficient model)", y=1.02)
plt.show()

az.plot_posterior(idata_rand, var_names=["alpha_bar", "beta_bar", "sigma_alpha", "sigma_beta"], hdi_prob=0.95)
plt.suptitle("Posterior (Random-coefficient model)", y=1.02)
plt.show()

az.plot_ppc(idata_rand, group="posterior")
plt.suptitle("Posterior Predictive Check (Random-coefficient model)", y=1.02)
plt.show()

# Regression curves by human (posterior means)
samples_rand = mcmc_rand.get_samples(group_by_chain=False)
alpha_h = np.asarray(samples_rand["alpha"])  # (S, K)
beta_h  = np.asarray(samples_rand["beta"])   # (S, K)

t_grid = np.linspace(df["temperature"].min(), df["temperature"].max(), 50)
fig, ax = plt.subplots(figsize=(9, 6))
for k, h in enumerate(humans):
    eta_samples = alpha_h[:, k][:, None] + beta_h[:, k][:, None] * t_grid[None, :]
    mu_mean = np.exp(eta_samples).mean(axis=0)
    ax.plot(t_grid, mu_mean, label=f"Human {h}")

sns.scatterplot(x="temperature", y="fish_num", hue="human", data=df, ax=ax, alpha=0.6, legend=False)
ax.set_title("Regression curves: Random-coefficient Poisson GLMM")
ax.set_xlabel("Temperature")
ax.set_ylabel("Fish count")
ax.legend(title="Human")
plt.tight_layout()
plt.show()
