-
Notifications
You must be signed in to change notification settings - Fork 87
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Explaining Predictions with SHAP #958
Conversation
Codecov Report
@@ Coverage Diff @@
## main #958 +/- ##
========================================
Coverage 99.87% 99.87%
========================================
Files 172 174 +2
Lines 8824 8987 +163
========================================
+ Hits 8813 8976 +163
Misses 11 11
Continue to review full report at Codecov.
|
12da727
to
ca19d39
Compare
@@ -10,3 +10,4 @@ cloudpickle>=0.2.2 | |||
click>=7.0.0 | |||
psutil>=5.6.3 | |||
requirements-parser>=0.2.0 | |||
shap>=0.35.0 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think this is a reasonable core-requirement but I can move it to requirements.txt
if needed (although this would require delaying the shap import in _algorithms
to get the testing to work.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks like the wheel is just ~309.4 kB, but curious about AM and what they'd think 🤔
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm fine to add this to core requirements. I just checked the installer size and its under 1MB. As long as there aren't any hairy pip dependency clashes, we should be good to go.
raise NotImplementedError("SHAP values cannot currently be computed for xgboost models.") | ||
if estimator.model_family == ModelFamily.CATBOOST and pipeline.problem_type == ProblemTypes.MULTICLASS: | ||
# Will randomly segfault | ||
raise NotImplementedError("SHAP values cannot currently be computed for catboost models for multiclass problems.") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Not sure why the segfaults are happening but it is definitely related to the shap library. I think it is ok to not support catboost multiclass estimators for now while I investigate.
return {feature_name: scaled_values[:, i].tolist() for i, feature_name in enumerate(sorted_feature_names)} | ||
|
||
|
||
def _normalize_shap_values(values): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I am planning on using _normalize_shap_values for the implementation of compute_prediction
(will be added in the next PR in the epic along with display_table
):
def explain_prediction(pipeline, features, training_data):
shap_values = _compute_shap_values(pipeline, features, training_data)
normalized_values = _normalize_shap_values(shap_values)
display_table(normalized_values)
97d439f
to
89ad79c
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM. Are SHAP values deterministic? Might be useful to include a test that checks values for a dummy dataset (maybe certain edge cases)
from evalml.problem_types.problem_types import ProblemTypes | ||
|
||
|
||
def _is_tree_estimator(model_family): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Might be nice to include this as part of the enum implementation but not blocking ofc
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Great idea!
Returns: | ||
dictionary | ||
""" | ||
assert isinstance(shap_values, np.ndarray), "SHAP values must be stored in a numpy array!" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Maybe raise a ValueError instead here to fit our usual convention
raise NotImplementedError("SHAP values cannot currently be computed for catboost models for multiclass problems.") | ||
# Use tree_path_dependent to avoid linear runtime with dataset size | ||
with warnings.catch_warnings(record=True): | ||
explainer = shap.TreeExplainer(estimator._component_obj, feature_perturbation="tree_path_dependent") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What kind of warnings pop up here? If we're setting record=True
maybe we can process into a more readable form for users.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Agreed, is there a way we can log these warnings?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yea I think it's a good idea to log these warnings! shap issues the following warning (Setting feature_perturbation = \"tree_path_dependent\" because no background data was given
) when you don't pass in training data to the TreeExplainer
but their example code doesn't do so either so I didn't want to show it to the user and have them worry.
@patch("evalml.pipelines.explanations._algorithms.shap.TreeExplainer") | ||
def test_value_errors_raised(mock_tree_explainer, pipeline, exception, match): | ||
|
||
if "xgboost" in pipeline.name.lower(): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
very cool
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Added nonblocking comments but implementation looks good!
Q: do we need evalml/pipelines/explanations/__init__.py
(maybe yes for later when we add public methods) and evalml/tests/pipeline_tests/explanations_tests/__init__.py
? Both are empty files 🤔
@@ -10,3 +10,4 @@ cloudpickle>=0.2.2 | |||
click>=7.0.0 | |||
psutil>=5.6.3 | |||
requirements-parser>=0.2.0 | |||
shap>=0.35.0 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks like the wheel is just ~309.4 kB, but curious about AM and what they'd think 🤔
raise NotImplementedError("SHAP values cannot currently be computed for catboost models for multiclass problems.") | ||
# Use tree_path_dependent to avoid linear runtime with dataset size | ||
with warnings.catch_warnings(record=True): | ||
explainer = shap.TreeExplainer(estimator._component_obj, feature_perturbation="tree_path_dependent") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Agreed, is there a way we can log these warnings?
explainer = shap.KernelExplainer(decision_function, sampled_training_data_features, link_function) | ||
shap_values = explainer.shap_values(pipeline_features) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Same here, what warnings are we trying to catch, and why?
|
||
|
||
def _normalize_values_dict(values): | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can we add a docstring for this even though it's private since logic is complex enough that it'd be helpful? :D
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Totally.
new_min (float): New minimum value. | ||
new_max (float): New maximum value. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could just be cause this isn't implemented yet but these are listed in docstring but not in code?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Typo! Thanks for catching!
def not_xgboost_or_baseline(estimator): | ||
"""Filter out xgboost and baselines for next test since they are not supported.""" | ||
return estimator.model_family not in {ModelFamily.XGBOOST, ModelFamily.BASELINE} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Does this need to be a separate method if it's only used once? Could just calculate in line for interpretable_estimators
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I was itching to use filter
and I knew pylint wouldn't like a lambda function so I went for a function. I think a simple list comprehension here would do 👍
all_n_points_to_explain = [1, 5] | ||
|
||
|
||
@pytest.mark.parametrize("estimator, problem_type,n_points_to_explain", |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Spacing ("estimator, problem_type,n_points_to_explain") is inconsistent heh
@freddyaboulton Also, this is some really cool stuff!! I had a fun time reading through the DD and this :D |
89ad79c
to
ea948c5
Compare
ee74784
to
88ba8f4
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks great!! I left some comments and questions but nothing blocking.
Two closing thoughts on the code organization:
- Let's move
explanations
toprediction_explanations
- I think its fine to name the algo file
algorithms.py
instead of_algorithms.py
. I know we can easily move this in your next PRs too, just throwing that out there. I'm in favor of keeping the methods inside prefixed with_
to indicate they're private, as you have done.
@@ -10,3 +10,4 @@ cloudpickle>=0.2.2 | |||
click>=7.0.0 | |||
psutil>=5.6.3 | |||
requirements-parser>=0.2.0 | |||
shap>=0.35.0 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm fine to add this to core requirements. I just checked the installer size and its under 1MB. As long as there aren't any hairy pip dependency clashes, we should be good to go.
evalml/model_family/model_family.py
Outdated
"""Checks whether the estimator's model family uses tree ensembles.""" | ||
tree_estimators = {cls.CATBOOST, cls.EXTRA_TREES, cls.RANDOM_FOREST, | ||
cls.XGBOOST} | ||
return model_family in tree_estimators |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
👍
# Sklearn components do this under-the-hood so we're not changing the data the model was trained on. | ||
pipeline_features = check_array(pipeline_features.values) | ||
|
||
if ModelFamily.is_tree_estimator(estimator.model_family): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Style suggestion: estimator.model_family.is_tree_estimator
?
if "xgboost" in pipeline.name.lower(): | ||
pytest.importorskip("xgboost", "Skipping test because xgboost is not installed.") | ||
if "catboost" in pipeline.name.lower(): | ||
pytest.importorskip("catboost", "Skipping test because catboost is not installed.") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nice
assert all(len(values) == N_FEATURES for values in shap_values), "A SHAP value must be computed for every feature!" | ||
for class_values in shap_values: | ||
assert all(isinstance(feature, list) for feature in class_values.values()), "Every value in the dict must be a list!" | ||
assert all(len(v) == n_points_to_explain for v in class_values.values()), "A SHAP value must be computed for every data point to explain!" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Awesome! Looks good.
This is reminding me we'd benefit from having some sort of fuzzy match checker in our unit tests, because then we could simply say assert shap_values == ...
if isinstance(shap_values, dict): | ||
check_regression(shap_values, n_points_to_explain=n_points_to_explain) | ||
else: | ||
check_classification(shap_values, True, n_points_to_explain=n_points_to_explain) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
So if the problem type was binary
, why would we call check_regression
here?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, but I followed your guidance in your comment below and this is no longer the case!
training_data, y = X_y_regression | ||
pipeline_class = make_pipeline(training_data, y, estimator, problem_type) | ||
shap_values = calculate_shap_for_test(training_data, y, pipeline_class, n_points_to_explain) | ||
check_regression(shap_values, n_points_to_explain) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I wonder if this test would get simpler if you did
if problem_type == ProblemTypes.BINARY:
training_data, y = X_y_binary
elif problem_type == ProblemTypes.MULTICLASS:
training_data, y = X_y_multi
elif problem_type == ProblemTypes.REGRESSION:
training_data, y = X_y_regression
then
pipeline_class = make_pipeline(training_data, y, estimator, problem_type)
then
if problem_type in [ProblemTypes.BINARY, ProblemTypes.MULTICLASS]:
check_classification(shap_values, False, n_points_to_explain)
elif problem_type == ProblemTypes.REGRESSION:
check_regression(shap_values, n_points_to_explain)
And at that point, there'd be no need for the check_classification
/check_regression
helpers because you could just put that code at the bottom
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Great idea! Thanks for the thoughtful tip!
if estimator.model_family == ModelFamily.BASELINE: | ||
raise ValueError("You passed in a baseline pipeline. These are simple enough that SHAP values are not needed.") | ||
|
||
pipeline_features = pipeline._transform(features) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@freddyaboulton could you please add a docstring in PipelineBase._transform
? I think its fine to be calling this here, and it also feels fine to not make PipelineBase._transform
public for now. But a docstring would help.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Definitely.
…ow output a list for Catboost binary problems.
@dsherry Thanks for the feedback! I moved |
Pull Request Description
Closes #952 (task 1 of the prediction explanations epic). The only difference is that I did not implement a
compute_features
function inPipelineBase
since the_transform
method already did what I needed.After creating the pull request: in order to pass the release_notes_updated check you will need to update the "Future Release" section of
docs/source/release_notes.rst
to include this pull request by adding :pr:123
.