In [25]:
import torch

We will use PyCox to import the METABRIC Dataset

In [26]:
from pycox import datasets
df = datasets.metabric.read_df()

Preprocessing, setting Random folds and computing Event Quantiles of Interest 

In [43]:
import numpy as np

dat1  = df[['x0', 'x1', 'x2', 'x3', 'x4', 'x5', 'x6', 'x7', 'x8']]
times = (df['duration'].values+1)
events =  df['event'].values
data = dat1.to_numpy()
folds = np.array([1]*381 + [2]*381 + [3]*381 + [4]*381 + [5]*380 )
np.random.seed(0)
np.random.shuffle(folds)
np.quantile(times[events==1], [0.25, .5, .75, .99]).tolist()

[43.68333435058594, 86.86666870117188, 146.33333587646484, 283.5426806640625]

In [44]:
#This is a flag that is used to artificially increase the amount of censoring in the 
#dataset to determine robustness of DSM to increased censoring levels.
INCREASE_CENSORING = False

In [45]:
from sklearn.preprocessing import StandardScaler

In [74]:
import importlib
import dsm
import dsm_utilites
importlib.reload(dsm)
importlib.reload(dsm_utilites)

<module 'dsm_utilites' from '/Users/chiragn/Research/ICML2020/DeepSurvivalMachines/dsm_utilites.py'>

In [None]:
#set parameter grid
params = [{'G':6, 'mlptyp':2,'HIDDEN':[100], 'n_iter':int(1000), 'lr':1e-3, 'ELBO':True, 'mean':False, \
           'lambd':0, 'alpha':1,'thres':1e-3, 'bs':int(25)}]


#set val data size
vsize = int(0.15*1523)

torch.manual_seed(0)

for param in params:

    outs = []

    for f in range(1,6,1):

        x_train = data[folds!=f]
        x_test  = data[folds==f]
        x_valid = x_train[-vsize:, :]
        x_train = x_train[:-vsize, :]

        t_train = times[folds!=f]
        t_test  = times[folds==f]
        t_valid = t_train[-vsize:]
        t_train = t_train[:-vsize]


        e_train = events[folds!=f]
        e_test  = events[folds==f]
        e_valid = e_train[-vsize:]
        e_train = e_train[:-vsize]


        print ("val len:", len(x_valid))

        print ("tr  len:", len(x_train))


        #normalize the feature set using standard scaling

        scl = StandardScaler()
        x_train = scl.fit_transform(x_train)
        x_valid = scl.transform(x_valid)
        x_test = scl.transform(x_test)


        print ("Censoring in Fold:", np.mean(e_train))

        if INCREASE_CENSORING:
            e_train, t_train = increaseCensoring(e_train, t_train, .50)

        print ("Censoring in Fold:", np.mean(e_train))

        #Convert the train, test and validation data torch

        x_train = torch.from_numpy(x_train).double() 
        e_train = torch.from_numpy(e_train).double() 
        t_train = torch.from_numpy(t_train).double() 

        x_valid = torch.from_numpy(x_valid).double() 
        e_valid = torch.from_numpy(e_valid).double() 
        t_valid = torch.from_numpy(t_valid).double() 

        x_test = torch.from_numpy(x_test).double() 
        e_test = torch.from_numpy(e_test).double() 
        t_test = torch.from_numpy(t_test).double() 


        K, mlptyp, HIDDEN, n_iter, lr, ELBO, mean, lambd, alpha, thres, bs = \
        param['G'], param['mlptyp'], param['HIDDEN'], param['n_iter'], param['lr'], \
        param['ELBO'], param['mean'], param['lambd'], param['alpha'], param['thres'], param['bs']

        D = x_train.shape[1]
        
        model = dsm.DeepSurvivalMachines(D, K, mlptyp, HIDDEN, dist='Weibull')
        model.double()
        
        model, i = dsm_utilites.trainDSM(model,quantiles,x_train, t_train, e_train, x_valid, t_valid, e_valid,lr=lr,bs=bs,alpha=alpha )
        
    
        print ("TEST PERFORMANCE")

        out =  (dsm_utilites.computeCIScores(model, quantiles, x_test, t_test, e_test, t_train, e_train))

        print (out)

        outs.append(out)

val len: 228
tr  len: 1295
Censoring in Fold: 0.5814671814671815
Censoring in Fold: 0.5814671814671815
Pretraining the Underlying Distributions...


HBox(children=(FloatProgress(value=0.0, max=1000.0), HTML(value='')))


202.8469264171793 1.2680954189977467


HBox(children=(FloatProgress(value=0.0, max=10000.0), HTML(value='')))

3.5295747584993196 (0.7436527098215582, 0.6811342904030737, 0.5970621901813846, 0.611855947999837)
3.515131386970598 (0.7386963678694428, 0.6716928435166729, 0.5849140201392863, 0.62534764825785)
3.5158226769628036 (0.7318455746593371, 0.6699128823791801, 0.5911467483533817, 0.6331405037220301)
3.515479804334142 (0.7297725821056623, 0.672073403791732, 0.5936461114676398, 0.6369502882527434)
3.5152090747964597 (0.7299993252087804, 0.6723811713623161, 0.5948084129361593, 0.6397806508689314)
3.515229667328146 (0.727796010086639, 0.6730755262613353, 0.5972531310617756, 0.6420898707648605)
3.515459661315446 (0.7282370222892353, 0.6703104281959925, 0.5979750789683245, 0.6429366444808753)
3.5152011552733837 (0.7284817501550502, 0.6689068376100378, 0.5977576124523662, 0.6441363144872003)
3.5156224070684705 (0.7298338654512917, 0.6682243802957398, 0.5976506274847705, 0.645576700159252)
3.516007568690427 (0.7302225119493331, 0.6672430764040999, 0.5983333250469652, 0.6398588975503352)
3.516402141

In [8]:
quantiles

[43.68333435058594, 86.86666870117188, 146.33333587646484, 283.5426806640625]