In [219]:
from dsm import datasets
from dsm.dsm_api import DeepSurvivalMachines
import numpy as np
from sklearn.model_selection import ParameterGrid
from sksurv.metrics import concordance_index_ipcw, brier_score,cumulative_dynamic_auc

In [220]:
x, t, e = datasets.load_dataset('SUPPORT')

In [221]:
times = np.quantile(t[e == 1], [0.25, 0.5, 0.75]).tolist()

In [222]:
times

[14.0, 58.0, 252.0]

In [223]:
print(x.shape)
print(t.shape)
print(e.shape)

(9105, 44)
(9105,)
(9105,)


In [227]:
train_size = round(9105*0.70)
val_size = round(9105*0.10)
test_size = round(9105*0.20)

x_train, t_train, e_train = x[: train_size], t[: train_size], e[:train_size]
x_test,  t_test,  e_test = x[-test_size:], t[-test_size:], e[-test_size:]
x_val,  t_val,  e_val = x[train_size:train_size+val_size], t[train_size:train_size+val_size], e[train_size:train_size+val_size]

In [228]:
x_test.shape

(1821, 44)

In [229]:
x_val.shape

(910, 44)

In [230]:
x_train.shape

(6374, 44)

In [231]:
print(x_train.shape, t_train.shape, e_train.shape)
print(x_val.shape, t_val.shape, e_val.shape)
print(x_test.shape, t_test.shape, e_test.shape)

(6374, 44) (6374,) (6374,)
(910, 44) (910,) (910,)
(1821, 44) (1821,) (1821,)


In [232]:
et_train = np.array([(e_train[i], t_train[i]) for i in range(len(e_train))],
                 dtype=[('e', bool), ('t', int)])
et_test = np.array([(e_test[i], t_test[i]) for i in range(len(e_test))],
                 dtype=[('e', bool), ('t', int)])
et_val = np.array([(e_val[i], t_val[i]) for i in range(len(e_val))],
                 dtype=[('e', bool), ('t', int)])

In [233]:
print(et_train.shape, et_test.shape, et_val.shape)

(6374,) (1821,) (910,)


In [234]:
param_grid = {'k': [1, 2, 3, 4, 5, 6, 7, 8, 9, 10],
              'distribution': ['LogNormal', 'Weibull']}

In [235]:
grids = ParameterGrid(param_grid)

In [205]:
#scores = []
for grid in grids:
    model = DeepSurvivalMachines(k = grid['k'], distribution = grid['distribution'])
    model.fit(x_train, t_train, e_train)
    out_risk = model.predict_risk(x_val, times)
    out_survival = model.predict_survival(x_val, times)
    
    cis = []
    brs = []
    for i in range(len(times)):
        cis.append(concordance_index_ipcw(et_train, et_val, out_risk[:, i], times[i])[0])
    
    brs.append(brier_score(et_train, et_val, out_survival, times )[1])
    print ("Concordance Index:", np.mean(cis))
    print ("Brier Score:", np.mean(brs))
    cdauc=[]
    for i in range(len(times)):
        cdauc.append(cumulative_dynamic_auc(et_train, et_val, out_risk[:,i], times[i])[0])
    print ("Cumulative_dynamic AUC ", np.mean(cdauc))  
    print(grid)
    #scores  = scores + [[grid, np.mean(cis), np.mean(brs), np.mean(cdauc)]]
    

 12%|█▏        | 1240/10000 [00:02<00:16, 531.95it/s]
100%|██████████| 1/1 [00:00<00:00,  5.54it/s]
  0%|          | 44/10000 [00:00<00:22, 437.38it/s]

Concordance Index: 0.5092904206321293
Brier Score: 0.2094134782813962
Cumulative_dynamic AUC  0.5086922832567786
{'distribution': 'LogNormal', 'k': 1}


 12%|█▏        | 1240/10000 [00:02<00:16, 524.13it/s]
100%|██████████| 1/1 [00:00<00:00,  7.85it/s]
  1%|          | 59/10000 [00:00<00:16, 589.59it/s]

Concordance Index: 0.6225696126219331
Brier Score: 0.19564178498171694
Cumulative_dynamic AUC  0.6463304993181905
{'distribution': 'LogNormal', 'k': 2}


 12%|█▏        | 1240/10000 [00:02<00:17, 511.63it/s]
