In [None]:
import os
import sklearn
import pandas as pd
import numpy as np
from sksurv.ensemble.survival_loss import CoxPH
from sksurv.preprocessing import OneHotEncoder
from sympy.codegen.fnodes import merge
from sklearn.impute import KNNImputer
from sklearn.preprocessing import MinMaxScaler, OneHotEncoder
from sklearn.model_selection import train_test_split
from sklearn.ensemble import RandomForestClassifier
from sklearn.pipeline import Pipeline
from sklearn.compose import ColumnTransformer

raw_data = pd.read_csv("raw_data.csv")
raw_data
###数据插补
imputer = KNNImputer(n_neighbors=3)
raw_data_imp = pd.DataFrame(imputer.fit_transform(raw_data), columns=raw_data.columns)
raw_data_imp["BMI"] = raw_data_imp["Weight"] / (raw_data_imp["Height"] / 100) ** 2
raw_data_imp
raw_data_imp.columns

hypertension_df = raw_data_imp[raw_data_imp['Hypertension'] == 1]
heart_failure_df = raw_data_imp[raw_data_imp['Heart failure'] == 1]
coronary_heart_disease_df = raw_data_imp[raw_data_imp['Coronary heart disease'] == 1]
angina_pectoris_df = raw_data_imp[raw_data_imp['Angina pectoris'] == 1]
myocardial_infarction_df = raw_data_imp[raw_data_imp['Myocardial infarction'] == 1]
stroke_df = raw_data_imp[raw_data_imp['Stroke'] == 1]
print("Hypertension cases:", hypertension_df.shape[0])
print("Heart failure cases:", heart_failure_df.shape[0])
print("Coronary heart disease cases:", coronary_heart_disease_df.shape[0])
print("Angina pectoris cases:", angina_pectoris_df.shape[0])
print("Myocardial infarction cases:", myocardial_infarction_df.shape[0])
print("Stroke cases:", stroke_df.shape[0])

data_Catogory = hypertension_df[['Gender', 'Race', 'Education', 'Hypertension',
                                 'Heart failure', 'Coronary heart disease', 'Angina pectoris',
                                 'Myocardial infarction', 'Stroke', 'Survival status', 'Death cause',
                                 'Survival months']]
data_Int = hypertension_df[['Age', 'BMI', 'Urine  total arsenic', 'Urine arsenic acid',
                            'Urine arsenous acid', 'Urine arsenobetaine', 'Urine arsenocholine',
                            'Urine dimethylarsinic acid', 'Urine monomethylarsonic acid',
                            'Blood lead', 'Urine lead', 'Blood Cadmium', 'Blood total mercury',
                            'Urine mercury', 'Blood inorganic mercury', 'Urine Barium',
                            'Urine cadmium', 'Urine cobalt', 'Urine cesium', 'Urine molybdenum',
                            'Urine antimony', 'Urine thallium', 'Urine tungsten']]
from sklearn.preprocessing import MinMaxScaler

scaler = MinMaxScaler(feature_range=(0, 10))
data_scaled = scaler.fit_transform(data_Int)
data_scaled = pd.DataFrame(data_scaled)
data_scaled.columns = data_Int.columns
data_Catogory = data_Catogory.reset_index(drop=True)
data_scaled = data_scaled.reset_index(drop=True)
#float
data_scaled = data_scaled.astype('float')
hypertension_df = pd.concat([data_Catogory, data_scaled], axis=1)
hypertension_df
survival_time = hypertension_df['Survival months']
survival_status = hypertension_df['Survival status']
complete_df = hypertension_df[
    ['Gender', 'Race', 'Education', 'Age', 'BMI', 'Urine  total arsenic', 'Urine arsenic acid',
     'Urine arsenous acid', 'Urine arsenobetaine', 'Urine arsenocholine',
     'Urine dimethylarsinic acid', 'Urine monomethylarsonic acid',
     'Blood lead', 'Urine lead', 'Blood Cadmium', 'Blood total mercury',
     'Urine mercury', 'Blood inorganic mercury', 'Urine Barium',
     'Urine cadmium', 'Urine cobalt', 'Urine cesium', 'Urine molybdenum',
     'Urine antimony', 'Urine thallium', 'Urine tungsten']]

