# 3-2-単回帰モデル

In [None]:
# 3-2-単回帰モデル (R+Stan -> Python+NumPyro/ArviZ)

# === Imports ===
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
import arviz as az

import jax
import jax.numpy as jnp
from jax import random
from numpyro import distributions as dist
from numpyro import sample
from numpyro.infer import MCMC, NUTS, Predictive
from numpyro import render_model

# ---- Utility ---------------------------------------------------------------
SEED = 1
rng_key = random.PRNGKey(SEED)
NUM_CHAINS = 4
NUM_WARMUP = 800
NUM_SAMPLES = 1200

def run_mcmc(model, num_chains=4, num_warmup=1000, num_samples=1000, thinning=1, seed=SEED, **model_kwargs):
    num_devices = jax.local_device_count()
    chain_method = "parallel" if num_devices >= num_chains else "sequential"
    kernel = NUTS(model)
    mcmc = MCMC(
        kernel,
        num_warmup=num_warmup,
        num_samples=num_samples,
        num_chains=num_chains,
        thinning=thinning,
        chain_method=chain_method,   # デバイスが足りない環境では最初から逐次実行に
        progress_bar=True,
    )
    mcmc.run(random.PRNGKey(seed), **model_kwargs)
    return mcmc

## データを読み込む

In [None]:
# === Data loading (pandas) ===
df = pd.read_csv("3-2-1-beer-sales-2.csv")
print("Head (n=3):")
print(df.head(3))  # print()で結果表示

sample_size = len(df)
print(f"Sample size: {sample_size}")  # print()で結果表示

# Convert to JAX arrays
temperature = jnp.array(df["temperature"].values)
sales = jnp.array(df["sales"].values)

In [None]:
# === Visualization of raw data (seaborn/matplotlib, labels in English) ===
plt.figure()
sns.scatterplot(x=df["temperature"], y=df["sales"])
plt.title("Beer Sales vs. Temperature")
plt.xlabel("Temperature")
plt.ylabel("Sales")
plt.tight_layout()
plt.show()

## モデルを定義する

In [None]:
# === NumPyro models ===
# Non-vectorized version (loop) ~ Stan: 3-2-1-simple-lm
def model_loop(sales, temperature, N):
    Intercept = sample("Intercept", dist.Normal(0.0, 10.0))
    beta = sample("beta", dist.Normal(0.0, 10.0))
    sigma = sample("sigma", dist.HalfCauchy(5.0))
    # non-vectorized likelihood
    for i in range(N):
        mu_i = Intercept + beta * temperature[i]
        sample(f"sales_{i}", dist.Normal(mu_i, sigma), obs=sales[i])

# Vectorized version ~ Stan: 3-2-2-simple-lm-vec
def model_vec(sales=None, temperature=None, N=None):
    Intercept = sample("Intercept", dist.Normal(0.0, 10.0))
    beta = sample("beta", dist.Normal(0.0, 10.0))
    sigma = sample("sigma", dist.HalfCauchy(5.0))
    mu = Intercept + beta * temperature
    sample("sales", dist.Normal(mu, sigma), obs=sales)

In [None]:
# === Visualize the Bayesian statistical model structure (NumPyro built-in) ===
# Plate/graph of the vectorized model
try:
    from IPython.display import display
    display(render_model(model_vec, model_kwargs={"sales": sales, "temperature": temperature, "N": sample_size}))
except Exception:
    # If not in a notebook, still build the graph (won't auto-display)
    _ = render_model(model_vec, model_kwargs={"sales": sales, "temperature": temperature, "N": sample_size})

## モデルのフィッティング(loop)

In [None]:
# === MCMC: non-vectorized (reference) ===
#nuts_loop = NUTS(model_loop)
#mcmc_loop = MCMC(nuts_loop, num_warmup=1000, num_samples=2000, num_chains=2, progress_bar=True)
#rng_key, subkey = jax.random.split(rng_key)
#mcmc_loop.run(subkey, sales=sales, temperature=temperature, N=sample_size)
mcmc_loop = run_mcmc(model_loop, sales=sales, temperature=temperature, N=sample_size)

