In [None]:
import tensorflow as tf
from tensorflow.keras import Model, layers

In [None]:
class FCLayer(tf.keras.layers.Layer):
    def __init__(self, num_layers, h_dim, dropout_rate=None, activation=tf.nn.relu, kernel_regularizer=None):
        super(FCLayer, self).__init__()
        self.num_layers = num_layers
        self.dropout_rate = dropout_rate
        self.fcn = layers.Dense(h_dim, activation=activation,
                                 kernel_initializer='glorot_uniform', kernel_regularizer=kernel_regularizer)
        self.dropout = layers.Dropout(rate=dropout_rate)
        
    def call(self, x):
        for layer in range(self.num_layers):
            x = self.fcn(x)
            if not self.keep_prob is None:
                x = self.dropout(x)
        return x

    

In [None]:
class DeepHit(Model):
    def __init__(self, num_layers_shared, h_dim_shared, activation, dropout_rate, kernel_regularizer,
                num_layers_cs, h_dim_cs, num_event, num_category):
        super(DeepHit, self).__init__()
        self.num_event = num_event
        self.num_category = num_category
        self.shared_net = FCLayer(num_layers_shared, h_dim_shared, dropout_rate, activation, kernel_regularizer)
        self.cs_net = FCLayer(num_layers_cs, h_dim_cs, dropout_rate, activation, kernel_regularizer)
        self.dropout = layers.Dropout(rate=dropout_rate)
        self.out_net = layers.Dense(num_event * num_category, activation=tf.nn.softmax,
                                 kernel_initializer='glorot_uniform', kernel_regularizer=kernel_regularizer)
        
    def call(self, inputs):
        x = self.shared_net(inputs)
        x = layers.Concatenate(axis=1)([inputs, x])
        out = []
        for _ in range(self.num_event):
            cs_out = self.cs_net(x)
            out.append(cs_out)
        out = tf.stack(out, axis=1) # stack referenced on subject
        out = tf.reshape(out, [-1, self.num_event*self.h_dim_cs])
        out = self.dropout(out)
        out = self.out_net(out)
        out = tf.reshape(out, [-1, self.num_event, self.num_category])
        return out

In [None]:
# Note that this will apply 'softmax' to the logits.
def loss_Log_likelihood(x, y, event):
    I_1 = tf.math.sign(event)

    #for uncenosred: log P(T=t,K=k|x)
    tmp1 = tf.math.reduce_sum(tf.math.reduce_sum(x * y, reduction_indices=2), reduction_indices=1, keep_dims=True)
    tmp1 = I_1 * log(tmp1)

    #for censored: log \sum P(T>t|x)
    tmp2 = tf.math.reduce_sum(tf.math.reduce_sum(x, y, reduction_indices=2), reduction_indices=1, keep_dims=True)
    tmp2 = (1. - I_1) * log(tmp2)

    return - tf.math.reduce_mean(tmp1 + 1.0*tmp2)

# Accuracy metric.
def loss_ranking(time, event, num_event, num_category, x, y):
    sigma1 = tf.constant(0.1, dtype=tf.float32)
    eta = []
    for e in range(num_event):
        one_vector = tf.ones_like(time, dtype=tf.float32)
        I_2 = tf.cast(tf.math.equal(event, e+1), dtype = tf.float32) #indicator for event
        I_2 = tf.linalg.diag(tf.squeeze(I_2))
        tmp_e = tf.reshape(tf.slice(y, [0, e, 0], [-1, 1, -1]), [-1, num_category]) #event specific joint prob.

        R = tf.linalg.matmul(tmp_e, tf.transpose(x)) #no need to divide by each individual dominator
        # r_{ij} = risk of i-th pat based on j-th time-condition (last meas. time ~ event time) , i.e. r_i(T_{j})

        diag_R = tf.reshape(tf.linalg.diag_part(R), [-1, 1])
        R = tf.linalg.matmul(one_vector, tf.transpose(diag_R)) - R # R_{ij} = r_{j}(T_{j}) - r_{i}(T_{j})
        R = tf.transpose(R)                                 # Now, R_{ij} (i-th row j-th column) = r_{i}(T_{i}) - r_{j}(T_{i})

        T = tf.nn.relu(tf.math.sign(tf.linalg.matmul(one_vector, tf.transpose(time)) - 
                               tf.linalg.matmul(time, tf.transpose(one_vector))))
        # T_{ij}=1 if t_i < t_j  and T_{ij}=0 if t_i >= t_j

        T = tf.linalg.matmul(I_2, T) # only remains T_{ij}=1 when event occured for subject i

        tmp_eta = tf.math.reduce_mean(T * tf.math.exp(-R/sigma1), reduction_indices=1, keep_dims=True)

        eta.append(tmp_eta)
    eta = tf.stack(eta, axis=1) #stack referenced on subjects
    eta = tf.math.reduce_mean(tf.reshape(eta, [-1, num_event]), reduction_indices=1, keep_dims=True)

    return tf.math.reduce_sum(eta)

def loss_calibration(time, event, num_event, num_category, x, y):
    eta = []
    for e in range(num_event):
        one_vector = tf.ones_like(time, dtype=tf.float32)
        I_2 = tf.cast(tf.math.equal(event, e+1), dtype = tf.float32) #indicator for event
        tmp_e = tf.reshape(tf.slice(y, [0, e, 0], [-1, 1, -1]), [-1, num_category]) #event specific joint prob.

        r = tf.math.reduce_sum(tmp_e * x, axis=0) #no need to divide by each individual dominator
        tmp_eta = tf.math.reduce_mean((r - I_2)**2, reduction_indices=1, keep_dims=True)

        eta.append(tmp_eta)
    eta = tf.stack(eta, axis=1) #stack referenced on subjects
    eta = tf.math.reduce_mean(tf.reshape(eta, [-1, num_event]), reduction_indices=1, keep_dims=True)

    return tf.math.reduce_sum(eta) #sum over num_Events


# Stochastic gradient descent optimizer.
optimizer = tf.optimizers.Adam(learning_rate)