complete_df

from sklearn.model_selection import train_test_split

survival_target = pd.concat([survival_time, survival_status], axis=1)
survival_target['Survival status'] = survival_target['Survival status'].map({0: False, 1: True})
survival_target = [(survival_target['Survival status'].iloc[i], survival_target["Survival months"].iloc[i]) for i in
                   range(survival_target.shape[0])]
survival_target = np.array(survival_target, dtype=[('status', 'bool'), ('time', '<f8')])
survival_x_train, survival_x_test, new_y_train, new_y_test = train_test_split(complete_df, survival_target,
                                                                              test_size=0.3, random_state=0)
new_y_test
survival_x_train.shape

In [None]:
import numpy as np
import matplotlib.pyplot as plt
import pandas
import seaborn as sns
from sklearn import set_config
from sklearn.model_selection import ShuffleSplit, GridSearchCV

from sksurv.datasets import load_veterans_lung_cancer
from sksurv.column import encode_categorical
from sksurv.metrics import concordance_index_censored
from sksurv.svm import FastSurvivalSVM

set_config(display="text")  # displays text representation of estimators
sns.set_style("whitegrid")
from sksurv.linear_model import CoxPHSurvivalAnalysis, CoxnetSurvivalAnalysis
from sksurv.metrics import (
    concordance_index_censored,
    concordance_index_ipcw,
    cumulative_dynamic_auc,
    integrated_brier_score,
)
from sksurv.ensemble import RandomSurvivalForest
from sksurv.tree import SurvivalTree

In [None]:
from sksurv.metrics import concordance_index_censored, cumulative_dynamic_auc
from sksurv.util import Surv

In [None]:
##Coxph
coxph = CoxPHSurvivalAnalysis(alpha= 1.0, n_iter=50, ties='breslow', tol=1e-05)
coxph.fit(survival_x_train,new_y_train)

In [None]:
coxph.score(survival_x_train,new_y_train)

In [None]:
coxph.score(survival_x_test,new_y_test)

In [None]:
va_times = np.arange(12, 184, 12)
cph_risk_scores = coxph.predict(survival_x_test)
cph_auc, cph_mean_auc = cumulative_dynamic_auc(new_y_train, new_y_test, cph_risk_scores, va_times)

plt.plot(va_times, cph_auc, marker="o")
plt.axhline(cph_mean_auc, linestyle="--")
plt.xlabel("days from enrollment")
plt.ylabel("time-dependent AUC")
plt.grid(True)

In [None]:
##FastKernelSurvivalSVM
from sksurv.svm import FastKernelSurvivalSVM

In [None]:

Fksvm = FastKernelSurvivalSVM(
 alpha=0.01,
 kernel='rbf',
 max_iter=20,
 optimizer='rbtree',
 rank_ratio= 1.0
)
Fksvm.fit(survival_x_train, new_y_train)

In [None]:
Fksvm.score(survival_x_train, new_y_train)

In [None]:
Fksvm.score(survival_x_test, new_y_test)

In [None]:

Fksvm_risk_scores = Fksvm.predict(survival_x_test)
Fksvm_auc, Fksvm_mean_auc = cumulative_dynamic_auc(new_y_train, new_y_test, Fksvm_risk_scores, va_times)

plt.plot(va_times, Fksvm_auc, marker="o")
plt.axhline(Fksvm_mean_auc, linestyle="--")
plt.xlabel("days from enrollment")
plt.ylabel("time-dependent AUC")
plt.grid(True)

In [None]:
###GradientBoostingSurvivalAnalysis
from sksurv.ensemble import GradientBoostingSurvivalAnalysis

In [None]:

gbs_model = GradientBoostingSurvivalAnalysis(
 max_depth= 2,
 min_samples_leaf= 30,
 min_samples_split= 2,
 n_estimators= 500)

