In [1]:
!pip install survshap

Collecting survshap
  Downloading survshap-0.4.2-py3-none-any.whl (19 kB)
Collecting scikit-survival>=0.17.2 (from survshap)
  Downloading scikit_survival-0.22.2-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (3.7 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m3.7/3.7 MB[0m [31m9.3 MB/s[0m eta [36m0:00:00[0m
Collecting shap>0.41.0 (from survshap)
  Downloading shap-0.44.1-cp310-cp310-manylinux_2_12_x86_64.manylinux2010_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl (535 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m535.7/535.7 kB[0m [31m17.6 MB/s[0m eta [36m0:00:00[0m
Collecting scikit-learn<1.4,>=1.3.0 (from scikit-survival>=0.17.2->survshap)
  Downloading scikit_learn-1.3.2-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (10.8 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m10.8/10.8 MB[0m [31m32.8 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting slicer==0.0.7 (from shap>0.41.0->sur

In [2]:
from survshap import SurvivalModelExplainer, PredictSurvSHAP, ModelSurvSHAP
from sksurv.ensemble import RandomSurvivalForest
from sksurv.datasets import load_gbsg2
from sksurv.preprocessing import OneHotEncoder
from sklearn import set_config
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import OrdinalEncoder
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt


%matplotlib inline

In [3]:
X, y = load_gbsg2()

grade_str = X.loc[:, "tgrade"].astype(object).values[:, np.newaxis]
grade_num = OrdinalEncoder(categories=[["I", "II", "III"]]).fit_transform(grade_str)

X = X.drop("tgrade", axis=1)
X = OneHotEncoder().fit_transform(X)
X.loc[:, "tgrade"] = grade_num

random_state = 42

X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.25, random_state=random_state)

In [4]:
model = RandomSurvivalForest(
    n_estimators=1000, min_samples_split=10, min_samples_leaf=15, n_jobs=-1, random_state=random_state
)
model.fit(X_train, y_train)

In [5]:
explainer = SurvivalModelExplainer(model = model, data = X, y = y)

# compute SHAP values for a single instance
observation_A = X.iloc[[0]]
survshap_A = PredictSurvSHAP()
survshap_A.fit(explainer = explainer, new_observation = observation_A)

survshap_A.result
survshap_A.plot()

In [8]:
# compute SHAP values for a group of instances
model_survshap = ModelSurvSHAP(calculation_method="treeshap") # fast implementation for tree-based models
model_survshap.fit(explainer = explainer, new_observations = X.iloc[:10,:])

model_survshap.result
model_survshap.plot_mean_abs_shap_values()
model_survshap.plot_shap_lines_for_all_individuals(variable = "age")
extracted_survshap = model_survshap.individual_explanations[0]



In [7]:
X.columns

Index(['age', 'estrec', 'horTh=yes', 'menostat=Post', 'pnodes', 'progrec',
       'tsize', 'tgrade'],
      dtype='object')