In [148]:
import pandas as pd
import numpy as np
from sksurv.ensemble import RandomSurvivalForest, GradientBoostingSurvivalAnalysis
from sksurv.metrics import concordance_index_ipcw, brier_score, cumulative_dynamic_auc, concordance_index_censored, integrated_brier_score
from sklearn.preprocessing import OneHotEncoder
from sklearn.model_selection import train_test_split
from sksurv.util import Surv
import matplotlib.pyplot as plt
from sklearn.inspection import permutation_importance

import random

<p>All 5 of the imputed datasets</p>

In [66]:
train1 = pd.read_csv("../data/imputations/final_w_imps_1.csv")
train2 = pd.read_csv("../data/imputations/final_w_imps_2.csv")
train3 = pd.read_csv("../data/imputations/final_w_imps_3.csv")
train4 = pd.read_csv("../data/imputations/final_w_imps_4.csv")
train5 = pd.read_csv("../data/imputations/final_w_imps_5.csv")

train = [train1, train2, train3, train4, train5]

test1 = pd.read_csv("../data/imputations/final_w_imps_1_test.csv")
test2 = pd.read_csv("../data/imputations/final_w_imps_2_test.csv")
test3 = pd.read_csv("../data/imputations/final_w_imps_3_test.csv")
test4 = pd.read_csv("../data/imputations/final_w_imps_4_test.csv")
test5 = pd.read_csv("../data/imputations/final_w_imps_5_test.csv")

test = [test1, test2, test3, test4, test5]

<p>Removing the Pseudo-observation columns</p>

In [67]:
for idx in range(len(train)):
    train[idx] = train[idx].loc[:, ~train[idx].columns.str.startswith("PO_")]

for idx in range(len(test)):
    test[idx] = test[idx].loc[:, ~test[idx].columns.str.startswith("PO_")]

<p>Grouping ethnicities</p>

In [68]:
def collapse_ethnicity(val):
    val = str(val).upper()
    if "ASIAN" in val:
        return "Asian"
    elif "WHITE" in val or "MIDDLE" in val:
        return "White"
    elif "BLACK" in val:
        return "Black or African American"
    elif any(word in val for word in ["OTHER", "PATIENT", "UNABLE", "UNKNOWN"]):
        return "Unknown"
    elif "AMERICAN" in val:
        return "American Indian or Alaska Native"
    elif "HISPANIC" in val:
        return "Hispanic"
    elif "MULTI" in val:
        return "More than one race"
    else:
        return val

# Apply ethnicity collapsing to each dataframe in ds
for i in range(len(train)):
    if 'ethnicity' in train[i].columns:
        train[i]['ethnicity'] = train[i]['ethnicity'].apply(collapse_ethnicity)

for i in range(len(test)):
    if 'ethnicity' in test[i].columns:
        test[i]['ethnicity'] = test[i]['ethnicity'].apply(collapse_ethnicity)

A value is trying to be set on a copy of a slice from a DataFrame.
Try using .loc[row_indexer,col_indexer] = value instead

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy
  train[i]['ethnicity'] = train[i]['ethnicity'].apply(collapse_ethnicity)
A value is trying to be set on a copy of a slice from a DataFrame.
Try using .loc[row_indexer,col_indexer] = value instead

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy
  test[i]['ethnicity'] = test[i]['ethnicity'].apply(collapse_ethnicity)


<p>Splitting into X and y for each dataset</p>
<p>Combines into one dataset for onehot encoding before resplitting</p>

In [124]:
def splitXY(train_dfs, test_dfs):
    # Track lengths for re-splitting
    train_lengths = [len(df) for df in train_dfs]
    test_lengths = [len(df) for df in test_dfs]

    # Combine all into one DataFrame (with origin label)
    combined = pd.concat(train_dfs + test_dfs, ignore_index=True)

    # Clip negative times
    combined["survival_days"] = combined["survival_days"].clip(lower=0)
    combined["surv_days_w_mean_imps"] = combined["surv_days_w_mean_imps"].clip(lower=0)

    # Create binary outcome
    combined["survived_90"] = combined["survival_days"] > 90

    # Do one-hot encoding across the entire dataset
    X_all = pd.get_dummies(combined.drop(columns=["subject_id", "survival_days", "survived_90", "event", "surv_days_w_mean_imps", "delta"]))
    y_all = combined["surv_days_w_mean_imps"]
    y_event = combined["event"]

    # Rebuild structured target
    y_struct_all = Surv.from_arrays(event=y_event.astype(bool), time=y_all)

    # Re-split into original pieces
    X_train, y_train, X_test, y_test = [], [], [], []

    train_cumsum = [0] + list(pd.Series(train_lengths).cumsum())
    test_cumsum = [train_cumsum[-1]] + list((pd.Series(test_lengths).cumsum() + train_cumsum[-1]))

    # Slice train
    for i in range(len(train_dfs)):
        X_train.append(X_all.iloc[train_cumsum[i]:train_cumsum[i+1]].reset_index(drop=True))
        y_train.append(y_struct_all[train_cumsum[i]:train_cumsum[i+1]])

    # Slice test
    for i in range(len(test_dfs)):
        X_test.append(X_all.iloc[test_cumsum[i]:test_cumsum[i+1]].reset_index(drop=True))
        y_test.append(y_struct_all[test_cumsum[i]:test_cumsum[i+1]])

    return X_train, y_train, X_test, y_test

X_train, y_train, X_test, y_test = splitXY(train, test)