In [None]:

gbs_model.fit(survival_x_train, new_y_train)
gbs_model.score(survival_x_train, new_y_train)

In [None]:
gbs_model.score(survival_x_test, new_y_test)

In [None]:
gbs_model_risk_scores = gbs_model.predict(survival_x_test)
gbs_model_auc, gbs_model_mean_auc = cumulative_dynamic_auc(new_y_train, new_y_test, gbs_model_risk_scores, va_times)

plt.plot(va_times, gbs_model_auc, marker="o")
plt.axhline(gbs_model_mean_auc, linestyle="--")
plt.xlabel("days from enrollment")
plt.ylabel("time-dependent AUC")
plt.grid(True)

In [None]:
###RandomSurvivalForest
from sksurv.ensemble import RandomSurvivalForest

In [None]:

rsf_model = RandomSurvivalForest(
    max_depth=None,
 min_samples_leaf= 5,
 min_samples_split= 2,
 n_estimators= 500)

In [None]:
rsf_model.fit(survival_x_train, new_y_train)
rsf_model.score(survival_x_train, new_y_train)

In [None]:
rsf_model.score(survival_x_test, new_y_test)

In [None]:
rsf_chf_funcs = rsf_model.predict_cumulative_hazard_function(survival_x_test, return_array=False)
rsf_risk_scores = np.vstack([chf(va_times) for chf in rsf_chf_funcs])

rsf_auc, rsf_mean_auc = cumulative_dynamic_auc(new_y_train, new_y_test, rsf_risk_scores, va_times)
plt.plot(va_times, rsf_auc, marker="o")
plt.axhline(rsf_mean_auc, linestyle="--")
plt.xlabel("days from enrollment")
plt.ylabel("time-dependent AUC")
plt.grid(True)

In [None]:
###ExtraSurvivalTrees
from sksurv.ensemble import ExtraSurvivalTrees
from sksurv.metrics import concordance_index_censored

In [None]:

extra_trees_model = ExtraSurvivalTrees(
 max_depth=None,
 max_leaf_nodes= None,
 min_samples_leaf=3,
 min_samples_split=10,
 min_weight_fraction_leaf= 0.0,
 n_estimators=100)

In [None]:
extra_trees_model.fit(survival_x_train, new_y_train)

In [None]:

predicted_risk_scores = extra_trees_model.predict(survival_x_train)
c_index = concordance_index_censored(new_y_train['status'], new_y_train['time'], predicted_risk_scores)
print(f"Concordance Index (C-index): {c_index[0]}")

In [None]:

predicted_risk_scores = extra_trees_model.predict(survival_x_test)
c_index = concordance_index_censored(new_y_test['status'], new_y_test['time'], predicted_risk_scores)
print(f"Concordance Index (C-index): {c_index[0]}")

In [None]:
et_chf_funcs = extra_trees_model.predict_cumulative_hazard_function(survival_x_test, return_array=False)
et_risk_scores = np.vstack([chf(va_times) for chf in et_chf_funcs])

et_auc, et_mean_auc = cumulative_dynamic_auc(new_y_train, new_y_test, et_risk_scores, va_times)
plt.plot(va_times, et_auc, marker="o")
plt.axhline(et_mean_auc, linestyle="--")
plt.xlabel("days from enrollment")
plt.ylabel("time-dependent AUC")
plt.grid(True)

In [None]:
plt.plot(va_times, cph_auc, "o-", label=f"CoxPHSurvival (mean AUC = {cph_mean_auc:.3f})")
plt.plot(va_times, rsf_auc, "o-", label=f"Random survival forest (mean AUC = {rsf_mean_auc:.3f})")
plt.plot(va_times, gbs_model_auc, "o-", label=f"GradientBoostingSurvival (mean AUC = {gbs_model_mean_auc:.3f})")
plt.plot(va_times, Fksvm_auc, "o-", label=f"FastKernelSurvivalSVM (mean AUC = {Fksvm_mean_auc:.3f})")
plt.plot(va_times, et_auc, "o-", label=f"ExtraSurvivalTrees (mean AUC = {et_mean_auc:.3f})")
plt.xlabel("months from enrollment")
plt.ylabel("time-dependent AUC")
plt.ylim(0.4, 1)  
plt.legend(loc="lower right", prop={'size': 9})
plt.grid(True)
plt.show()

