# Benchmark test on SUPPORT

Import data set

In [9]:
import numpy as np
import pandas as pd
import torch

import sys
from sklearn.model_selection import ParameterGrid


sys.path.append('../')
from auton_survival import datasets
outcomes, features = datasets.load_support()

Preprocess data

In [10]:
from auton_survival.preprocessing import Preprocessor

cat_feats = ['sex', 'dzgroup', 'dzclass', 'income', 'race', 'ca']
num_feats = ['age', 'num.co', 'meanbp', 'wblc', 'hrt', 'resp',
             'temp', 'pafi', 'alb', 'bili', 'crea', 'sod', 'ph',
             'glucose', 'bun', 'urine', 'adlp', 'adls']


features = Preprocessor().fit_transform(features, cat_feats=cat_feats, num_feats=num_feats)


horizons = [0.25, 0.5, 0.75]
times = np.quantile(outcomes.time[outcomes.event==1], horizons).tolist()

x, t, e = features.values, outcomes.time.values, outcomes.event.values

n = len(x)

tr_size = int(n * 0.70)
vl_size = int(n * 0.10)
te_size = int(n * 0.20)

x_train, x_test, x_val = x[:tr_size], x[-te_size:], x[tr_size:tr_size+vl_size]
t_train, t_test, t_val = t[:tr_size], t[-te_size:], t[tr_size:tr_size+vl_size]
e_train, e_test, e_val = e[:tr_size], e[-te_size:], e[tr_size:tr_size+vl_size]

t = outcomes["time"]
e = outcomes["event"]
quantiles = [0.25, 0.5, 0.75]
quantiles = np.quantile(t[e == 1], quantiles)

def dataframe_to_tensor(data):
    """Function that converts a pandas dataframe into a tensor"""
    if isinstance(data, (pd.Series, pd.DataFrame)):
        return data.to_numpy()
    else:
        return torch.from_numpy(data).float()

x_val_tensor = dataframe_to_tensor(x_val)
t_val_tensor = dataframe_to_tensor(t_val)
e_val_tensor = dataframe_to_tensor(e_val)

In [11]:
train_data = (x_train, t_train, e_train)
val_data_tensor = (x_val_tensor, t_val_tensor, e_val_tensor)
val_data = (x_val, t_val, e_val)
test_data = (x_test, t_test, e_test)

## DCM model

In [12]:
from auton_survival.models.dcm import DeepCoxMixtures
from auton_survival.models.dcm.dcm_utilities import test_step

# hyperparameters according to the paper
DCM_param_grid = {"k" : [3, 4, 6],
                  "learning_rate" : [1e-3],
                  "layers" : [[50], [100], [50, 50], [100, 100]],
                  "batch_size": [128]
                  }

DCM_params = ParameterGrid(DCM_param_grid)

In [13]:
class DCM_Wrapper(object):
    def __init__(self, params_grid):
        self.params_grid = params_grid
        self.model = None

    def fit(self, train_set, val_set):

        x_train, t_train, e_train = train_set
        x_val, t_val, e_val = val_set
        x_val_tensor = dataframe_to_tensor(x_val)
        t_val_tensor = dataframe_to_tensor(t_val)
        e_val_tensor = dataframe_to_tensor(e_val)

        models = []
        for param in self.params_grid:
            model = DeepCoxMixtures(k=param["k"],
                                    layers=param["layers"])
            # The fit method is called to train the model
            model.fit(x_train, t_train, e_train,
                      iters=100,
                      learning_rate=param["learning_rate"],
                      batch_size=param["batch_size"])

            # store the performance on the validation set
            breslow_splines = model.torch_model[1]
            val_result = test_step(model.torch_model[0], x_val_tensor, t_val_tensor, e_val_tensor, breslow_splines)
            models.append([[val_result, model]])

        best_model = min(models)
        self.model = best_model[0][1]

    def predict(self, test_set):

        x_test, t_test, e_test = test_set

        out_survival = self.model.predict_survival(x_test, times)
        out_risk = 1 - out_survival

        return out_survival, out_risk


## DSM model

In [14]:
from auton_survival.models.dsm import DeepSurvivalMachines

DSM_param_grid = {"distribution": ['Weibull'],
                  "k": [3, 4],
                  "layers": [[], [50], [50, 50], [100], [100, 100]],
                  "batch_size": [128, 256],
                  "learning_rate": [ 1e-4, 1e-3],
                  "activation": ["SeLu"]
             }
DSM_params = ParameterGrid(DSM_param_grid)

In [24]:
class DSM_Wrapper(object):
    def __init__(self, params_grid):
        self.params_grid = params_grid
        self.model = None

    def fit(self, train_set, val_set):

        models = []
        x_train, t_train, e_train = train_set
        x_val, t_val, e_val = val_set
        for param in self.params_grid:
            model = DeepSurvivalMachines(k=param['k'],
                                 distribution=param['distribution'],
                                 layers=param['layers'])

            model.fit(x_train, t_train, e_train, iters=100, learning_rate=param['learning_rate'])
            models.append([model.compute_nll(x_val, t_val, e_val), model])

        best_model_entry = min(models, key=lambda x: x[0])

        # Extract the model
        self.model = best_model_entry[1]

    def predict(self, test_set):

        x_test, t_test, e_test = test_set

        out_survival = self.model.predict_survival(x_test, times)
        out_risk = 1 - self.model.predict_risk(x_test, times)

        return out_survival, out_risk


## Benchmark Function

In [23]:
import time
import pickle
from sksurv.metrics import concordance_index_ipcw, brier_score, cumulative_dynamic_auc

