<a href="https://colab.research.google.com/github/Mahdiye-Bayat/firstDive/blob/master/CRESA/RNN_Surv.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
from google.colab import drive
import os
drive.mount('/content/gdrive')
root_path = 'gdrive/My Drive/Data_Prep'
os.chdir(root_path)

Mounted at /content/gdrive


In [None]:
from pandas.core.arrays import boolean
#data_handler
#%% 
import pandas as pd
import scipy.io as sio
import numpy as np

from sklearn import preprocessing
from sklearn.utils import shuffle
from sklearn.metrics import roc_auc_score
from sklearn.metrics import accuracy_score
from sklearn.metrics import f1_score


#%% 

"""
Input: string (dataset name)
Output: imported data
"""

def import_small_dataset(dataset): 
 
    path = ''
    path += str(dataset)
    path += '.csv'
    
    data = pd.read_csv(path)
        
    data = data.drop(data.columns[0], axis = 1)
    
    if(dataset == 'nwtco'):
        data['in.subcohort'] = data['in.subcohort'].astype(float)
    if(dataset == 'myData'):
        data['event'] = data['event'].astype(bool).astype(float)
    
    return data



"""
Input: string (dataset name)
Output: imported data
"""

def import_UNOS_dataset(dataset): 
    if (dataset == 'Transplant'):
        data = sio.loadmat('UNOS_Data_Share/Preprocessed_Data/Transplant_Data')
        #Assign more meaningful names to features
        data ['Feature_Name'] = ['Age', 'Gender', 'IschemicTime','Diabetes', 'Infection', 'Transfusion','PreviousTransplant','NumberOfPreviousTransplant','VentilatorAssist','ECMOAssist','VentSupport', 'Creatinine','Bilirubin', 'PRA','HLAMismatch','BloodTypeA','BloodTypeB','BloodTypeO','BloodTypeAB','Dialysis','IABP','DonorAge','DonorGender','DonorBloodTypeA','DonorBloodTypeB','DonorBloodTypeO','DonorBloodTypeAB','ABOEqual','ABO_Compatible','ABO_Incompatible','HEP_C_Antigen','DonorDiabetes', 'Distance','DaysInState1','DaysInState1A','DaysInState2','DaysInState1B','BMI','DonorBMI','VADAssist','TotalArtificialHeart','Inotropic','A_Mismatch','B_Mismatch','DR_Mismatch','Height_Difference','Weight_Difference','LVAD_at_listing','LVAD_while_listed','LVAD']
        data ['Surv_Name'] = ['Survival_Time', 'Eventful', 'Tranplant_Year']
        num_rows = 60400
    
    else:
        data = sio.loadmat('UNOS_Data_Share/Preprocessed_Data/Waitlist_Data')
        data['Feature_Name'] = ['Age','Gender', 'Diabetes','PreviousTransplant', 'NumberOfPreviousTransplant', 'VentilatorAssist','ECMOAssist', 'Creatinine', 'BloodTypeA','BloodTypeB','BloodTypeO','BloodTypeAB', 'IABP','DaysInState1', 'DaysInState1A','DaysInState2','DaysInState1B', 'BMI','VADAssist', 'TotalArtificialHeart','Inotropic','LVAD_at_listing','LVAD_while_listed','LVAD']                           
        data['Surv_Name'] = ['Survival_Time', 'Eventful', 'Listed_Year']
        num_rows = 36329
    
    #Instead of having the column censored, we will have the column eventful, which will assume value
    #1 if the event has been registered and 0 otherwise
    for i in range(len(data['Surv'][:, 1])):
        if (data['Surv'][i][1] == 0): 
            data['Surv'][i][1] = 1
        else:
            data['Surv'][i][1] = 0
    
    pd_data = pd.DataFrame(data = data['Feature'], index=np.arange(num_rows), columns=data['Feature_Name'])
    
    surv_data = pd.DataFrame(data = data['Surv'], index=np.arange(num_rows), columns=np.transpose(data['Surv_Name']))
    
    pd_data = pd.concat([pd_data, surv_data], axis = 1)
    
    return pd_data



"""
Input: Dataframe
Ouput: np array, standard scaler (the np array is already scaled)
"""
def normalize_data(data): 
    npdata = data.values
    npdata = npdata.astype(float)
    scaler = preprocessing.StandardScaler().fit(npdata)
    npdata = scaler.transform(npdata)
    return npdata, scaler




"""

"""
def delete_features(npdata, features_to_delete, dataset):
    if (dataset != 'nwtco'):
        for feature in reversed(features_to_delete):
            npdata = np.delete(npdata, [feature], axis = 1)
    else:
        for feature in (features_to_delete):
            npdata = np.delete(npdata, [feature], axis = 1)
    return npdata




"""
Input: string (dataset name)
Output: list of two numbers: cens feat and surv feat 
"""

def get_surv_cens_numbers(dataset): 
    if (dataset== 'myData'):
        surv = 0
        cens = 1
    if (dataset == 'nwtco'):
        surv = 6
        cens = 5
    elif(dataset == 'aids2'):
        surv = 1
        cens = 2
    elif(dataset == 'flchain' ):
        surv = 8
        cens = 9
    elif(dataset == 'Transplant'):
        surv = 50
        cens = 51
    elif(dataset == 'Waitlist'):
        surv = 24
        cens = 25
    return surv, cens




