# 2-6-Stanコーディングの詳細

In [None]:
# -*- coding: utf-8 -*-
# R+Stan -> Python+NumPyro/ArviZ rewrite for sections 2-6-x
# Requirements satisfied:
# - print() for results
# - pandas for CSV loading
# - matplotlib/seaborn/arviz for visualization
# - labels in English
# - NumPyro for Bayesian inference
# - Use NumPyro built-in function for model visualization (render_model)
# - Use ArviZ for posterior visualization; use hdi_prob (NOT credible_interval)
# - Do NOT pass observed_data to az.from_numpyro
# - Do NOT pass kind to az.plot_density
# - For PPC plots, pass group="posterior" to az.plot_ppc

import os
import warnings
warnings.filterwarnings("ignore")

import pandas as pd
import numpy as np

import jax
import jax.numpy as jnp
from jax import random

import matplotlib.pyplot as plt
import seaborn as sns
import arviz as az

import numpyro
import numpyro.distributions as dist
from numpyro.infer import MCMC, NUTS, Predictive

# ---- Utility ---------------------------------------------------------------

SEED = 1
rng_key = random.PRNGKey(SEED)
NUM_CHAINS = 4
NUM_WARMUP = 800
NUM_SAMPLES = 1200

def run_mcmc(model_fn, rng_key, **model_kwargs):
    kernel = NUTS(model_fn)
    mcmc = MCMC(kernel, num_warmup=NUM_WARMUP, num_samples=NUM_SAMPLES, num_chains=NUM_CHAINS, progress_bar=True)
    mcmc.run(rng_key, **model_kwargs)
    return mcmc

def print_summary_from_idata(idata, var_names, hdi_prob=0.95, title=None):
    if title:
        print(f"\n=== {title} ===")
    summ = az.summary(idata, var_names=var_names, hdi_prob=hdi_prob)
    print(summ)

# Try to use NumPyro's built-in model renderer (graphviz required)
def try_render_model(model_fn, render_name, **model_args):
    try:
        g = numpyro.render_model(
            model_fn,
            model_args=(),
            model_kwargs=model_args,
            render_distributions=True,
            render_params=True
        )
        outpath = f"{render_name}.svg"
        g.render(render_name, format="svg", cleanup=True)
        print(f"Model graph saved to: {outpath}")
    except Exception as e:
        print(f"(Skip model rendering for {render_name}: {e})")

# ---- Data loading (pandas) -------------------------------------------------

# 2-4-1 (単一系列の売上)
df_sales1 = pd.read_csv("2-4-1-beer-sales-1.csv")
print("Head of 2-4-1-beer-sales-1.csv:")
print(df_sales1.head(3))
sales = jnp.array(df_sales1["sales"].to_numpy())
N = sales.shape[0]
print(f"Sample size (N): {N}")

# 2-6-1 (A/B の比較用)
df_ab = pd.read_csv("2-6-1-beer-sales-ab.csv")
print("\nHead of 2-6-1-beer-sales-ab.csv:")
print(df_ab.head(3))

# Split into Beer A and B like the original code (first 100 = A, next 100 = B)
sales_a = jnp.array(df_ab["sales"].iloc[:100].to_numpy())
sales_b = jnp.array(df_ab["sales"].iloc[100:200].to_numpy())
N_ab = 100

# ---- Quick visualization of A/B (matplotlib / seaborn) --------------------
plt.figure()
sns.histplot(df_ab, x="sales", hue="beer_name", element="step", stat="density", common_norm=False, alpha=0.5)
sns.kdeplot(df_ab, x="sales", hue="beer_name", common_norm=False, fill=False)
plt.title("Sales Distribution by Beer Type")
plt.xlabel("Sales")
plt.ylabel("Density")
plt.tight_layout()
plt.show()

# ---- Models (NumPyro) ------------------------------------------------------
# Stan: 2-6-1-normal-prior
def model_261_normal_prior(sales):
    mu = numpyro.sample("mu", dist.Normal(0., 1e6))
    sigma = numpyro.sample("sigma", dist.HalfNormal(1e6))
    numpyro.sample("y", dist.Normal(mu, sigma).expand([sales.shape[0]]), obs=sales)

# Stan: 2-6-2-lp (target += for likelihood; priors kept very weak and proper)
def model_262_lp(sales):
    mu = numpyro.sample("mu", dist.Normal(0., 1e6))
    sigma = numpyro.sample("sigma", dist.HalfNormal(1e6))
    # target += normal_lpdf(sales[i] | mu, sigma) for each i
    # (implemented via numpyro.factor; loop version)
    for i in range(sales.shape[0]):
        numpyro.factor(f"ll_{i}", dist.Normal(mu, sigma).log_prob(sales[i]))

# Stan: 2-6-3-lp-normal-prior (explicit priors + target += for likelihood)
# In NumPyro, priors are already explicit via sample(); we still show factor-based likelihood.
def model_263_lp_normal_prior(sales):
    mu = numpyro.sample("mu", dist.Normal(0., 1e6))
    sigma = numpyro.sample("sigma", dist.HalfNormal(1e6))
    for i in range(sales.shape[0]):
        numpyro.factor(f"ll_{i}", dist.Normal(mu, sigma).log_prob(sales[i]))

