# DSM on SUPPORT Dataset

SUPPORT: This dataset comes from the Vanderbilt University study
to estimate survival for seriously ill hospitalized adults.
(Refer to http://biostat.mc.vanderbilt.edu/wiki/Main/SupportDesc.
for the original datasource.)
In this notebook, we will demonstrate application of
Deep Survival Machines for survival prediction on the SUPPORT dataset.

### Load the SUPPORT Dataset

In [1]:
from dsm import datasets

In [2]:
x, t, e = datasets.load_dataset('SUPPORT')

### Compute horizons at which we evaluate the performance of DSM

In [3]:
import numpy as np

In [17]:
horizons = [0.25, 0.5, 0.75]
times = np.quantile(t[e == 1], horizons).tolist()

### Splitting the data into train, test and validation sets

In [5]:
data_size = len(x)
tr_size = int(data_size * 0.70)
vl_size = int(data_size * 0.10)
te_size = int(data_size * 0.20)

x_train = x[: tr_size]
t_train = t[: tr_size]
e_train = e[: tr_size]

x_test = x[-te_size :]
t_test = t[-te_size :]
e_test = e[-te_size :]

x_val = x[tr_size : tr_size + vl_size]
t_val = t[tr_size : tr_size + vl_size]
e_val = e[tr_size : tr_size + vl_size]

### Setting the parameter grid

In [6]:
from sklearn.model_selection import ParameterGrid

In [7]:
param_grid = {'k' : [3, 4, 6],
              'distribution' : ['LogNormal', 'Weibull'],
              'learning_rate' : [ 1e-4, 1e-3],
              'layers' : [ [], [100], [100, 100] ]
             }
params = ParameterGrid(param_grid)

### Model Training and Selection

In [8]:
from dsm import DeepSurvivalMachines

In [9]:
models = []
for param in params:
    model = DeepSurvivalMachines(k = param['k'],
                                 distribution = param['distribution'],
                                 layers = param['layers'])
    model.fit(x_train, t_train, e_train, iters = 100, learning_rate = param['learning_rate'])
    models.append([[model.compute_nll(x_val, t_val, e_val), model]])