"""
Input: string(dataset name)
Ouput: two strings, name of the columns containing surv and cens 
"""
def get_surv_cens_names(dataset):
    if (dataset== 'myData'):
        surv = 'tte0'
        cens = 'event'
    if (dataset == 'nwtco'):
        surv = 'edrel'
        cens = 'rel'
    elif(dataset == 'aids2'):
        surv = 'death'
        cens = 'status'
    elif(dataset == 'flchain'):
        surv = 'futime'
        cens = 'death'
    elif(dataset == 'Transplant'):
        surv = 'Survival_Time'
        cens = 'Eventful'
    elif(dataset == 'Waitlist'):
        surv = 'Survival_Time'
        cens = 'Eventful'
    return surv, cens



"""
Input: array with all the survival times 
       float with the interval length desired (granularity of the discretization)
Ouput: matrix with dimensions: (num_patients, total number of intervals cosidered),
       each cell (i, j) in the matrix will have value 1 if the event of death for the 
       patient i happens in the time interval j
"""
def get_event_uncensored_array(survival_times_array, uncensored_array, index_list, intervals_length,min_days,  max_days):

    #total_intervals = np.amax(survival_times_array) // intervals_length
    total_intervals = (max_days-min_days) // intervals_length

    y_event = np.zeros((len(survival_times_array), total_intervals))
    
    for patient, surv_time in enumerate(survival_times_array):
        event_interval = surv_time // intervals_length - (min_days // intervals_length)
        for i in range(len(y_event[patient])): 
            #Since with // we know that the result is cut (e.g., 3.9//2 = 1) we know that the event happens in  event_interval (note that we start to count the intervals from 0)
            if (i == event_interval and uncensored_array[index_list[patient]] == 1): 
                y_event[patient][i] = 1
    return y_event




"""
Input: array with all the survival times 
       float with the interval length desired (granularity of the discretization)
Ouput: matrix with dimensions: (num_patients, total number of intervals cosidered),
       each cell (i, j) in the matrix will have value 1 if the 
       patient i is still alive in the time interval j
"""
def get_survivor_array(survival_times_array, intervals_length, min_days, max_days):
    
    #total_intervals = np.amax(survival_times_array) // intervals_length
    total_intervals = (max_days-min_days) // intervals_length
    y_survivor = np.zeros((len(survival_times_array), total_intervals))
       
    for patient, surv_time in enumerate(survival_times_array):
        event_interval = surv_time // intervals_length - (min_days // intervals_length)
        for i in range(len(y_survivor[patient])): 
            #Since with // we know that the result is cut (e.g., 3.9//2 = 1) we know that the event happens in  event_interval (note that we start to count the intervals from 0)
            if (i < event_interval):
                y_survivor[patient][i] = 1
    return y_survivor




"""
Input: array with all the survival times 
       float with the interval length desired (granularity of the discretization)
Ouput: matrix with dimensions: (num_patients, total number of intervals cosidered),
       each cell (i, j) in the matrix will have value 1 if the 
       patient i is  uncensored or if the patient is censored and still alive in the time interval j
"""

def get_uncensored_or_survivor_array(survival_times_array, uncensored_array, index_list, intervals_length, min_days,  max_days):
    #total_intervals = np.amax(survival_times_array) // intervals_length
    total_intervals = (max_days-min_days) // intervals_length
    y_uncensored_or_survivor = np.zeros((len(survival_times_array), total_intervals))
        
    for patient, surv_time in enumerate(survival_times_array):
        event_interval = surv_time // intervals_length - (min_days // intervals_length)
        for i in range(len(y_uncensored_or_survivor[patient])): 
            #Since with // we know that the result is cut (e.g., 3.9//2 = 1) we know that the event happens in  event_interval (note that we start to count the intervals from 0)
            if (i < event_interval):
                y_uncensored_or_survivor[patient][i] = 1
           
            elif(uncensored_array[index_list[patient]] == 1):
                y_uncensored_or_survivor[patient][i] = 1
                
    return y_uncensored_or_survivor



"""
Input: array with all the survival times 
       float with the interval length desired (granularity of the discretization)
Ouput: matrix with dimensions: (num_patients, total number of intervals cosidered),
       each cell (i, j) in the matrix will have value 1 if the event of death for the 
       patient i happens in the time interval j
"""
def get_event_censored_array(survival_times_array, uncensored_array, index_list, intervals_length, min_days, max_days):

    total_intervals = (max_days-min_days) // intervals_length

    y_event = np.zeros((len(survival_times_array), total_intervals))
    
    for patient, surv_time in enumerate(survival_times_array):
        event_interval = surv_time // intervals_length -(min_days // intervals_length)
        for i in range(len(y_event[patient])): 
            #Since with // we know that the result is cut (e.g., 3.9//2 = 1) we know that the event happens in  event_interval (note that we start to count the intervals from 0)
            if (i == event_interval and uncensored_array[index_list[patient]] == 0): 
                y_event[patient][i] = 1
    return y_event



"""
Input: pandas dataseframe with data
       integer with length of the intervals
       integer with the first day considered
       integer with the maximum time considered
Output: np array transfromed to be inputed in the RNN
"""

def create_RNN_input(data, intervals_length, min_days, max_days): 
    total_intervals = (max_days - min_days) // intervals_length
    x_input = []
    for i in range(total_intervals):
        temp = np.full((len(data), 1), float(i)/total_intervals)
        #We now sobstitute the index of the time interval with all the same number 
        #temp = np.full((len(data), 1), 1.0)
        concatenated = np.concatenate((data, temp), axis = 1)
        x_input.append(concatenated)  
    np_x_input = np.asarray(x_input)
    np_x_input = np.transpose(np_x_input, (1,2,0))
    return np_x_input




