**Prior Predictive Checks**

- Generate survival times from priors only

- Check plausibility: no negative times, realistic range for survival

- Figures: prior predictive survival curves

In [None]:
# ============================================================
# Prior Predictive Checks
# ============================================================

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns

sns.set_style("whitegrid")
np.random.seed(42)

# ------------------------------------------------------------
# Load baseline dataset
# ------------------------------------------------------------

df = pd.read_csv("data/processed/pbc_clean.csv")

df_baseline = (
    df.sort_values(["id", "year"])
      .groupby("id", as_index=False)
      .first()
)

# Survival outcome
T = df_baseline["years"].values
E = df_baseline["status2"].values  # 1 = event, 0 = censored

# ------------------------------------------------------------
# Covariates
# ------------------------------------------------------------
continuous_vars = ["age", "serBilir", "albumin"]
categorical_vars = ["sex", "drug", "edema"]

# Standardize continuous covariates
scaler = StandardScaler()
X_cont = scaler.fit_transform(df_baseline[continuous_vars])

# Keep categorical covariates as-is
X_cat = df_baseline[categorical_vars].values

# Design matrix
X = np.column_stack([X_cont, X_cat])
N, P = X.shape

In [None]:
# ------------------------------------------------------------
# Prior sampling
# ------------------------------------------------------------

n_prior_draws = 500

# Use narrower Normal priors to avoid extreme T̃
alpha_prior = np.random.normal(0, 0.5, size=n_prior_draws)
beta_prior = np.random.normal(0, 0.5, size=(n_prior_draws, P))

# Storage
T_tilde = np.zeros((n_prior_draws, N))

# Generate synthetic survival times
for d in range(n_prior_draws):
    log_lambda = alpha_prior[d] + X @ beta_prior[d]
    lambda_ = np.exp(log_lambda)
    T_tilde[d, :] = np.random.exponential(scale=1 / lambda_)

# Diagnostics
print("=== Prior Predictive Survival Time Diagnostics ===")
print(f"Min T̃: {T_tilde.min():.4f}")
print(f"Median T̃: {np.median(T_tilde):.2f}")
print(f"Mean T̃: {T_tilde.mean():.2f}")
print(f"Max T̃: {T_tilde.max():.2f}")

In [None]:
# ------------------------------------------------------------
# Figure 8: Prior Predictive Survival Curves
# ------------------------------------------------------------

time_grid = np.linspace(0, 15, 200)
plt.figure(figsize=(6, 5))

for d in range(50):  # plot subset for clarity
    lambda_mean = np.mean(np.exp(alpha_prior[d] + X @ beta_prior[d]))
    survival_curve = np.exp(-lambda_mean * time_grid)
    plt.plot(time_grid, survival_curve, color="gray", alpha=0.2)

plt.xlabel("Time (years)")
plt.ylabel("Survival Probability")
plt.title("Prior Predictive Survival Curves")
plt.tight_layout()
plt.savefig("results/figures/prior_predictive_survival_curves.png", dpi=300)
plt.close()