**6. Posterior Inference**

- HMC / NUTS sampling: 4 chains, 2000 draws (1000 warmup)

- Convergence Diagnostics:

    - R-hat (<1.1), Effective Sample Size (ESS), trace plots

    - Any divergences noted and mitigated

- Figures:

    - Trace plots

    - Posterior distributions of α and β parameters

In [None]:
# ============================================================
# Posterior Inference
# ============================================================

import numpy as np
import pandas as pd
import pymc as pm
import arviz as az
import matplotlib.pyplot as plt
import seaborn as sns

sns.set_style("whitegrid")

In [None]:
# ------------------------------------------------------------
# Load baseline data
# ------------------------------------------------------------

df = pd.read_csv("data/processed/pbc_clean.csv")

df_baseline = (
    df.sort_values(["id", "year"])
      .groupby("id", as_index=False)
      .first()
)

covariates = ["age", "sex", "drug", "serBilir", "albumin", "edema"]

# ------------------------------------------------------------
# Load posterior trace (from Section 5)
# ------------------------------------------------------------

trace = az.from_netcdf("results/models/model1_trace.nc")

In [None]:
# ------------------------------------------------------------
# Sampling diagnostics
# ------------------------------------------------------------

summary = az.summary(
    trace,
    var_names=["alpha", "beta"],
    round_to=3
)

print("\n=== Sampling Diagnostics ===")
print(summary[["r_hat", "ess_bulk", "ess_tail"]])

# Divergence check
divergences = trace.sample_stats["diverging"].values.sum()
print(f"\nNumber of divergences: {divergences}")

assert summary["r_hat"].max() < 1.1, "R-hat indicates lack of convergence"

In [None]:
 ------------------------------------------------------------
# Figure 10: Trace plots
# ------------------------------------------------------------

az.plot_trace(
    trace,
    var_names=["alpha", "beta"],
    compact=False
)

plt.tight_layout()
plt.savefig("results/figures/trace_plots.png", dpi=300)
plt.close()

In [None]:
posterior = trace.posterior
# ------------------------------------------------------------
# Table 2: Posterior summaries
# ------------------------------------------------------------

posterior_summary = az.summary(
    trace,
    var_names=["alpha", "beta"],
    hdi_prob=0.95
)[["mean", "hdi_2.5%", "hdi_97.5%"]]

# Rename beta rows
beta_names = [f"beta_{c}" for c in covariates]
posterior_summary.index = ["alpha"] + beta_names

posterior_summary = posterior_summary.round(3)
posterior_summary.to_csv("results/tables/table2_posterior_estimates.csv")

print("\n=== Table 2: Posterior Estimates ===")
print(posterior_summary)

In [None]:
# ------------------------------------------------------------
# Table 3: Hazard ratios
# ------------------------------------------------------------

beta_samples = posterior["beta"].values.reshape(-1, len(covariates))

hazard_ratios = np.exp(beta_samples)

hr_summary = pd.DataFrame({
    "HR_mean": hazard_ratios.mean(axis=0),
    "HR_2.5%": np.percentile(hazard_ratios, 2.5, axis=0),
    "HR_97.5%": np.percentile(hazard_ratios, 97.5, axis=0)
}, index=covariates)

hr_summary = hr_summary.round(3)
hr_summary.to_csv("results/tables/table3_hazard_ratios.csv")

print("\n=== Table 3: Hazard Ratios ===")
print(hr_summary)

In [None]:
# ------------------------------------------------------------
# Figure 11: Posterior densities
# ------------------------------------------------------------

plt.figure(figsize=(10, 6))

for i, cov in enumerate(covariates):
    sns.kdeplot(
        hazard_ratios[:, i],
        label=cov,
        fill=True,
        alpha=0.4
    )

plt.axvline(1, color="black", linestyle="--", linewidth=1)
plt.xlabel("Hazard Ratio")
plt.ylabel("Density")
plt.title("Posterior Distributions of Hazard Ratios")
plt.legend()
plt.tight_layout()
plt.savefig("results/figures/posterior_hazard_ratios.png", dpi=300)
plt.close()