Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ENH: New plotting API for beeswarm so can accept and return Axes #3561

Open
wants to merge 4 commits into
base: master
Choose a base branch
from
Open
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
71 changes: 44 additions & 27 deletions shap/plots/_beeswarm.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
# TODO: Add support for hclustering based explanations where we sort the leaf order by magnitude and then show the dendrogram to the left
def beeswarm(shap_values, max_display=10, order=Explanation.abs.mean(0),
clustering=None, cluster_threshold=0.5, color=None,
axis_color="#333333", alpha=1, show=True, log_scale=False,
axis_color="#333333", alpha=1, ax=None, show=True, log_scale=False,
color_bar=True, s=16, plot_size="auto", color_bar_label=labels["FEATURE_VALUE"]):
"""Create a SHAP beeswarm plot, colored by feature values when they are provided.

Expand All @@ -41,6 +41,9 @@
How many top features to include in the plot (default is 10, or 7 for
interaction plots).

ax: matplotlib Axes
Axes object to draw the plot onto, otherwise uses the current Axes.

show : bool
Whether ``matplotlib.pyplot.show()`` is called before returning.
Setting this to ``False`` allows the plot to be customized further
Expand All @@ -57,7 +60,13 @@
number of features that are being displayed. Passing a single float will cause
each row to be that many inches high. Passing a pair of floats will scale the
plot by that number of inches. If ``None`` is passed, then the size of the
current figure will be left unchanged.
current figure will be left unchanged. If ax is not ``None``, then plot_size has
no effect.

Returns
-------
ax: matplotlib Axes
Returns the Axes object with the plot drawn onto it. Only returned if ``show=False``.

Examples
--------
Expand Down Expand Up @@ -155,8 +164,14 @@
if feature_names is None:
feature_names = np.array([labels['FEATURE'] % str(i) for i in range(num_features)])

fig = pl.gcf()
ax_is_None = False
if ax is None:
ax = pl.gca()
ax_is_None = True

if log_scale:
pl.xscale('symlog')
ax.set_xscale('symlog')

Check warning on line 174 in shap/plots/_beeswarm.py

View check run for this annotation

Codecov / codecov/patch

shap/plots/_beeswarm.py#L174

Added line #L174 was not covered by tests

if clustering is None:
partition_tree = getattr(shap_values, "clustering", None)
Expand Down Expand Up @@ -318,17 +333,18 @@
yticklabels[-1] = "Sum of %d other features" % num_cut

row_height = 0.4
if plot_size == "auto":
pl.gcf().set_size_inches(8, min(len(feature_order), max_display) * row_height + 1.5)
elif type(plot_size) in (list, tuple):
pl.gcf().set_size_inches(plot_size[0], plot_size[1])
elif plot_size is not None:
pl.gcf().set_size_inches(8, min(len(feature_order), max_display) * plot_size + 1.5)
pl.axvline(x=0, color="#999999", zorder=-1)
if ax_is_None:
if plot_size == "auto":
fig.set_size_inches(8, min(len(feature_order), max_display) * row_height + 1.5)
elif type(plot_size) in (list, tuple):
fig.set_size_inches(plot_size[0], plot_size[1])
elif plot_size is not None:
fig.set_size_inches(8, min(len(feature_order), max_display) * plot_size + 1.5)

Check warning on line 342 in shap/plots/_beeswarm.py

View check run for this annotation

Codecov / codecov/patch

shap/plots/_beeswarm.py#L339-L342

Added lines #L339 - L342 were not covered by tests
ax.axvline(x=0, color="#999999", zorder=-1)

# make the beeswarm dots
for pos, i in enumerate(reversed(feature_inds)):
pl.axhline(y=pos, color="#cccccc", lw=0.5, dashes=(1, 5), zorder=-1)
ax.axhline(y=pos, color="#cccccc", lw=0.5, dashes=(1, 5), zorder=-1)
shaps = values[:, i]
fvalues = None if features is None else features[:, i]
inds = np.arange(len(shaps))
Expand Down Expand Up @@ -380,7 +396,7 @@

# plot the nan fvalues in the interaction feature as grey
nan_mask = np.isnan(fvalues)
pl.scatter(shaps[nan_mask], pos + ys[nan_mask], color="#777777",
ax.scatter(shaps[nan_mask], pos + ys[nan_mask], color="#777777",
s=s, alpha=alpha, linewidth=0,
zorder=3, rasterized=len(shaps) > 500)

Expand All @@ -390,13 +406,13 @@
cvals_imp[np.isnan(cvals)] = (vmin + vmax) / 2.0
cvals[cvals_imp > vmax] = vmax
cvals[cvals_imp < vmin] = vmin
pl.scatter(shaps[np.invert(nan_mask)], pos + ys[np.invert(nan_mask)],
ax.scatter(shaps[np.invert(nan_mask)], pos + ys[np.invert(nan_mask)],
cmap=color, vmin=vmin, vmax=vmax, s=s,
c=cvals, alpha=alpha, linewidth=0,
zorder=3, rasterized=len(shaps) > 500)
else:

pl.scatter(shaps, pos + ys, s=s, alpha=alpha, linewidth=0, zorder=3,
ax.scatter(shaps, pos + ys, s=s, alpha=alpha, linewidth=0, zorder=3,

Check warning on line 415 in shap/plots/_beeswarm.py

View check run for this annotation

Codecov / codecov/patch

shap/plots/_beeswarm.py#L415

Added line #L415 was not covered by tests
color=color if colored_feature else "#777777", rasterized=len(shaps) > 500)


Expand All @@ -405,7 +421,7 @@
import matplotlib.cm as cm
m = cm.ScalarMappable(cmap=color)
m.set_array([0, 1])
cb = pl.colorbar(m, ax=pl.gca(), ticks=[0, 1], aspect=80)
cb = fig.colorbar(m, ax=ax, ticks=[0, 1], aspect=80)
cb.set_ticklabels([labels['FEATURE_VALUE_LOW'], labels['FEATURE_VALUE_HIGH']])
cb.set_label(color_bar_label, size=12, labelpad=0)
cb.ax.tick_params(labelsize=11, length=0)
Expand All @@ -415,21 +431,22 @@
# cb.ax.set_aspect((bbox.height - 0.9) * 20)
# cb.draw_all()

pl.gca().xaxis.set_ticks_position('bottom')
pl.gca().yaxis.set_ticks_position('none')
pl.gca().spines['right'].set_visible(False)
pl.gca().spines['top'].set_visible(False)
pl.gca().spines['left'].set_visible(False)
pl.gca().tick_params(color=axis_color, labelcolor=axis_color)
pl.yticks(range(len(feature_inds)), reversed(yticklabels), fontsize=13)
pl.gca().tick_params('y', length=20, width=0.5, which='major')
pl.gca().tick_params('x', labelsize=11)
pl.ylim(-1, len(feature_inds))
pl.xlabel(labels['VALUE'], fontsize=13)
ax.xaxis.set_ticks_position('bottom')
ax.yaxis.set_ticks_position('none')
ax.spines['right'].set_visible(False)
ax.spines['top'].set_visible(False)
ax.spines['left'].set_visible(False)
ax.tick_params(color=axis_color, labelcolor=axis_color)
ax.set_yticks(range(len(feature_inds)), reversed(yticklabels), fontsize=13)
ax.tick_params('y', length=20, width=0.5, which='major')
ax.tick_params('x', labelsize=11)
ax.set_ylim(-1, len(feature_inds))
ax.set_xlabel(labels['VALUE'], fontsize=13)
if show:
pl.show()
else:
return pl.gca()
pl.close(fig)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

May I ask about the motivation behind pl.close? I think it may be a breaking change; we may as well decide a suitable pattern and plan to do that same for the other plots.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Of course. Jupyter will still render the plot even if show = false. Closing the figure will prevent that. But maybe I've misunderstood what show is meant to be doing.

Ah if its a breaking change then perhaps more thought should be put into it. How do you think it's a breaking change?

return ax

Check warning on line 449 in shap/plots/_beeswarm.py

View check run for this annotation

Codecov / codecov/patch

shap/plots/_beeswarm.py#L448-L449

Added lines #L448 - L449 were not covered by tests

def shorten_text(text, length_limit):
if len(text) > length_limit:
Expand Down