Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Compute Shap Values for XGBoost #2162

Merged

Conversation

freddyaboulton
Copy link
Contributor

@freddyaboulton freddyaboulton commented Apr 20, 2021

Pull Request Description

Fixes #1890

Sample output on fraud dataset

GBoost Classifier w/ Imputer + DateTime Featurization Component + One Hot Encoder

{'Imputer': {'categorical_impute_strategy': 'most_frequent', 'numeric_impute_strategy': 'mean', 'categorical_fill_value': None, 'numeric_fill_value': None}, 'DateTime Featurization Component': {'features_to_extract': ['year', 'month', 'day_of_week', 'hour'], 'encode_as_categories': False}, 'One Hot Encoder': {'top_n': 10, 'features_to_encode': None, 'categories': None, 'drop': 'if_binary', 'handle_unknown': 'ignore', 'handle_missing': 'error'}, 'XGBoost Classifier': {'eta': 0.1, 'max_depth': 6, 'min_child_weight': 1, 'n_estimators': 100}}

	Best 1 of 5

		Predicted Probabilities: [False: 0.998, True: 0.002]
		Predicted Value: False
		Target Value: False
		Cross Entropy: 0.002
		Index ID: 159

		Feature Name      Feature Value      Contribution to Prediction   SHAP Value
		============================================================================
		  card_id            670.00                      +                   0.29   
		  store_id           6079.00                     -                  -0.54   
		  provider        VISA 16 digit                  -                  -0.54   
		  datetime     2019-03-27 15:00:00               -                  -0.76   
		   amount           81184.00                     --                 -1.73   


	Best 2 of 5

		Predicted Probabilities: [False: 0.998, True: 0.002]
		Predicted Value: False
		Target Value: False
		Cross Entropy: 0.002
		Index ID: 81

		Feature Name      Feature Value      Contribution to Prediction   SHAP Value
		============================================================================
		  datetime     2019-08-04 19:43:30               -                  -0.33   
		  card_id           29175.00                     -                  -0.43   
		    lat               22.47                      -                  -0.49   
		  store_id           7149.00                     -                  -0.50   
		   amount           76820.00                     --                 -1.35   

Sample output on iris dataset

XGBoost Classifier w/ Imputer

{'Imputer': {'categorical_impute_strategy': 'most_frequent', 'numeric_impute_strategy': 'mean', 'categorical_fill_value': None, 'numeric_fill_value': None}, 'XGBoost Classifier': {'eta': 0.1, 'max_depth': 6, 'min_child_weight': 1, 'n_estimators': 100}}

	Best 1 of 5

		Predicted Probabilities: [setosa: 0.003, versicolor: 0.01, virginica: 0.987]
		Predicted Value: virginica
		Target Value: virginica
		Cross Entropy: 0.013
		Index ID: 26

		Class: setosa
		
		  Feature Name      Feature Value   Contribution to Prediction   SHAP Value
		===========================================================================
		sepal width (cm)        2.50                    +                   0.00   
		petal width (cm)        1.90                    +                   0.00   
		sepal length (cm)       6.30                    -                  -0.01   
		petal length (cm)       5.00                  -----                -2.86   
		
		
		Class: versicolor
		
		  Feature Name      Feature Value   Contribution to Prediction   SHAP Value
		===========================================================================
		sepal length (cm)       6.30                    +                   0.28   
		sepal width (cm)        2.50                    -                  -0.13   
		petal width (cm)        1.90                    --                 -0.79   
		petal length (cm)       5.00                   ---                 -1.10   
		
		
		Class: virginica
		
		  Feature Name      Feature Value   Contribution to Prediction   SHAP Value
		===========================================================================
		petal width (cm)        1.90                   +++                  1.79   
		petal length (cm)       5.00                    ++                  1.61   
		sepal width (cm)        2.50                    +                   0.73   
		sepal length (cm)       6.30                    -                  -0.02   

Sample output boston housing

XGBoost Regressor w/ Imputer

{'Imputer': {'categorical_impute_strategy': 'most_frequent', 'numeric_impute_strategy': 'mean', 'categorical_fill_value': None, 'numeric_fill_value': None}, 'XGBoost Regressor': {'eta': 0.1, 'max_depth': 6, 'min_child_weight': 1, 'n_estimators': 100}}

	Best 1 of 5

		Predicted Value: 10.815
		Target Value: 10.8
		Absolute Difference: 0.015
		Index ID: 40

		Feature Name   Feature Value   Contribution to Prediction   SHAP Value
		======================================================================
		    TAX           666.00                   -                  -0.34   
		  PTRATIO          20.20                   -                  -0.40   
		    NOX            0.74                    -                  -0.61   
		     RM            5.85                    -                  -2.33   
		   LSTAT           23.79                  ----                -9.58   


	Best 2 of 5

		Predicted Value: 20.619
		Target Value: 20.6
		Absolute Difference: 0.019
		Index ID: 83

		Feature Name   Feature Value   Contribution to Prediction   SHAP Value
		======================================================================
		    CRIM           4.84                    +                   1.64   
		    DIS            3.15                    +                   1.05   
		    TAX           666.00                   -                  -0.31   
		     RM            5.91                    --                 -2.71   
		   LSTAT           11.45                   --                 -3.22   

After creating the pull request: in order to pass the release_notes_updated check you will need to update the "Future Release" section of docs/source/release_notes.rst to include this pull request by adding :pr:123.

@codecov
Copy link

codecov bot commented Apr 20, 2021

Codecov Report

