# Cox Proportional Hazards Model on Clinical Covariates

In [3]:
import os
import pandas as pd
import numpy as np
from lifelines import CoxPHFitter
import seaborn as sns
import matplotlib.pyplot as plt
from sksurv.metrics import concordance_index_censored, concordance_index_ipcw
from sksurv.util import Surv

In [4]:
import warnings
warnings.filterwarnings("ignore")

In [38]:
def coxPF_fit(data, data_test, label="survival_time", event="event"):
    """ Fit and test the cox PH model. """
    # Fit the Cox proportional hazards model with all covariates
    cph = CoxPHFitter()

    data = data.dropna()
    data_test = data_test.dropna()

    # Fit the model
    cph.fit(data, duration_col=label, event_col=event)

    # Predictions
    actual_times = data_test['survival_time']
    actual_events = data_test['event']
    predicted_scores = -cph.predict_partial_hazard(data_test)

    # c_index
    c_index = concordance_index_censored(
        actual_events.astype(bool), actual_times, predicted_scores
    )[0]

    # C-index IPCW (requires sksurv format)
    y_test_struct = Surv.from_arrays(event=actual_events.astype(bool), time=actual_times)
    y_train_struct = Surv.from_arrays(event=data[event].astype(bool), time=data[label])
    c_index_ipcw = concordance_index_ipcw(y_train_struct, y_test_struct, predicted_scores)[0]


    return {
        "c_index": c_index,
        "c_index_ipcw": c_index_ipcw,
        "actual_times": actual_times,
        "actual_events": actual_events,
        "predicted_scores": predicted_scores
    }


def prepare_df_clinical_variables(data_df, with_grade):
    """ Setup dataset with the clinical covariates. """
    list_covariates = ['birth_days_to', 'dss_survival_days', 'dss_censorship', 'sex', 'ajcc_pathologic_tumor_stage']
    column_dict = {"birth_days_to": "age", "dss_survival_days": "survival_time", "dss_censorship": "event", "ajcc_pathologic_tumor_stage": "stage"}

    # If grade is available
    if with_grade:
        list_covariates.append('histological_grade')
        column_dict['histological_grade'] = 'grade'
    data_df= data_df[list_covariates]

    # Mapping dictionary
    stage_mapping = {
        'Stage I': 1,
        'Stage IA': 1,
        'Stage IB': 1,
        'Stage II': 2,
        'Stage IIA': 2,
        'Stage IIB': 2,
        'Stage III': 3,
        'Stage IIIA': 3,
        'Stage IIIB': 3,
        'Stage IV': 4,
        'Stage X': np.nan,
        '[Discrepancy]': np.nan,
        '[Not Available]': np.nan,
    }

    data_df = data_df.rename(columns=column_dict)
    
    # Map the column
    data_df['stage'] = data_df['stage'].map(stage_mapping)
    # Event is the opposite of censorship
    data_df['event'] = data_df['event'].apply(lambda x: 0 if x == 1 else 1)
    data_df['age'] = data_df['age'] * -1 
    data_df['sex'] = data_df['sex'].apply(lambda x: 0 if x == 'F' else 1)

    if with_grade:
        grade_mapping = {
        'Low Grade': 1,
        'High Grade': 2,
        'G1' : 1,
        'G2' : 2,
        'G3' : 3,
        'G4' : 4,
        'GX' : np.nan, #np.nan,
        '[Not Available]': np.nan, 
        '[Unknown]': np.nan}
        
        data_df['grade'] = data_df['grade'].map(grade_mapping)
    

    return data_df
    
def visualize_df(clinical_df, name, with_grade=False):
    """ Visualizes the clinical variables. """
    print("Visualizing " + name)
    # Plotting the distributions of clinical variables
    covariates = ['sex', 'age', 'survival_time', 'event', 'stage']
    if with_grade:
        covariates.append('grade')
    for column in covariates:
        plt.figure(figsize=(8, 4))
        if clinical_df[column].nunique() <= 10:  # Categorical or ordinal variable
            ax = sns.countplot(x=column, hue=column, data=clinical_df, palette='coolwarm', legend=False)
            plt.title(f"Distribution of {column} for {name}")
            plt.xlabel(column)
            plt.ylabel("Count")
        else:  # Continuous variable
            sns.histplot(clinical_df[column], kde=True, color='blue')
            plt.title(f"Distribution of {column}")
            plt.xlabel(column)
            plt.ylabel("Density")
        plt.show()


