In [None]:
import sys
import os
sys.path.append('../..')

# prevents a thread conflict between MKL/OpenMP (used by numpy/scipy) and PyTorch
os.environ['OMP_NUM_THREADS'] = '1'
os.environ['MKL_NUM_THREADS'] = '1'
os.environ['OPENBLAS_NUM_THREADS'] = '1'
os.environ['VECLIB_MAX_THREADS'] = '1'

In [None]:
# libraries
import optuna
import numpy as np
import pandas as pd
import json
import matplotlib.pyplot as plt
import seaborn as sns
from sksurv.util import Surv
import pickle

# pytorch
import torch
torch.set_num_threads(1)
import torch.nn as nn
import torchtuples as tt
from torch.optim import AdamW


# random seed for reproducibility
np.random.seed(1212)
torch.manual_seed(1212)

# models
from sksurv.linear_model import CoxPHSurvivalAnalysis
from xgbse.converters import convert_to_structured
from xgbse import XGBSEDebiasedBCE
from sksurv.ensemble import RandomSurvivalForest
from pycox.preprocessing.label_transforms import LabTransDiscreteTime
from pycox.models import DeepHitSingle, CoxPH
from pycox.evaluation import EvalSurv


# metrics
from sksurv.metrics import cumulative_dynamic_auc, concordance_index_censored, integrated_brier_score, concordance_index_ipcw


# survshap
from survshap import SurvivalModelExplainer, ModelSurvSHAP

# warnings
# import warnings




In [None]:
# data
X_train = pd.read_parquet('../../data/processed/train_standard_features.parquet')
y_train = pd.read_parquet('../../data/processed/train_targets.parquet')
X_val = pd.read_parquet('../../data/processed/val_standard_features.parquet')
y_val = pd.read_parquet('../../data/processed/val_targets.parquet')
X_test = pd.read_parquet('../../data/processed/test_standard_features.parquet')
y_test = pd.read_parquet('../../data/processed/test_targets.parquet')

y_train_struct = np.array([(bool(e), float(t)) for e, t in zip(y_train['event'], y_train['time'])], 
                        dtype=[('event', bool), ('time', float)])
y_val_struct = np.array([(bool(e), float(t)) for e, t in zip(y_val['event'], y_val['time'])], 
                    dtype=[('event', bool), ('time', float)])
y_test_struct = np.array([(bool(e), float(t)) for e, t in zip(y_test['event'], y_test['time'])], 
                    dtype=[('event', bool), ('time', float)])

y_train_struct_xgbse = convert_to_structured(y_train['time'], y_train['event'])
y_val_struct_xgbse = convert_to_structured(y_val['time'], y_val['event'])
y_test_struct_xgbse = convert_to_structured(y_test['time'], y_test['event'])

In [None]:
min_time = np.min(y_test_struct['time'])
max_time = np.max(y_test_struct['time'])
# ensure times_for_brier is within [min_time, max_time)
times_for_brier = np.arange(np.ceil(min_time), np.floor(max_time))

In [None]:
min_time, max_time

## Cox model

In [None]:
def objective_cox(trial):
    # hyperparameter search space
    params = {
        'alpha': trial.suggest_float('alpha', 0.001, 10.0, log=True),
        'n_iter': trial.suggest_int('n_iter', 100, 2000),
        'tol': trial.suggest_float('tol', 1e-6, 1e-3, log=True),
    }
    
    times = np.arange(1, 361)
    
    try:
        # train and evaluate model
        model = CoxPHSurvivalAnalysis(**params)
        model.fit(X_train, y_train_struct)
        
        # get predictions for both sets
        train_risks = model.predict(X_train)
        val_risks = model.predict(X_val)
        
        # calculate c-index for both
        train_c_index = concordance_index_censored(
            y_train_struct['event'], y_train_struct['time'], train_risks)[0]
        val_c_index = concordance_index_censored(
            y_val_struct['event'], y_val_struct['time'], val_risks)[0]
        
        # mean AUC score
        _ , mean_auc = cumulative_dynamic_auc(
            y_train_struct, y_val_struct, val_risks, times)
        
        # store metrics and coefficients
        trial.set_user_attr('train_c_index', train_c_index)
        trial.set_user_attr('val_c_index', val_c_index)
        trial.set_user_attr('mean_auc', mean_auc)
        
        if hasattr(model, 'coef_'):
            coef_dict = {name: float(coef) for name, coef in zip(X_train.columns, model.coef_)}
            trial.set_user_attr('coefficients', coef_dict)
        
        return val_c_index
        
    except Exception as e:
        print(f"Trial failed: {e}")
        return float('-inf')

# run optimization
study_cox = optuna.create_study(direction='maximize')
study_cox.optimize(objective_cox, n_trials=20, show_progress_bar=True)

# save results
best_trial_cox = study_cox.best_trial
print(f"\nBest score: {best_trial_cox.value:.5f}")
print(f"Train c-index: {best_trial_cox.user_attrs['train_c_index']:.5f}")
print(f"Val c-index: {best_trial_cox.user_attrs['val_c_index']:.5f}")
print(f"Mean AUC: {best_trial_cox.user_attrs['mean_auc']:.5f}")
print(f"Best params: {best_trial_cox.params}")

# save parameters
with open('../../models/best_cox_params.json', 'w') as f:
    json.dump(study_cox.best_params, f, indent=2)


In [None]:
# prediction
# load the best hyperparameters
with open('../../models/best_cox_params.json', 'r') as f:
    best_params = json.load(f)

print("best hyperparameters:")
print(json.dumps(best_params, indent=2))

# parameters for model initialization
cox_params = {
    'alpha': best_params['alpha'],
    'n_iter': best_params['n_iter'],
    'tol': best_params['tol'],
}



# train the final model
print("training final model")
cox_model = CoxPHSurvivalAnalysis(**cox_params)

cox_model.fit(
    X_train, 
    y_train_struct
)

print("training complete")

# predictions on test set
print("predictions on test")
cox_test_surv_probs = cox_model.predict_survival_function(X_test)
cox_surv_funcs = cox_model.predict_survival_function(X_test, return_array=False)

cox_test_surv_probs_ibs = np.vstack([fn(times_for_brier) for fn in cox_surv_funcs])

cox_test_risks = cox_model.predict(X_test)


# c index
cox_test_c_index = concordance_index_censored(
    y_test_struct['event'], 
    y_test_struct['time'], 
    cox_test_risks
)[0]

# cumulative dynamic AUC
cox_td_auc, cox_test_mean_auc = cumulative_dynamic_auc(
    y_train_struct, y_test_struct, cox_test_risks, times=list(range(1, 361))
)

# ibs
cox_ibs_score = integrated_brier_score(
        survival_train=y_test_struct, # use test set censoring distribution like pycox
        survival_test=y_test_struct,
        estimate=cox_test_surv_probs_ibs,
        times=times_for_brier
    )


print(f"\nFINAL TEST RESULTS")
print(f"test C-index: {cox_test_c_index:.5f}")
print(f"test mean AUC: {cox_test_mean_auc:.5f}")
print(f"IBS: {cox_ibs_score:.4f}")

# save the final model
print("saving final model")
import pickle
with open('../../models/final_cox_model.pkl', 'wb') as f:
    pickle.dump(cox_model, f)

# save test predictions
cox_test_predictions = {
    'survival_probabilities': cox_test_surv_probs,
    'risk_scores': cox_test_risks,
    'c_index': cox_test_c_index,
    'mean_auc': cox_test_mean_auc,
    'ibs': cox_ibs_score
}

