In [25]:
import torch

We will use PyCox to import the METABRIC Dataset

In [26]:
from pycox import datasets
df = datasets.metabric.read_df()

Preprocessing, setting Random folds and computing Event Quantiles of Interest 

In [27]:
import numpy as np


dat1  = df[['x0', 'x1', 'x2', 'x3', 'x4', 'x5', 'x6', 'x7', 'x8']]
times = (df['duration'].values+1)
events =  df['event'].values
data = dat1.to_numpy()
folds = np.array([1]*191 + [2]*191 + [3]*191 + [4]*191 + [5]*190 + [6]*190 + [7]*190 + [8]*190 + [9]*190 + [10]*190 )
np.random.seed(100)
np.random.shuffle(folds)
quantiles = np.quantile(times[events==1], [0.25, .5, .75, .99]).tolist()

In [28]:
#This is a flag that is used to artificially increase the amount of censoring in the 
#dataset to determine robustness of DSM to increased censoring levels.
INCREASE_CENSORING = False

In [29]:
from sklearn.preprocessing import StandardScaler

In [32]:
import importlib
import dsm
import dsm_utilites
importlib.reload(dsm)
importlib.reload(dsm_utilites)

<module 'dsm_utilites' from '/Users/chiragn/Research/ICML2020/DeepSurvivalMachines/dsm_utilites.py'>

In [33]:
#set parameter grid
params = [{'G':4, 'mlptyp':1,'HIDDEN':[], 'n_iter':int(500), 'lr':1e-3, 'ELBO':True, 'mean':False, \
            'lambd':0, 'alpha':1,'thres':1e-3, 'bs':int(25)}] 

#set val data size
vsize = int(0.15*1523)

torch.manual_seed(0)

for param in params:

    outs = []

    for f in range(1,6,1):

        x_train = data[folds!=f]
        x_test  = data[folds==f]
        x_valid = x_train[-vsize:, :]
        x_train = x_train[:-vsize, :]

        t_train = times[folds!=f]
        t_test  = times[folds==f]
        t_valid = t_train[-vsize:]
        t_train = t_train[:-vsize]


        e_train = events[folds!=f]
        e_test  = events[folds==f]
        e_valid = e_train[-vsize:]
        e_train = e_train[:-vsize]


        print ("val len:", len(x_valid))

        print ("tr  len:", len(x_train))


        #normalize the feature set using standard scaling

        scl = StandardScaler()
        x_train = scl.fit_transform(x_train)
        x_valid = scl.transform(x_valid)
        x_test = scl.transform(x_test)


        print ("Censoring in Fold:", np.mean(e_train))

        if INCREASE_CENSORING:
            e_train, t_train = increaseCensoring(e_train, t_train, .50)

        print ("Censoring in Fold:", np.mean(e_train))

        #Convert the train, test and validation data torch

        x_train = torch.from_numpy(x_train).double() 
        e_train = torch.from_numpy(e_train).double() 
        t_train = torch.from_numpy(t_train).double() 

        x_valid = torch.from_numpy(x_valid).double() 
        e_valid = torch.from_numpy(e_valid).double() 
        t_valid = torch.from_numpy(t_valid).double() 

        x_test = torch.from_numpy(x_test).double() 
        e_test = torch.from_numpy(e_test).double() 
        t_test = torch.from_numpy(t_test).double() 


        K, mlptyp, HIDDEN, n_iter, lr, ELBO, mean, lambd, alpha, thres, bs = \
        param['G'], param['mlptyp'], param['HIDDEN'], param['n_iter'], param['lr'], \
        param['ELBO'], param['mean'], param['lambd'], param['alpha'], param['thres'], param['bs']

        D = x_train.shape[1]
        
        model = dsm.DeepSurvivalMachines(D, K, mlptyp, HIDDEN, dist='Weibull')
        model.double()
        
        dsm_utilites.trainDSM(model,quantiles,x_train, t_train, e_train, x_valid, t_valid, e_valid,lr=lr,bs=bs,alpha=alpha )
        
    
        print ("TEST PERFORMANCE")

        out =  (dsm_utilites.computeCIScores(model, quantiles, x_test, t_test, e_test, t_train, e_train))

        print (out)

        outs.append(out)







  0%|          | 0/10000 [00:00<?, ?it/s][A[A[A[A[A[A






  0%|          | 0/10000 [00:00<?, ?it/s][A[A[A[A[A[A

val len: 228
tr  len: 1485
Censoring in Fold: 0.5723905723905723
Censoring in Fold: 0.5723905723905723
Pretraining the Underlying Distributions...








  0%|          | 1/10000 [00:00<36:58,  4.51it/s][A[A[A[A[A[A





  0%|          | 2/10000 [00:00<36:55,  4.51it/s][A[A[A[A[A[A





  0%|          | 3/10000 [00:00<37:07,  4.49it/s][A[A[A[A[A[A





  0%|          | 4/10000 [00:00<37:15,  4.47it/s][A[A[A[A[A[A





  0%|          | 5/10000 [00:01<37:15,  4.47it/s][A[A[A[A[A[A





  0%|          | 6/10000 [00:01<37:19,  4.46it/s][A[A[A[A[A[A





  0%|          | 7/10000 [00:01<37:52,  4.40it/s][A[A[A[A[A[A





  0%|          | 8/10000 [00:01<38:33,  4.32it/s][A[A[A[A[A[A





  0%|          | 9/10000 [00:02<38:19,  4.34it/s][A[A[A[A[A[A





  0%|          | 10/10000 [00:02<37:54,  4.39it/s][A[A[A[A[A[A





  0%|          | 11/10000 [00:02<37:52,  4.40it/s][A[A[A[A[A[A





  0%|          | 12/10000 [00:02<37:52,  4.40it/s][A[A[A[A[A[A





  0%|          | 13/10000 [00:02<37:44,  4.41it/s][A[A[A[A[A[A





  0%|          | 14/10000 [00:03<37:50,  

KeyboardInterrupt: 

In [8]:
quantiles

[43.68333435058594, 86.86666870117188, 146.33333587646484, 283.5426806640625]