# 3-10-交互作用

In [None]:
# -*- coding: utf-8 -*-
"""
Interactions with NumPyro: (1) categorical x categorical, (2) categorical x continuous, (3) continuous x continuous
- CSV loading: pandas
- Bayesian estimation: NumPyro (NUTS/MCMC)
- Model visualization: numpyro.contrib.render.render_model (built-in)
- Posterior visualization: ArviZ (plot_posterior with hdi_prob, plot_ppc with group="posterior", plot_forest without group)
- Data viz: matplotlib / seaborn / ArviZ
- All printed results via print()
- Labels are in English
- Prohibited usages avoided: credible_interval, az.from_numpyro(observed_data=...), az.plot_density(kind=...), az.plot_forest(group=...)
"""

import os
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns

import jax
import jax.numpy as jnp
from numpyro import sample, plate
import numpyro
import numpyro.distributions as dist
from numpyro.infer import MCMC, NUTS, Predictive

import arviz as az
import xarray as xr

from numpyro.contrib.render import render_model  # NumPyro built-in model visualization

# ---------------------------
# Utility helpers
# ---------------------------

SEED = 1
rng_key = jax.random.PRNGKey(SEED)

def print_header(title):
    print("\n" + "="*80)
    print(title)
    print("="*80)

def run_mcmc(model_fn, rng_key, num_warmup=1000, num_samples=1000, num_chains=4, **model_kwargs):
    kernel = NUTS(model_fn)
    mcmc = MCMC(kernel, num_warmup=num_warmup, num_samples=num_samples, num_chains=num_chains, progress_bar=False)
    mcmc.run(rng_key, **model_kwargs)
    mcmc.print_summary()
    return mcmc

#def idata_from_mcmc_and_ppc(mcmc, model_fn, rng_key, obs_name, observed_array, **predict_kwargs):
#    """
#    Create ArviZ InferenceData WITHOUT using observed_data argument in az.from_numpyro.
#    Then add observed_data manually so az.plot_ppc can work.
#    """
#    # Posterior predictive
#    predictive = Predictive(model_fn, posterior_samples=mcmc.get_samples(), return_sites=[obs_name])
#    pp = predictive(rng_key, **predict_kwargs)
#    idata = az.from_numpyro(mcmc, posterior_predictive={obs_name: pp[obs_name]})
#    # Manually attach observed_data (without using observed_data=...)
#    ds_obs = xr.Dataset({obs_name: (("obs_dim",), np.asarray(observed_array))})
#    idata.add_groups({"observed_data": ds_obs})
#    return idata

def idata_from_mcmc_and_ppc(mcmc, model_fn, rng_key, obs_name, **predict_kwargs):
    predictive = Predictive(model_fn, posterior_samples=mcmc.get_samples(),
                            return_sites=[obs_name])
    pp = predictive(rng_key, **predict_kwargs)
    # observed_data は az.from_numpyro が自動で付ける（posterior=MCMC を渡しているため）
    idata = az.from_numpyro(mcmc, posterior_predictive={obs_name: pp[obs_name]})
    return idata


def render_and_save_model(model_fn, filename_base, **model_kwargs):
    """
    Use NumPyro's built-in render_model to visualize the Bayesian model.
    If Graphviz backend is unavailable, skip gracefully.
    """
    try:
        g = render_model(model_fn, model_kwargs=model_kwargs, render_distributions=True)
        outpath = f"{filename_base}.gv"
        g.save(outpath)
        print(f"Model graph saved (Graphviz .gv): {outpath}")
    except Exception as e:
        print(f"Model rendering skipped: {e}")

def summarize_predictions(name, pred_array, hdi_prob=0.95):
    """
    Summarize predictions across posterior samples.
    pred_array: shape (draws, N)
    """
    mean = np.mean(pred_array, axis=0)
    hdi = az.hdi(pred_array, hdi_prob=hdi_prob)
    print_header(f"{name} | Posterior Predictive Summary (mean and HDI)")
    for i, (m, (lo, hi)) in enumerate(zip(mean, hdi)):
        print(f"Row {i}: mean={m:.2f}, {int(hdi_prob*100)}% HDI=({lo:.2f}, {hi:.2f})")

# ---------------------------
# 1) categorical x categorical
# ---------------------------

