# Random Survival Forest

In [1]:
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd

%matplotlib inline

from sklearn import set_config
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import OrdinalEncoder

from sksurv.datasets import load_gbsg2
from sksurv.ensemble import RandomSurvivalForest
from sksurv.preprocessing import OneHotEncoder

set_config(display="text")  # displays text representation of estimators

In [2]:
# Load the dataset.
df = pd.read_csv("../data/cleaned_data.csv")

In [3]:
# One-hot encode categorical columns in 0/1 format.
df = pd.get_dummies(df, dtype=int)

In [4]:
df.head()

Unnamed: 0,ID,hla_match_c_high,hla_high_res_8,hla_low_res_6,hla_high_res_6,hla_high_res_10,hla_match_dqb1_high,hla_nmdp_6,hla_match_c_low,hla_match_drb1_low,...,donor_related_Related,donor_related_Unrelated,melphalan_dose_MEL,"melphalan_dose_N/A, Mel not given",cardiac_No,cardiac_Not done,cardiac_Yes,pulm_moderate_No,pulm_moderate_Not done,pulm_moderate_Yes
0,0,2.0,8.0,6.0,6.0,10.0,2.0,6.0,2.0,2.0,...,0,1,0,1,1,0,0,1,0,0
1,1,2.0,8.0,6.0,6.0,10.0,2.0,6.0,2.0,2.0,...,1,0,0,1,1,0,0,0,0,1
2,2,2.0,8.0,6.0,6.0,10.0,2.0,6.0,2.0,2.0,...,1,0,0,1,1,0,0,1,0,0
3,3,2.0,8.0,6.0,6.0,10.0,2.0,6.0,2.0,2.0,...,0,1,0,1,1,0,0,1,0,0
4,4,2.0,8.0,6.0,6.0,10.0,2.0,5.0,2.0,2.0,...,1,0,1,0,1,0,0,1,0,0


In [5]:
X = df.drop(['ID', 'efs', 'efs_time'], axis=1)
y = np.array([tuple(row) for row in df[['efs', 'efs_time']].values], dtype=[('cens', '?'), ('time', '<f8')])

In [6]:
random_state = 20

X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.25, random_state=random_state)

In [None]:
rsf = RandomSurvivalForest(n_estimators=1000,
                           min_samples_split=10,
                           min_samples_leaf=15,
                           n_jobs=-1,
                           random_state=random_state)
rsf.fit(X_train, y_train)

In [None]:
# Check how well the model performs on the test data.
rsf.score(X_test, y_test)