In [95]:
import torch

We will use PyCox to import the METABRIC Dataset

In [96]:
from pycox import datasets
df = datasets.metabric.read_df()

Preprocessing, setting Random folds and computing Event Quantiles of Interest 

In [123]:
import numpy as np


dat1  = df[['x0', 'x1', 'x2', 'x3', 'x4', 'x5', 'x6', 'x7', 'x8']]
times = (df['duration'].values+1)
events =  df['event'].values
data = dat1.to_numpy()
folds = np.array([1]*381 + [2]*381 + [3]*381 + [4]*381 + [5]*380 )
np.random.seed(0)
np.random.shuffle(folds)
quantiles = np.quantile(times[events==1], [0.25, .5, .75, .99]).tolist()

In [124]:
#This is a flag that is used to artificially increase the amount of censoring in the 
#dataset to determine robustness of DSM to increased censoring levels.
INCREASE_CENSORING = False

In [125]:
from sklearn.preprocessing import StandardScaler

In [126]:
import importlib
import dsm
import dsm_utilites
importlib.reload(dsm)
importlib.reload(dsm_utilites)

<module 'dsm_utilites' from '/Users/chiragn/Research/ICML2020/DeepSurvivalMachines/dsm_utilites.py'>

In [127]:
float(len(dat1)*9)/10

1713.6

In [128]:
#set parameter grid

params = [{'G':6, 'mlptyp':2,'HIDDEN':[100], 'n_iter':int(1000), 'lr':1e-3, 'ELBO':True, 'mean':False, \
           'lambd':0, 'alpha':1,'thres':1e-3, 'bs':int(25), 'dist': 'Weibull'}]


#set val data size
vsize = int(0.15*1712)

torch.manual_seed(0)

for param in params:

    outs = []

    for f in range(1,6,1):

        x_train = data[folds!=f]
        x_test  = data[folds==f]
        x_valid = x_train[-vsize:, :]
        x_train = x_train[:-vsize, :]

        t_train = times[folds!=f]
        t_test  = times[folds==f]
        t_valid = t_train[-vsize:]
        t_train = t_train[:-vsize]


        e_train = events[folds!=f]
        e_test  = events[folds==f]
        e_valid = e_train[-vsize:]
        e_train = e_train[:-vsize]


        print ("val len:", len(x_valid))

        print ("tr  len:", len(x_train))


        #normalize the feature set using standard scaling

        scl = StandardScaler()
        x_train = scl.fit_transform(x_train)
        x_valid = scl.transform(x_valid)
        x_test = scl.transform(x_test)


        print ("Censoring in Fold:", np.mean(e_train))

        if INCREASE_CENSORING:
            e_train, t_train = increaseCensoring(e_train, t_train, .50)

        print ("Censoring in Fold:", np.mean(e_train))

        #Convert the train, test and validation data torch

        x_train = torch.from_numpy(x_train).double() 
        e_train = torch.from_numpy(e_train).double() 
        t_train = torch.from_numpy(t_train).double() 

        x_valid = torch.from_numpy(x_valid).double() 
        e_valid = torch.from_numpy(e_valid).double() 
        t_valid = torch.from_numpy(t_valid).double() 

        x_test = torch.from_numpy(x_test).double() 
        e_test = torch.from_numpy(e_test).double() 
        t_test = torch.from_numpy(t_test).double() 


        K, mlptyp, HIDDEN, n_iter, lr, ELBO, mean, lambd, alpha, thres, bs,dist = \
        param['G'], param['mlptyp'], param['HIDDEN'], param['n_iter'], param['lr'], \
        param['ELBO'], param['mean'], param['lambd'], param['alpha'], param['thres'],\
        param['bs'], param['dist'] 

        D = x_train.shape[1]
        
        print (dist)
        
        model = dsm.DeepSurvivalMachines(D, K, mlptyp, HIDDEN, dist=dist)
        model.double()
        
        model, i = dsm_utilites.trainDSM(model,quantiles,x_train, t_train, e_train, x_valid, t_valid, e_valid,lr=lr,bs=bs,alpha=alpha )
        
    
        print ("TEST PERFORMANCE")

        out =  (dsm_utilites.computeCIScores(model, quantiles, x_test, t_test, e_test, t_train, e_train))

        print (out)

        outs.append(out)

val len: 256
tr  len: 1267
Censoring in Fold: 0.585635359116022
Censoring in Fold: 0.585635359116022
Weibull
Pretraining the Underlying Distributions...


HBox(children=(FloatProgress(value=0.0, max=10000.0), HTML(value='')))

200.75226152672906 1.2680803184752372


HBox(children=(FloatProgress(value=0.0, max=10000.0), HTML(value='')))


TEST PERFORMANCE
(0.7154317427767943, 0.6613066844703644, 0.6095326277314909, 0.5646184019234619)
val len: 256
tr  len: 1267
Censoring in Fold: 0.5864246250986582
Censoring in Fold: 0.5864246250986582
Weibull
Pretraining the Underlying Distributions...


HBox(children=(FloatProgress(value=0.0, max=10000.0), HTML(value='')))

199.7153561845356 1.2632128253408417


HBox(children=(FloatProgress(value=0.0, max=10000.0), HTML(value='')))

TEST PERFORMANCE
(0.6007865597948178, 0.5885119373707697, 0.586134286427175, 0.601290434891957)
val len: 256
tr  len: 1267
Censoring in Fold: 0.579321231254933
Censoring in Fold: 0.579321231254933
Weibull
Pretraining the Underlying Distributions...


HBox(children=(FloatProgress(value=0.0, max=10000.0), HTML(value='')))

203.71121598683038 1.2771277338630311


HBox(children=(FloatProgress(value=0.0, max=10000.0), HTML(value='')))

TEST PERFORMANCE
(0.7121724313147351, 0.6818384881157109, 0.6577998627814042, 0.6497338850905421)
val len: 256
tr  len: 1267
Censoring in Fold: 0.5785319652722968
Censoring in Fold: 0.5785319652722968
Weibull
Pretraining the Underlying Distributions...


HBox(children=(FloatProgress(value=0.0, max=10000.0), HTML(value='')))

200.38728541667388 1.2916873727801628


HBox(children=(FloatProgress(value=0.0, max=10000.0), HTML(value='')))

TEST PERFORMANCE
(0.7604568637757466, 0.7232963903416784, 0.6617261376642131, 0.646405712483512)
val len: 256
tr  len: 1268
Censoring in Fold: 0.582807570977918
Censoring in Fold: 0.582807570977918
Weibull
Pretraining the Underlying Distributions...


HBox(children=(FloatProgress(value=0.0, max=10000.0), HTML(value='')))

200.8431094911515 1.2591288578083222


HBox(children=(FloatProgress(value=0.0, max=10000.0), HTML(value='')))

TEST PERFORMANCE
(0.7380987174602022, 0.6644269287843599, 0.6575335218015144, 0.6570588366776975)


In [8]:
quantiles

[43.68333435058594, 86.86666870117188, 146.33333587646484, 283.5426806640625]

In [93]:
model.dist

'LogNormal'