In [None]:
# # Step 1: Install the package

!pip install scikit-survival
!pip install lifelines
!pip install shap

In [None]:
# Step 2: Load packages

import numpy as np
import matplotlib.pyplot as plt
import pandas as pd
import seaborn as sns

from sklearn.model_selection import GridSearchCV, KFold
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import MinMaxScaler
from sklearn.pipeline import Pipeline

from lifelines import CoxPHFitter 
from lifelines.utils import k_fold_cross_validation
from lifelines.statistics import logrank_test

from lifelines import KaplanMeierFitter 
from lifelines.plotting import add_at_risk_counts


from sksurv.linear_model import CoxPHSurvivalAnalysis
from sksurv.metrics import concordance_index_censored
from sksurv.svm import FastSurvivalSVM
from sksurv.ensemble import RandomSurvivalForest
from sksurv.ensemble import GradientBoostingSurvivalAnalysis

import shap

sns.set_style("whitegrid")

import warnings
warnings.filterwarnings("ignore")

# Connect to google drive 
from google.colab import drive
drive.mount('/content/drive')


In [None]:
# Step 3: Load data
file1 = pd.read_csv('/content/drive/MyDrive/Data_Tutorial_BC/data_clinical_patient.csv')
file2 = pd.read_csv('/content/drive/MyDrive/Data_Tutorial_BC/GeneID_MRMR_50.csv')


In [None]:
# 3.1 Merge gene data with OS time and status
data = pd.merge(file1[['PATIENT_ID','OS_MONTHS','OS_STATUS']],file2, how="inner", on=["PATIENT_ID"])

In [None]:
# Have a quick look on data
data.head()

In [None]:
# Step 4: Preprocess data & Explore data

# 4.1 Drop unused columns
drop_list = ['PATIENT_ID']
df = data.drop(drop_list, axis=1)

print('After the first preprocessing, the shape of data is', df.shape)

# 4.2 Check missing values again
print('Missing value number:', df.isna().sum().sum())

# 4.3 Encode OS status to dummy
df['OS_STATUS'] = np.where(df['OS_STATUS'] == '1:DECEASED', 1, 0)

In [None]:
# 4.6 Correlation analysis
colormap = plt.cm.Reds
plt.figure(figsize=(15,15))
sns.heatmap(df.corr(),linewidths=0.1,vmax=0.8, 
            square=True, cmap = colormap, linecolor='white')
plt.title('Correlation matrix', fontsize=14)
plt.show()


In [None]:
# 4.7 Time Distribution of Death and Censor
num_censored = df.shape[0] - df["OS_STATUS"].sum()
print("%.1f%% of records are censored" % (num_censored/df.shape[0]*100))

plt.figure(figsize=(9, 6))
val, bins, patches = plt.hist((df.query('OS_STATUS == 1')['OS_MONTHS'],
                               df.query('OS_STATUS == 0')['OS_MONTHS']),
                              bins=30, stacked=True)
_ = plt.legend(patches, ["Time of Death", "Time of Censoring"])
plt.title("Time Distribution for Censored and Death Patients")
plt.xlabel('Time (months)',fontsize='large')
plt.ylabel('Frequency',fontsize='large')

In [None]:
# Step 5: Cox survival analysis
# 5.1: Normalise data
ss = MinMaxScaler()
df_norm = data.drop(['OS_STATUS', 'OS_MONTHS', 'PATIENT_ID'], axis = 1)
df_norm = pd.DataFrame(ss.fit_transform(df_norm), columns=df_norm.columns)
df_norm['OS_STATUS'] = data['OS_STATUS']
df_norm['OS_MONTHS'] = data['OS_MONTHS']

In [None]:
# 5.2: Build model 
# Cox Proportional Hazards Model
cph = CoxPHFitter()
cph.fit(df_norm, duration_col='OS_MONTHS', event_col='OS_STATUS')

# Plot
plt.figure(figsize=(8, 10))
plt.title('Cox Proportional Hazards Model for 50 Gene Features')
cph.plot()