"""
Input: Dataframe with the data
Ouput: Dataframe with preprocessed data
"""

def aids2_preprocessing(data):
    data = data.replace({'sex': {'M': -1.0, 'F': 1.0}})
    data = data.replace({'status':{'D': 1.0, 'A': 0.0}})
    data['death'] = data['death'] - data['diag']
    del data['diag']
    data.rename(columns={'T.categ': 'categ'}, inplace=True)
    data=pd.get_dummies(data)
    return data



"""
Input: Dataframe with missing data
Output: Dataframe with completed data
"""
def flchain_fillMissing(data):
    data.rename(columns={'sample.yr': 'sample_year'}, inplace=True)
    data.rename(columns={'flc.grp': 'flc_grp'}, inplace=True)
    data['age'].fillna(data.mean()['age'], inplace=True)
    data['sample_year'].fillna(data.mean()['sample_year'] , inplace=True)
    data['kappa'].fillna(data.mean()['kappa'], inplace=True)
    data['lambda'].fillna(data.mean()['lambda'], inplace=True)
    data['flc_grp'].fillna(data.mean()['flc_grp'], inplace=True)
    data['creatinine'].fillna(data.mean()['creatinine'], inplace=True) 
    data['futime'].fillna(data.mean()['futime'], inplace=True)
    
    data['sex'].fillna(data['sex'].value_counts().index[0], inplace=True)
    data['death'].fillna(data['death'].value_counts().index[0], inplace=True)
    data['chapter'].fillna(data['chapter'].value_counts().index[0], inplace=True)
    data['mgus'].fillna(data['mgus'].value_counts().index[0], inplace=True)
    return data



"""
Input: Dataframe
Output preprocessed Dataframe
"""
def flchain_preprocessing(data): 
    data = data.replace({'sex': {'M': -1.0, 'F': 1.0}})
    data = pd.get_dummies(data)
    
    return data




def get_shuffled_data(dataset):
    if (dataset == 'Transplant' or dataset == 'Waitlist'):
        data = import_UNOS_dataset(dataset)
    else:
        data = import_small_dataset(dataset)
           
    data = data.sample(frac=1).reset_index(drop=True)
    data = shuffle(data)
    
    return data
    

"""
Input: int (intervals length)
       int (first days considered)
       int (max days considered)
       string (dataset name)
Output: Dictionary with all the data ready to be used
"""

def elaborate_data(data, intervals_length, min_days, max_days_considered, dataset, cross_validation_number): 
    
    if (dataset == 'aids2'):
        data = aids2_preprocessing(data)
    elif(dataset == 'flchain'):
        data = flchain_fillMissing(data)
        data = flchain_preprocessing(data)   
        
    for k in range(cross_validation_number):
        temp1, temp2 = np.split(data, [int(0.80*len(data))])
        frames = [temp2, temp1]
        data = pd.concat(frames)
          
    data, Val, Test = np.split(data, [int(0.60*len(data)), int(0.80*len(data))])
    
    surv_feature, cens_feature = get_surv_cens_names(dataset)
    idx_features_to_delete = get_surv_cens_numbers(dataset)
    
    
    y_survivor = get_survivor_array(survival_times_array=data[surv_feature], intervals_length=intervals_length, min_days = min_days, max_days = max_days_considered)
    y_Val_survivor = get_survivor_array(survival_times_array= Val[surv_feature], intervals_length=intervals_length, min_days = min_days, max_days = max_days_considered)
    y_Test_survivor = get_survivor_array(survival_times_array=Test[surv_feature], intervals_length=intervals_length, min_days = min_days, max_days = max_days_considered)

    #We have to pass also the index list because when we drop the rows we drop also the respective indexes
    y_event = get_event_uncensored_array(survival_times_array=data[surv_feature], uncensored_array=data[cens_feature], index_list = data.index.values, intervals_length=intervals_length, min_days = min_days, max_days = max_days_considered)
    y_event_censoring = get_event_uncensored_array(survival_times_array=data[surv_feature], uncensored_array=data[cens_feature], index_list = data.index.values, intervals_length=intervals_length, min_days = min_days, max_days = max_days_considered)

    y_uncensored_or_survivor = get_uncensored_or_survivor_array(survival_times_array=data[surv_feature],uncensored_array=data[cens_feature], index_list = data.index.values, intervals_length=intervals_length, min_days = min_days, max_days = max_days_considered)
    y_uncensored_or_survivor_val = get_uncensored_or_survivor_array(survival_times_array=Val[surv_feature],uncensored_array=Val[cens_feature], index_list = Val.index.values, intervals_length=intervals_length, min_days = min_days, max_days = max_days_considered)
    y_uncensored_or_survivor_test = get_uncensored_or_survivor_array(survival_times_array=Test[surv_feature],uncensored_array=Test[cens_feature], index_list = Test.index.values, intervals_length=intervals_length, min_days = min_days, max_days = max_days_considered)
    
    npdata, scaler = normalize_data(data)
    npdata = delete_features(npdata, idx_features_to_delete, dataset)

    np_x_input = create_RNN_input(data=npdata, intervals_length=intervals_length,min_days = min_days,  max_days = max_days_considered)
   
    data_test = Test
    data_val = Val
    
    Val = Val.values
    Val = Val.astype(float) 
    
    Test = Test.values
    Test = Test.astype(float) 
    
    Val = scaler.transform(Val)
    Test = scaler.transform(Test)
    
    Val = delete_features(Val, idx_features_to_delete, dataset)
    Test = delete_features(Test, idx_features_to_delete, dataset)

    np_val_input = create_RNN_input(data = Val, intervals_length=intervals_length,min_days = min_days, max_days = max_days_considered)
    np_test_input = create_RNN_input(data = Test, intervals_length=intervals_length, min_days = min_days, max_days = max_days_considered)

    data = np.asarray(data)
    data_test = np.asarray(data_test)
    data_val = np.asarray(data_val)
    
    return dict(
            X_data = data,
            X_data_test = data_test,
            X_data_val = data_val, 
            y_survivor = y_survivor,
            y_Val_survivor = y_Val_survivor, 
            y_Test_survivor = y_Test_survivor,
            y_event = y_event,
            y_event_censoring = y_event_censoring,
            y_uncensored_or_survivor = y_uncensored_or_survivor, 
            y_uncensored_or_survivor_val = y_uncensored_or_survivor_val, 
            y_uncensored_or_survivor_test = y_uncensored_or_survivor_test, 
            rnn_input = np_x_input,
            rnn_val = np_val_input,
            rnn_test = np_test_input
            )
 
    
