In [20]:
import sys
sys.path.append('/Users/shakedcaspi/Documents/tau/survival_analysis_ml/project')

In [21]:
from pycox.evaluation import EvalSurv
import torch
import torchtuples as tt
import os
import pickle
from lifelines.utils import concordance_index

from utils import dataset_name
from eval_utils import MODEL_CLASS, get_results, DL_MODELS, ML_MODELS, MODELS
import numpy as np
import pandas as pd
np.random.seed(42)
_ = torch.manual_seed(42)


print(f"Experiments over {dataset_name}")

Experiments over METABRIC


In [22]:
train = pd.read_csv(f"../../datasets/train/{dataset_name}.csv")
test = pd.read_csv(f"../../datasets/test/{dataset_name}.csv")

X, y = train.drop(["event", "duration"], axis=1), train[["event", "duration"]]
X_test, y_test = test.drop(["event", "duration"], axis=1), test[["event", "duration"]]

# Finding the best parameters for each model and run the models over the test data

In [23]:
results = pd.DataFrame()
     
for model in MODELS:
    stats = pickle.load( open( f"statistics/{model}/best_model.pkl", "rb" ))
    df = pd.DataFrame(stats, index=[model])
    results = pd.concat([results, df], axis=0)
    
c_index_df = results[["c_index", "c_index_std", "c_index_params"]].sort_values("c_index", ascending=False)
concordance_td_df = results[["concordance_td", "concordance_td_std", "concordance_td_params"]].sort_values("concordance_td", ascending=False)
ibs_df = results[["ibs", "ibs_std", "ibs_params"]].sort_values("ibs", ascending=True)

## `c index`

In [24]:
c_index_df

Unnamed: 0,c_index,c_index_std,c_index_params
deep_surv,0.651445,0.026593,deep_surv_dropout_0.16_num_nodes_[41]_activati...
cox_time,0.649788,0.021436,"cox_time_dropout_0.1_num_nodes_[32, 32]_activa..."
rsf,0.645472,0.018103,rsf_n_estimators_100_max_depth_5_min_samples_s...
gbst,0.639102,0.017931,gbst_learning_rate_0.1_n_estimators_50_subsamp...
reg_coxph,0.633514,0.026686,reg_coxph_l1_ratio_0.9_tol_1_max_iter_1000000_...
pc_hazard,0.561752,0.022399,pc_hazard_dropout_0.1_num_nodes_[100]_activati...
deep_hit,0.506651,0.093355,"deep_hit_dropout_0.1_num_nodes_[3, 5]_activati..."


In [25]:
results = []
column = "c_index_params"


for model_name, params in zip(c_index_df.index, c_index_df[column]):
    if model_name in DL_MODELS:
        c_index, concordance_td, ibs = get_results(params, model_name, X, y, X_test, y_test)
        results.append((model_name, c_index))
    else:
        c_index, concordance_td, ibs = get_results(params, model_name, X, y, X_test, y_test)
        results.append((model_name, c_index))

for m,val in sorted(results, key=lambda x: x[1], reverse=True):
    print(m, np.round(val,4))

deep_surv 0.6717
rsf 0.6643
reg_coxph 0.6639
gbst 0.6552
cox_time 0.6392
deep_hit 0.5798
pc_hazard 0.4927


___

## `concordance td`

In [26]:
concordance_td_df

Unnamed: 0,concordance_td,concordance_td_std,concordance_td_params
cox_time,0.669009,0.012341,"cox_time_dropout_0.1_num_nodes_[32, 128, 128]_..."
deep_hit,0.663193,0.012744,"deep_hit_dropout_0.3_num_nodes_[28, 28, 100, 2..."
rsf,0.661191,0.01431,rsf_n_estimators_200_max_depth_5_min_samples_s...
deep_surv,0.651445,0.026593,deep_surv_dropout_0.16_num_nodes_[41]_activati...
gbst,0.639095,0.017929,gbst_learning_rate_0.1_n_estimators_50_subsamp...
reg_coxph,0.633514,0.026686,reg_coxph_l1_ratio_0.9_tol_1_max_iter_1000000_...
pc_hazard,0.563487,0.05081,pc_hazard_dropout_0.1_num_nodes_[100]_activati...


In [27]:
results = []
column = "concordance_td_params"


for model_name, params in zip(concordance_td_df.index, concordance_td_df[column]):
    if model_name in DL_MODELS:
        c_index, concordance_td, ibs = get_results(params, model_name, X, y, X_test, y_test)
        results.append((model_name, concordance_td))
    else:
        c_index, concordance_td, ibs = get_results(params, model_name, X, y, X_test, y_test)
        results.append((model_name, concordance_td))

for m,val in sorted(results, key=lambda x: x[1], reverse=True):
    print(m, np.round(val,4))

deep_surv 0.6717
reg_coxph 0.6639
rsf 0.659
cox_time 0.6577
gbst 0.6552
deep_hit 0.6038
pc_hazard 0.4976


## `integrated brier score`

In [28]:
ibs_df

Unnamed: 0,ibs,ibs_std,ibs_params
deep_hit,0.113506,0.001865,"deep_hit_dropout_0.1_num_nodes_[16, 32, 64, 64..."
deep_surv,0.166265,0.006674,deep_surv_dropout_0.16_num_nodes_[41]_activati...
cox_time,0.16658,0.006295,"cox_time_dropout_0.1_num_nodes_[32, 32]_activa..."
rsf,0.168998,0.004748,rsf_n_estimators_200_max_depth_5_min_samples_s...
gbst,0.170241,0.003956,gbst_learning_rate_0.1_n_estimators_50_subsamp...
reg_coxph,0.170414,0.007362,reg_coxph_l1_ratio_0.1_tol_0.1_max_iter_1000_v...
pc_hazard,0.24723,0.008502,"pc_hazard_dropout_0.3_num_nodes_[16, 16, 16, 1..."


In [29]:
results = []
column = "ibs_params"


for model_name, params in zip(ibs_df.index, ibs_df[column]):
    if model_name in DL_MODELS:
        c_index, concordance_td, ibs = get_results(params, model_name, X, y, X_test, y_test)
        results.append((model_name, ibs))
    else:
        c_index, concordance_td, ibs = get_results(params, model_name, X, y, X_test, y_test)
        results.append((model_name, ibs))

for m,val in sorted(results, key=lambda x: x[1], reverse=False):
    print(m, np.round(val,4))

reg_coxph 0.1361
deep_surv 0.1419
cox_time 0.1479
rsf 0.1513
gbst 0.1528
deep_hit 0.3381
pc_hazard 0.4477
