# 3-3-モデルを用いた予測

In [None]:
# -*- coding: utf-8 -*-
# 3-3: Prediction using a simple linear model (Python / NumPyro version)

import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import arviz as az

import jax.numpy as jnp
from jax import random

import numpyro
import numpyro.distributions as dist
from numpyro.infer import MCMC, NUTS, Predictive

# ------------------------------------------------------------
# 1) Load data (use pandas) & basic info
# ------------------------------------------------------------
df = pd.read_csv("3-2-1-beer-sales-2.csv")  # <- CSV読み込みはpandas
print("Head of data:")
print(df.head())  # 計算結果や中間結果の表示は print()

sample_size = len(df)
print(f"Sample size: {sample_size}")

temperature = jnp.array(df["temperature"].to_numpy())
sales = jnp.array(df["sales"].to_numpy())

# Prediction temperatures: 11..30
temperature_pred = jnp.arange(11, 31)
print("Prediction temperatures (°C):")
print(np.array(temperature_pred))
N_pred = int(temperature_pred.shape[0])

# ------------------------------------------------------------
# 2) NumPyro model (Bayesian estimation with weakly-informative priors)
# ------------------------------------------------------------
def simple_lm(temperature, sales=None):
    # 弱情報事前分布（Stanの“非情報的”に相当する穏当な広い幅）
    Intercept = numpyro.sample("Intercept", dist.Normal(0.0, 10.0))  # <-- priors
    beta = numpyro.sample("beta", dist.Normal(0.0, 10.0))            # <-- priors
    sigma = numpyro.sample("sigma", dist.HalfNormal(10.0))           # <-- priors

    mu = Intercept + beta * temperature
    numpyro.deterministic("mu", mu)
    numpyro.sample("sales", dist.Normal(mu, sigma), obs=sales)

# （任意）モデル図の可視化：NumPyro組み込みの render_model を利用
try:
    import graphviz  # graphviz が環境にあればモデル可視化が保存されます
    graph = numpyro.render_model(simple_lm, model_args=(temperature,), model_kwargs={"sales": sales})
    graph.render("simple_lm_graph", format="png", cleanup=True)
    print("Saved model graph to simple_lm_graph.png")
except Exception as e:
    print(f"Model graph not rendered: {e}")

# ------------------------------------------------------------
# 3) MCMC (NUTS)
# ------------------------------------------------------------
rng_key = random.PRNGKey(1)
nuts = NUTS(simple_lm)
mcmc = MCMC(nuts, num_warmup=1000, num_samples=2000, num_chains=4, progress_bar=True)
mcmc.run(rng_key, temperature=temperature, sales=sales)

# NumPyro組み込みのサマリ出力（print()）
mcmc.print_summary()

# 追加でArviZのsummary（print()）
#idata_post = az.from_numpyro(mcmc=mcmc)  # ← observed_data は使わない
idata_post = az.from_numpyro(posterior=mcmc)  # ← observed_data は使わない
print(az.summary(idata_post, var_names=["Intercept", "beta", "sigma"]).to_string())

# ------------------------------------------------------------
# 4) Posterior Predictive for observed data (PPC)
#    - az.from_numpyro の observed_data 引数は使わず、from_dictで構築
# ------------------------------------------------------------
ppc = Predictive(simple_lm, posterior_samples=mcmc.get_samples())(
    random.PRNGKey(2), temperature=temperature
)
# NumPyro の Predictive は既定で「チェーンを潰した1本の系列（drawのみ）」を返すので、先頭に長さ1の chain 次元を足してから渡します。
idata_ppc = az.from_dict(
    posterior_predictive={"sales": np.array(ppc["sales"])[None, :, :]},   # 予測された観測
    observed_data={"sales": np.array(sales)},                 # 観測データ（from_dictで渡す）
    coords={"obs_index": np.arange(sample_size)},
    dims={"sales": ["obs_index"], "mu": ["obs_index"]},
)

# ------------------------------------------------------------
# 5) Predictions for new temperatures (11..30°C)
#    mu_pred と sales_pred を生成
# ------------------------------------------------------------
def pred_model(temperature_pred):
    # 事前はダミー（posterior_samplesで置換される）
    Intercept = numpyro.sample("Intercept", dist.Normal(0.0, 10.0))
    beta = numpyro.sample("beta", dist.Normal(0.0, 10.0))
    sigma = numpyro.sample("sigma", dist.HalfNormal(10.0))
    mu_pred = Intercept + beta * temperature_pred
    numpyro.deterministic("mu_pred", mu_pred)
    numpyro.sample("sales_pred", dist.Normal(mu_pred, sigma))

pred_new = Predictive(pred_model, posterior_samples=mcmc.get_samples())(
    random.PRNGKey(3), temperature_pred=temperature_pred
)