计算c-index brier score

In [None]:
import numpy as np
import pandas as pd
from sksurv.metrics import concordance_index_censored, integrated_brier_score
from sksurv.util import Surv
from sklearn.utils import resample

surv_train = Surv.from_arrays(event=new_y_train['status'], time=new_y_train['time'])
surv_test = Surv.from_arrays(event=new_y_test['status'], time=new_y_test['time'])

def calculate_cindex(model, x_test, y_test):
    risk_scores = model.predict(x_test)
    result = concordance_index_censored(y_test['status'], y_test['time'], risk_scores)
    return result[0]

def calculate_integrated_brier_score(model, x_train, y_train, x_test, y_test, time_points):
    survival_probs = np.vstack([fn(time_points) for fn in model.predict_survival_function(x_test)])
    brier_score_val = integrated_brier_score(y_train, y_test, survival_probs, time_points)
    return brier_score_val


lower, upper = np.percentile(new_y_train["time"], [10, 90])
time_points = np.arange(lower, upper + 1)

# Bootstrap 
n_iterations = 1000
cindex_scores = []
brier_scores_list = []


In [None]:
#CoxPH
for _ in range(n_iterations):
    #  bootstrap 
    x_test_resampled, y_test_resampled = resample(survival_x_test, new_y_test)
    
    #  C-index
    cindex = calculate_cindex(coxph, x_test_resampled, y_test_resampled)
    cindex_scores.append(cindex)
    
    #  Integrated Brier Score
    brier_score_val = calculate_integrated_brier_score(coxph, survival_x_train, surv_train, x_test_resampled, y_test_resampled, time_points)
    brier_scores_list.append(brier_score_val)

#  C-index  95% CI
cindex_lower = np.percentile(cindex_scores, 2.5)
cindex_upper = np.percentile(cindex_scores, 97.5)

#  Brier Score 95% CI
brier_lower = np.percentile(brier_scores_list, 2.5)
brier_upper = np.percentile(brier_scores_list, 97.5)

print(f"C-index: {np.mean(cindex_scores):.3f}")
print(f"95% CI for C-index: [{cindex_lower:.3f}, {cindex_upper:.3f}]")
print(f"Integrated Brier Score: {np.mean(brier_scores_list):.3f}")
print(f"95% CI for Integrated Brier Score: [{brier_lower:.3f}, {brier_upper:.3f}]")

In [None]:
#SVM
for _ in range(n_iterations):
    #  bootstrap 
    x_test_resampled, y_test_resampled = resample(survival_x_test, new_y_test)
    
    #  C-index
    cindex = calculate_cindex(Fksvm, x_test_resampled, y_test_resampled)
    cindex_scores.append(cindex)
    
    # #  Integrated Brier Score
    # brier_score_val = calculate_integrated_brier_score(Fksvm, survival_x_train, surv_train, x_test_resampled, y_test_resampled, time_points)
    # brier_scores_list.append(brier_score_val)

#  C-index 95% CI
cindex_lower = np.percentile(cindex_scores, 2.5)
cindex_upper = np.percentile(cindex_scores, 97.5)

# #  Brier Score 95% CI
# brier_lower = np.percentile(brier_scores_list, 2.5)
# brier_upper = np.percentile(brier_scores_list, 97.5)


print(f"C-index: {np.mean(cindex_scores):.3f}")
print(f"95% CI for C-index: [{cindex_lower:.3f}, {cindex_upper:.3f}]")
# print(f"Integrated Brier Score: {np.mean(brier_scores_list):.3f}")
# print(f"95% CI for Integrated Brier Score: [{brier_lower:.3f}, {brier_upper:.3f}]")

