# DSM on SUPPORT Dataset

SUPPORT: This dataset comes from the Vanderbilt University study to estimate survival for seriously ill hospitalized adults. (Refer to http://biostat.mc.vanderbilt.edu/wiki/Main/SupportDesc. for the original datasource.)
In this notebook, we will demonstrate application of deep survival machines for survival prediction on the support dataset.

In [1]:
from dsm import datasets
from dsm import DeepSurvivalMachines

import numpy as np

from sklearn.model_selection import ParameterGrid
from sksurv.metrics import concordance_index_ipcw, brier_score, cumulative_dynamic_auc

### Load the Dataset

In [2]:
x, t, e = datasets.load_dataset('SUPPORT')

### Computing the required quantile of the data

In [3]:
req_quantiles = [0.25, 0.5, 0.75]
times = np.quantile(t[e == 1], req_quantiles).tolist()

### Defining the train, test and validation size

In [4]:
data_size = len(x)
tr_size = int(data_size * 0.70)
vl_size = int(data_size * 0.10)
te_size = int(data_size * 0.20)

### Splitting the data into train, test and validation sets

In [5]:
x_train = x[: tr_size]
t_train = t[: tr_size]
e_train = e[: tr_size]

x_test = x[-te_size :]
t_test = t[-te_size :]
e_test = e[-te_size :]

x_val = x[tr_size : tr_size + vl_size]
t_val = t[tr_size : tr_size + vl_size]
e_val = e[tr_size : tr_size + vl_size]

In [6]:
et_train = np.array([(e_train[i], t_train[i]) for i in range(len(e_train))],
                 dtype = [('e', bool), ('t', float)])
et_test = np.array([(e_test[i], t_test[i]) for i in range(len(e_test))],
                 dtype = [('e', bool), ('t', float)])
et_val = np.array([(e_val[i], t_val[i]) for i in range(len(e_val))],
                 dtype = [('e', bool), ('t', float)])

In [7]:
print(et_train.shape, et_test.shape, et_val.shape)

(6373,) (1821,) (910,)


### Defining the parameter grid

In [8]:
param_grid = {'k' : [3, 4, 6],
              'distribution' : ['LogNormal', 'Weibull'],
              'learning_rate' : [ 1e-4, 1e-3],
              'layers' : [ [], [100], [100, 100] ]
             }
params = ParameterGrid(param_grid)

### Validation

In [10]:
loss = []
for param in params:
    model = DeepSurvivalMachines(k = param['k'],
                                 distribution = param['distribution'],
                                 layers = param['layers'])
    model.fit(x_train, t_train, e_train, iters = 100, learning_rate = param['learning_rate'])
#     out_risk = model.predict_risk(x_val, times)
#     out_survival = model.predict_survival(x_val, times)

#     cis = []
#     brs = []
#     for i, _ in enumerate(times):
#         cis.append(concordance_index_ipcw(et_train, et_val, out_risk[:, i], times[i])[0])

#     brs.append(brier_score(et_train, et_val, out_survival, times)[1])
#     print("Concordance Index : ", np.mean(cis))
#     print("Brier Score : ", np.mean(brs))
#     cdauc = []
#     for i, _ in enumerate(times):
#         cdauc.append(cumulative_dynamic_auc(et_train, et_val, out_risk[:, i], times[i])[0])
#     print("Cumulative_dynamic AUC : ", np.mean(cdauc))
#     print(param)
    loss.append([[model.compute_nll(x_val, t_val, e_val), param]])
    print("Loss : ", model.compute_nll(x_val, t_val, e_val))

 12%|█▏        | 1242/10000 [00:02<00:16, 523.27it/s]
  0%|          | 0/100 [00:00<?, ?it/s]

ParameterDict(  (1): Parameter containing: [torch.DoubleTensor of size 1]) ParameterDict(  (1): Parameter containing: [torch.DoubleTensor of size 1])


100%|██████████| 100/100 [00:24<00:00,  4.05it/s]
  0%|          | 48/10000 [00:00<00:20, 475.56it/s]

Loss :  1.791779679102967


 12%|█▏        | 1242/10000 [00:02<00:18, 475.42it/s]
  0%|          | 0/100 [00:00<?, ?it/s]

ParameterDict(  (1): Parameter containing: [torch.DoubleTensor of size 1]) ParameterDict(  (1): Parameter containing: [torch.DoubleTensor of size 1])


 31%|███       | 31/100 [00:09<00:20,  3.44it/s]
  1%|          | 61/10000 [00:00<00:16, 603.40it/s]

Loss :  1.892768051162644


 12%|█▏        | 1242/10000 [00:02<00:16, 544.16it/s]
  0%|          | 0/100 [00:00<?, ?it/s]

ParameterDict(  (1): Parameter containing: [torch.DoubleTensor of size 1]) ParameterDict(  (1): Parameter containing: [torch.DoubleTensor of size 1])


 69%|██████▉   | 69/100 [00:28<00:12,  2.40it/s]
  0%|          | 49/10000 [00:00<00:20, 488.47it/s]

Loss :  1.901654906764768


 12%|█▏        | 1242/10000 [00:02<00:16, 544.89it/s]
  0%|          | 0/100 [00:00<?, ?it/s]

ParameterDict(  (1): Parameter containing: [torch.DoubleTensor of size 1]) ParameterDict(  (1): Parameter containing: [torch.DoubleTensor of size 1])


 10%|█         | 10/100 [00:03<00:29,  3.04it/s]
  0%|          | 48/10000 [00:00<00:20, 479.02it/s]

Loss :  1.8302455224742051


 12%|█▏        | 1242/10000 [00:02<00:15, 559.42it/s]
  0%|          | 0/100 [00:00<?, ?it/s]

ParameterDict(  (1): Parameter containing: [torch.DoubleTensor of size 1]) ParameterDict(  (1): Parameter containing: [torch.DoubleTensor of size 1])


 37%|███▋      | 37/100 [00:16<00:28,  2.19it/s]
  0%|          | 25/10000 [00:00<00:39, 249.79it/s]

Loss :  1.9027040020995745


 12%|█▏        | 1242/10000 [00:03<00:23, 369.39it/s]
  0%|          | 0/100 [00:00<?, ?it/s]

ParameterDict(  (1): Parameter containing: [torch.DoubleTensor of size 1]) ParameterDict(  (1): Parameter containing: [torch.DoubleTensor of size 1])


  7%|▋         | 7/100 [00:03<00:39,  2.33it/s]
  0%|          | 26/10000 [00:00<00:39, 255.62it/s]

Loss :  1.822405936729107


 12%|█▏        | 1242/10000 [00:02<00:20, 417.84it/s]
  0%|          | 0/100 [00:00<?, ?it/s]

ParameterDict(  (1): Parameter containing: [torch.DoubleTensor of size 1]) ParameterDict(  (1): Parameter containing: [torch.DoubleTensor of size 1])


100%|██████████| 100/100 [00:37<00:00,  2.65it/s]
  1%|          | 58/10000 [00:00<00:17, 571.10it/s]

Loss :  1.7920227562610282


 12%|█▏        | 1242/10000 [00:02<00:15, 576.21it/s]
  0%|          | 0/100 [00:00<?, ?it/s]

ParameterDict(  (1): Parameter containing: [torch.DoubleTensor of size 1]) ParameterDict(  (1): Parameter containing: [torch.DoubleTensor of size 1])


 27%|██▋       | 27/100 [00:08<00:23,  3.17it/s]
  0%|          | 46/10000 [00:00<00:21, 454.04it/s]

Loss :  1.8986661311966089


 12%|█▏        | 1242/10000 [00:02<00:15, 568.67it/s]
  0%|          | 0/100 [00:00<?, ?it/s]

ParameterDict(  (1): Parameter containing: [torch.DoubleTensor of size 1]) ParameterDict(  (1): Parameter containing: [torch.DoubleTensor of size 1])


 60%|██████    | 60/100 [00:29<00:19,  2.03it/s]
  1%|          | 55/10000 [00:00<00:18, 541.57it/s]

Loss :  1.9286739525572432


 12%|█▏        | 1242/10000 [00:02<00:14, 592.43it/s]
  0%|          | 0/100 [00:00<?, ?it/s]

ParameterDict(  (1): Parameter containing: [torch.DoubleTensor of size 1]) ParameterDict(  (1): Parameter containing: [torch.DoubleTensor of size 1])


 10%|█         | 10/100 [00:05<00:45,  1.97it/s]
  0%|          | 21/10000 [00:00<00:47, 209.28it/s]

Loss :  1.8398291981803336


 12%|█▏        | 1242/10000 [00:03<00:22, 388.16it/s]
  0%|          | 0/100 [00:00<?, ?it/s]

ParameterDict(  (1): Parameter containing: [torch.DoubleTensor of size 1]) ParameterDict(  (1): Parameter containing: [torch.DoubleTensor of size 1])


 36%|███▌      | 36/100 [00:17<00:31,  2.06it/s]
  0%|          | 21/10000 [00:00<00:47, 209.10it/s]

Loss :  1.9015395994294026


 12%|█▏        | 1242/10000 [00:02<00:20, 427.45it/s]
  0%|          | 0/100 [00:00<?, ?it/s]

ParameterDict(  (1): Parameter containing: [torch.DoubleTensor of size 1]) ParameterDict(  (1): Parameter containing: [torch.DoubleTensor of size 1])


  8%|▊         | 8/100 [00:04<00:54,  1.68it/s]
  1%|          | 53/10000 [00:00<00:18, 529.12it/s]

Loss :  1.8217662619632706


 12%|█▏        | 1242/10000 [00:02<00:15, 573.31it/s]
  0%|          | 0/100 [00:00<?, ?it/s]

ParameterDict(  (1): Parameter containing: [torch.DoubleTensor of size 1]) ParameterDict(  (1): Parameter containing: [torch.DoubleTensor of size 1])


100%|██████████| 100/100 [00:48<00:00,  2.07it/s]
  0%|          | 35/10000 [00:00<00:28, 345.83it/s]

Loss :  1.7920228826407076


 12%|█▏        | 1242/10000 [00:02<00:16, 541.72it/s]
  0%|          | 0/100 [00:00<?, ?it/s]

ParameterDict(  (1): Parameter containing: [torch.DoubleTensor of size 1]) ParameterDict(  (1): Parameter containing: [torch.DoubleTensor of size 1])


 31%|███       | 31/100 [00:16<00:35,  1.93it/s]
  1%|          | 52/10000 [00:00<00:19, 515.63it/s]

Loss :  1.8808444877323536


 12%|█▏        | 1242/10000 [00:02<00:15, 570.95it/s]
  0%|          | 0/100 [00:00<?, ?it/s]

ParameterDict(  (1): Parameter containing: [torch.DoubleTensor of size 1]) ParameterDict(  (1): Parameter containing: [torch.DoubleTensor of size 1])


 67%|██████▋   | 67/100 [00:37<00:18,  1.80it/s]
  1%|          | 55/10000 [00:00<00:18, 544.09it/s]

Loss :  1.9074769414689978


 12%|█▏        | 1242/10000 [00:03<00:21, 403.81it/s]
  0%|          | 0/100 [00:00<?, ?it/s]

ParameterDict(  (1): Parameter containing: [torch.DoubleTensor of size 1]) ParameterDict(  (1): Parameter containing: [torch.DoubleTensor of size 1])


 10%|█         | 10/100 [00:07<01:07,  1.34it/s]
  1%|          | 52/10000 [00:00<00:19, 518.99it/s]

Loss :  1.8344213323043301


 12%|█▏        | 1242/10000 [00:02<00:16, 537.25it/s]
  0%|          | 0/100 [00:00<?, ?it/s]

ParameterDict(  (1): Parameter containing: [torch.DoubleTensor of size 1]) ParameterDict(  (1): Parameter containing: [torch.DoubleTensor of size 1])


 36%|███▌      | 36/100 [00:30<00:53,  1.20it/s]
  1%|          | 60/10000 [00:00<00:16, 597.66it/s]

Loss :  1.9013559302687577


 12%|█▏        | 1242/10000 [00:03<00:23, 372.04it/s]
  0%|          | 0/100 [00:00<?, ?it/s]

ParameterDict(  (1): Parameter containing: [torch.DoubleTensor of size 1]) ParameterDict(  (1): Parameter containing: [torch.DoubleTensor of size 1])


  7%|▋         | 7/100 [00:04<01:04,  1.45it/s]
  0%|          | 44/10000 [00:00<00:22, 436.51it/s]

Loss :  1.8270592550232851


 18%|█▊        | 1845/10000 [00:04<00:20, 401.10it/s]
  0%|          | 0/100 [00:00<?, ?it/s]

ParameterDict(  (1): Parameter containing: [torch.DoubleTensor of size 1]) ParameterDict(  (1): Parameter containing: [torch.DoubleTensor of size 1])


100%|██████████| 100/100 [00:43<00:00,  2.31it/s]
  0%|          | 48/10000 [00:00<00:20, 478.77it/s]

Loss :  4.6093532490047675


 18%|█▊        | 1845/10000 [00:05<00:25, 317.52it/s]
  0%|          | 0/100 [00:00<?, ?it/s]

ParameterDict(  (1): Parameter containing: [torch.DoubleTensor of size 1]) ParameterDict(  (1): Parameter containing: [torch.DoubleTensor of size 1])


100%|██████████| 100/100 [00:38<00:00,  2.60it/s]
  0%|          | 27/10000 [00:00<00:37, 268.23it/s]

Loss :  4.608030412650856


 18%|█▊        | 1845/10000 [00:06<00:28, 282.90it/s]
  0%|          | 0/100 [00:00<?, ?it/s]

ParameterDict(  (1): Parameter containing: [torch.DoubleTensor of size 1]) ParameterDict(  (1): Parameter containing: [torch.DoubleTensor of size 1])


 94%|█████████▍| 94/100 [00:46<00:02,  2.02it/s]
  0%|          | 28/10000 [00:00<00:36, 276.44it/s]

Loss :  4.670185381833166


 18%|█▊        | 1845/10000 [00:05<00:23, 351.93it/s]
  0%|          | 0/100 [00:00<?, ?it/s]

ParameterDict(  (1): Parameter containing: [torch.DoubleTensor of size 1]) ParameterDict(  (1): Parameter containing: [torch.DoubleTensor of size 1])


 10%|█         | 10/100 [00:04<00:40,  2.23it/s]
  0%|          | 24/10000 [00:00<00:42, 233.92it/s]

Loss :  4.569357715390566


 18%|█▊        | 1845/10000 [00:04<00:21, 379.78it/s]
  0%|          | 0/100 [00:00<?, ?it/s]

ParameterDict(  (1): Parameter containing: [torch.DoubleTensor of size 1]) ParameterDict(  (1): Parameter containing: [torch.DoubleTensor of size 1])


 41%|████      | 41/100 [00:19<00:28,  2.08it/s]
  0%|          | 24/10000 [00:00<00:43, 227.72it/s]

Loss :  4.659017084529439


 18%|█▊        | 1845/10000 [00:04<00:19, 418.31it/s]
  0%|          | 0/100 [00:00<?, ?it/s]

ParameterDict(  (1): Parameter containing: [torch.DoubleTensor of size 1]) ParameterDict(  (1): Parameter containing: [torch.DoubleTensor of size 1])


 10%|█         | 10/100 [00:04<00:39,  2.27it/s]
  0%|          | 34/10000 [00:00<00:29, 334.36it/s]

Loss :  4.561482861739587


 18%|█▊        | 1845/10000 [00:05<00:23, 346.70it/s]
  0%|          | 0/100 [00:00<?, ?it/s]

ParameterDict(  (1): Parameter containing: [torch.DoubleTensor of size 1]) ParameterDict(  (1): Parameter containing: [torch.DoubleTensor of size 1])


100%|██████████| 100/100 [00:39<00:00,  2.52it/s]
  0%|          | 44/10000 [00:00<00:22, 434.72it/s]

Loss :  4.591173457889463


 18%|█▊        | 1845/10000 [00:03<00:17, 474.91it/s]
  0%|          | 0/100 [00:00<?, ?it/s]

ParameterDict(  (1): Parameter containing: [torch.DoubleTensor of size 1]) ParameterDict(  (1): Parameter containing: [torch.DoubleTensor of size 1])


100%|██████████| 100/100 [00:39<00:00,  2.51it/s]
  1%|          | 58/10000 [00:00<00:17, 572.76it/s]

Loss :  4.545935570638167


 18%|█▊        | 1845/10000 [00:04<00:18, 443.60it/s]
  0%|          | 0/100 [00:00<?, ?it/s]

ParameterDict(  (1): Parameter containing: [torch.DoubleTensor of size 1]) ParameterDict(  (1): Parameter containing: [torch.DoubleTensor of size 1])


 65%|██████▌   | 65/100 [00:27<00:15,  2.33it/s]
  1%|          | 52/10000 [00:00<00:19, 513.55it/s]

Loss :  4.674528665539941


 18%|█▊        | 1845/10000 [00:03<00:14, 551.36it/s]
  0%|          | 0/100 [00:00<?, ?it/s]

ParameterDict(  (1): Parameter containing: [torch.DoubleTensor of size 1]) ParameterDict(  (1): Parameter containing: [torch.DoubleTensor of size 1])


 10%|█         | 10/100 [00:04<00:39,  2.25it/s]
  1%|          | 53/10000 [00:00<00:18, 526.86it/s]

Loss :  4.588916240897539


 18%|█▊        | 1845/10000 [00:04<00:21, 385.40it/s]
  0%|          | 0/100 [00:00<?, ?it/s]

ParameterDict(  (1): Parameter containing: [torch.DoubleTensor of size 1]) ParameterDict(  (1): Parameter containing: [torch.DoubleTensor of size 1])


 26%|██▌       | 26/100 [00:14<00:40,  1.82it/s]
  0%|          | 25/10000 [00:00<00:40, 246.07it/s]

Loss :  4.659496441592007


 18%|█▊        | 1845/10000 [00:05<00:23, 345.68it/s]
  0%|          | 0/100 [00:00<?, ?it/s]

ParameterDict(  (1): Parameter containing: [torch.DoubleTensor of size 1]) ParameterDict(  (1): Parameter containing: [torch.DoubleTensor of size 1])


  7%|▋         | 7/100 [00:03<00:42,  2.19it/s]
  0%|          | 33/10000 [00:00<00:30, 327.41it/s]

Loss :  4.565743905249313


 18%|█▊        | 1845/10000 [00:05<00:22, 366.16it/s]
  0%|          | 0/100 [00:00<?, ?it/s]

ParameterDict(  (1): Parameter containing: [torch.DoubleTensor of size 1]) ParameterDict(  (1): Parameter containing: [torch.DoubleTensor of size 1])


100%|██████████| 100/100 [00:42<00:00,  2.33it/s]
  0%|          | 27/10000 [00:00<00:37, 265.40it/s]

Loss :  4.532594889862267


 18%|█▊        | 1845/10000 [00:04<00:18, 431.78it/s]
  0%|          | 0/100 [00:00<?, ?it/s]

ParameterDict(  (1): Parameter containing: [torch.DoubleTensor of size 1]) ParameterDict(  (1): Parameter containing: [torch.DoubleTensor of size 1])


100%|██████████| 100/100 [00:45<00:00,  2.21it/s]
  1%|          | 66/10000 [00:00<00:15, 650.80it/s]

Loss :  4.531720977985748


 18%|█▊        | 1845/10000 [00:04<00:21, 383.80it/s]
  0%|          | 0/100 [00:00<?, ?it/s]

ParameterDict(  (1): Parameter containing: [torch.DoubleTensor of size 1]) ParameterDict(  (1): Parameter containing: [torch.DoubleTensor of size 1])


 82%|████████▏ | 82/100 [00:44<00:09,  1.85it/s]
  0%|          | 47/10000 [00:00<00:21, 468.18it/s]

Loss :  4.67510823640753


 18%|█▊        | 1845/10000 [00:03<00:15, 542.91it/s]
  0%|          | 0/100 [00:00<?, ?it/s]

ParameterDict(  (1): Parameter containing: [torch.DoubleTensor of size 1]) ParameterDict(  (1): Parameter containing: [torch.DoubleTensor of size 1])


 11%|█         | 11/100 [00:06<00:54,  1.64it/s]
  0%|          | 45/10000 [00:00<00:22, 448.17it/s]

Loss :  4.59325414806432


 18%|█▊        | 1845/10000 [00:05<00:23, 347.77it/s]
  0%|          | 0/100 [00:00<?, ?it/s]

ParameterDict(  (1): Parameter containing: [torch.DoubleTensor of size 1]) ParameterDict(  (1): Parameter containing: [torch.DoubleTensor of size 1])


 53%|█████▎    | 53/100 [00:31<00:28,  1.66it/s]
  1%|          | 61/10000 [00:00<00:16, 603.92it/s]

Loss :  4.663408918330405


 18%|█▊        | 1845/10000 [00:03<00:14, 550.58it/s]
  0%|          | 0/100 [00:00<?, ?it/s]

ParameterDict(  (1): Parameter containing: [torch.DoubleTensor of size 1]) ParameterDict(  (1): Parameter containing: [torch.DoubleTensor of size 1])


  8%|▊         | 8/100 [00:04<00:47,  1.92it/s]

Loss :  4.567481604190257





In [13]:
best_param = min(loss)

### Training

In [22]:
model = DeepSurvivalMachines(k = best_param[0][1]['k'],
                                 distribution = best_param[0][1]['distribution'],
                                 layers = best_param[0][1]['layers'])
model.fit(x_train, t_train, e_train,
          iters = 100,
          learning_rate = best_param[0][1]['learning_rate'])

 12%|█▏        | 1242/10000 [00:02<00:15, 572.97it/s]
  0%|          | 0/100 [00:00<?, ?it/s]

ParameterDict(  (1): Parameter containing: [torch.DoubleTensor of size 1]) ParameterDict(  (1): Parameter containing: [torch.DoubleTensor of size 1])


100%|██████████| 100/100 [00:25<00:00,  3.87it/s]


<dsm.dsm_api.DeepSurvivalMachines at 0x7f556d346400>

### Inference

In [23]:
out_risk = model.predict_risk(x_test, times)
out_survival = model.predict_survival(x_test, times)

### Evaluation Metrics

In [31]:
cis = []
brs = []
for i, _ in enumerate(times):
    cis.append(concordance_index_ipcw(et_train, et_test, out_risk[:, i], times[i])[0])

brs.append(brier_score(et_train, et_test, out_survival, times)[1])
cdauc = []
for i, _ in enumerate(times):
    cdauc.append(cumulative_dynamic_auc(et_train, et_test, out_risk[:, i], times[i])[0])

for i in enumerate(req_quantiles):    
    print(f"For {req_quantiles[i]} quantile,")
    print("TD Concordance Index:", cis[i])
    print("Brier Score:", brs[0][i])
    print("ROC AUC ", cdauc[i][0], "\n")

For 0.25 quantile,
TD Concordance Index: 0.7558259663556339
Brier Score: 0.11231144600979545
ROC AUC  0.7636623594102698 

For 0.5 quantile,
TD Concordance Index: 0.6938908829155738
Brier Score: 0.1852553403507094
ROC AUC  0.7127964681618975 

For 0.75 quantile,
TD Concordance Index: 0.6624767327831064
Brier Score: 0.2224485369397085
ROC AUC  0.7197007330223542 

