# Train/Test/Validation Split

In [11]:
from dsm import datasets
from dsm.dsm_api import DeepSurvivalMachines
import numpy as np
from sklearn.model_selection import ParameterGrid
from sksurv.metrics import concordance_index_ipcw, brier_score,cumulative_dynamic_auc

### Loading the Support Dataset 

In [27]:
x, t, e = datasets.load_dataset('SUPPORT')

### Computing the required quantile of the data 

In [28]:
times = np.quantile(t[e == 1], [0.25, 0.5, 0.75]).tolist()

### Defining the train, test and validation size

In [29]:
train_size = round(9105 * 0.70)
val_size = round(9105 * 0.10)
test_size = round(9105 * 0.20)

### Splitting the data into train, test and validation sets

In [15]:
x_train = x[: train_size]
t_train = t[: train_size]
e_train = e[: train_size]

x_test = x[-test_size :]
t_test = t[-test_size :]
e_test = e[-test_size :]

x_val = x[train_size : train_size + val_size]
t_val = t[train_size : train_size + val_size]
e_val = e[train_size : train_size + val_size]

In [16]:
et_train = np.array([(e_train[i], t_train[i]) for i in range(len(e_train))],
                 dtype = [('e', bool), ('t', int)])
et_test = np.array([(e_test[i], t_test[i]) for i in range(len(e_test))],
                 dtype = [('e', bool), ('t', int)])
et_val = np.array([(e_val[i], t_val[i]) for i in range(len(e_val))],
                 dtype = [('e', bool), ('t', int)])

In [17]:
print(et_train.shape, et_test.shape, et_val.shape)

(6374,) (1821,) (910,)


### Defining the parameter grid

In [26]:
param_grid = {'k' : [1, 2, 3, 4, 5, 6, 7, 8, 9, 10],
              'distribution' : ['LogNormal', 'Weibull']}
grids = ParameterGrid(param_grid)

### Validation 

