# 3-8-ポアソン回帰モデル

In [None]:
# -*- coding: utf-8 -*-
# Poisson regression in Python with NumPyro (translation of the provided R + Stan code)

import os
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
import arviz as az

import jax
import jax.numpy as jnp
import numpyro
import numpyro.distributions as dist
from numpyro.infer import NUTS, MCMC, Predictive
from numpyro.contrib.render import render_model

# ------------------------------------------------------------
# 0) Utilities
# ------------------------------------------------------------
SEED = 1
rng_key = jax.random.PRNGKey(SEED)

def to_jnp(x):
    return jnp.array(np.asarray(x))

# ------------------------------------------------------------
# 1) Data loading & EDA (labels in English)
# ------------------------------------------------------------
csv_path = "3-8-1-fish-num-1.csv"  # same file name as in R code
if not os.path.exists(csv_path):
    print(f"WARNING: CSV not found at {csv_path}. Please place the file in the working directory.")

df = pd.read_csv(csv_path)
print("Head (first 3 rows):")
print(df.head(3))

print("\nSummary (describe):")
print(df.describe(include="all"))

# Scatter plot: fish count vs temperature by weather (English labels)
plt.figure(figsize=(6, 4))
sns.scatterplot(data=df, x="temperature", y="fish_num", hue="weather")
plt.title("Fish Count vs Temperature by Weather")
plt.xlabel("Temperature (°C)")
plt.ylabel("Number of Fish")
plt.legend(title="Weather")
plt.tight_layout()
plt.show()

# Prepare arrays
temp_np = df["temperature"].values
fish_np = df["fish_num"].astype(int).values
sunny_np = (df["weather"].astype(str).str.lower() == "sunny").astype(int).values  # 1 if sunny else 0

N = len(df)

# Design matrix as in the Stan design-matrix example: [Intercept, temperature, sunny]
X_np = np.column_stack([np.ones(N), temp_np, sunny_np])
K = X_np.shape[1]

# ------------------------------------------------------------
# 2) Models (NumPyro)
# ------------------------------------------------------------
# (a) Explicit exp() transform (equivalent to Stan 3-8-1-glm-pois-1)
def poisson_model_exp(temp, sunny, fish_num=None):
    Intercept = numpyro.sample("Intercept", dist.Normal(0.0, 10.0))
    b_temp = numpyro.sample("b_temp", dist.Normal(0.0, 10.0))
    b_sunny = numpyro.sample("b_sunny", dist.Normal(0.0, 10.0))
    eta = Intercept + b_temp * temp + b_sunny * sunny
    rate = jnp.exp(eta)
    numpyro.deterministic("eta", eta)
    numpyro.deterministic("rate", rate)
    numpyro.sample("obs", dist.Poisson(rate), obs=fish_num)

# (b) Log-scale linear predictor version (conceptually similar to Stan's poisson_log; NumPyro has only Poisson(rate))
def poisson_model_logstyle(temp, sunny, fish_num=None):
    Intercept = numpyro.sample("Intercept", dist.Normal(0.0, 10.0))
    b_temp = numpyro.sample("b_temp", dist.Normal(0.0, 10.0))
    b_sunny = numpyro.sample("b_sunny", dist.Normal(0.0, 10.0))
    eta = Intercept + b_temp * temp + b_sunny * sunny
    numpyro.deterministic("eta", eta)
    numpyro.sample("obs", dist.Poisson(jnp.exp(eta)), obs=fish_num)

# (c) Design-matrix version (equivalent to Stan 3-8-3-glm-pois-design-matrix)
def poisson_model_design_mat(X, Y=None):
    K = X.shape[1]
    b = numpyro.sample("b", dist.Normal(jnp.zeros(K), 10.0 * jnp.ones(K)))  # includes intercept
    eta = jnp.dot(X, b)
    numpyro.deterministic("eta", eta)
    numpyro.sample("obs", dist.Poisson(jnp.exp(eta)), obs=Y)

