# Cox-PH and DeepSurv

In this notebook we will train the [Cox-PH method](http://jmlr.org/papers/volume20/18-424/18-424.pdf), also known as [DeepSurv](https://bmcmedresmethodol.biomedcentral.com/articles/10.1186/s12874-018-0482-1).
We will use the METABRIC data sets as an example

A more detailed introduction to the `pycox` package can be found in [this notebook](https://nbviewer.jupyter.org/github/havakv/pycox/blob/master/examples/01_introduction.ipynb) about the `LogisticHazard` method.

The main benefit Cox-CC (and the other Cox methods) has over Logistic-Hazard is that it is a continuous-time method, meaning we do not need to discretize the time scale.

In [None]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from sklearn.preprocessing import StandardScaler
from sklearn_pandas import DataFrameMapper

import torch
import torchtuples as tt

from pycox.datasets import metabric
from pycox.models import CoxPH
from pycox.evaluation import EvalSurv

In [None]:
## Uncomment to install `sklearn-pandas`
! pip install sklearn-pandas

In [None]:
np.random.seed(1234)
_ = torch.manual_seed(123)

## Dataset

We load the METABRIC data set and split in train, test and validation.

In [None]:
for col in df_train.columns:
    print(df_train[col].dtype)

In [None]:
df_train = metabric.read_df()
df_test = df_train.sample(frac=0.2)
df_train = df_train.drop(df_test.index)
df_val = df_train.sample(frac=0.2)
df_train = df_train.drop(df_val.index)

In [None]:
df_train.head()

## MASH dataset

In [None]:
data = pd.read_csv("../clean_data/nafl/combined.large.nafl.csv")

X = data.drop(columns=['DaysUntilFirstProgression', 'Outcome'])
Y = data[['StudyID', 'DaysUntilFirstProgression', 'Outcome']]

X = X.set_index('StudyID')
Y = Y.set_index('StudyID')

In [None]:
print('Outcome' in X.columns)

In [None]:
foo = ['Lab', 'Med', 'Code']
print([x for x in X.columns if not any(f in x for f in foo)])

In [None]:
Y['Outcome'] = Y['Outcome'].astype(int)

In [None]:
from sklearn.model_selection import train_test_split

# X_train, X_test, y_train, y_test = train_test_split(X_torch, Y_torch, test_size=0.3, random_state=42)
X_train, X_test, y_train, y_test = train_test_split(X, Y, test_size=0.3, random_state=42)
X_train, X_val, y_train, y_val = train_test_split(X_train, y_train, test_size=0.3, random_state=42)

## Feature transforms
We have 9 covariates, in addition to the durations and event indicators.

We will standardize the 5 numerical covariates, and leave the binary variables as is. As variables needs to be of type `'float32'`, as this is required by pytorch.

In [None]:
cols_standardize = ['x0', 'x1', 'x2', 'x3', 'x8']
cols_leave = ['x4', 'x5', 'x6', 'x7']

standardize = [([col], StandardScaler()) for col in cols_standardize]
leave = [(col, None) for col in cols_leave]

x_mapper = DataFrameMapper(standardize + leave)

In [None]:
# x_train = x_mapper.fit_transform(df_train).astype('float32')
# x_val = x_mapper.transform(df_val).astype('float32')
# x_test = x_mapper.transform(df_test).astype('float32')

x_train = X_train.values.astype('float32')
x_val = X_val.values.astype('float32')
x_test = X_test.values.astype('float32')

In [None]:
# attempting the dataframe mapper method to standardize
cols_numerical = [x for x in X.columns if any(keyword in x for keyword in ['BMI', 'Age', 'Lab'])]
cols_other = [x for x in X.columns if x not in cols_numerical]

standardize = [([col], StandardScaler()) for col in cols_numerical]
leave = [(col, None) for col in cols_other]

x_mapper = DataFrameMapper(standardize + leave)

In [None]:
x_train = x_mapper.fit_transform(X_train).astype('float32')
x_val = x_mapper.transform(X_val).astype('float32')
x_test = x_mapper.transform(X_test).astype('float32')

In [None]:
lab_feat = [feat for feat in X.columns if 'Lab' in feat]
numerical_feat = ['mean_BMI', 'last_BMI', 'FirstNAFL.Age.90']
numerical_feat.extend(lab_feat)

In [None]:
len(numerical_feat)

In [None]:
# write a function to only standardize the numerical columns and reattach to the rest of the dataframe
scaler = StandardScaler()

def standardize_numerical(dataframe, num_feat=numerical_feat, training_set=True):
    """
    dataframe: Pandas DataFrame

    Returns: a processed DataFrame where the numerical features have been standardized and the categorical features remain the same.
    """
    if training_set:
        scaled = scaler.fit_transform(dataframe[num_feat])
    else:
        scaled = scaler.transform(dataframe[num_feat])
        
    scaled_df = pd.DataFrame(scaled, columns=num_feat, index=dataframe.index)
    cat = dataframe.drop(columns=num_feat)
    processed = pd.concat([scaled_df, cat], axis=1)

    return processed

In [None]:
# standardize our features
X_train_scaled = standardize_numerical(X_train, training_set=True)
X_val_scaled = standardize_numerical(X_val, training_set=False)
X_test_scaled = standardize_numerical(X_test, training_set=False)

In [None]:
foo = pd.DataFrame(x_train, columns=cols_numerical + cols_other, index = X_train_scaled.index)

We need no label transforms

In [None]:
get_target = lambda df: (df['DaysUntilFirstProgression'].values, df['Outcome'].values)
y_train = get_target(y_train)
y_val = get_target(y_val)
durations_test, events_test = get_target(y_test)
val = x_val, y_val

# get_target = lambda df: (df['duration'].values, df['event'].values)

# y_train = get_target(df_train)
# y_val = get_target(df_val)
# durations_test, events_test = get_target(df_test)
# val = x_val, y_val

In [None]:
x_train.dtype

In [None]:
events_test.dtype

In [None]:
y_train_new = (y_train[0].astype('float32'), y_train[1].astype('int32'))
y_val_new = (y_val[0].astype('float32'), y_val[1].astype('int32'))
durations_test = durations_test.astype('float32')
events_test = events_test.astype('int32')

In [None]:
y_train_new[0]

## Neural net

We create a simple MLP with two hidden layers, ReLU activations, batch norm and dropout. 
Here, we just use the `torchtuples.practical.MLPVanilla` net to do this.

Note that we set `out_features` to 1, and that we have not `output_bias`.

In [None]:
in_features = x_train.shape[1]
num_nodes = [32, 32]
out_features = 1
batch_norm = True
dropout = 0.1
output_bias = False

net = tt.practical.MLPVanilla(in_features, num_nodes, out_features, batch_norm,
                              dropout, output_bias=output_bias)

## Training the model

To train the model we need to define an optimizer. You can choose any `torch.optim` optimizer, but here we instead use one from `tt.optim` as it has some added functionality.
We use the `Adam` optimizer, but instead of choosing a learning rate, we will use the scheme proposed by [Smith 2017](https://arxiv.org/pdf/1506.01186.pdf) to find a suitable learning rate with `model.lr_finder`. See [this post](https://towardsdatascience.com/finding-good-learning-rate-and-the-one-cycle-policy-7159fe1db5d6) for an explanation.

In [None]:
model = CoxPH(net, tt.optim.Adam)

In [None]:
batch_size = 256
lrfinder = model.lr_finder(x_train, y_train_new, batch_size, tolerance=10)
_ = lrfinder.plot()
# plt.savefig("batch_loss_learning_rate.png", dpi=300)

In [None]:
lrfinder.get_best_lr()

Often, this learning rate is a little high, so we instead set it manually to 0.01

In [None]:
model.optimizer.set_lr(0.003) # 0.01

We include the `EarlyStopping` callback to stop training when the validation loss stops improving. After training, this callback will also load the best performing model in terms of validation loss.

In [None]:
epochs = 512
callbacks = [tt.callbacks.EarlyStopping()]
verbose = True

In [None]:
%%time
log = model.fit(x_train, y_train_new, batch_size, epochs, callbacks, verbose,
                val_data=val, val_batch_size=batch_size)

In [None]:
_ = log.plot()
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.title('Loss over Training Epochs')
# plt.savefig("loss_per_epoch.png", dpi=300)

We can get the partial log-likelihood

In [None]:
model.partial_log_likelihood(*val).mean()

In [None]:
train = x_train, y_train_new

In [None]:
model.partial_log_likelihood(*train).mean()

## Prediction

For evaluation we first need to obtain survival estimates for the test set.
This can be done with `model.predict_surv` which returns an array of survival estimates, or with `model.predict_surv_df` which returns the survival estimates as a dataframe.

However, as `CoxPH` is semi-parametric, we first need to get the non-parametric baseline hazard estimates with `compute_baseline_hazards`. 

Note that for large datasets the `sample` argument can be used to estimate the baseline hazard on a subset.

In [None]:
_ = model.compute_baseline_hazards()

In [None]:
surv = model.predict_surv_df(x_test)

In [None]:
y_test_studyid = y_test.reset_index(inplace=False)
pos_outcome = y_test_studyid.index[y_test['Outcome'] == 1].tolist()
neg_outcome = y_test_studyid.index[y_test['Outcome'] == 0].tolist()

In [None]:
# calculate one curve for all patients with positive progression
pos_surv = surv[pos_outcome]
pos_surv_avg = pos_surv.mean(axis=1)

In [None]:
# calculate one curve for all patients with negative progression
neg_surv = surv[neg_outcome]
neg_surv_avg = neg_surv.mean(axis=1)

In [None]:
plt.plot(pos_surv_avg, label='Progressed Patients')
plt.plot(neg_surv_avg, label='Censored Patients')

plt.ylabel('Average S(t | x)')
_ = plt.xlabel('Time')
plt.title('Average Survival Prediction')
plt.legend()
# plt.savefig("average_survival.png", dpi=300)

In [None]:
surv.iloc[:, pos_outcome[0:10]].plot(legend=None)
plt.ylabel('S(t | x)')
_ = plt.xlabel('Time')
plt.title('Survival Prediction for 10 Progressed Patients')
# plt.legend(False)
# surv.iloc[:, :].plot()
# plt.ylabel('S(t | x)')
# _ = plt.xlabel('Time')
# plt.savefig("progressed_survival.png", dpi=300)

In [None]:
surv.iloc[:, neg_outcome[0:10]].plot(legend=None)
plt.ylabel('S(t | x)')
_ = plt.xlabel('Time')
plt.title('Survival Prediction for 10 Censored Patients')
# plt.savefig("censored_survival.png", dpi=300)

In [None]:
surv.iloc[:, 1000:1010].plot()
plt.ylabel('S(t | x)')
_ = plt.xlabel('Time')

In [None]:
x_test.shape

In [None]:
surv.shape

In [None]:
surv.head()

## Evaluation

We can use the `EvalSurv` class for evaluation the concordance, brier score and binomial log-likelihood. Setting `censor_surv='km'` means that we estimate the censoring distribution by Kaplan-Meier on the test set.

In [None]:
ev = EvalSurv(surv, durations_test, events_test, censor_surv='km')

In [None]:
ev.concordance_td()

In [None]:
y_train

In [None]:
surv_train = model.predict_surv_df(x_train)

# Step 2: Create EvalSurv object using training labels
ev_train = EvalSurv(
    surv_train,
    durations=y_train_new[0],
    events=y_train_new[1],
    censor_surv='km'
)

# # Step 3: Compute Concordance Index
# c_index_train = ev_train.concordance_td('antolini')  # or 'adj_antolini', 'uno', etc.
# print(f"C-index (train): {c_index_train:.4f}")

In [None]:
ev_train.concordance_td()

In [None]:
time_grid = np.linspace(durations_test.min(), durations_test.max(), 100)
_ = ev.brier_score(time_grid).plot()

# plt.figure(figsize=(6.4, 4.8))
plt.xlabel('Time t (days)')
plt.ylabel('Brier Score')
plt.ylim(0, 0.9)
plt.title('Cox PH DeepSurv Brier Score Over Time')
# plt.savefig("results/deepsurv_brier_score.png", dpi=300)

In [None]:
ev.integrated_brier_score(time_grid)

In [None]:
ev_train.integrated_brier_score(time_grid)

In [None]:
ev.integrated_nbll(time_grid)

In [None]:
ev_train.integrated_nbll(time_grid)

# Adapt to produce SHAP scores

In [None]:
net.eval()

In [None]:
# choose a scalar output for the SHAP
import shap
import torch

net.to('cpu')

# # Background dataset: small sample from training data
# X_background = torch.tensor(x_train[:100], dtype=torch.float32)

# # Test samples for SHAP analysis
# X_explain = torch.tensor(x_test[:50], dtype=torch.float32)

# DeepExplainer works directly with the PyTorch model
explainer = shap.DeepExplainer(net, torch.tensor(x_train, dtype=torch.float32))

In [None]:
# Compute SHAP values — returns shape (n_samples, n_features)
shap_values = explainer.shap_values(torch.tensor(x_test, dtype=torch.float32))

In [None]:
shap_values_squeezed = shap_values.squeeze(-1)
shap_values_squeezed.shape

In [None]:
shap_values_squeezed.shape

In [None]:
X.columns.shape

In [None]:
import pickle
filename = 'results/coxph_shap_values_scaledx.pkl'
with open(filename, 'wb') as file:
    # Use pickle.dump to serialize and write the data
    pickle.dump(shap_values_squeezed, file)

In [None]:
foo = pickle.load('results/coxph_shap_values_scaledx.pkl')

In [None]:
file_path = 'results/coxph_shap_values_scaledx.pkl'
with open(file_path, 'rb') as file:
    foo = pickle.load(file)

In [None]:
shap_values_squeezed.shape

In [None]:
X.columns

In [None]:
import matplotlib.pyplot as plt
shap.summary_plot(shap_values_squeezed, x_test, feature_names=X.columns, show=False)

In [None]:
shap_values_squeezed.shape

In [None]:
len(features)

In [None]:
X.columns.shape

In [None]:
# plot with human readable names
features = ['Transaminase-SGOT AST Blood Test', 'Iron Binding Capacity Test', 'Chest pain', 'Essential (primary) hypertension', 'Hemoglobin Blood Test', 'Very low-density lipoprotein Blood Test', 'Type 2 diabetes mellitus', 'Mean Corpuscular Hemoglobin Blood Test', 'Abnormal results of liver function studies', 'Vitamin D Blood Test','Hematocrit Blood Test','Mean corpuscular volume Blood Test','Other specified disorders involving the immune mechanism','Most recent BMI before MASLD diagnosis','Platelet Blood Test','Cardiac Risk Ratio','Red blood cell count Blood Test','Vitamin D deficiency, unspecified','Other nonspecific abnormal finding of lung field','Other specified counseling' ]
shap_values_top = shap_values_squeezed[:, :20]
x_test_top = x_test[:, :20]
# features_top = features[:20]
features_top = X.columns[:20]

plt.figure(figsize=(18, 7))
shap.summary_plot(shap_values_top, x_test_top, feature_names=features_top, show=False, max_display=10)
plt.xlabel('SHAP value (impact on model output)')
plt.title('Cox PH DeepSurv Influential Features')
# plt.savefig('results/coxph_shap.png', dpi=300)

In [None]:
# plot with human readable names
features = ['Transaminase-SGOT AST Blood Test', 'Iron Binding Capacity Test', 'Chest pain', 'Essential (primary) hypertension', 'Hemoglobin Blood Test', 'Very low-density lipoprotein Blood Test', 'Type 2 diabetes mellitus', 'Mean Corpuscular Hemoglobin Blood Test', 'Abnormal results of liver function studies', 'Vitamin D Blood Test','Hematocrit Blood Test','Mean corpuscular volume Blood Test','Other specified disorders involving the immune mechanism','Most recent BMI before MASLD diagnosis','Platelet Blood Test','Cardiac Risk Ratio','Red blood cell count Blood Test','Vitamin D deficiency, unspecified','Other nonspecific abnormal finding of lung field','Other specified counseling' ]
shap_values_top = shap_values_squeezed[:, :20]
x_test_top = x_test[:, :20]
features_top = features[:20]
# features_top = X.columns[:20]

plt.figure(figsize=(18, 7))
shap.summary_plot(shap_values_top, x_test_top, feature_names=features_top, show=False, max_display=10)
plt.xlabel('SHAP value (impact on model output)')
plt.gca().tick_params(axis='y', labelsize=11)
plt.title('Cox PH DeepSurv Influential Features')
plt.tight_layout()
plt.savefig('results/coxph_shap.png', dpi=300)

In [None]:
# Compute mean absolute SHAP values per feature
shap_mean_abs = np.abs(shap_values_squeezed).mean(axis=0)

# Create a Series for easy sorting
shap_series = pd.Series(shap_mean_abs, index=X.columns)

# Get top 20 feature names
top20_features = shap_series.sort_values(ascending=False).head(20).index.tolist()
top20_features

In [None]:
# top features according to shap score
# rank features by mean absolute SHAP value
# Calculate mean absolute SHAP value for each feature
feature_names = X.columns

mean_abs_shap = np.abs(shap_values_squeezed).mean(axis=0)

# Get indices of top features
top_indices = np.argsort(mean_abs_shap)[::-1]  # descending order

# Get corresponding feature names and importance values
top_features = [(feature_names[i], mean_abs_shap[i]) for i in top_indices]

In [None]:
top_features

In [None]:
bottom_10_features = [x[0] for x in top_features[-10:]]

In [None]:
bottom_10_features

In [None]:
top_10_features = [x[0] for x in top_features[:10]]

In [None]:
top_10_features

In [None]:
# 1. Compute the mean SHAP value for each feature (not absolute)
mean_shap = shap_values_squeezed.mean(axis=0)

# 2. Get indices of top 10 positive and top 10 negative impact features
top_positive_indices = np.argsort(mean_shap)[-10:]  # most positive
top_negative_indices = np.argsort(mean_shap)[:10]   # most negative

# 3. Retrieve feature names and their SHAP values
top_positive_features = [(feature_names[i], mean_shap[i]) for i in reversed(top_positive_indices)]
top_negative_features = [(feature_names[i], mean_shap[i]) for i in top_negative_indices]

pos_names = [x[0] for x in top_positive_features]
# pos_names

neg_names = [x[0] for x in top_negative_features]
# neg_names

In [None]:
pos_names

In [None]:
neg_names

# plot survival curves of populations split on the most descriptive features

In [None]:
# splitting on median for continuous variables
# X['Lab_2091-7'].mean()
np.sum([1 for x in X['Lab_2091-7'].values if x > 5])

In [None]:
np.sum([1 for x in X['Code_R07.9'].values if x != 0])

In [None]:
y_test_studyid.set_index('StudyID', inplace=True)

In [None]:
bar = X[X['Lab_2091-7'] > 0].index

In [None]:
y_test_studyid = y_test.reset_index(inplace=False)
pos_outcome = y_test_studyid.index[y_test['Lab_2091-7'] == 1].tolist()
neg_outcome = y_test_studyid.index[y_test['Outcome'] == 0].tolist()

pos_surv = surv[pos_outcome]
pos_surv_avg = pos_surv.mean(axis=1)

plt.plot(pos_surv_avg, label='Progressed Patients')
plt.plot(neg_surv_avg, label='Censored Patients')

plt.ylabel('Average S(t | x)')
_ = plt.xlabel('Time')
plt.title('Average Survival Prediction')
plt.legend()
# plt.savefig("average_survival.png", dpi=300)

In [None]:
# i can only plot the survival estimates for the x_test
y_test_id = y_test.reset_index(inplace=False)
pos_outcome = y_test_studyid.index[X_test_scaled['Lab_2091-7'] > 0].tolist()
neg_outcome = y_test_studyid.index[X_test_scaled['Lab_2091-7'] <= 0].tolist()

pos_surv = surv[pos_outcome]
pos_surv_avg = pos_surv.mean(axis=1)

plt.plot(pos_surv_avg, label='Patients with Positive Lab_2091-7')
plt.plot(neg_surv_avg, label='Censored Patients')

plt.ylabel('Average S(t | x)')
_ = plt.xlabel('Time')
plt.title('Average Survival Prediction')
plt.legend()
# plt.savefig("average_survival.png", dpi=300)

In [None]:
X_test_scaled['Lab_2091-7'].median() # mean: -0.05, median: -0.4

In [None]:
# plot diabetes, insulin, and VLDL
# np.sum([1 for x in X_test_scaled['MedType_Code_EPIC-MED_119019'] if not x]) # insulin
# np.sum([1 for x in X_test_scaled['Code_E11.9'] if not x]) # diabetes
np.sum([1 for x in X_test['Lab_2091-7'] if x > 0]) # vldl

In [None]:
X_test_scaled.shape

In [None]:
# type(X_test['MedType_Code_EPIC-MED_119019'])
# target_ids = X_test.index[X_test['Code_E11.9'] == True].tolist()
target_ids = X_test.index[X_test['Code_R94.5'] == True].tolist()

In [None]:
studyid_to_col = {study_id: i for i, study_id in enumerate(X_test.index)}
target_cols = [studyid_to_col[sid] for sid in target_ids if sid in studyid_to_col]
pos_surv = surv.iloc[:, target_cols]

In [None]:
neg_surv = surv.drop(surv.columns[target_cols], axis=1)

In [None]:
pos_surv.shape

In [None]:
neg_surv.shape

In [None]:
# y_test_studyid = y_test.reset_index(inplace=False)
# pos_outcome = y_test_studyid.index[X_test['Lab_2091-7'] == 1].tolist()
# neg_outcome = y_test_studyid.index[y_test['Outcome'] == 0].tolist()

# pos_surv = surv[pos_outcome]
pos_surv_avg = pos_surv.mean(axis=1)
neg_surv_avg = neg_surv.mean(axis=1)

plt.plot(pos_surv_avg, label='Patients Diagnosed with Abnormal Liver Function (n=609)')
plt.plot(neg_surv_avg, label='Patients Diagnosed without Abnormal Liver Function (n=2866)')

plt.ylabel('Average S(t | x)')
_ = plt.xlabel('Time t (days)')
plt.title('Predicted Time to Progression Stratified by Abnormal Liver Diagnosis')
plt.legend()
plt.savefig("results/coxph_abnormalliver.png", dpi=300)

In [None]:
# check common shap features
nn = {'Lab_4679-7',
 'Lab_14338-8',
 'Lab_2132-9',
 'Lab_6768-6',
 'Code_Z23',
 'Lab_6690-2',
 'Lab_2093-3',
 'MedType_Code_EPIC-MED_10328',
 'Lab_13457-7',
 'Lab_2571-8',
 'MedType_Code_EPIC-MED_27698',
 'Lab_19153-6',
 'Lab_2501-5',
 'Lab_786-4',
 'Code_E78.5',
 'Lab_XC5-9',
 'Lab_2502-3',
 'MedType_Code_EPIC-MED_40900',
 'Lab_2089-1',
 'Lab_789-8'}

cox = {'Lab_1920-8',
 'Lab_2500-7',
 'Code_R07.9',
 'Code_I10',
 'Lab_718-7',
 'Lab_2091-7',
 'Code_E11.9',
 'Lab_785-6',
 'Code_R94.5',
 'Lab_62292-8',
 'Lab_4544-3',
 'Lab_787-2',
 'Code_D89.89',
 'last_BMI',
 'Lab_777-3',
 'Lab_9830-1',
 'Lab_789-8',
 'Code_E55.9',
 'Code_R91.8',
 'Code_Z71.89'}

In [None]:
intersect = nn.intersection(cox)

In [None]:
intersect

In [None]:
net.eval()