New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Aggregate prediction explanations for derived features #1901
Aggregate prediction explanations for derived features #1901
Conversation
Codecov Report
@@ Coverage Diff @@
## main #1901 +/- ##
=========================================
+ Coverage 100.0% 100.0% +0.1%
=========================================
Files 267 267
Lines 21536 21715 +179
=========================================
+ Hits 21530 21709 +179
Misses 6 6
Continue to review full report at Codecov.
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Wow, that's a great addition. Nice job.
provenance (dict): A mapping from a feature in the original data to the names of the features that were created | ||
from that feature | ||
Returns: | ||
dict |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Love this doc string. Very clear. I think the return just needs a description.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done!
Arguments: | ||
values (dict): A mapping of feature names to a list of SHAP values for each data point. | ||
provenance (dict): A mapping from a feature in the original data to the names of the features that were created | ||
from that feature |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
supernit: period?
json_rows = _rows_to_dict(rows) | ||
drill_down = self.make_drill_down_dict(self.provenance, shap_values[1], normalized_values[1], | ||
pipeline_features, original_features, self.include_shap_values) | ||
json_rows["drill_down"] = drill_down |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is "json_rows" a carryover copy pasta? It's kinda weird in make_dict()
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Changed the name to dict_rows
!
edf28c1
to
0bebd6a
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM! I left a few comments on docstrings and some nitpicks, but nothing blocking.
|
||
|
||
@pytest.fixture | ||
def fraud_100(): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nice! haha
|
||
table_maker = table_maker.make_text if output_format == "text" else table_maker.make_dict | ||
|
||
table = table_maker(values, normalized_values, pipeline_features, top_k=3, include_shap_values=include_shap) | ||
table = table_maker(values, normalized_values, values, normalized_values, pipeline_features, pipeline_features) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit-pick: I was really confused when I saw these input params repeated. Any chance you can add the keys, like:
table = table_maker(aggregated_shap_values=values,
aggregated_normalized_values=normalized_values,
shap_values=values,
normalized_values=normalized_values,
pipeline_features=pipeline_features,
original_features=pipeline_features)
just to make it a little clearer?
@abc.abstractmethod | ||
def make_text(self, shap_values, normalized_values, pipeline_features, top_k, include_shap_values=False): | ||
def make_text(self, aggregated_shap_values, aggregated_normalized_values, | ||
shap_values, normalized_values, pipeline_features, orignal_features): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
typo: original_features
json_output_for_class["class_name"] = _make_json_serializable(class_name) | ||
json_output.append(json_output_for_class) | ||
return {"explanations": json_output} | ||
|
||
|
||
def _make_single_prediction_shap_table(pipeline, pipeline_features, index_to_explain, top_k=3, | ||
def _make_single_prediction_shap_table(pipeline, pipeline_features, input_features, index_to_explain, top_k=3, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should we update this docstring to include input_features
and index_to_explain
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes!
@@ -395,7 +473,7 @@ def __init__(self, top_k_features, include_shap_values): | |||
self.top_k_features = top_k_features | |||
self.include_shap_values = include_shap_values | |||
|
|||
def make_text(self, index, pipeline, pipeline_features): | |||
def make_text(self, index, pipeline, pipeline_features, input_features): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
update doc string
f681308
to
e497258
Compare
e497258
to
f8ffba7
Compare
Pull Request Description
Fixes #1347
We will only aggregate the values for the features that we know the provenance of. Otherwise, no aggregation will happen. Which is basically the current behavior.
Sample output on titanic dataset
Example of drill_down dict on titanic dataset
Sample output of explain_predictions_best_worst on fraud dataset
After creating the pull request: in order to pass the release_notes_updated check you will need to update the "Future Release" section of
docs/source/release_notes.rst
to include this pull request by adding :pr:123
.