This notebook runs a Cox model to predict the survival outcome in a cross validation fashion.

In [None]:
import numpy as np
import pandas as pd 

import torchtuples as tt
from pycox.models import CoxPH

from sklearn.model_selection import KFold, train_test_split
from sksurv.nonparametric import kaplan_meier_estimator

In [None]:
# Choose embedding to run the experiment on (consider both _predicted_binary.csv and _embedding.csv)
embedding_type = 'BERT_predicted_binary'

In [None]:
# Load data
embedding = pd.read_csv('data/{}.csv'.format(embedding_type), index_col = [0, 1] if 'predicted' in embedding_type else [0])
outcomes = pd.read_csv('data/TGCA_Merged.csv', index_col = 0)

In [None]:
if 'binary' in embedding_type:
    # Avoid nan issue
    embedding.ajcc_pathologic_tumor_stage = embedding.ajcc_pathologic_tumor_stage.astype('category')

### Cross prediction

In [None]:
# Load splits
split = pd.read_csv('results/split.csv', index_col = [0])

In [None]:
def train_and_predict(data, index_train, index_val, index_test, prediction_times):
    """
        Function to train a Cox model and predict the outcome

        Args:
            index_train (list): index used to train model.
            index_val (list): index used to stop training.
            index_test (list): index used to test.
            prediction_times (list float): Times to predict survival.

        Returns:
            DataFrame (len(index_test) * len(prediction_times)) - Predictions for each patients at the difference time horizons
    """
    trans = lambda x: x.values.astype('float32')
    np.random.seed(42)

    ## Define NN connecting embedding to Cox
    net = tt.practical.MLPVanilla(data.shape[1], [], 1, True, 0.1, output_bias = False)
    model = CoxPH(net, tt.optim.Adam)

    ## Train
    model.fit(trans(data.loc[index_train]), (trans(outcomes.loc[index_train].t), trans(outcomes.loc[index_train].e)), 
            batch_size = 100, epochs = 500, callbacks = [tt.callbacks.EarlyStopping()], verbose = False,
            val_data = (trans(data.loc[index_val]), (trans(outcomes.loc[index_val].t), trans(outcomes.loc[index_val].e))))
    _ = model.compute_baseline_hazards() # Fit the non-parametric baseline

    ## Predict and interpolate
    embed_test = data.loc[index_test]
    pred = model.predict_surv_df(trans(embed_test))
    pred.columns = embed_test.index
    
    pred_times = pd.DataFrame(np.nan, columns = pred.columns, index = prediction_times)
    pred = pd.concat([pred, pred_times], axis = 0).sort_index(kind = 'stable').bfill().ffill()
    pred = pred[~pred.index.duplicated(keep='first')]
    pred = pred.loc[prediction_times]
    return pred.T

In [None]:
# Used to predict outcomes
prediction_times = np.linspace(0, outcomes.t.max(), 100)

In [None]:
predictions = {}
for split_type in split.columns:
    predictions[split_type] = pd.DataFrame(index = split.index, columns = prediction_times)
    embed = pd.get_dummies(embedding.loc[split_type]) if 'predicted' in embedding_type else embedding # If only cross validated - Use the same
    for fold in split[split_type].dropna().unique():
        train = split[split_type] != fold
        train = train[train].index
        train, val = train_test_split(train, test_size = 0.2, random_state = 42) 
        test = split[split_type] == fold
        predictions[split_type][test] = train_and_predict(embed, train, val, test, prediction_times)

In [None]:
predictions = pd.concat(predictions)
predictions.to_csv('results/{}_predictions.csv'.format(embedding_type))

---------

### Does adding manually extracted features improve performance ?
The hypothesis is that info in the text might be *complementary*, not replacing the other features.

To evaluate, run the following code that concat the binary embedding with the other you are considering and use it for prediction. 

To also investigate if thesimple features are more useful then the embedding, run the previous code with the embedding binary_embedding.csv (jsut cahnge the embedding_type variable)

In [None]:
assert 'binary' not in embedding_type, 'Not useful to combine these embeddings'

In [None]:
embedding_binary = pd.read_csv('data/binary_embedding.csv', index_col = [0])
concatenated_emb = pd.concat([embedding, embedding_binary], axis = 1)

In [None]:
predictions = {}
for split_type in split.columns:
    predictions[split_type] = pd.DataFrame(index = split.index, columns = prediction_times)
    for fold in split[split_type].dropna().unique():
        train = split[split_type] != fold
        train = train[train].index
        train, val = train_test_split(train, test_size = 0.2, random_state = 42) 
        test = split[split_type] == fold
        predictions[split_type][test] = train_and_predict(concatenated_emb, train, val, test, prediction_times)

In [None]:
predictions = pd.concat(predictions)
predictions.to_csv('results/{}_concat_predictions.csv'.format(embedding_type))