100%|██████████| 1/1 [00:00<00:00,  7.04it/s]
  0%|          | 39/10000 [00:00<00:25, 386.10it/s]

Concordance Index: 0.6364015485450529
Brier Score: 0.19561785074208204
Cumulative_dynamic AUC  0.6648284665873677
{'distribution': 'LogNormal', 'k': 3}


 12%|█▏        | 1240/10000 [00:02<00:16, 518.60it/s]
100%|██████████| 1/1 [00:00<00:00,  3.31it/s]
  1%|          | 60/10000 [00:00<00:16, 592.42it/s]

Concordance Index: 0.5991297199199518
Brier Score: 0.19911243021089486
Cumulative_dynamic AUC  0.6168034589116687
{'distribution': 'LogNormal', 'k': 4}


 12%|█▏        | 1240/10000 [00:02<00:15, 574.78it/s]
100%|██████████| 1/1 [00:00<00:00,  5.39it/s]
  1%|          | 60/10000 [00:00<00:16, 590.91it/s]

Concordance Index: 0.5878718504990289
Brier Score: 0.2002449075616305
Cumulative_dynamic AUC  0.6084366081959983
{'distribution': 'LogNormal', 'k': 5}


 12%|█▏        | 1240/10000 [00:02<00:14, 589.23it/s]
100%|██████████| 1/1 [00:00<00:00,  4.88it/s]
  1%|          | 60/10000 [00:00<00:16, 596.45it/s]

Concordance Index: 0.6368460429833823
Brier Score: 0.1979542865459052
Cumulative_dynamic AUC  0.6519796443481246
{'distribution': 'LogNormal', 'k': 6}


 12%|█▏        | 1240/10000 [00:02<00:15, 574.59it/s]
100%|██████████| 1/1 [00:00<00:00,  4.32it/s]
  1%|          | 60/10000 [00:00<00:16, 593.98it/s]

Concordance Index: 0.6511055559469626
Brier Score: 0.1955849172557882
Cumulative_dynamic AUC  0.6810514142065088
{'distribution': 'LogNormal', 'k': 7}


 12%|█▏        | 1240/10000 [00:02<00:15, 580.79it/s]
100%|██████████| 1/1 [00:00<00:00,  3.64it/s]
  0%|          | 29/10000 [00:00<00:34, 285.95it/s]

Concordance Index: 0.6331446267971144
Brier Score: 0.19876817075613976
Cumulative_dynamic AUC  0.6548590739887442
{'distribution': 'LogNormal', 'k': 8}


 12%|█▏        | 1240/10000 [00:02<00:17, 490.66it/s]
100%|██████████| 1/1 [00:00<00:00,  3.58it/s]
  1%|          | 59/10000 [00:00<00:17, 582.17it/s]

Concordance Index: 0.6475819691793351
Brier Score: 0.19713614022219517
Cumulative_dynamic AUC  0.6700905021936281
{'distribution': 'LogNormal', 'k': 9}


 12%|█▏        | 1240/10000 [00:02<00:16, 538.26it/s]
100%|██████████| 1/1 [00:00<00:00,  3.14it/s]
  1%|          | 66/10000 [00:00<00:15, 657.18it/s]

Concordance Index: 0.6766311332976391
Brier Score: 0.1949544626973324
Cumulative_dynamic AUC  0.7053828500566288
{'distribution': 'LogNormal', 'k': 10}


 18%|█▊        | 1845/10000 [00:02<00:12, 648.37it/s]
100%|██████████| 1/1 [00:00<00:00, 10.31it/s]
  0%|          | 48/10000 [00:00<00:21, 469.70it/s]

Concordance Index: 0.5217620646680267
Brier Score: 0.22052591107210942
Cumulative_dynamic AUC  0.5113355137174557
{'distribution': 'Weibull', 'k': 1}


 18%|█▊        | 1845/10000 [00:02<00:13, 620.63it/s]
100%|██████████| 1/1 [00:00<00:00,  8.43it/s]
  1%|          | 67/10000 [00:00<00:14, 665.06it/s]

Concordance Index: 0.5655511571626577
Brier Score: 0.2029583594877187
Cumulative_dynamic AUC  0.5752601364418766
{'distribution': 'Weibull', 'k': 2}


 18%|█▊        | 1845/10000 [00:03<00:13, 597.92it/s]