In [None]:
# Report
cph.print_summary(columns=["coef","exp(coef)","exp(coef) lower 95%","exp(coef) upper 95%", "z", "p"], decimals=3)

In [None]:
# Cross validation (optional)
scores = k_fold_cross_validation(cph, df_norm, 'OS_MONTHS', event_col='OS_STATUS', k=5, 
                                scoring_method="concordance_index", seed=18)

print("Average score", np.mean(scores))

In [None]:
# Step 6: Machine Learning Methods for Survival Analysis

# 6.1: Set up seed and the options for the cross-validation approach
SEED = 5
CV = KFold(n_splits=5, shuffle=True, random_state=0)

# 6.2 Split data to prepare for ML
X = df.drop(['OS_MONTHS','OS_STATUS'], axis = 1)
df['OS_STATUS'] = np.where(df['OS_STATUS'] == 1, True, False)
y = df[['OS_STATUS','OS_MONTHS']].to_records(index=False)

# Split the data set into training and testing sets
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, 
                                                    stratify=y['OS_STATUS'],
                                                    random_state=SEED)

In [None]:
# 6.3: Build model

# Define a function for grid search to tune training model
# and predict the results
def grid_search(estimator, param, X_train, y_train, X_test, y_test, CV):
    
    # Define Grid Search
    gcv = GridSearchCV(
    estimator,
    param_grid=param,
    cv=CV,
    n_jobs=-1).fit(X_train, y_train)

    # Find best model
    model = gcv.best_estimator_
    print(model)
    
    # Predict model
    prediction = model.predict(X_test)
    result = concordance_index_censored(y_test["OS_STATUS"], y_test["OS_MONTHS"], prediction)
    print('C-index for test set (Hold out):', result[0])

    return [model,  prediction]


In [None]:
# Re-run experiment 20 times 
def c_index(model, X, y, n=20):
    np.random.seed(1)
    seeds = np.random.permutation(1000)[:n]

    # Train and evaluate model with 20 times
    cindex_score = []
    predict_list = []
    
    for s in seeds:
        X_trn, X_test, y_trn, y_test = train_test_split(X, y, test_size=0.2, 
                                                        stratify=y['OS_STATUS'],
                                                        random_state=s)
        model.fit(X_trn, y_trn)
        prediction = model.predict(X_test)
        predict_list.append(prediction)
        result = concordance_index_censored(y_test["OS_STATUS"],y_test["OS_MONTHS"], prediction)
    
        cindex_score.append(round(result[0],3))

    print('Average C-index for {} runs'.format(n), np.mean(cindex_score))
    
    return [cindex_score, predict_list]

In [None]:
# Define the Pipeline and hyperparameter

# CoxPHSurvivalAnalysis
pipe_cox = Pipeline([('scaler', MinMaxScaler()),('model', CoxPHSurvivalAnalysis())])
param_cox ={'scaler': [MinMaxScaler()],
        "model__alpha": [0.001, 0.01, 0.1, 1, 10, 100]}

# Random Survival Forest
pipe_rsf = Pipeline([('scaler', MinMaxScaler()),('model', RandomSurvivalForest())])
param_rsf ={'scaler': [MinMaxScaler()],
        'model__random_state': [SEED],
        'model__max_features': ['sqrt'],
        'model__max_depth': [8],
        'model__min_samples_leaf': [50, 100],
        'model__min_samples_split': [100],
        'model__n_estimators':[500]}

# Gradient Boost Survival
pipe_gbs = Pipeline([('scaler', MinMaxScaler()),('model', GradientBoostingSurvivalAnalysis())])
param_gbs ={'scaler': [MinMaxScaler()],
        'model__random_state': [SEED],
        'model__learning_rate': [0.01, 0.1, 1],
        'model__n_estimators':[200, 500, 800, 1000]}

# Survival SVM
pipe_svm = Pipeline([('scaler', MinMaxScaler()),('model', FastSurvivalSVM())])
param_svm ={'scaler': [MinMaxScaler()],
        'model__random_state': [SEED],
        'model__max_iter': [500, 5000],
        'model__optimizer':['avltree', 'rbtree','simple']}

