In [1]:
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
from sklearn.model_selection import train_test_split, StratifiedKFold
from sksurv.linear_model import CoxPHSurvivalAnalysis
from sksurv.metrics import concordance_index_censored
from sksurv.nonparametric import kaplan_meier_estimator
import warnings


In [2]:
class CoxPipeline:
    def __init__(self, alpha=0.5, n_splits=5, random_state=42):
        self.alpha = alpha
        self.n_splits = n_splits
        self.random_state = random_state
        self.model = None
        warnings.filterwarnings("ignore")

    def split_data(self, data, features, target, duration):
        # Train-test split
        train_val_data, test_data = train_test_split(
            data, test_size=0.2, stratify=data[target], random_state=self.random_state
        )
        
        # Further split train_val_data into train and validation
        train_data, val_data = train_test_split(
            train_val_data, test_size=0.2, stratify=train_val_data[target], random_state=self.random_state
        )

        # Prepare data for Cox model
        X_train = train_data[features]
        y_train = np.array(list(zip(train_data[target], train_data[duration])), dtype=[(target, bool), (duration, int)])

        X_val = val_data[features]
        y_val = np.array(list(zip(val_data[target], val_data[duration])), dtype=[(target, bool), (duration, int)])

        X_test = test_data[features]
        y_test = np.array(list(zip(test_data[target], test_data[duration])), dtype=[(target, bool), (duration, int)])

        return X_train, y_train, X_val, y_val, X_test, y_test

    def cross_validate(self, X_train, y_train, target):
        skf = StratifiedKFold(n_splits=self.n_splits, shuffle=True, random_state=self.random_state)
        concordance_indices = []

        for fold, (train_idx, val_idx) in enumerate(skf.split(X_train, y_train[target])):
            X_train_fold, X_val_fold = X_train.iloc[train_idx], X_train.iloc[val_idx]
            y_train_fold, y_val_fold = y_train[train_idx], y_train[val_idx]

            estimator = CoxPHSurvivalAnalysis(alpha=self.alpha)
            estimator.fit(X_train_fold, y_train_fold)

            prediction = estimator.predict(X_val_fold)
            result = concordance_index_censored(y_val_fold[target], y_val_fold[duration], prediction)
            concordance_indices.append(result[0])

            print(f"Fold {fold + 1} Concordance Index: {result[0]}")

        average_concordance_index = np.mean(concordance_indices)
        print("Average Concordance Index:", average_concordance_index)
        return average_concordance_index

    def train_and_evaluate(self, X_train, y_train, X_test, y_test):
        self.model = CoxPHSurvivalAnalysis(alpha=self.alpha)
        self.model.fit(X_train, y_train)

        test_predictions = self.model.predict(X_test)
        test_result = concordance_index_censored(y_test[target], y_test[duration], test_predictions)
        print("Test Data Concordance Index:", test_result[0])

        coefficients = pd.Series(self.model.coef_, index=X_test.columns)
        print(coefficients)

        return test_result[0], coefficients
    
    def plot_kaplan_meier(self, data, target, duration):
        time = np.array(list(data[duration].astype(int)))
        events = np.array(list(data[target].astype(bool)))

        survival_probabilities, time_points, conf_int = kaplan_meier_estimator(
            events, time, conf_type="log-log"
        )

        plt.figure(figsize=(10, 6))
        plt.step(time_points, survival_probabilities, where='post', label='Kaplan-Meier Estimate')

        if conf_int.shape[1] == 2 and len(conf_int) == len(time_points):
            plt.fill_between(
                time_points, conf_int[:, 0], conf_int[:, 1], color='grey', alpha=0.2, label='95% Confidence Interval'
            )

        plt.xlabel('Time')
        plt.ylabel('Survival Probability')
        plt.title('Kaplan-Meier Survival Estimate')
        plt.legend()
        plt.grid(True)
        plt.show()

    def run_pipeline(self, data, features, target, duration):
        X_train, y_train, X_val, y_val, X_test, y_test = self.split_data(data, features, target, duration)
        self.cross_validate(X_train, y_train, target)
        test_concordance, coefficients = self.train_and_evaluate(X_train, y_train, X_test, y_test)
        #self.plot_kaplan_meier(data, target, duration)
        return test_concordance, coefficients


In [3]:
def analyze_file(data):
    print("Dataset:")
    display(data)
    
    # Output the shape of the dataset
    print("\nShape of the dataset (rows, columns):")
    print(data.shape)
    
    # Output the data types of each column
    print("\nData types of each column:")
    print(data.dtypes)
    
    # Output the count of missing values in each column
    print("\nCount of missing values in each column:")
    print(data.isnull().sum())
