# 3-5-brmsの使い方

In [None]:
# -*- coding: utf-8 -*-
# Rewriting of "brmsの使い方｜RとStanではじめる ベイズ統計モデリング" examples
# to Python with NumPyro + ArviZ + pandas + matplotlib/seaborn
# Conditions honored:
# - Use print() for numerical outputs
# - Read CSV via pandas
# - Use matplotlib/seaborn/arviz for visualization (labels in English)
# - Use NumPyro for Bayesian inference
# - Use NumPyro built-ins for model visualization (render_model if available)
# - Use ArviZ for posterior/posterior predictive visualization
# - az.plot_posterior(..., hdi_prob=...)  (do NOT use credible_interval)
# - Do NOT use observed_data arg in az.from_numpyro
# - Do NOT use kind arg in az.plot_density
# - When plotting PPC, set group="posterior"
# - Do NOT use group arg in az.plot_forest

import os
import numpy as np
import pandas as pd
import arviz as az
import matplotlib.pyplot as plt
import seaborn as sns

import jax
import jax.numpy as jnp
from numpyro import sample, deterministic
import numpyro
import numpyro.distributions as dist
from numpyro.infer import MCMC, NUTS, Predictive

# ---------------------------------------------------------------------
# 0) Reproducibility and plotting style
# ---------------------------------------------------------------------
numpyro.set_platform("cpu")
rng_key = jax.random.PRNGKey(1)
sns.set(style="whitegrid")

# ---------------------------------------------------------------------
# 1) Data: read via pandas (CSV must exist in the working directory)
# ---------------------------------------------------------------------
# R code used: read.csv("3-2-1-beer-sales-2.csv")
csv_path = "3-2-1-beer-sales-2.csv"
if not os.path.exists(csv_path):
    raise FileNotFoundError(
        f"{csv_path} not found. Place the CSV in the working directory."
    )

df = pd.read_csv(csv_path)
# Expect columns: 'sales', 'temperature'
if not {"sales", "temperature"}.issubset(set(df.columns)):
    raise ValueError("CSV must contain 'sales' and 'temperature' columns.")

# Basic scatter (labels in English)
plt.figure(figsize=(6, 4))
sns.scatterplot(data=df, x="temperature", y="sales", s=50)
plt.xlabel("Temperature (°C)")
plt.ylabel("Sales")
plt.title("Beer Sales vs Temperature")
plt.tight_layout()
plt.show()

# Numpy arrays for NumPyro
x = jnp.asarray(df["temperature"].values, dtype=jnp.float32)
y = jnp.asarray(df["sales"].values, dtype=jnp.float32)

# ---------------------------------------------------------------------
# 2) NumPyro models (equivalents of brms formulations)
# ---------------------------------------------------------------------
# Default weakly-informative prior linear regression:
def model_default(temp, sales=None):
    # Comparable to a simple Gaussian linear model with weak priors
    b0 = sample("b0", dist.Normal(0.0, 10.0))
    b1 = sample("b1", dist.Normal(0.0, 10.0))
    sigma = sample("sigma", dist.HalfNormal(10.0))
    mu = b0 + b1 * temp
    deterministic("mu", mu)
    sample("obs", dist.Normal(mu, sigma), obs=sales)

# "Uninformative" (very wide) priors for Intercept and sigma
# (NumPyro requires proper priors, so we emulate flat priors with huge scales)
def model_uninformative(temp, sales=None):
    b0 = sample("b0", dist.Normal(0.0, 1e6))
    b1 = sample("b1", dist.Normal(0.0, 10.0))  # slope still weakly informative
    sigma = sample("sigma", dist.HalfNormal(1e6))
    mu = b0 + b1 * temp
    deterministic("mu", mu)
    sample("obs", dist.Normal(mu, sigma), obs=sales)

# Custom prior for slope: Normal(0, 100000) as in set_prior("normal(0,100000)", class="b")
def model_slope_wide(temp, sales=None):
    b0 = sample("b0", dist.Normal(0.0, 10.0))
    b1 = sample("b1", dist.Normal(0.0, 100000.0))
    sigma = sample("sigma", dist.HalfNormal(10.0))
    mu = b0 + b1 * temp
    deterministic("mu", mu)
    sample("obs", dist.Normal(mu, sigma), obs=sales)