with open('../../models/cox_test_predictions.pkl', 'wb') as f:
    pickle.dump(cox_test_predictions, f)

print('model and predictions saved')

## Random Survival Forest

In [None]:
def objective_rsf(trial):
    # hyperparameter search space
    params = {
        'n_estimators': trial.suggest_int('n_estimators', 80, 200),
        'min_samples_split': trial.suggest_int('min_samples_split', 10, 30), 
        'min_samples_leaf': trial.suggest_int('min_samples_leaf', 5, 20), 
        'max_depth': trial.suggest_int('max_depth', 10, 20),
        'max_samples': trial.suggest_float('max_samples', 0.6, 0.9),
        'n_jobs': -1,
        'random_state': 1212
    }
    
    times = np.arange(1, 361)
    
    try:
        # train and evaluate model
        model = RandomSurvivalForest(**params)
        model.fit(X_train, y_train_struct)
        
        # predictions for both sets
        train_risks = model.predict(X_train)
        val_risks = model.predict(X_val)
        
        # c-index for both
        train_c_index = concordance_index_censored(
            y_train_struct['event'], y_train_struct['time'], train_risks)[0]
        val_c_index = concordance_index_censored(
            y_val_struct['event'], y_val_struct['time'], val_risks)[0]
        
        # mean AUC score
        _, mean_auc = cumulative_dynamic_auc(
            y_train_struct, y_val_struct, val_risks, times)
        
        # store metrics
        trial.set_user_attr('train_c_index', train_c_index)
        trial.set_user_attr('val_c_index', val_c_index)
        trial.set_user_attr('mean_auc', mean_auc)
        
        return val_c_index
        
    except Exception as e:
        print(f"Trial failed: {e}")
        return float('-inf')

# run optimization for the RSF model
study_rsf = optuna.create_study(direction='maximize')
study_rsf.optimize(objective_rsf, n_trials=20, show_progress_bar=True)

# save results
print("\nRSF best trial")
# assign the best trial from the correct study
best_trial_rsf = study_rsf.best_trial

# print 'best_trial_rsf' metrics
print(f"best score: {best_trial_rsf.value:.5f}")
print(f"train C-index: {best_trial_rsf.user_attrs['train_c_index']:.5f}")
print(f"val C-index: {best_trial_rsf.user_attrs['val_c_index']:.5f}")
print(f"mean AUC: {best_trial_rsf.user_attrs['mean_auc']:.5f}")
print(f"best params: {best_trial_rsf.params}")

# save parameters 
with open('../../models/best_rsf_params.json', 'w') as f:
    json.dump(study_rsf.best_params, f, indent=2)

print("model and predictions saved")

In [None]:
# prediction
# load the best hyperparameters
with open('../../models/best_rsf_params.json', 'r') as f:
    best_params = json.load(f)

print("best hyperparameters:")
print(json.dumps(best_params, indent=2))

# parameters for model initialization
rsf_params = {
    'n_estimators': best_params['n_estimators'],
    'min_samples_split': best_params['min_samples_split'],
    'min_samples_leaf': best_params['min_samples_leaf'],
    'max_depth': best_params['max_depth'],
    'max_samples': best_params['max_samples'],
    'n_jobs': -1,
    'random_state': 1212
}



# train the final model
print("training final model")
rsf_model = RandomSurvivalForest(**rsf_params)

rsf_model.fit(
    X_train, 
    y_train_struct
)

print("model training complete")

# predict on test set
print("predictions on test")
rsf_test_surv_probs = rsf_model.predict_survival_function(X_test)
rsf_surv_funcs = rsf_model.predict_survival_function(X_test, return_array=False)

rsf_test_surv_probs_ibs = np.vstack([fn(times_for_brier) for fn in rsf_surv_funcs])

rsf_test_risks = rsf_model.predict(X_test)


# c index
rsf_test_c_index = concordance_index_censored(
    y_test_struct['event'], 
    y_test_struct['time'], 
    rsf_test_risks
)[0]

# cumulative dynamic AUC
rsf_td_auc, rsf_test_mean_auc = cumulative_dynamic_auc(
    y_train_struct, y_test_struct, rsf_test_risks, times=list(range(1, 361))
)

# ibs
rsf_ibs_score = integrated_brier_score(
        survival_train=y_test_struct, 
        survival_test=y_test_struct,
        estimate=rsf_test_surv_probs_ibs,
        times=times_for_brier
    )


print(f"\nFINAL TEST RESULTS")
print(f"test C-index: {rsf_test_c_index:.5f}")
print(f"test mean AUC: {rsf_test_mean_auc:.5f}")
print(f"IBS: {rsf_ibs_score:.4f}")

# save the final model
print("saving final model")
import pickle
with open('../../models/final_rsf_model.pkl', 'wb') as f:
    pickle.dump(rsf_model, f)

# save test predictions
rsf_test_predictions = {
    'survival_probabilities': rsf_test_surv_probs,
    'risk_scores': rsf_test_risks,
    'c_index': rsf_test_c_index,
    'mean_auc': rsf_test_mean_auc,
    'ibs': rsf_ibs_score
}

with open('../../models/rsf_test_predictions.pkl', 'wb') as f:
    pickle.dump(rsf_test_predictions, f)

print("model and predictions saved")

## XGBSE

In [None]:
# hyperparameter search
def get_risk_scores(survival_probs):
    """convert survival probabilities to risk scores for c-index calculation"""
    return 1 - survival_probs.mean(axis=1)

def objective_xgbse(trial):
    # XGBoost hyperparameters
    xgb_params = {
        'max_depth': trial.suggest_int('max_depth', 3, 12),
        'learning_rate': trial.suggest_float('learning_rate', 0.01, 0.3),
        'reg_alpha': trial.suggest_float('reg_alpha', 0, 2), 
        'min_child_weight': trial.suggest_int('min_child_weight', 1, 10),
        'random_state': 1212
    }
    
    # BCE specific hyperparameters
    lr_params = {
        'C': trial.suggest_float('lr_C', 0.001, 100, log=True),
        'max_iter': trial.suggest_int('lr_max_iter', 100, 1000)
    }
    
    num_boost_round = trial.suggest_int('num_boost_round', 50, 300)

    times = np.arange(0, 361, 30).tolist()
    
    try:
        model = XGBSEDebiasedBCE(
            xgb_params=xgb_params,
            lr_params=lr_params,
            n_jobs=-1
        )
        model.fit(
            X_train, 
            y_train_struct, 
            time_bins=times, 
            num_boost_round=num_boost_round)
        
        # survival probabilities and convert to risk scores
        train_survival_probs = model.predict(X_train)
        val_survival_probs = model.predict(X_val)
        
        train_risks = get_risk_scores(train_survival_probs)
        val_risks = get_risk_scores(val_survival_probs)
        
        # c-index
        train_c_index = concordance_index_censored(
            y_train_struct['event'], y_train_struct['time'], train_risks)[0]
        val_c_index = concordance_index_censored(
            y_val_struct['event'], y_val_struct['time'], val_risks)[0]
        
        # mean AUC
        _, mean_auc = cumulative_dynamic_auc(
            y_train_struct, y_val_struct, val_risks, times=list(range(1, 361)))
        
        # store metrics
        trial.set_user_attr('train_c_index', train_c_index)
        trial.set_user_attr('val_c_index', val_c_index)
        trial.set_user_attr('mean_auc', mean_auc)
        
        return val_c_index
        
    except Exception as e:
        print(f"Trial failed: {e}")
        return float('-inf')
    