# ------------------------------------------------------------
# 3) Fit the "brms equivalent" model: fish_num ~ weather + temperature (use model_exp)
# ------------------------------------------------------------
nuts_kernel = NUTS(poisson_model_exp)
mcmc = MCMC(nuts_kernel, num_warmup=1000, num_samples=1000, num_chains=4, progress_bar=True)
mcmc.run(
    rng_key,
    temp=to_jnp(temp_np),
    sunny=to_jnp(sunny_np),
    fish_num=to_jnp(fish_np),
)

# Get samples and print summary via ArviZ (printed using print())
posterior_samples = mcmc.get_samples()
idata = az.from_numpyro(mcmc)  # observed_data is NOT provided (per requirement)
summary_df = az.summary(idata, var_names=["Intercept", "b_temp", "b_sunny"], hdi_prob=0.95)
print("\nMCMC summary (95% HDI):")
print(summary_df)

# Helpful exponentiations like in R example
print("\nReference exponentiations (as in R):")
print("exp(-0.59) =", np.exp(-0.59))
print("exp(0.08)  =", np.exp(0.08))

# ------------------------------------------------------------
# 4) Visualize the Bayesian model structure with NumPyro's built-in function
# ------------------------------------------------------------
try:
    from IPython.display import display  # safe if running in notebook
    g = render_model(
        poisson_model_exp,
        model_args=(to_jnp(temp_np), to_jnp(sunny_np), to_jnp(fish_np)),
        render_distributions=True,
        render_params=True,
    )
    display(g)
except Exception as e:
    print(f"Model graph could not be rendered (install graphviz if needed): {e}")

# ------------------------------------------------------------
# 5) Posterior plots (ArviZ; use hdi_prob; do NOT use credible_interval)
# ------------------------------------------------------------
# Posterior distributions of coefficients
az.plot_posterior(idata, var_names=["Intercept", "b_temp", "b_sunny"], hdi_prob=0.95)
plt.suptitle("Posterior Distributions (95% HDI)", y=1.02)
plt.show()

# Forest plot (no 'group' argument; use hdi_prob)
az.plot_forest(idata, var_names=["Intercept", "b_temp", "b_sunny"], combined=True, hdi_prob=0.95)
plt.title("Forest Plot of Coefficients (95% HDI)")
plt.show()

# ------------------------------------------------------------
# 6) Posterior predictive checks on observed data
#    Use az.plot_ppc with group="posterior" (required)
# ------------------------------------------------------------
ppc = Predictive(poisson_model_exp, posterior_samples, return_sites=["obs", "eta", "rate"])(
    jax.random.PRNGKey(SEED + 10),
    temp=to_jnp(temp_np),
    sunny=to_jnp(sunny_np),
    fish_num=None,
)
idata_ppc = az.from_numpyro(posterior_predictive=ppc, prior=None)  # no observed_data, as required
# Merge PPC into main idata for convenience
idata = az.concat(idata, idata_ppc)

az.plot_ppc(idata, group="posterior", num_pp_samples=200)
plt.title("Posterior Predictive Check (Observed Scale)")
plt.show()

# ------------------------------------------------------------
# 7) "Marginal effects" style curves (temperature x weather)
#    - 95% HDI band for expected rate (lambda)
#    - 99% predictive interval for counts
# ------------------------------------------------------------
temp_grid = np.linspace(temp_np.min(), temp_np.max(), 100)

def predict_over_grid(sunny_flag, rng_key):
    """Return mean rate, 95% HDI for rate, and 99% predictive interval for counts on a temp grid."""
    predictive = Predictive(poisson_model_exp, posterior_samples, return_sites=["rate", "obs"])
    out = predictive(
        rng_key,
        temp=to_jnp(temp_grid),
        sunny=to_jnp(np.full_like(temp_grid, sunny_flag)),
        fish_num=None,
    )
    # Shapes: [draws, grid]
    rate_draws = np.asarray(out["rate"])
    y_draws = np.asarray(out["obs"])
    mean_rate = rate_draws.mean(axis=0)
    hdi95_rate = az.hdi(rate_draws, hdi_prob=0.95)        # (grid, 2)
    hdi99_y = az.hdi(y_draws, hdi_prob=0.99)              # (grid, 2)
    return mean_rate, hdi95_rate, hdi99_y