# ---------------------------------------------------------------------
# 3) Fit the default model with MCMC (NUTS): equivalent to brm(..., chains=4, iter=2000, warmup=1000)
# ---------------------------------------------------------------------
nuts = NUTS(model_default)
mcmc = MCMC(nuts, num_warmup=1000, num_samples=1000, num_chains=4, thinning=1, progress_bar=True)
mcmc.run(rng_key, temp=x, sales=y)
mcmc.print_summary()  # MCMC result check (like simple_lm_brms)
print("\nMCMC sample keys:", list(mcmc.get_samples().keys()))
print("Posterior sample shapes:", {k: v.shape for k, v in mcmc.get_samples(group_by_chain=True).items()})

# ---------------------------------------------------------------------
# 4) Posterior, Prior Predictive, and InferenceData for ArviZ
# (Do NOT pass observed_data to az.from_numpyro)
# ---------------------------------------------------------------------
# Posterior predictive on observed x
rng_key, subkey = jax.random.split(rng_key)
posterior_samples = mcmc.get_samples()
ppc = Predictive(model_default, posterior_samples)(subkey, temp=x)

# Prior predictive (for completeness)
rng_key, subkey = jax.random.split(rng_key)
prior_pred = Predictive(model_default, num_samples=500)(subkey, temp=x)

# Convert to ArviZ InferenceData
idata = az.from_numpyro(
    mcmc, prior=prior_pred, posterior_predictive=ppc
)

# Print a numerical summary with print()
summary_df = az.summary(idata, var_names=["b0", "b1", "sigma"], hdi_prob=0.95)
print("\nPosterior summary (95% HDI):")
print(summary_df)

# ---------------------------------------------------------------------
# 5) Posterior visualization (ArviZ)
#    - Use hdi_prob (NOT credible_interval)
#    - Avoid kind in plot_density
#    - Avoid group in plot_forest
# ---------------------------------------------------------------------
az.plot_posterior(idata, var_names=["b0", "b1", "sigma"], hdi_prob=0.95)
plt.suptitle("Posterior Distributions (95% HDI)")
plt.show()

az.plot_density(idata, var_names=["b0", "b1", "sigma"])
plt.suptitle("Posterior Density")
plt.show()

# Intervals-style visualization similar to stanplot(..., type="intervals")
az.plot_forest(idata, var_names=["b0", "b1", "sigma"])
plt.title("Forest Plot of Coefficients")
plt.show()

# Posterior predictive check: must set group="posterior"
az.plot_ppc(idata, group="posterior")
plt.title("Posterior Predictive Check")
plt.show()

# ---------------------------------------------------------------------
# 6) NumPyro built-in model visualization (plate/graph) if available
# ---------------------------------------------------------------------
try:
    from numpyro.contrib.render import render_model
    dot = render_model(
        model_default,
        model_args=(x,),
        model_kwargs={"sales": y},
        render_distributions=True,
        render_params=True,
    )
    # In notebooks, display(dot) shows the graph; here we just print a notice
    print("\nRendered model graph (use display(dot) in a notebook to view).")
except Exception as e:
    print("\nModel rendering skipped (graphviz or render module not available):", e)

# ---------------------------------------------------------------------
# 7) "Fitted" values and "Predict" for new data (temperature = 20)
#    - Replicates fitted(...) and predict(...) behavior
# ---------------------------------------------------------------------
new_temp = jnp.array([20.0], dtype=jnp.float32)

# Expected (fitted) mean at 20°C using posterior samples: mu = b0 + b1*20
b0 = np.asarray(posterior_samples["b0"])
b1 = np.asarray(posterior_samples["b1"])
sigma = np.asarray(posterior_samples["sigma"])
mu20 = b0 + 20.0 * b1

# Print fitted mean and 95% HDI
mu20_hdi = az.hdi(mu20, hdi_prob=0.95)#.to_array().values  # returns [low, high]
print("\nFitted mean at 20°C:")
print(f"mean = {mu20.mean():.3f}, 95% HDI = [{mu20_hdi[0]:.3f}, {mu20_hdi[1]:.3f}]")

# Predictive distribution at 20°C by drawing y ~ Normal(mu20, sigma)
# (Equivalent to predict(...))
rng = np.random.default_rng(1)
y20 = rng.normal(loc=mu20, scale=sigma)
y20_hdi = az.hdi(y20, hdi_prob=0.95)#.to_array().values
print("\nPredictive at 20°C:")
print(f"mean = {y20.mean():.3f}, 95% HDI = [{y20_hdi[0]:.3f}, {y20_hdi[1]:.3f}]")