def model_cat_cat(publicity, bargen, sales=None):
    N = publicity.shape[0]
    alpha = sample("alpha", dist.Normal(0, 10))
    beta_pub = sample("beta_publicity", dist.Normal(0, 10))
    beta_bar = sample("beta_bargen", dist.Normal(0, 10))
    beta_int = sample("beta_interaction", dist.Normal(0, 10))
    sigma = sample("sigma", dist.HalfNormal(10))
    mu = alpha + beta_pub*publicity + beta_bar*bargen + beta_int*(publicity*bargen)
    with plate("obs", N):
        sample("sales", dist.Normal(mu, sigma), obs=sales)

# ---------------------------
# 2) categorical x continuous
# ---------------------------

def model_cat_cont(publicity, temperature, sales=None):
    N = publicity.shape[0]
    alpha = sample("alpha", dist.Normal(0, 10))
    beta_pub = sample("beta_publicity", dist.Normal(0, 10))
    beta_temp = sample("beta_temperature", dist.Normal(0, 10))
    beta_int = sample("beta_interaction", dist.Normal(0, 10))
    sigma = sample("sigma", dist.HalfNormal(10))
    mu = alpha + beta_pub*publicity + beta_temp*temperature + beta_int*(publicity*temperature)
    with plate("obs", N):
        sample("sales", dist.Normal(mu, sigma), obs=sales)

# ---------------------------
# 3) continuous x continuous
# ---------------------------

def model_cont_cont(product, clerk, sales=None):
    N = product.shape[0]
    alpha = sample("alpha", dist.Normal(0, 10))
    beta_prod = sample("beta_product", dist.Normal(0, 10))
    beta_clerk = sample("beta_clerk", dist.Normal(0, 10))
    beta_int = sample("beta_interaction", dist.Normal(0, 10))
    sigma = sample("sigma", dist.HalfNormal(10))
    mu = alpha + beta_prod*product + beta_clerk*clerk + beta_int*(product*clerk)
    with plate("obs", N):
        sample("sales", dist.Normal(mu, sigma), obs=sales)

# ---------------------------
# Main analysis
# ---------------------------

