In [None]:
! pip install scikit-learn==1.1.3
! pip install scikit-survival
! pip install lifelines
! pip install scikit-optimize

In [None]:
import warnings
import matplotlib.pyplot as plt
import numpy as np
# @title Import libs
import pandas as pd
from lifelines import KaplanMeierFitter
from lifelines.statistics import logrank_test
from sklearn.decomposition import PCA
from sklearn.model_selection import (GridSearchCV, cross_val_predict,cross_validate)
from sksurv.ensemble import (ComponentwiseGradientBoostingSurvivalAnalysis,
                             RandomSurvivalForest)
from sksurv.metrics import concordance_index_censored
from sksurv.svm import FastSurvivalSVM
warnings.filterwarnings('ignore')
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler

In [None]:
#@title Survival Algorithms & Params
rsf_param_grid = param_grid = {'n_estimators': np.arange(1, 10),
                          'min_samples_split': np.arange(6,10,1)
                          }

cwgb_param_grid = {
              'loss': ['coxph' , 'squared' ],
              'learning_rate': [1, 10, 100 ],
              'dropout_rate': [0.1, 0.5, 0.9],
              'n_estimators': [10],
            }

fsvm_param_grid = {'alpha': 2. ** np.arange(-10, 10, 5),
                'optimizer': ['avltree' , 'direct-count' ,'PRSVM'  ,'rbtree'  , 'simple' ]
              }
# Define the regressors and their respective parameter spaces
survival_algs = {
    'RandomSurvivalForest':(RandomSurvivalForest(random_state=42),rsf_param_grid),
    'ComponentwiseGradientBoostingSurvivalAnalysis' :(ComponentwiseGradientBoostingSurvivalAnalysis(random_state=42),cwgb_param_grid),
    'FastSurvivalSVM': (FastSurvivalSVM(max_iter=512, tol=1e-6, random_state=42), fsvm_param_grid)
}

In [None]:
n_feats = [10]

result_path = "/"
feature_path = "/"

def append_row(df, row):
    return pd.concat([
                df,
                pd.DataFrame([row], columns=row.index)]
          ).reset_index(drop=True)
def score_survival_model(model, X, y):
    prediction = model.predict(X)
    result = concordance_index_censored(y['Event'], y['Duration'], prediction)
    return result[0]


In [None]:
import glob
files = glob.glob(feature_path +"*")
datasets = []
for filee in files:
  datasets.append(filee.split("/")[-1].split(".")[0])
datasets

In [None]:
#@title Survival analysis

runs_df = pd.DataFrame()

n=0
# Step 1: Prepare your data
for dataset in datasets:
  for name, (reg, param_space) in survival_algs.items():

    n=n+1
    path = feature_path+dataset+".xlsx"
    Y = pd.read_csv("COX_OUTCOME.csv",header=0)
    X = pd.read_excel(path, sheet_name='Data' , engine='openpyxl',header=0).reindex()

    Y = Y[['Event','Duration']]
    Y['Event'] = Y['Event'].astype(bool)
    Y['Duration'] = Y['Duration'].astype(float)/365

    X.columns = X.columns.astype(str)
    Y = Y.to_records(index=False)

    # Split the data into training and test sets
    X_train, X_test, y_train, y_test = train_test_split(X, Y, test_size=0.2,random_state=101)

    # Scale the features
    scaler = StandardScaler()
    X_train = scaler.fit_transform(X_train)
    X_test = scaler.transform(X_test)

    # Create PCA instance: PCA for 5 components
    pca = PCA(n_components=5)
    X_train = pca.fit_transform(X_train)
    X_test = pca.transform(X_test)

    grid_search = GridSearchCV(estimator=reg, param_grid=param_space, cv=5)
    grid_search.fit(X_train, y_train)

    # Step 5: Retrieve the best model from grid search
    best_regressor = grid_search.best_estimator_

    # Step 6: Evaluate the best model using cross-validation
    train_scores = cross_validate(best_regressor, X_train, y_train, cv=5, scoring=score_survival_model)
    test_test_score = best_regressor.score(X_test, y_test)

    test_risk_scores = cross_val_predict(best_regressor, X_test, y_test, cv=5)
    train_risk_scores = cross_val_predict(best_regressor, X_train, y_train, cv=5)

    train_threshold = np.percentile(train_risk_scores, 50)
    test_threshold = np.percentile(test_risk_scores, 50)

    test_high_risk = y_test[test_risk_scores >= test_threshold]
    test_low_risk = y_test[test_risk_scores < test_threshold]

    train_high_risk = y_train[train_risk_scores >= train_threshold]
    train_low_risk = y_train[train_risk_scores < train_threshold]

    if len(train_low_risk)>0 and len(train_high_risk)>0:

      test_results = logrank_test( test_low_risk["Duration"],  test_high_risk["Duration"],
                              event_observed_A=test_low_risk["Event"], event_observed_B=test_high_risk["Event"])

      results = logrank_test( train_low_risk["Duration"],  train_high_risk["Duration"],
                              event_observed_A=train_low_risk["Event"], event_observed_B=train_high_risk["Event"])

      kmf = KaplanMeierFitter()
      kmf2 = KaplanMeierFitter()
      plt.clf()
      kmf.fit(train_high_risk["Duration"], train_high_risk["Event"],label='High Risk')

      kmf2.fit(train_low_risk["Duration"], train_low_risk["Event"],label='Low Risk')

      ax = plt.subplot(111)

      ax = kmf.plot(color='Gold', label='High Risk',show_censors=True, censor_styles={'ms': 6, 'marker': '|'})
      ax = kmf2.plot(color='Teal', label='Low Risk',show_censors=True, censor_styles={'ms': 6, 'marker': '|'})
      plt.title(dataset)
      plt.xlabel("Time (Years)")
      plt.ylabel("Survival probability")
      ax.grid(axis='both', which='both',color='lightgray', linestyle='-', linewidth=0.5,zorder=-1000)
      ax.text(0.05, 0.05, "Log Rank P-value : "+str(round(results.p_value, 4)),bbox = {'facecolor': 'lightgray'})
      ax.text(0.05, 0.15, "C-index : "+str(round(train_scores['test_score'].mean(), 2)),bbox = {'facecolor': 'lightgray'})
      print(str(n*100/51)+","+dataset+","+name+","+str(train_scores['test_score'].mean())+","+str(train_scores['test_score'].std())+","+str(results.p_value)+","+str(test_test_score)+","+str(test_results.p_value)+","+str(train_scores))

      from lifelines.plotting import add_at_risk_counts
      add_at_risk_counts(kmf, kmf2 , ax=ax)
      plt.savefig(result_path+"/figs/"+dataset+"-"+name+".png")

      plt.show()