Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Displaying the top_k features with largest shap value magnitudes #1374

Merged
merged 7 commits into from
Nov 6, 2020
7 changes: 7 additions & 0 deletions docs/source/release_notes.rst
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ Release Notes
* Enhancements
* Added ability to freeze hyperparameters for AutoMLSearch :pr:`1284`
* Added the index id to the ``explain_predictions_best_worst`` output to help users identify which rows in their data are included :pr:`1365`
* The top_k features displayed in ``explain_predictions_*`` functions are now determined by the magnitude of shap values as opposed to the ``top_k`` largest and smallest shap values. :pr:`1374`
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

👍

* Fixes
* Updated enum classes to show possible enum values as attributes :pr:`1391`
* Changes
Expand All @@ -17,6 +18,12 @@ Release Notes
* Removed ``category_encoders`` from test-requirements.txt :pr:`1373`
* Tweak codecov.io settings again to avoid flakes :pr:`1413`

.. warning::

**Breaking Changes**
* The ``top_k`` and ``top_k_features`` parameters in ``explain_predictions_*`` functions now return ``k`` features as opposed to ``2 * k`` features :pr:`1374`
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

👍



**v0.15.0 Oct. 29, 2020**
* Enhancements
* Added stacked ensemble component classes (``StackedEnsembleClassifier``, ``StackedEnsembleRegressor``) :pr:`1134`
Expand Down
7 changes: 4 additions & 3 deletions docs/source/user_guide/model_understanding.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -330,7 +330,7 @@
"from evalml.model_understanding.prediction_explanations import explain_prediction\n",
"\n",
"table = explain_prediction(pipeline=pipeline, input_features=X.iloc[3:4],\n",
" training_data=X, include_shap_values=True)\n",
" training_data=X, top_k=6, include_shap_values=True)\n",
"print(table)"
]
},
Expand Down Expand Up @@ -367,7 +367,7 @@
"from evalml.model_understanding.prediction_explanations import explain_predictions_best_worst\n",
"\n",
"report = explain_predictions_best_worst(pipeline=pipeline, input_features=X, y_true=y,\n",
" include_shap_values=True, num_to_explain=2)\n",
" include_shap_values=True, top_k_features=6, num_to_explain=2)\n",
"\n",
"print(report)"
]
Expand Down Expand Up @@ -431,7 +431,8 @@
"source": [
"import json\n",
"report = explain_predictions_best_worst(pipeline=pipeline, input_features=X, y_true=y,\n",
" num_to_explain=1, include_shap_values=True, output_format=\"dict\")\n",
" num_to_explain=1, top_k_features=6,\n",
" include_shap_values=True, output_format=\"dict\")\n",
"print(json.dumps(report, indent=2))"
]
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,12 +26,13 @@ def _make_rows(shap_values, normalized_values, pipeline_features, top_k, include
list(str)
"""
tuples = [(value[0], feature_name) for feature_name, value in normalized_values.items()]
tuples = sorted(tuples)

if len(tuples) <= 2 * top_k:
features_to_display = reversed(tuples)
else:
features_to_display = tuples[-top_k:][::-1] + tuples[:top_k][::-1]
# Sort the features s.t the top_k w the largest shap value magnitudes are the first
# top_k elements
tuples = sorted(tuples, key=lambda x: abs(x[0]), reverse=True)

# Then sort such that the SHAP values go from most positive to most negative
features_to_display = reversed(sorted(tuples[:top_k]))

rows = []
for value, feature_name in features_to_display:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,27 +34,25 @@ def test_explain_prediction_value_error(test_features):

explain_prediction_answer = """Feature Name Feature Value Contribution to Prediction
=========================================================
d 40.00 ++++
a 10.00 +++
c 30.00 --
b 20.00 ----""".splitlines()
d 40.00 +++++
b 20.00 -----""".splitlines()

explain_prediction_regression_dict_answer = {
"explanations": [{
"feature_names": ["d", "a", "c", "b"],
"feature_values": [40, 10, 30, 20],
"qualitative_explanation": ["++++", "+++", "--", "----"],
"quantitative_explanation": [None, None, None, None],
"feature_names": ["d", "b"],
"feature_values": [40, 20],
"qualitative_explanation": ["+++++", "-----"],
"quantitative_explanation": [None, None],
"class_name": None
}]
}

explain_prediction_binary_dict_answer = {
"explanations": [{
"feature_names": ["d", "a", "c", "b"],
"feature_values": [40, 10, 30, 20],
"qualitative_explanation": ["++++", "+++", "--", "----"],
"quantitative_explanation": [None, None, None, None],
"feature_names": ["d", "b"],
"feature_values": [40, 20],
"qualitative_explanation": ["+++++", "-----"],
"quantitative_explanation": [None, None],
"class_name": "class_1"
}]
}
Expand All @@ -63,10 +61,8 @@ def test_explain_prediction_value_error(test_features):

Feature Name Feature Value Contribution to Prediction
=========================================================
a 10.00 +
b 20.00 +
c 30.00 -
d 40.00 -
a 10.00 +++++
c 30.00 ---


Class: class_1
Expand All @@ -75,36 +71,32 @@ def test_explain_prediction_value_error(test_features):
=========================================================
a 10.00 +++
b 20.00 ++
c 30.00 -
d 40.00 --


Class: class_2

Feature Name Feature Value Contribution to Prediction
=========================================================
a 10.00 +
b 20.00 +
c 30.00 ---
d 40.00 ---
""".splitlines()

explain_prediction_multiclass_dict_answer = {
"explanations": [
{"feature_names": ["a", "b", "c", "d"],
"feature_values": [10, 20, 30, 40],
"qualitative_explanation": ["+", "+", "-", "-"],
"quantitative_explanation": [None] * 4,
{"feature_names": ["a", "c"],
"feature_values": [10, 30],
"qualitative_explanation": ["+++++", "---"],
"quantitative_explanation": [None, None],
"class_name": "class_0"},
{"feature_names": ["a", "b", "c", "d"],
"feature_values": [10, 20, 30, 40],
"qualitative_explanation": ["+++", "++", "-", "--"],
"quantitative_explanation": [None] * 4,
{"feature_names": ["a", "b"],
"feature_values": [10, 20],
"qualitative_explanation": ["+++", "++"],
"quantitative_explanation": [None, None],
"class_name": "class_1"},
{"feature_names": ["a", "b", "c", "d"],
"feature_values": [10, 20, 30, 40],
"qualitative_explanation": ["+", "+", "---", "---"],
"quantitative_explanation": [None] * 4,
{"feature_names": ["c", "d"],
"feature_values": [30, 40],
"qualitative_explanation": ["---", "---"],
"quantitative_explanation": [None, None],
"class_name": "class_2"},
]
}
Expand All @@ -113,36 +105,36 @@ def test_explain_prediction_value_error(test_features):
@pytest.mark.parametrize("problem_type,output_format,shap_values,normalized_shap_values,answer",
[(ProblemTypes.REGRESSION,
"text",
{"a": [1], "b": [-2], "c": [-0.25], "d": [2]},
{"a": [0.5], "b": [-0.75], "c": [-0.25], "d": [0.75]},
{"a": [1], "b": [-2.1], "c": [-0.25], "d": [2.3]},
{"a": [0.5], "b": [-2.1], "c": [-0.25], "d": [2.3]},
explain_prediction_answer),
(ProblemTypes.REGRESSION,
"dict",
{"a": [1], "b": [-2], "c": [-0.25], "d": [2]},
{"a": [0.5], "b": [-0.75], "c": [-0.25], "d": [0.75]},
{"a": [1], "b": [-2.1], "c": [-0.25], "d": [2.3]},
{"a": [0.5], "b": [-2.1], "c": [-0.25], "d": [2.3]},
explain_prediction_regression_dict_answer
),
(ProblemTypes.BINARY,
"text",
[{}, {"a": [1], "b": [-2], "c": [-0.25], "d": [2]}],
[{}, {"a": [0.5], "b": [-0.75], "c": [-0.25], "d": [0.75]}],
[{}, {"a": [0.5], "b": [-0.89], "c": [0.33], "d": [0.89]}],
[{}, {"a": [0.5], "b": [-0.89], "c": [-0.25], "d": [0.89]}],
explain_prediction_answer),
(ProblemTypes.BINARY,
"dict",
[{}, {"a": [1], "b": [-2], "c": [-0.25], "d": [2]}],
[{}, {"a": [0.5], "b": [-0.75], "c": [-0.25], "d": [0.75]}],
[{}, {"a": [0.5], "b": [-0.89], "c": [0.33], "d": [0.89]}],
[{}, {"a": [0.5], "b": [-0.89], "c": [-0.25], "d": [0.89]}],
explain_prediction_binary_dict_answer),
(ProblemTypes.MULTICLASS,
"text",
[{}, {}, {}],
[{"a": [0.1], "b": [0.09], "c": [-0.04], "d": [-0.06]},
[{"a": [1.1], "b": [0.09], "c": [-0.53], "d": [-0.06]},
{"a": [0.53], "b": [0.24], "c": [-0.15], "d": [-0.22]},
{"a": [0.03], "b": [0.02], "c": [-0.42], "d": [-0.47]}],
explain_prediction_multiclass_answer),
(ProblemTypes.MULTICLASS,
"dict",
[{}, {}, {}],
[{"a": [0.1], "b": [0.09], "c": [-0.04], "d": [-0.06]},
[{"a": [1.1], "b": [0.09], "c": [-0.53], "d": [-0.06]},
{"a": [0.53], "b": [0.24], "c": [-0.15], "d": [-0.22]},
{"a": [0.03], "b": [0.02], "c": [-0.42], "d": [-0.47]}],
explain_prediction_multiclass_dict_answer)
Expand Down
Loading