Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add waterfall color config for waterfall plot #3377

Open
wants to merge 15 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
3 changes: 2 additions & 1 deletion shap/plots/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
from ._scatter import scatter
from ._text import text
from ._violin import violin
from ._waterfall import waterfall
from ._waterfall import WaterfallColorConfig, waterfall

__all__ = [
"bar",
Expand All @@ -37,4 +37,5 @@
"text",
"violin",
"waterfall",
"WaterfallColorConfig"
]
82 changes: 62 additions & 20 deletions shap/plots/_waterfall.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,6 @@
from dataclasses import dataclass, field
from typing import Dict, List, Union

import matplotlib
import matplotlib.pyplot as plt
import numpy as np
Expand All @@ -9,11 +12,29 @@
from ._labels import labels


@dataclass
class WaterfallColorConfig:
"""Color configuration for waterfall plots.
"""
positive_arrow: Union[np.ndarray, str, List[float]] = field(default_factory=lambda: colors.red_rgb)
negative_arrow: Union[np.ndarray, str, List[float]] = field(default_factory=lambda: colors.blue_rgb)
default_positive_color: Union[np.ndarray, str, List[float]] = field(default_factory=lambda: colors.light_red_rgb)
default_negative_color: Union[np.ndarray, str, List[float]] = field(default_factory=lambda: colors.light_blue_rgb)
hlines: Union[np.ndarray, str, List[float]] = "#cccccc"
vlines: Union[np.ndarray, str, List[float]] = "#bbbbbb"
text: Union[np.ndarray, str, List[float]] = "white"
tick_labels: Union[np.ndarray, str, List[float]] = "#999999"


# TODO: If we make a JS version of this plot then we could let users click on a bar and then see the dependence
# plot that is associated with that feature get overlaid on the plot...it would quickly allow users to answer
# why a feature is pushing down or up. Perhaps the best way to do this would be with an ICE plot hanging off
# of the bar...
def waterfall(shap_values, max_display=10, show=True):
def waterfall(shap_values: Explanation,
max_display: int = 10,
show: bool = True,
plot_cmap: Union[WaterfallColorConfig, Dict[str, Union[str, List[float], np.ndarray]], None, str, List[str]] = None
):
"""Plots an explanation of a single prediction as a waterfall plot.

The SHAP value of a feature represents the impact of the evidence provided by that feature on the model's
Expand All @@ -30,20 +51,42 @@
shap_values : Explanation
A one-dimensional :class:`.Explanation` object that contains the feature values and SHAP values to plot.

max_display : str
max_display : int
The maximum number of features to plot (default is 10).

show : bool
Whether ``matplotlib.pyplot.show()`` is called before returning.
Setting this to ``False`` allows the plot to be customized further after it
has been created.
plot_cmap: shap.plots.WaterfallColorConfig, dict[str, Union[list[float], np.ndarray, str]], str, list[str] or None
Colormap to plot. This is either a dictionary with the keys can be either a numpy array or a list (with 3 float entries between 0 and 1)
a `matplotlib color name <https://matplotlib.org/cheatsheets/_images/cheatsheets-2.png>`_ (see section Color names) or a hex code.
Configurable keys: ``positive_arrow, negative_arrow, default_positive_color, default_negative_color, hlines, vlines, text, tick_labels``.
Missing keys will be filled with default values.
Furthermore one can pass a list of strings directly, e.g. ``["white", "blue", "yellow", "black", "beige"]`` to this argument,
which will set the first ``len(plot_cmaps)`` colors according to the list elements. Usage of a single string is also possible, e.g.
"kmcmrb" which works correspondingly by converting the string to list before applying the same logic. If the list is shorter than the number of configuration
option the latter options will filled with default values.

Examples
--------

See `waterfall plot examples <https://shap.readthedocs.io/en/latest/example_notebooks/api_examples/plots/waterfall.html>`_.

"""
if plot_cmap is None:
color_config = WaterfallColorConfig()
elif isinstance(plot_cmap, dict):
color_config = WaterfallColorConfig(**plot_cmap)
elif isinstance(plot_cmap, WaterfallColorConfig):
color_config = plot_cmap
elif isinstance(plot_cmap, str):
color_config = WaterfallColorConfig(*list(plot_cmap))
elif isinstance(plot_cmap, list):
color_config = WaterfallColorConfig(*plot_cmap)

if not isinstance(color_config, WaterfallColorConfig):
raise TypeError(f"Expected color_config to be of type shap.plots.WaterfallColorConfig, dict, str or list. Received {type(color_config)} instead.")

Check warning on line 89 in shap/plots/_waterfall.py

View check run for this annotation

Codecov / codecov/patch