In [162]:
def run_rsf(X_train_set, y_train_set, X_test_set, y_test_set, plot_fp=False):
    rsf = RandomSurvivalForest(n_estimators=100, min_samples_split=10,
                               min_samples_leaf=15, random_state=7)
    rsf.fit(X_train_set, y_train_set)

    # Predict survival functions
    surv_funcs = rsf.predict_survival_function(X_test_set)

    # Define evaluation time points
    time_buckets = np.array([7, 30, 60, 180, 365, 730])
    surv_probs = np.row_stack([fn(time_buckets) for fn in surv_funcs])
    risk_scores = 1 - surv_probs  # Higher risk = lower survival

    c_rsf = []
    
    print("  RSF C-index per time bucket:")
    for i, t in enumerate(time_buckets):
        c_index, _, _, _, _ = concordance_index_censored(
            y_test_set["event"], [x[1] for x in y_test_set], risk_scores[:, i]
        )
        print(f"    Day {t:4d}: C-index = {c_index:.4f}")
        c_rsf.append(c_index)
    print(f"  Mean RSF C-index: {sum(c_rsf) / len(c_rsf):4f}")

    # Full IBS over a grid
    surv_probs_grid = np.row_stack([fn(time_buckets) for fn in surv_funcs])
    ibs = integrated_brier_score(y_train_set, y_test_set, surv_probs_grid, time_buckets)
    print(f"  RSF IBS:     {ibs:.4f}")

    # Permutation importance
    if plot_fp:
        perm_result = permutation_importance(
            rsf, X_test_set, y_test_set, n_repeats=10, random_state=42, n_jobs=-1
        )
        sorted_idx = perm_result.importances_mean.argsort()[::-1][:20]

        plt.figure(figsize=(10, 6))
        plt.barh(range(len(sorted_idx)), perm_result.importances_mean[sorted_idx][::-1], align='center')
        plt.yticks(range(len(sorted_idx)), X_test_set.columns[sorted_idx][::-1])
        plt.xlabel("Permutation Importance (Mean decrease in score)")
        plt.title("RSF Permutation Feature Importance (Top 20)")
        plt.tight_layout()
        plt.show()

def run_gb(X_train_set, y_train_set, X_test_set, y_test_set, plot_fp=False):
    gb = GradientBoostingSurvivalAnalysis(n_estimators=100, random_state=7)
    gb.fit(X_train_set, y_train_set)

    # Predict survival functions
    surv_funcs = gb.predict_survival_function(X_test_set)

    # Define time buckets
    time_buckets = np.array([7, 30, 60, 180, 365, 730])
    surv_probs = np.row_stack([fn(time_buckets) for fn in surv_funcs])
    risk_scores = 1 - surv_probs  # Higher risk = lower survival

    c_gb = []
    
    print("  GB C-index per time bucket:")
    for i, t in enumerate(time_buckets):
        c_index, _, _, _, _ = concordance_index_censored(
            y_test_set["event"], [x[1] for x in y_test_set], risk_scores[:, i]
        )
        print(f"    Day {t:4d}: C-index = {c_index:.4f}")
        c_gb.append(c_index)
    print(f"  Mean GB C-index: {sum(c_gb) / len(c_gb):4f}")
    
    # IBS
    surv_probs_grid = np.row_stack([fn(time_buckets) for fn in surv_funcs])
    ibs = integrated_brier_score(y_train_set, y_test_set, surv_probs_grid, time_buckets)
    print(f"  GB IBS:      {ibs:.4f}")

    # Feature importance
    if plot_fp:
        fi_gb = gb.feature_importances_
        top_idx = np.argsort(fi_gb)[::-1][:20]

        plt.figure(figsize=(10, 6))
        plt.barh(range(len(top_idx)), fi_gb[top_idx][::-1], align='center')
        plt.yticks(range(len(top_idx)), X_test_set.columns[top_idx][::-1])
        plt.xlabel("Feature Importance")
        plt.title("Gradient Boosting - Top 20 Feature Importances")
        plt.tight_layout()
        plt.show()



In [163]:
indices = [0, 1, 2, 3, 4]
random.shuffle(indices)

for i, j in enumerate(indices):
    print(f"Train Set {i}, Test Set {j}:")
    run_rsf(X_train[i], y_train[i], X_test[j], y_test[j])
    print()
    run_gb(X_train[i], y_train[i], X_test[j], y_test[j])
    print()
    print()

Train Set 0, Test Set 3:
  RSF C-index per time bucket:
    Day    7: C-index = 0.7930
    Day   30: C-index = 0.7882
    Day   60: C-index = 0.8027
    Day  180: C-index = 0.8296
    Day  365: C-index = 0.8305
    Day  730: C-index = 0.8321
  Mean RSF C-index: 0.812691
  RSF IBS:     0.0800

  GB C-index per time bucket:
    Day    7: C-index = 0.8016
    Day   30: C-index = 0.8016
    Day   60: C-index = 0.8016
    Day  180: C-index = 0.8016
    Day  365: C-index = 0.8016
    Day  730: C-index = 0.8016
  Mean GB C-index: 0.801637
  GB IBS:      0.0773


Train Set 1, Test Set 4:
  RSF C-index per time bucket:
    Day    7: C-index = 0.7848
    Day   30: C-index = 0.7843
    Day   60: C-index = 0.7990
    Day  180: C-index = 0.8284
    Day  365: C-index = 0.8288
    Day  730: C-index = 0.8381
  Mean RSF C-index: 0.810561
  RSF IBS:     0.0800

  GB C-index per time bucket:
    Day    7: C-index = 0.7994
    Day   30: C-index = 0.7994
    Day   60: C-index = 0.7994
    Day  180: C-index