**Model Comparison**

- LOO-CV and/or WAIC for Model 1 vs Model 2

- compare Exponential vs Weibull likelihoods

- Table: model comparison metrics and ΔLOO or ΔWAIC

In [None]:
import numpy as np
import pandas as pd
import arviz as az

In [None]:
# ------------------------------------------------------------
# Load baseline dataset
# ------------------------------------------------------------

df = pd.read_csv("data/processed/pbc_clean.csv")

df_baseline = (
    df.sort_values(["id", "year"])
      .groupby("id", as_index=False)
      .first()
)

T = df_baseline["years"].values
E = df_baseline["status2"].values

# Separate continuous and categorical covariates
continuous_vars = ["age", "serBilir", "albumin"]
categorical_vars = ["sex", "drug", "edema"]

# Scale continuous covariates
scaler = StandardScaler()
X_cont = scaler.fit_transform(df_baseline[continuous_vars])

# Keep categorical covariates as-is
X_cat = df_baseline[categorical_vars].values

# Combine design matrix
X = np.column_stack([X_cont, X_cat])

In [None]:
# ------------------------------------------------------------
# Load saved posterior traces
# ------------------------------------------------------------

trace_m1 = az.from_netcdf("results/models/model1_trace.nc")
trace_m2 = az.from_netcdf("results/models/model2_trace.nc")

In [None]:
def compute_loglik_exponential(trace, X, T, E, hierarchical=False):
    """
    Returns array of shape (n_draws, n_observations)
    """

    posterior = trace.posterior

    alpha = posterior["alpha"].values.reshape(-1, 1)
    beta = posterior["beta"].values.reshape(-1, X.shape[1])

    if hierarchical:
        alpha_h = posterior["alpha_h"].values.reshape(-1, X.shape[0])
        log_lambda = alpha_h + beta @ X.T
    else:
        log_lambda = alpha + beta @ X.T

    lambda_ = np.exp(log_lambda)

    loglik = (
        E * (np.log(lambda_) - lambda_ * T) +
        (1 - E) * (-lambda_ * T)
    )

    return loglik

In [None]:
# ------------------------------------------------------------
# Model Comparison (LOO / WAIC)
# ------------------------------------------------------------
# Load traces safely
idata_m1 = az.from_netcdf(file_m1)
idata_m2 = az.from_netcdf(file_m2)

# LOO
loo_m1 = az.loo(idata_m1, pointwise=True)
loo_m2 = az.loo(idata_m2, pointwise=True)

print("\n=== LOO Results ===")
print("Model 1:", loo_m1)
print("Model 2:", loo_m2)

# WAIC
waic_m1 = az.waic(idata_m1, pointwise=True)
waic_m2 = az.waic(idata_m2, pointwise=True)

print("\n=== WAIC Results ===")
print("Model 1:", waic_m1)
print("Model 2:", waic_m2)

In [None]:
# ------------------------------------------------------------
# Table 4: Absolute scores
# ------------------------------------------------------------

table4 = pd.DataFrame({
    "Model": ["Non-Hierarchical", "Hierarchical"],
    "LOO": [loo_m1.elpd_loo, loo_m2.elpd_loo],
    "LOO_SE": [loo_m1.se, loo_m2.se],
    "WAIC": [waic_m1.elpd_waic, waic_m2.elpd_waic],
    "WAIC_SE": [waic_m1.se, waic_m2.se],
})

table4 = table4.round(2)
table4.to_csv("results/tables/table4_model_comparison.csv", index=False)

print("\n=== Table 4: LOO / WAIC ===")
print(table4)

In [None]:
# ------------------------------------------------------------
# Table 5: Delta comparisons
# ------------------------------------------------------------

delta_loo = loo_m2.elpd_loo - loo_m1.elpd_loo
delta_waic = waic_m2.elpd_waic - waic_m1.elpd_waic

table5 = pd.DataFrame({
    "Metric": ["ΔLOO", "ΔWAIC"],
    "Hierarchical - NonHierarchical": [delta_loo, delta_waic]
}).round(2)

table5.to_csv("results/tables/table5_delta_comparison.csv", index=False)