# Convert to InferenceData (do NOT pass observed_data)
idata_loop = az.from_numpyro(mcmc_loop)

# Print summary with 95% HDI (2.5%/97.5%)
summary_loop = az.summary(idata_loop, var_names=["Intercept", "beta", "sigma"], hdi_prob=0.95)
print("\nNon-vectorized model summary (95% HDI):")  # print()で結果表示
print(summary_loop[["mean", "sd", "hdi_2.5%", "hdi_97.5%"]].to_string())

## モデルのフィッティング(vec)

In [None]:
# === MCMC: vectorized (main) ===
#nuts_vec = NUTS(model_vec)
#mcmc_vec = MCMC(nuts_vec, num_warmup=1000, num_samples=2000, num_chains=2, progress_bar=True)
#rng_key, subkey = jax.random.split(rng_key)
#mcmc_vec.run(subkey, sales=sales, temperature=temperature, N=sample_size)
mcmc_vec = run_mcmc(model_vec, sales=sales, temperature=temperature, N=sample_size)

idata_vec = az.from_numpyro(mcmc_vec)

summary_vec = az.summary(idata_vec, var_names=["Intercept", "beta", "sigma"], hdi_prob=0.95)
print("\nVectorized model summary (95% HDI):")  # print()で結果表示
print(summary_vec[["mean", "sd", "hdi_2.5%", "hdi_97.5%"]].to_string())

## 事後分布(loop)

In [None]:
# === Posterior visualization (ArviZ) ===
# Posterior distributions for parameters (use hdi_prob, NOT credible_interval)
az.plot_posterior(idata_vec, var_names=["Intercept", "beta", "sigma"], hdi_prob=0.95)
plt.suptitle("Posterior Distributions (Vectorized Model)", y=1.02)
plt.tight_layout()
plt.show()

## 事後分布(vec)

In [None]:
# === Posterior Predictive Check (PPC) ===
# Generate posterior predictive using samples grouped by chain to preserve dims
posterior_samples = mcmc_vec.get_samples(group_by_chain=True)
rng_key, pred_key = jax.random.split(rng_key)
predictive = Predictive(model_vec, posterior_samples=posterior_samples, batch_ndims = 2)# batch_ndims = num_chains
ppc_samples = predictive(pred_key, sales=None, temperature=temperature, N=sample_size)

## サンプリング結果を結合する

In [None]:
# Merge PPC into InferenceData (still do NOT pass observed_data)
idata_vec_ppc = az.from_numpyro(mcmc_vec, posterior_predictive=ppc_samples)

## 事後分布をプロットする

In [None]:
# Plot PPC; explicitly set group="posterior" as requested
#az.plot_ppc(idata_vec_ppc, group="posterior")
#plt.suptitle("Posterior Predictive Check", y=1.02)
#plt.tight_layout()
#plt.show()

# 凡例のデフォルト位置を固定（例: 右上）
with plt.rc_context({"legend.loc": "upper right"}):
    axes = az.plot_ppc(idata_vec_ppc, group="posterior")  # ← ここはそのまま
    plt.suptitle("Posterior Predictive Check", y=1.02)
    plt.tight_layout()
    plt.show()

## 回帰直線をプロットする

In [None]:
# === Optional: data + fitted regression line (using posterior means) ===
beta_mean = float(summary_vec.loc["beta", "mean"])
intercept_mean = float(summary_vec.loc["Intercept", "mean"])
xline = np.linspace(df["temperature"].min(), df["temperature"].max(), 100)
yline = intercept_mean + beta_mean * xline

plt.figure()
sns.scatterplot(x=df["temperature"], y=df["sales"])
plt.plot(xline, yline)
plt.title("Beer Sales vs. Temperature with Fitted Line")
plt.xlabel("Temperature")
plt.ylabel("Sales")
plt.tight_layout()
plt.show()