# Train/Test/Validation Split

In [1]:
from dsm import datasets
from dsm.dsm_api import DeepSurvivalMachines

import numpy as np

from sklearn.model_selection import ParameterGrid
from sksurv.metrics import concordance_index_ipcw, brier_score, cumulative_dynamic_auc

### Loading the Support Dataset 

In [2]:
x, t, e = datasets.load_dataset('SUPPORT')

### Computing the required quantile of the data 

In [3]:
times = np.quantile(t[e == 1], [0.25, 0.5, 0.75]).tolist()

### Defining the train, test and validation size

In [4]:
data_size = len(x)
tr_size = round(data_size * 0.70)
vl_size = round(data_size * 0.10)
te_size = round(data_size * 0.20)

### Splitting the data into train, test and validation sets

In [8]:
x_train = x[: tr_size]
t_train = t[: tr_size]
e_train = e[: tr_size]

x_test = x[-te_size :]
t_test = t[-te_size :]
e_test = e[-te_size :]

x_val = x[tr_size : tr_size + vl_size]
t_val = t[tr_size : tr_size + vl_size]
e_val = e[tr_size : tr_size + vl_size]

In [9]:
et_train = np.array([(e_train[i], t_train[i]) for i in range(len(e_train))],
                 dtype = [('e', bool), ('t', int)])
et_test = np.array([(e_test[i], t_test[i]) for i in range(len(e_test))],
                 dtype = [('e', bool), ('t', int)])
et_val = np.array([(e_val[i], t_val[i]) for i in range(len(e_val))],
                 dtype = [('e', bool), ('t', int)])

In [10]:
print(et_train.shape, et_test.shape, et_val.shape)

(6374,) (1821,) (910,)


### Defining the parameter grid

In [11]:
param_grid = {'k' : [3, 4, 6],
              'distribution' : ['LogNormal', 'Weibull'],
              'learning_rate' : [ float(1e-4), float(1e-3)],
              'layers' : [ [], [100], [100, 100] ]
             }
params = ParameterGrid(param_grid)

### Validation 