mean_rate_sunny, hdi95_rate_sunny, hdi99_y_sunny = predict_over_grid(1, jax.random.PRNGKey(SEED + 21))
mean_rate_cloud, hdi95_rate_cloud, hdi99_y_cloud = predict_over_grid(0, jax.random.PRNGKey(SEED + 22))

# 95% HDI for expected rate (lambda)
plt.figure(figsize=(7, 4))
plt.title("Poisson Regression: Expected Count (95% HDI)")
plt.xlabel("Temperature (°C)")
plt.ylabel("Expected Number of Fish")
plt.plot(temp_grid, mean_rate_cloud, label="Cloudy: mean λ")
plt.fill_between(temp_grid, hdi95_rate_cloud[:, 0], hdi95_rate_cloud[:, 1], alpha=0.25, label="Cloudy: 95% HDI (λ)")
plt.plot(temp_grid, mean_rate_sunny, label="Sunny: mean λ")
plt.fill_between(temp_grid, hdi95_rate_sunny[:, 0], hdi95_rate_sunny[:, 1], alpha=0.25, label="Sunny: 95% HDI (λ)")
plt.scatter(temp_np, fish_np, s=20, alpha=0.6, c=(sunny_np > 0), cmap="coolwarm", label="Observed")
plt.legend(title="Legend", loc="best")
plt.tight_layout()
plt.show()

# 99% predictive interval for counts
plt.figure(figsize=(7, 4))
plt.title("Poisson Regression: Predictive Interval (99%)")
plt.xlabel("Temperature (°C)")
plt.ylabel("Predicted Number of Fish")
plt.plot(temp_grid, mean_rate_cloud, label="Cloudy: mean of λ")
plt.fill_between(temp_grid, hdi99_y_cloud[:, 0], hdi99_y_cloud[:, 1], alpha=0.25, label="Cloudy: 99% PI (counts)")
plt.plot(temp_grid, mean_rate_sunny, label="Sunny: mean of λ")
plt.fill_between(temp_grid, hdi99_y_sunny[:, 0], hdi99_y_sunny[:, 1], alpha=0.25, label="Sunny: 99% PI (counts)")
plt.legend(title="Legend", loc="best")
plt.tight_layout()
plt.show()

# ------------------------------------------------------------
# 8) Alternative fits for reference (mirroring the Stan examples)
# ------------------------------------------------------------
# (b) Log-style model fit (conceptually same posterior)
mcmc_log = MCMC(NUTS(poisson_model_logstyle), num_warmup=1000, num_samples=1000, num_chains=4, progress_bar=True)
mcmc_log.run(
    jax.random.PRNGKey(SEED + 30),
    temp=to_jnp(temp_np),
    sunny=to_jnp(sunny_np),
    fish_num=to_jnp(fish_np),
)
idata_log = az.from_numpyro(mcmc_log)
print("\nSummary (log-style model, 95% HDI):")
print(az.summary(idata_log, var_names=["Intercept", "b_temp", "b_sunny"], hdi_prob=0.95))

# (c) Design-matrix model fit
mcmc_dm = MCMC(NUTS(poisson_model_design_mat), num_warmup=1000, num_samples=1000, num_chains=4, progress_bar=True)
mcmc_dm.run(
    jax.random.PRNGKey(SEED + 40),
    X=to_jnp(X_np),
    Y=to_jnp(fish_np),
)
idata_dm = az.from_numpyro(mcmc_dm)
print("\nSummary (design matrix model, 95% HDI):")
print(az.summary(idata_dm, var_names=["b"], hdi_prob=0.95))

# Done
print("\nAll done.")
