Random forest on simulated data, plot survival curves, predict using median survival time, plot feature importance by permutation, test reduced feature number

In [None]:
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
import os, glob, inspect, sys
from sksurv.ensemble import RandomSurvivalForest
from sksurv.datasets import load_gbsg2
from sksurv.metrics import concordance_index_ipcw
import sksurv
import eli5
from eli5.sklearn import PermutationImportance
from sklearn.model_selection import train_test_split
from sklearn.decomposition import PCA
from sklearn.preprocessing import MinMaxScaler

currentdir = os.path.dirname(os.path.abspath(inspect.getfile(inspect.currentframe())))
parentdir = os.path.dirname(currentdir)
sys.path.insert(0,parentdir) 
import epri_mc_lib_3 as mc
from importlib import reload
reload(mc)


Do random survival forest analysis on the simulated data

In [None]:
data = pd.read_csv(os.path.join(os.path.dirname(os.getcwd()), '../Data/Merged_data/CopulaGAN_simulated_data_survival_2.csv'),
                  index_col=0)
data.reset_index(inplace=True)
#data = data.drop(columns=['NDE_cycle'])

## Format data

In [None]:
data_x = data.iloc[:, 2:]
data_y = data.iloc[:, 0:2]


In [None]:

X_train, X_test, y_train, y_test = train_test_split(
    data_x, data_y, test_size=0.25)

In [None]:
y_train_num = y_train.to_records(index=False)
y_test_num = y_test.to_records(index=False)


## Random survival forest

In [None]:
rsf = RandomSurvivalForest(n_estimators=5000,
                           max_features="auto",
                          oob_score=True
                          )
rsf.fit(X_train, y_train_num)


## Out of bag score (Concordance index)

The out of bag score for random survival forests is the concordance index that is a measure of whether samples are properly ordered relative to one another. 0.5 indicates random ordering and 1 indicates perfect ordering. 

In [None]:
rsf.oob_score_

## Making predictions

THe predictions are a risk score, the expected number of deathes for a terminal node. This doesn't seem to be super useful. The rsf.score gives the concordance index for the test data.

In [None]:
y_pred = pd.Series(rsf.predict(X_test))


More conservative version of the CI to handle high levels of censored data.

In [None]:
concordance_index_ipcw(y_train_num, y_test_num, y_pred)

CI score for test data.

In [None]:
rsf.score(X_test,y_test_num)

Predict survival functions.

In [None]:
surv = rsf.predict_survival_function(X_test, return_array=False)

Calculate the median survival so that we can predict a specific number.

In [None]:
median_survival_list = mc.calc_median_survival(surv)
colordict = {False:'blue', True:'red'}
plt.scatter(x=y_test.F_Time,y=median_survival_list,c=[colordict[c] for c in y_test.Observed],alpha=0.2)
plt.ylabel("Median predicted survival time")
plt.xlabel("Observed survival time")
plt.title("Random survival forest, simulated data")
plt.show()

In [None]:
surv = rsf.predict_survival_function(X_test, return_array=False)

i=0
while i < 20: #X_test.shape[0]:
    fn = surv[i]
    plt.step(fn.x, fn(fn.x), where="post")
    plt.ylabel("Survival probability")
    plt.xlabel("Time in cycles")
    plt.ylim(0,1)
    plt.text(1500000, 0.95, 'Actual survived cycles: '+str(np.round(y_test.iloc[i,1],0)))
    plt.text(1500000, 0.90, 'Actual failed: '+str(y_test.iloc[i,0]))
    plt.show()
    i+=1

## Test on the real data

Calculate the concordance index. Get the CI score for the data.

In [None]:
data_real = pd.read_csv(os.path.join(os.path.dirname(os.getcwd()), '../Data/Merged_data/Survival_df.csv'),
                  index_col=0)
#data_real.drop(columns=['NLE_ratio_119_17'],inplace=True)
#data_real = data_real.drop(columns=['NDE_cycle'])

real_x = data_real.iloc[:, 2:]
real_y = data_real.iloc[:, 0:2]
real_y_num = real_y.to_records(index=False)

print(rsf.score(real_x,real_y_num))

surv = rsf.predict_survival_function(real_x, return_array=False)
y_pred_real = pd.Series(rsf.predict(real_x))


Conservative ipcw version of the CI score.

In [None]:
concordance_index_ipcw(y_train_num, real_y_num, y_pred_real)

Convert survival function to median expected survival time.

In [None]:

