-
Notifications
You must be signed in to change notification settings - Fork 87
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
953 implement table ui for shap #974
Conversation
Codecov Report
@@ Coverage Diff @@
## main #974 +/- ##
==========================================
+ Coverage 99.67% 99.87% +0.20%
==========================================
Files 174 178 +4
Lines 9043 9163 +120
==========================================
+ Hits 9014 9152 +138
+ Misses 29 11 -18
Continue to review full report at Codecov.
|
…elines/explanations.
@@ -11,3 +11,4 @@ click>=7.0.0 | |||
psutil>=5.6.3 | |||
requirements-parser>=0.2.0 | |||
shap>=0.35.0 | |||
texttable>=1.6.2 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The wheel is 10kb and it's implemented in base python so I don't think it will cause any dependency issues down the line.
return _make_single_prediction_table(shap_values, normalized_shap_values, top_k, include_shap_values) | ||
|
||
|
||
def _explain_prediction(pipeline, features, training_data=None): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We would make this function public once we go live. I don't think we'd make _explain_with_shap_values
public but if someone ever needed to access the SHAP values, we can point them to it.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@freddyaboulton I have a suggestion on the organization:
- Either delete this method, or name it
print_prediction_explanation
. Personally I think its ok to delete since users can callprint
whenever they want - Rename
_explain_with_shap_values
to_explain_prediction
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah, deleting the one which calls print
will make the testing easier too
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
+1 to this!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done!
@@ -52,7 +52,9 @@ def _compute_shap_values(pipeline, features, training_data=None): | |||
|
|||
# This is to make sure all dtypes are numeric - SHAP algorithms will complain otherwise. | |||
# Sklearn components do this under-the-hood so we're not changing the data the model was trained on. | |||
pipeline_features = check_array(pipeline_features.values) | |||
# Catboost can naturally handle string-enconded categorical features so we don't need to convert to numeric. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ah, makes sense.
Typo: string-enconded --> string-encoded
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Great stuff! 👏 The table looks really nice! I like that the logic to compute it is pretty easy to follow. I'm excited to try this out.
Nothing blocking per se, but two thoughts which would be good to resolve:
- I do think it would be a good idea to resolve the discussion about method names in
_explainers.py
. I suggested deleting the method which just callsprint
and renaming the other one to_explain_prediction
- I think your test coverage of
make_table
is solid, but it would be great to add coverage of what's currently called_explain_with_shap_values
, assuming we delete_explain_predictions
. Could be as simple as mocking_compute_shap_values
and asserting the returned string matches on one example.
There's a lot of standalone methods at this point for prediction explanation. Makes me wonder if a class would help with encapsulation. But nothing jumped out at me... fundamentally, having everything wrapped up in explain_prediction
means users don't need to worry about the rest.
You're gonna add docs and example usage in a separate PR, right?
return _make_single_prediction_table(shap_values, normalized_shap_values, top_k, include_shap_values) | ||
|
||
|
||
def _explain_prediction(pipeline, features, training_data=None): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@freddyaboulton I have a suggestion on the organization:
- Either delete this method, or name it
print_prediction_explanation
. Personally I think its ok to delete since users can callprint
whenever they want - Rename
_explain_with_shap_values
to_explain_prediction
None: displays a table to std out | ||
""" | ||
if not (isinstance(features, pd.DataFrame) and features.shape[0] == 1): | ||
raise ValueError("features must be stored in a dataframe of one row.") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Perhaps if we delete this method, this validation can be moved into the other method
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm on-board with the re-org you suggested. The only reason I did it this way was to not directly expose the ability to display SHAP values as a column but I don't think that's necessary anymore.
features (pd.DataFrame): Dataframe of features - needs to correspond to data the pipeline was fit on. | ||
top_k (int): How many of the highest/lowest features to include in the table. | ||
training_data (pd.DataFrame): Training data the pipeline was fit on. | ||
For non-tree estimators, we need a sample of training data for the KernelSHAP algorithm. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nice. Suggest "Training data the pipeline was fit on. This is required for non-tree estimators because we need a sample of training data for the KernelSHAP algorithm."
|
||
Arguments: | ||
pipeline (PipelineBase): Fitted pipeline whose predictions we want to explain with SHAP. | ||
features (pd.DataFrame): Dataframe of features - needs to correspond to data the pipeline was fit on. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is one row of input data, right? Why not expect pd.Series
? Perhaps we can say "input features" to make it clear we're talking about the raw input.
>>> data = pd.read_csv("/Users/dylan.sherry/Downloads/dataset_61_iris.csv")
>>> row0 = data.loc[0]
>>> print(row0)
>>> print()
>>> print(type(row0))
>>> print()
>>> print(row0.index)
>>> print()
>>> print(row0.to_numpy())
outputs
sepallength 5.1
sepalwidth 3.5
petallength 1.4
petalwidth 0.2
class Iris-setosa
Name: 0, dtype: object
<class 'pandas.core.series.Series'>
Index(['sepallength', 'sepalwidth', 'petallength', 'petalwidth', 'class'], dtype='object')
[5.1 3.5 1.4 0.2 'Iris-setosa']
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The problem is the call to pipeline._transform
. It can't handle pd.Series
objects so we would have to convert to a dataframe internally. The problem with doing the conversion internally is that the dtype
information is lost when you do df.iloc[index]
(see the example below). This would cause some of the components to freak out (like OneHotEncoder
) because there are more object
dtype columns than the data it was fit on.
We would have to require the user to always pass in the dtypes and/or the training data? This would only impact explaining tree based models which are the only estimators for which SHAP doesn't require training data. It's doable but I'm curious what your thoughts are.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Got it @freddyaboulton . So then pandas dataframe is required in this method?
rows = [] | ||
for value, feature_name in features_to_display: | ||
symbol = "+" if value >= 0 else "-" | ||
display_text = symbol * min(int(abs(value) // 0.2) + 1, 5) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nice!! Breaking out the integer division operator, love it 😁
test_cases = [5, [1], np.ones((1, 15)), pd.DataFrame({"a": [1, 2, 3], "b": [1, 2, 3]}).iloc[0]] | ||
|
||
|
||
@pytest.mark.parametrize("test_case", test_cases) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nit-pick: one's singular and one's plural
return _make_single_prediction_table(shap_values, normalized_shap_values, top_k, include_shap_values) | ||
|
||
|
||
def _explain_prediction(pipeline, features, training_data=None): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah, deleting the one which calls print
will make the testing easier too
'c + 0.000', 'd -- -2.560', 'e -- -2.800', 'f -- -2.900', '', '', 'Class 2', '', | ||
'Feature Name Contribution to Prediction SHAP Value', | ||
'======================================================', 'a + 0.680', 'c + 0.000', | ||
'b + 0.000', 'd -- -1.840', 'e -- -2.040', 'f -- -2.680', ''] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Idk if this makes things easier but you could do
multiclass_table_shap = """Class 0
Feature Name Contribution to Prediction SHAP Value
======================================================
...
""".split(\n)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Much more legible!
|
||
# Making sure the content is the same, regardless of formatting. | ||
for row_table, row_answer in zip(table.splitlines(), answer): | ||
assert row_table.strip().split() == row_answer.strip().split() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is great!
table = _make_table(dtypes, alignment, values, values, top_k, include_shap_values).splitlines() | ||
if include_shap_values: | ||
assert "SHAP Value" in table[0] | ||
# Subtracting two because a header and a line under the header are included in the table. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
👍
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This looks really good! 👍
@dsherry Thanks for the feedback! Yes, I plan on adding a tutorial and updating the docs in the next PR once this is merged. I think using classes might make more sense when we have multiple interpretation algorithms but I think the code layout here is straightforward and there isn't a lot "state" we need to keep track of yet. I renamed |
Pull Request Description
Closes #953 .
Demo of what is displayed to user
To see more examples, see https://github.com/FeatureLabs/shap-reports/pull/1 (but those results include the SHAP value).
After creating the pull request: in order to pass the release_notes_updated check you will need to update the "Future Release" section of
docs/source/release_notes.rst
to include this pull request by adding :pr:123
.