Skip to content

Commit

Permalink
Displaying the top_k features with largest shap value magnitudes (#1374)
Browse files Browse the repository at this point in the history
* Displaying the top_k features with largest shap value magnitudes

* Adding PR 1360 to release notes.

* Updating broken tests in test_explainers to reflect top_k rows by magnitude.

* Updating comment.

* Removing linebreak in release notes for pr 1374.

* Sorting shap values in _make_rows via lambda function.

* Not changing notebook python version in model understanding user guide.
  • Loading branch information
freddyaboulton committed Nov 6, 2020
1 parent 4bf2f9b commit 6a80b40
Show file tree
Hide file tree
Showing 5 changed files with 114 additions and 144 deletions.
7 changes: 7 additions & 0 deletions docs/source/release_notes.rst
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ Release Notes
* Enhancements
* Added ability to freeze hyperparameters for AutoMLSearch :pr:`1284`
* Added the index id to the ``explain_predictions_best_worst`` output to help users identify which rows in their data are included :pr:`1365`
* The top_k features displayed in ``explain_predictions_*`` functions are now determined by the magnitude of shap values as opposed to the ``top_k`` largest and smallest shap values. :pr:`1374`
* Fixes
* Updated enum classes to show possible enum values as attributes :pr:`1391`
* Changes
Expand All @@ -17,6 +18,12 @@ Release Notes
* Removed ``category_encoders`` from test-requirements.txt :pr:`1373`
* Tweak codecov.io settings again to avoid flakes :pr:`1413`

.. warning::

**Breaking Changes**
* The ``top_k`` and ``top_k_features`` parameters in ``explain_predictions_*`` functions now return ``k`` features as opposed to ``2 * k`` features :pr:`1374`


**v0.15.0 Oct. 29, 2020**
* Enhancements
* Added stacked ensemble component classes (``StackedEnsembleClassifier``, ``StackedEnsembleRegressor``) :pr:`1134`
Expand Down
7 changes: 4 additions & 3 deletions docs/source/user_guide/model_understanding.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -330,7 +330,7 @@
"from evalml.model_understanding.prediction_explanations import explain_prediction\n",
"\n",
"table = explain_prediction(pipeline=pipeline, input_features=X.iloc[3:4],\n",
" training_data=X, include_shap_values=True)\n",
" training_data=X, top_k=6, include_shap_values=True)\n",
"print(table)"
]
},
Expand Down Expand Up @@ -367,7 +367,7 @@
"from evalml.model_understanding.prediction_explanations import explain_predictions_best_worst\n",
"\n",
"report = explain_predictions_best_worst(pipeline=pipeline, input_features=X, y_true=y,\n",
" include_shap_values=True, num_to_explain=2)\n",
" include_shap_values=True, top_k_features=6, num_to_explain=2)\n",
"\n",
"print(report)"
]
Expand Down Expand Up @@ -431,7 +431,8 @@
"source": [
"import json\n",
"report = explain_predictions_best_worst(pipeline=pipeline, input_features=X, y_true=y,\n",
" num_to_explain=1, include_shap_values=True, output_format=\"dict\")\n",
" num_to_explain=1, top_k_features=6,\n",
" include_shap_values=True, output_format=\"dict\")\n",
"print(json.dumps(report, indent=2))"
]
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,12 +26,13 @@ def _make_rows(shap_values, normalized_values, pipeline_features, top_k, include
list(str)
"""
tuples = [(value[0], feature_name) for feature_name, value in normalized_values.items()]
tuples = sorted(tuples)

if len(tuples) <= 2 * top_k:
features_to_display = reversed(tuples)
else:
features_to_display = tuples[-top_k:][::-1] + tuples[:top_k][::-1]
# Sort the features s.t the top_k w the largest shap value magnitudes are the first
# top_k elements
tuples = sorted(tuples, key=lambda x: abs(x[0]), reverse=True)

# Then sort such that the SHAP values go from most positive to most negative
features_to_display = reversed(sorted(tuples[:top_k]))

rows = []
for value, feature_name in features_to_display:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,27 +34,25 @@ def test_explain_prediction_value_error(test_features):

explain_prediction_answer = """Feature Name Feature Value Contribution to Prediction
=========================================================
d 40.00 ++++
a 10.00 +++
c 30.00 --
b 20.00 ----""".splitlines()
d 40.00 +++++
b 20.00 -----""".splitlines()

explain_prediction_regression_dict_answer = {
"explanations": [{
"feature_names": ["d", "a", "c", "b"],
"feature_values": [40, 10, 30, 20],
"qualitative_explanation": ["++++", "+++", "--", "----"],
"quantitative_explanation": [None, None, None, None],
"feature_names": ["d", "b"],
"feature_values": [40, 20],
"qualitative_explanation": ["+++++", "-----"],
"quantitative_explanation": [None, None],
"class_name": None
}]
}

explain_prediction_binary_dict_answer = {
"explanations": [{
"feature_names": ["d", "a", "c", "b"],
"feature_values": [40, 10, 30, 20],
"qualitative_explanation": ["++++", "+++", "--", "----"],
"quantitative_explanation": [None, None, None, None],
"feature_names": ["d", "b"],
"feature_values": [40, 20],
"qualitative_explanation": ["+++++", "-----"],
"quantitative_explanation": [None, None],
"class_name": "class_1"
}]
}
Expand All @@ -63,10 +61,8 @@ def test_explain_prediction_value_error(test_features):
Feature Name Feature Value Contribution to Prediction
=========================================================
a 10.00 +
b 20.00 +
c 30.00 -
d 40.00 -
a 10.00 +++++
c 30.00 ---
Class: class_1
Expand All @@ -75,36 +71,32 @@ def test_explain_prediction_value_error(test_features):
=========================================================
a 10.00 +++
b 20.00 ++
c 30.00 -
d 40.00 --
Class: class_2
Feature Name Feature Value Contribution to Prediction
=========================================================
a 10.00 +
b 20.00 +
c 30.00 ---
d 40.00 ---
""".splitlines()

explain_prediction_multiclass_dict_answer = {
"explanations": [
{"feature_names": ["a", "b", "c", "d"],
"feature_values": [10, 20, 30, 40],
"qualitative_explanation": ["+", "+", "-", "-"],
"quantitative_explanation": [None] * 4,
{"feature_names": ["a", "c"],
"feature_values": [10, 30],
"qualitative_explanation": ["+++++", "---"],
"quantitative_explanation": [None, None],
"class_name": "class_0"},
{"feature_names": ["a", "b", "c", "d"],
"feature_values": [10, 20, 30, 40],
"qualitative_explanation": ["+++", "++", "-", "--"],
"quantitative_explanation": [None] * 4,
{"feature_names": ["a", "b"],
"feature_values": [10, 20],
"qualitative_explanation": ["+++", "++"],
"quantitative_explanation": [None, None],
"class_name": "class_1"},
{"feature_names": ["a", "b", "c", "d"],
"feature_values": [10, 20, 30, 40],
"qualitative_explanation": ["+", "+", "---", "---"],
"quantitative_explanation": [None] * 4,
{"feature_names": ["c", "d"],
"feature_values": [30, 40],
"qualitative_explanation": ["---", "---"],
"quantitative_explanation": [None, None],
"class_name": "class_2"},
]
}
Expand All @@ -113,36 +105,36 @@ def test_explain_prediction_value_error(test_features):
@pytest.mark.parametrize("problem_type,output_format,shap_values,normalized_shap_values,answer",
[(ProblemTypes.REGRESSION,
"text",
{"a": [1], "b": [-2], "c": [-0.25], "d": [2]},
{"a": [0.5], "b": [-0.75], "c": [-0.25], "d": [0.75]},
{"a": [1], "b": [-2.1], "c": [-0.25], "d": [2.3]},
{"a": [0.5], "b": [-2.1], "c": [-0.25], "d": [2.3]},
explain_prediction_answer),
(ProblemTypes.REGRESSION,
"dict",
{"a": [1], "b": [-2], "c": [-0.25], "d": [2]},
{"a": [0.5], "b": [-0.75], "c": [-0.25], "d": [0.75]},
{"a": [1], "b": [-2.1], "c": [-0.25], "d": [2.3]},
{"a": [0.5], "b": [-2.1], "c": [-0.25], "d": [2.3]},
explain_prediction_regression_dict_answer
),
(ProblemTypes.BINARY,
"text",
[{}, {"a": [1], "b": [-2], "c": [-0.25], "d": [2]}],
[{}, {"a": [0.5], "b": [-0.75], "c": [-0.25], "d": [0.75]}],
[{}, {"a": [0.5], "b": [-0.89], "c": [0.33], "d": [0.89]}],
[{}, {"a": [0.5], "b": [-0.89], "c": [-0.25], "d": [0.89]}],
explain_prediction_answer),
(ProblemTypes.BINARY,
"dict",
[{}, {"a": [1], "b": [-2], "c": [-0.25], "d": [2]}],
[{}, {"a": [0.5], "b": [-0.75], "c": [-0.25], "d": [0.75]}],
[{}, {"a": [0.5], "b": [-0.89], "c": [0.33], "d": [0.89]}],
[{}, {"a": [0.5], "b": [-0.89], "c": [-0.25], "d": [0.89]}],
explain_prediction_binary_dict_answer),
(ProblemTypes.MULTICLASS,
"text",
[{}, {}, {}],
[{"a": [0.1], "b": [0.09], "c": [-0.04], "d": [-0.06]},
[{"a": [1.1], "b": [0.09], "c": [-0.53], "d": [-0.06]},
{"a": [0.53], "b": [0.24], "c": [-0.15], "d": [-0.22]},
{"a": [0.03], "b": [0.02], "c": [-0.42], "d": [-0.47]}],
explain_prediction_multiclass_answer),
(ProblemTypes.MULTICLASS,
"dict",
[{}, {}, {}],
[{"a": [0.1], "b": [0.09], "c": [-0.04], "d": [-0.06]},
[{"a": [1.1], "b": [0.09], "c": [-0.53], "d": [-0.06]},
{"a": [0.53], "b": [0.24], "c": [-0.15], "d": [-0.22]},
{"a": [0.03], "b": [0.02], "c": [-0.42], "d": [-0.47]}],
explain_prediction_multiclass_dict_answer)
Expand Down
Loading

0 comments on commit 6a80b40

Please sign in to comment.