if __name__ == "__main__":
    sns.set(style="whitegrid")

    # =============================================================================
    # Categorical x Categorical
    # =============================================================================
    print_header("Categorical x Categorical: Load & Summarize")
    interaction_1 = pd.read_csv("3-10-1-interaction-1.csv")
    print("Head (3 rows):")
    print(interaction_1.head(3))
    print("\nSummary:")
    print(interaction_1.describe(include='all'))

    # Design matrix (like R's model.matrix)
    import patsy as pt
    mm1 = pt.dmatrix("C(publicity) * C(bargen)", interaction_1, return_type="dataframe")
    print_header("Design Matrix (C(publicity) * C(bargen))")
    print(mm1.head())

    # Prepare data: encode to {0,1} with baseline "not"
    pub_map = {"not": 0, "to_implement": 1}
    bar_map = {"not": 0, "to_implement": 1}
    y1 = interaction_1["sales"].to_numpy()
    x_pub1 = interaction_1["publicity"].map(pub_map).to_numpy()
    x_bar1 = interaction_1["bargen"].map(bar_map).to_numpy()

    # Fit model
    print_header("Categorical x Categorical: MCMC Summary")
    mcmc1 = run_mcmc(
        model_cat_cat, rng_key, sales=jnp.array(y1), publicity=jnp.array(x_pub1), bargen=jnp.array(x_bar1)
    )

    # NumPyro model visualization (built-in)
    print_header("Model Visualization (NumPyro built-in)")
    render_and_save_model(
        model_cat_cat,
        filename_base="model_cat_cat",
        sales=jnp.array(y1), publicity=jnp.array(x_pub1), bargen=jnp.array(x_bar1)
    )

    # Posterior to ArviZ InferenceData (no observed_data arg) + add observed manually
    rng_key, subkey = jax.random.split(rng_key)
    #idata1 = idata_from_mcmc_and_ppc(
    #    mcmc1, model_cat_cat, subkey, obs_name="sales", observed_array=y1,
    #    publicity=jnp.array(x_pub1), bargen=jnp.array(x_bar1), sales=None
    #)
    idata1 = idata_from_mcmc_and_ppc(
        mcmc1, model_cat_cat, subkey, obs_name="sales",
        publicity=jnp.array(x_pub1), bargen=jnp.array(x_bar1), sales=None
    )

    # Posterior distributions (use hdi_prob, NOT credible_interval)
    az.plot_posterior(idata1, var_names=["alpha","beta_publicity","beta_bargen","beta_interaction","sigma"], hdi_prob=0.95)
    plt.suptitle("Posterior (Categorical x Categorical)", y=1.02)
    plt.tight_layout()
    plt.show()

    # Forest plot (no group argument)
    az.plot_forest(idata1, var_names=["alpha","beta_publicity","beta_bargen","beta_interaction"])
    plt.title("Forest Plot (Categorical x Categorical)")
    plt.tight_layout()
    plt.show()

    # PPC (must pass group="posterior")
    az.plot_ppc(idata1, group="posterior")
    plt.suptitle("Posterior Predictive Check (Categorical x Categorical)", y=1.02)
    plt.tight_layout()
    plt.show()

    # Interaction effect check (predict on new data)
    print_header("Categorical x Categorical: Predictions for New Data")
    newdata_1 = pd.DataFrame({
        "publicity": ["not", "to_implement", "not", "to_implement"],
        "bargen":    ["not", "not", "to_implement", "to_implement"],
    })
    print(newdata_1)
    new_pub1 = newdata_1["publicity"].map(pub_map).to_numpy()
    new_bar1 = newdata_1["bargen"].map(bar_map).to_numpy()

    pred1 = Predictive(model_cat_cat, posterior_samples=mcmc1.get_samples(), return_sites=["sales"])
    rng_key, subkey = jax.random.split(rng_key)
    ypred1 = pred1(subkey, publicity=jnp.array(new_pub1), bargen=jnp.array(new_bar1), sales=None)["sales"]  # (draws, 4)
    summarize_predictions("Categorical x Categorical", np.asarray(ypred1), hdi_prob=0.95)

    # Marginal effect-like plot
    # (Mean predictions for each combination)
    grid1 = newdata_1.copy()
    grid1["pub_code"] = grid1["publicity"].map(pub_map)
    grid1["bar_code"] = grid1["bargen"].map(bar_map)
    mean_pred1 = np.mean(ypred1, axis=0)
    plt.figure()
    for bar_label, sub in grid1.groupby("bargen"):
        idx = sub.index.to_numpy()
        plt.plot(sub["pub_code"], mean_pred1[idx], marker="o", label=f"Bargen: {bar_label}")
    plt.xticks([0,1], ["publicity=not", "publicity=to_implement"])
    plt.xlabel("Publicity")
    plt.ylabel("Predicted Sales")
    plt.title("Interaction Plot (Categorical x Categorical)")
    plt.legend()
    plt.tight_layout()
    plt.show()

    # =============================================================================
    # Categorical x Continuous
    # =============================================================================
    print_header("Categorical x Continuous: Load & Summarize")
    interaction_2 = pd.read_csv("3-10-2-interaction-2.csv")
    print("Head (3 rows):")
    print(interaction_2.head(3))
    print("\nSummary:")
    print(interaction_2.describe(include='all'))

    mm2 = pt.dmatrix("C(publicity) * temperature", interaction_2, return_type="dataframe")
    print_header("Design Matrix (C(publicity) * temperature)")
    print(mm2.head())

    y2 = interaction_2["sales"].to_numpy()
    x_pub2 = interaction_2["publicity"].map(pub_map).to_numpy()
    temp2 = interaction_2["temperature"].to_numpy()

    print_header("Categorical x Continuous: MCMC Summary")
    rng_key, subkey = jax.random.split(rng_key)
    mcmc2 = run_mcmc(
        model_cat_cont, subkey, sales=jnp.array(y2), publicity=jnp.array(x_pub2), temperature=jnp.array(temp2)
    )

    print_header("Model Visualization (NumPyro built-in)")
    render_and_save_model(
        model_cat_cont, filename_base="model_cat_cont",
        sales=jnp.array(y2), publicity=jnp.array(x_pub2), temperature=jnp.array(temp2)
    )

    rng_key, subkey = jax.random.split(rng_key)
    #idata2 = idata_from_mcmc_and_ppc(
    #    mcmc2, model_cat_cont, subkey, obs_name="sales", observed_array=y2,
    #    publicity=jnp.array(x_pub2), temperature=jnp.array(temp2), sales=None
    #)
    idata2 = idata_from_mcmc_and_ppc(
        mcmc2, model_cat_cont, subkey, obs_name="sales",# observed_array=y2,
        publicity=jnp.array(x_pub2), temperature=jnp.array(temp2), sales=None
    )

    az.plot_posterior(idata2, var_names=["alpha","beta_publicity","beta_temperature","beta_interaction","sigma"], hdi_prob=0.95)
    plt.suptitle("Posterior (Categorical x Continuous)", y=1.02)
    plt.tight_layout()
    plt.show()

    az.plot_forest(idata2, var_names=["alpha","beta_publicity","beta_temperature","beta_interaction"])
    plt.title("Forest Plot (Categorical x Continuous)")
    plt.tight_layout()
    plt.show()

    az.plot_ppc(idata2, group="posterior")
    plt.suptitle("Posterior Predictive Check (Categorical x Continuous)", y=1.02)
    plt.tight_layout()
    plt.show()

    print_header("Categorical x Continuous: Predictions for New Data")
    newdata_2 = pd.DataFrame({
        "publicity":   ["not", "not", "to_implement", "to_implement"],
        "temperature": [0, 10, 0, 10]
    })
    print(newdata_2)
    new_pub2 = newdata_2["publicity"].map(pub_map).to_numpy()
    new_temp2 = newdata_2["temperature"].to_numpy()

    pred2 = Predictive(model_cat_cont, posterior_samples=mcmc2.get_samples(), return_sites=["sales"])
    rng_key, subkey = jax.random.split(rng_key)
    ypred2 = pred2(subkey, publicity=jnp.array(new_pub2), temperature=jnp.array(new_temp2), sales=None)["sales"]
    summarize_predictions("Categorical x Continuous", np.asarray(ypred2), hdi_prob=0.95)

    # Regression lines (marginal effects-like)
    tgrid = np.linspace(min(temp2), max(temp2), 50)
    fig, ax = plt.subplots()
    for pub_label, code in pub_map.items():
        pred_line = Predictive(model_cat_cont, posterior_samples=mcmc2.get_samples(), return_sites=["sales"])
        rng_key, subkey = jax.random.split(rng_key)
        yline = pred_line(subkey, publicity=jnp.array(np.full_like(tgrid, code)), temperature=jnp.array(tgrid), sales=None)["sales"]
        ax.plot(tgrid, np.mean(yline, axis=0), label=f"Publicity: {pub_label}")
    ax.set_xlabel("Temperature")
    ax.set_ylabel("Predicted Sales")
    ax.set_title("Regression Lines (Categorical x Continuous)")
    ax.legend()
    plt.tight_layout()
    plt.show()

    # =============================================================================
    # Continuous x Continuous
    # =============================================================================
    print_header("Continuous x Continuous: Load & Summarize")
    interaction_3 = pd.read_csv("3-10-3-interaction-3.csv")
    print("Head (3 rows):")
    print(interaction_3.head(3))
    print("\nSummary:")
    print(interaction_3.describe(include='all'))

    # Scatter plot (labels in English)
    plt.figure()
    sns.scatterplot(
        data=interaction_3, x="product", y="sales", hue=interaction_3["clerk"].astype(str), legend=False
    )
    plt.xlabel("Product")
    plt.ylabel("Sales")
    plt.title("Scatter: Sales vs Product (Colored by Clerk)")
    plt.tight_layout()
    plt.show()

    mm3 = pt.dmatrix("product * clerk", interaction_3, return_type="dataframe")
    print_header("Design Matrix (product * clerk)")
    print(mm3.head())

    y3 = interaction_3["sales"].to_numpy()
    prod3 = interaction_3["product"].to_numpy()
    clerk3 = interaction_3["clerk"].to_numpy()

    print_header("Continuous x Continuous: MCMC Summary")
    rng_key, subkey = jax.random.split(rng_key)
    mcmc3 = run_mcmc(
        model_cont_cont, subkey, sales=jnp.array(y3), product=jnp.array(prod3), clerk=jnp.array(clerk3)
    )

    print_header("Model Visualization (NumPyro built-in)")
    render_and_save_model(
        model_cont_cont, filename_base="model_cont_cont",
        sales=jnp.array(y3), product=jnp.array(prod3), clerk=jnp.array(clerk3)
    )

    rng_key, subkey = jax.random.split(rng_key)
    #idata3 = idata_from_mcmc_and_ppc(
    #    mcmc3, model_cont_cont, subkey, obs_name="sales", observed_array=y3,
    #    product=jnp.array(prod3), clerk=jnp.array(clerk3), sales=None
    #)
    idata3 = idata_from_mcmc_and_ppc(
        mcmc3, model_cont_cont, subkey, obs_name="sales",# observed_array=y3,
        product=jnp.array(prod3), clerk=jnp.array(clerk3), sales=None
    )

    az.plot_posterior(idata3, var_names=["alpha","beta_product","beta_clerk","beta_interaction","sigma"], hdi_prob=0.95)
    plt.suptitle("Posterior (Continuous x Continuous)", y=1.02)
    plt.tight_layout()
    plt.show()

    az.plot_forest(idata3, var_names=["alpha","beta_product","beta_clerk","beta_interaction"])
    plt.title("Forest Plot (Continuous x Continuous)")
    plt.tight_layout()
    plt.show()

    az.plot_ppc(idata3, group="posterior")
    plt.suptitle("Posterior Predictive Check (Continuous x Continuous)", y=1.02)
    plt.tight_layout()
    plt.show()

    # Predictions at specified points (like newdata_3 in R)
    print_header("Continuous x Continuous: Predictions for New Data")
    newdata_3 = pd.DataFrame({
        "product": [0, 10, 0, 10],
        "clerk":   [0, 0, 10, 10]
    })
    print(newdata_3)
    new_prod3 = newdata_3["product"].to_numpy()
    new_clerk3 = newdata_3["clerk"].to_numpy()

    pred3 = Predictive(model_cont_cont, posterior_samples=mcmc3.get_samples(), return_sites=["sales"])
    rng_key, subkey = jax.random.split(rng_key)
    ypred3 = pred3(subkey, product=jnp.array(new_prod3), clerk=jnp.array(new_clerk3), sales=None)["sales"]
    summarize_predictions("Continuous x Continuous", np.asarray(ypred3), hdi_prob=0.95)

    # Regression lines: vary product across several clerk values
    plt.figure()
    pgrid = np.linspace(min(prod3), max(prod3), 50)
    for c in sorted(np.unique(clerk3))[:9]:  # mirror R example: up to first 9 "clerk" values if many
        pred_line = Predictive(model_cont_cont, posterior_samples=mcmc3.get_samples(), return_sites=["sales"])
        rng_key, subkey = jax.random.split(rng_key)
        yline = pred_line(subkey, product=jnp.array(pgrid), clerk=jnp.array(np.full_like(pgrid, c)), sales=None)["sales"]
        plt.plot(pgrid, np.mean(yline, axis=0), label=f"clerk={c}")
    plt.xlabel("Product")
    plt.ylabel("Predicted Sales")
    plt.title("Regression Lines by Number of Clerks")
    # For clarity with many lines, hide legend or keep it small
    # plt.legend(ncol=2, fontsize=8)
    plt.tight_layout()
    plt.show()

    # Separate small multiples for clerk values 1..9 if available
    unique_clerks = np.unique(clerk3)
    cond_clerks = unique_clerks[:min(9, len(unique_clerks))]
    ncols = 3
    nrows = int(np.ceil(len(cond_clerks)/ncols))
    fig, axes = plt.subplots(nrows, ncols, figsize=(12, 4*nrows), squeeze=False)
    for i, c in enumerate(cond_clerks):
        r, cc = divmod(i, ncols)
        ax = axes[r, cc]
        pred_line = Predictive(model_cont_cont, posterior_samples=mcmc3.get_samples(), return_sites=["sales"])
        rng_key, subkey = jax.random.split(rng_key)
        yline = pred_line(subkey, product=jnp.array(pgrid), clerk=jnp.array(np.full_like(pgrid, c)), sales=None)["sales"]
        ax.plot(pgrid, np.mean(yline, axis=0))
        ax.set_title(f"Clerk = {c}")
        ax.set_xlabel("Product")
        ax.set_ylabel("Predicted Sales")
    # Hide any unused subplots
    total_ax = nrows*ncols
    for j in range(len(cond_clerks), total_ax):
        r, cc = divmod(j, ncols)
        axes[r, cc].axis("off")
    fig.suptitle("Regression Lines by Clerk (Small Multiples)", y=1.02)
    plt.tight_layout()
    plt.show()

    print_header("Done")