In [12]:
for param in params:
    model = DeepSurvivalMachines(k = param['k'],
                                 distribution = param['distribution'],
                                 layers = param[ 'layers'])
    model.fit(x_train, t_train, e_train, learning_rate = param['learning_rate'])
    out_risk = model.predict_risk(x_val, times)
    out_survival = model.predict_survival(x_val, times)

    cis = []
    brs = []
    for i, _ in enumerate(times):
        cis.append(concordance_index_ipcw(et_train, et_val, out_risk[:, i], times[i])[0])

    brs.append(brier_score(et_train, et_val, out_survival, times)[1])
    print("Concordance Index:", np.mean(cis))
    print("Brier Score:", np.mean(brs))
    cdauc = []
    for i, _ in enumerate(times):
        cdauc.append(cumulative_dynamic_auc(et_train, et_val, out_risk[:, i], times[i])[0])
    print("Cumulative_dynamic AUC ", np.mean(cdauc))
    print(param)
    #scores  = scores + [[grid, np.mean(cis), np.mean(brs), np.mean(cdauc)]]

  Variable._execution_engine.run_backward(
 12%|█▏        | 1240/10000 [00:02<00:20, 430.34it/s]
100%|██████████| 1/1 [00:00<00:00,  2.40it/s]
  0%|          | 46/10000 [00:00<00:22, 452.18it/s]

Concordance Index: 0.45340340953800856
Brier Score: 0.2120856405467805
Cumulative_dynamic AUC  0.4325276050509048
{'distribution': 'LogNormal', 'k': 3, 'layers': [], 'learning_rate': 0.0001}


 12%|█▏        | 1240/10000 [00:02<00:18, 482.10it/s]
100%|██████████| 1/1 [00:00<00:00,  6.98it/s]
  1%|          | 61/10000 [00:00<00:16, 602.23it/s]

Concordance Index: 0.5868530920885563
Brier Score: 0.2000913347679684
Cumulative_dynamic AUC  0.6063103268431748
{'distribution': 'LogNormal', 'k': 3, 'layers': [], 'learning_rate': 0.001}


 12%|█▏        | 1240/10000 [00:02<00:14, 592.76it/s]
100%|██████████| 1/1 [00:00<00:00,  4.29it/s]
  1%|          | 58/10000 [00:00<00:17, 572.61it/s]

Concordance Index: 0.5196558945077734
Brier Score: 0.20453532747201106
Cumulative_dynamic AUC  0.5207638770294832
{'distribution': 'LogNormal', 'k': 3, 'layers': [100], 'learning_rate': 0.0001}


 12%|█▏        | 1240/10000 [00:02<00:14, 592.35it/s]
100%|██████████| 1/1 [00:00<00:00,  5.26it/s]
  1%|          | 64/10000 [00:00<00:15, 636.81it/s]

Concordance Index: 0.6859982768146832
Brier Score: 0.18952182660274805
Cumulative_dynamic AUC  0.717028369155304
{'distribution': 'LogNormal', 'k': 3, 'layers': [100], 'learning_rate': 0.001}


 12%|█▏        | 1240/10000 [00:02<00:16, 540.19it/s]
100%|██████████| 1/1 [00:00<00:00,  4.35it/s]
  0%|          | 33/10000 [00:00<00:30, 326.17it/s]

Concordance Index: 0.6365066758765581
Brier Score: 0.20278208010307205
Cumulative_dynamic AUC  0.6556533791977494
{'distribution': 'LogNormal', 'k': 3, 'layers': [100, 100], 'learning_rate': 0.0001}


 12%|█▏        | 1240/10000 [00:02<00:15, 570.21it/s]
100%|██████████| 1/1 [00:00<00:00,  4.71it/s]
  1%|          | 60/10000 [00:00<00:16, 598.40it/s]

Concordance Index: 0.6980410911851789
Brier Score: 0.186164633385437
Cumulative_dynamic AUC  0.7278002430124927
{'distribution': 'LogNormal', 'k': 3, 'layers': [100, 100], 'learning_rate': 0.001}


 12%|█▏        | 1240/10000 [00:02<00:14, 587.06it/s]
100%|██████████| 1/1 [00:00<00:00,  5.90it/s]
  1%|          | 51/10000 [00:00<00:19, 506.41it/s]

Concordance Index: 0.46215208842170163
Brier Score: 0.20839325631897473
Cumulative_dynamic AUC  0.46804170098645265
{'distribution': 'LogNormal', 'k': 4, 'layers': [], 'learning_rate': 0.0001}


 12%|█▏        | 1240/10000 [00:02<00:15, 582.93it/s]
100%|██████████| 1/1 [00:00<00:00,  6.32it/s]
  1%|          | 60/10000 [00:00<00:16, 593.53it/s]

Concordance Index: 0.6354333768613575
Brier Score: 0.1960336595711897
Cumulative_dynamic AUC  0.6636556147146971
{'distribution': 'LogNormal', 'k': 4, 'layers': [], 'learning_rate': 0.001}


 12%|█▏        | 1240/10000 [00:02<00:14, 588.32it/s]
100%|██████████| 1/1 [00:00<00:00,  5.39it/s]
  1%|          | 59/10000 [00:00<00:16, 585.72it/s]

Concordance Index: 0.5942250292158068
Brier Score: 0.2024949091777255
Cumulative_dynamic AUC  0.6038333298925812
{'distribution': 'LogNormal', 'k': 4, 'layers': [100], 'learning_rate': 0.0001}


 12%|█▏        | 1240/10000 [00:02<00:14, 585.69it/s]
100%|██████████| 1/1 [00:00<00:00,  5.37it/s]
  1%|          | 60/10000 [00:00<00:16, 595.54it/s]

Concordance Index: 0.680827491886742
Brier Score: 0.1909926945502677
Cumulative_dynamic AUC  0.7087961885128095
{'distribution': 'LogNormal', 'k': 4, 'layers': [100], 'learning_rate': 0.001}


 12%|█▏        | 1240/10000 [00:01<00:14, 620.06it/s]
100%|██████████| 1/1 [00:00<00:00,  5.20it/s]
  1%|          | 59/10000 [00:00<00:16, 586.40it/s]

Concordance Index: 0.6529878889301487
Brier Score: 0.20270261177977067
Cumulative_dynamic AUC  0.6795711191738834
{'distribution': 'LogNormal', 'k': 4, 'layers': [100, 100], 'learning_rate': 0.0001}


 12%|█▏        | 1240/10000 [00:01<00:13, 625.88it/s]
100%|██████████| 1/1 [00:00<00:00,  5.21it/s]
  1%|          | 61/10000 [00:00<00:16, 602.37it/s]

Concordance Index: 0.6861435130326664
Brier Score: 0.19044132852802934
Cumulative_dynamic AUC  0.7207597059035754
{'distribution': 'LogNormal', 'k': 4, 'layers': [100, 100], 'learning_rate': 0.001}


 12%|█▏        | 1240/10000 [00:02<00:15, 570.97it/s]
100%|██████████| 1/1 [00:00<00:00,  3.40it/s]
  1%|          | 59/10000 [00:00<00:16, 587.93it/s]

Concordance Index: 0.5302643526526526
Brier Score: 0.20577185843790813
Cumulative_dynamic AUC  0.5259841488637905
{'distribution': 'LogNormal', 'k': 6, 'layers': [], 'learning_rate': 0.0001}


 12%|█▏        | 1240/10000 [00:02<00:15, 556.44it/s]
100%|██████████| 1/1 [00:00<00:00,  4.89it/s]
  1%|          | 57/10000 [00:00<00:17, 569.55it/s]

Concordance Index: 0.6406275340947074
Brier Score: 0.19703457713583417
Cumulative_dynamic AUC  0.6640191026181977
{'distribution': 'LogNormal', 'k': 6, 'layers': [], 'learning_rate': 0.001}


 12%|█▏        | 1240/10000 [00:02<00:14, 605.16it/s]
100%|██████████| 1/1 [00:00<00:00,  4.40it/s]
  1%|          | 59/10000 [00:00<00:17, 576.88it/s]

Concordance Index: 0.5550278989857907
Brier Score: 0.20467083817040485
Cumulative_dynamic AUC  0.5732934023211531
{'distribution': 'LogNormal', 'k': 6, 'layers': [100], 'learning_rate': 0.0001}


 12%|█▏        | 1240/10000 [00:02<00:15, 556.80it/s]
100%|██████████| 1/1 [00:00<00:00,  4.44it/s]
  1%|          | 58/10000 [00:00<00:17, 574.02it/s]

Concordance Index: 0.6935590743040007
Brier Score: 0.1897312861601562
Cumulative_dynamic AUC  0.7242164820273551
{'distribution': 'LogNormal', 'k': 6, 'layers': [100], 'learning_rate': 0.001}


 12%|█▏        | 1240/10000 [00:02<00:14, 612.05it/s]
100%|██████████| 1/1 [00:00<00:00,  4.24it/s]
  1%|          | 60/10000 [00:00<00:16, 591.39it/s]

Concordance Index: 0.6271994007602202
Brier Score: 0.2038126153471652
Cumulative_dynamic AUC  0.6461344733276411
{'distribution': 'LogNormal', 'k': 6, 'layers': [100, 100], 'learning_rate': 0.0001}


 12%|█▏        | 1240/10000 [00:02<00:14, 599.20it/s]
100%|██████████| 1/1 [00:00<00:00,  4.21it/s]
  0%|          | 3/10000 [00:00<05:41, 29.30it/s]

Concordance Index: 0.6893133200733587
Brier Score: 0.19022108973346144
Cumulative_dynamic AUC  0.7209618134702657
{'distribution': 'LogNormal', 'k': 6, 'layers': [100, 100], 'learning_rate': 0.001}


 18%|█▊        | 1845/10000 [00:02<00:12, 649.49it/s]
100%|██████████| 1/1 [00:00<00:00,  8.34it/s]
  1%|          | 67/10000 [00:00<00:14, 664.84it/s]

Concordance Index: 0.5080991660444875
Brier Score: 0.2106066316999924
Cumulative_dynamic AUC  0.5001444692832061
{'distribution': 'Weibull', 'k': 3, 'layers': [], 'learning_rate': 0.0001}


 18%|█▊        | 1845/10000 [00:02<00:12, 670.87it/s]
100%|██████████| 1/1 [00:00<00:00,  8.40it/s]
  1%|          | 65/10000 [00:00<00:15, 641.34it/s]

Concordance Index: 0.5425327274077133
Brier Score: 0.20629606293671854
Cumulative_dynamic AUC  0.5391945715171822
{'distribution': 'Weibull', 'k': 3, 'layers': [], 'learning_rate': 0.001}


 18%|█▊        | 1845/10000 [00:02<00:12, 676.74it/s]
100%|██████████| 1/1 [00:00<00:00,  6.06it/s]
  1%|          | 68/10000 [00:00<00:14, 671.83it/s]

Concordance Index: 0.5931411960701646
Brier Score: 0.20299695028498696
Cumulative_dynamic AUC  0.6115004776175379
{'distribution': 'Weibull', 'k': 3, 'layers': [100], 'learning_rate': 0.0001}


 18%|█▊        | 1845/10000 [00:02<00:12, 678.30it/s]
100%|██████████| 1/1 [00:00<00:00,  7.34it/s]
  1%|          | 64/10000 [00:00<00:15, 632.59it/s]

Concordance Index: 0.7102077987736953
Brier Score: 0.1849977845614227
Cumulative_dynamic AUC  0.7388547775107428
{'distribution': 'Weibull', 'k': 3, 'layers': [100], 'learning_rate': 0.001}


 18%|█▊        | 1845/10000 [00:02<00:12, 676.02it/s]
100%|██████████| 1/1 [00:00<00:00,  6.43it/s]
  1%|          | 66/10000 [00:00<00:15, 650.75it/s]

Concordance Index: 0.620941038907524
Brier Score: 0.2036507194684802
Cumulative_dynamic AUC  0.6288558992463418
{'distribution': 'Weibull', 'k': 3, 'layers': [100, 100], 'learning_rate': 0.0001}


 18%|█▊        | 1845/10000 [00:02<00:12, 676.90it/s]
100%|██████████| 1/1 [00:00<00:00,  6.33it/s]
  1%|          | 64/10000 [00:00<00:15, 639.46it/s]

Concordance Index: 0.7134676097735823
Brier Score: 0.181088136045393
Cumulative_dynamic AUC  0.7446297372894702
{'distribution': 'Weibull', 'k': 3, 'layers': [100, 100], 'learning_rate': 0.001}


 18%|█▊        | 1845/10000 [00:03<00:15, 530.43it/s]
100%|██████████| 1/1 [00:00<00:00,  7.35it/s]
  1%|          | 66/10000 [00:00<00:15, 653.42it/s]

Concordance Index: 0.47659870040858504
Brier Score: 0.21757234461873573
Cumulative_dynamic AUC  0.4538623560114727
{'distribution': 'Weibull', 'k': 4, 'layers': [], 'learning_rate': 0.0001}


 18%|█▊        | 1845/10000 [00:03<00:15, 522.57it/s]
100%|██████████| 1/1 [00:00<00:00,  7.17it/s]
  1%|          | 64/10000 [00:00<00:15, 637.43it/s]

Concordance Index: 0.6031573793902991
Brier Score: 0.20091932551942968
Cumulative_dynamic AUC  0.611529985311081
{'distribution': 'Weibull', 'k': 4, 'layers': [], 'learning_rate': 0.001}


 18%|█▊        | 1845/10000 [00:02<00:13, 622.90it/s]
100%|██████████| 1/1 [00:00<00:00,  5.22it/s]
  1%|          | 67/10000 [00:00<00:14, 668.95it/s]

Concordance Index: 0.6150765213526143
Brier Score: 0.20079626063857792
Cumulative_dynamic AUC  0.6266080206140123
{'distribution': 'Weibull', 'k': 4, 'layers': [100], 'learning_rate': 0.0001}


 18%|█▊        | 1845/10000 [00:02<00:12, 678.73it/s]
100%|██████████| 1/1 [00:00<00:00,  6.28it/s]
  1%|          | 67/10000 [00:00<00:14, 669.45it/s]

Concordance Index: 0.7042037614520913
Brier Score: 0.18520804374817076
Cumulative_dynamic AUC  0.7334330816600197
{'distribution': 'Weibull', 'k': 4, 'layers': [100], 'learning_rate': 0.001}


 18%|█▊        | 1845/10000 [00:02<00:12, 628.36it/s]
100%|██████████| 1/1 [00:00<00:00,  5.86it/s]
  1%|          | 67/10000 [00:00<00:14, 666.89it/s]

Concordance Index: 0.6466840182073358
Brier Score: 0.20301336888640495
Cumulative_dynamic AUC  0.6591577721857974
{'distribution': 'Weibull', 'k': 4, 'layers': [100, 100], 'learning_rate': 0.0001}


 18%|█▊        | 1845/10000 [00:02<00:12, 676.36it/s]
100%|██████████| 1/1 [00:00<00:00,  5.67it/s]
  1%|          | 67/10000 [00:00<00:14, 669.20it/s]

Concordance Index: 0.7079620752337589
Brier Score: 0.18433734466778193
Cumulative_dynamic AUC  0.738622393958475
{'distribution': 'Weibull', 'k': 4, 'layers': [100, 100], 'learning_rate': 0.001}


 18%|█▊        | 1845/10000 [00:02<00:12, 630.85it/s]
100%|██████████| 1/1 [00:00<00:00,  4.74it/s]
  0%|          | 38/10000 [00:00<00:26, 376.41it/s]

Concordance Index: 0.528617409132099
Brier Score: 0.20587386707056146
Cumulative_dynamic AUC  0.5365072096396085
{'distribution': 'Weibull', 'k': 6, 'layers': [], 'learning_rate': 0.0001}


 18%|█▊        | 1845/10000 [00:02<00:13, 624.17it/s]
100%|██████████| 1/1 [00:00<00:00,  5.77it/s]
  1%|          | 66/10000 [00:00<00:15, 658.18it/s]

Concordance Index: 0.6054747974071183
Brier Score: 0.19937701740665517
Cumulative_dynamic AUC  0.61779817030874
{'distribution': 'Weibull', 'k': 6, 'layers': [], 'learning_rate': 0.001}


 18%|█▊        | 1845/10000 [00:02<00:12, 675.32it/s]
100%|██████████| 1/1 [00:00<00:00,  5.19it/s]
  1%|          | 67/10000 [00:00<00:14, 663.90it/s]

Concordance Index: 0.5443869274404641
Brier Score: 0.2057537264300032
Cumulative_dynamic AUC  0.5398483361385489
{'distribution': 'Weibull', 'k': 6, 'layers': [100], 'learning_rate': 0.0001}


 18%|█▊        | 1845/10000 [00:02<00:12, 675.90it/s]
100%|██████████| 1/1 [00:00<00:00,  4.64it/s]
  1%|          | 65/10000 [00:00<00:15, 648.30it/s]

Concordance Index: 0.7085613962141747
Brier Score: 0.18772750978036887
Cumulative_dynamic AUC  0.7349156454958966
{'distribution': 'Weibull', 'k': 6, 'layers': [100], 'learning_rate': 0.001}


 18%|█▊        | 1845/10000 [00:03<00:13, 589.89it/s]
100%|██████████| 1/1 [00:00<00:00,  4.32it/s]
  1%|          | 65/10000 [00:00<00:15, 646.28it/s]

Concordance Index: 0.6568985178048318
Brier Score: 0.202961664621087
Cumulative_dynamic AUC  0.6705386354241947
{'distribution': 'Weibull', 'k': 6, 'layers': [100, 100], 'learning_rate': 0.0001}


 18%|█▊        | 1845/10000 [00:02<00:11, 680.03it/s]
100%|██████████| 1/1 [00:00<00:00,  4.08it/s]

Concordance Index: 0.6995013567080287
Brier Score: 0.18538782612760984
Cumulative_dynamic AUC  0.7319456134570768
{'distribution': 'Weibull', 'k': 6, 'layers': [100, 100], 'learning_rate': 0.001}





### Training  

In [13]:
model = DeepSurvivalMachines(k = 3, distribution = 'Weibull', layers = [100])
model.fit(x_train, t_train, e_train, learning_rate = float(1e-3))

 18%|█▊        | 1845/10000 [00:03<00:14, 573.04it/s]
100%|██████████| 1/1 [00:00<00:00,  7.03it/s]


### Prediction 

In [14]:
out_risk = model.predict_risk(x_test, times)
out_survival = model.predict_survival(x_test, times)

### Evaluation

In [15]:
cis = []
brs = []
for i, _ in enumerate(times):
    cis.append(concordance_index_ipcw(et_train, et_test, out_risk[:, i], times[i])[0])

brs.append(brier_score(et_train, et_test, out_survival, times)[1])
print("Concordance Index:", np.mean(cis))
print("Brier Score:", np.mean(brs))
cdauc = []
for i, _ in enumerate(times):
    cdauc.append(cumulative_dynamic_auc(et_train, et_test, out_risk[:, i], times[i])[0])
print("Cumulative_dynamic AUC ", np.mean(cdauc))

Concordance Index: 0.7023360857164943
Brier Score: 0.17641956933412173
Cumulative_dynamic AUC  0.7260571933183376
