Use Gradient Boosting for survival prediction as AFT

Test on regression trees:

Gridsearch, plot performance, test changing some hyperparameters, use early stoppping to prevent overfitting, fit_predict on simulated data and predict on original, plot feature importance (gini) and permutation features importances (eli5)

Test on least squares:
Gridsearch, plot performance, test changing some hyperparameters, fit_predict on simulated data and predict on original, plot coefficients and permutation features importances (eli5)

In [None]:
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
import os, glob, inspect, sys

from sklearn.model_selection import train_test_split
from sklearn.model_selection import GridSearchCV, ShuffleSplit
from sksurv.metrics import concordance_index_censored, concordance_index_ipcw
from sksurv.ensemble import ComponentwiseGradientBoostingSurvivalAnalysis
from sksurv.ensemble import GradientBoostingSurvivalAnalysis
import eli5
from eli5.sklearn import PermutationImportance

currentdir = os.path.dirname(os.path.abspath(inspect.getfile(inspect.currentframe())))
parentdir = os.path.dirname(currentdir)
sys.path.insert(0,parentdir) 
import epri_mc_lib_3 as mc
from importlib import reload
reload(mc)

In [None]:
class EarlyStoppingMonitor:

    def __init__(self, window_size, max_iter_without_improvement):
        self.window_size = window_size
        self.max_iter_without_improvement = max_iter_without_improvement
        self._best_step = -1

    def __call__(self, iteration, estimator, args):
        # continue training for first self.window_size iterations
        if iteration < self.window_size:
            return False

        # compute average improvement in last self.window_size iterations.
        # oob_improvement_ is the different in negative log partial likelihood
        # between the previous and current iteration.
        start = iteration - self.window_size + 1
        end = iteration + 1
        improvement = np.mean(estimator.oob_improvement_[start:end])

        if improvement > 1e-6:
            self._best_step = iteration
            return False  # continue fitting

        # stop fitting if there was no improvement
        # in last max_iter_without_improvement iterations
        diff = iteration - self._best_step
        return diff >= self.max_iter_without_improvement

In [None]:
data = pd.read_csv(os.path.join(os.path.dirname(os.getcwd()), '../Data/Merged_data/CopulaGAN_simulated_data_survival_2.csv'),)

In [None]:
data.info()

In [None]:
data_x = data.iloc[:, 2:]
data_y = data.iloc[:, 0:2]

In [None]:
data_real = pd.read_csv(os.path.join(os.path.dirname(os.getcwd()), '../Data/Merged_data/Survival_df.csv'),
                  index_col=0)

real_x = data_real.iloc[:, 2:]
real_y_pre = data_real.iloc[:, 0:2]
real_y = real_y_pre.to_records(index=False)

## Train test split

In [None]:
X_train, X_test, y_train_pre, y_test_pre = train_test_split(
    data_x, data_y, test_size=0.2, random_state=42)

In [None]:
y_train = y_train_pre.to_records(index=False)
y_test = y_test_pre.to_records(index=False)

# Accelerated Time Failure models

The concordance index is defined as the proportion of all comparable pairs in which the predictions and outcomes are concordant.
We also choose to use concordance_index_ipcw as a metric as while the difference between concordance_index_ipcw and concordance_index_censored is negligible for small amounts of censoring, when analyzing survival data with moderate to high amounts of censoring the CI_censored is over confident.

We chose to run the GB with 'ipcwls' ( The loss ‘ipcwls’ refers to inverse-probability of censoring weighted least squares error.) as with this method we can return *time to event* and not only log hazard ratio and they adjust for censoring.

## Gradient boosting with regression trees

First we tested on a single split or 'stump' and print the resulting concordance index.

In [None]:
stump = GradientBoostingSurvivalAnalysis(loss='ipcwls',
    n_estimators=100, learning_rate=1.0, max_depth=1, random_state=42
)
stump.fit(X_train, y_train)
cindex = stump.score(X_test, y_test)

print(round(cindex, 3))
mc.score_survival_model_ipcw(stump, X_test, y_train, y_test)

### Gridsearch CV

In [None]:
param_grid = {'learning_rate': [0.01, 0.05, 0.1],
              'n_estimators': [250, 500, 750, 1000, 1250],
              'max_depth': [2, 3, 4],
              'min_impurity_decrease': [0, 0.01],
              'subsample': [0.4, 0.5, 0.6]
             }
