Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Explaining Predictions with SHAP #958

Merged
merged 11 commits into from Jul 24, 2020

Conversation

freddyaboulton
Copy link
Contributor

@freddyaboulton freddyaboulton commented Jul 21, 2020

Pull Request Description

Closes #952 (task 1 of the prediction explanations epic). The only difference is that I did not implement a compute_features function in PipelineBase since the _transform method already did what I needed.


After creating the pull request: in order to pass the release_notes_updated check you will need to update the "Future Release" section of docs/source/release_notes.rst to include this pull request by adding :pr:123.

@freddyaboulton freddyaboulton marked this pull request as draft Jul 21, 2020
@codecov
Copy link

codecov bot commented Jul 21, 2020

Codecov Report

Merging #958 into main will increase coverage by 0.00%.
The diff coverage is 100.00%.

Impacted file tree graph

@@           Coverage Diff            @@
##             main     #958    +/-   ##
========================================
  Coverage   99.87%   99.87%            
========================================
  Files         172      174     +2     
  Lines        8824     8987   +163     
========================================
+ Hits         8813     8976   +163     
  Misses         11       11            
Impacted Files Coverage Δ
evalml/pipelines/pipeline_base.py 100.00% <ø> (ø)
evalml/model_family/model_family.py 100.00% <100.00%> (ø)
...l/pipelines/prediction_explanations/_algorithms.py 100.00% <100.00%> (ø)
...peline_tests/explanations_tests/test_algorithms.py 100.00% <100.00%> (ø)
evalml/tests/utils_tests/test_cli_utils.py 100.00% <100.00%> (ø)

Continue to review full report at Codecov.

Legend - Click here to learn more
Δ = absolute <relative> (impact), ø = not affected, ? = missing data
Powered by Codecov. Last update 2c440a6...56f853b. Read the comment docs.

@freddyaboulton freddyaboulton force-pushed the 952-explain-predictions-with-shap branch from 12da727 to ca19d39 Compare Jul 21, 2020
@@ -10,3 +10,4 @@ cloudpickle>=0.2.2
click>=7.0.0
psutil>=5.6.3
requirements-parser>=0.2.0
shap>=0.35.0
Copy link
Contributor Author

