# Random Survival Forests

In [1]:
pip install scikit-survival

Collecting scikit-survival
  Downloading scikit_survival-0.22.2-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (3.7 MB)
[?25l     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/3.7 MB[0m [31m?[0m eta [36m-:--:--[0m[2K     [91m━━━[0m[91m╸[0m[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.4/3.7 MB[0m [31m11.8 MB/s[0m eta [36m0:00:01[0m[2K     [91m━━━━━━━━━[0m[91m╸[0m[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.9/3.7 MB[0m [31m13.0 MB/s[0m eta [36m0:00:01[0m[2K     [91m━━━━━━━━━━━━━━━━━━[0m[91m╸[0m[90m━━━━━━━━━━━━━━━━━━━━━[0m [32m1.7/3.7 MB[0m [31m16.5 MB/s[0m eta [36m0:00:01[0m[2K     [91m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m[90m╺[0m[90m━━━━━━━━━━[0m [32m2.7/3.7 MB[0m [31m19.5 MB/s[0m eta [36m0:00:01[0m[2K     [91m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m[91m╸[0m [32m3.7/3.7 MB[0m [31m21.1 MB/s[0m eta [36m0:00:01[0m[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m3.7/3.7 MB[0m 

In [3]:
%matplotlib inline
import numpy as np
import matplotlib.pyplot as plt
import pandas as pd
import seaborn as sns
from sklearn import set_config
from sklearn.model_selection import RepeatedKFold, GridSearchCV

from sksurv.datasets import load_veterans_lung_cancer
from sksurv.column import encode_categorical
from sksurv.metrics import concordance_index_censored
from sksurv.ensemble import RandomSurvivalForest

set_config(display="text")  # displays text representation of estimators
sns.set_style("whitegrid")

In [4]:
# load dataset
data_x, y = load_veterans_lung_cancer()
x = encode_categorical(data_x)

In [5]:
from sklearn.model_selection import train_test_split

# Split the data into train and test sets
x_train, x_test, y_train, y_test = train_test_split(x, y, test_size=0.2, random_state=42)

In [6]:
rsf = RandomSurvivalForest(
    n_estimators=1000, min_samples_split=10, min_samples_leaf=15, n_jobs=-1, random_state=42
)
rsf.fit(x_train, y_train)

RandomSurvivalForest(min_samples_leaf=15, min_samples_split=10,
                     n_estimators=1000, n_jobs=-1, random_state=42)

In [None]:
def score_survival_model(model, X, y):
    prediction = model.predict(X)
    result = concordance_index_censored(y["Status"], y["Survival_in_days"], -prediction)
    return result[0]

In [None]:
param_grid = {"alpha": 2.0 ** np.arange(-12, 13, 2)}

# 5 repetitions of 5-fold cross-validation
cv = RepeatedKFold(n_splits=5, n_repeats=5, random_state=42)
gcv = GridSearchCV(rsf, param_grid, scoring=score_survival_model, n_jobs=1, refit=False, cv=cv)

In [None]:
import warnings

warnings.filterwarnings("ignore", category=UserWarning)
gcv = gcv.fit(x_train, y_train)

In [None]:
print("Best parameters:", gcv.best_params_)
print("Best C-index:", round(gcv.best_score_, 3))

In [None]:
def plot_performance(gcv):
    n_splits = gcv.cv.get_n_splits()
    cv_results = pd.DataFrame(gcv.cv_results_)
    cv_results = cv_results.rename(columns=lambda x: x.replace("param_", ""))

    plt.figure(figsize=(6, 3))
    sns.stripplot(x="alpha", y="mean_test_score", data=cv_results, jitter=True, dodge=True)
    plt.title("Cross-validated Performance vs. Alpha")
    plt.xlabel("Alpha")
    plt.ylabel("Mean Evaluation Score")
    plt.xticks(rotation=45)
    plt.show()

plot_performance(gcv)

In [None]:
estimator.set_params(**gcv.best_params_)
estimator.fit(x_train, y_train)

In [None]:
test_pred = estimator.predict(x_test)
#print(np.round(test_pred, 3))
#print(y_test)

# Evaluate the model on the test set
test_cindex = concordance_index_censored(
    y_test["Status"],
    y_test["Survival_in_days"],
    -test_pred)           # flip sign to obtain risk scores
print("C-index on test set:", round(test_cindex[0], 3))


In [None]:
# reference estimator (regression) and result

ref_estimator = FastSurvivalSVM(rank_ratio=0.0, max_iter=1000, tol=1e-5, random_state=42)
ref_estimator.fit(x_train, y_train)

cindex = concordance_index_censored(
    y_train["Status"],
    y_train["Survival_in_days"],
    -ref_estimator.predict(x_train),  # flip sign to obtain risk scores
)
print("C-index (train):", round(cindex[0], 3))

cindex = concordance_index_censored(
    y_test["Status"],
    y_test["Survival_in_days"],
    -ref_estimator.predict(x_test),  # flip sign to obtain risk scores
)
print("C-index (test):", round(cindex[0], 3))

In [None]:
pred = ref_estimator.predict(x_test.iloc[:2])
print(np.round(pred, 3))