100%|██████████| 1/1 [00:00<00:00,  7.94it/s]
  1%|          | 64/10000 [00:00<00:15, 634.77it/s]

Concordance Index: 0.6011102388064652
Brier Score: 0.20006664777293817
Cumulative_dynamic AUC  0.6150050341748386
{'distribution': 'Weibull', 'k': 3}


 18%|█▊        | 1845/10000 [00:03<00:14, 546.07it/s]
100%|██████████| 1/1 [00:00<00:00,  5.64it/s]
  1%|          | 52/10000 [00:00<00:19, 515.34it/s]

Concordance Index: 0.5334196125687835
Brier Score: 0.2058152474159565
Cumulative_dynamic AUC  0.5452291852340513
{'distribution': 'Weibull', 'k': 4}


 18%|█▊        | 1845/10000 [00:03<00:14, 553.79it/s]
100%|██████████| 1/1 [00:00<00:00,  6.71it/s]
  0%|          | 28/10000 [00:00<00:36, 275.12it/s]

Concordance Index: 0.5545690402614445
Brier Score: 0.20617123598941967
Cumulative_dynamic AUC  0.5476095391281718
{'distribution': 'Weibull', 'k': 5}


 18%|█▊        | 1845/10000 [00:03<00:14, 551.13it/s]
100%|██████████| 1/1 [00:00<00:00,  6.12it/s]
  1%|          | 67/10000 [00:00<00:14, 663.42it/s]

Concordance Index: 0.6476051251161218
Brier Score: 0.19588918479440143
Cumulative_dynamic AUC  0.6569182166503739
{'distribution': 'Weibull', 'k': 6}


 18%|█▊        | 1845/10000 [00:02<00:12, 631.03it/s]
100%|██████████| 1/1 [00:00<00:00,  5.23it/s]
  1%|          | 67/10000 [00:00<00:14, 667.21it/s]

Concordance Index: 0.5858539801383983
Brier Score: 0.2034606185365795
Cumulative_dynamic AUC  0.5913725838715068
{'distribution': 'Weibull', 'k': 7}


 18%|█▊        | 1845/10000 [00:02<00:11, 688.18it/s]
100%|██████████| 1/1 [00:00<00:00,  5.08it/s]
  0%|          | 48/10000 [00:00<00:20, 478.21it/s]

Concordance Index: 0.5361255567804992
Brier Score: 0.20410873790652573
Cumulative_dynamic AUC  0.5468374972837957
{'distribution': 'Weibull', 'k': 8}


 18%|█▊        | 1845/10000 [00:03<00:13, 596.08it/s]
100%|██████████| 1/1 [00:00<00:00,  3.78it/s]
  1%|          | 65/10000 [00:00<00:15, 646.81it/s]

Concordance Index: 0.6072673284548974
Brier Score: 0.20075580259939282
Cumulative_dynamic AUC  0.6229017351810217
{'distribution': 'Weibull', 'k': 9}


 18%|█▊        | 1845/10000 [00:03<00:13, 603.63it/s]
100%|██████████| 1/1 [00:00<00:00,  4.37it/s]

Concordance Index: 0.6125137904891208
Brier Score: 0.1990688415320855
Cumulative_dynamic AUC  0.6332807149415878
{'distribution': 'Weibull', 'k': 10}





In [236]:
model = DeepSurvivalMachines(k = 10, distribution = 'LogNormal')
model.fit(x_train, t_train, e_train)
out_risk = model.predict_risk(x_test, times)
out_survival = model.predict_survival(x_test, times)
    
cis = []
brs = []
for i in range(len(times)):
    cis.append(concordance_index_ipcw(et_train, et_test, out_risk[:,i], times[i])[0])

brs.append(brier_score(et_train, et_test, out_survival, times)[1])
print ("Concordance Index:", np.mean(cis))
print ("Brier Score:", np.mean(brs))
cdauc=[]
for i in range(len(times)):
    cdauc.append(cumulative_dynamic_auc(et_train, et_test, out_risk[:,i], times[i])[0])
print ("Cumulative_dynamic AUC ", np.mean(cdauc))      

 12%|█▏        | 1240/10000 [00:02<00:15, 568.10it/s]
100%|██████████| 1/1 [00:00<00:00,  3.37it/s]

Concordance Index: 0.6581903218379163
Brier Score: 0.18598571651278992
Cumulative_dynamic AUC  0.6829802892087405