In [19]:
for grid in grids:
    model = DeepSurvivalMachines(k = grid['k'], distribution = grid['distribution'])
    model.fit(x_train, t_train, e_train)
    out_risk = model.predict_risk(x_val, times)
    out_survival = model.predict_survival(x_val, times)

    cis = []
    brs = []
    for i, _ in enumerate(times):
        cis.append(concordance_index_ipcw(et_train, et_val, out_risk[:, i], times[i])[0])

    brs.append(brier_score(et_train, et_val, out_survival, times)[1])
    print("Concordance Index:", np.mean(cis))
    print("Brier Score:", np.mean(brs))
    cdauc = []
    for i, _ in enumerate(times):
        cdauc.append(cumulative_dynamic_auc(et_train, et_val, out_risk[:, i], times[i])[0])
    print("Cumulative_dynamic AUC ", np.mean(cdauc))
    print(grid)
    #scores  = scores + [[grid, np.mean(cis), np.mean(brs), np.mean(cdauc)]]

  Variable._execution_engine.run_backward(
 12%|█▏        | 1240/10000 [00:02<00:18, 461.10it/s]
100%|██████████| 1/1 [00:00<00:00,  3.54it/s]
  1%|          | 59/10000 [00:00<00:16, 585.51it/s]

Concordance Index: 0.560607191724184
Brier Score: 0.2021620748498746
Cumulative_dynamic AUC  0.5750584452945019
{'distribution': 'LogNormal', 'k': 1}


 12%|█▏        | 1240/10000 [00:02<00:14, 589.56it/s]
100%|██████████| 1/1 [00:00<00:00,  4.49it/s]
  1%|          | 59/10000 [00:00<00:16, 589.67it/s]

Concordance Index: 0.5821705308796168
Brier Score: 0.2009241698319314
Cumulative_dynamic AUC  0.5854638786373166
{'distribution': 'LogNormal', 'k': 2}


 12%|█▏        | 1240/10000 [00:02<00:14, 617.82it/s]
100%|██████████| 1/1 [00:00<00:00,  7.23it/s]
  0%|          | 35/10000 [00:00<00:28, 348.73it/s]

Concordance Index: 0.6391032999998018
Brier Score: 0.19475588885625786
Cumulative_dynamic AUC  0.667565202099054
{'distribution': 'LogNormal', 'k': 3}


 12%|█▏        | 1240/10000 [00:02<00:16, 525.22it/s]
100%|██████████| 1/1 [00:00<00:00,  6.05it/s]
  1%|          | 58/10000 [00:00<00:17, 578.54it/s]

Concordance Index: 0.5691713341823809
Brier Score: 0.20206039508740525
Cumulative_dynamic AUC  0.5812913957548488
{'distribution': 'LogNormal', 'k': 4}


 12%|█▏        | 1240/10000 [00:02<00:16, 529.57it/s]
100%|██████████| 1/1 [00:00<00:00,  5.47it/s]
  1%|          | 61/10000 [00:00<00:16, 606.60it/s]

Concordance Index: 0.638391286210458
Brier Score: 0.19552422017250817
Cumulative_dynamic AUC  0.667041705838899
{'distribution': 'LogNormal', 'k': 5}


 12%|█▏        | 1240/10000 [00:02<00:15, 577.03it/s]
100%|██████████| 1/1 [00:00<00:00,  4.83it/s]
  1%|          | 61/10000 [00:00<00:16, 603.20it/s]

Concordance Index: 0.6597539885696767
Brier Score: 0.19655277378417849
Cumulative_dynamic AUC  0.6821135013162273
{'distribution': 'LogNormal', 'k': 6}


 12%|█▏        | 1240/10000 [00:02<00:14, 589.92it/s]
100%|██████████| 1/1 [00:00<00:00,  4.57it/s]
  1%|          | 57/10000 [00:00<00:17, 562.29it/s]

Concordance Index: 0.6076619002812838
Brier Score: 0.20006168194128313
Cumulative_dynamic AUC  0.6221518856844712
{'distribution': 'LogNormal', 'k': 7}


 12%|█▏        | 1240/10000 [00:02<00:15, 572.40it/s]
100%|██████████| 1/1 [00:00<00:00,  3.60it/s]
  1%|          | 60/10000 [00:00<00:16, 594.98it/s]

Concordance Index: 0.6415817265114007
Brier Score: 0.1962276338812652
Cumulative_dynamic AUC  0.6752203150348444
{'distribution': 'LogNormal', 'k': 8}


 12%|█▏        | 1240/10000 [00:02<00:15, 577.73it/s]
100%|██████████| 1/1 [00:00<00:00,  3.51it/s]
  1%|          | 58/10000 [00:00<00:17, 575.65it/s]

Concordance Index: 0.6437421174327902
Brier Score: 0.19711185898098124
Cumulative_dynamic AUC  0.6714473812081884
{'distribution': 'LogNormal', 'k': 9}


 12%|█▏        | 1240/10000 [00:02<00:15, 580.48it/s]
100%|██████████| 1/1 [00:00<00:00,  3.34it/s]
  0%|          | 9/10000 [00:00<01:51, 89.49it/s]

Concordance Index: 0.6026501592204184
Brier Score: 0.19977362494150183
Cumulative_dynamic AUC  0.6252710590213463
{'distribution': 'LogNormal', 'k': 10}


 18%|█▊        | 1845/10000 [00:02<00:13, 616.33it/s]
100%|██████████| 1/1 [00:00<00:00,  8.79it/s]
  1%|          | 67/10000 [00:00<00:14, 669.76it/s]

Concordance Index: 0.5930035504562823
Brier Score: 0.20354793233124888
Cumulative_dynamic AUC  0.5980019720469677
{'distribution': 'Weibull', 'k': 1}


 18%|█▊        | 1845/10000 [00:02<00:12, 640.17it/s]
100%|██████████| 1/1 [00:00<00:00,  9.14it/s]
  1%|          | 67/10000 [00:00<00:14, 668.76it/s]

Concordance Index: 0.6344473936632284
Brier Score: 0.1938714579798196
Cumulative_dynamic AUC  0.6564552629449063
{'distribution': 'Weibull', 'k': 2}


 18%|█▊        | 1845/10000 [00:02<00:12, 640.77it/s]
100%|██████████| 1/1 [00:00<00:00,  8.14it/s]
  1%|          | 68/10000 [00:00<00:14, 672.98it/s]

Concordance Index: 0.607695901992059
Brier Score: 0.2020926742272414
Cumulative_dynamic AUC  0.6084716762798367
{'distribution': 'Weibull', 'k': 3}


 18%|█▊        | 1845/10000 [00:03<00:13, 611.24it/s]
100%|██████████| 1/1 [00:00<00:00,  5.16it/s]
  1%|          | 66/10000 [00:00<00:15, 656.33it/s]

Concordance Index: 0.6278288485536119
Brier Score: 0.19723910014812573
Cumulative_dynamic AUC  0.6465486995855049
{'distribution': 'Weibull', 'k': 4}


 18%|█▊        | 1845/10000 [00:02<00:12, 639.00it/s]
100%|██████████| 1/1 [00:00<00:00,  6.62it/s]
  1%|          | 68/10000 [00:00<00:14, 675.63it/s]

Concordance Index: 0.623721870224501
Brier Score: 0.1987119092353963
Cumulative_dynamic AUC  0.6423785076577243
{'distribution': 'Weibull', 'k': 5}


 18%|█▊        | 1845/10000 [00:02<00:12, 641.75it/s]
100%|██████████| 1/1 [00:00<00:00,  5.96it/s]
  1%|          | 67/10000 [00:00<00:14, 663.26it/s]

Concordance Index: 0.623719372487444
Brier Score: 0.19890029152656974
Cumulative_dynamic AUC  0.6307186810695534
{'distribution': 'Weibull', 'k': 6}


 18%|█▊        | 1845/10000 [00:02<00:12, 647.39it/s]
100%|██████████| 1/1 [00:00<00:00,  5.49it/s]
  1%|          | 68/10000 [00:00<00:14, 678.26it/s]

Concordance Index: 0.5480812745613538
Brier Score: 0.20493838163735953
Cumulative_dynamic AUC  0.558182664119938
{'distribution': 'Weibull', 'k': 7}


 18%|█▊        | 1845/10000 [00:02<00:12, 634.11it/s]
100%|██████████| 1/1 [00:00<00:00,  5.00it/s]
  1%|          | 67/10000 [00:00<00:14, 669.50it/s]

Concordance Index: 0.5534216773892559
Brier Score: 0.20349522562214395
Cumulative_dynamic AUC  0.5694979875996248
{'distribution': 'Weibull', 'k': 8}


 18%|█▊        | 1845/10000 [00:03<00:13, 611.31it/s]
100%|██████████| 1/1 [00:00<00:00,  4.60it/s]
  1%|          | 67/10000 [00:00<00:14, 666.44it/s]

Concordance Index: 0.509918346008642
Brier Score: 0.20704238283775456
Cumulative_dynamic AUC  0.5136175482948849
{'distribution': 'Weibull', 'k': 9}


 18%|█▊        | 1845/10000 [00:03<00:13, 601.67it/s]
100%|██████████| 1/1 [00:00<00:00,  3.90it/s]

Concordance Index: 0.5457867524346355
Brier Score: 0.20434785349912266
Cumulative_dynamic AUC  0.5518614097391491
{'distribution': 'Weibull', 'k': 10}





### Training  

In [23]:
model = DeepSurvivalMachines(k = 6, distribution = 'LogNormal')
model.fit(x_train, t_train, e_train)
out_risk = model.predict_risk(x_test, times)
out_survival = model.predict_survival(x_test, times)

 12%|█▏        | 1240/10000 [00:02<00:14, 586.27it/s]
100%|██████████| 1/1 [00:00<00:00,  4.83it/s]


### Prediction 

In [24]:
out_risk = model.predict_risk(x_test, times)
out_survival = model.predict_survival(x_test, times)

### Evaluation

In [25]:
cis = []
brs = []
for i, _ in enumerate(times):
    cis.append(concordance_index_ipcw(et_train, et_test, out_risk[:, i], times[i])[0])

brs.append(brier_score(et_train, et_test, out_survival, times)[1])
print("Concordance Index:", np.mean(cis))
print("Brier Score:", np.mean(brs))
cdauc = []
for i, _ in enumerate(times):
    cdauc.append(cumulative_dynamic_auc(et_train, et_test, out_risk[:, i], times[i])[0])
print("Cumulative_dynamic AUC ", np.mean(cdauc))

Concordance Index: 0.6576888784160551
Brier Score: 0.18412562484999506
Cumulative_dynamic AUC  0.684763312097342