In [None]:
#gbs_model
for _ in range(n_iterations):
    #  bootstrap 
    x_test_resampled, y_test_resampled = resample(survival_x_test, new_y_test)
    
    #  C-index
    cindex = calculate_cindex(gbs_model, x_test_resampled, y_test_resampled)
    cindex_scores.append(cindex)
    
    #  Integrated Brier Score
    brier_score_val = calculate_integrated_brier_score(gbs_model, survival_x_train, surv_train, x_test_resampled, y_test_resampled, time_points)
    brier_scores_list.append(brier_score_val)

#  C-index 95% CI
cindex_lower = np.percentile(cindex_scores, 2.5)
cindex_upper = np.percentile(cindex_scores, 97.5)

#  Brier Score 95% CI
brier_lower = np.percentile(brier_scores_list, 2.5)
brier_upper = np.percentile(brier_scores_list, 97.5)


print(f"C-index: {np.mean(cindex_scores):.3f}")
print(f"95% CI for C-index: [{cindex_lower:.3f}, {cindex_upper:.3f}]")
print(f"Integrated Brier Score: {np.mean(brier_scores_list):.3f}")
print(f"95% CI for Integrated Brier Score: [{brier_lower:.3f}, {brier_upper:.3f}]")

In [None]:
#rsf
for _ in range(n_iterations):
    # bootstrap 
    x_test_resampled, y_test_resampled = resample(survival_x_test, new_y_test)
    
    # C-index
    cindex = calculate_cindex(rsf_model, x_test_resampled, y_test_resampled)
    cindex_scores.append(cindex)
    
    # Integrated Brier Score
    brier_score_val = calculate_integrated_brier_score(rsf_model, survival_x_train, surv_train, x_test_resampled, y_test_resampled, time_points)
    brier_scores_list.append(brier_score_val)

# C-index 95% CI
cindex_lower = np.percentile(cindex_scores, 2.5)
cindex_upper = np.percentile(cindex_scores, 97.5)

#Brier Score 95% CI
brier_lower = np.percentile(brier_scores_list, 2.5)
brier_upper = np.percentile(brier_scores_list, 97.5)

print(f"C-index: {np.mean(cindex_scores):.3f}")
print(f"95% CI for C-index: [{cindex_lower:.3f}, {cindex_upper:.3f}]")
print(f"Integrated Brier Score: {np.mean(brier_scores_list):.3f}")
print(f"95% CI for Integrated Brier Score: [{brier_lower:.3f}, {brier_upper:.3f}]")

In [None]:
#extrasurvivaltreee
for _ in range(n_iterations):
    # bootstrap
    x_test_resampled, y_test_resampled = resample(survival_x_test, new_y_test)
    
    #  C-index
    cindex = calculate_cindex(extra_trees_model, x_test_resampled, y_test_resampled)
    cindex_scores.append(cindex)
    
    # Integrated Brier Score
    brier_score_val = calculate_integrated_brier_score(extra_trees_model, survival_x_train, surv_train, x_test_resampled, y_test_resampled, time_points)
    brier_scores_list.append(brier_score_val)

#  C-index 95% CI
cindex_lower = np.percentile(cindex_scores, 2.5)
cindex_upper = np.percentile(cindex_scores, 97.5)

#  Brier Score 95% CI
brier_lower = np.percentile(brier_scores_list, 2.5)
brier_upper = np.percentile(brier_scores_list, 97.5)

print(f"C-index: {np.mean(cindex_scores):.3f}")
print(f"95% CI for C-index: [{cindex_lower:.3f}, {cindex_upper:.3f}]")
print(f"Integrated Brier Score: {np.mean(brier_scores_list):.3f}")
print(f"95% CI for Integrated Brier Score: [{brier_lower:.3f}, {brier_upper:.3f}]")

Calibration curve

In [None]:
import numpy as np
import matplotlib.pyplot as plt
from sksurv.metrics import concordance_index_ipcw, cumulative_dynamic_auc
from sksurv.util import Surv
from lifelines import KaplanMeierFitter
from sklearn.model_selection import KFold