@freddyaboulton freddyaboulton Jul 21, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this is a reasonable core-requirement but I can move it to requirements.txt if needed (although this would require delaying the shap import in _algorithms to get the testing to work.

Copy link
Contributor

@angela97lin angela97lin Jul 22, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks like the wheel is just ~309.4 kB, but curious about AM and what they'd think 🤔

Copy link
Collaborator

@dsherry dsherry Jul 24, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm fine to add this to core requirements. I just checked the installer size and its under 1MB. As long as there aren't any hairy pip dependency clashes, we should be good to go.

raise NotImplementedError("SHAP values cannot currently be computed for xgboost models.")
if estimator.model_family == ModelFamily.CATBOOST and pipeline.problem_type == ProblemTypes.MULTICLASS:
# Will randomly segfault
raise NotImplementedError("SHAP values cannot currently be computed for catboost models for multiclass problems.")
Copy link
Contributor Author

@freddyaboulton freddyaboulton Jul 21, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not sure why the segfaults are happening but it is definitely related to the shap library. I think it is ok to not support catboost multiclass estimators for now while I investigate.

@freddyaboulton freddyaboulton marked this pull request as ready for review Jul 21, 2020
return {feature_name: scaled_values[:, i].tolist() for i, feature_name in enumerate(sorted_feature_names)}


def _normalize_shap_values(values):
Copy link
Contributor Author

@freddyaboulton freddyaboulton Jul 22, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am planning on using _normalize_shap_values for the implementation of compute_prediction (will be added in the next PR in the epic along with display_table):

def explain_prediction(pipeline, features, training_data):
    shap_values = _compute_shap_values(pipeline, features, training_data)
    normalized_values = _normalize_shap_values(shap_values)
    display_table(normalized_values)

@freddyaboulton freddyaboulton force-pushed the 952-explain-predictions-with-shap branch 2 times, most recently from 97d439f to 89ad79c Compare Jul 22, 2020
Copy link
Contributor

@jeremyliweishih jeremyliweishih left a comment

LGTM. Are SHAP values deterministic? Might be useful to include a test that checks values for a dummy dataset (maybe certain edge cases)

from evalml.problem_types.problem_types import ProblemTypes


def _is_tree_estimator(model_family):
Copy link
Contributor

@jeremyliweishih jeremyliweishih Jul 22, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Might be nice to include this as part of the enum implementation but not blocking ofc

Copy link
Contributor Author

@freddyaboulton freddyaboulton Jul 23, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Great idea!

Returns:
dictionary
"""
assert isinstance(shap_values, np.ndarray), "SHAP values must be stored in a numpy array!"
Copy link
Contributor

@jeremyliweishih jeremyliweishih Jul 22, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe raise a ValueError instead here to fit our usual convention

raise NotImplementedError("SHAP values cannot currently be computed for catboost models for multiclass problems.")
# Use tree_path_dependent to avoid linear runtime with dataset size
with warnings.catch_warnings(record=True):
explainer = shap.TreeExplainer(estimator._component_obj, feature_perturbation="tree_path_dependent")
Copy link
Contributor

@jeremyliweishih jeremyliweishih Jul 22, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What kind of warnings pop up here? If we're setting record=True maybe we can process into a more readable form for users.

Copy link
Contributor

@angela97lin angela97lin Jul 22, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Agreed, is there a way we can log these warnings?

Copy link
Contributor Author

@freddyaboulton freddyaboulton Jul 22, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yea I think it's a good idea to log these warnings! shap issues the following warning (Setting feature_perturbation = \"tree_path_dependent\" because no background data was given) when you don't pass in training data to the TreeExplainer but their example code doesn't do so either so I didn't want to show it to the user and have them worry.

@patch("evalml.pipelines.explanations._algorithms.shap.TreeExplainer")
def test_value_errors_raised(mock_tree_explainer, pipeline, exception, match):

if "xgboost" in pipeline.name.lower():
Copy link
Contributor

@jeremyliweishih jeremyliweishih Jul 22, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

very cool

Copy link
Contributor

@angela97lin angela97lin left a comment

Added nonblocking comments but implementation looks good!

Q: do we need evalml/pipelines/explanations/__init__.py (maybe yes for later when we add public methods) and evalml/tests/pipeline_tests/explanations_tests/__init__.py? Both are empty files 🤔

@@ -10,3 +10,4 @@ cloudpickle>=0.2.2
click>=7.0.0
psutil>=5.6.3
requirements-parser>=0.2.0
shap>=0.35.0
Copy link
Contributor

@angela97lin angela97lin Jul 22, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks like the wheel is just ~309.4 kB, but curious about AM and what they'd think 🤔

raise NotImplementedError("SHAP values cannot currently be computed for catboost models for multiclass problems.")
# Use tree_path_dependent to avoid linear runtime with dataset size
with warnings.catch_warnings(record=True):
explainer = shap.TreeExplainer(estimator._component_obj, feature_perturbation="tree_path_dependent")
Copy link
Contributor

@angela97lin angela97lin Jul 22, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Agreed, is there a way we can log these warnings?

explainer = shap.KernelExplainer(decision_function, sampled_training_data_features, link_function)
shap_values = explainer.shap_values(pipeline_features)
Copy link
Contributor

@angela97lin angela97lin Jul 22, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same here, what warnings are we trying to catch, and why?



def _normalize_values_dict(values):

Copy link
Contributor

@angela97lin angela97lin Jul 22, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we add a docstring for this even though it's private since logic is complex enough that it'd be helpful? :D

Copy link
Contributor Author

@freddyaboulton freddyaboulton Jul 22, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Totally.

new_min (float): New minimum value.
new_max (float): New maximum value.
Copy link
Contributor

@angela97lin angela97lin Jul 22, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could just be cause this isn't implemented yet but these are listed in docstring but not in code?

Copy link
Contributor Author

@freddyaboulton freddyaboulton Jul 22, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Typo! Thanks for catching!

def not_xgboost_or_baseline(estimator):
"""Filter out xgboost and baselines for next test since they are not supported."""
return estimator.model_family not in {ModelFamily.XGBOOST, ModelFamily.BASELINE}
Copy link
Contributor

@angela97lin angela97lin Jul 22, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does this need to be a separate method if it's only used once? Could just calculate in line for interpretable_estimators

Copy link
Contributor Author

@freddyaboulton freddyaboulton Jul 22, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I was itching to use filter and I knew pylint wouldn't like a lambda function so I went for a function. I think a simple list comprehension here would do 👍

all_n_points_to_explain = [1, 5]


@pytest.mark.parametrize("estimator, problem_type,n_points_to_explain",
Copy link
Contributor

@angela97lin angela97lin Jul 22, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Spacing ("estimator, problem_type,n_points_to_explain") is inconsistent heh

@angela97lin
Copy link
Contributor

angela97lin commented Jul 22, 2020

@freddyaboulton Also, this is some really cool stuff!! I had a fun time reading through the DD and this :D

@freddyaboulton freddyaboulton force-pushed the 952-explain-predictions-with-shap branch from 89ad79c to ea948c5 Compare Jul 23, 2020
Copy link
Collaborator

@dsherry dsherry left a comment

Looks great!! I left some comments and questions but nothing blocking.

Two closing thoughts on the code organization:

  • Let's move explanations to prediction_explanations
  • I think its fine to name the algo file algorithms.py instead of _algorithms.py. I know we can easily move this in your next PRs too, just throwing that out there. I'm in favor of keeping the methods inside prefixed with _ to indicate they're private, as you have done.

@@ -10,3 +10,4 @@ cloudpickle>=0.2.2
click>=7.0.0
psutil>=5.6.3
requirements-parser>=0.2.0
shap>=0.35.0
Copy link
Collaborator

@dsherry dsherry Jul 24, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm fine to add this to core requirements. I just checked the installer size and its under 1MB. As long as there aren't any hairy pip dependency clashes, we should be good to go.

"""Checks whether the estimator's model family uses tree ensembles."""
tree_estimators = {cls.CATBOOST, cls.EXTRA_TREES, cls.RANDOM_FOREST,
cls.XGBOOST}
return model_family in tree_estimators
Copy link
Collaborator

@dsherry dsherry Jul 24, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

👍

# Sklearn components do this under-the-hood so we're not changing the data the model was trained on.
pipeline_features = check_array(pipeline_features.values)

if ModelFamily.is_tree_estimator(estimator.model_family):
Copy link
Collaborator

@dsherry dsherry Jul 24, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Style suggestion: estimator.model_family.is_tree_estimator ?

if "xgboost" in pipeline.name.lower():
pytest.importorskip("xgboost", "Skipping test because xgboost is not installed.")
if "catboost" in pipeline.name.lower():
pytest.importorskip("catboost", "Skipping test because catboost is not installed.")
Copy link
Collaborator

@dsherry dsherry Jul 24, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice

assert all(len(values) == N_FEATURES for values in shap_values), "A SHAP value must be computed for every feature!"
for class_values in shap_values:
assert all(isinstance(feature, list) for feature in class_values.values()), "Every value in the dict must be a list!"
assert all(len(v) == n_points_to_explain for v in class_values.values()), "A SHAP value must be computed for every data point to explain!"
Copy link
Collaborator

@dsherry dsherry Jul 24, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Awesome! Looks good.

This is reminding me we'd benefit from having some sort of fuzzy match checker in our unit tests, because then we could simply say assert shap_values == ...

if isinstance(shap_values, dict):
check_regression(shap_values, n_points_to_explain=n_points_to_explain)
else:
check_classification(shap_values, True, n_points_to_explain=n_points_to_explain)
Copy link
Collaborator

@dsherry dsherry Jul 24, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So if the problem type was binary, why would we call check_regression here?

Copy link
Contributor Author

@freddyaboulton freddyaboulton Jul 24, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, but I followed your guidance in your comment below and this is no longer the case!

training_data, y = X_y_regression
pipeline_class = make_pipeline(training_data, y, estimator, problem_type)
shap_values = calculate_shap_for_test(training_data, y, pipeline_class, n_points_to_explain)
check_regression(shap_values, n_points_to_explain)
Copy link
Collaborator

@dsherry dsherry Jul 24, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I wonder if this test would get simpler if you did

if problem_type == ProblemTypes.BINARY:
    training_data, y = X_y_binary
elif problem_type == ProblemTypes.MULTICLASS:
    training_data, y = X_y_multi
elif problem_type == ProblemTypes.REGRESSION:
    training_data, y = X_y_regression

then

pipeline_class = make_pipeline(training_data, y, estimator, problem_type)

then

if problem_type in [ProblemTypes.BINARY, ProblemTypes.MULTICLASS]:
    check_classification(shap_values, False, n_points_to_explain)
elif problem_type == ProblemTypes.REGRESSION:
    check_regression(shap_values, n_points_to_explain)

And at that point, there'd be no need for the check_classification/check_regression helpers because you could just put that code at the bottom

Copy link
Contributor Author

@freddyaboulton freddyaboulton Jul 24, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Great idea! Thanks for the thoughtful tip!

if estimator.model_family == ModelFamily.BASELINE:
raise ValueError("You passed in a baseline pipeline. These are simple enough that SHAP values are not needed.")

pipeline_features = pipeline._transform(features)
Copy link
Collaborator

@dsherry dsherry Jul 24, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@freddyaboulton could you please add a docstring in PipelineBase._transform? I think its fine to be calling this here, and it also feels fine to not make PipelineBase._transform public for now. But a docstring would help.

Copy link
Contributor Author

@freddyaboulton freddyaboulton Jul 24, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Definitely.

…ow output a list for Catboost binary problems.
@freddyaboulton
Copy link
Contributor Author

freddyaboulton commented Jul 24, 2020

@dsherry Thanks for the feedback! I moved explanations to prediction_explanations and I'll remove the leading underscores from function names in a later PR!

@freddyaboulton freddyaboulton merged commit e591395 into main Jul 24, 2020
2 checks passed
@freddyaboulton freddyaboulton deleted the 952-explain-predictions-with-shap branch Jul 24, 2020
@angela97lin angela97lin mentioned this pull request Jul 31, 2020
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Explain Predictions with SHAP algorithm
4 participants