**Model Specification**
- Model 2 – Hierarchical Survival Model

    - Random intercepts by patient or group: α_h[i] ~ Normal(μ_α, σ_α)

    - Priors:

        - μ_α ~ Normal(0,1)

        - σ_α ~ Exponential(1) or HalfNormal(1)

In [None]:
# ============================================================
# Bayesian Survival Models
# ============================================================

import numpy as np
import pandas as pd
import pymc as pm
import arviz as az
import pytensor.tensor as pt

# ------------------------------------------------------------
# Load baseline dataset
# ------------------------------------------------------------

df = pd.read_csv("data/processed/pbc_clean.csv")

df_baseline = (
    df.sort_values(["id", "year"])
      .groupby("id", as_index=False)
      .first()
)

# Survival data
T = df_baseline["years"].values
E = df_baseline["status2"].values  # 1 = event, 0 = censored

# Separate continuous and categorical covariates
continuous_vars = ["age", "serBilir", "albumin"]
categorical_vars = ["sex", "drug", "edema"]

# Scale continuous covariates
scaler = StandardScaler()
X_cont = scaler.fit_transform(df_baseline[continuous_vars])

# Keep categorical covariates as-is
X_cat = df_baseline[categorical_vars].values

# Combine design matrix
X = np.column_stack([X_cont, X_cat])

N, P = X.shape


In [None]:
# ============================================================
# Model 2: Hierarchical Exponential Survival Model 
# ============================================================

patient_idx = df_baseline["id"].astype("category").cat.codes.values
n_patients = len(np.unique(patient_idx))

with pm.Model() as model2:

    # Hyperpriors
    mu_alpha = pm.Normal("mu_alpha", mu=0, sigma=1)
    sigma_alpha = pm.Exponential("sigma_alpha", 1.0)

    # Random intercepts (patient-level frailty)
    alpha_patient = pm.Normal(
        "alpha_patient",
        mu=mu_alpha,
        sigma=sigma_alpha,
        shape=n_patients
    )

    # Fixed effects
    beta = pm.Normal("beta", mu=0, sigma=1, shape=P)

    # Linear predictor
    log_lambda = alpha_patient[patient_idx] + pm.math.dot(X, beta)
    lambda_ = pm.math.exp(log_lambda)

    # FAST manual likelihood
    loglik = (
        E * (pm.math.log(lambda_) - lambda_ * T) +
        (1 - E) * (-lambda_ * T)
    )

    pm.Potential("likelihood", loglik.sum())
    #pm.Deterministic("log_likelihood", loglik)
    

    trace_model2 = pm.sample(
        draws=2000,
        tune=1000,
        chains=4,
        target_accept=0.9,
        return_inferencedata=True
    )

# Save results 
file_m2 = "results/models/model2_new_trace.nc"
if os.path.exists(file_m2):
    os.remove(file_m2)
az.to_netcdf(trace_model2, file_m2)