In [39]:
def CoxPH_per_fold(fold, data_source, keys, visualize=False):
    """ Computes test c-index of fitted cox PH models of one fold. """
    fold_folder = os.path.join(data_source, f'{fold}')

    model_types = []
    with_grade = False
    for item in keys:
        covariates = item.split('_')
        if "grade" in covariates:
            with_grade = True
        model_types.append(covariates)

    # Prepare dataset train data
    split_file_train = os.path.join(fold_folder, 'train_old.csv')
    data_df = pd.read_csv(split_file_train)
    clinical_df_train = prepare_df_clinical_variables(pd.DataFrame(data_df), with_grade)

    # Prepare dataset test data
    split_file_test = os.path.join(fold_folder, 'test_old.csv')
    data_df_test = pd.read_csv(split_file_test)
    clinical_df_test = prepare_df_clinical_variables(pd.DataFrame(data_df_test), with_grade)

    if visualize:
        visualize_df(clinical_df_train, F"Train data on fold {fold}")
        visualize_df(clinical_df_test, f"Test data on fold {fold}")

    # Run model for each model type
    final_results = {}
    for item in model_types:
        covariates = item + ['survival_time', 'event']
        results_model = coxPF_fit(clinical_df_train[covariates], clinical_df_test[covariates])
        final_results['_'.join(item)] = results_model

    return final_results

def compute_clinical_results(data_source_path, with_grade=False):
    """ Computes test c-index of fitted cox PH models of all folds. """
    c_index_dict = {
        "sex_age": [],
        "sex_age_stage": []
    }
    c_index_adj_dict = {
        "sex_age": [],
        "sex_age_stage": []
    }

    if with_grade:
        c_index_dict["sex_age_grade"] = []
        c_index_adj_dict["sex_age_grade"] = []
        c_index_dict["sex_age_grade_stage"] = []
        c_index_adj_dict["sex_age_grade_stage"] = []

    for fold in range(5):

        results = CoxPH_per_fold(fold, data_source_path, list(c_index_dict.keys()))

        for key in c_index_dict:
            # Collect C-index
            c_index_dict[key].append(results[key]["c_index"])
            c_index_adj_dict[key].append(results[key]["c_index_ipcw"])


    print('Results of the clinical covariate cox PD model.')
    for key in c_index_dict:
        print("Model type: ", key)
        c_index_values = np.array(c_index_dict[key])
        c_index_adj_values = np.array(c_index_adj_dict[key])
        print(f"C-index: {np.mean(c_index_values):.3f}±{np.std(c_index_values):.3f}")
        print(f"C-index IPCW: {np.mean(c_index_adj_values):.3f}±{np.std(c_index_adj_values):.3f}")



# Compute clinical results on different datasets

In [None]:
for with_grade, data_type in zip([False, True, False, True], ['tcga_brca', 'tcga_blca', 'tcga_luad', 'tcga_kirc']):
    print(f"Data type: {data_type}")
    data_source_brca = f'../data/data_files/{data_type}/splits/'
    compute_clinical_results(data_source_brca, with_grade=with_grade)

# Visualize data

In [41]:
def obtain_all_fold_data(fold_folder):

    fold_folder_train = os.path.join(fold_folder, 'train.csv')
    data_df_train = pd.read_csv(fold_folder_train)
    clinical_df_train = prepare_df_clinical_variables(pd.DataFrame(data_df_train), True)
    
    fold_folder_test = os.path.join(fold_folder, 'test.csv')
    data_df_test = pd.read_csv(fold_folder_test)
    clinical_df_test = prepare_df_clinical_variables(pd.DataFrame(data_df_test), True)

    clinical_df = pd.concat([clinical_df_train, clinical_df_test], ignore_index=True)

    # Optionally, you can reset the index if you want a clean sequential index
    clinical_df.reset_index(drop=True, inplace=True)
    
    return clinical_df

In [None]:
for with_grade, data_type in zip([False, True, False, True], ['tcga_brca', 'tcga_blca', 'tcga_luad', 'tcga_kirc']):
    result_dir_source = f'../data/data_files/{data_type}/splits/0/'
    data_df = obtain_all_fold_data(result_dir_source)
    visualize_df(data_df, data_type, with_grade)