# Cross-check using NumPyro Predictive for new data
rng_key, subkey = jax.random.split(rng_key)
ppc_new = Predictive(model_default, posterior_samples)(subkey, temp=new_temp)
# ppc_new['obs'] has shape (num_draws, N=1)
y20_numpyro = np.asarray(ppc_new["obs"]).reshape(-1)
y20n_hdi = az.hdi(y20_numpyro, hdi_prob=0.95)#.to_array().values
print("\nNumPyro Predictive (cross-check) at 20°C:")
print(f"mean = {y20_numpyro.mean():.3f}, 95% HDI = [{y20n_hdi[0]:.3f}, {y20n_hdi[1]:.3f}]")

# ---------------------------------------------------------------------
# 8) "Marginal effects" style plots:
#    - Regression line with 95% credible interval (for mean)
#    - 95% prediction interval (for new observations)
# ---------------------------------------------------------------------
# Create a temperature grid
x_grid = np.linspace(df["temperature"].min(), df["temperature"].max(), 100).astype(np.float32)

# Compute posterior draws of mu over grid: (S x 100)
mu_grid = (b0[:, None] + b1[:, None] * x_grid[None, :])

# 95% HDI for the mean (credible band)
mu_hdi = az.hdi(mu_grid, hdi_prob=0.95)
mu_mean = mu_grid.mean(axis=0)

# Prediction interval via predictive draws y ~ Normal(mu, sigma)
rng = np.random.default_rng(1)
y_grid_draws = rng.normal(loc=mu_grid, scale=sigma[:, None])
y_hdi = az.hdi(y_grid_draws, hdi_prob=0.95)
y_mean = y_grid_draws.mean(axis=0)

# Plot: credible band for mean
plt.figure(figsize=(7, 4.5))
sns.scatterplot(data=df, x="temperature", y="sales", s=40, alpha=0.7, label="Observed")
plt.plot(x_grid, mu_mean, label="Fitted mean")
#plt.fill_between(x_grid, mu_hdi.sel(hdi="lower"), mu_hdi.sel(hdi="higher"),
#                 alpha=0.3, label="95% Credible Interval")
plt.fill_between(x_grid, mu_hdi[:, 0], mu_hdi[:, 1],
                 alpha=0.3, label="95% Credible Interval")
plt.xlabel("Temperature (°C)")
plt.ylabel("Sales")
plt.title("Regression with 95% Credible Interval (Mean)")
plt.legend()
plt.tight_layout()
plt.show()

# Plot: 95% prediction interval
plt.figure(figsize=(7, 4.5))
sns.scatterplot(data=df, x="temperature", y="sales", s=40, alpha=0.7, label="Observed")
plt.plot(x_grid, y_mean, label="Predictive mean")
#plt.fill_between(x_grid, y_hdi.sel(hdi="lower"), y_hdi.sel(hdi="higher"),
#                 alpha=0.3, label="95% Prediction Interval")
plt.fill_between(x_grid, y_hdi[:, 0], y_hdi[:, 1],
                 alpha=0.3, label="95% Prediction Interval")
plt.xlabel("Temperature (°C)")
plt.ylabel("Sales")
plt.title("Regression with 95% Prediction Interval")
plt.legend()
plt.tight_layout()
plt.show()

# ---------------------------------------------------------------------
# 9) Alternative prior configurations (like prior_summary/changes in brms)
#    These runs demonstrate how to alter priors; summaries printed via print().
# ---------------------------------------------------------------------
# Uninformative priors for Intercept and sigma
nuts_u = NUTS(model_uninformative)
mcmc_u = MCMC(nuts_u, num_warmup=800, num_samples=800, num_chains=2, progress_bar=True)
rng_key, subkey = jax.random.split(rng_key)
mcmc_u.run(subkey, temp=x, sales=y)
print("\nUninformative prior model summary:")
mcmc_u.print_summary()

# Slope prior N(0, 100000)
nuts_w = NUTS(model_slope_wide)
mcmc_w = MCMC(nuts_w, num_warmup=800, num_samples=800, num_chains=2, progress_bar=True)
rng_key, subkey = jax.random.split(rng_key)
mcmc_w.run(subkey, temp=x, sales=y)
print("\nWide slope prior model summary (b1 ~ Normal(0, 100000)):")
mcmc_w.print_summary()

# Optional: compare posterior plots for the alternative models (posterior only)
idata_u = az.from_numpyro(mcmc_u)
idata_w = az.from_numpyro(mcmc_w)
az.plot_posterior(idata_u, var_names=["b0", "b1", "sigma"], hdi_prob=0.95)
plt.suptitle("Posterior (Uninformative Intercept & Sigma)")
plt.show()
az.plot_posterior(idata_w, var_names=["b0", "b1", "sigma"], hdi_prob=0.95)
plt.suptitle("Posterior (Wide Slope Prior)")
plt.show()