# run optimization
study_xgbse = optuna.create_study(direction='maximize')
study_xgbse.optimize(objective_xgbse, n_trials=20, show_progress_bar=True)

# save results
best_trial_xgbse = study_xgbse.best_trial
print(f"\nBest score: {best_trial_xgbse.value:.5f}")
print(f"Train c-index: {best_trial_xgbse.user_attrs['train_c_index']:.5f}")
print(f"Val c-index: {best_trial_xgbse.user_attrs['val_c_index']:.5f}")
print(f"Mean AUC: {best_trial_xgbse.user_attrs['mean_auc']:.5f}")
print(f"Best params: {best_trial_xgbse.params}")

# save parameters
with open('../../models/best_xgbse_params.json', 'w') as f:
    json.dump(study_xgbse.best_params, f, indent=2)

In [None]:
# prediction
def calculate_ibs_for_xgbse(model, X_train, y_train_struct, X_test, y_test_struct):
    min_time = np.min(y_test_struct['time'])
    max_time = np.max(y_test_struct['time'])
    # times_for_brier is within [min_time, max_time)
    times_for_brier = np.arange(np.ceil(min_time), np.floor(max_time))
    

    xgbse_pred_df = model.predict(X_test)
    
    # get the time bins from the model
    xgbse_time_bins = model.time_bins
    
    # interpolate predictions onto new time grid
    interpolated_probs = np.zeros((X_test.shape[0], len(times_for_brier)))
    
    for i in range(len(xgbse_pred_df)):
        # get survival probabilities for i-th patient
        patient_probs = xgbse_pred_df.iloc[i].values
        
        # Interpolate to new time grid
        interpolated_probs[i, :] = np.interp(
            x=times_for_brier,      # new time points
            xp=xgbse_time_bins,     # XGBSE time bins
            fp=patient_probs,       
            left=1.0,               # surv prob is 1 before first time bin
            right=patient_probs[-1] # extend last value after last time bin
        )
    
    # Step 5: Calculate IBS using scikit-survival
    xgbse_ibs_score = integrated_brier_score(
        survival_train=y_test_struct, # use test set censoring distribution like pycox
        survival_test=y_test_struct,
        estimate=interpolated_probs,
        times=times_for_brier
    )
    
    return xgbse_ibs_score

# load the best hyperparameters
with open('../../models/best_xgbse_params.json', 'r') as f:
    best_params = json.load(f)

print("best hyperparameters:")
print(json.dumps(best_params, indent=2))

# extract parameters for model initialization
xgb_params = {
    'max_depth': best_params['max_depth'],
    'learning_rate': best_params['learning_rate'],
    'reg_alpha': best_params['reg_alpha'],
    'min_child_weight': best_params['min_child_weight'],
    'random_state': 1212
}

lr_params = {
    'C': best_params['lr_C'],
    'max_iter': best_params['lr_max_iter']
}

num_boost_round = best_params['num_boost_round']

# time bins
times_xgbse = np.arange(0, 361, 30).tolist()
print(f"Time bins: {times_xgbse}")


# train the final model
print("training final model")
xgbse_model = XGBSEDebiasedBCE(
    xgb_params=xgb_params,
    lr_params=lr_params,
    n_jobs=-1
)

xgbse_model.fit(
    X_train, 
    y_train_struct, 
    time_bins=times_xgbse, 
    num_boost_round=num_boost_round
)

print("model training complete")

# predict on test set
print("predictions on test")
xgbse_test_survival_probs = xgbse_model.predict(X_test)
xgbse_test_risks = get_risk_scores(xgbse_test_survival_probs)


# C index
xgbse_test_c_index = concordance_index_censored(
    y_test_struct['event'], 
    y_test_struct['time'], 
    xgbse_test_risks
)[0]

# cumulative dynamic AUC
xgbse_td_auc, xgbse_test_mean_auc = cumulative_dynamic_auc(
    y_train_struct, y_test_struct, xgbse_test_risks, times=list(range(1, 361))
)

# ibs
xgbse_ibs_score = calculate_ibs_for_xgbse(
    xgbse_model, 
    X_train, 
    y_train_struct, 
    X_test, 
    y_test_struct
    )

print(f"\nFINAL TEST RESULTS")
print(f"Test C-index: {xgbse_test_c_index:.5f}")
print(f"Test Mean AUC: {xgbse_test_mean_auc:.5f}")
print(f"IBS: {xgbse_ibs_score:.4f}")

# save the final model
print("Saving final model...")
import pickle
with open('../../models/final_xgbse_model.pkl', 'wb') as f:
    pickle.dump(xgbse_model, f)

# save test predictions
xgbse_test_predictions = {
    'survival_probabilities': xgbse_test_survival_probs,
    'risk_scores': xgbse_test_risks,
    'c_index': xgbse_test_c_index,
    'mean_auc': xgbse_test_mean_auc,
    'time_bins': times_xgbse
}

with open('../../models/xgbse_test_predictions.pkl', 'wb') as f:
    pickle.dump(xgbse_test_predictions, f)

print("model and predictions saved")

## DeepHit

In [None]:
# llms were useful for making this work

def get_risk_scores_repurchase_expected_time(model, X):
    # get the predicted survival probs as df
    # rows=individuals, cols=time points
    surv_probs_df = model.predict_surv_df(X)
    
    # transposition logic as safeguard
    if surv_probs_df.shape[0] < surv_probs_df.shape[1]:
        surv_probs_df = surv_probs_df.T
    
    # get time points and calculate the width of each time interval
    time_points = surv_probs_df.columns.values
    
    # prepend t=0 to the time points to correctly calculate the first interval's width (t_1 - 0)
    time_points_with_zero = np.insert(time_points, 0, 0)
    
    # time_diffs will be an array of interval durations: [t_1-0, t_2-t_1, t_3-t_2, ...]
    time_diffs = np.diff(time_points_with_zero)
    
    # calculate the expected survival time as the area under the curve
    # the sum of S(t_i) * (t_i - t_{i-1}) for each interval
    # calculate using dot product
    surv_probs_np = surv_probs_df.values
    expected_times = np.dot(surv_probs_np, time_diffs)
    
    # return the NEGATIVE expected time.
    return -expected_times

# time discretization
cuts = np.arange(0, 361, 30, dtype=np.float32)
labtrans = LabTransDiscreteTime(cuts=cuts)
y_train_transformed = labtrans.fit_transform(y_train['time'], y_train['event'])
# y_val_transformed = labtrans.transform(y_val['time'], y_val['event'])


DEVICE = torch.device('cpu')