# data prepare
time_points = np.percentile(new_y_test['time'], np.linspace(5, 95, 10))
surv_test = Surv.from_arrays(new_y_test['status'], new_y_test['time'])
predicted_survival = coxph.predict_survival_function(survival_x_test)

kf = KFold(n_splits=5)
calibration_curves = []
all_predicted_prob = []
all_actual_prob = []

# K fold
for train_index, test_index in kf.split(survival_x_test):
    X_train, X_test = survival_x_test.iloc[train_index], survival_x_test.iloc[test_index]
    y_train, y_test = new_y_test[train_index], new_y_test[test_index]
    coxph.fit(X_train, y_train)
    predicted_survival = coxph.predict_survival_function(X_test)
    predicted_probabilities = np.array([surv(time_points) for surv in predicted_survival]).T
    kmf = KaplanMeierFitter()
    kmf.fit(y_test['time'], event_observed=y_test['status'])
    actual_probabilities = kmf.survival_function_at_times(time_points).values
    all_predicted_prob.append(predicted_probabilities.mean(axis=1))
    all_actual_prob.append(actual_probabilities)
mean_predicted_prob = np.mean(all_predicted_prob, axis=0)
mean_actual_prob = np.mean(all_actual_prob, axis=0)

plt.figure(figsize=(8, 6))
plt.plot(mean_predicted_prob, mean_actual_prob, marker='o', linestyle='-', label='Average Calibration')
plt.plot([0, 1], [0, 1], 'k--', label='Perfectly calibrated')
plt.xlabel('Predicted survival probability')
plt.ylabel('Observed survival probability')
plt.title('Average Calibration Plot for Survival Analysis')
plt.legend()
plt.show()

In [None]:
###########GBS

time_points = np.percentile(new_y_test['time'], np.linspace(5, 95, 10))
surv_test = Surv.from_arrays(new_y_test['status'], new_y_test['time'])
predicted_survival = gbs_model.predict_survival_function(survival_x_test)

kf = KFold(n_splits=5)
calibration_curves = []
all_predicted_prob = []
all_actual_prob = []

# K fold
for train_index, test_index in kf.split(survival_x_test):
    X_train, X_test = survival_x_test.iloc[train_index], survival_x_test.iloc[test_index]
    y_train, y_test = new_y_test[train_index], new_y_test[test_index]

    gbs_model.fit(X_train, y_train)
    predicted_survival = gbs_model.predict_survival_function(X_test)
    predicted_probabilities = np.array([surv(time_points) for surv in predicted_survival]).T
    kmf = KaplanMeierFitter()
    kmf.fit(y_test['time'], event_observed=y_test['status'])
    actual_probabilities = kmf.survival_function_at_times(time_points).values

    all_predicted_prob.append(predicted_probabilities.mean(axis=1))

mean_predicted_prob = np.mean(all_predicted_prob, axis=0)
mean_actual_prob = np.mean(all_actual_prob, axis=0)

plt.figure(figsize=(8, 6))
plt.plot(mean_predicted_prob, mean_actual_prob, marker='o', linestyle='-', label='Average Calibration')
plt.plot([0, 1], [0, 1], 'k--', label='Perfectly calibrated')
plt.xlabel('Predicted survival probability')
plt.ylabel('Observed survival probability')
plt.title('Average Calibration Plot for Survival Analysis')
plt.legend()
plt.show()

In [None]:
###########rsf

time_points = np.percentile(new_y_test['time'], np.linspace(5, 95, 10))
surv_test = Surv.from_arrays(new_y_test['status'], new_y_test['time'])
predicted_survival = rsf_model.predict_survival_function(survival_x_test)
kf = KFold(n_splits=5)
calibration_curves = []
all_predicted_prob = []
all_actual_prob = []

