In [None]:
from Dataset.dataset import Flchain
import tensorflow as tf
from Models.model import CoxSE, CoxSENAM, SurvivalModelBase
from lifelines.utils.concordance import concordance_index
import numpy as np
import matplotlib.pyplot as plt

In [None]:
def plot_explanations(ds, feature_importance, title='Model Explanations'):
    idxs = np.argsort(feature_importance)
    plt.figure(figsize=(10, 10))
    plt.barh(y=np.array(ds.features_names)[idxs],width=feature_importance[idxs])
    plt.title(title)

# Loading Data

In [None]:
ds = Flchain('Dataset/flchain.csv', test_fract=0.3, verbose=True)

In [None]:
val_id = 0

(x_train, ye_train, y_train, e_train,
 x_val, ye_val, y_val, e_val,
 x_test, ye_test, y_test, e_test) = ds.get_train_val_test_final_eval(val_id=val_id)

# Training CoxSE Model

In [None]:
alpha = 0.01
beta = 0.01
num_layers = 2
num_nodes = 16
act = 'relu'
l2w = 0.0001
dropoutp = 0.2
learning_rate = 0.001


optimizer = tf.keras.optimizers.Adam(learning_rate=learning_rate)
es = tf.keras.callbacks.EarlyStopping(monitor='val_loss', patience=100, restore_best_weights=True)

callbacks = [es]

mdl_coxse = CoxSE(input_shape=ds.input_shape, alpha=alpha, beta=beta, 
                  num_layers=num_layers, num_nodes=num_nodes, act=act, l2w=l2w, dropoutp=dropoutp)
mdl_coxse.special_compile(custom_optimizer=optimizer, custom_metric_func=SurvivalModelBase.cindex, custom_metric_name='CI')

mdl_coxse_history = mdl_coxse.fit(x_train, ye_train, epochs=1000, batch_size=512,
                                validation_data=(x_val, ye_val), callbacks=callbacks, verbose=True)


In [None]:
plt.plot(mdl_coxse_history.history['loss'])
plt.plot(mdl_coxse_history.history['val_loss'])
plt.xlabel('epochs')
plt.ylabel('loss')

In [None]:
y_pred_train, w_pred_train = mdl_coxse.predict(x_train)
y_pred_test, w_pred_test = mdl_coxse.predict(x_test)

In [None]:
ci = concordance_index(y_test, -y_pred_test, e_test)
print(ci)

In [None]:
plot_explanations(ds=ds, feature_importance=np.abs(w_pred_test).mean(axis=0), title='CoxSE Aggregate Explanations')

# Training CoxSENAM Model

In [None]:
alpha = 0.01
beta = 0.01
num_layers = 2
num_nodes = 16
act = 'relu'
l2w = 0.0001
dropoutp = 0.2
learning_rate = 0.001


optimizer = tf.keras.optimizers.Adam(learning_rate=learning_rate)
es = tf.keras.callbacks.EarlyStopping(monitor='val_loss', patience=100, restore_best_weights=True)

callbacks = [es]

mdl_coxsenam = CoxSENAM(input_shape=ds.input_shape, alpha=alpha, beta=beta, 
                  num_layers=num_layers, num_nodes=num_nodes, act=act, l2w=l2w, dropoutp=dropoutp)
mdl_coxsenam.special_compile(custom_optimizer=optimizer, custom_metric_func=SurvivalModelBase.cindex, custom_metric_name='CI')

mdl_coxsenam_history = mdl_coxsenam.fit(x_train, ye_train, epochs=1000, batch_size=512,
                                validation_data=(x_val, ye_val), callbacks=callbacks, verbose=True)


In [None]:
plt.plot(mdl_coxsenam_history.history['loss'])
plt.plot(mdl_coxsenam_history.history['val_loss'])
plt.xlabel('epochs')
plt.ylabel('loss')

In [None]:
y_pred_train, w_pred_train = mdl_coxsenam.predict(x_train)
y_pred_test, w_pred_test = mdl_coxsenam.predict(x_test)

In [None]:
ci = concordance_index(y_test, -y_pred_test, e_test)
print(ci)

In [None]:
plot_explanations(ds=ds, feature_importance=np.abs(w_pred_test).mean(axis=0), title='CoxSE Aggregate Explanations')