# 2-5-MCMCの結果の評価

In [None]:
# -*- coding: utf-8 -*-
# Rewriting the provided R + Stan code into Python with NumPyro + ArviZ

import os
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
import arviz as az

import jax
import jax.numpy as jnp
from jax import random

import numpyro
import numpyro.distributions as dist
from numpyro.infer import MCMC, NUTS, Predictive

# (Optional) use 64-bit for numerical stability similar to Stan
numpyro.enable_x64()

SEED = 1
rng_key = random.PRNGKey(SEED)

# --------------------------------------------------------------------------------------------------
# Helpers
# --------------------------------------------------------------------------------------------------
def run_mcmc(model, model_kwargs, num_chains=4, num_warmup=1000, num_samples=1000, thinning=1, seed=SEED):
    num_devices = jax.local_device_count()
    chain_method = "parallel" if num_devices >= num_chains else "sequential"
    kernel = NUTS(model)
    mcmc = MCMC(
        kernel,
        num_warmup=num_warmup,
        num_samples=num_samples,
        num_chains=num_chains,
        thinning=thinning,
        chain_method=chain_method,   # デバイスが足りない環境では最初から逐次実行に
        progress_bar=True,
    )
    mcmc.run(random.PRNGKey(seed), **model_kwargs)
    return mcmc

def numpyro_render_model(model, model_kwargs, filename_prefix):
    """Visualize model graph using NumPyro's built-in render_model."""
    try:
        from numpyro.contrib.render import render_model
        graph = render_model(model, model_kwargs=model_kwargs, render_distributions=True)
        # Save as SVG/PNG if graphviz is available; otherwise silently skip
        graph.render(filename_prefix, format="png", cleanup=True)
        print(f"Saved model graph to {filename_prefix}.png")
    except Exception as e:
        print(f"(Model visualization skipped: {e})")

def to_inferencedata(mcmc, posterior_predictive=None, observed_data=None):
    return az.from_numpyro(
        mcmc,
        posterior_predictive=posterior_predictive,
        #observed_data=observed_data,
    )

## 5.2 MCMCの実行

In [None]:
# ==================================================================================================
# Part 1: MCMC evaluation analogous to the 'beer sales' example (Normal model: unknown mu, sigma)
# ==================================================================================================

# --- Data: read CSV with pandas (R: read.csv("2-4-1-beer-sales-1.csv")) ---------------------------
beer_csv = "2-4-1-beer-sales-1.csv"
if not os.path.exists(beer_csv):
    print(f"WARNING: '{beer_csv}' not found in the working directory.")
file_beer_sales_1 = pd.read_csv(beer_csv)

# Sample size
sample_size = len(file_beer_sales_1)
print(f"Sample size (beer sales): {sample_size}")

# Data vector
sales_np = file_beer_sales_1["sales"].to_numpy()
sales = jnp.asarray(sales_np)

# --- NumPyro model: Normal with weakly-informative priors (closest to Stan's implicit flat) ------
def normal_mean_sd_model(sales):
    # Weakly-informative priors to emulate Stan's implicit improper priors
    mu = numpyro.sample("mu", dist.Normal(0.0, 1e6))
    sigma = numpyro.sample("sigma", dist.HalfCauchy(5.0))
    numpyro.sample("sales", dist.Normal(mu, sigma), obs=sales)

# Visualize model graph (NumPyro built-in)
numpyro_render_model(normal_mean_sd_model, {"sales": sales}, "normal_mean_sd_model")

# --- Run MCMC ------------------------------------------------------------------------------------
mcmc_sales = run_mcmc(normal_mean_sd_model, {"sales": sales}, num_chains=4, num_warmup=1000, num_samples=1000, thinning=1, seed=SEED)

## 5.3 MCMCサンプルの抽出

In [None]:
# Extract MCMC samples (grouped by chain) similar to rstan::extract(..., permuted = FALSE)
samples_by_chain = mcmc_sales.get_samples(group_by_chain=True)  # dict of arrays with shape (chains, draws, ...)
param_names = list(samples_by_chain.keys())
print(f"Parameter names: {param_names}")

# Inspect shapes and sample access similar to R example
mu_chain = samples_by_chain["mu"]  # shape: (chains, draws)
sigma_chain = samples_by_chain["sigma"]

print(f"Class of samples: {type(samples_by_chain)} (dict of JAX arrays)")
print(f"Shape of mu samples (chains, draws): {mu_chain.shape}")
print(f"Shape of sigma samples (chains, draws): {sigma_chain.shape}")

## 5.4 MCMCサンプルの代表値の計算

In [None]:
# "Parameter mu, chain 1: first MCMC sample after warmup"
first_mu_chain1 = np.asarray(mu_chain[0, 0])
print(f"First mu sample (chain 1): {first_mu_chain1}")

# "Parameter mu, chain 1: all MCMC samples"
mu_chain1_all = np.asarray(mu_chain[0, :])
print(f"Number of mu samples in chain 1: {mu_chain1_all.size}")

# "All chains: total number of mu samples"
mu_all_chains = np.asarray(mu_chain.reshape(-1))
print(f"Total number of mu samples (all chains): {mu_all_chains.size}")