# predictions グループとして格納（observed_dataは禁止条件に該当しない from_dict を使用）
idata_pred = az.from_dict(
    predictions={
        "mu_pred": np.array(pred_new["mu_pred"])[None, :, :],           # 期待値の事後分布
        "sales_pred": np.array(pred_new["sales_pred"])[None, :, :],     # 予測値（ノイズ込み）
    },
    coords={"pred_index": np.arange(N_pred)},
    dims={"mu_pred": ["pred_index"], "sales_pred": ["pred_index"]},
)

# ------------------------------------------------------------
# 6) Visualization
#    - matplotlib / ArviZ を使用
#    - ラベルは英語表記
#    - az.plot_posterior では hdi_prob を使用（credible_intervalは使わない）
#    - az.plot_ppc の group は "posterior" を使用
#    - az.from_numpyro(observed_data=...) は使用しない
#    - az.plot_density(kind=...) は使わない
# ------------------------------------------------------------

# 6-1) Raw data scatter + posterior mean line + 95% prediction interval
sales_pred_samples = np.array(pred_new["sales_pred"])  # shape: (draws, pred_index)
mu_pred_samples = np.array(pred_new["mu_pred"])
mean_mu = mu_pred_samples.mean(axis=0)
lower_pi, upper_pi = np.percentile(sales_pred_samples, [2.5, 97.5], axis=0)

fig, ax = plt.subplots()
ax.scatter(np.array(temperature), np.array(sales), alpha=0.6, label="Observed sales")
ax.plot(np.array(temperature_pred), mean_mu, label="Posterior mean of expected sales")
ax.fill_between(np.array(temperature_pred), lower_pi, upper_pi, alpha=0.3, label="95% prediction interval")
ax.set_xlabel("Temperature (°C)")
ax.set_ylabel("Sales")
ax.set_title("Beer Sales vs Temperature with Posterior Predictions")
ax.legend()
plt.show()

# 6-2) Posterior Predictive Check for observed data
#     （条件に従い group="posterior" を指定）
az.plot_ppc(idata_ppc, group="posterior")
plt.title("Posterior Predictive Check (Observed vs Posterior Predictive)")
plt.show()

# 6-3) 11～30°C の各温度での予測値の95%区間（bayesplot::mcmc_intervals相当）
#az.plot_forest(
#    idata_pred,
#    var_names=["sales_pred"],
#    group="predictions",
#    combined=True,
#    hdi_prob=0.95
#)
# 旧: az.plot_forest(idata_pred, var_names=["sales_pred"], group="predictions", combined=True, hdi_prob=0.95)
# 新: predictions グループの Dataset を直接渡す
az.plot_forest(
    idata_pred.predictions,            # ← ここがポイント
    var_names=["sales_pred"],
    combined=True,
    hdi_prob=0.95
)
plt.title("95% intervals for predicted sales by temperature (11–30°C)")
plt.xlabel("Predicted sales")
plt.show()

# 6-4) mu_pred[1] と sales_pred[1] の95%区間比較（Rの比較に相当）
#      ※Pythonは0始まりなので pred_index=0 が 11°C に対応
#az.plot_forest(
#    idata_pred,
#    var_names=["mu_pred", "sales_pred"],
#    coords={"pred_index": [0]},
#    group="predictions",
#    combined=True,
#    hdi_prob=0.95
#)
# 旧: az.plot_forest(idata_pred, var_names=["mu_pred", "sales_pred"], group="predictions", coords={"pred_index": [0]}, combined=True, hdi_prob=0.95)
# 新:
az.plot_forest(
    idata_pred.predictions,            # ← predictions を直接
    var_names=["mu_pred", "sales_pred"],
    coords={"pred_index": [0]},
    combined=True,
    hdi_prob=0.95
)
plt.title("95% interval: expected vs predicted sales at 11°C")
plt.xlabel("Value")
plt.show()

# 6-5) sales_pred[1]（=11°C）と sales_pred[20]（=30°C）の事後分布比較
az.plot_posterior(
    idata_pred,
    var_names=["sales_pred"],
    coords={"pred_index": [0, N_pred - 1]},
    group="predictions",
    hdi_prob=0.99
)
plt.suptitle("Posterior of predicted sales at 11°C and 30°C", y=1.02)
plt.show()

# ------------------------------------------------------------
# 7) Print a few numerical results (print() as required)
# ------------------------------------------------------------
print("Posterior mean of Intercept, beta, sigma:")
post = az.extract(idata_post, group="posterior")
for p in ["Intercept", "beta", "sigma"]:
    vals = np.array(post[p]).ravel()
    print(f"{p}: mean={vals.mean():.3f}, sd={vals.std(ddof=1):.3f}")

print("First 5 predictions for 11–30°C (posterior mean of expected sales):")
for t, m in zip(np.array(temperature_pred[:5]), mean_mu[:5]):
    print(f"T={t}°C -> E[sales] ≈ {m:.2f}")
