# 3-9-ロジスティック回帰モデル

In [None]:
# -*- coding: utf-8 -*-
# Logistic regression with binomial outcomes (NumPyro version)

import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns

import jax
import jax.numpy as jnp
import numpyro
import numpyro.distributions as dist
from numpyro.infer import NUTS, MCMC, Predictive
from numpyro import render_model

import arviz as az

# -------------------------------------------------------------------
# 0) Utility
# -------------------------------------------------------------------
def logistic(x):
    return 1 / (1 + np.exp(-x))

# For reproducibility (similar to seed=1 in R)
rng_key = jax.random.PRNGKey(1)

# -------------------------------------------------------------------
# 1) Load data and quick EDA (print & plot)
# -------------------------------------------------------------------
df = pd.read_csv("3-9-1-germination.csv")

print("Head (3 rows):")
print(df.head(3))

print("\nSummary:")
print(df.describe(include="all"))

# encode solar to dummy: 1 = sunshine, 0 = shade
df["solar_dummy"] = (df["solar"] == "sunshine").astype(int)

# scatter: observed proportion vs nutrition, colored by solar
plt.figure(figsize=(6,4))
sns.scatterplot(
    data=df,
    x="nutrition", 
    y=df["germination"] / df["size"],
    hue="solar",
    style="solar"
)
plt.title("Seed germination proportion by nutrition and solar")
plt.xlabel("Nutrition")
plt.ylabel("Germination proportion")
plt.tight_layout()
plt.show()

# -------------------------------------------------------------------
# 2) NumPyro model (equivalent to Stan binomial_logit)
# -------------------------------------------------------------------
def glm_binom_model(germination, size, solar, nutrition):
    Intercept = numpyro.sample("Intercept", dist.Normal(0.0, 10.0))
    b_solar   = numpyro.sample("b_solar",   dist.Normal(0.0, 10.0))
    b_nutri   = numpyro.sample("b_nutrition", dist.Normal(0.0, 10.0))
    logits = Intercept + b_solar * solar + b_nutri * nutrition
    numpyro.sample("germination", dist.Binomial(total_count=size, logits=logits), obs=germination)

# data arrays
germination = jnp.array(df["germination"].values, dtype=jnp.int32)
size        = jnp.array(df["size"].values,        dtype=jnp.int32)
solar       = jnp.array(df["solar_dummy"].values, dtype=jnp.float32)
nutrition   = jnp.array(df["nutrition"].values,   dtype=jnp.float32)

# -------------------------------------------------------------------
# 3) Fit with MCMC (NUTS)
# -------------------------------------------------------------------
nuts = NUTS(glm_binom_model)
mcmc = MCMC(nuts, num_warmup=1000, num_samples=2000, num_chains=2, progress_bar=True)
mcmc.run(rng_key, germination=germination, size=size, solar=solar, nutrition=nutrition)

# Convert to ArviZ InferenceData (do NOT pass observed_data per instruction)
# Also create posterior predictive samples for PPC
ppc = Predictive(glm_binom_model, posterior_samples=mcmc.get_samples())(
    jax.random.split(rng_key, 1)[0],
    germination=None, size=size, solar=solar, nutrition=nutrition
)

idata = az.from_numpyro(
    mcmc,
    posterior_predictive=ppc  # observed_data is intentionally NOT provided as per requirement
)

# Print MCMC summary via ArviZ (wrapped in print())
summary_df = az.summary(idata, var_names=["Intercept","b_solar","b_nutrition"])
print("\nMCMC summary (ArviZ):")
print(summary_df.to_string())

# -------------------------------------------------------------------
# 4) Interpret coefficients: predicted probs & odds / odds ratios
#    (Analogous to R's newdata_1 and fitted values)
# -------------------------------------------------------------------
newdata_1 = pd.DataFrame({
    "solar":     ["shade", "sunshine", "sunshine"],
    "nutrition": [2,       2,          3],
    "size":      [10,      10,         10]
})
newdata_1["solar_dummy"] = (newdata_1["solar"] == "sunshine").astype(int)

# posterior means for coefficients
post = mcmc.get_samples()
b0_mean = float(np.asarray(post["Intercept"]).mean())
b1_mean = float(np.asarray(post["b_solar"]).mean())
b2_mean = float(np.asarray(post["b_nutrition"]).mean())

# linear predictor (mean coefficients)
linear_fit = (
    b0_mean
    + b1_mean * newdata_1["solar_dummy"].values
    + b2_mean * newdata_1["nutrition"].values
)
fit_prob = logistic(linear_fit)

print("\nNew data (for predictions):")
print(newdata_1[["solar","nutrition","size","solar_dummy"]])