best_model = min(models)
model = best_model[0][1]

  Variable._execution_engine.run_backward(
 12%|█▏        | 1242/10000 [00:03<00:22, 386.70it/s]
  0%|          | 0/100 [00:00<?, ?it/s]

ParameterDict(  (1): Parameter containing: [torch.DoubleTensor of size 1]) ParameterDict(  (1): Parameter containing: [torch.DoubleTensor of size 1])


100%|██████████| 100/100 [00:28<00:00,  3.57it/s]
 12%|█▏        | 1242/10000 [00:02<00:17, 508.96it/s]
  0%|          | 0/100 [00:00<?, ?it/s]

ParameterDict(  (1): Parameter containing: [torch.DoubleTensor of size 1]) ParameterDict(  (1): Parameter containing: [torch.DoubleTensor of size 1])


 29%|██▉       | 29/100 [00:07<00:19,  3.68it/s]
 12%|█▏        | 1242/10000 [00:02<00:18, 471.82it/s]
  0%|          | 0/100 [00:00<?, ?it/s]

ParameterDict(  (1): Parameter containing: [torch.DoubleTensor of size 1]) ParameterDict(  (1): Parameter containing: [torch.DoubleTensor of size 1])


 75%|███████▌  | 75/100 [00:31<00:10,  2.36it/s]
 12%|█▏        | 1242/10000 [00:02<00:18, 463.31it/s]
  0%|          | 0/100 [00:00<?, ?it/s]

ParameterDict(  (1): Parameter containing: [torch.DoubleTensor of size 1]) ParameterDict(  (1): Parameter containing: [torch.DoubleTensor of size 1])


  9%|▉         | 9/100 [00:03<00:33,  2.69it/s]
 12%|█▏        | 1242/10000 [00:03<00:22, 393.72it/s]
  0%|          | 0/100 [00:00<?, ?it/s]

ParameterDict(  (1): Parameter containing: [torch.DoubleTensor of size 1]) ParameterDict(  (1): Parameter containing: [torch.DoubleTensor of size 1])


 36%|███▌      | 36/100 [00:17<00:31,  2.00it/s]
 12%|█▏        | 1242/10000 [00:02<00:16, 518.61it/s]
  0%|          | 0/100 [00:00<?, ?it/s]

ParameterDict(  (1): Parameter containing: [torch.DoubleTensor of size 1]) ParameterDict(  (1): Parameter containing: [torch.DoubleTensor of size 1])


  6%|▌         | 6/100 [00:02<00:41,  2.26it/s]
 12%|█▏        | 1242/10000 [00:02<00:20, 437.86it/s]
  0%|          | 0/100 [00:00<?, ?it/s]

ParameterDict(  (1): Parameter containing: [torch.DoubleTensor of size 1]) ParameterDict(  (1): Parameter containing: [torch.DoubleTensor of size 1])


100%|██████████| 100/100 [00:39<00:00,  2.52it/s]
 12%|█▏        | 1242/10000 [00:02<00:18, 480.01it/s]
  0%|          | 0/100 [00:00<?, ?it/s]

ParameterDict(  (1): Parameter containing: [torch.DoubleTensor of size 1]) ParameterDict(  (1): Parameter containing: [torch.DoubleTensor of size 1])


 27%|██▋       | 27/100 [00:09<00:26,  2.75it/s]
 12%|█▏        | 1242/10000 [00:02<00:15, 551.97it/s]
  0%|          | 0/100 [00:00<?, ?it/s]

ParameterDict(  (1): Parameter containing: [torch.DoubleTensor of size 1]) ParameterDict(  (1): Parameter containing: [torch.DoubleTensor of size 1])


 60%|██████    | 60/100 [00:29<00:19,  2.00it/s]
 12%|█▏        | 1242/10000 [00:02<00:17, 500.31it/s]
  0%|          | 0/100 [00:00<?, ?it/s]

ParameterDict(  (1): Parameter containing: [torch.DoubleTensor of size 1]) ParameterDict(  (1): Parameter containing: [torch.DoubleTensor of size 1])


 11%|█         | 11/100 [00:06<00:52,  1.71it/s]
 12%|█▏        | 1242/10000 [00:02<00:20, 434.73it/s]
  0%|          | 0/100 [00:00<?, ?it/s]

ParameterDict(  (1): Parameter containing: [torch.DoubleTensor of size 1]) ParameterDict(  (1): Parameter containing: [torch.DoubleTensor of size 1])


 30%|███       | 30/100 [00:17<00:39,  1.76it/s]
 12%|█▏        | 1242/10000 [00:04<00:29, 296.95it/s]
  0%|          | 0/100 [00:00<?, ?it/s]

ParameterDict(  (1): Parameter containing: [torch.DoubleTensor of size 1]) ParameterDict(  (1): Parameter containing: [torch.DoubleTensor of size 1])


  6%|▌         | 6/100 [00:04<01:09,  1.34it/s]
 12%|█▏        | 1242/10000 [00:02<00:17, 492.24it/s]
  0%|          | 0/100 [00:00<?, ?it/s]

ParameterDict(  (1): Parameter containing: [torch.DoubleTensor of size 1]) ParameterDict(  (1): Parameter containing: [torch.DoubleTensor of size 1])


100%|██████████| 100/100 [00:55<00:00,  1.80it/s]
 12%|█▏        | 1242/10000 [00:02<00:16, 524.59it/s]
  0%|          | 0/100 [00:00<?, ?it/s]

ParameterDict(  (1): Parameter containing: [torch.DoubleTensor of size 1]) ParameterDict(  (1): Parameter containing: [torch.DoubleTensor of size 1])


 26%|██▌       | 26/100 [00:14<00:41,  1.77it/s]
 12%|█▏        | 1242/10000 [00:02<00:18, 478.92it/s]
  0%|          | 0/100 [00:00<?, ?it/s]

ParameterDict(  (1): Parameter containing: [torch.DoubleTensor of size 1]) ParameterDict(  (1): Parameter containing: [torch.DoubleTensor of size 1])


 82%|████████▏ | 82/100 [00:51<00:11,  1.61it/s]
 12%|█▏        | 1242/10000 [00:02<00:19, 445.79it/s]
  0%|          | 0/100 [00:00<?, ?it/s]

ParameterDict(  (1): Parameter containing: [torch.DoubleTensor of size 1]) ParameterDict(  (1): Parameter containing: [torch.DoubleTensor of size 1])


 11%|█         | 11/100 [00:06<00:55,  1.60it/s]
 12%|█▏        | 1242/10000 [00:03<00:26, 336.13it/s]
  0%|          | 0/100 [00:00<?, ?it/s]

ParameterDict(  (1): Parameter containing: [torch.DoubleTensor of size 1]) ParameterDict(  (1): Parameter containing: [torch.DoubleTensor of size 1])


 27%|██▋       | 27/100 [00:23<01:03,  1.15it/s]
 12%|█▏        | 1242/10000 [00:03<00:27, 314.79it/s]
  0%|          | 0/100 [00:00<?, ?it/s]

ParameterDict(  (1): Parameter containing: [torch.DoubleTensor of size 1]) ParameterDict(  (1): Parameter containing: [torch.DoubleTensor of size 1])


  5%|▌         | 5/100 [00:04<01:22,  1.15it/s]
 18%|█▊        | 1845/10000 [00:04<00:21, 375.14it/s]
  0%|          | 0/100 [00:00<?, ?it/s]

ParameterDict(  (1): Parameter containing: [torch.DoubleTensor of size 1]) ParameterDict(  (1): Parameter containing: [torch.DoubleTensor of size 1])


100%|██████████| 100/100 [00:36<00:00,  2.72it/s]
 18%|█▊        | 1845/10000 [00:06<00:29, 274.87it/s]
  0%|          | 0/100 [00:00<?, ?it/s]

ParameterDict(  (1): Parameter containing: [torch.DoubleTensor of size 1]) ParameterDict(  (1): Parameter containing: [torch.DoubleTensor of size 1])


100%|██████████| 100/100 [00:38<00:00,  2.62it/s]
 18%|█▊        | 1845/10000 [00:03<00:16, 503.40it/s]
  0%|          | 0/100 [00:00<?, ?it/s]

ParameterDict(  (1): Parameter containing: [torch.DoubleTensor of size 1]) ParameterDict(  (1): Parameter containing: [torch.DoubleTensor of size 1])


 85%|████████▌ | 85/100 [00:39<00:06,  2.18it/s]
 18%|█▊        | 1845/10000 [00:04<00:21, 372.67it/s]
  0%|          | 0/100 [00:00<?, ?it/s]

ParameterDict(  (1): Parameter containing: [torch.DoubleTensor of size 1]) ParameterDict(  (1): Parameter containing: [torch.DoubleTensor of size 1])


 12%|█▏        | 12/100 [00:05<00:38,  2.28it/s]
 18%|█▊        | 1845/10000 [00:03<00:15, 540.56it/s]
  0%|          | 0/100 [00:00<?, ?it/s]

ParameterDict(  (1): Parameter containing: [torch.DoubleTensor of size 1]) ParameterDict(  (1): Parameter containing: [torch.DoubleTensor of size 1])


 61%|██████    | 61/100 [00:29<00:18,  2.09it/s]
 18%|█▊        | 1845/10000 [00:04<00:18, 442.52it/s]
  0%|          | 0/100 [00:00<?, ?it/s]

ParameterDict(  (1): Parameter containing: [torch.DoubleTensor of size 1]) ParameterDict(  (1): Parameter containing: [torch.DoubleTensor of size 1])


  7%|▋         | 7/100 [00:04<00:54,  1.72it/s]
 18%|█▊        | 1845/10000 [00:03<00:17, 472.40it/s]
  0%|          | 0/100 [00:00<?, ?it/s]

ParameterDict(  (1): Parameter containing: [torch.DoubleTensor of size 1]) ParameterDict(  (1): Parameter containing: [torch.DoubleTensor of size 1])


100%|██████████| 100/100 [00:40<00:00,  2.50it/s]
 18%|█▊        | 1845/10000 [00:05<00:22, 363.28it/s]
  0%|          | 0/100 [00:00<?, ?it/s]

ParameterDict(  (1): Parameter containing: [torch.DoubleTensor of size 1]) ParameterDict(  (1): Parameter containing: [torch.DoubleTensor of size 1])


100%|██████████| 100/100 [00:38<00:00,  2.59it/s]
 18%|█▊        | 1845/10000 [00:04<00:21, 388.27it/s]
  0%|          | 0/100 [00:00<?, ?it/s]

ParameterDict(  (1): Parameter containing: [torch.DoubleTensor of size 1]) ParameterDict(  (1): Parameter containing: [torch.DoubleTensor of size 1])


 81%|████████  | 81/100 [00:37<00:08,  2.14it/s]
 18%|█▊        | 1845/10000 [00:03<00:16, 493.16it/s]
  0%|          | 0/100 [00:00<?, ?it/s]

ParameterDict(  (1): Parameter containing: [torch.DoubleTensor of size 1]) ParameterDict(  (1): Parameter containing: [torch.DoubleTensor of size 1])


 13%|█▎        | 13/100 [00:06<00:43,  2.01it/s]
 18%|█▊        | 1845/10000 [00:03<00:14, 549.49it/s]
  0%|          | 0/100 [00:00<?, ?it/s]

ParameterDict(  (1): Parameter containing: [torch.DoubleTensor of size 1]) ParameterDict(  (1): Parameter containing: [torch.DoubleTensor of size 1])


 42%|████▏     | 42/100 [00:23<00:32,  1.78it/s]
 18%|█▊        | 1845/10000 [00:04<00:21, 372.01it/s]
  0%|          | 0/100 [00:00<?, ?it/s]

ParameterDict(  (1): Parameter containing: [torch.DoubleTensor of size 1]) ParameterDict(  (1): Parameter containing: [torch.DoubleTensor of size 1])


 10%|█         | 10/100 [00:04<00:43,  2.09it/s]
 18%|█▊        | 1845/10000 [00:05<00:22, 367.91it/s]
  0%|          | 0/100 [00:00<?, ?it/s]

ParameterDict(  (1): Parameter containing: [torch.DoubleTensor of size 1]) ParameterDict(  (1): Parameter containing: [torch.DoubleTensor of size 1])


100%|██████████| 100/100 [00:47<00:00,  2.11it/s]
 18%|█▊        | 1845/10000 [00:04<00:18, 445.61it/s]
  0%|          | 0/100 [00:00<?, ?it/s]

ParameterDict(  (1): Parameter containing: [torch.DoubleTensor of size 1]) ParameterDict(  (1): Parameter containing: [torch.DoubleTensor of size 1])


100%|██████████| 100/100 [00:50<00:00,  2.00it/s]
 18%|█▊        | 1845/10000 [00:04<00:20, 395.85it/s]
  0%|          | 0/100 [00:00<?, ?it/s]

ParameterDict(  (1): Parameter containing: [torch.DoubleTensor of size 1]) ParameterDict(  (1): Parameter containing: [torch.DoubleTensor of size 1])


100%|██████████| 100/100 [00:54<00:00,  1.85it/s]
 18%|█▊        | 1845/10000 [00:03<00:16, 496.97it/s]
  0%|          | 0/100 [00:00<?, ?it/s]

ParameterDict(  (1): Parameter containing: [torch.DoubleTensor of size 1]) ParameterDict(  (1): Parameter containing: [torch.DoubleTensor of size 1])


 16%|█▌        | 16/100 [00:09<00:48,  1.74it/s]
 18%|█▊        | 1845/10000 [00:04<00:18, 431.29it/s]
  0%|          | 0/100 [00:00<?, ?it/s]

ParameterDict(  (1): Parameter containing: [torch.DoubleTensor of size 1]) ParameterDict(  (1): Parameter containing: [torch.DoubleTensor of size 1])


 33%|███▎      | 33/100 [00:17<00:36,  1.84it/s]
 18%|█▊        | 1845/10000 [00:04<00:18, 452.19it/s]
  0%|          | 0/100 [00:00<?, ?it/s]

ParameterDict(  (1): Parameter containing: [torch.DoubleTensor of size 1]) ParameterDict(  (1): Parameter containing: [torch.DoubleTensor of size 1])


  7%|▋         | 7/100 [00:04<00:59,  1.55it/s]


### Inference

In [14]:
out_risk = model.predict_risk(x_test, times)
out_survival = model.predict_survival(x_test, times)

### Evaluation

In [15]:
from sksurv.metrics import concordance_index_ipcw, brier_score, cumulative_dynamic_auc

In [28]:
cis = []
brs = []

et_train = np.array([(e_train[i], t_train[i]) for i in range(len(e_train))],
                 dtype = [('e', bool), ('t', float)])
et_test = np.array([(e_test[i], t_test[i]) for i in range(len(e_test))],
                 dtype = [('e', bool), ('t', float)])
et_val = np.array([(e_val[i], t_val[i]) for i in range(len(e_val))],
                 dtype = [('e', bool), ('t', float)])

for i, _ in enumerate(times):
    cis.append(concordance_index_ipcw(et_train, et_test, out_risk[:, i], times[i])[0])
brs.append(brier_score(et_train, et_test, out_survival, times)[1])
roc_auc = []
for i, _ in enumerate(times):
    roc_auc.append(cumulative_dynamic_auc(et_train, et_test, out_risk[:, i], times[i])[0])
for horizon in enumerate(horizons):
    print(f"For {horizon[1]} quantile,")
    print("TD Concordance Index:", cis[horizon[0]])
    print("Brier Score:", brs[0][horizon[0]])
    print("ROC AUC ", roc_auc[horizon[0]][0], "\n")

For 0.25 quantile,
TD Concordance Index: 0.7531559283335377
Brier Score: 0.11253308894960265
ROC AUC  0.7607097914595871 

For 0.5 quantile,
TD Concordance Index: 0.693495227483385
Brier Score: 0.18531434279970113
ROC AUC  0.7124639668143149 

For 0.75 quantile,
TD Concordance Index: 0.6622243713405473
Brier Score: 0.22235390432697502
ROC AUC  0.7192136669291815 

