Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Scalability over larger datasets #5

Open
Addicted-to-coding opened this issue Nov 25, 2022 · 5 comments
Open

Scalability over larger datasets #5

Addicted-to-coding opened this issue Nov 25, 2022 · 5 comments

Comments

@Addicted-to-coding
Copy link

Hi,

Is your method scalable over larger datasets? I tried running this method on a dataset of size (10000, 8) and got an estimated run time as below. This should not be the case since your own test dataset is of size (300,8) and the time per iteration is low. Are you retraining the model for computing the shape values for each example? It is not clear to me why the time per iteration has increased so much given the number of features is the same.

image

@krzyzinskim
Copy link
Collaborator

Hi,
Currently, the exact values of Kernel SHAP are calculated. However, if you have 8 variables, it shouldn't be a problem (as you can see in the experiments). We're not retraining the model so it's not the cause.

I'm interested in where this problem might come from but I can't reproduce it. Can you share a minimal reproducible code example with this dataset or any other causing this problem?

It's also related to #4, I believe.

@Addicted-to-coding
Copy link
Author

Addicted-to-coding commented Nov 25, 2022

Thanks for the prompt reply. Yes, this is related to issue #4 which I posted earlier. We can reproduce this by create a random array of size (1000,8). Here's a simple example on how to reproduce it

create random datasets

X=np.random.rand(1000,8)
y=np.random.rand(1000,1)
boo=np.random.choice(a=[True,False],size=(1000,1),p=[0.5,0.5])
out=np.empty(1000,dtype=[('event','?'),('time','<f8')]
boo=boo.reshape(1000,)
out=out.reshape(1000,)
out['event']=boo
out['time']=y
X=pd.DataFrame(X,columns=['f1','f2','f3','f4','f5','f6','f7','f8'])

run random survival forest

from sksurv.ensemble import RandomSurvivalForest
rsf=RandomSurvivalForest(random_state=42,n_estimators=120, max_depth=8,min_sample_leaf=4,max_features=3)
rsf.fit(X,out)
rsf.score(X,out)

run survshap

from survshap import SurvivalModelExplainer,ModelSurvShap
rsf=SurvivalModelExplainer(rsf,X,out)

pnd_survshap_global_rsf=ModelSurvShap(random_state=42)
pnd_survshap_global_rsf.fit(rsf_pnd)

Produces the following output
image

@Addicted-to-coding
Copy link
Author

Hi,
I was wondering if you were able to reproduce this on your end and had any solutions?

@solidate
Copy link

solidate commented Feb 9, 2023

Hi,
I am also facing the same issue as mentioned by @Addicted-to-coding .
The fit method seems to be awfully slow.

@krzyzinskim Were you able to reproduce this?

@hbaniecki
Copy link
Member

Hi @Addicted-to-coding @solidate, it's expected to be slow. The implemented (default) algorithm aims to "exactly" approximate Shapley values and therefore is useful for relatively small (background) datasets. So you can probably compute SurvSHAP(t) for 1000+ samples, but when using 100-200 samples as the background for estimation.

Another way to speed up calculations is to reduce the number of timestamps (parameter timestamps in fit() method) at which the survival function is predicted. By default, values are predicted for each unique event time, which in the case of 1000 observations can be a lot of timestamps.

Also, RSF has a slow inference adding to the time. See the comparison with a simpler CPH model.


import numpy as np
import pandas as pd
from survshap import SurvivalModelExplainer, ModelSurvSHAP

X=np.random.rand(1000,8)
y=np.random.rand(1000,1)
boo=np.random.choice(a=[True,False],size=(1000,1),p=[0.5,0.5])
out=np.empty(1000,dtype=[('event','?'),('time','<f8')])
out['event']=boo.reshape(-1)
out['time']=y.reshape(-1)
X=pd.DataFrame(X,columns=['f1','f2','f3','f4','f5','f6','f7','f8'])

from sksurv.linear_model import CoxPHSurvivalAnalysis
cph = CoxPHSurvivalAnalysis()
cph.fit(X, out)
cph.score(X, out)

from sksurv.ensemble import RandomSurvivalForest
rsf=RandomSurvivalForest(random_state=42,n_estimators=120, max_depth=8,max_features=3)
rsf.fit(X,out)
rsf.score(X,out)

exp_cph=SurvivalModelExplainer(cph,X,out)
ms_cph=ModelSurvSHAP(random_state=42)
ms_cph.fit(exp_cph)

exp_rsf=SurvivalModelExplainer(rsf,X,out)
ms_rsf=ModelSurvSHAP(random_state=42)
ms_rsf.fit(exp_rsf)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

4 participants