def objective_deephit(trial):
    # hyperparameter ranges
    batch_norm = trial.suggest_categorical('batch_norm', [True, False])
    dropout_prob = trial.suggest_float('dropout_prob', 0.1, 0.3)
    lr = trial.suggest_float('lr', 1e-4, 1e-2, log=True)
    num_layers = trial.suggest_int('num_layers', 2, 3)
    num_nodes = trial.suggest_categorical('num_nodes', [32, 64, 128])
    output_bias = trial.suggest_categorical('output_bias', [True, False])
    reg = trial.suggest_float('reg', 1e-5, 1e-3, log=True)
    
    times = list(range(1, 361))
    
    try:
        
        
        print(f"data shapes: X_train={X_train.shape}, y_train_transformed={y_train_transformed[0].shape if isinstance(y_train_transformed, tuple) else len(y_train_transformed)}")
        print(f"labtrans cuts: {len(labtrans.cuts)}, out_features: {labtrans.out_features}")
        
        # build network
        layers = []
        current_size = X_train.shape[1]
        
        for i in range(num_layers):
            layers.append(nn.Linear(current_size, num_nodes))
            layers.append(nn.ReLU())
            if batch_norm:
                layers.append(nn.BatchNorm1d(num_nodes))
            layers.append(nn.Dropout(dropout_prob))
            current_size = num_nodes
            
        layers.append(nn.Linear(current_size, labtrans.out_features, bias=output_bias))
        net = nn.Sequential(*layers).to(DEVICE)
        
        # model setup
        optimizer = AdamW(net.parameters(), lr=lr, weight_decay=reg)
        model = DeepHitSingle(net, optimizer, alpha=0.2, sigma=0.1,
                              duration_index=labtrans.cuts, device=DEVICE)


        
        print("starting model training")
        # train model
        model.fit(X_train.values, y_train_transformed, 
                  batch_size=256, epochs=30, verbose=False)
        
        print("model training complete, calculating predictions")
        # debug the prediction step
        X_train_tensor = torch.tensor(X_train.values, dtype=torch.float32).to(DEVICE)
        X_val_tensor = torch.tensor(X_val.values, dtype=torch.float32).to(DEVICE)
        
        print(f"input tensor shapes: X_train={X_train_tensor.shape}, X_val={X_val_tensor.shape}")
        
        # get risk scores and calculate metrics
        train_risks = get_risk_scores_repurchase_expected_time(model, X_train_tensor)
        val_risks = get_risk_scores_repurchase_expected_time(model, X_val_tensor)
        
        print("risk scores calculated")
        
        train_c_index = concordance_index_censored(y_train_struct['event'], y_train_struct['time'], train_risks)[0]
        val_c_index = concordance_index_censored(y_val_struct['event'], y_val_struct['time'], val_risks)[0]
        _, mean_auc = cumulative_dynamic_auc(y_train_struct, y_val_struct, val_risks, times)
        
        trial.set_user_attr('train_c_index', train_c_index)
        trial.set_user_attr('val_c_index', val_c_index)
        trial.set_user_attr('mean_auc', mean_auc)
        
        return val_c_index
        
    except Exception as e:
        import traceback
        print(f"trial failed with full traceback:")
        print(traceback.format_exc())
        return float('-inf')

# optimization
study_deephit = optuna.create_study(direction='maximize')
study_deephit.optimize(objective_deephit, n_trials=20, show_progress_bar=True)

# save results
if study_deephit.best_trial.value != float('-inf'):
    best_trial_deephit = study_deephit.best_trial
    print(f"\nBest score: {best_trial_deephit.value:.5f}")
    print(f"Train c-index: {best_trial_deephit.user_attrs['train_c_index']:.5f}")
    print(f"Val c-index: {best_trial_deephit.user_attrs['val_c_index']:.5f}")
    print(f"Mean AUC: {best_trial_deephit.user_attrs['mean_auc']:.5f}")
    print(f"Best params: {best_trial_deephit.params}")
    
    # save parameters
    import json
    with open('../../models/best_deephit_params.json', 'w') as f:
        json.dump(study_deephit.best_params, f, indent=2)
else:
    print("all trials failed")

In [None]:
def get_risk_scores_from_surv_df(surv_probs_df):

    # get time points and calculate the width of each time interval
    time_points = surv_probs_df.columns.values
    
    # add t=0 to the time points to calculate the first interval's width (t_1 - 0)
    time_points_with_zero = np.insert(time_points, 0, 0)
    
    # time_diffs will be an array of interval durations: [t_1-0, t_2-t_1, t_3-t_2, ...]
    time_diffs = np.diff(time_points_with_zero)
    
    # calculate the expected survival time as the area under the curve
    surv_probs_np = surv_probs_df.values
    expected_times = np.dot(surv_probs_np, time_diffs)
    
    # return the NEGATIVE expected time.
    return -expected_times


# load the best hyperparameters
with open('../../models/best_deephit_params.json', 'r') as f:
    best_params = json.load(f)

print("best hyperparameters:")
print(json.dumps(best_params, indent=2))

# time discretization , use same as training
cuts = np.arange(0, 361, 10, dtype=np.float32)
labtrans = LabTransDiscreteTime(cuts=cuts)
y_train_transformed = labtrans.fit_transform(y_train['time'], y_train['event'])

DEVICE = torch.device('cpu')

# build network
layers = []
current_size = X_train.shape[1]

for i in range(best_params['num_layers']):
    layers.append(nn.Linear(current_size, best_params['num_nodes']))
    layers.append(nn.ReLU())
    if best_params['batch_norm']:
        layers.append(nn.BatchNorm1d(best_params['num_nodes']))
    layers.append(nn.Dropout(best_params['dropout_prob']))
    current_size = best_params['num_nodes']

layers.append(nn.Linear(current_size, labtrans.out_features, bias=best_params['output_bias']))
net = nn.Sequential(*layers).to(DEVICE)

# initialize
optimizer = AdamW(net.parameters(), lr=best_params['lr'], weight_decay=best_params['reg'])
deephit_model = DeepHitSingle(net, optimizer, alpha=0.2, sigma=0.1,
                              duration_index=labtrans.cuts, device=DEVICE)

# train final model
print("training final model")
deephit_model.fit(X_train.values, y_train_transformed, 
                  batch_size=256, epochs=30, verbose=False)

print("model training complete")


# set model to evaluation mode (necessary i think)
deephit_model.net.eval()

print("making predictions on test")
X_test_tensor = torch.tensor(X_test.values, dtype=torch.float32).to(DEVICE)

# llms helped create these steps
# STEP 1: PREDICT ONLY ONCE
# use `predict_surv` which returns a tensor/numpy array, this is the main computation
with torch.no_grad():
    surv_array = deephit_model.predict_surv(X_test_tensor)

# STEP 2: FIX DATAFRAME CREATION (NO .T)
# df for calculating IBS and for saving.
deephit_test_surv_df = pd.DataFrame(
    surv_array, 
    columns=deephit_model.duration_index
)
print(f"predictions generated, shape: {deephit_test_surv_df.shape}")

# STEP 3: CALCULATE METRICS

# risk scores using existing df
print("calculating risk scores")
deephit_test_risks = get_risk_scores_from_surv_df(deephit_test_surv_df)

# c-index
print("calculating C-index")
deephit_test_c_index = concordance_index_censored(
    y_test_struct['event'], 
    y_test_struct['time'], 
    deephit_test_risks
)[0]

# cumulative dynamic AUC
print("calculating AUC")
deephit_td_auc, deephit_test_mean_auc = cumulative_dynamic_auc(
    y_train_struct, y_test_struct, deephit_test_risks, times=list(range(1, 361))
)

# interpolate survival probabilities at the required times
surv_probs_for_sksurv = np.empty((deephit_test_surv_df.shape[0], len(times_for_brier)), dtype=np.float32)
for i, row in enumerate(deephit_test_surv_df.values):
    surv_probs_for_sksurv[i, :] = np.interp(
        times_for_brier, 
        deephit_test_surv_df.columns.values, 
        row
    )

# ibs
deephit_ibs_score = integrated_brier_score(
    y_test_struct,           # for censoring model
    y_test_struct,            
    surv_probs_for_sksurv,    # prediction spare
    times_for_brier          # time points for those predictions
)


# this part should now be reached without crashing due to llms figuring out the threading issue
print(f"\nFINAL TEST RESULTS")
print(f"etst C-index: {deephit_test_c_index:.5f}")
print(f"test mean AUC: {deephit_test_mean_auc:.5f}")
print(f"IBS: {deephit_ibs_score:.4f}")

