# Random Survival Forest Model

In [2]:
import pandas as pd
import numpy as np
import matplotlib as plt
import seaborn as sns

from sklearn import set_config
from sklearn.model_selection import train_test_split
from sklearn.model_selection import KFold, GridSearchCV, StratifiedKFold

from sksurv.ensemble import RandomSurvivalForest
from sksurv.metrics import concordance_index_censored

import time
from datetime import datetime, timedelta

In [3]:
training_data_imputed_df = pd.read_csv("training_data_imputed_simple_TRAIN.csv.gz") 
# Filter to 30 day window
survival_data_30d = training_data_imputed_df.loc[
    (training_data_imputed_df['survival_time'] < 30) &
    (training_data_imputed_df['survival_time'] > 0)
    ]
survival_data_30d = survival_data_30d.astype({"cdiff_survival_flag": bool})

In [4]:
train_df, val_df, = train_test_split(survival_data_30d, test_size=0.2, random_state=0)

In [9]:
X_train = train_df.drop(['cdiff_2d_flag', 'cdiff_7d_flag', 'cdiff_30d_flag', 'cdiff_survival_flag', 'survival_time'], axis=1)
y_train = train_df[['cdiff_survival_flag', 'survival_time']]
y_array_event_time_train = np.array(
    list(zip(y_train['cdiff_survival_flag'], y_train['survival_time'])), 
    dtype=[('status', bool), ('survival_in_days', float)]
)
X_val = val_df.drop(['cdiff_2d_flag', 'cdiff_7d_flag', 'cdiff_30d_flag', 'cdiff_survival_flag', 'survival_time'], axis=1)
y_val = val_df[['cdiff_survival_flag', 'survival_time']]
y_array_event_time_val = np.array(
    list(zip(y_val['cdiff_survival_flag'], y_val['survival_time'])), 
    dtype=[('status', bool), ('survival_in_days', float)]
)

In [11]:
rsf = RandomSurvivalForest(
    n_estimators=5, min_samples_split=5, min_samples_leaf=5, n_jobs=-1, random_state=0
)
rsf.fit(X_train, y_array_event_time_train)

KeyboardInterrupt: 

In [6]:
c_index_val = rsf.score(X_val, y_array_event_time_val)
c_index_val

np.float64(0.9884094743346037)

In [13]:
y_train['cdiff_survival_flag'].value_counts()

cdiff_survival_flag
False    62044
True      4384
Name: count, dtype: int64

In [1]:

# Custom scoring function for GridSearchCV
def custom_c_index(estimator, X, y):
    """Custom scoring function for GridSearchCV"""
    prediction = estimator.predict(X)
    result = concordance_index_censored(y['event'], y['time'], prediction)
    return result[0]  # Return the concordance index
    
# param_grid = {
#     'n_estimators': [2, 5, 10, 20],
#     'min_samples_split': [3, 5, 7, 9],
#     'min_samples_leaf': [4, 5, 6],
# }

param_grid = {
    'n_estimators': [2, 5],
    'min_samples_split': [2],
    'min_samples_leaf': [2],
}

rsf = RandomSurvivalForest(
    n_jobs=-1, random_state=0
)

n_splits = 4
skf = StratifiedKFold(n_splits=n_splits, shuffle=True, random_state=0)

grid_search = GridSearchCV(
    estimator=rsf,
    param_grid=param_grid,
    cv=skf,
    scoring=custom_c_index,
    verbose=1,
    n_jobs=-1  
)

SyntaxError: invalid syntax (465273436.py, line 5)

In [32]:
start_time = time.time()
start_datetime = datetime.now()
print(f"Training started at: {start_datetime.strftime('%Y-%m-%d %H:%M:%S')}")

print("Starting GridSearchCV to find optimal parameters...")
grid_search.fit(X_train, y_array_event_time_train)

print("\nBest parameters found:")
print(grid_search.best_params_)
print(f"Best cross-validation {scoring} score: {grid_search.best_score_:.4f}")

best_model = grid_search.best_estimator_

end_time = time.time()
end_datetime = datetime.now()

duration_seconds = end_time - start_time
duration = timedelta(seconds=duration_seconds)

print(f"Training finished at: {end_datetime.strftime('%Y-%m-%d %H:%M:%S')}")
print(f"Total training time: {duration}")

Training started at: 2025-05-23 11:03:19
Starting GridSearchCV to find optimal parameters...
Fitting 4 folds for each of 2 candidates, totalling 8 fits


Traceback (most recent call last):
  File "/project/bios26406/conda/ml4h/lib/python3.13/site-packages/sklearn/model_selection/_validation.py", line 949, in _score
    scores = scorer(estimator, X_test, y_test, **score_params)
  File "/scratch/local/jobs/31202343/ipykernel_1136308/3730242851.py", line 4, in custom_c_index
    c_index = rsf.score(X_val, y_array_event_time_val)
  File "/project/bios26406/conda/ml4h/lib/python3.13/site-packages/sksurv/base.py", line 95, in score
    risk_score = self.predict(X)
  File "/project/bios26406/conda/ml4h/lib/python3.13/site-packages/sksurv/ensemble/forest.py", line 292, in predict
    return self._predict("predict", X)
           ~~~~~~~~~~~~~^^^^^^^^^^^^^^
  File "/project/bios26406/conda/ml4h/lib/python3.13/site-packages/sksurv/ensemble/forest.py", line 239, in _predict
    check_is_fitted(self, "estimators_")
    ~~~~~~~~~~~~~~~^^^^^^^^^^^^^^^^^^^^^
  File "/project/bios26406/conda/ml4h/lib/python3.13/site-packages/sklearn/utils/validation.py

KeyboardInterrupt: 