# Neural Fine Gray on SUPPORT Dataset

The SUPPORT dataset comes from the Vanderbilt University study
to estimate survival for seriously ill hospitalized adults.
(Refer to http://biostat.mc.vanderbilt.edu/wiki/Main/SupportDesc.
for the original datasource.)

In this notebook, we will apply Neural Fine Gray on the SUPPORT data.

In [1]:
import sys
sys.path.append('../')
sys.path.append('../DeepSurvivalMachines/')

### Load the SUPPORT Dataset

The package includes helper functions to load the dataset.

X represents an np.array of features (covariates),
T is the event/censoring times and,
E is the censoring indicator.

In [2]:
from nfg import datasets
x, t, e, columns = datasets.load_dataset('FRAMINGHAM', competing = True)

### Compute horizons at which we evaluate the performance of DSM

Survival predictions are issued at certain time horizons. Here we will evaluate the performance
of DSM to issue predictions at the 25th, 50th and 75th event time quantile as is standard practice in Survival Analysis.

In [3]:
import numpy as np
import torch
np.random.seed(42)
torch.random.manual_seed(42)

horizons = [0.25, 0.5, 0.75]
times = np.quantile(t[e!=0], horizons).tolist()

In [4]:
# Display the percentage of observed event at different time horizon
for time in times:
    print('At time {:.2f}'.format(time))
    for risk in np.unique(e):
        print('\t {:.2f} % observed risk {}'.format(100 * ((e == risk) & (t < time)).mean(), risk))

At time 2153.75
	 0.00 % observed risk 0
	 8.30 % observed risk 1
	 2.66 % observed risk 2
At time 4589.50
	 0.00 % observed risk 0
	 14.52 % observed risk 1
	 7.40 % observed risk 2
At time 6620.75
	 0.00 % observed risk 0
	 20.32 % observed risk 1
	 12.56 % observed risk 2


### Splitting the data into train, test and validation sets

We will train NSC on 80% of the Data (10 % of which is used for stopping criterion and 10% for model Selection) and report performance on the remaining 20% held out test set.

In [5]:
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import MinMaxScaler

x_train, x_test, t_train, t_test, e_train, e_test = train_test_split(x, t, e, test_size = 0.2, random_state = 42)
x_train, x_val, t_train, t_val, e_train, e_val = train_test_split(x_train, t_train, e_train, test_size = 0.2, random_state = 42)
x_dev, x_val, t_dev, t_val, e_dev, e_val = train_test_split(x_val, t_val, e_val, test_size = 0.5, random_state = 42)

ss = MinMaxScaler().fit(t_train.reshape(-1, 1))
t_train_ddh = ss.transform(t_train.reshape(-1, 1)).flatten()
t_dev_ddh = ss.transform(t_dev.reshape(-1, 1)).flatten()
t_val_ddh = ss.transform(t_val.reshape(-1, 1)).flatten()
times_ddh = ss.transform(np.array(times).reshape(-1, 1)).flatten()

### Setting the parameter grid

Lets set up the parameter grid to tune hyper-parameters. We will tune the number of underlying survival distributions, 
($K$), the distribution choices (Log-Normal or Weibull), the learning rate for the Adam optimizer between $1\times10^{-3}$ and $1\times10^{-4}$ and the number of hidden layers between $0, 1$ and $2$.

In [6]:
from sklearn.model_selection import ParameterSampler

In [7]:
layers = [[50], [50, 50], [50, 50, 50], [100], [100, 100], [100, 100, 100]]
param_grid = {
            'learning_rate' : [1e-3, 1e-4],
            'layers_surv': layers,
            'layers' : layers,
            'act': ['Tanh'],
            'batch': [100, 250],
            }
params = ParameterSampler(param_grid, 5, random_state = 42)

### Model Training and Selection

In [8]:
from nfg import NeuralFineGray

In [9]:
models = []
for param in params:
    model = NeuralFineGray(layers = param['layers'], act = param['act'], layers_surv = param['layers_surv'])
    # The fit method is called to train the model
    model.fit(x_train, t_train_ddh, e_train, n_iter = 500, bs = param['batch'], 
            lr = param['learning_rate'], val_data = (x_dev, t_dev_ddh, e_dev))
    nll = model.compute_nll(x_val, t_val_ddh, e_val)
    if not(np.isnan(nll)) and 0 < nll:
        models.append([nll, model])
    else:
        print("WARNING: Nan Value Observed")

Loss: 3.719:  41%|████▏     | 207/500 [00:16<00:23, 12.21it/s]
Loss: 3.713:  12%|█▏        | 59/500 [00:08<00:59,  7.37it/s]


KeyboardInterrupt: 

In [10]:
best_model = min(models, key = lambda x: x[0])
model = best_model[1]

### Inference

Model prediction for the different patients and analysis of the results

In [11]:
out_risk = model.predict_risk(x_test, times_ddh.tolist())
out_survival = model.predict_survival(x_test, times_ddh.tolist())

### Evaluation

We evaluate the performance of NSC in its discriminative ability (Time Dependent Concordance Index and Cumulative Dynamic AUC) as well as Brier Score.

In [12]:
from sksurv.metrics import concordance_index_ipcw, brier_score, cumulative_dynamic_auc

In [17]:
risk = 2

In [18]:
et_train = np.array([(e_train[i] == risk, t_train[i]) for i in range(len(e_train))],
                 dtype = [('e', bool), ('t', float)])
et_test = np.array([(e_test[i] == risk, t_test[i]) for i in range(len(e_test))],
                 dtype = [('e', bool), ('t', float)])
selection = (t_test < t_train.max()) | (e_test == 0)

cis = []
for i, _ in enumerate(times):
    cis.append(concordance_index_ipcw(et_train, et_test[selection], out_risk[:, i][selection], times[i])[0])
brs = brier_score(et_train, et_test[selection], out_survival[selection], times)[1]
roc_auc = []
for i, _ in enumerate(times):
    roc_auc.append(cumulative_dynamic_auc(et_train, et_test[selection], out_risk[:, i][selection], times[i])[0])
for horizon in enumerate(horizons):
    print(f"For {horizon[1]} quantile,")
    print("TD Concordance Index:", cis[horizon[0]])
    print("Brier Score:", brs[horizon[0]])
    print("ROC AUC ", roc_auc[horizon[0]][0], "\n")

For 0.25 quantile,
TD Concordance Index: 0.6738837188386445
Brier Score: 0.03318377171458241
ROC AUC  0.6824688563278923 

For 0.5 quantile,
TD Concordance Index: 0.7267959507792772
Brier Score: 0.060932342459839015
ROC AUC  0.751689690754543 

For 0.75 quantile,
TD Concordance Index: 0.6750281517154372
Brier Score: 0.11688161648993128
ROC AUC  0.7009334828845468 