# save final model
print("saving final model")
deephit_model.save_net('../../models/final_deephit_model.pt')

# save test predictions
deephit_test_predictions = {
    'survival_probabilities': deephit_test_surv_df,
    'risk_scores': deephit_test_risks,
    'c_index': deephit_test_c_index,
    'mean_auc': deephit_test_mean_auc,
    'time_bins': cuts[:-1].tolist()
}

with open('../../models/deephit_test_predictions.pkl', 'wb') as f:
    pickle.dump(deephit_test_predictions, f)

print("model and predictions saved")

## Deepsurv

In [None]:
# hyperparameter search
y_train_tuple = (y_train['time'].values, y_train['event'].values)
y_val_tuple = (y_val['time'].values, y_val['event'].values)


def objective_deepsurv(trial):
    # hyperparameter search space
    batch_norm = trial.suggest_categorical('batch_norm', [True, False])
    dropout_prob = trial.suggest_float('dropout_prob', 0.1, 0.5)
    lr = trial.suggest_float('lr', 1e-4, 1e-2, log=True)
    num_layers = trial.suggest_int('num_layers', 2, 4)
    num_nodes = trial.suggest_categorical('num_nodes', [32, 64, 128])
    output_bias = trial.suggest_categorical('output_bias', [True, False])
    reg = trial.suggest_float('reg', 1e-5, 1e-2, log=True)

    try:
        # build network
        layers = []
        current_size = X_train.shape[1]
        
        for i in range(num_layers):
            layers.append(nn.Linear(current_size, num_nodes))
            layers.append(nn.ReLU())
            if batch_norm:
                layers.append(nn.BatchNorm1d(num_nodes))
            layers.append(nn.Dropout(dropout_prob))
            current_size = num_nodes
            
        # output layer must have 1 node for the single risk score
        layers.append(nn.Linear(current_size, 1, bias=output_bias))
        net = nn.Sequential(*layers).to(DEVICE)
        
        # model setup
        optimizer = AdamW(net.parameters(), lr=lr, weight_decay=reg)
        model = CoxPH(net, optimizer, device=DEVICE)
        
        print("start coxph model training")
        # number of epochs is fixed
        model.fit(
            X_train.values,
            y_train_tuple,
            batch_size=256,
            epochs=30,
            verbose=False
        )
        
        print("training complete, calculating predictions")
        
        # risk scores
        X_train_tensor = torch.tensor(X_train.values, dtype=torch.float32).to(DEVICE)
        X_val_tensor = torch.tensor(X_val.values, dtype=torch.float32).to(DEVICE)
        
        # .predict() returns the risk scores directly
        train_risks = model.predict(X_train_tensor).squeeze()
        val_risks = model.predict(X_val_tensor).squeeze()
        
        print("risk scores calculated")
        
        # c index train
        train_c_index = concordance_index_censored(
            y_train_struct['event'],
            y_train_struct['time'], 
            train_risks)[0]
        
        # c index val
        val_c_index = concordance_index_censored(
            y_val_struct['event'], 
            y_val_struct['time'], 
            val_risks)[0]
        
        times = np.arange(int(y_train['time'].min()), 361, 1)
        _, mean_auc = cumulative_dynamic_auc(y_train_struct, y_val_struct, val_risks, times)
        
        # store trial results
        trial.set_user_attr('train_c_index', train_c_index)
        trial.set_user_attr('val_c_index', val_c_index)
        trial.set_user_attr('mean_auc', mean_auc)
        
        return val_c_index
        
    except Exception as e:
        import traceback
        print(f"trial failed with full traceback:")
        print(traceback.format_exc())
        return float('-inf')

# run the Optuna Study
study_deepsurv = optuna.create_study(direction='maximize')
study_deepsurv.optimize(objective_deepsurv, n_trials=20, show_progress_bar=True)

# print and save best results
if study_deepsurv.best_trial.value > float('-inf'):
    best_trial_deepsurv = study_deepsurv.best_trial
    print("\ndeepSurv best trial")
    print(f"best score (val C-index): {best_trial_deepsurv.value:.5f}")
    print(f"train C-index: {best_trial_deepsurv.user_attrs['train_c_index']:.5f}")
    print(f"mean val AUC: {best_trial_deepsurv.user_attrs['mean_auc']:.5f}")
    print(f"best params: {best_trial_deepsurv.params}")
    
    import json
    with open('../../models/best_deepsurv_params.json', 'w') as f:
        json.dump(best_trial_deepsurv.params, f, indent=2)
else:
    print("all trials failed for deepsurv")

In [None]:
# prediction
with open('../../models/best_deepsurv_params.json', 'r') as f:
    best_params = json.load(f)

print("best deepsurv hyperparameters:")
print(json.dumps(best_params, indent=2))

DEVICE = torch.device('cpu')

# build network
layers = []
current_size = X_train.shape[1]

for i in range(best_params['num_layers']):
    layers.append(nn.Linear(current_size, best_params['num_nodes']))
    layers.append(nn.ReLU())
    if best_params['batch_norm']:
        layers.append(nn.BatchNorm1d(best_params['num_nodes']))
    layers.append(nn.Dropout(best_params['dropout_prob']))
    current_size = best_params['num_nodes']

# output layer for CoxPH must have 1 node for the risk score
layers.append(nn.Linear(current_size, 1, bias=best_params['output_bias']))
net = nn.Sequential(*layers).to(DEVICE)

# initialize model
optimizer = AdamW(net.parameters(), lr=best_params['lr'], weight_decay=best_params['reg'])
deepsurv_model = CoxPH(net, optimizer, device=DEVICE)

# --- 2. TRAIN THE FINAL MODEL ON TRAINING DATA ---

print("training final model")
# use a tuple for pycox model fitting
y_train_tuple = (y_train['time'].values, y_train['event'].values)
deepsurv_model.fit(X_train.values, y_train_tuple, 
                   batch_size=256, epochs=30, verbose=False)

print("model training complete")

# MAKE PREDICTIONS

deepsurv_model.net.eval()
print("predictions on test set")

#  must compute the baseline hazard function on train data
deepsurv_model.compute_baseline_hazards(X_train.values, y_train_tuple)

# prepare test data tensor
X_test_tensor = torch.tensor(X_test.values, dtype=torch.float32).to(DEVICE)

# get risk scores for c index and AUC
deepsurv_test_risks = deepsurv_model.predict(X_test_tensor).squeeze()

# get full survival probs for IBS
deepsurv_test_surv_df = deepsurv_model.predict_surv_df(X_test_tensor)

print("predictions generated")

# PERFORMANCE METRICS

# time points for evaluation

print("calculating C-index")
deepsurv_test_c_index = concordance_index_censored(
    y_test_struct['event'], 
    y_test_struct['time'], 
    deepsurv_test_risks
)[0]

print("calculating AUC")
deepsurv_td_auc, deepsurv_test_mean_auc = cumulative_dynamic_auc(
    y_train_struct, y_test_struct, deepsurv_test_risks, times=list(range(1, 361))
)

print("calculating IBS")


# predict_surv_df returns shape (n_samples, n_times) 
deepsurv_test_surv_df = deepsurv_model.predict_surv_df(X_test_tensor)

# interpolate survival probabilities at the exact times needed
n_samples = len(X_test)
surv_probs_for_ibs = np.zeros((n_samples, len(times_for_brier)))

for i in range(n_samples):
    # get survival function for sample i ( columns are samples, rows are times)
    surv_func = deepsurv_test_surv_df.iloc[:, i]
    # interpolate at required times
    surv_probs_for_ibs[i, :] = np.interp(
        times_for_brier,
        surv_func.index,
        surv_func.values,
        left=1.0,
        right=surv_func.values[-1]
    )