"""
Input: numpy matrix
       string (dataset name)
Output: mask of the acceptable pairs
"""
def get_accettablePairsMsk(data, dataset): 
    acc_pairs_matrix = []
    surv_feat, cens_feat = get_surv_cens_numbers(dataset)
    
    for surv_i, compl_i in zip(data[:, surv_feat], data[:, cens_feat]): 
        acc_pairs_row = np.zeros(len(data[:, surv_feat]))
        if (compl_i == 1):
            j = 0
            for surv_j, compl_j in zip(data[:, surv_feat], data[:, cens_feat]):
                if (compl_j == 1 or surv_j >= surv_i):
                    acc_pairs_row[j] = 1
                j += 1
        acc_pairs_matrix.append(acc_pairs_row)
    
    acc_pairs_matrix = np.asarray(acc_pairs_matrix)
    return acc_pairs_matrix




In [None]:
pip install lifelines

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Collecting lifelines
  Downloading lifelines-0.27.4-py3-none-any.whl (349 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m349.7/349.7 KB[0m [31m19.0 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting formulaic>=0.2.2
  Downloading formulaic-0.5.2-py3-none-any.whl (77 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m77.1/77.1 KB[0m [31m9.7 MB/s[0m eta [36m0:00:00[0m
Collecting autograd-gamma>=0.3
  Downloading autograd-gamma-0.5.0.tar.gz (4.0 kB)
  Preparing metadata (setup.py) ... [?25l[?25hdone
Collecting astor>=0.8
  Downloading astor-0.8.1-py2.py3-none-any.whl (27 kB)
Collecting interface-meta>=1.2.0
  Downloading interface_meta-1.3.0-py3-none-any.whl (14 kB)
Building wheels for collected packages: autograd-gamma
  Building wheel for autograd-gamma (setup.py) ... [?25l[?25hdone
  Created wheel for autograd-gamma: filename=autograd_gamma-0.5.0

In [None]:
#%% %pip install 'lifelines'
import tensorflow.compat.v1 as tf
tf.disable_v2_behavior()
import tensorflow as tf2
import numpy as np
#import data_handler
from lifelines.utils import concordance_index

#%%CLASS RNN


class RNN_SURV():
    
    
    def __init__(self,
                 bidirectional,
                 cell_type,
                 constant_loss,
                 constant_raykar,
                 learning_rate,
                 num_features, 
                 num_rnn_layers,
                 num_nodes_h1, 
                 num_nodes_h2,
                 num_nodes_input_rnn_layer,
                 num_steps,
                 state_size
                 ): 
        self.bidirectional = bidirectional
        self.cell_type = cell_type
        
        self.constant_loss = constant_loss
        self.constant_raykar = constant_raykar
        
        self.learning_rate = learning_rate
        self.num_rnn_layers = num_rnn_layers
        self.num_nodes_input = num_features
        self.num_nodes_h1 = num_nodes_h1
        self.num_nodes_h2 = num_nodes_h2
        self.num_nodes_input_rnn_layer = num_nodes_input_rnn_layer
        self.num_nodes_output = num_steps
        self.state_size = state_size
        
        self.X = 0
        self.Y_true = 0
        self.dropout_hid_state = 1
        self.dropout_keep_prob_output = 1
        self.dropout_keep_prob_input = 1
        self.dropout_keep_prob_state = 1
        
        return
    
    
    
    """
    Input: - y_pred: survival predictions made by the neural network
           - survival_vector: true labels (each element of the vector is equal to 1 if the patient is still alive, and 0 if the patient is dead/lost)
           - uncensored_or_survivor: censored labels (each element of the vector is equal to 1 for al those intervals of which we know the real outcome, and 0 for those intervals in which we lost track of the patient)
           - btach_size: size of the input batch
    Output: sum over the values of the loss functions in the different time intervals
    """
    def PartialLikelihood(self, y_pred, survival_vector, uncensored_or_survivor, batch_size): 
        
        y_pred = tf.unstack(y_pred, axis = 1)[0]
        losses =  tf.multiply(uncensored_or_survivor, tf.nn.sigmoid_cross_entropy_with_logits(labels=survival_vector, logits = y_pred))
        interval_loss = tf.reduce_sum(losses) 
        
        return interval_loss
    
    
    
    """
    Input: - mask_acc_pairs: mask to take into consideration only acceptable pairs
           - output: risk score (final neural network output)
    Output: value of the raykar loss function (defined in equation (6) in the paper)
    """ 
    def RaykarLikelihood(self, msk_acc_pairs, output): 
        output_col_concat = tf.concat([output, output], axis = 1)
        for i in range(self.input_length-2): 
            output_col_concat = tf.concat([output_col_concat, output], axis = 1 )
        output_j = output_col_concat
        output_i = tf.transpose(output_col_concat)
        loss = tf.reduce_sum(tf.multiply(msk_acc_pairs, (1+ tf.log(tf.nn.sigmoid(tf.subtract(output_j, output_i))+ 1e-30)/tf.log(2.0))))
        loss = loss / tf.reduce_sum(msk_acc_pairs)
        return -loss
    
    
    
    """
    Input: - batch_size
    Output: the object itself
    
    This function defines the graph of RNN_SURV 
    """
    def build_model(self, batch_size): 
        tf.reset_default_graph()
        tf.set_random_seed(1)
        
        #Placeholders########################
        self.raw_X = tf.placeholder(tf.float32, [None, None, self.num_nodes_input], name='input_placeholder')
        self.survivor = tf.placeholder(tf.float32, [None, None], name='survivor_placeholder')
        self.event = tf.placeholder(tf.float32, [None, None], name='event_placeholder')
        self.event_censoring = tf.placeholder(tf.float32, [None, None], name='event_placeholder_censoring')
        self.uncensored_or_survivor=tf.placeholder(tf.float32, [None, None], name='uncensored_or_survivor_placeholder')
        self.dropout_hid_state = tf.placeholder(tf.float32, name = 'dropout_hidden_layer')
        self.dropout_keep_prob_output = tf.placeholder(tf.float32, name = 'dropout_keep_prob_output')
        self.dropout_keep_prob_input = tf.placeholder(tf.float32, name = 'dropout_keep_prob_input')
        self.dropout_keep_prob_state = tf.placeholder(tf.float32, name='dropout_keep_prob_state')
        
        self.msk_acc_pairs = tf.placeholder(tf.float32, [None, None], name = 'msk_acc_pairs_placeholder')
        self.input_length = batch_size
                
        # Define weights
        if (self.bidirectional == False):
            self.weights = {
                'embed_1': tf.Variable(tf.truncated_normal([self.num_nodes_input, self.num_nodes_h1]), name='W_embed_1'), 
                'embed_2': tf.Variable(tf.truncated_normal([self.num_nodes_h1, self.num_nodes_h2]), name='W_embed_2'), 
                'input_rnn' : tf.Variable(tf.truncated_normal([self.num_nodes_h2, self.num_nodes_input_rnn_layer]), name='W_input_rnn'), 
                'out': tf.Variable(tf.truncated_normal([self.state_size, 1]), name="W_ois_trainingut"),
                'cox_out': tf.Variable(tf.truncated_normal([self.num_nodes_output, 1]), name="W_cox_output")
            }
        else:
            self.weights = {
                'embed_1': tf.Variable(tf.truncated_normal([self.num_nodes_input, self.num_nodes_h1]), name='W_bid_embed_1'), 
                'embed_2': tf.Variable(tf.truncated_normal([self.num_nodes_h1, self.num_nodes_h2]), name='W_bid_embed_2'), 
                'input_rnn': tf.Variable(tf.truncated_normal([self.num_nodes_h2, self.num_nodes_input_rnn_layer]), name='W_bid_input_rnn'), 
                'out': tf.Variable(tf.truncated_normal([self.state_size*2, 1]), name="W_out"),
                'cox_out': tf.Variable(tf.truncated_normal([self.num_nodes_output, 1]), name="W_cox_output")

            }
        biases = {
    
            'embed_1' : tf.Variable(tf.zeros([self.num_nodes_h1]), name="b_embed_1"),
            'embed_2' : tf.Variable(tf.zeros([self.num_nodes_h2]), name="b_embed_2"),
            'input_rnn' : tf.Variable(tf.zeros([self.num_nodes_input_rnn_layer]), name="b_input_rnn"),
            'out': tf.Variable(tf.zeros([1]), name="b_out"),
            'cox_out': tf.Variable(tf.zeros([1]), name="b_cox_out")
        }
        ##################################
        #tf.compat.v1.nn.rnn_cell.MultiRNNCell
        #tf.keras.layers.StackedRNNCells
        #tf.compat.v1.nn.rnn_cell.LSTMCell
        def get_RNN_cell():
            cell = tf.nn.rnn_cell.BasicRNNCell(num_units = self.state_size)
            cell = tf.nn.rnn_cell.DropoutWrapper(cell, input_keep_prob = self.dropout_keep_prob_input, output_keep_prob=self.dropout_keep_prob_output, state_keep_prob = self.dropout_keep_prob_state)
            return cell
        
        def get_LSTM_cell(state_size): 
            cell = tf.nn.rnn_cell.LSTMCell(state_size, state_is_tuple=True)
            cell = tf.nn.rnn_cell.DropoutWrapper(cell, input_keep_prob = self.dropout_keep_prob_input, output_keep_prob=self.dropout_keep_prob_output, state_keep_prob = self.dropout_keep_prob_state)
            print("LSTM cell")
            return cell
        
        def get_GRU_cell(state_size): 
            cell =  tf.nn.rnn_cell.GRUCell(state_size)
            cell = tf.nn.rnn_cell.DropoutWrapper(cell, input_keep_prob = self.dropout_keep_prob_input, output_keep_prob=self.dropout_keep_prob_output, state_keep_prob = self.dropout_keep_prob_state)
            return cell
        
        def get_RNN_layers(X, weights, biases, cell_type="LSTM"):
            #rnn_inputs is a list of num_steps tensors with shape [batch_size, num_features]
            if (self.bidirectional == False):
                if (cell_type=="LSTM"):
                    
                    cell = tf.nn.rnn_cell.MultiRNNCell([get_LSTM_cell(self.state_size) for size in range(self.num_rnn_layers)], state_is_tuple=True)
                    
                elif (cell_type=="GRU"):
                    cell = tf.nn.rnn.MultiRNNCell([get_GRU_cell(self.state_size) for _ in range(self.num_rnn_layers)], state_is_tuple=True)
                elif(cell_type == "RNN"):
                    cell = tf.nn.rnn.MultiRNNCell([get_RNN_cell(self.state_size) for _ in range(self.num_rnn_layers)], state_is_tuple=True)
                else: 
                    print("Invalid argument")
                return cell
            else:
                if (cell_type == "LSTM"):
                    fwd_cell = get_LSTM_cell(self.state_size)
                    bwd_cell = get_LSTM_cell(self.state_size)
                elif (cell_type == "GRU"):
                    fwd_cell = tf.nn.rnn.MultiRNNCell([get_GRU_cell(self.state_size) for _ in range(self.num_rnn_layers)], state_is_tuple=True)
                    bwd_cell = tf.nn.rnn.MultiRNNCell([get_GRU_cell(self.state_size) for _ in range(self.num_rnn_layers)], state_is_tuple=True)
                else: 
                    print("Invalid argument")
                return fwd_cell, bwd_cell
########################
        unstacked_raw_X = tf.unstack(self.raw_X, self.num_nodes_output, axis =1)
    
        X = []
        
        for step_X in unstacked_raw_X: 
            linear_h_1 = tf.matmul(step_X, self.weights['embed_1']) + biases['embed_1']
            h_1 = tf.nn.relu(linear_h_1)
            d_1 = tf.nn.dropout(h_1, keep_prob = self.dropout_hid_state)
            d_1 =  d_1 / self.dropout_hid_state 
            linear_h_2 = tf.matmul(d_1, self.weights['embed_2']) + biases['embed_2']
            h_2 = tf.nn.relu(linear_h_2)
            d_2 = tf.nn.dropout(h_2, keep_prob = self.dropout_hid_state)
            d_2 =  d_2 / self.dropout_hid_state 

            linear_h_3 = tf.nn.relu(tf.matmul(d_2, self.weights['input_rnn']) + biases['input_rnn'])
            h_3 = tf.nn.relu(linear_h_3)
            d_3 =  tf.nn.dropout(h_3, keep_prob = self.dropout_hid_state)
            d_3 =  d_3 / self.dropout_hid_state 

            X.append(d_3)
    
        X = tf.stack(X, axis = 0)
        
        X.set_shape([None, None, self.num_nodes_input_rnn_layer])
        X = tf.transpose(X, [1, 0, 2])
        
        cell = get_RNN_layers(X, self.weights, biases, cell_type=self.cell_type)
    
        def execute_rnn(cell, X_input, init_state_fw, init_state_bw=None):
            if (self.bidirectional == False):
                rnn_outputs, final_state =  tf.nn.dynamic_rnn(cell, X_input, initial_state = init_state_fw, dtype=tf.float32)
            else:
                rnn_outputs, final_state = tf.nn.bidirectional_dynamic_rnn(cell_fw = cell[0], cell_bw = cell[1], inputs = X_input, initial_state_fw = init_state_fw, initial_state_bw=init_state_bw, dtype=tf.float32)
            return rnn_outputs, final_state
        
            
        if(self.bidirectional == False):
            init_state_fw = cell.zero_state(batch_size, tf.float32)
            init_state_bw = None
            rnn_outputs, final_state = execute_rnn(cell, X, init_state_fw)
        else:
            init_state_fw = cell[0].zero_state(batch_size, tf.float32)
            init_state_bw = cell[1].zero_state(batch_size, tf.float32)
            rnn_outputs, final_state = execute_rnn(cell, X, init_state_fw, init_state_bw)
        
        
        if (self.bidirectional == True):
            rnn_outputs = tf.concat([rnn_outputs[0], rnn_outputs[1]], axis = 2)
            rnn_outputs = tf.reshape(rnn_outputs, [-1, self.state_size*2])
        else:
            rnn_outputs = tf.reshape(rnn_outputs, [-1, self.state_size])
            
        logits = tf.matmul(rnn_outputs, self.weights['out']) + biases['out']
        self.line_predictions = tf.nn.sigmoid(logits)
        
        matrix_logits = tf.reshape(logits, [-1, self.num_nodes_output])
        self.cox_output = (tf.matmul(matrix_logits, self.weights['cox_out']) + biases['cox_out'])
        
        line_survivor = tf.reshape(self.survivor, [-1])
        line_uncensored_or_survivor = tf.reshape(self.uncensored_or_survivor, [-1])
        
        self.loss = RNN_SURV.PartialLikelihood(self, y_pred = logits, survival_vector = line_survivor, uncensored_or_survivor = line_uncensored_or_survivor, batch_size = batch_size)
        self.raykar_loss = RNN_SURV.RaykarLikelihood(self, msk_acc_pairs = self.msk_acc_pairs, output = self.cox_output)
        
        L2_regularizer_weights = tf.nn.l2_loss(self.weights['embed_1']) + tf.nn.l2_loss(self.weights['embed_2'])+ tf.nn.l2_loss(self.weights['input_rnn'])+ tf.nn.l2_loss(self.weights['out'])
        L2_regularizer_biases = tf.nn.l2_loss(biases['embed_1']) + tf.nn.l2_loss(biases['embed_2'])+ tf.nn.l2_loss(biases['input_rnn']) + tf.nn.l2_loss(biases['out'])
    
        self.total_loss = tf.reduce_sum(self.constant_raykar*self.raykar_loss + self.constant_loss*self.loss + 0.1* L2_regularizer_weights + 0.1 * L2_regularizer_biases) #+ 1 * tf.nn.l2_loss(biases['cox_out']) +1 *  tf.nn.l2_loss(weights['cox_out']))
        #print(self.total_loss.shape)
        self.train_step = tf.train.AdamOptimizer(self.learning_rate).minimize(self.total_loss)
        
        self.predictions = tf.nn.sigmoid(matrix_logits)
        self.true_labels = tf.cast(self.survivor, tf.int8)
        
        return self
    
    
    
    
    def train_RNN( 
              self,
              batch_size,
              data,
              dataset,
              dropout_hid_state,
              dropout_keep_prob_input,
              dropout_keep_prob_output,
              dropout_keep_prob_state,
              num_epochs_per_time,
              patience_max, 
              saved_model = False, 
              verbose=True):
        
        path = "./checkpoints/"
        path += str(dataset)
        path += "/state_size"
        path += str(self.state_size)
        if (self.bidirectional == True):
            path += "Bid"
        path += ".ckpt"
        
        print("path: ", path)
       
        iteration = 0
        patience = patience_max
        max_c_index = 0
        
        
        with tf.Session() as sess: 
                
            if (saved_model == False): 
                print("NO SAVED MODEL")
                sess.run(tf.global_variables_initializer())
            else:  
                restorer = tf.train.Saver()
                restorer.restore(sess, path)
       
            while (patience > 0):
                
                iteration = iteration + num_epochs_per_time
            
                for epoch in range(num_epochs_per_time):
        
                    num_batches = len(data['rnn_input']) // batch_size
        
                    shuffler = np.random.permutation(len(data['rnn_input']))
        
                    for num_batch in range(num_batches):
                        
                        batch_X = data['rnn_input'][shuffler[num_batch*batch_size: num_batch*batch_size+batch_size]]
                        batch_X = np.transpose(batch_X, [0,2,1])
                        batch_survivor = data['y_survivor'][shuffler[num_batch*batch_size: num_batch*batch_size+batch_size]]
                        batch_event = data['y_event'][shuffler[num_batch*batch_size: num_batch*batch_size+batch_size]]
                        batch_event_censoring = data['y_event_censoring'][shuffler[num_batch*batch_size: num_batch*batch_size+batch_size]]
                        batch_uncensored_or_survivor = data['y_uncensored_or_survivor'][shuffler[num_batch*batch_size: num_batch*batch_size+batch_size]]
        
                        batch_data = data['X_data'][shuffler[num_batch*batch_size: num_batch*batch_size+batch_size]]
    
                        batch_msk_acc_pairs = get_accettablePairsMsk(batch_data, dataset = dataset)
                                                
                        feed_dict = {self.raw_X: batch_X, self.survivor: batch_survivor, self.event:batch_event, self.event_censoring : batch_event_censoring, self.uncensored_or_survivor: batch_uncensored_or_survivor,  self.dropout_hid_state : dropout_hid_state, self.dropout_keep_prob_output : dropout_keep_prob_output, self.dropout_keep_prob_input: dropout_keep_prob_input, self.dropout_keep_prob_state: dropout_keep_prob_state,  self.msk_acc_pairs: batch_msk_acc_pairs}
                        
                        _, r, l, training_loss_, predictions_ , true_labels_ = sess.run([ self.train_step, self.raykar_loss, self.loss, self.total_loss, self.predictions, self.true_labels], feed_dict)
                        saved_model = True
                        
                
                c_index = RNN_SURV.validate(self, batch_size, data, sess, dataset)

                if(max_c_index >= c_index):
                    patience = patience - 1
                else:
                    patience = patience_max
                    saver = tf.train.Saver()
                    saver.save(sess, path)
                    print("UPDATED")
                    max_c_index = c_index
            

        return max_c_index
    
    
    
    
    """
    Input: - data: actual data
           - dataset: string containing the name of the dataset
           - phase: it can either be 'validation' or 'test'
           
    Output: the two columns containing the true values for the time-to-event and the censoring indicator respectively.
    To be noted: for every i cens[i] will be equal to 0 if censored and equal to 1 if eventful
    """
    def get_truths_cens(self, data, dataset, phase):
        if (phase == 'validation'): 
            label = 'X_data_val'
        elif(phase == 'test'):
            label = 'X_data_test'
        
        if (dataset == 'nwtco'):
            truths = data[label][:, 6]
            cens = data[label][:, 5]
        elif (dataset == 'myData'):
            truths = data[label][:, 0]
            cens = data[label][:, 1]
        elif(dataset == 'aids2'):
             truths = data[label][:, 1]
             cens = data[label][:, 2]
        elif(dataset == 'flchain'):
            truths = data[label][:, 8]
            cens = data [label][:, 9]
        elif(dataset == 'Transplant'):
            truths = data[label][:, 50]
            cens = data [label][:, 51]
        elif(dataset == 'Waitlist'):
            truths = data[label][:, 24]
            cens = data [label][:, 25]
        
        return truths, cens



    
    def validate(self, 
                      batch_size, 
                      data, val_sess, 
                      dataset
                      ):
    
        validation_predictions = np.zeros(0)
        
        num_loops = len(data['y_Val_survivor']) // batch_size
        for i in range(num_loops):
            Y_val = data['y_Val_survivor'][batch_size*i:batch_size*i+batch_size]
            X_val = data['rnn_val'][batch_size*i:batch_size*i+batch_size]
            X_val = np.transpose(X_val, [0,2,1])
             
            feed_dict = {self.raw_X: X_val, self.survivor: Y_val, self.dropout_hid_state:1, self.dropout_keep_prob_output: 1, self.dropout_keep_prob_input: 1, self.dropout_keep_prob_state: 1}
            s1 = val_sess.run(self.cox_output, feed_dict)
            s1 = s1[:, 0]
            s1 = np.asarray(s1)
            validation_predictions = np.concatenate([validation_predictions, s1])
        
        
        truths, cens = RNN_SURV.get_truths_cens(self, data, dataset, phase = 'validation')
        
        C_index_val = concordance_index(truths[:len(validation_predictions)], -validation_predictions, cens[:len(validation_predictions)])
        
        print("c-index_validation: ", C_index_val)

        return C_index_val
    
    
    
    
    
    def test(
         self,
         data,
         dataset, 
         input_size,
         is_training
         ):
    
        RNN_SURV.build_model(self, batch_size = input_size)
            
        with tf.Session() as test_sess: 
            saver = tf.train.Saver()
            path = "./checkpoints/"
            path += str(dataset)
            path += "/state_size"
            path += str(self.state_size)
            if (self.bidirectional == True):
                path += "Bid"
            path += ".ckpt"
            saver.restore(test_sess, path)
            
            X_test = data['rnn_test']
            X_test = np.transpose(X_test, [0, 2, 1])
            y_test = data['y_Test_survivor']
            
            truths, cens = RNN_SURV.get_truths_cens(self, data, dataset, phase = 'test')
            
            feed_dict = {self.raw_X: X_test, self.survivor: y_test,self.dropout_hid_state: 1, self.dropout_keep_prob_output: 1, self.dropout_keep_prob_input : 1, self.dropout_keep_prob_state : 1}
            true_labels_, predictions_, cox_output_ = test_sess.run([self.survivor, self.predictions, self.cox_output], feed_dict)
            
            C_index_test = concordance_index(truths, -cox_output_, cens)

                	
        return C_index_test


In [None]:
dat=get_shuffled_data('myData')
elData=elaborate_data(data= dat, intervals_length=100, min_days=0, max_days_considered=1100, dataset='myData', cross_validation_number=5)

In [None]:
rnnNet=RNN_SURV(bidirectional=False,
          cell_type="LSTM",
          constant_loss=0.5,
          constant_raykar=0.5,
          learning_rate=0.01,
          num_features=12, 
          num_rnn_layers=2,
          num_nodes_h1=15, 
          num_nodes_h2=15,
          num_nodes_input_rnn_layer=18,
          num_steps=11 ,
          state_size=3
                 )

In [None]:
rnnNet

<__main__.RNN_SURV at 0x7f3747079400>

In [None]:
rnnNet.build_model( batch_size = 30)

Instructions for updating:
Please use `rate` instead of `keep_prob`. Rate should be set to `rate = 1 - keep_prob`.
  cell = tf.nn.rnn_cell.LSTMCell(state_size, state_is_tuple=True)
Instructions for updating:
Please use `keras.layers.RNN(cell)`, which is equivalent to this API
Instructions for updating:
Call initializer instance with the dtype argument instead of passing it to the constructor


LSTM cell
LSTM cell


<__main__.RNN_SURV at 0x7f3747079400>

In [None]:

rnnNet.train_RNN( 
              batch_size=30,
              data=elData,
              dataset='myData',
              dropout_hid_state=0.1,
              dropout_keep_prob_input=0.1,
              dropout_keep_prob_output=0.1,
              dropout_keep_prob_state=0.1,
              num_epochs_per_time=5,
              patience_max=3) 

path:  ./checkpoints/myData/state_size3.ckpt
NO SAVED MODEL
c-index_validation:  0.5045944866160608
UPDATED
c-index_validation:  0.44459981355706485
c-index_validation:  0.47001376126426064
c-index_validation:  0.5


0.5045944866160608

In [None]:
rnnNet.test(data=elData,
         dataset='myData', 
         input_size=30,
         is_training=False)

  cell = tf.nn.rnn_cell.LSTMCell(state_size, state_is_tuple=True)


LSTM cell
LSTM cell


InvalidArgumentError: ignored