# Stan: 2-6-4-lp-normal-prior-vec (vectorized target +=)
def model_264_lp_normal_prior_vec(sales):
    mu = numpyro.sample("mu", dist.Normal(0., 1e6))
    sigma = numpyro.sample("sigma", dist.HalfNormal(1e6))
    # Vectorized log-prob then summed (target += normal_lpdf(sales | mu, sigma))
    numpyro.factor("ll_vec", dist.Normal(mu, sigma).log_prob(sales).sum())

# Stan: 2-6-5-difference-mean  + generated quantities (diff)
def model_265_difference_mean(N, sales_a=None, sales_b=None):
    mu_a = numpyro.sample("mu_a", dist.Normal(0., 1e6))
    sigma_a = numpyro.sample("sigma_a", dist.HalfNormal(1e6))
    mu_b = numpyro.sample("mu_b", dist.Normal(0., 1e6))
    sigma_b = numpyro.sample("sigma_b", dist.HalfNormal(1e6))

    with numpyro.plate("N", N):
        numpyro.sample("y_a", dist.Normal(mu_a, sigma_a), obs=sales_a)
        numpyro.sample("y_b", dist.Normal(mu_b, sigma_b), obs=sales_b)

    numpyro.deterministic("diff", mu_b - mu_a)

# ---- Fit models ------------------------------------------------------------
# 2-6-1
mcmc_261 = run_mcmc(model_261_normal_prior, rng_key, sales=sales)
idata_261 = az.from_numpyro(mcmc_261)
print_summary_from_idata(idata_261, ["mu", "sigma"], hdi_prob=0.95, title="2-6-1 normal-prior")

# NumPyro built-in "visualization" of model structure (graph)
try_render_model(model_261_normal_prior, "model_261_normal_prior", sales=sales)

# 2-6-2 (lp with loop)
mcmc_262 = run_mcmc(model_262_lp, random.split(rng_key, 2)[1], sales=sales)
idata_262 = az.from_numpyro(mcmc_262)
print_summary_from_idata(idata_262, ["mu", "sigma"], hdi_prob=0.95, title="2-6-2 lp (loop likelihood)")

# 2-6-3 (lp + explicit prior; same as 2-6-2 in NumPyro with comments)
mcmc_263 = run_mcmc(model_263_lp_normal_prior, random.split(rng_key, 3)[2], sales=sales)
idata_263 = az.from_numpyro(mcmc_263)
print_summary_from_idata(idata_263, ["mu", "sigma"], hdi_prob=0.95, title="2-6-3 lp + normal prior")

# 2-6-4 (lp vectorized)
mcmc_264 = run_mcmc(model_264_lp_normal_prior_vec, random.split(rng_key, 4)[3], sales=sales)
idata_264 = az.from_numpyro(mcmc_264)
print_summary_from_idata(idata_264, ["mu", "sigma"], hdi_prob=0.95, title="2-6-4 lp vectorized")

# 2-6-5 (difference of means with generated quantities)
mcmc_265 = run_mcmc(model_265_difference_mean, random.split(rng_key, 5)[4], N=N_ab, sales_a=sales_a, sales_b=sales_b)
idata_265 = az.from_numpyro(mcmc_265)
print_summary_from_idata(idata_265, ["mu_a", "sigma_a", "mu_b", "sigma_b", "diff"], hdi_prob=0.95, title="2-6-5 difference of means")

# ---- Posterior & PPC visualization (ArviZ) ---------------------------------
# Posterior plots (use hdi_prob; DO NOT use credible_interval)
fig = az.plot_posterior(idata_261, var_names=["mu", "sigma"], hdi_prob=0.95)
plt.suptitle("Posterior of mu and sigma (2-6-1)", y=1.02)
plt.show()

# Density (no 'kind' argument passed)
az.plot_density(idata_265, var_names=["mu_a", "mu_b", "diff"])
plt.suptitle("Posterior density: mu_a, mu_b, diff (2-6-5)")
plt.show()

# Posterior Predictive Check for 2-6-5
# Build posterior predictive draws WITHOUT passing observed_data to from_numpyro
predictive_265 = Predictive(model_265_difference_mean, posterior_samples=mcmc_265.get_samples(), return_sites=["y_a", "y_b"])
ppc_samples = predictive_265(random.split(rng_key, 6)[5], N=N_ab)  # sales_a/sales_b omitted -> generated

idata_265_ppc = az.from_numpyro(mcmc_265, posterior_predictive=ppc_samples)
az.plot_ppc(idata_265_ppc, group="posterior")  # group must be "posterior"
plt.suptitle("Posterior Predictive Check (2-6-5)")
plt.show()

# ---- Extra: quick posterior diff-only plot (as in bayesplot::mcmc_dens for 'diff') ----
az.plot_posterior(idata_265, var_names=["diff"], hdi_prob=0.95)
plt.suptitle("Posterior of difference (mu_b - mu_a)")
plt.show()

print("\nAll analyses complete.")