shap/plots/_waterfall.py#L89

Added line #L89 was not covered by tests

# Turn off interactive plot
if show is False:
Expand Down Expand Up @@ -132,7 +175,7 @@
neg_lefts.append(loc)
if num_individual != num_features or i + 4 < num_individual:
plt.plot([loc, loc], [rng[i] - 1 - 0.4, rng[i] + 0.4],
color="#bbbbbb", linestyle="--", linewidth=0.5, zorder=-1)
color=color_config.vlines, linestyle="--", linewidth=0.5, zorder=-1)
if features is None:
yticklabels[rng[i]] = feature_names[order[i]]
else:
Expand Down Expand Up @@ -161,10 +204,10 @@
# draw invisible bars just for sizing the axes
label_padding = np.array([0.1*dataw if w < 1 else 0 for w in pos_widths])
plt.barh(pos_inds, np.array(pos_widths) + label_padding + 0.02*dataw,
left=np.array(pos_lefts) - 0.01*dataw, color=colors.red_rgb, alpha=0)
left=np.array(pos_lefts) - 0.01*dataw, color=color_config.positive_arrow, alpha=0)
label_padding = np.array([-0.1*dataw if -w < 1 else 0 for w in neg_widths])
plt.barh(neg_inds, np.array(neg_widths) + label_padding - 0.02*dataw,
left=np.array(neg_lefts) + 0.01*dataw, color=colors.blue_rgb, alpha=0)
left=np.array(neg_lefts) + 0.01*dataw, color=color_config.negative_arrow, alpha=0)

# define variable we need for plotting the arrows
head_length = 0.08
Expand All @@ -184,20 +227,20 @@
arrow_obj = plt.arrow(
pos_lefts[i], pos_inds[i], max(dist-hl_scaled, 0.000001), 0,
head_length=min(dist, hl_scaled),
color=colors.red_rgb, width=bar_width,
color=color_config.positive_arrow, width=bar_width,
head_width=bar_width,
)

if pos_low is not None and i < len(pos_low):
plt.errorbar(
pos_lefts[i] + pos_widths[i], pos_inds[i],
xerr=np.array([[pos_widths[i] - pos_low[i]], [pos_high[i] - pos_widths[i]]]),
ecolor=colors.light_red_rgb,
ecolor=color_config.default_positive_color,
)

txt_obj = plt.text(
pos_lefts[i] + 0.5*dist, pos_inds[i], format_value(pos_widths[i], '%+0.02f'),
horizontalalignment='center', verticalalignment='center', color="white",
horizontalalignment='center', verticalalignment='center', color=color_config.text,
fontsize=12,
)
text_bbox = txt_obj.get_window_extent(renderer=renderer)
Expand All @@ -209,7 +252,7 @@

txt_obj = plt.text(
pos_lefts[i] + (5/72)*bbox_to_xscale + dist, pos_inds[i], format_value(pos_widths[i], '%+0.02f'),
horizontalalignment='left', verticalalignment='center', color=colors.red_rgb,
horizontalalignment='left', verticalalignment='center', color=color_config.positive_arrow,
fontsize=12,
)

Expand All @@ -220,20 +263,20 @@
arrow_obj = plt.arrow(
neg_lefts[i], neg_inds[i], -max(-dist-hl_scaled, 0.000001), 0,
head_length=min(-dist, hl_scaled),
color=colors.blue_rgb, width=bar_width,
color=color_config.negative_arrow, width=bar_width,
head_width=bar_width,
)

if neg_low is not None and i < len(neg_low):
plt.errorbar(
neg_lefts[i] + neg_widths[i], neg_inds[i],
xerr=np.array([[neg_widths[i] - neg_low[i]], [neg_high[i] - neg_widths[i]]]),
ecolor=colors.light_blue_rgb,
ecolor=color_config.default_negative_color,
)

txt_obj = plt.text(
neg_lefts[i] + 0.5*dist, neg_inds[i], format_value(neg_widths[i], '%+0.02f'),
horizontalalignment='center', verticalalignment='center', color="white",
horizontalalignment='center', verticalalignment='center', color=color_config.text,
fontsize=12,
)
text_bbox = txt_obj.get_window_extent(renderer=renderer)
Expand All @@ -245,7 +288,7 @@

txt_obj = plt.text(
neg_lefts[i] - (5/72)*bbox_to_xscale + dist, neg_inds[i], format_value(neg_widths[i], '%+0.02f'),
horizontalalignment='right', verticalalignment='center', color=colors.blue_rgb,
horizontalalignment='right', verticalalignment='center', color=color_config.negative_arrow,
fontsize=12,
)

Expand All @@ -256,12 +299,12 @@

