Skip to content

Commit

Permalink
Fix bug on force plots to show feature values for specific rows in tr…
Browse files Browse the repository at this point in the history
…aining data (#3044)

* Pass actual row data from correct index to the shap.force_plot call. Fixes bug where only first row feature values are show in all row plots.

* Lint edit

* Add release notes for bug #3045

* Forgot pr in release note line

* Lint edit - blank line

* Lint adjustment

* Add tests to check force plot feature values match those of the rows to explain. (#3044)

Co-authored-by: Angela Lin <angela97lin@gmail.com>
Co-authored-by: Freddy Boulton <alfonsoboulton@gmail.com>
  • Loading branch information
3 people committed Dec 1, 2021
1 parent 2598cfa commit 840fc3b
Show file tree
Hide file tree
Showing 3 changed files with 29 additions and 11 deletions.
1 change: 1 addition & 0 deletions docs/source/release_notes.rst
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ Release Notes
* Added in error message when fit and predict/predict_proba data types are different :pr:`3036`
* Fixed bug where ensembling components could not get converted to JSON format :pr:`3049`
* Fixed bug where components with tuned integer hyperparameters could not get converted to JSON format :pr:`3049`
* Fixed bug where force plots were not displaying correct feature values :pr:`3044`
* Included confusion matrix at the pipeline threshold for ``find_confusion_matrix_per_threshold`` :pr:`3080`
* Fixed bug where One Hot Encoder would error out if a non-categorical feature had a missing value :pr:`3083`
* Fixed bug where features created from categorical columns by ``Delayed Feature Transformer`` would be inferred as categorical :pr:`3083`
Expand Down
15 changes: 4 additions & 11 deletions evalml/model_understanding/force_plots.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,29 +28,22 @@ def graph_force_plot(pipeline, rows_to_explain, training_data, y, matplotlib=Fal

def gen_force_plot(shap_values, training_data, expected_value, matplotlib):
"""Helper function to generate a single force plot."""
# Ensure there are as many features as shap values.
assert training_data.shape[1] == len(shap_values)

# TODO: Update this to make the force plot display multi-row array visualizer.
training_data_sample = training_data.iloc[0]
shap_plot = shap.force_plot(
expected_value,
np.array(shap_values),
training_data_sample,
matplotlib=matplotlib,
expected_value, np.array(shap_values), training_data, matplotlib=matplotlib
)
return shap_plot

if jupyter_check():
initjs()

shap_plots = force_plot(pipeline, rows_to_explain, training_data, y)
for row in shap_plots:
for ix, row in enumerate(shap_plots):
row_id = rows_to_explain[ix]
for cls in row:
cls_dict = row[cls]
cls_dict["plot"] = gen_force_plot(
cls_dict["shap_values"],
training_data[cls_dict["feature_names"]],
training_data[cls_dict["feature_names"]].iloc[row_id],
cls_dict["expected_value"],
matplotlib=matplotlib,
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,25 @@
)


def validate_plot_feature_values(results, X):
"""Helper to validate feature values returned from force plots"""
for row, result in enumerate(results):
for result_label in result:
# Check feature values in generated force plot correspond to input rows
row_force_plot = result[result_label]["plot"]
assert "features" in row_force_plot.data
plot_features = row_force_plot.data["features"]
feature_vals = [plot_features[k]["value"] for k in plot_features]

# Features in results depend on effect size; filter feature names
effect_features_ix = plot_features.keys()
effect_features = [
row_force_plot.data["featureNames"][i] for i in effect_features_ix
]

assert all(feature_vals == X[effect_features].iloc[row].values)


def test_exceptions():
with pytest.raises(
TypeError,
Expand Down Expand Up @@ -113,6 +132,8 @@ def test_force_plot_binary(
shap.plots._force.AdditiveForceVisualizer,
)

validate_plot_feature_values(results, X)


@pytest.mark.parametrize(
"rows_to_explain, just_data", product([[0], [0, 1, 2, 3, 4]], [False, True])
Expand Down Expand Up @@ -164,6 +185,8 @@ def test_force_plot_multiclass(
shap.plots._force.AdditiveForceVisualizer,
)

validate_plot_feature_values(results, X)


@pytest.mark.parametrize(
"rows_to_explain, just_data", product([[0], [0, 1, 2, 3, 4]], [False, True])
Expand Down Expand Up @@ -211,6 +234,7 @@ def test_force_plot_regression(
assert isinstance(
result["regression"]["plot"], shap.plots._force.AdditiveForceVisualizer
)
validate_plot_feature_values(results, X)


@pytest.mark.parametrize("pipeline_class,estimator", pipeline_test_cases)
Expand Down

0 comments on commit 840fc3b

Please sign in to comment.