def benchmark_model(name, model_wrap, train_set, val_set, test_set):
    result = {'Model': name}

    try:
        start = time.time()
        model_wrap.fit(train_set=train_set, val_set=val_set)
        result['Train Time'] = time.time() - start
        print("Fit complete!")

        start = time.time()
        survival, risk = model_wrap.predict(test_set)
        result['Predict Time'] = time.time() - start
        print("Predict complete!")

        cis = []
        brs = []

        et_train = np.array([(e_train[i], t_train[i]) for i in range(len(e_train))],
                         dtype = [('e', bool), ('t', float)])
        et_test = np.array([(e_test[i], t_test[i]) for i in range(len(e_test))],
                         dtype = [('e', bool), ('t', float)])
        et_val = np.array([(e_val[i], t_val[i]) for i in range(len(e_val))],
                         dtype = [('e', bool), ('t', float)])

        for i, _ in enumerate(times):
            cis.append(concordance_index_ipcw(et_train, et_test, risk[:, i], times[i])[0])

        brs.append(brier_score(et_train, et_test, survival, times)[1])
        roc_auc = []

        for i, _ in enumerate(times):
            roc_auc.append(cumulative_dynamic_auc(et_train, et_test, risk[:, i], times[i])[0])

        for horizon in enumerate(horizons):
            result[f"{horizon[1]} quantile TD Concordance Index"] = cis[horizon[0]]
            result[f"{horizon[1]} quantile Brier Score"] = brs[0][horizon[0]]
            result[f"{horizon[1]} quantile ROC AUC"] = roc_auc[horizon[0]][0]

    except Exception as e:
        result['Error'] = str(e)

    return result


In [25]:
dcm_wrap = DCM_Wrapper(DCM_params)
dsm_wrap = DSM_Wrapper(DSM_params)

models_wrap = [
    ("DCM", dcm_wrap),
    ("DSM", dsm_wrap),
]

results = [benchmark_model(name, model_wrap, train_data, val_data, test_data)
           for name, model_wrap in models_wrap]

# save fitted model
for name, model in models_wrap:
    with open(f'{name}.pkl', 'wb') as f:
        pickle.dump(model.model, f)

pd.DataFrame(results)

  probs = gates+np.log(event_probs)
  probs = gates+np.log(event_probs)
  return spl(ts)**risks
  s0ts = (-risks)*(spl(ts)**(risks-1))
 48%|████▊     | 48/100 [00:17<00:18,  2.79it/s]
 67%|██████▋   | 67/100 [00:31<00:15,  2.10it/s]
 30%|███       | 30/100 [00:14<00:33,  2.11it/s]
 61%|██████    | 61/100 [00:32<00:20,  1.89it/s]
 34%|███▍      | 34/100 [00:16<00:31,  2.12it/s]
 61%|██████    | 61/100 [00:27<00:17,  2.25it/s]
 20%|██        | 20/100 [00:10<00:42,  1.89it/s]
 38%|███▊      | 38/100 [00:19<00:31,  1.97it/s]
 44%|████▍     | 44/100 [00:24<00:30,  1.83it/s]
 56%|█████▌    | 56/100 [00:31<00:24,  1.77it/s]
 19%|█▉        | 19/100 [00:11<00:49,  1.62it/s]
 48%|████▊     | 48/100 [00:30<00:33,  1.57it/s]


Fit complete!
Predict complete!


 18%|█▊        | 1797/10000 [00:03<00:16, 510.13it/s]
100%|██████████| 100/100 [00:12<00:00,  7.74it/s]
 18%|█▊        | 1797/10000 [00:03<00:16, 510.28it/s]
 92%|█████████▏| 92/100 [00:12<00:01,  7.64it/s]
 18%|█▊        | 1797/10000 [00:03<00:16, 507.20it/s]
 93%|█████████▎| 93/100 [00:13<00:01,  6.70it/s]
 18%|█▊        | 1797/10000 [00:03<00:16, 505.56it/s]
 27%|██▋       | 27/100 [00:04<00:11,  6.46it/s]
 18%|█▊        | 1797/10000 [00:03<00:15, 512.86it/s]
 67%|██████▋   | 67/100 [00:11<00:05,  6.08it/s]
 18%|█▊        | 1797/10000 [00:03<00:16, 506.66it/s]
 23%|██▎       | 23/100 [00:04<00:13,  5.66it/s]
 18%|█▊        | 1797/10000 [00:03<00:16, 500.64it/s]
 93%|█████████▎| 93/100 [00:14<00:01,  6.46it/s]
 18%|█▊        | 1797/10000 [00:03<00:16, 507.53it/s]
 17%|█▋        | 17/100 [00:02<00:13,  6.03it/s]
 18%|█▊        | 1797/10000 [00:03<00:16, 512.42it/s]
 60%|██████    | 60/100 [00:11<00:07,  5.22it/s]
 18%|█▊        | 1797/10000 [00:03<00:16, 511.88it/s]
  9%|▉         | 9

Fit complete!
Predict complete!





Unnamed: 0,Model,Train Time,Predict Time,0.25 quantile TD Concordance Index,0.25 quantile Brier Score,0.25 quantile ROC AUC,0.5 quantile TD Concordance Index,0.5 quantile Brier Score,0.5 quantile ROC AUC,0.75 quantile TD Concordance Index,0.75 quantile Brier Score,0.75 quantile ROC AUC
0,DCM,266.639422,0.000998,0.771127,0.108708,0.781708,0.717221,0.178909,0.741438,0.673002,0.220844,0.716491
1,DSM,509.01891,0.005984,0.23573,0.107729,0.227737,0.298051,0.17895,0.274391,0.330637,0.215506,0.285361
