# 3-7-正規線形モデル

In [None]:
# -*- coding: utf-8 -*-
# Normal linear model in Python with NumPyro + ArviZ
# Translation of the R/brms example

import pandas as pd
import numpy as np
import jax.numpy as jnp
from jax import random

import matplotlib.pyplot as plt
import seaborn as sns
import arviz as az

import numpyro
import numpyro.distributions as dist
from numpyro.infer import MCMC, NUTS, Predictive
from numpyro.diagnostics import hpdi

# -----------------------------
# 1) Load data & quick overview
# -----------------------------
df = pd.read_csv("3-7-1-beer-sales-4.csv")

print("Head (3 rows):")
print(df.head(3).to_string(index=False))

print("\nSummary:")
print(df.describe(include="all").to_string())

# -----------------------------
# 2) Scatter plot (English labels)
# -----------------------------
sns.set(style="whitegrid")
plt.figure(figsize=(7, 5))
sns.scatterplot(
    data=df,
    x="temperature",
    y="sales",
    hue="weather",
    edgecolor="white"
)
plt.title("Beer Sales vs Temperature and Weather")
plt.xlabel("Temperature")
plt.ylabel("Sales")
plt.tight_layout()
plt.show()

# -------------------------------------
# 3) Build design matrix like model.matrix
#    sales ~ weather + temperature
# -------------------------------------
df["weather"] = df["weather"].astype("category")
weather_categories = list(df["weather"].cat.categories)
baseline_weather = weather_categories[0]  # reference level (drop_first=True)
dummies = pd.get_dummies(df["weather"], prefix="weather", drop_first=True)

X = pd.concat(
    [
        pd.Series(1.0, index=df.index, name="Intercept"),
        df["temperature"].rename("temperature"),
        dummies
    ],
    axis=1
)
y = df["sales"].to_numpy()

X_cols = list(X.columns)
print("\nDesign matrix columns:", X_cols)
print("Baseline weather level (reference):", baseline_weather)
print("\nDesign matrix (head):")
print(X.head().to_string(index=False))

X_np = X.to_numpy(dtype=np.float32)
X_jnp = jnp.array(X_np)
y_jnp = jnp.array(y)

# -----------------------------
# 4) NumPyro model
# -----------------------------
def model(X, y=None):
    n_features = X.shape[1]
    beta = numpyro.sample(
        "beta",
        dist.Normal(0.0, 10.0).expand([n_features]).to_event(1)
    )
    sigma = numpyro.sample("sigma", dist.HalfNormal(10.0))
    mu = jnp.dot(X, beta)
    numpyro.sample("sales", dist.Normal(mu, sigma), obs=y)

# -----------------------------
# 5) MCMC sampling
# -----------------------------
rng_key = random.PRNGKey(1)
nuts = NUTS(model)
mcmc = MCMC(nuts, num_warmup=1000, num_samples=2000, num_chains=2, progress_bar=False)
mcmc.run(rng_key, X=X_jnp, y=y_jnp)

# -----------------------------
# 6) InferenceData for ArviZ (NO observed_data argument here)
#    Add posterior predictive and observed data separately.
# -----------------------------
posterior_samples = mcmc.get_samples()
rng_key, rng_ppc = random.split(rng_key)
post_pred = Predictive(model, posterior_samples)(rng_ppc, X=X_jnp)

coords = {
    "coef": X_cols,
    "obs_id": np.arange(len(df))
}
dims = {
    "beta": ["coef"],
    "sales": ["obs_id"]
}

idata = az.from_numpyro(
    mcmc,
    coords=coords,
    dims=dims,
    posterior_predictive=post_pred  # OK to pass
    # ⚠️ Do NOT pass observed_data here (forbidden by your rules)
)

# Now attach observed data via from_dict (allowed)
idata_obs = az.from_dict(
    observed_data={"sales": y},
    coords={"obs_id": np.arange(len(df))},
    dims={"sales": ["obs_id"]}
)
#idata = az.concat(idata, idata_obs)
idata = az.from_numpyro(
    mcmc,
    coords=coords,
    dims=dims,
    posterior_predictive=post_pred
)

# -----------------------------
# 7) Print summary with print()
# -----------------------------
summary_df = az.summary(idata, var_names=["beta", "sigma"], hdi_prob=0.94)
print("\nMCMC summary (HDI 94%):")
print(summary_df.to_string())

# -----------------------------
# 8) Posterior distribution plots (ArviZ, use hdi_prob—not credible_interval)
# -----------------------------
az.plot_posterior(
    idata,
    var_names=["beta", "sigma"],
    hdi_prob=0.94
)
plt.tight_layout()
plt.show()

# -----------------------------
# 9) Posterior Predictive Check
#    Use group="posterior" as required
# -----------------------------
az.plot_ppc(idata, group="posterior")
plt.tight_layout()
plt.show()

# ---------------------------------------------------------
# 10) "Marginal effects": regression lines by weather level
#     Using NumPyro's hpdi for uncertainty bands
# ---------------------------------------------------------
# Temperature grid:
t_min, t_max = df["temperature"].min(), df["temperature"].max()
t_grid = np.linspace(t_min, t_max, 60)

# Helper to build X for a chosen weather:
def build_X_for(weather_label, temps):
    # start with baseline columns
    base = pd.DataFrame({
        "Intercept": np.ones_like(temps, dtype=float),
        "temperature": temps
    })
    # build dummy columns that were present in training (drop_first=True)
    dummy_cols = [c for c in X_cols if c.startswith("weather_")]
    d = pd.get_dummies(pd.Categorical([weather_label]*len(temps),
                                      categories=weather_categories),
                       prefix="weather",
                       drop_first=True)
    # Align columns and fill missing with 0
    d = d.reindex(columns=[c.replace("weather_", "weather_") for c in dummy_cols], fill_value=0)
    X_new = pd.concat([base, d], axis=1)
    # Ensure same column order as training design matrix
    X_new = X_new[X_cols]
    return X_new.to_numpy(dtype=np.float32)

betas = posterior_samples["beta"]  # shape: (samples, n_features)

plt.figure(figsize=(7, 5))
sns.scatterplot(
    data=df,
    x="temperature",
    y="sales",
    hue="weather",
    alpha=0.6,
    edgecolor="white"
)

for w in weather_categories:
    Xg = build_X_for(w, t_grid)
    mu_draws = Xg @ np.asarray(betas.T)  # shape: (len(t_grid), n_samples)
    # hpdi expects (n_samples, n_points)
    hdi_bounds = hpdi(jnp.array(mu_draws.T), prob=0.94)  # (2, n_points)
    mu_med = np.median(mu_draws, axis=1)

    plt.plot(t_grid, mu_med, label=f"{w} (median)")
    plt.fill_between(t_grid, hdi_bounds[0], hdi_bounds[1], alpha=0.2)

plt.title("Fitted Regression Lines with 94% HDI")
plt.xlabel("Temperature")
plt.ylabel("Sales")
plt.tight_layout()
plt.show()

# -----------------------------
# 11) (Optional) print design matrix shape and basic checks
# -----------------------------
print("\nShapes:")
print(f"X: {X_jnp.shape}, y: {y_jnp.shape}")
print(f"Num categories (weather): {len(weather_categories)}; baseline: {baseline_weather}")
