# 3-4-デザイン行列を用いた一般化線形モデルの推定

In [None]:
# glm_design_matrix_numpyro.py
# R+Stan の「デザイン行列を用いた一般化線形モデル」を Python+NumPyro に移植
# - CSV読み込み: pandas
# - ベイズ推定: NumPyro (NUTS)
# - モデル可視化: numpyro.render_model（組み込み）
# - 事後分布: ArviZ (plot_posterior(hdi_prob=...), plot_ppc(group="posterior"))
# - 可視化: matplotlib
# - 計算結果の表示: print()

import os
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

import jax.numpy as jnp
from jax import random

import numpyro
import numpyro.distributions as dist
from numpyro.infer import MCMC, NUTS, Predictive

import arviz as az

# ------------------------------
# 1) Data loading & design matrix
# ------------------------------
csv_path = "3-2-1-beer-sales-2.csv"
df = pd.read_csv(csv_path)

# Sample size
N = df.shape[0]

# Design matrix (Intercept + temperature) to mirror R's model.matrix with default intercept
X = np.column_stack([np.ones(N), df["temperature"].to_numpy()])
Y = df["sales"].to_numpy()

# Names for columns (for readability in prints/plots)
x_cols = ["Intercept", "temperature"]

print("First five rows of the design matrix (X):")
print(pd.DataFrame(X, columns=x_cols).head(5))

print(f"\nSample size N = {N}")
print(f"Number of predictors K = {X.shape[1]}")

# ------------------------------
# 2) NumPyro model (GLM with design matrix)
# ------------------------------
def model(X, Y=None):
    """
    Linear regression with a design matrix:
      Y ~ Normal(X @ b, sigma)
    Priors are weakly-informative to echo the unconstrained Stan example.
    """
    K = X.shape[1]
    b = numpyro.sample("b", dist.Normal(jnp.zeros(K), 10.0))
    sigma = numpyro.sample("sigma", dist.HalfNormal(10.0))
    mu = jnp.dot(X, b)
    numpyro.sample("Y", dist.Normal(mu, sigma), obs=Y)

# ------------------------------
# 3) MCMC (NUTS)
# ------------------------------
seed = 1
rng_key = random.PRNGKey(seed)

kernel = NUTS(model)
mcmc = MCMC(kernel, num_warmup=1000, num_samples=2000, num_chains=4, progress_bar=True)
mcmc.run(rng_key, X=jnp.array(X), Y=jnp.array(Y))

# 計算結果の表示（printを使用）
print("\n=== NumPyro MCMC summary ===")
# mcmc.print_summary() は標準出力に直接出すが、条件「printを使う」に合わせて ArviZ の表も print する
idata = az.from_numpyro(mcmc)  # observed_data は渡さない（禁止条件遵守）
print(az.summary(idata, var_names=["b", "sigma"]).to_string())

# 係数の事後平均を print
b_mean = idata.posterior["b"].mean(dim=("chain", "draw")).values
sigma_mean = idata.posterior["sigma"].mean(dim=("chain", "draw")).values
print("\nPosterior means:")
print(f"  b (Intercept, temperature): {b_mean}")
print(f"  sigma: {float(sigma_mean)}")

# ------------------------------
# 4) Posterior predictive & InferenceData
# ------------------------------
rng_key, rng_key_ppc = random.split(rng_key)
posterior_samples = mcmc.get_samples()
ppc = Predictive(model, posterior_samples=posterior_samples)
pp_samples = ppc(rng_key_ppc, X=jnp.array(X))  # Y draws only; observed_data は ArviZ に渡さない

# InferenceData with posterior_predictive (observed_data を入れない)
idata = az.from_numpyro(mcmc, posterior_predictive=pp_samples)

# ------------------------------
# 5) Visualizations
# ------------------------------

# 5-1) Data scatter + posterior mean regression line
plt.figure()
plt.scatter(df["temperature"], df["sales"], alpha=0.6, label="Observed data")
# Posterior mean line
x_grid = np.linspace(df["temperature"].min(), df["temperature"].max(), 100)
X_grid = np.column_stack([np.ones_like(x_grid), x_grid])
y_mean_line = X_grid @ b_mean
plt.plot(x_grid, y_mean_line, linewidth=2, label="Posterior mean line")
plt.xlabel("Temperature")
plt.ylabel("Sales")
plt.title("Sales vs Temperature with Posterior Mean Line")
plt.legend()
plt.tight_layout()
plt.show()

# 5-2) Posterior distributions with HDI (use hdi_prob, not credible_interval)
az.plot_posterior(
    idata,
    var_names=["b", "sigma"],
    hdi_prob=0.95
)
plt.suptitle("Posterior Distributions (95% HDI)", y=1.02)
plt.tight_layout()
plt.show()

# 5-3) Posterior predictive check (group must be 'posterior')
az.plot_ppc(idata, group="posterior", num_pp_samples=100)
plt.suptitle("Posterior Predictive Check", y=1.02)
plt.tight_layout()
plt.show()

# ------------------------------
# 6) Model visualization (NumPyro built-in)
# ------------------------------
# ベイズ統計モデルの可視化は numpyro の組み込み関数を使用
# Graphviz が無い環境でも例外にせず案内だけ出す
try:
    g = numpyro.render_model(model, model_args=(jnp.array(X), jnp.array(Y)))
    out_path = "glm_design_matrix_model_graph"
    g.render(out_path, format="png", cleanup=True)
    print(f'\nSaved NumPyro model graph to "{out_path}.png"')
except Exception as e:
    print("\nModel graph could not be rendered (Graphviz may be missing).")
    print(f"Reason: {e}")
