In [1]:
import pandas as pd
from sklearn.preprocessing import StandardScaler
import numpy as np
from model import DeepHit
from losses import *
from utils import *
import tensorflow as tf

tf.enable_eager_execution()

In [2]:
df = pd.read_csv("./data/metabric.csv")

In [3]:
event = np.asarray(df[['label']])
time = np.asarray(df[['event_time']])
dhdata = np.asarray(df.iloc[:,:-2])

In [4]:
scaler = StandardScaler()
df_trans = scaler.fit_transform(dhdata)

In [5]:
num_category = int(np.max(time) * 1.2)        #to have enough time-horizon
num_event = int(len(np.unique(event)) - 1) #only count the number of events (do not count censoring as an event)
x_dim = df_trans.shape[1]
mask1 = compute_mask1(time, event, num_event, num_category)
mask2 = compute_mask2(time, -1, num_category)

In [6]:
parameters = get_random_hyperparameters()

In [7]:
deephit = DeepHit(num_layers_shared=parameters['num_layers_shared'], h_dim_shared=parameters['h_dim_shared'],
                  activation=parameters['active_fn'], dropout_rate=parameters['dropout'],
                num_layers_cs=parameters['num_layers_CS'], h_dim_cs=parameters['h_dim_CS']
                  , num_event=num_event, num_category=num_category)
optimizer = tf.keras.optimizers.Adam(parameters['lr_train'])

In [8]:
def run_optimization(x, event, time, mask1, mask2, alpha, beta, gamma):
    # Wrap computation inside a GradientTape for automatic differentiation.
    with tf.GradientTape() as g:
        pred = deephit(x)
        loss1 = loss_log_likelihood(pred, mask1, event)
        loss2 = loss_ranking(pred, mask2, time, event, num_event, num_category)
        loss3 = loss_calibration(pred, mask2, time, event, num_event, num_category)
        total_loss = alpha * loss1 + beta * loss2 + gamma * loss3
        # Compute gradients.
        gradients = g.gradient(total_loss, deephit.trainable_variables)

        # Update W and b following gradients.
        optimizer.apply_gradients(zip(gradients, deephit.trainable_variables))

In [9]:
# Use tf.data API to shuffle and batch data.
train_data = tf.data.Dataset.from_tensor_slices((df_trans, event, time, mask1, mask2))
train_data = train_data.repeat().shuffle(5000).batch(parameters["mb_size"]).prefetch(1)

In [11]:
# Run training for the given number of steps.
for step, (batch_x, batch_event, batch_time, batch_mask1, batch_mask2) in enumerate(train_data.take(parameters['iteration']), 1):
    # Run the optimization to update W and b values.
    run_optimization(batch_x, batch_event, batch_time, batch_mask1, batch_mask2,
                    parameters["alpha"], parameters["beta"], parameters["gamma"])
    display_step = 100
    if step % display_step == 0:
        pred = deephit(batch_x)
        loss1 = loss_log_likelihood(pred, batch_mask1, batch_event)
        loss2 = loss_ranking(pred, batch_mask2, batch_time, batch_event, num_event, num_category)
        loss3 = loss_calibration(pred, batch_mask2, batch_time, batch_event, num_event, num_category)
        total_loss = parameters["alpha"] * loss1 + parameters["beta"] * loss2 + parameters["gamma"] * loss3
        print("step: %i, loss: %f" % (step, total_loss))

step: 100, loss: 35.687234
step: 200, loss: 24.152375
step: 300, loss: 39.046767
step: 400, loss: 35.061722
step: 500, loss: 29.683415
step: 600, loss: 41.539700
step: 700, loss: 43.571146
step: 800, loss: 34.950887
step: 900, loss: 33.125334
step: 1000, loss: 40.911422
step: 1100, loss: 24.935297
step: 1200, loss: 33.998720
step: 1300, loss: 38.100834
step: 1400, loss: 26.929088
step: 1500, loss: 37.150350
step: 1600, loss: 27.311292
step: 1700, loss: 30.535820
step: 1800, loss: 31.768735
step: 1900, loss: 44.855879
step: 2000, loss: 35.592344
step: 2100, loss: 30.032639
step: 2200, loss: 29.772686
step: 2300, loss: 28.914038
step: 2400, loss: 28.251952
step: 2500, loss: 31.934848
step: 2600, loss: 35.229896
step: 2700, loss: 26.797281
step: 2800, loss: 32.201593
step: 2900, loss: 35.312597
step: 3000, loss: 29.140213
step: 3100, loss: 33.017831
step: 3200, loss: 26.810098
step: 3300, loss: 26.661194
step: 3400, loss: 29.152359
step: 3500, loss: 25.136743


KeyboardInterrupt: 