print("\nPredicted probability (using posterior mean coefficients):")
print(fit_prob)

# odds
odds = fit_prob / (1 - fit_prob)
odds_1, odds_2, odds_3 = odds
print("\nOdds for each newdata row:")
#print(odds.values)
print(odds)

# model-based odds ratios (from predictions)
print("\nOdds ratio: solar shade→sunshine (nutrition=2):")
print(odds_2 / odds_1)

print("Odds ratio: nutrition 2→3 (solar=sunshine):")
print(odds_3 / odds_2)

# coefficient-based odds ratios (exp of coefficients)
print("\nexp(b_solar) from posterior mean:")
print(np.exp(b1_mean))

print("exp(b_nutrition) from posterior mean (1-unit increase):")
print(np.exp(b2_mean))

# -------------------------------------------------------------------
# 5) Marginal effects: regression curves with 95% HDI
#    (probability vs nutrition, for solar=shade/sunshine)
# -------------------------------------------------------------------
# grid over nutrition
x_grid = np.linspace(df["nutrition"].min(), df["nutrition"].max(), 80)

# draw-level coefficients
b0 = np.asarray(post["Intercept"])        # (draws,)
b1 = np.asarray(post["b_solar"])          # (draws,)
b2 = np.asarray(post["b_nutrition"])      # (draws,)

def prob_grid(solar_val):
    # shape: (draws, len(x_grid))
    lp = b0[:, None] + b1[:, None]*solar_val + b2[:, None]*x_grid[None, :]
    p  = 1 / (1 + np.exp(-lp))
    mean = p.mean(axis=0)
    hdi = az.hdi(p, hdi_prob=0.95)  # returns array of shape (len(x_grid), 2)
    return mean, hdi[:, 0], hdi[:, 1]

mean0, low0, high0 = prob_grid(0.0)  # shade
mean1, low1, high1 = prob_grid(1.0)  # sunshine

plt.figure(figsize=(7,5))
# raw data points (proportion)
sns.scatterplot(
    data=df,
    x="nutrition",
    y=df["germination"]/df["size"],
    hue="solar",
    style="solar",
    alpha=0.6
)

# shade curve
plt.plot(x_grid, mean0, label="shade (mean)", linewidth=2)
plt.fill_between(x_grid, low0, high0, alpha=0.2, label="shade 95% HDI")

# sunshine curve
plt.plot(x_grid, mean1, label="sunshine (mean)", linewidth=2)
plt.fill_between(x_grid, low1, high1, alpha=0.2, label="sunshine 95% HDI")

plt.title("Logistic regression fit with 95% HDI")
plt.xlabel("Nutrition")
plt.ylabel("Germination probability")
plt.legend(title="Legend")
plt.tight_layout()
plt.show()

# -------------------------------------------------------------------
# 6) Posterior visualization with ArviZ (respecting argument rules)
# -------------------------------------------------------------------
# Posterior distributions of coefficients
az.plot_posterior(idata, var_names=["Intercept","b_solar","b_nutrition"], hdi_prob=0.95)
plt.suptitle("Posterior of coefficients (95% HDI)", y=1.02)
plt.tight_layout()
plt.show()

# Forest plot (do NOT use group=...)
az.plot_forest(idata, var_names=["Intercept","b_solar","b_nutrition"])
plt.title("Forest plot of coefficients")
plt.tight_layout()
plt.show()

# Posterior predictive checks (explicitly set group='posterior')
az.plot_ppc(idata, group="posterior")
plt.suptitle("Posterior predictive check", y=1.02)
plt.tight_layout()
plt.show()

# -------------------------------------------------------------------
# 7) Model visualization (NumPyro built-in)
# -------------------------------------------------------------------
# This uses numpyro.render_model to visualize the model graph.
# It requires graphviz installed in your environment to display nicely.
try:
    _ = render_model(
        glm_binom_model,
        model_args=(),
        model_kwargs=dict(germination=None, size=size, solar=solar, nutrition=nutrition),
        render_params=True,
        render_distributions=True
    )
except Exception as e:
    print("Model graph rendering skipped:", e)

# -------------------------------------------------------------------
# 8) (Optional) Bernoulli case sketch when total_count==1 always
# -------------------------------------------------------------------
# def bernoulli_model(y01, x_solar, x_nutrition):
#     Intercept = numpyro.sample("Intercept", dist.Normal(0., 10.))
#     b_solar   = numpyro.sample("b_solar",   dist.Normal(0., 10.))
#     b_nutri   = numpyro.sample("b_nutrition", dist.Normal(0., 10.))
#     logits = Intercept + b_solar * x_solar + b_nutri * x_nutrition
#     numpyro.sample("y", dist.Bernoulli(logits=logits), obs=y01)
