# SurvSHAP(t): Time-Dependent Explanations Of Machine Learning Survival Models
### M. Krzyziński, M. Spytek, H. Baniecki, P. Biecek
## Experiment 2: Comparison to SurvNAM

#### Imports

In [None]:
import pandas as pd
import numpy as np
import pickle
from sksurv.util import Surv

#### Preparing data and models 

In [None]:
dataset0_train = pd.read_csv("data/exp2_dataset0_train.csv")
dataset0_test = pd.read_csv("data/exp2_dataset0_test.csv")
X_train0 = dataset0_train.iloc[:, :5]
X_test0 = dataset0_test.iloc[:, :5]
y_train0 = Surv.from_dataframe("event", "time", dataset0_train)
y_test0 = Surv.from_dataframe("event", "time", dataset0_test)

In [None]:
dataset1_train = pd.read_csv("data/exp2_dataset1_train.csv")
dataset1_test = pd.read_csv("data/exp2_dataset1_test.csv")
X_train1 = dataset1_train.iloc[:, :5]
X_test1 = dataset1_test.iloc[:, :5]
y_train1 = Surv.from_dataframe("event", "time", dataset1_train)
y_test1 = Surv.from_dataframe("event", "time", dataset1_test)

In [None]:
from sksurv.linear_model import CoxPHSurvivalAnalysis
cph_dataset0 = CoxPHSurvivalAnalysis()
cph_dataset0.fit(X_train0, y_train0)

In [None]:
cph_dataset1 = CoxPHSurvivalAnalysis()
cph_dataset1.fit(X_train1, y_train1)

#### Reading explanations
##### SurvNAM dataset0

In [None]:
survnam_dataset0_rsf = pd.read_csv("results/survnam_explanations_dataset0_rsf.csv")
survnam_dataset0_cph = pd.read_csv("results/survnam_explanations_dataset0_cph.csv")

##### SurvNAM dataset1 

In [None]:
survnam_dataset1_rsf = pd.read_csv("results/survnam_explanations_dataset1_rsf.csv")
survnam_dataset1_cph = pd.read_csv("results/survnam_explanations_dataset1_cph.csv")

#### Importance rankings

In [None]:
def get_orderings_and_ranks_survnam(explanations):
    importance_orderings = []
    importance_ranks = []
    for i, row in explanations.iterrows():
        importance_orderings.append(row.sort_values(key=lambda x: -abs(x)).index.to_list())
        importance_ranks.append(np.abs(row).rank(ascending=False).to_list())
    return pd.DataFrame(importance_orderings), pd.DataFrame(importance_ranks)

from scipy.stats import weightedtau
def mean_weighted_tau(ranks1, ranks2):
    taus = [None] * 100
    for i in range(100):
        tau, _ = weightedtau(ranks1.iloc[i], ranks2.iloc[i])
        if np.isnan(tau):
            tau = 0
        taus[i] = tau
    return np.mean(taus)

##### dataset0
- $\beta^T = [10^{−6}, 0.1, -0.15, 10^{−6}, 10^{−6}]$
- ranking (by index): [0/3/4, 1, 2]

In [None]:
cph_dataset0.coef_

##### CPH

In [None]:
dataset0_cph_survnam_orderings, dataset0_cph_survnam_ranks  = get_orderings_and_ranks_survnam(survnam_dataset0_cph)

In [None]:
print("The least important (0/3/4)")
print(dataset0_cph_survnam_orderings[4].value_counts())

print("The second most important (1)")
print(dataset0_cph_survnam_orderings[1].value_counts())

print("The most important (2)")
print(dataset0_cph_survnam_orderings[0].value_counts())

In [None]:
# GT CPH
importance_ranks = []
for i, row in X_test0.iterrows():
    impact = row * cph_dataset0.coef_
    importance_ranks.append(np.abs(impact).rank(ascending=False).to_list())
dataset0_cph_true_ranks = pd.DataFrame(importance_ranks)

In [None]:
mean_weighted_tau(dataset0_cph_survnam_orderings, dataset0_cph_true_ranks)

#### RSF 

In [None]:
dataset0_rsf_survnam_orderings, dataset0_rsf_survnam_ranks  = get_orderings_and_ranks_survnam(survnam_dataset0_rsf)

In [None]:
print("The least important (0/3/4)")
print(dataset0_rsf_survnam_orderings[4].value_counts())

print("The second most important (1)")
print(dataset0_rsf_survnam_orderings[1].value_counts())

print("The most important (2)")
print(dataset0_rsf_survnam_orderings[0].value_counts())

##### dataset1
- $\beta^T = [10^{−6}, −0.15, 10^{−6}, 10^{−6}, −0.1]$
- ranking (by index): [0/2/3, 4, 1]

In [None]:
cph_dataset1.coef_

#### CPH

In [None]:
dataset1_cph_survnam_orderings, dataset1_cph_survnam_ranks  = get_orderings_and_ranks_survnam(survnam_dataset1_cph)

In [None]:
print("The least important (0/2/3)")
print(dataset1_cph_survnam_orderings[4].value_counts())

print("The second most important (4)")
print(dataset1_cph_survnam_orderings[1].value_counts())

print("The most important (1)")
print(dataset1_cph_survnam_orderings[0].value_counts())

In [None]:
# GT CPH
importance_ranks = [] 
for i, row in X_test1.iterrows():
    impact = row * cph_dataset1.coef_
    importance_ranks.append(np.abs(impact).rank(ascending=False).to_list())
dataset1_cph_true_ranks = pd.DataFrame(importance_ranks)

In [None]:
mean_weighted_tau(dataset1_cph_survnam_orderings, dataset1_cph_true_ranks)

#### RSF 

In [None]:
dataset1_rsf_survnam_orderings, dataset1_rsf_survnam_ranks  = get_orderings_and_ranks_survnam(survnam_dataset1_rsf)

In [None]:
print("The least important (0/2/3)")
print(dataset1_rsf_survnam_orderings[4].value_counts())

print("The second most important (4)")
print(dataset1_rsf_survnam_orderings[1].value_counts())

print("The most important (1)")
print(dataset1_rsf_survnam_orderings[0].value_counts())