median_survival_list = mc.calc_median_survival(surv)
colordict = {False:'blue', True:'red'}
plt.scatter(x=real_y.F_Time,y=median_survival_list,c=[colordict[c] for c in real_y.Observed],alpha=0.4)
plt.ylabel("Median predicted survival time")
plt.xlabel("Observed survival time")
plt.title("Random survival forest, real data")
plt.show()


In [None]:

i=0

while i < real_x.shape[0]:
    print()
    fn = surv[i]
    plt.step(fn.x, fn(fn.x), where="post")
    plt.ylabel("Survival probability")
    plt.xlabel("Time in cycles")
    plt.ylim(0,1)
    plt.text(1500000, 0.95, 'Actual survived cycles: '+str(real_y.iloc[i,1]))
    plt.text(1500000, 0.90, 'Actual failed: '+str(real_y.iloc[i,0]))
    plt.title(real_x.index[i])
    plt.show()
    i+=1


## Feature importance by permutation

This estimates the importance of each feature by permutating it and looking at the effect on the model. This is not good when the features are correlated as they are in this case. We test on a subset of the features.

In [None]:
data_x = data.iloc[:, 2:]
data_y = data.iloc[:, 0:2]
df_features = data_x[mc.feature_selection]

X_train, X_test, y_train, y_test = train_test_split(
    df_features, data_y, test_size=0.25)
y_train_num = y_train.to_records(index=False)
y_test_num = y_test.to_records(index=False)


In [None]:

rsf = RandomSurvivalForest(n_estimators=1000,
                           max_features="auto",
                          oob_score=True
                          )
rsf.fit(X_train, y_train_num)
rsf.oob_score_

In [None]:
y_pred = pd.Series(rsf.predict(X_test))
rsf.score(X_test,y_test_num)

In [None]:
perm = PermutationImportance(rsf, n_iter=15)
perm.fit(X_test, y_test_num)

In [None]:
feature_names = X_test.columns.tolist()
eli5.show_weights(perm, feature_names=feature_names)

Calculate the concordance index and important features with a subset of features that without number of cycles.

In [None]:
data_x = data.iloc[:, 2:]
data_y = data.iloc[:, 0:2]
df_features = data_x[mc.feature_selection2]

X_train, X_test, y_train, y_test = train_test_split(
    df_features, data_y, test_size=0.25)
y_train_num = y_train.to_records(index=False)
y_test_num = y_test.to_records(index=False)


In [None]:

rsf = RandomSurvivalForest(n_estimators=1000,
                           max_features="auto",
                          oob_score=True
                          )
rsf.fit(X_train, y_train_num)
rsf.oob_score_

In [None]:
surv = rsf.predict_survival_function(real_x[mc.feature_selection2], return_array=False)

i=0

while i < real_x.shape[0]:
    print()
    fn = surv[i]
    plt.step(fn.x, fn(fn.x), where="post")
    plt.ylabel("Survival probability")
    plt.xlabel("Time in cycles")
    plt.ylim(0,1)
    plt.text(1500000, 0.95, 'Actual survived cycles: '+str(real_y.iloc[i,1]))
    plt.text(1500000, 0.90, 'Actual failed: '+str(real_y.iloc[i,0]))
    plt.title(real_x.index[i])
    plt.show()
    i+=1


In [None]:
y_pred = pd.Series(rsf.predict(X_test))
rsf.score(X_test,y_test_num)

In [None]:
perm = PermutationImportance(rsf, n_iter=15)
perm.fit(X_test, y_test_num)

In [None]:
feature_names = X_test.columns.tolist()
eli5.show_weights(perm, feature_names=feature_names)

Calculate concordance and feature importance with only NLO for comparison.

In [None]:
data_x = data.iloc[:, 2:]
data_y = data.iloc[:, 0:2]
df_features = pd.DataFrame(data_x['NLO_avg'])

X_train, X_test, y_train, y_test = train_test_split(
    df_features, data_y, test_size=0.25)
y_train_num = y_train.to_records(index=False)
y_test_num = y_test.to_records(index=False)

rsf = RandomSurvivalForest(n_estimators=1000,
                           max_features="auto",
                          oob_score=True
                          )
rsf.fit(X_train, y_train_num)
print(rsf.oob_score_)

y_pred = pd.Series(rsf.predict(X_test))
print(rsf.score(X_test,y_test_num))

perm = PermutationImportance(rsf, n_iter=15)
perm.fit(X_test, y_test_num)

feature_names = X_test.columns.tolist()
eli5.show_weights(perm, feature_names=feature_names)