In [29]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

from sklearn.preprocessing import OrdinalEncoder
from sklearn.model_selection import train_test_split
from sklearn.inspection import permutation_importance

from scipy.stats import gaussian_kde

from sksurv.datasets import load_gbsg2, load_whas500, load_aids
from sksurv.ensemble import RandomSurvivalForest
from sksurv.linear_model import CoxPHSurvivalAnalysis
from sksurv.preprocessing import OneHotEncoder

import shap

from scipy.stats import pearsonr

In [16]:
def split_and_train(X, y, model_name, seed=20):
    X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.25, random_state=seed)

    model = get_model(model_name, seed)
    model.fit(X_train, y_train)

    return model, (X_train, y_train), (X_test, y_test)

def get_model(name, seed=20):
  if name == 'rf':
    model = RandomSurvivalForest(n_estimators=1000, min_samples_split=10, min_samples_leaf=15, n_jobs=-1, random_state=seed)
  elif name == 'cox':
    model = CoxPHSurvivalAnalysis()
  else:
    raise ValueError(f"Unrecognised model name '{name}'!")

  return model

def get_explanations(model, X_test, eps):
    # compute the reference point (average patient)
    X_mean = X_test.mean().to_frame().T

    # find average risk
    y_mean = model.predict(X_mean)
    # predict risk for all
    y_pred = model.predict(X_test)

    # select only those individuals whose change in risk with respect to the reference point is smaller than epsilon (i.e. close to 0)
    sel_mask = np.abs(y_pred - y_mean) < eps

    # use the reference point
    ex = shap.KernelExplainer(model.predict, X_mean)
    explanation = ex(X_test[sel_mask])
    df_shap = pd.DataFrame(explanation.values, columns=X_test.columns)


    sel_data = X_test[sel_mask]

    # get SHAP values of the selected individuals (delta_R close to 0)
    return df_shap, sel_data  

### GBSG2

In [40]:
X, y = load_gbsg2()

grade_str = X.loc[:, "tgrade"].astype(object).values[:, np.newaxis]
grade_num = OrdinalEncoder(categories=[["I", "II", "III"]]).fit_transform(grade_str)

X_no_grade = X.drop("tgrade", axis=1)
Xt = OneHotEncoder().fit_transform(X_no_grade)
Xt.loc[:, "tgrade"] = grade_num

Xt.rename(columns={'horTh=yes': 'horTh', 'menostat=Post': 'menostat'}, inplace=True)

for model_name in ['rf', 'cox']:
    model, (X_train, y_train), (X_test, y_test) = split_and_train(Xt, y, model_name)
    score = model.score(X_test, y_test)
    print(f"{model_name}: {score:.3f}")

rf: 0.675
cox: 0.665


In [42]:
model, (X_train, y_train), (X_test, y_test) = split_and_train(Xt, y, 'rf')
eps = model.predict(X_train).std()/10.0
ex_result, org_df = get_explanations(model, X_test, eps)

100%|██████████| 27/27 [00:20<00:00,  1.30it/s]


In [43]:
for c in ex_result.columns:
    corr = pearsonr(ex_result[c].to_numpy(), org_df[c].to_numpy())
    print(f"{c}: {corr[0]:.2f} ({corr[1]:.2f})")

age: 0.18 (0.37)
estrec: -0.66 (0.00)
horTh: -0.98 (0.00)
menostat: -0.61 (0.00)
pnodes: 0.78 (0.00)
progrec: -0.82 (0.00)
tsize: 0.92 (0.00)
tgrade: 0.98 (0.00)


### ACT

In [17]:
X, y = load_aids()

cat_cols = ['hemophil', 'ivdrug', 'karnof', 'raceth', 'sex', 'strat2', 'tx', 'txgrp']
X[cat_cols] = OrdinalEncoder().fit_transform(X[cat_cols])

for model_name in ['rf', 'cox']:
    model, (X_train, y_train), (X_test, y_test) = split_and_train(X, y, model_name)
    score = model.score(X_test, y_test)
    print(f"{model_name}: {score:.3f}")

rf: 0.732
cox: 0.725


In [18]:
model, (X_train, y_train), (X_test, y_test) = split_and_train(X, y, 'rf')
eps = model.predict(X_train).std()/10.0
ex_result, org_df = get_explanations(model, X_test, eps)

100%|██████████| 34/34 [00:20<00:00,  1.65it/s]


In [39]:
for c in ex_result.columns:
    corr = pearsonr(ex_result[c].to_numpy(), org_df[c].to_numpy())
    print(f"{c}: {corr[0]:.2f} ({corr[1]:.2f})")

age: 0.27 (0.12)
cd4: -0.69 (0.00)
hemophil: -0.92 (0.00)
ivdrug: -0.97 (0.00)
karnof: -0.07 (0.69)
priorzdv: -0.08 (0.64)
raceth: -0.84 (0.00)
sex: 0.93 (0.00)
strat2: -1.00 (0.00)
tx: -0.91 (0.00)
txgrp: -0.91 (0.00)


### WHAS

In [44]:
X, y = load_whas500()

cat_cols = ['afb', 'av3', 'chf', 'cvd', 'gender', 'miord', 'mitype', 'sho']
X[cat_cols] = OrdinalEncoder().fit_transform(X[cat_cols])

for model_name in ['rf', 'cox']:
    model, (X_train, y_train), (X_test, y_test) = split_and_train(X, y, model_name)
    score = model.score(X_test, y_test)
    print(f"{model_name}: {score:.3f}")

rf: 0.802
cox: 0.785


In [45]:
model, (X_train, y_train), (X_test, y_test) = split_and_train(X, y, 'rf')
eps = model.predict(X_train).std()/10.0
ex_result, org_df = get_explanations(model, X_test, eps)

100%|██████████| 12/12 [00:09<00:00,  1.31it/s]


In [46]:
for c in ex_result.columns:
    corr = pearsonr(ex_result[c].to_numpy(), org_df[c].to_numpy())
    print(f"{c}: {corr[0]:.2f} ({corr[1]:.2f})")

afb: nan (nan)
age: 0.78 (0.00)
av3: nan (nan)
bmi: -0.69 (0.01)
chf: nan (nan)
cvd: 0.26 (0.41)
diasbp: -0.92 (0.00)
gender: 0.84 (0.00)
hr: 0.89 (0.00)
los: 0.86 (0.00)
miord: 0.55 (0.06)
mitype: -0.82 (0.00)
sho: nan (nan)
sysbp: -0.87 (0.00)


  corr = pearsonr(ex_result[c].to_numpy(), org_df[c].to_numpy())