cv = ShuffleSplit(n_splits=10, test_size=0.5, random_state=0) #to use first to refine search
GSCV_tree = GridSearchCV(GradientBoostingSurvivalAnalysis(loss='ipcwls', random_state=42), param_grid, 
                         scoring=mc.score_survival_model,
                   n_jobs=4, refit=False,
                   cv=cv, verbose=1)

In [None]:
GSCV_tree.fit(X_train, y_train)

In [None]:
GSCV_tree.best_params_

### Model performance

In [None]:
GB_tree = GradientBoostingSurvivalAnalysis(
    loss='ipcwls',
    subsample=0.4,
    min_impurity_decrease=0.01,
    learning_rate=0.01, 
    max_depth=2, 
    random_state=42
)

Testing several n_estimators

In [None]:
scores =  pd.DataFrame()
for n_estimators in range(1, 3000, 100):
    GB_tree.set_params(n_estimators=n_estimators)
    GB_tree.fit(X_train, y_train)
    results = pd.DataFrame({'n_estimators': n_estimators,
                           'ci_train': GB_tree.score(X_train, y_train),
                            'ci_test': GB_tree.score(X_test, y_test)
                           }, index=[0])
    scores = scores.append(results, ignore_index=True)

Plotting vs n_estimators

In [None]:
sns.set(style='whitegrid')
scores.set_index('n_estimators').plot()
plt.xlabel('n_estimator')
plt.ylabel('concordance index')
plt.title('Gradient boosting with regression trees')

In [None]:
n_estimators = [i * 5 for i in range(1, 21)]

estimators = {
    "no regularization": GradientBoostingSurvivalAnalysis(
    min_impurity_decrease=0.01,
        learning_rate=1.0, max_depth=2, random_state=0
    ),
    "learning rate": GradientBoostingSurvivalAnalysis(
    min_impurity_decrease=0.01,
        learning_rate=0.1, max_depth=2, random_state=0
    ),
    "dropout": GradientBoostingSurvivalAnalysis(
    min_impurity_decrease=0.01,
        learning_rate=1.0, dropout_rate=0.1, max_depth=2, random_state=0
    ),
    "subsample": GradientBoostingSurvivalAnalysis(
    min_impurity_decrease=0.01,
        learning_rate=1.0, subsample=0.5, max_depth=2, random_state=0
    ),
}

scores_reg = {k: [] for k in estimators.keys()}
scores_train_reg = {k: [] for k in estimators.keys()}

for n in n_estimators:
    for name, est in estimators.items():
        est.set_params(n_estimators=n)
        est.fit(X_train, y_train)
        cindex = est.score(X_test, y_test)
        cindex_train = est.score(X_train, y_train)
        scores_reg[name].append(cindex)
        scores_train_reg[name].append(cindex_train)
        
scores_res = pd.DataFrame(scores_reg, index=n_estimators)
scores_train_reg = pd.DataFrame(scores_train_reg, index=n_estimators)



In [None]:
ax = scores_res.plot(xlabel="n_estimators", ylabel="concordance index")
ax.grid(True)
plt.title('Test')
plt.ylim(0.84, 0.98)

In [None]:
ax = scores_train_reg.plot(xlabel="n_estimators", ylabel="concordance index")
ax.grid(True)
plt.title('Train')
plt.ylim(0.84, 0.98)

### Early stoppping

In [None]:
GB_tree_ES = GradientBoostingSurvivalAnalysis(
    loss='ipcwls',
    n_estimators=1250,
    subsample=0.4,
    min_impurity_decrease=0.01,
    learning_rate=0.01, 
    max_depth=2, 
    random_state=42
)

monitor = EarlyStoppingMonitor(25, 100)

GB_tree_ES.fit(X_train, y_train, monitor=monitor)

print("Fitted base learners:", GB_tree_ES.n_estimators_)

cindex = GB_tree_ES.score(X_test, y_test)
cindex = GB_tree_ES.score(X_train, y_train)

print("Performance on test set", round(cindex, 3))
print("Performance on train set", round(cindex_train, 3))
print('CI_ipcw', mc.score_survival_model_ipcw(GB_tree_ES, X_test, y_train, y_test))


improvement = pd.Series(
    GB_tree_ES.oob_improvement_,
    index=np.arange(1, 1 + len(GB_tree_ES.oob_improvement_))
)
ax = improvement.plot(xlabel="iteration", ylabel="oob improvement")
ax.axhline(0.0, linestyle="--", color="gray")
cutoff = len(improvement) - monitor.max_iter_without_improvement
ax.axvline(cutoff, linestyle="--", color="C3")