# IBS
deepsurv_ibs_score = integrated_brier_score(
    survival_train=y_test_struct,
    survival_test=y_test_struct,
    estimate=surv_probs_for_ibs,
    times=times_for_brier
)

print(f"\nFINAL DEEPSURV TEST RESULTS")
print(f"test C-index: {deepsurv_test_c_index:.5f}")
print(f"test mean AUC: {deepsurv_test_mean_auc:.5f}")
print(f"IBS: {deepsurv_ibs_score:.4f}")

# SAVING
print("saving final model and predictions")
deepsurv_model.save_net('../../models/final_deepsurv_model.pt')

deepsurv_test_predictions = {
    'survival_probabilities': deepsurv_test_surv_df,
    'risk_scores': deepsurv_test_risks,
    'c_index': deepsurv_test_c_index,
    'mean_auc': deepsurv_test_mean_auc,
    'ibs': deepsurv_ibs_score
}
with open('../../models/deepsurv_test_predictions.pkl', 'wb') as f:
    pickle.dump(deepsurv_test_predictions, f)

print("omdel and predictions saved")

## Model comparison

In [None]:
# c index, ibs
# Collect test C-index and IBS scores for all models

# Cox model
cox_test_c_index = cox_test_predictions['c_index']
cox_test_ibs = cox_test_predictions['ibs']

# Random Survival Forest
rsf_test_c_index = rsf_test_predictions['c_index']
rsf_test_ibs = rsf_test_predictions['ibs']

# XGBSE
xgbse_test_c_index = xgbse_test_predictions['c_index']
xgbse_test_ibs = xgbse_ibs_score

# DeepHit
deephit_test_c_index = deephit_test_predictions['c_index']
deephit_test_ibs = deephit_ibs_score

# DeepSurv
deepsurv_test_c_index = deepsurv_test_predictions['c_index']
deepsurv_test_ibs = deepsurv_test_predictions['ibs']

# show results
results_df = pd.DataFrame({
    'Model': ['Cox', 'Random Survival Forest', 'XGBSE', 'DeepHit', 'DeepSurv'],
    'Test C-index': [
        cox_test_c_index,
        rsf_test_c_index,
        xgbse_test_c_index,
        deephit_test_c_index,
        deepsurv_test_c_index
    ],
    'Test IBS': [
        cox_test_ibs,
        rsf_test_ibs,
        xgbse_test_ibs,
        deephit_test_ibs,
        deepsurv_test_ibs
    ]
})

print(results_df)

In [None]:
plt.style.use('seaborn-v0_8-paper')
sns.set_context("paper", font_scale=1.1)

# Global rcParams
plt.rcParams.update(plt.rcParamsDefault)
plt.rcParams.update({
    'figure.dpi': 300,
    'savefig.dpi': 300,
    'font.size': 11,
    'axes.titlesize': 13,
    'axes.labelsize': 11,
    'legend.frameon': False,
    'axes.spines.top': False,
    'axes.spines.right': False,
    'lines.linewidth': 2,
    'font.family': 'serif',
    'font.serif': ['Times New Roman', 'Times', 'serif']
    
})


# Consistent colors
COLORS = ['#2E86C1', '#E74C3C', '#28B463', '#F39C12', '#8E44AD']
sns.set_palette(COLORS)

In [None]:
# exp the loghazar models
cox_test_risks_exp = np.exp(cox_test_risks)
deepsurv_test_risks_exp = np.exp(deepsurv_test_risks)

In [None]:
# for plots, i use mostly LLMS

In [None]:
# risk score distributions
# KDE plots comparing risk score distributions for repurchase (event==1) vs non-repurchase (event==0)

model_names = [
    ("Cox", cox_test_risks_exp),
    ("RSF", rsf_test_risks),
    ("XGBSE", xgbse_test_risks),
    ("DeepHit", deephit_test_risks),
    ("DeepSurv", deepsurv_test_risks_exp)
]

event = y_test['event'].values

fig, axes = plt.subplots(1, 5, figsize=(20, 4))

for i, (name, risks) in enumerate(model_names):
    ax = axes[i]
    # to numpy array for indexing
    risks_np = np.asarray(risks)
    sns.kdeplot(risks_np[event == 1], label="Repurchase (event=1)", fill=True, color="tab:blue", ax=ax)
    sns.kdeplot(risks_np[event == 0], label="No Repurchase (event=0)", fill=True, color="tab:orange", ax=ax)
    ax.set_title(f"{name}")
    ax.set_xlabel("Risk Score")
    ax.set_ylabel("Density")

# legend below the plots
handles, labels = axes[0].get_legend_handles_labels()
fig.legend(handles, labels, loc='lower center', bbox_to_anchor=(0.5, -0.05), ncol=2)

plt.tight_layout()
plt.show()



In [None]:

from lifelines import KaplanMeierFitter

def plot_calibration_on_ax(ax, model_name, predictions, pred_type, y_test_data, is_first_plot=False):
    time_points = [90, 180, 360]
    colors = ['#FFC300',       '#FF5733',        '#8E44AD']
    linestyles = ['-', '-', '-']
    
    diag_label = 'Perfect Calibration' if is_first_plot else '_nolegend_'
    ax.plot([0, 1], [0, 1], 'k', linestyle=':', label='Perfect calibration', linewidth=0.8)

    for i, t0 in enumerate(time_points):
        # get survival probabilities based on prediction type
        if pred_type == 'funcs':
            survival_probs_at_t0 = np.array([fn(t0) for fn in predictions])
        elif pred_type == 'df':
            survival_probs_at_t0 = predictions[float(t0) if float(t0) in predictions.columns else int(t0)]
        elif pred_type == 'df_transposed':
            survival_probs_at_t0 = predictions.T[float(t0)]
        
        event_probs_at_t0 = 1 - survival_probs_at_t0
        
        # calculation logic
        cal_df = pd.DataFrame({
            'predicted_prob': event_probs_at_t0,
            'time': y_test_data['time'],
            'event': y_test_data['event']
        })
        try:
            cal_df['prob_bin'] = pd.qcut(cal_df['predicted_prob'], q=10, labels=False, duplicates='drop')
        except ValueError:
            cal_df['prob_bin'] = pd.cut(cal_df['predicted_prob'], bins=10, labels=False, include_lowest=True)
        
        binned_groups = cal_df.groupby('prob_bin')
        mean_predicted_prob = binned_groups['predicted_prob'].mean()
        
        observed_frequency = []
        for name, group in binned_groups:
            kmf = KaplanMeierFitter()
            kmf.fit(group['time'], event_observed=group['event'])
            observed_frequency.append(1 - kmf.predict(t0, interpolate=True))

        # plotting
        sorted_indices = mean_predicted_prob.argsort()
        ax.plot(mean_predicted_prob.iloc[sorted_indices], pd.Series(observed_frequency).iloc[sorted_indices],
                c=colors[i],
                linestyle=linestyles[i],
                linewidth=3,
                label=f'{t0} days')

    # subplot formatting
    upper_limit = 0.35
    ax.set_xlim(0, upper_limit)
    ax.set_ylim(0, upper_limit)
    ax.set_title(f'{model_name}')
    ax.set_aspect('equal', adjustable='box')
    ax.grid(True, linestyle='--', alpha=0.2)
    

# CREATE FIGURE AND PLOT MODELS 


# figure with 1 row and 5 columns
fig, axes = plt.subplots(1, 5, figsize=(24, 5), sharey=True)