# Summary stats analogous to R: median, mean, 95% interval
print(f"Posterior median of mu: {np.median(mu_all_chains)}")
print(f"Posterior mean of mu: {np.mean(mu_all_chains)}")
q025, q975 = np.quantile(mu_all_chains, [0.025, 0.975])
print(f"95% interval of mu (2.5%, 97.5%): ({q025}, {q975})")

# NumPyro's built-in textual summary (similar to Stan print)
print("\nNumPyro summary:")
mcmc_sales.print_summary()

## 5.5 トレースプロットの描画

In [None]:
# --- Convert to InferenceData for ArviZ plotting --------------------------------------------------
idata_sales = to_inferencedata(mcmc_sales, observed_data={"sales": sales})

# --- Trace plot (ArviZ) ---------------------------------------------------------------------------
az.plot_trace(idata_sales, var_names=["mu", "sigma"])
plt.suptitle("Trace Plots (mu, sigma)", y=1.02)
plt.show()

## 5.6 ggplot2による事後分布の可視化

In [None]:
# --- Posterior KDE with seaborn (English labels) --------------------------------------------------
sns.kdeplot(mu_all_chains, linewidth=1.5)
plt.xlabel("mu")
plt.ylabel("Density")
plt.title("Posterior Density of mu")
plt.show()

## 5.7 bayesplotによる事後分布の可視化

In [None]:
# --- Posterior hist/density via ArviZ -------------------------------------------------------------
az.plot_posterior(idata_sales, var_names=["mu", "sigma"], hdi_prob=0.95)
plt.suptitle("Posterior Distributions with 95% HDI", y=1.02)
plt.show()

# --- Chain-wise overlays / histograms / autocorrelation (ArviZ analogues to bayesplot) -----------
# Chain-wise density overlay
#az.plot_density(idata_sales, var_names=["mu", "sigma"], kind="kde", data_labels=None, outline=True, credible_interval=None, rug=False, hdi_prob=None, backend=None, credible_interval_method=None, point_estimate=None, figsize=None, textsize=None, legend=True, combined=False)
#plt.suptitle("Chain-wise Density Overlays", y=1.02)
#plt.show()
az.plot_density(
    idata_sales,
    var_names=["mu", "sigma"],
    data_labels=None,
    outline=True,
    hdi_prob=0.95,     # 要件どおり HDI を明示
    point_estimate="auto"
)
plt.suptitle("Chain-wise Density Overlays", y=1.02)
plt.show()

# Autocorrelation (correlogram)
az.plot_autocorr(idata_sales, var_names=["mu", "sigma"])
plt.suptitle("Autocorrelation (mu, sigma)", y=1.02)
plt.show()

# Forest (interval) plot similar to mcmc_intervals / mcmc_areas
az.plot_forest(idata_sales, var_names=["mu", "sigma"], hdi_prob=0.8)
plt.title("Parameter Intervals (80% HDI)")
plt.show()

## 確率分布による違い

In [None]:
# ==================================================================================================
# Part 2: Posterior Predictive Checks for Normal vs Poisson models (animal_num example)
# ==================================================================================================

# --- Data -----------------------------------------------------------------------------------------
animal_csv = "2-5-1-animal-num.csv"
if not os.path.exists(animal_csv):
    print(f"WARNING: '{animal_csv}' not found in the working directory.")
animal_df = pd.read_csv(animal_csv)
print("Head of animal_num data:")
print(animal_df.head(3))

N_animal = len(animal_df)
print(f"Sample size (animal_num): {N_animal}")

y_np = animal_df["animal_num"].to_numpy()
y = jnp.asarray(y_np)

# --- Models (Stan equivalents) --------------------------------------------------------------------
# Normal model (Stan: mu>0, sigma>0; here we mirror weakly-informative priors and positivity)
def animal_normal_model(y):
    mu = numpyro.sample("mu", dist.HalfNormal(1000.0))    # positive mean to echo Stan's <lower=0>
    sigma = numpyro.sample("sigma", dist.HalfNormal(100.0))
    numpyro.sample("y", dist.Normal(mu, sigma), obs=y)

# Poisson model (lambda>0)
def animal_poisson_model(y):
    lam = numpyro.sample("lambda", dist.HalfNormal(1000.0))
    numpyro.sample("y", dist.Poisson(lam), obs=y)

# Visualize models
numpyro_render_model(animal_normal_model, {"y": y}, "animal_normal_model")
numpyro_render_model(animal_poisson_model, {"y": y}, "animal_poisson_model")

# --- Run MCMC for both models ---------------------------------------------------------------------
mcmc_normal = run_mcmc(animal_normal_model, {"y": y}, num_chains=4, num_warmup=1000, num_samples=1000, seed=SEED)
mcmc_poisson = run_mcmc(animal_poisson_model, {"y": y}, num_chains=4, num_warmup=1000, num_samples=1000, seed=SEED)

print("\nEstimated parameters (Normal model):")
mcmc_normal.print_summary(exclude_deterministic=False)

print("\nEstimated parameters (Poisson model):")
mcmc_poisson.print_summary(exclude_deterministic=False)

