In [210]:
import numpy as np
import pandas as pd
import tensorflow as tf
import tensorflow.keras.backend as K
from tensorflow.keras.layers import Input, Dense, Dropout, BatchNormalization, concatenate, Lambda, Multiply
from tensorflow.keras.models import Model
from tensorflow.keras.optimizers import Adam, RMSprop,SGD,Nadam, Adagrad, Adadelta
from tensorflow.keras.regularizers import l1,l2,l1_l2
from tensorflow.keras.initializers import Constant ,Orthogonal, RandomNormal, VarianceScaling, Ones, Zeros
from tensorflow.keras.constraints import Constraint, UnitNorm
from keras.callbacks import Callback, TerminateOnNaN, ModelCheckpoint
from sksurv.metrics import concordance_index_censored as concordance
import math
import time
from sklearn.preprocessing import StandardScaler



In [211]:
#This is the TFCox class

class TFCox():
    def __init__(self, seed=42,batch_norm=False,l1_ratio=1,lbda=0.0001,
                 max_it=50,learn_rate=0.001,stop_if_nan=True,stop_at_value=False, cscore_metric=False,suppress_warnings=True,verbose=0):
        
        self.max_it = max_it
        self.tnan = stop_if_nan
        self.tcscore = stop_at_value
        self.lr=learn_rate
        self.cscore=cscore_metric
        np.random.seed(seed)
        tf.random.set_seed(seed)
        
        self.l1r = l1_ratio
        self.lbda=lbda
        self.bnorm = batch_norm
        self.verbose=verbose
        if suppress_warnings == True:
            import warnings
            warnings.filterwarnings('ignore')
       
    def coxloss(self, state):
        
        def loss(y_true, y_pred):  

                return -K.mean((y_pred - K.log(tf.math.cumsum(K.exp(y_pred),reverse=True,axis=0)+0.0001))*state,axis=0)

        return loss

    def cscore_metric(self, state):
        def loss(y_true,y_pred):
            con = 0
            dis = 0
            for a in range(len(y_pred)):
                for b in range(a+1,len(y_pred)):                                       
                        if (y_pred[a]>y_pred[b])  & (y_pred[a]*state[a]!=0):
                            con+=1
                            
                        elif (y_pred[a]<y_pred[b])  & (y_pred[a]*state[a]!=0):
                            dis+=1
            return     con/(con+dis)
        return loss
 
    
    def fit(self, X,state,time):
        from tensorflow.python.framework.ops import disable_eager_execution
        disable_eager_execution()
        K.clear_session()
       
        
        
        self.time = np.array(time)  
        self.newindex = pd.DataFrame(self.time).sort_values(0).index
        self.X = (pd.DataFrame(np.array(X)).reindex(self.newindex))                      
        self.state = np.array(pd.DataFrame(np.array(state)).reindex(self.newindex))
        self.time  = np.array(pd.DataFrame(np.array(time)).reindex(self.newindex))                       
        inputsx = Input(shape=(self.X.shape[1],)) 
        state = Input(shape=(1,))
        
        if self.bnorm==True:
            out = BatchNormalization()(inputsx)
            out = Dense(1,activation='linear',
                    kernel_regularizer=l1_l2(self.lbda*self.l1r,self.lbda*(1-self.l1r)),
                   use_bias=False)(out)
        else:
            out = Dense(1,activation='linear',
                    kernel_regularizer=l1_l2(self.lbda*self.l1r,self.lbda*(1-self.l1r)),
                   use_bias=False)(inputsx)

        
        model = Model(inputs=[inputsx, state], outputs=out)
        if (self.tcscore != False) or (self.cscore==True) :
            model.compile(optimizer=Adam(self.lr) ,
                          loss=self.coxloss(state) , metrics=[self.cscore_metric(state)],
                          experimental_run_tf_function=False)
        else:
            model.compile(optimizer=Adam(self.lr) ,
                          loss=self.coxloss(state) ,
                          experimental_run_tf_function=False)
        
        self.model=model
        if self.verbose==1:
            print(self.model.summary())

        self.loss_history_ = []
        for its in range(self.max_it):
            self.temp_weights = self.model.get_weights()
           
            tr = self.model.train_on_batch([self.X, self.state],np.zeros(self.state.shape))
           
            self.loss_history_.append(tr) 
            
            if self.verbose == 1:
                if (self.tcscore != False) or (self.cscore==True) :
                    print('loss:', self.loss_history_[-1][0],' C-score: ',self.loss_history_[-1][1] )
                else:
                    print('loss:', self.loss_history_[-1] )
            
            if self.tcscore != False:
                if self.loss_history_[-1][1]>=self.tcscore:
                    print('Terminated early because concordance >=' +str(self.tcscore)+ ' as set by stop_at_value flag.')
                    break
            if (self.tcscore != False) or (self.cscore==True) :
                if (math.isnan(self.loss_history_[-1][0]) or math.isinf(self.loss_history_[-1][0])) and self.tnan:
                    self.model.set_weights(self.temp_weights)
                    print('Terminated because weights == nan or inf, reverted to last valid weight set')
                    break
            else:
                if (math.isnan(self.loss_history_[-1]) or math.isinf(self.loss_history_[-1])) and self.tnan:
                    self.model.set_weights(self.temp_weights)
                    print('Terminated because weights == nan or inf, reverted to last valid weight set')
                    break
            
        self.beta_ = self.model.get_weights()[-1]

    def predict(self,X):
        preds = self.model.predict([X,np.zeros(len(X))])

        return preds