# models, predictions, and prediction types
models_to_plot = [
    ('Cox', cox_surv_funcs, 'funcs'),
    ('RSF', rsf_surv_funcs, 'funcs'),
    ('XGBSE', xgbse_test_survival_probs, 'df'),
    ('DeepHit', deephit_test_surv_df, 'df'),
    ('DeepSurv', deepsurv_test_surv_df, 'df_transposed')
]

# loop through models and plot on respective axes
for i, (model_name, preds, pred_type) in enumerate(models_to_plot):
    plot_calibration_on_ax(axes[i], model_name, preds, pred_type, y_test, is_first_plot=(i==0))

# figure formatting
# shared y-axis label
axes[0].set_ylabel('Observed Frequency of Repurchase')
# shared x-axis label in the middle
fig.text(0.5, 0.05, 'Mean Predicted Probability of Repurchase', ha='center', va='center', fontsize=12)

# a single legend for the entire figure
handles, labels = axes[0].get_legend_handles_labels()
fig.legend(handles, labels, loc='lower center', bbox_to_anchor=(0.5, -0.05), ncol=4, fontsize=12)

# adjust layout to prevent overlap
plt.tight_layout()
plt.subplots_adjust(bottom=0.15)  # room for the legend at the bottom
plt.show()

In [None]:
from sklearn.metrics import roc_curve
from sklearn.metrics import auc

ROC_COLORS = ['#FFC300',       '#FF5733',        '#8E44AD']

models_to_plot = [
    ('Cox', cox_test_risks), 
    ('RSF', rsf_test_risks),
    ('XGBSE', xgbse_test_risks), 
    ('DeepHit', deephit_test_risks),
    ('DeepSurv', deepsurv_test_risks)
]

# pre-calculate AUC curves and mean values
evaluation_times = np.arange(int(y_test['time'].min()), int(y_test['time'].max()), 10)
all_auc_scores = pd.DataFrame(index=evaluation_times)
mean_aucs = {}

print("calculating cumulative/dynamic AUC over time")
for model_name, risks in models_to_plot:
    auc_curve, integrated_mean_auc = cumulative_dynamic_auc(
        survival_train=y_train_struct, survival_test=y_test_struct,
        estimate=risks, times=evaluation_times
    )
    all_auc_scores[model_name] = auc_curve
    mean_aucs[model_name] = integrated_mean_auc
print("calculation complete")


# HELPER FUNCTION FOR ROC SNAPSHOTS

def time_dependent_roc_at_t(y_true, risk_scores, time_point):
    events_by_t = (y_true['time'] <= time_point) & (y_true['event'] == 1)
    no_events_by_t = y_true['time'] > time_point
    if events_by_t.sum() == 0 or no_events_by_t.sum() == 0: return np.array([0, 1]), np.array([0, 1])
    y_true_binary = np.concatenate([np.ones(events_by_t.sum()), np.zeros(no_events_by_t.sum())])
    y_scores_binary = np.concatenate([risk_scores[events_by_t], risk_scores[no_events_by_t]])
    return roc_curve(y_true_binary, y_scores_binary)


# 2x5 FIGURE 
fig, axes = plt.subplots(2, 5, figsize=(24, 10), sharex=False, sharey=False)

for i, (model_name, risks) in enumerate(models_to_plot):
    # axes for this model
    ax_top = axes[0, i]
    ax_bottom = axes[1, i]
    
    # AUC over time (top) 
    model_mean_auc = mean_aucs[model_name]
    ax_top.axhline(model_mean_auc, color='gray', linestyle='--', linewidth=1.5, alpha=0.9,
                   label=f'Mean AUC = {model_mean_auc:.3f}')
    ax_top.plot(evaluation_times, all_auc_scores[model_name], color=COLORS[0], linewidth=2.5)
    ax_top.set_title(model_name)
    ax_top.grid(True, linestyle='--', alpha=0.6)
    ax_top.legend(loc='lower right', fontsize=9)
    ax_top.set_ylim(0.45, 1.0)

    # ROC curves (bottom)
    time_points = [90, 180, 360]
    for j, t in enumerate(time_points):
        fpr, tpr, _ = time_dependent_roc_at_t(y_test, risks, t)
        auc_score = auc(fpr, tpr)
        ax_bottom.plot(fpr, tpr, color=ROC_COLORS[j], label=f'AUC at {t} days = {auc_score:.3f}', lw=2)
    
    ax_bottom.plot([0, 1], [0, 1], 'k--', alpha=0.6)
    ax_bottom.set_aspect('equal', adjustable='box')
    ax_bottom.legend(fontsize=9)

# figure formatting
# shared labels for rows and columns
for i in range(5):
    axes[1, i].set_xlabel('False Positive Rate') # Label bottom row
    axes[0, i].sharey(axes[0, 0]) # Share Y-axis for top row
    axes[1, i].sharey(axes[1, 0]) # Share Y-axis for bottom row
    axes[1, i].sharex(axes[1, 0]) # Share X-axis for bottom row

axes[0, 0].set_ylabel('Cumulative/Dynamic AUC')
axes[1, 0].set_ylabel('True Positive Rate')

# adjust layout
plt.tight_layout(rect=[0.02, 0.0, 1, 0.95]) # make space for title
plt.show()

## Model explainability

In [None]:
import pandas as pd
import survshap
from survshap import SurvivalModelExplainer, ModelSurvSHAP
from sksurv.util import Surv

def stratified_sample(X, y, n_samples=150, random_state=42):
    event_rate = y['event'].mean()
    n_event_1 = int(n_samples * event_rate)
    n_event_0 = n_samples - n_event_1
    
    indices_1 = X[y['event'] == 1].index
    indices_0 = X[y['event'] == 0].index
    
    sample_1 = X.loc[indices_1].sample(n=min(n_event_1, len(indices_1)), random_state=random_state)
    sample_0 = X.loc[indices_0].sample(n=min(n_event_0, len(indices_0)), random_state=random_state)
    
    return pd.concat([sample_1, sample_0])

# sample data
X_train_sample = stratified_sample(X_train, y_train, n_samples=150)
X_test_sample = stratified_sample(X_test, y_test, n_samples=150)

# convert to structured arrays
y_train_sample = Surv.from_arrays(
    y_train.loc[X_train_sample.index, 'event'].astype(bool),
    y_train.loc[X_train_sample.index, 'time']
)

y_test_sample = Surv.from_arrays(
    y_test.loc[X_test_sample.index, 'event'].astype(bool),
    y_test.loc[X_test_sample.index, 'time']
)

# create explainer and fit
explainer = SurvivalModelExplainer(
    model=rsf_model, 
    data=X_train_sample, 
    y=y_train_sample
)

rsf_model_survshap = ModelSurvSHAP(calculation_method="sampling", B=20, random_state=1212)
rsf_model_survshap.fit(explainer=explainer, new_observations=X_test_sample)
print("ModelSurvSHAP fitted")

In [None]:
max_vars = 8
df_to_plot = rsf_model_survshap.result.head(max_vars)
time_cols = [col for col in df_to_plot.columns if col.startswith('t = ')]
time_values = [float(col.split('=')[1]) for col in time_cols]

# plotting

# color palette
colors = ['#ae2c87', '#46bac2', '#ffa58c', '#f05a71', '#8bdcbe', '#4378bf', '#FED61E', '#371ea3']

# figure and axes
fig, ax = plt.subplots(figsize=(20, 6))