# put horizontal lines for each feature row
for i in range(num_features):
plt.axhline(i, color="#cccccc", lw=0.5, dashes=(1, 5), zorder=-1)
plt.axhline(i, color=color_config.hlines, lw=0.5, dashes=(1, 5), zorder=-1)

# mark the prior expected value and the model prediction
plt.axvline(base_values, 0, 1/num_features, color="#bbbbbb", linestyle="--", linewidth=0.5, zorder=-1)
plt.axvline(base_values, 0, 1/num_features, color=color_config.vlines, linestyle="--", linewidth=0.5, zorder=-1)
fx = base_values + values.sum()
plt.axvline(fx, 0, 1, color="#bbbbbb", linestyle="--", linewidth=0.5, zorder=-1)
plt.axvline(fx, 0, 1, color=color_config.vlines, linestyle="--", linewidth=0.5, zorder=-1)

# clean up the main axis
plt.gca().xaxis.set_ticks_position('bottom')
Expand All @@ -270,7 +313,6 @@
plt.gca().spines['top'].set_visible(False)
plt.gca().spines['left'].set_visible(False)
ax.tick_params(labelsize=13)
#plt.xlabel("\nModel output", fontsize=12)

# draw the E[f(X)] tick mark
xmin, xmax = ax.get_xlim()
Expand All @@ -293,7 +335,7 @@
) + matplotlib.transforms.ScaledTranslation(-10/72., 0, fig.dpi_scale_trans))
tick_labels[1].set_transform(tick_labels[1].get_transform(
) + matplotlib.transforms.ScaledTranslation(12/72., 0, fig.dpi_scale_trans))
tick_labels[1].set_color("#999999")
tick_labels[1].set_color(color_config.tick_labels)
ax3.spines['right'].set_visible(False)
ax3.spines['top'].set_visible(False)
ax3.spines['left'].set_visible(False)
Expand All @@ -305,13 +347,13 @@
tick_labels[1].set_transform(tick_labels[1].get_transform(
) + matplotlib.transforms.ScaledTranslation(22/72., -1/72., fig.dpi_scale_trans))

tick_labels[1].set_color("#999999")
tick_labels[1].set_color(color_config.tick_labels)

# color the y tick labels that have the feature values as gray
# (these fall behind the black ones with just the feature name)
tick_labels = ax.yaxis.get_majorticklabels()
for i in range(num_features):
tick_labels[i].set_color("#999999")
tick_labels[i].set_color(color_config.tick_labels)

if show:
plt.show()
Expand Down
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
33 changes: 32 additions & 1 deletion tests/plots/test_waterfall.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from sklearn.tree import DecisionTreeRegressor

import shap
from shap.plots import WaterfallColorConfig, colors


def test_waterfall_input_is_explanation():
Expand Down Expand Up @@ -45,7 +46,37 @@ def test_waterfall_legacy(explainer):
shap.plots._waterfall.waterfall_legacy(explainer.expected_value, shap_values[0])
plt.tight_layout()
return fig

# todo: use parametrize to use all possible config options
@pytest.mark.parametrize("color_config", [{"positive_arrow": colors.red_rgb,
"negative_arrow": colors.blue_rgb,
"default_positive_color": colors.light_red_rgb,
"default_negative_color": colors.light_blue_rgb
},
{
'positive_arrow': np.array([1., 0., 0.31796406]),
'negative_arrow': np.array([0., 0.54337757, 0.98337906]),
'default_positive_color': np.array([1., 0.49803922, 0.65490196]),
'default_negative_color': np.array([0.49803922, 0.76862745, 0.98823529])
},
WaterfallColorConfig(
positive_arrow=np.array([1., 0., 0.31796406]),
negative_arrow=np.array([0., 0.54337757, 0.98337906]),
default_positive_color=np.array([1., 0.49803922, 0.65490196]),
default_negative_color=np.array([0.49803922, 0.76862745, 0.98823529])
),
[[1., 0., 0.31796406], [0., 0.54337757, 0.98337906], [1., 0.49803922, 0.65490196], [0.49803922, 0.76862745, 0.98823529]],
['#FF0051', '#008BFB', '#FF7FB1', '#7FC4FC'],
"kgry",
["black", "green", "red", "yellow"],
])
@pytest.mark.mpl_image_compare(tolerance=3)
def test_waterfall_color_config_default(explainer, color_config):
"""Test waterfall config options."""
fig = plt.figure()
shap_values = explainer(explainer.data)
shap.plots.waterfall(shap_values[0], plot_cmap=color_config)
plt.tight_layout()
return fig

def test_waterfall_plot_for_decision_tree_explanation():
# Regression tests for GH issue #3129
Expand Down