In [None]:
import warnings
warnings.filterwarnings("ignore")
import matplotlib.pyplot as plt
%matplotlib inline
from matplotlib import font_manager as fm, rcParams
plt.rcParams['font.sans-serif'] = ['Arial Unicode MS']

In [None]:
import pandas as pd
import numpy as np
from lifelines import CoxPHFitter
from lifelines.utils import concordance_index
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler
from sklearn.model_selection import GridSearchCV
from sklearn.metrics import roc_curve, auc

geo_path = "./dataset/GEO_clinical_genes.xlsx"
tcga_path = "./dataset/TCGA_clinical_genes.xlsx"
df = pd.read_excel(geo_path)


df['event'] = df['CSS'].apply(lambda x: 1 if x == 'Dead' else 0) 
df.drop(columns=['CSS'], inplace=True)

covariates = list(df.columns)
covariates.remove('Survival_months')  
covariates.remove('event')

categorical_vars = df.select_dtypes(include=['object', 'category']).columns.tolist()
df = pd.get_dummies(df, columns=categorical_vars, drop_first=True)

In [None]:
train_df, test_df = train_test_split(df, test_size=0.2, random_state=42)
print(train_df.shape)
print(test_df.shape)

In [None]:
penalizer_values = [0.0001, 0.001, 0.01, 0.1]
l1_ratios = [0, 0.01, 0.05, 0.1, 0.2]

best_c_index = 0
best_params = None

for pen in penalizer_values:
    for l1 in l1_ratios:
        cph = CoxPHFitter(penalizer=pen, l1_ratio=l1)
        cph.fit(train_df, duration_col='Survival_months', event_col='event')
        c_index = cph.concordance_index_
        
        if c_index > best_c_index:
            best_c_index = c_index
            best_params = (pen, l1)

print(f"penalizer={best_params[0]}, l1_ratio={best_params[1]}ï¼ŒC-index={best_c_index:.4f}")

In [None]:
cph = CoxPHFitter(penalizer=best_params[0], l1_ratio=best_params[1])
cph.fit(train_df, duration_col='Survival_months', event_col='event')


cph.print_summary()

train_c_index = concordance_index(train_df['Survival_months'], -cph.predict_partial_hazard(train_df), train_df['event'])
print(f"train_c_index: {train_c_index:.4f}")

test_c_index = concordance_index(test_df['Survival_months'], -cph.predict_partial_hazard(test_df), test_df['event'])
print(f"c_index_test: {test_c_index:.4f}")

In [None]:
feature_importance = cph.summary[['p', 'coef']]
feature_importance = feature_importance.sort_values(by='p')
print("feature_importance:")
print(feature_importance)

In [None]:
time_points = [12, 36, 60]

pred_surv = cph.predict_survival_function(test_df)

plt.figure(figsize=(8, 6), dpi=500)

for t in time_points:
    surv_probs = np.array([
        np.interp(t, pred_surv.index, pred_surv[col].values) 
        for col in pred_surv.columns
    ])
    risk_scores = 1 - surv_probs

    y_true = ((test_df['Survival_months'] <= t) & (test_df['event'] == 1)).astype(int)

    fpr, tpr, thresholds = roc_curve(y_true, risk_scores)
    roc_auc = auc(fpr, tpr)
    
    plt.plot(fpr, tpr, label=f"{t} months (AUC = {roc_auc:.2f})")

plt.plot([0, 1], [0, 1], 'k--', lw=1)
plt.xlim([0, 1])
plt.ylim([0, 1])
plt.xlabel('False Positive Rate')
plt.ylabel('True Positive Rate')
plt.title('Cox Survival Analysis: 1-year, 3-year, and 5-year ROC Curves')
plt.legend(loc='lower right')
plt.show()

In [None]:
t_min = 1
t_max = 60
t_points = np.arange(t_min, t_max + 1) 
auc_values = []

pred_surv = cph.predict_survival_function(test_df)

for t in t_points:
    surv_probs = np.array([
        np.interp(t, pred_surv.index, pred_surv[col].values) 
        for col in pred_surv.columns
    ])
    risk_scores = 1 - surv_probs

    y_true = ((test_df['Survival_months'] <= t) & (test_df['event'] == 1)).astype(int)
    
    if (y_true.sum() == 0) or (y_true.sum() == len(y_true)):
        auc_values.append(np.nan)
    else:
        fpr, tpr, thresholds = roc_curve(y_true, risk_scores)
        auc_val = auc(fpr, tpr)
        auc_values.append(auc_val)


plt.figure(figsize=(8, 6), dpi=500)
plt.ylim([0.6, 0.9])
plt.plot(t_points, auc_values, marker='o', linestyle='-')
plt.xlabel('Time (months)')
plt.ylabel('AUC')
plt.title('Cox Survival Analysis: AUC Over Time')
plt.grid(True)
plt.show()

In [None]:
# DataFrame
rsf_auc_df = pd.DataFrame({
    'Month': t_points,
    'Cox_AUC': auc_values
})


rsf_auc_df.to_excel("./Log/Cox_AUC_by_month.xlsx", index=False)