Skip to content

Prediction explanations: multiclass catboost support#2224

Merged
dsherry merged 6 commits intomainfrom
ds_2136_shap_catboost_multiclass
May 4, 2021
Merged

Prediction explanations: multiclass catboost support#2224
dsherry merged 6 commits intomainfrom
ds_2136_shap_catboost_multiclass

Conversation

@dsherry
Copy link
Contributor

@dsherry dsherry commented May 4, 2021

Fixes #2136

The following code works after this PR:

import evalml
X, y = evalml.demos.load_wine()
assert len(y.to_series().unique()) == 3
pipeline = evalml.pipelines.MulticlassClassificationPipeline(['Imputer', 'CatBoost Classifier'])
pipeline.fit(X, y)
print(evalml.model_understanding.prediction_explanations.explain_predictions(
    pipeline=pipeline, input_features=X, y=y, indices_to_explain=[0]))

which outputs

CatBoost Classifier w/ Imputer

{'Imputer': {'categorical_impute_strategy': 'most_frequent', 'numeric_impute_strategy': 'mean', 'categorical_fill_value': None, 'numeric_fill_value': None}, 'CatBoost Classifier': {'n_estimators': 10, 'eta': 0.03, 'max_depth': 6, 'bootstrap_type': None, 'silent': True, 'allow_writing_files': False}}

	1 of 1

		Class: class_0
		
		 Feature Name     Feature Value   Contribution to Prediction
		============================================================
		    proline          1065.00                  ++            
		    alcohol           14.23                   +             
		color_intensity       5.64                    +             
		
		
		Class: class_1
		
		 Feature Name     Feature Value   Contribution to Prediction
		============================================================
		    alcohol           14.23                   --            
		    proline          1065.00                  --            
		color_intensity       5.64                    --            
		
		
		Class: class_2
		
		Feature Name    Feature Value   Contribution to Prediction
		==========================================================
		   proline         1065.00                  -             
		 flavanoids         3.06                    -             
		total_phenols       2.80                    -             

@codecov
Copy link

codecov bot commented May 4, 2021

Codecov Report

Merging #2224 (3c35a5e) into main (bc8460a) will decrease coverage by 0.1%.
The diff coverage is 100.0%.

Impacted file tree graph

@@            Coverage Diff            @@
##             main    #2224     +/-   ##
=========================================
- Coverage   100.0%   100.0%   -0.0%     
=========================================
  Files         288      288             
  Lines       24495    24489      -6     
=========================================
- Hits        24477    24470      -7     
- Misses         18       19      +1     
Impacted Files Coverage Δ
...s/prediction_explanations_tests/test_algorithms.py 98.3% <ø> (-0.9%) ⬇️
...derstanding/prediction_explanations/_algorithms.py 98.9% <100.0%> (-<0.1%) ⬇️

Continue to review full report at Codecov.

Legend - Click here to learn more
Δ = absolute <relative> (impact), ø = not affected, ? = missing data
Powered by Codecov. Last update bc8460a...3c35a5e. Read the comment docs.

@dsherry dsherry marked this pull request as ready for review May 4, 2021 13:03
pytest.skip("Skipping because estimator and pipeline are not compatible.")

if problem_type == ProblemTypes.MULTICLASS and estimator.model_family == ModelFamily.CATBOOST:
pytest.skip("Skipping Catboost for multiclass problems.")
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

After deleting this check, this test does verify that multiclass catboost works. I verified this by changing the multiclass catboost impl to error, and saw that this test errored out.



baseline_message = "You passed in a baseline pipeline. These are simple enough that SHAP values are not needed."
xg_boost_message = "SHAP values cannot currently be computed for xgboost models."
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I didn't see this message being used anywhere. I think its left over from a previous change.

# Because of this issue: https://github.com/slundberg/shap/issues/1215
if estimator.model_family == ModelFamily.CATBOOST and is_multiclass(pipeline.problem_type):
# Will randomly segfault
raise NotImplementedError("SHAP values cannot currently be computed for catboost models for multiclass problems.")
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think this is true anymore with shap >= 0.36.0!

features = check_array(features.values)

if estimator.model_family.is_tree_estimator():
# Because of this issue: https://github.com/slundberg/shap/issues/1215
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this was the wrong issue. This is a discussion about xgboost.

Copy link
Contributor

@chukarsten chukarsten left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good!

Copy link
Contributor

@freddyaboulton freddyaboulton left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good @dsherry ! Can you update the model understanding docs? There's a sentence saying we don't support prediction explanations for catboost multiclass.

@dsherry dsherry force-pushed the ds_2136_shap_catboost_multiclass branch from 5294e17 to 3c35a5e Compare May 4, 2021 14:45
Copy link
Contributor

@bchen1116 bchen1116 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice!

@dsherry dsherry merged commit 614ff06 into main May 4, 2021
@dsherry dsherry deleted the ds_2136_shap_catboost_multiclass branch May 4, 2021 16:08
This was referenced May 4, 2021
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

Prediction explanations: support catboost

4 participants