Merging #2162 (633f16b) into main (04959d4) will increase coverage by 0.1%.
The diff coverage is 100.0%.

Impacted file tree graph

@@            Coverage Diff            @@
##             main    #2162     +/-   ##
=========================================
+ Coverage   100.0%   100.0%   +0.1%     
=========================================
  Files         295      295             
  Lines       24362    24376     +14     
=========================================
+ Hits        24352    24366     +14     
  Misses         10       10             
Impacted Files Coverage Δ
...derstanding/prediction_explanations/_algorithms.py 98.9% <100.0%> (+1.1%) ⬆️
...s/prediction_explanations_tests/test_algorithms.py 99.2% <100.0%> (-0.8%) ⬇️

Continue to review full report at Codecov.

Legend - Click here to learn more
Δ = absolute <relative> (impact), ø = not affected, ? = missing data
Powered by Codecov. Last update 04959d4...633f16b. Read the comment docs.

@freddyaboulton freddyaboulton force-pushed the 1890-add-support-for-xgboost-prediction-explanations branch from c2cdb5a to 48799a3 Compare April 20, 2021 16:41
@@ -9,7 +9,7 @@ cloudpickle>=0.2.2
click>=7.0.0
psutil>=5.6.3
requirements-parser>=0.2.0
shap>=0.35.0
shap>=0.36.0
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fix was introduced in 0.36.0

Copy link
Contributor

@jeremyliweishih jeremyliweishih left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM just a clarifying question.

@@ -144,6 +140,29 @@ def test_shap(estimator, problem_type, n_points_to_explain, X_y_binary, X_y_mult
shap_values.values()), "A SHAP value must be computed for every data point to explain!"


@patch('evalml.model_understanding.prediction_explanations._algorithms.logger')
@patch('shap.TreeExplainer')
def test_compute_shap_values_catches_shap_tree_warnings(mock_tree_explainer, mock_debug, X_y_binary, caplog):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is this new functionality from 0.36.0 and this test is just making sure we catch it?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for asking! I meant to post a comment but forgot. Yea, for some reason, warnings are not being shown when I upgraded to 0.36. Not sure what the root cause is, but rather than deleting the code in question or just merging an uncovered line, I decided to add a test for it.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

good call!

Copy link
Contributor

@bchen1116 bchen1116 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM! Left a nit-pick, but doesn't block any merge!

@@ -93,7 +89,7 @@ def calculate_shap_for_test(training_data, y, pipeline, n_points_to_explain):
return _compute_shap_values(pipeline, pd.DataFrame(points_to_explain), training_data)


interpretable_estimators = [e for e in _all_estimators_used_in_search() if e.model_family not in {ModelFamily.XGBOOST, ModelFamily.BASELINE}]
interpretable_estimators = [e for e in _all_estimators_used_in_search() if e.model_family not in {ModelFamily.BASELINE}]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

not necessary and very nit, but why not change to e.model_family != ModelFamily.BASELINE just to make it cleaner since it's only 1 object we're checking for

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good catch!

@@ -425,7 +425,7 @@
"source": [
"The interpretation of the table is the same for regression problems - but the SHAP value now corresponds to the change in the estimated value of the dependent variable rather than a change in probability. For multiclass classification problems, a table will be output for each possible class.\n",
"\n",
"This functionality is currently **not supported** for **XGBoost** models or **CatBoost multiclass** classifiers.\n",
"This functionality is currently **not supported** for **CatBoost multiclass** classifiers.\n",
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

To confirm, this works for XGBoost multiclass too then? :D

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes! 🎉

Copy link
Contributor

@angela97lin angela97lin left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just had a clarifying question, but looks good!! Cool stuff that the shap package took care of this and we didn't have to write our own custom impl hehe 🥳

Copy link
Contributor

@dsherry dsherry left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🚢 😁

# this modifies the output to match the output format of other binary estimators.
# Ok to fill values of negative class with zeros since the negative class will get dropped
# in the UI anyways.
if estimator.model_family == ModelFamily.CATBOOST and pipeline.problem_type == ProblemTypes.BINARY:
if estimator.model_family in {ModelFamily.CATBOOST, ModelFamily.XGBOOST} and pipeline.problem_type == ProblemTypes.BINARY:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see, so xgboost's output format is the same as catboost's output format, and therefore the same reshaping is needed here. Sounds good. I think its strange that shap has this behavior!

# this modifies the output to match the output format of other binary estimators.
# Ok to fill values of negative class with zeros since the negative class will get dropped
# in the UI anyways.
if estimator.model_family == ModelFamily.CATBOOST and pipeline.problem_type == ProblemTypes.BINARY:
if estimator.model_family in {ModelFamily.CATBOOST, ModelFamily.XGBOOST} and pipeline.problem_type == ProblemTypes.BINARY:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@freddyaboulton no action needed here in this PR but: this change reminds me of some work we'll need to do eventually for timeseries. There are a lot of places where we do

pipeline.problem_type == ProblemTypes.BINARY

instead of

is_binary(pipeline.problem_type)

which will catch time series binary as well

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Great call @dsherry !

@freddyaboulton freddyaboulton force-pushed the 1890-add-support-for-xgboost-prediction-explanations branch from 557a412 to 633f16b Compare April 22, 2021 14:51
@freddyaboulton freddyaboulton merged commit 3caee9e into main Apr 22, 2021
@freddyaboulton freddyaboulton deleted the 1890-add-support-for-xgboost-prediction-explanations branch April 22, 2021 15:44
This was referenced May 4, 2021
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Prediction explanations: add support for xgboost
5 participants