-
Notifications
You must be signed in to change notification settings - Fork 0
/
explanation.py
31 lines (27 loc) · 1.17 KB
/
explanation.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
import shap
import os.path
import utils
ROOT = os.path.abspath(os.path.join(__file__, '../'))
def explain_model(model, X, show):
# Visualization for feature importance with SHAP (global)
shap.initjs()
explainer = shap.TreeExplainer(model)
shap_values = explainer.shap_values(X)
shap.summary_plot(shap_values, X, plot_type="bar", show=show)
utils.save_picture(os.path.join(ROOT, 'outputs/summary_plot_bar.png'))
utils.clear_plot()
shap.dependence_plot("post_length", shap_values, X, show=show)
utils.save_picture(os.path.join(ROOT, 'outputs/dependence_plot.png'))
utils.clear_plot()
shap.summary_plot(shap_values, X, show=show)
utils.save_picture(os.path.join(ROOT, 'outputs/summary_plot.png'))
utils.clear_plot()
def explain_class(model, X, show):
shap.initjs()
explainer = shap.TreeExplainer(model)
shap_values = explainer.shap_values(X)
# visualize the first prediction's explanation
shap.force_plot(explainer.expected_value, shap_values[0, :], X.iloc[0, :], matplotlib=True, show=show,
link="logit")
utils.save_picture(os.path.join(ROOT, 'outputs/force_plot_post.png'))
utils.clear_plot()