def get_high_correlation_features(df, target_columns, threshold=0.9):
    """
    Get features with correlation greater than the specified threshold with the target variables.

    Parameters:
    - df (pd.DataFrame): The input DataFrame with features and target variables.
    - target_columns (list of str): List of target variable column names.
    - threshold (float): Correlation threshold to filter features.

    Returns:
    - dict: A dictionary where keys are column names and values are their correlation values.
    """
    high_corr_features = {}

    # Compute correlation matrix
    corr_matrix = df.corr()

    # Iterate over columns in the DataFrame
    for column in df.columns:
        if column not in target_columns:
            # Check correlation with each target variable
            for target in target_columns:
                correlation_value = corr_matrix.loc[column, target]
                if abs(correlation_value) > threshold:
                    high_corr_features[column] = correlation_value
                    break  # No need to check other targets for this feature

    return high_corr_features

def show_correlation_heatmap(df, top_n_columns):
    """
    Displays a correlation matrix heatmap for the specified top_n_columns along with 'OS' and 'OS.time'.
    
    Parameters:
    df (pd.DataFrame): The DataFrame containing the data.
    top_n_columns (list): A list of column names for which the correlation matrix will be displayed.
    """
    # Ensure 'OS' and 'OS.time' are included in the correlation matrix
    columns_to_include = top_n_columns + ['OS', 'OS.time']
    
    # Select only the specified columns, including 'OS' and 'OS.time'
    selected_data = df[columns_to_include]
    
    # Compute the correlation matrix
    correlation_matrix = selected_data.corr()
    
    # Plot heatmap
    plt.figure(figsize=(24, 22))
    sns.heatmap(correlation_matrix, annot=True, cmap='coolwarm', linewidths=0.5)
    plt.title(f'Correlation Matrix Heatmap for Top {len(top_n_columns)} Features + OS and OS.time')
    plt.show()



In [4]:
data = pd.read_csv('/Users/simrantanwar/Desktop/College/DDP/survival_analysis/data/tcga_brain_cleaned_26_10_2024.csv')
analyze_file(data)

Dataset:


Unnamed: 0.1,Unnamed: 0,Patient ID,age_at_initial_pathologic_diagnosis,initial_pathologic_dx_year,gender_MALE,race_WHITE,histological_type_Oligoastrocytoma,histological_type_Oligodendroglioma,histological_type_Untreated primary (de novo) GBM,treatment_outcome_first_course_Complete Remission/Response,treatment_outcome_first_course_Partial Remission/Response,treatment_outcome_first_course_Progressive Disease,OS,OS.time
0,0,TCGA-02-0047,28489.50,2005.0,1.0,1.0,0.0,0.0,1.0,0.0,0.0,0.0,1.0,448.0
1,1,TCGA-02-0055,22645.50,2005.0,0.0,1.0,0.0,0.0,1.0,0.0,0.0,0.0,1.0,76.0
2,2,TCGA-02-2483,15705.75,2008.0,1.0,0.0,0.0,0.0,1.0,0.0,0.0,0.0,0.0,466.0
3,3,TCGA-02-2485,19358.25,2009.0,1.0,0.0,0.0,0.0,1.0,0.0,0.0,0.0,0.0,470.0
4,4,TCGA-02-2486,23376.00,2008.0,1.0,1.0,0.0,0.0,1.0,0.0,0.0,0.0,1.0,618.0
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
661,661,TCGA-WY-A85A,7305.00,2010.0,1.0,1.0,0.0,0.0,0.0,0.0,1.0,0.0,0.0,1320.0
662,662,TCGA-WY-A85B,8766.00,2010.0,1.0,1.0,0.0,0.0,0.0,1.0,0.0,0.0,0.0,1393.0
663,663,TCGA-WY-A85C,13149.00,2010.0,1.0,1.0,0.0,0.0,0.0,0.0,0.0,1.0,0.0,1426.0
664,664,TCGA-WY-A85D,21915.00,2010.0,1.0,1.0,1.0,0.0,0.0,0.0,1.0,0.0,0.0,1147.0



Shape of the dataset (rows, columns):
(666, 14)

Data types of each column:
Unnamed: 0                                                      int64
Patient ID                                                     object
age_at_initial_pathologic_diagnosis                           float64
initial_pathologic_dx_year                                    float64
gender_MALE                                                   float64
race_WHITE                                                    float64
histological_type_Oligoastrocytoma                            float64
histological_type_Oligodendroglioma                           float64
histological_type_Untreated primary (de novo) GBM             float64
treatment_outcome_first_course_Complete Remission/Response    float64
treatment_outcome_first_course_Partial Remission/Response     float64
treatment_outcome_first_course_Progressive Disease            float64
OS                                                            float64
OS.time      

In [5]:
data = data.drop(['Unnamed: 0','Patient ID'],axis=1)
analyze_file(data)

Dataset:


Unnamed: 0,age_at_initial_pathologic_diagnosis,initial_pathologic_dx_year,gender_MALE,race_WHITE,histological_type_Oligoastrocytoma,histological_type_Oligodendroglioma,histological_type_Untreated primary (de novo) GBM,treatment_outcome_first_course_Complete Remission/Response,treatment_outcome_first_course_Partial Remission/Response,treatment_outcome_first_course_Progressive Disease,OS,OS.time
0,28489.50,2005.0,1.0,1.0,0.0,0.0,1.0,0.0,0.0,0.0,1.0,448.0
1,22645.50,2005.0,0.0,1.0,0.0,0.0,1.0,0.0,0.0,0.0,1.0,76.0
2,15705.75,2008.0,1.0,0.0,0.0,0.0,1.0,0.0,0.0,0.0,0.0,466.0
3,19358.25,2009.0,1.0,0.0,0.0,0.0,1.0,0.0,0.0,0.0,0.0,470.0
4,23376.00,2008.0,1.0,1.0,0.0,0.0,1.0,0.0,0.0,0.0,1.0,618.0
...,...,...,...,...,...,...,...,...,...,...,...,...
661,7305.00,2010.0,1.0,1.0,0.0,0.0,0.0,0.0,1.0,0.0,0.0,1320.0
662,8766.00,2010.0,1.0,1.0,0.0,0.0,0.0,1.0,0.0,0.0,0.0,1393.0
663,13149.00,2010.0,1.0,1.0,0.0,0.0,0.0,0.0,0.0,1.0,0.0,1426.0
664,21915.00,2010.0,1.0,1.0,1.0,0.0,0.0,0.0,1.0,0.0,0.0,1147.0



Shape of the dataset (rows, columns):
(666, 12)

Data types of each column:
age_at_initial_pathologic_diagnosis                           float64
initial_pathologic_dx_year                                    float64
gender_MALE                                                   float64
race_WHITE                                                    float64
histological_type_Oligoastrocytoma                            float64
histological_type_Oligodendroglioma                           float64
histological_type_Untreated primary (de novo) GBM             float64
treatment_outcome_first_course_Complete Remission/Response    float64
treatment_outcome_first_course_Partial Remission/Response     float64
treatment_outcome_first_course_Progressive Disease            float64
OS                                                            float64
OS.time                                                       float64
dtype: object

Count of missing values in each column:
age_at_initial_pathologic_di

# with treatment outcome first course

In [6]:
column_names = data.columns.tolist()
features = column_names[:-2]
target = column_names[-2]
print(target)
duration = column_names[-1]
print(duration)

# Initialize CoxPipeline with default parameters
pipeline = CoxPipeline(alpha=0.5, n_splits=5, random_state=42)

# Run the pipeline
test_concordance, coefficients = pipeline.run_pipeline(data, features, target, duration)

# Output results
#print("Test Concordance Index:", test_concordance)
#print("\nModel Coefficients:")
#print(coefficients)


OS
OS.time
Fold 1 Concordance Index: 0.8231907894736842
Fold 2 Concordance Index: 0.7923076923076923
Fold 3 Concordance Index: 0.8671140939597315
Fold 4 Concordance Index: 0.8241525423728814
Fold 5 Concordance Index: 0.8551344743276283
Average Concordance Index: 0.8323799184883235
Test Data Concordance Index: 0.8797845194216047
age_at_initial_pathologic_diagnosis                           0.000105
initial_pathologic_dx_year                                   -0.007203
gender_MALE                                                   0.215550
race_WHITE                                                   -0.250820
histological_type_Oligoastrocytoma                            0.028246
histological_type_Oligodendroglioma                          -0.847827
histological_type_Untreated primary (de novo) GBM             1.298000
treatment_outcome_first_course_Complete Remission/Response   -1.394913
treatment_outcome_first_course_Partial Remission/Response    -0.965117
treatment_outcome_first_course_

# Without treatment outcome first course

In [7]:
columns_to_drop = ['treatment_outcome_first_course_Complete Remission/Response','treatment_outcome_first_course_Partial Remission/Response','treatment_outcome_first_course_Progressive Disease']
data = data.drop(columns_to_drop,axis=1)
column_names = data.columns.tolist()
features = column_names[:-2]
target = column_names[-2]
print(target)
duration = column_names[-1]
print(duration)

# Initialize CoxPipeline with default parameters
pipeline = CoxPipeline(alpha=0.5, n_splits=5, random_state=42)

# Run the pipeline
test_concordance, coefficients = pipeline.run_pipeline(data, features, target, duration)

# Output results
#print("Test Concordance Index:", test_concordance)
#print("\nModel Coefficients:")
#print(coefficients)


OS
OS.time
Fold 1 Concordance Index: 0.7574013157894737
Fold 2 Concordance Index: 0.7713286713286713
Fold 3 Concordance Index: 0.8476510067114094
Fold 4 Concordance Index: 0.8248587570621468
Fold 5 Concordance Index: 0.8508557457212714
Average Concordance Index: 0.8104190993225945
Test Data Concordance Index: 0.8848880068046499
age_at_initial_pathologic_diagnosis                  0.000117
initial_pathologic_dx_year                          -0.038319
gender_MALE                                          0.223296
race_WHITE                                          -0.241569
histological_type_Oligoastrocytoma                  -0.182878
histological_type_Oligodendroglioma                 -0.752344
histological_type_Untreated primary (de novo) GBM    1.138768
dtype: float64