# loop through each feature and plot its line
for i, (index, row) in enumerate(df_to_plot.iterrows()):
    feature_name = row['variable_name']
    shap_values = row[time_cols].values
    
    ax.plot(
        time_values, 
        shap_values, 
        label=feature_name, 
        color=colors[i % len(colors)],
        linewidth=2.5, 
        alpha=0.9      
    )

# customization

# axis labels
ax.set_xlabel('Time', fontsize=12)
ax.set_ylabel('mean(|SHAP value|)', fontsize=12)

# frameless legend below plot
ax.legend(
    loc='upper center', 
    bbox_to_anchor=(0.5, -0.1), 
    ncol=8,                   
    frameon=False,            
    fontsize=12
)

# clean grid and spines
ax.grid(True, which='both', linestyle='--', linewidth=0.5, color='grey', alpha=0.3)
ax.spines['top'].set_visible(False)
ax.spines['right'].set_visible(False)
ax.spines['left'].set_visible(True)
ax.spines['bottom'].set_visible(True)
fig.set_facecolor('white')

# space for the legend at the bottom
plt.subplots_adjust(bottom=0.25)

plt.show()

In [None]:
def create_inverted_survshap_beeswarm(model_survshap, top_n=8):
    # top features based on aggregated change
    importance_df = model_survshap.result[['variable_name', 'aggregated_change']].copy()
    importance_df.columns = ['feature', 'importance']
    importance_df = importance_df.sort_values('importance', ascending=False)
    top_features = importance_df.head(top_n)['feature'].tolist()
    
    # extract SHAP values for each observation
    full_result_b0 = model_survshap.full_result[model_survshap.full_result['B'] == 0].copy()
    time_columns = [col for col in full_result_b0.columns if col.startswith('t = ')]
    
    plot_data = []
    for idx, row in full_result_b0.iterrows():
        if row['variable_name'] in top_features:
            # INVERT the SHAP values for intuitive meaning
            mean_shap = -row[time_columns].mean()
            plot_data.append({
                'feature': row['variable_name'],
                'aggregated_shap': mean_shap,
                'feature_value': row['variable_value']
            })
    
    df_plot = pd.DataFrame(plot_data)
    

    # create figure
    fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(16, 8),
                                   gridspec_kw={'width_ratios': [1, 2.5]},
                                   sharey=True)
    
    # sorted ascending for correct order in plot
    importance_sorted = importance_df.head(top_n).sort_values('importance', ascending=True)
    feature_order = importance_sorted['feature'].tolist()

    # left plot: lollipop chart
    ax1.hlines(
        y=importance_sorted['feature'],
        xmin=0,
        xmax=importance_sorted['importance'],
        color='#4682B4',
        alpha=0.6,
        linewidth=2
    )
    ax1.scatter(
        x=importance_sorted['importance'],
        y=importance_sorted['feature'],
        s=100,
        color='#4682B4',
        alpha=0.9,
        zorder=3
    )
    
    # right Plot: beeswarm
    # normalize for coloring (0 to 1)
    df_plot['feature_value_norm'] = 0.5
    for feature in top_features:
        mask = df_plot['feature'] == feature
        if mask.sum() > 0:
            vals = df_plot.loc[mask, 'feature_value'].values
            if vals.std() > 0:
                normalized = (vals - vals.min()) / (vals.max() - vals.min())
                df_plot.loc[mask, 'feature_value_norm'] = normalized
    
    # plot with jitter
    np.random.seed(42)
    # feature_order ensures beeswarm aligns with lollipops
    for i, feature in enumerate(feature_order):
        feature_data = df_plot[df_plot['feature'] == feature]
        if len(feature_data) > 0:
            y_jitter = np.random.normal(0, 0.08, len(feature_data))
            y_positions = i + y_jitter
            
            # keep a reference to the scatter plot for the colorbar
            scatter = ax2.scatter(
                feature_data['aggregated_shap'].values,
                y_positions,
                c=feature_data['feature_value_norm'].values,
                cmap='RdBu_r', vmin=0, vmax=1,
                s=50, alpha=0.6,
                edgecolors='none'
            )
            
    # styling

    # left plot (ax1)
    ax1.set_xlabel('Average |Aggregated SurvSHAP(t)| Value')
    ax1.set_title('(a)', loc='center', y=-0.1, weight='bold', fontsize=12)
    ax1.tick_params(axis='y', length=0) # Hide tick marks on y-axis
    ax1.spines[['top', 'right', 'left']].set_visible(False)
    ax1.xaxis.grid(True, linestyle='--', which='major', color='grey', alpha=0.25)
    ax1.set_xlim(0)

    # right plot (ax2)
    ax2.set_xlabel('Aggregated SurvSHAP(t) Value')
    ax2.set_title('(b)', loc='center', y=-0.1, weight='bold', fontsize=12)
    ax2.axvline(x=0, color='grey', linestyle='--', linewidth=1.5)
    ax2.spines[['top', 'right', 'left']].set_visible(False)
    ax2.xaxis.grid(True, linestyle='--', which='major', color='grey', alpha=0.25)
    
    # hide the tick marks
    ax2.tick_params(axis='y', length=0)
    
    # clean colorbar
    cbar = fig.colorbar(scatter, ax=ax2, aspect=40, pad=0.03)
    cbar.set_label('Feature Value (High / Low)', rotation=270, labelpad=15)
    cbar.outline.set_visible(False)
    
    # layout adjustment
    plt.tight_layout(rect=[0, 0, 1, 0.94]) # Adjust rect to make space for suptitle
    plt.show()
    
    return df_plot

df_inverted_plot_data = create_inverted_survshap_beeswarm(rsf_model_survshap, top_n=8)

In [None]:
# collect all best parameters for appendix
print("BEST HYPERPARAMETERS FOR ALL MODELS\n")

# cox model
print("1. COX PROPORTIONAL HAZARDS MODEL")
print("-" * 40)
with open('../../models/best_cox_params.json', 'r') as f:
    cox_params = json.load(f)
for param, value in cox_params.items():
    print(f"{param}: {value}")
print()

# random survival forest
print("2. RANDOM SURVIVAL FOREST")
print("-" * 40)
with open('../../models/best_rsf_params.json', 'r') as f:
    rsf_params = json.load(f)
for param, value in rsf_params.items():
    print(f"{param}: {value}")
print()

# xgbse
print("3. XGBSE (XGBoost Survival Embeddings)")
print("-" * 40)
with open('../../models/best_xgbse_params.json', 'r') as f:
    xgbse_params = json.load(f)
for param, value in xgbse_params.items():
    print(f"{param}: {value}")
print()

# deephit
print("4. DEEPHIT")
print("-" * 40)
with open('../../models/best_deephit_params.json', 'r') as f:
    deephit_params = json.load(f)
for param, value in deephit_params.items():
    print(f"{param}: {value}")
print()

# deepsurv
print("5. DEEPSURV")
print("-" * 40)
with open('../../models/best_deepsurv_params.json', 'r') as f:
    deepsurv_params = json.load(f)
for param, value in deepsurv_params.items():
    print(f"{param}: {value}")

In [None]:
# full SurvSHAP results df
df = rsf_model_survshap.result

# sort by aggregated_change (feature importance)
df_sorted = df.sort_values('aggregated_change', ascending=True)

plt.figure(figsize=(10, max(6, 0.4 * len(df_sorted))))
plt.barh(df_sorted['variable_name'], df_sorted['aggregated_change'], color='#4682B4', alpha=0.8)
plt.xlabel('Aggregated |SurvSHAP(t)| Value')
plt.ylabel('Feature')
plt.title('Feature Importance (Aggregated SurvSHAP) - Full List')
plt.tight_layout()
plt.show()