_ = improvement.rolling(monitor.window_size).mean().plot(ax=ax, linestyle=":")


## Prediction

We use the early stopping model to prevent overfitting

### Simulated data

In [None]:
pred_Xtest = GB_tree_ES.predict(X_test)

In [None]:
prediction = y_test_pre.copy()
prediction['pred_X_test'] = pred_Xtest

In [None]:
sns.set(style='white')
sns.scatterplot(x='F_Time', y='pred_X_test', hue='Observed', data=prediction,
               alpha=0.6, palette=sns.xkcd_palette(['marine blue', 'deep red'])
               )
plt.plot([0, 3500000], [0, 3500000], 'darkgray', lw=0.8)
plt.xlabel('Observed survival time from NDE measurement')
plt.ylabel('Predicted survival time')
plt.title('Gradient boosting with regression trees')

### Original data

In [None]:
pred_real = GB_tree_ES.predict(real_x)

In [None]:
print('CI:', GB_tree_ES.score(real_x, real_y), '\n'
     'CI_ipcw:', mc.score_survival_model_ipcw(GB_tree_ES, real_x, y_train, real_y))

In [None]:
prediction_real = real_y_pre.copy()
prediction_real['prediction'] = pred_real

In [None]:
pd.options.display.float_format = '{:.3e}'.format
prediction_real

In [None]:
sns.set(style='white')
sns.scatterplot(x='F_Time', y='prediction', hue='Observed', data=prediction_real,
               alpha=0.6, palette=sns.xkcd_palette(['marine blue', 'deep red'])
               )
plt.plot([0, 3500000], [0, 3500000], 'darkgray', lw=0.8)
plt.xlabel('Observed survival time from NDE measurement')
plt.ylabel('Predicted survival time')
plt.title('Gradient boosting with RT - original data')

## Permutation Feature importance

In [None]:
pd.DataFrame(GB_tree_ES.feature_importances_, index=X_test.columns.tolist())\
.sort_values(0,ascending=True).plot.barh(color=[sns.color_palette('PuBu', 13, desat=0.9)], width=0.6, figsize=(6,6), legend=False)
plt.xlabel('Feature importance', fontsize = 12)

In [None]:
perm = PermutationImportance(GB_tree_ES, n_iter=15)
perm.fit(X_test, y_test)
feature_names = X_test.columns.tolist()
eli5.explain_weights(perm, feature_names=feature_names)

# Gradient boosting with component-wise least squares

### Gridsearch CV

In [None]:
param_grid = {'learning_rate': [0.01, 0.1, 0.5, 1],
              'n_estimators': [4000, 5000, 6000],
              'subsample': [0.1, 0.2, 0.3],
             }
GSCV_IPCWLS = GridSearchCV(ComponentwiseGradientBoostingSurvivalAnalysis(loss='ipcwls', random_state=42),
                           param_grid, scoring=mc.score_survival_model,
                   n_jobs=4, refit=False,
                   cv=5, verbose=1)

In [None]:
GSCV_IPCWLS.fit(X_train, y_train)

In [None]:
round(GSCV_IPCWLS.best_score_, 3), GSCV_IPCWLS.best_params_

### Model performance

In [None]:
GB_CWLS = ComponentwiseGradientBoostingSurvivalAnalysis(
    loss='ipcwls',
    subsample=1,
    n_estimators=5000,
    learning_rate=1, 
    dropout_rate=0.0,
    random_state=42
)

In [None]:
GB_CWLS.fit(X_test, y_test)

Testing several n_estimators

In [None]:
scores_cwls =  pd.DataFrame()
for n_estimators in range(1, 5000, 100):
    GB_CWLS.set_params(n_estimators=n_estimators)
    GB_CWLS.fit(X_train, y_train)
    results_cwls = pd.DataFrame({'n_estimators': n_estimators,
                           'ci_train': GB_CWLS.score(X_train, y_train),
                            'ci_test': GB_CWLS.score(X_test, y_test)
                           }, index=[0])
    scores_cwls = scores_cwls.append(results_cwls, ignore_index=True)
 

Plotting vs n_estimators

In [None]:
sns.set(style='whitegrid')
scores_cwls.set_index('n_estimators').plot()
plt.xlabel('n_estimator')
plt.ylabel('concordance index')
plt.title('Gradient boosting with component-wise least squares')