print("\n=== Table 5: ΔLOO / ΔWAIC ===")
print(table5)

**Parameter Interpretation**

- Posterior mean, 95% credible intervals for βs

- Compute hazard ratios exp(β)

- Highlight significant covariates with clear interpretation

- Optional: forest plot for effect sizes

In [None]:
# ============================================================
# Parameter Interpretation
# ============================================================

import numpy as np
import pandas as pd
import arviz as az
import matplotlib.pyplot as plt

plt.style.use("seaborn-v0_8-whitegrid")

In [None]:
# ------------------------------------------------------------
# Load posterior samples
# ------------------------------------------------------------

trace = az.from_netcdf("results/models/model1_trace.nc")

posterior = trace.posterior

covariates = ["age", "sex", "drug", "serBilir", "albumin", "edema"]

beta_samples = posterior["beta"].values.reshape(-1, len(covariates))

# Convert to hazard ratios
hazard_ratios = np.exp(beta_samples)

In [None]:
# ------------------------------------------------------------
# Table: Hazard ratios with credible intervals
# ------------------------------------------------------------

hr_summary = pd.DataFrame({
    "HR_mean": hazard_ratios.mean(axis=0),
    "HR_2.5%": np.percentile(hazard_ratios, 2.5, axis=0),
    "HR_97.5%": np.percentile(hazard_ratios, 97.5, axis=0),
}, index=covariates)

hr_summary = hr_summary.round(3)
hr_summary.to_csv("results/tables/hazard_ratio_interpretation.csv")

print("\n=== Hazard Ratio Summary ===")
print(hr_summary)

In [None]:
hr_summary["Significant"] = ~(
    (hr_summary["HR_2.5%"] <= 1.0) &
    (hr_summary["HR_97.5%"] >= 1.0)
)

print("\n=== Significance Flag ===")
print(hr_summary[["HR_mean", "Significant"]])

In [None]:
# ------------------------------------------------------------
# Generate clinical interpretation statements
# ------------------------------------------------------------

interpretations = []

for cov in covariates:
    hr = hr_summary.loc[cov, "HR_mean"]
    lo = hr_summary.loc[cov, "HR_2.5%"]
    hi = hr_summary.loc[cov, "HR_97.5%"]
    sig = hr_summary.loc[cov, "Significant"]

    if sig:
        if hr > 1:
            text = f"Higher {cov} is associated with increased hazard (HR={hr:.2f}, 95% CI [{lo:.2f}, {hi:.2f}]), indicating worse survival."
        else:
            text = f"Higher {cov} is protective (HR={hr:.2f}, 95% CI [{lo:.2f}, {hi:.2f}]), associated with improved survival."
    else:
        text = f"{cov} shows no statistically meaningful association with survival (HR={hr:.2f}, 95% CI [{lo:.2f}, {hi:.2f}])."

    interpretations.append(text)

# Save interpretations
with open("results/tables/clinical_interpretations.txt", "w") as f:
    for line in interpretations:
        f.write(line + "\n")

print("\n=== Clinical Interpretations ===")
for line in interpretations:
    print("•", line)

In [None]:
# ------------------------------------------------------------
# Figure 14: Forest plot
# ------------------------------------------------------------

fig, ax = plt.subplots(figsize=(6, 4))

y_pos = np.arange(len(covariates))

ax.errorbar(
    hr_summary["HR_mean"],
    y_pos,
    xerr=[
        hr_summary["HR_mean"] - hr_summary["HR_2.5%"],
        hr_summary["HR_97.5%"] - hr_summary["HR_mean"]
    ],
    fmt="o",
    color="black",
    ecolor="gray",
    capsize=4
)

ax.axvline(1.0, color="red", linestyle="--")

ax.set_yticks(y_pos)
ax.set_yticklabels(covariates)
ax.set_xlabel("Hazard Ratio (log scale)")
ax.set_xscale("log")
ax.set_title("Posterior Hazard Ratios with 95% Credible Intervals")

plt.tight_layout()
plt.savefig("results/figures/forest_plot_hazard_ratios.png", dpi=300)
plt.close()