In [30]:
import numpy as np
from easydict import EasyDict
from collections import defaultdict
from sksurv.metrics import concordance_index_ipcw

import torch # For building the networks 
import torchtuples as tt # Some useful functions

from pycox.models import PCHazard


In [31]:
np.random.seed(1234)
_ = torch.manual_seed(123)

In [32]:
from survtrace.dataset import load_data

# define the setup parameters
pc_hazard_config = EasyDict({
    'data': 'metabric',
    'horizons': [.25, .5, .75],
    'batch_size': 64,
    'learning_rate': 0.01,
    'epochs': 50,
    'hidden_size': 32
})
pc_hazard_config = EasyDict({
    'data': 'support',
    'horizons': [.25, .5, .75],
    'batch_size': 128,
    'learning_rate': 0.01,
    'epochs': 50,
    'hidden_size': 32
})


# load data
df, df_train, df_y_train, df_test, df_y_test, df_val, df_y_val = load_data(pc_hazard_config)

x_train = np.array(df_train, dtype='float32')
x_val = np.array(df_val, dtype='float32')
x_test = np.array(df_test, dtype='float32')

y_df_to_tuple = lambda df: tuple([np.array(df['duration'], dtype='int64'), np.array(df['event'], dtype='float32'), np.array(df['proportion'], dtype='float32')])

y_train = y_df_to_tuple(df_y_train)
y_val = y_df_to_tuple(df_y_val)

In [33]:
pc_hazard_config

{'data': 'support',
 'horizons': [0.25, 0.5, 0.75],
 'batch_size': 128,
 'learning_rate': 0.01,
 'epochs': 50,
 'hidden_size': 32,
 'labtrans': <survtrace.utils.LabelTransform at 0x7f97186872e8>,
 'num_numerical_feature': 8,
 'num_categorical_feature': 6,
 'num_feature': 14,
 'vocab_size': 25,
 'duration_index': array([   0.  ,   14.  ,   57.  ,  250.25, 2029.  ]),
 'out_feature': 4}

In [34]:
hidden_size = pc_hazard_config.hidden_size
batch_norm = True
dropout = 0.1

net = torch.nn.Sequential(
    torch.nn.Linear(pc_hazard_config.num_feature, hidden_size),
    torch.nn.ReLU(),
    torch.nn.BatchNorm1d(hidden_size),
    torch.nn.Dropout(0.1),
    
    torch.nn.Linear(hidden_size, hidden_size),
    torch.nn.ReLU(),
    torch.nn.BatchNorm1d(hidden_size),
    torch.nn.Dropout(0.1),
    
    torch.nn.Linear(hidden_size, pc_hazard_config.out_feature)
)

In [35]:
model = PCHazard(net, tt.optim.Adam, duration_index=np.array(pc_hazard_config['duration_index'], dtype='float32'))
model.optimizer.set_lr(pc_hazard_config.learning_rate)

In [36]:
callbacks = [tt.callbacks.EarlyStopping()]
log = model.fit(x_train, y_train, pc_hazard_config.batch_size, pc_hazard_config.epochs, callbacks, val_data=tuple([x_val, y_val]))


0:	[0s / 0s],		train_loss: 1.6182,	val_loss: 1.5877
1:	[0s / 0s],		train_loss: 1.3642,	val_loss: 1.4273
2:	[0s / 1s],		train_loss: 1.3511,	val_loss: 1.3700
3:	[0s / 1s],		train_loss: 1.3372,	val_loss: 1.3812
4:	[0s / 1s],		train_loss: 1.3380,	val_loss: 1.3655
5:	[0s / 2s],		train_loss: 1.3348,	val_loss: 1.3437
6:	[0s / 2s],		train_loss: 1.3275,	val_loss: 1.3512
7:	[0s / 3s],		train_loss: 1.3289,	val_loss: 1.3573
8:	[0s / 3s],		train_loss: 1.3304,	val_loss: 1.3643
9:	[0s / 3s],		train_loss: 1.3257,	val_loss: 1.3794
10:	[0s / 3s],		train_loss: 1.3257,	val_loss: 1.5841
11:	[0s / 4s],		train_loss: 1.3235,	val_loss: 1.3900
12:	[0s / 4s],		train_loss: 1.3193,	val_loss: 1.3515
13:	[0s / 4s],		train_loss: 1.3189,	val_loss: 1.3436
14:	[0s / 5s],		train_loss: 1.3214,	val_loss: 1.3449
15:	[0s / 5s],		train_loss: 1.3251,	val_loss: 1.3619
16:	[0s / 5s],		train_loss: 1.3187,	val_loss: 1.3492
17:	[0s / 6s],		train_loss: 1.3187,	val_loss: 1.4312
18:	[0s / 6s],		train_loss: 1.3197,	val_loss: 1.3921
19:

In [37]:
class Evaluator:
    def __init__(self, df, train_index):
        '''the input duration_train should be the raw durations (continuous),
        NOT the discrete index of duration.
        '''
        self.df_train_all = df.loc[train_index]

    def eval_single(self, model, test_set, config, val_batch_size=None):
        df_train_all = self.df_train_all
        get_target = lambda df: (df['duration'].values, df['event'].values)
        durations_train, events_train = get_target(df_train_all)
        print('durations_train', durations_train)
        et_train = np.array([(events_train[i], durations_train[i]) for i in range(len(events_train))],
                        dtype = [('e', bool), ('t', float)])
        print('et_train', et_train)
        times = config['duration_index'][1:-1]
        print('times', times)
        horizons = config['horizons']

        df_test, df_y_test = test_set
        surv = model.predict_surv_df(df_test, batch_size=val_batch_size)
        risk = np.array((1 - surv).transpose())
        print('risk', risk)
        
        durations_test, events_test = get_target(df_y_test)
        print('durations_test', durations_test)
        print('events_test', events_test)
        et_test = np.array([(events_test[i], durations_test[i]) for i in range(len(events_test))],
                    dtype = [('e', bool), ('t', float)])
        print('et_test', et_test)
        metric_dict = defaultdict(list)
        cis = []
        for i, _ in enumerate(times):
            print('iteration', i)
            print('risk', risk[:, i+1])
            print(times)
            cis.append(
                concordance_index_ipcw(et_train, et_test, estimate=risk[:, i+1], tau=times[i])[0]
                )
            metric_dict[f'{horizons[i]}_ipcw'] = cis[i]


        for horizon in enumerate(horizons):
            print(f"For {horizon[1]} quantile,")
            print("TD Concordance Index - IPCW:", cis[horizon[0]])
        
        return metric_dict

In [38]:
evaluator = Evaluator(df, df_train.index)
evaluator.eval_single(model, (x_test, df_y_test), config=pc_hazard_config)

durations_train [ 30. 892.   7. ...  36.   6. 879.]
et_train [( True,  30.) (False, 892.) ( True,   7.) ... ( True,  36.) ( True,   6.)
 (False, 879.)]
times [ 14.    57.   250.25]
risk [[0.         0.31434375 0.563365   0.7165142  0.85991776]
 [0.         0.22974098 0.41427994 0.50047255 0.6402138 ]
 [0.         0.19131464 0.34098452 0.46686047 0.71914124]
 ...
 [0.         0.20016217 0.3747046  0.51264817 0.86809427]
 [0.         0.1685158  0.35801423 0.733208   0.9536221 ]
 [0.         0.10061115 0.25156212 0.4210931  0.8292403 ]]
durations_test [  31.  827.   79. ...  640.   51. 1388.]
events_test [1 0 1 ... 0 1 1]
et_test [( True,   31.) (False,  827.) ( True,   79.) ... (False,  640.)
 ( True,   51.) ( True, 1388.)]
iteration 0
risk [0.31434375 0.22974098 0.19131464 ... 0.20016217 0.1685158  0.10061115]
[ 14.    57.   250.25]
iteration 1
risk [0.563365   0.41427994 0.34098452 ... 0.3747046  0.35801423 0.25156212]
[ 14.    57.   250.25]
iteration 2
risk [0.7165142  0.50047255 0.46

defaultdict(list,
            {'0.25_ipcw': 0.6537429619072447,
             '0.5_ipcw': 0.6216117828439424,
             '0.75_ipcw': 0.6113766266416105})