In [None]:
n_estimators = [i * 5 for i in range(1, 41)]

estimators = {
    "no regularization": ComponentwiseGradientBoostingSurvivalAnalysis(
        learning_rate=1.0, random_state=0
    ),
    "learning rate": ComponentwiseGradientBoostingSurvivalAnalysis(
        learning_rate=0.1, random_state=0
    ),
    "dropout": ComponentwiseGradientBoostingSurvivalAnalysis(
        learning_rate=1.0, dropout_rate=0.1, random_state=0
    ),
    "subsample": ComponentwiseGradientBoostingSurvivalAnalysis(
        learning_rate=1.0, subsample=0.5, random_state=0
    ),
}

scores_reg_cwls = {k: [] for k in estimators.keys()}
scores_train_reg_cwls = {k: [] for k in estimators.keys()}

for n in n_estimators:
    for name, est in estimators.items():
        est.set_params(n_estimators=n)
        est.fit(X_train, y_train)
        cindex_cwls = est.score(X_test, y_test)
        cindex_train_cwls = est.score(X_train, y_train)
        scores_reg_cwls[name].append(cindex_cwls)
        scores_train_reg_cwls[name].append(cindex_train_cwls)
        
scores_res_cwls = pd.DataFrame(scores_reg_cwls, index=n_estimators)
scores_train_reg_cwls = pd.DataFrame(scores_train_reg_cwls, index=n_estimators)


In [None]:
ax = scores_res_cwls.plot(xlabel="n_estimators", ylabel="concordance index")
ax.grid(True)
plt.title('Test')
#plt.ylim(0.84, 0.98)

In [None]:
ax = scores_train_reg_cwls.plot(xlabel="n_estimators", ylabel="concordance index")
ax.grid(True)
plt.title('Train')
#plt.ylim(0.84, 0.98)

## Prediction

### Simulated data

In [None]:
pred_Xtest_cwls = GB_CWLS.predict(X_test)

In [None]:
print('CI:', GB_CWLS.score(X_test, y_test), '\n'
     'CI_ipcw:', mc.score_survival_model_ipcw(GB_CWLS, X_test, y_train, y_test))

In [None]:
prediction_cwls = y_test_pre.copy()
prediction_cwls['pred_X_test'] = pred_Xtest_cwls

In [None]:
sns.set(style='white')
sns.scatterplot(x='F_Time', y='pred_X_test', hue='Observed', data=prediction_cwls,
               alpha=0.6, palette=sns.xkcd_palette(['marine blue', 'deep red'])
               )
plt.plot([0, 3500000], [0, 3500000], 'darkgray', lw=0.8)
plt.xlabel('Observed survival time from NDE measurement')
plt.ylabel('Predicted survival time')
plt.title('Gradient boosting with IPCWLS')

### Original data

In [None]:
pred_real_cwls = GB_CWLS.predict(real_x)

In [None]:
print('CI:', GB_CWLS.score(real_x, real_y), '\n'
     'CI_ipcw:', mc.score_survival_model_ipcw(GB_CWLS, real_x, y_train, real_y))

In [None]:
prediction_real_cwls = real_y_pre.copy()
prediction_real_cwls['prediction'] = pred_real_cwls

In [None]:
pd.options.display.float_format = '{:.3e}'.format
prediction_real_cwls

In [None]:
sns.set(style='white')
sns.scatterplot(x='F_Time', y='prediction', hue='Observed', data=prediction_real_cwls,
               alpha=0.6, palette=sns.xkcd_palette(['marine blue', 'deep red'])
               )
plt.plot([0, 3500000], [0, 3500000], 'darkgray', lw=0.8)
plt.xlabel('Observed survival time from NDE measurement')
plt.ylabel('Predicted survival time')
plt.title('Gradient boosting with RT - original data')

## Feature importance

### Coefficients

In [None]:
pd.DataFrame(GB_CWLS.coef_[1:], index=X_test.columns.tolist())\
.sort_values(0,ascending=True).plot.barh(color=[sns.color_palette('coolwarm', 13, desat=0.9)], width=0.6, figsize=(6,6), legend=False)
plt.xlabel('Coefficients', fontsize = 12)

### permutation

In [None]:
perm = PermutationImportance(GB_CWLS, n_iter=15)
perm.fit(X_test, y_test)
feature_names = X_test.columns.tolist()
eli5.explain_weights(perm, feature_names=feature_names)