# Kfold
for train_index, test_index in kf.split(survival_x_test):
    X_train, X_test = survival_x_test.iloc[train_index], survival_x_test.iloc[test_index]
    y_train, y_test = new_y_test[train_index], new_y_test[test_index]

    rsf_model.fit(X_train, y_train)
    predicted_survival = rsf_model.predict_survival_function(X_test)
    predicted_probabilities = np.array([surv(time_points) for surv in predicted_survival]).T
    kmf = KaplanMeierFitter()
    kmf.fit(y_test['time'], event_observed=y_test['status'])
    actual_probabilities = kmf.survival_function_at_times(time_points).values
    all_predicted_prob.append(predicted_probabilities.mean(axis=1))
    all_actual_prob.append(actual_probabilities)

mean_predicted_prob = np.mean(all_predicted_prob, axis=0)
mean_actual_prob = np.mean(all_actual_prob, axis=0)

plt.figure(figsize=(8, 6))
plt.plot(mean_predicted_prob, mean_actual_prob, marker='o', linestyle='-', label='Average Calibration')

plt.plot([0, 1], [0, 1], 'k--', label='Perfectly calibrated')
plt.xlabel('Predicted survival probability')
plt.ylabel('Observed survival probability')
plt.title('Average Calibration Plot for Survival Analysis')
plt.legend()
plt.show()

In [None]:
###########extratree
# prepare
time_points = np.percentile(new_y_test['time'], np.linspace(5, 95, 10))
surv_test = Surv.from_arrays(new_y_test['status'], new_y_test['time'])
predicted_survival = extra_trees_model.predict_survival_function(survival_x_test)


kf = KFold(n_splits=5)
calibration_curves = []

all_predicted_prob = []
all_actual_prob = []

# k fold
for train_index, test_index in kf.split(survival_x_test):
    X_train, X_test = survival_x_test.iloc[train_index], survival_x_test.iloc[test_index]
    y_train, y_test = new_y_test[train_index], new_y_test[test_index]

    extra_trees_model.fit(X_train, y_train)
    predicted_survival = extra_trees_model.predict_survival_function(X_test)

    predicted_probabilities = np.array([surv(time_points) for surv in predicted_survival]).T

    kmf = KaplanMeierFitter()
    kmf.fit(y_test['time'], event_observed=y_test['status'])
    actual_probabilities = kmf.survival_function_at_times(time_points).values

    all_predicted_prob.append(predicted_probabilities.mean(axis=1))
    all_actual_prob.append(actual_probabilities)


mean_predicted_prob = np.mean(all_predicted_prob, axis=0)
mean_actual_prob = np.mean(all_actual_prob, axis=0)


plt.figure(figsize=(8, 6))
plt.plot(mean_predicted_prob, mean_actual_prob, marker='o', linestyle='-', label='Average Calibration')


plt.plot([0, 1], [0, 1], 'k--', label='Perfectly calibrated')

plt.xlabel('Predicted survival probability')
plt.ylabel('Observed survival probability')
plt.title('Average Calibration Plot for Survival Analysis')
plt.legend()
plt.show()

In [None]:
##gbs_model
import shap
import numpy as np

# K-means culster
small_background = shap.kmeans(survival_x_train, 50)

# use KernelExplainer creat explainer
explainer = shap.KernelExplainer(gbs_model.predict, small_background)
shap_values = explainer.shap_values(survival_x_test)
feature_importance = np.abs(shap_values).mean(axis=0)
top_indices = np.argsort(feature_importance)[-20:]

# top20 SHAP value
top_shap_values = shap_values[:, top_indices]
top_feature_names = [survival_x_test.columns[i] for i in top_indices]
top_survival_x_test = survival_x_test.iloc[:, top_indices]

# SHAP top20
shap.decision_plot(
    explainer.expected_value,
    top_shap_values,
    top_survival_x_test,
    feature_names=top_feature_names
)

#  waterfall 
sample_index = 0  
shap.waterfall_plot(
    shap.Explanation(
        values=top_shap_values[sample_index],
        base_values=explainer.expected_value,
        data=top_survival_x_test.iloc[sample_index, :],
        feature_names=top_feature_names
    )
)

In [None]:
# SHAP summary plot
shap.summary_plot(top_shap_values, top_survival_x_test, feature_names=top_feature_names)