# Estimator list:
estimator_list = {'Cox Regression':[pipe_cox, param_cox ], 
                'Random Forest Survival':[pipe_rsf, param_rsf], 
                'Gradient Boosting Survival': [pipe_gbs, param_gbs], 
                'SVM Survival': [pipe_svm, param_svm]}

In [None]:
model_list = []
pred_list = []
c_index_list = []
pred_list_n = []

for model_name, index in estimator_list.items():
    print('\n',model_name)
    estimator = index[0]
    param = index[1]
    outcome = grid_search(estimator, param, X_train, y_train, X_test, y_test, CV)
    model = outcome[0]
    model_list.append(model)
    pred_list.append(outcome[1])

    # Run model n times to check c-index
    score, pre = c_index(model, X, y, n=20)
    c_index_list.append(score)
    pred_list_n.append(pre)

    

In [None]:
# Visualise results

name = ['CPH', 'RSF', 'GBS', 'SSVM']
cv_res = []

for i in range(0,4):
    for c in c_index_list[i]:
        cv_res.append([name[i],c])

c_plot = pd.DataFrame(cv_res, columns=['Model Name','C-index'])

plt.subplots(figsize=(8,6))
ax = sns.boxplot(x="Model Name", y="C-index", data=c_plot)
plt.title('C-index for 20 runs')
 

In [None]:
# KM Curve for median
fig, ax = plt.subplots(2,2,figsize=(12,12))
k  = 0
for pred in pred_list:
    df1 = X_test.reset_index(drop=True) 
    risk =[]
    
    y_pred = pred
    med = np.median(y_pred)
    r = np.where(y_pred >= med, 1, 0)

    df1['Risk'] = r
    print(df1.shape)
    ix = df1['Risk'] == 1

    df_y = pd.DataFrame(y_test)
    df_y['OS_STATUS'] = np.where(df_y['OS_STATUS'] == True, 1, 0)
    df1['OS_STATUS']= df_y['OS_STATUS']
    df1['OS_MONTHS']= df_y['OS_MONTHS']
    T_hr, E_hr = df1.loc[ix]['OS_MONTHS'], df1.loc[ix]['OS_STATUS']
    T_lr, E_lr = df1.loc[~ix]['OS_MONTHS'], df1.loc[~ix]['OS_STATUS']

    # Set-up plots
    k+=1
    plt.subplot(2,2,k)

    # Fit survival curves
    kmf_hr = KaplanMeierFitter()
    ax = kmf_hr.fit(T_hr, E_hr, label='HR').plot_survival_function()

    kmf_lr = KaplanMeierFitter()
    ax = kmf_lr.fit(T_lr, E_lr, label='LR').plot_survival_function()

    add_at_risk_counts(kmf_lr, kmf_lr)
    # Format graph
    plt.ylim(0,1)
    ax.set_xlabel('Time (months)',fontsize='large')
    ax.set_ylabel('Est. probability of survival',fontsize='large')

    # Calculate p-value
    res = logrank_test(T_hr, T_lr, event_observed_A=E_hr, event_observed_B=E_lr, alpha=.95)
    print('\nModel', name[k-1])
    res.print_summary()

    # Location the label at the 1st out of 9 tick marks
    xloc = max(np.max(T_hr),np.max(T_lr)) / 10
    ax.text(xloc,.2,'p-value = {}'.format(round(res.p_value,3)),fontsize=12)
    ax.set_title('KM Curves {}' .format(name[k-1]))
    plt.tight_layout()
    

In [None]:
# Step 7: Interpret data
# Initialize JS For Plot

shap.initjs()

for i in range(0,4):
    print('\nModel', name[i])
    m = model_list[i][1]
    m.fit(X_train,y_train)
    explainer = shap.Explainer(m.predict, X_train, feature_names=X_train.columns)
    shaps = explainer(X_test)
    shap.summary_plot(shaps, X_test)