# --- Posterior Predictive (matching Stan's generated quantities 'pred') ---------------------------
def posterior_predictive_draws(model, mcmc, num_draws=None):
    posterior = mcmc.get_samples()
    if num_draws is not None:
        # randomly subsample draws if requested
        idx = np.random.default_rng(SEED).choice(posterior["mu" if "mu" in posterior else "lambda"].shape[0], size=num_draws, replace=False)
        posterior = {k: v[idx] for k, v in posterior.items()}
    predictive = Predictive(model, posterior_samples=posterior)
    yrep = predictive(random.PRNGKey(SEED), y=jnp.zeros_like(y))["y"]  # simulate replicated y
    # Shape: (draws, N)
    return np.asarray(yrep)

# For the Normal model, site name is "y"; for Poisson likewise.
y_rep_normal = posterior_predictive_draws(animal_normal_model, mcmc_normal)  # (draws, N)
y_rep_poisson = posterior_predictive_draws(animal_poisson_model, mcmc_poisson)

print(f"Shape of y_rep_normal: {y_rep_normal.shape}  # should be (4000, {N_animal}) with 4x1000")
print("First replicate (Normal):", y_rep_normal[0, :])
print("First replicate (Poisson):", y_rep_poisson[0, :])

# --- Simple reference histograms (observed vs first replicate) ------------------------------------
plt.hist(y_np, bins="auto")
plt.title("Observed Data Histogram")
plt.xlabel("Count")
plt.ylabel("Frequency")
plt.show()

plt.hist(y_rep_normal[0, :], bins="auto")
plt.title("Posterior Predictive (Normal) — 1st Draw")
plt.xlabel("Replicated Count")
plt.ylabel("Frequency")
plt.show()

plt.hist(y_rep_poisson[0, :], bins="auto")
plt.title("Posterior Predictive (Poisson) — 1st Draw")
plt.xlabel("Replicated Count")
plt.ylabel("Frequency")
plt.show()

# --- Build InferenceData for PPC plots with ArviZ -------------------------------------------------
idata_normal = az.from_numpyro(
    mcmc_normal,
    posterior_predictive={"y": y_rep_normal},
    #observed_data={"y": y},
)
idata_poisson = az.from_numpyro(
    mcmc_poisson,
    posterior_predictive={"y": y_rep_poisson},
    #observed_data={"y": y},
)

# PPC histograms: use a small number of replicated samples for clarity (like 1:5 in R code)
#az.plot_ppc(idata_normal, group="posterior_predictive", num_pp_samples=5)
#plt.suptitle("PPC — Normal Model (5 replicated samples)", y=1.02)
#plt.show()
az.plot_ppc(idata_normal, group="posterior", num_pp_samples=5)
plt.suptitle("PPC — Normal Model (5 replicated samples)", y=1.02)
plt.show()

#az.plot_ppc(idata_poisson, group="posterior_predictive", num_pp_samples=5)
#plt.suptitle("PPC — Poisson Model (5 replicated samples)", y=1.02)
#plt.show()
az.plot_ppc(idata_poisson, group="posterior", num_pp_samples=5)
plt.suptitle("PPC — Poisson Model (5 replicated samples)", y=1.02)
plt.show()

# PPC KDE (like ppc_dens / ppc_dens_overlay)
#az.plot_ppc(idata_normal, group="posterior_predictive", num_pp_samples=10, kind="kde")
az.plot_ppc(idata_normal, group="posterior", num_pp_samples=10)
plt.suptitle("PPC KDE — Normal Model (10 replicated samples)", y=1.02)
plt.show()

#az.plot_ppc(idata_poisson, group="posterior_predictive", num_pp_samples=10, kind="kde")
az.plot_ppc(idata_poisson, group="posterior", num_pp_samples=10)
plt.suptitle("PPC KDE — Poisson Model (10 replicated samples)", y=1.02)
plt.show()

# --- Posterior distributions (ArviZ) with hdi_prob (credible_interval is NOT used) ----------------
az.plot_posterior(idata_normal, var_names=["mu", "sigma"], hdi_prob=0.95)
plt.suptitle("Posterior — Normal Model (95% HDI)", y=1.02)
plt.show()

az.plot_posterior(idata_poisson, var_names=["lambda"], hdi_prob=0.95)
plt.suptitle("Posterior — Poisson Model (95% HDI)", y=1.02)
plt.show()

# --- Extra diagnostics akin to bayesplot combo/trace ----------------------------------------------
az.plot_trace(idata_normal, var_names=["mu", "sigma"])
plt.suptitle("Trace — Normal Model", y=1.02)
plt.show()

az.plot_trace(idata_poisson, var_names=["lambda"])
plt.suptitle("Trace — Poisson Model", y=1.02)
plt.show()

# Autocorrelation bars
az.plot_autocorr(idata_normal, var_names=["mu", "sigma"])
plt.suptitle("Autocorrelation — Normal Model", y=1.02)
plt.show()

az.plot_autocorr(idata_poisson, var_names=["lambda"])
plt.suptitle("Autocorrelation — Poisson Model", y=1.02)
plt.show()

print("Done.")