## Creating the simulated Data

In [224]:
#Uniform ranodmly distributed time data (no censoring)
y_time = np.random.rand(1000)

In [225]:
#Creating simulated X data from the time data 
X = (y_time * (np.random.rand(5000,1)-0.5)).transpose()

In [226]:
#Adding noise to the X data
X = X +( np.random.rand(1000,5000)-0.5)*5

In [227]:
#Randomly censoring 75% of the samples  (censored =0, uncensored = 1)
y_state=np.zeros(1000)
y_state[np.random.choice(1000,250)] =1

In [228]:
#Removing a random amount of time from the censored samples
for a in range(len(y_time)):
    if y_state[a] == 0:
        y_time[a] = y_time[a] - np.random.rand(1)[0]*y_time[a]

## Running the Model

In [229]:
#Creating a simple 80-20 test train split
train_index = np.random.choice(1000,800,replace=False)
test_index = [x for x in range(1000) if x not in train_index]

In [230]:
#Running the TFCox model (default L1_ratio = 1) in a small loop for different values of lambda

for a in [0,0.0001,0.001,0.01,0.1,1,10]:
    cox = TFCox(lbda=a)
    cox.fit(X[train_index],y_state[train_index],y_time[train_index])
    train_pred = cox.predict(X[train_index])
    test_pred = cox.predict(X[test_index])
    
    print('train concordance', 'lambda=',a,':' ,concordance(y_state[train_index].astype(bool),y_time[train_index],train_pred.flatten()))

    print('test concordance:', 'lambda=',a,':'  ,concordance(y_state[test_index].astype(bool),y_time[test_index],test_pred.flatten()))


train concordance lambda= 0 : (0.9994821685878963, 44393, 23, 0, 0)
test concordance: lambda= 0 : (0.7963263101026472, 1474, 377, 0, 0)
train concordance lambda= 0.0001 : (0.9995046829971181, 44394, 22, 0, 0)
test concordance: lambda= 0.0001 : (0.7855213398163156, 1454, 397, 0, 0)
train concordance lambda= 0.001 : (0.9994596541786743, 44392, 24, 0, 0)
test concordance: lambda= 0.001 : (0.8227984873041599, 1523, 328, 0, 0)
train concordance lambda= 0.01 : (0.9978611311239193, 44321, 95, 0, 0)
test concordance: lambda= 0.01 : (0.8854673149648838, 1639, 212, 0, 0)
train concordance lambda= 0.1 : (0.9538904899135446, 42368, 2048, 0, 0)
test concordance: lambda= 0.1 : (0.7277147487844409, 1347, 504, 0, 0)
train concordance lambda= 1 : (0.6059528097982709, 26914, 17502, 0, 0)
test concordance: lambda= 1 : (0.5807671528903295, 1075, 776, 0, 0)
train concordance lambda= 10 : (0.506686779538905, 22505, 21911, 0, 0)
test concordance: lambda= 10 : (0.42409508373851973, 785, 1066, 0, 0)
