Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ENH: New plotting API for beeswarm so can accept and return Axes #3561

Open
wants to merge 4 commits into
base: master
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
70 changes: 46 additions & 24 deletions shap/plots/_beeswarm.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
# TODO: Add support for hclustering based explanations where we sort the leaf order by magnitude and then show the dendrogram to the left
def beeswarm(shap_values, max_display=10, order=Explanation.abs.mean(0),
clustering=None, cluster_threshold=0.5, color=None,
axis_color="#333333", alpha=1, show=True, log_scale=False,
axis_color="#333333", alpha=1, ax=None, show=True, log_scale=False,
color_bar=True, s=16, plot_size="auto", color_bar_label=labels["FEATURE_VALUE"]):
"""Create a SHAP beeswarm plot, colored by feature values when they are provided.

Expand All @@ -41,6 +41,9 @@ def beeswarm(shap_values, max_display=10, order=Explanation.abs.mean(0),
How many top features to include in the plot (default is 10, or 7 for
interaction plots).

ax: matplotlib Axes
Axes object to draw the plot onto, otherwise uses the current Axes.

show : bool
Whether ``matplotlib.pyplot.show()`` is called before returning.
Setting this to ``False`` allows the plot to be customized further
Expand All @@ -57,7 +60,13 @@ def beeswarm(shap_values, max_display=10, order=Explanation.abs.mean(0),
number of features that are being displayed. Passing a single float will cause
each row to be that many inches high. Passing a pair of floats will scale the
plot by that number of inches. If ``None`` is passed, then the size of the
current figure will be left unchanged.
current figure will be left unchanged. If ax is not ``None``, then passing
plot_size will raise a Value Error.

Returns
-------
ax: matplotlib Axes
Returns the Axes object with the plot drawn onto it. Only returned if ``show=False``.

Examples
--------
Expand Down Expand Up @@ -85,6 +94,14 @@ def beeswarm(shap_values, max_display=10, order=Explanation.abs.mean(0),
)
raise ValueError(emsg)

if ax and plot_size:
emsg = (
"The beeswarm plot does not support passing an axis and adjusting the plot size. "
"To adjust the size of the plot, set plot_size to None and adjust the size on the original figure the axes was part of"
)
raise ValueError(emsg)


shap_exp = shap_values
# we make a copy here, because later there are places that might modify this array
values = np.copy(shap_exp.values)
Expand Down Expand Up @@ -155,8 +172,12 @@ def beeswarm(shap_values, max_display=10, order=Explanation.abs.mean(0),
if feature_names is None:
feature_names = np.array([labels['FEATURE'] % str(i) for i in range(num_features)])

fig = pl.gcf()
if ax is None:
ax = pl.gca()

if log_scale:
pl.xscale('symlog')
ax.set_xscale('symlog')

if clustering is None:
partition_tree = getattr(shap_values, "clustering", None)
Expand Down Expand Up @@ -319,16 +340,16 @@ def beeswarm(shap_values, max_display=10, order=Explanation.abs.mean(0),

row_height = 0.4
if plot_size == "auto":
pl.gcf().set_size_inches(8, min(len(feature_order), max_display) * row_height + 1.5)
fig.set_size_inches(8, min(len(feature_order), max_display) * row_height + 1.5)
elif type(plot_size) in (list, tuple):
pl.gcf().set_size_inches(plot_size[0], plot_size[1])
fig.set_size_inches(plot_size[0], plot_size[1])
elif plot_size is not None:
pl.gcf().set_size_inches(8, min(len(feature_order), max_display) * plot_size + 1.5)
pl.axvline(x=0, color="#999999", zorder=-1)
fig.set_size_inches(8, min(len(feature_order), max_display) * plot_size + 1.5)
ax.axvline(x=0, color="#999999", zorder=-1)

# make the beeswarm dots
for pos, i in enumerate(reversed(feature_inds)):
pl.axhline(y=pos, color="#cccccc", lw=0.5, dashes=(1, 5), zorder=-1)
ax.axhline(y=pos, color="#cccccc", lw=0.5, dashes=(1, 5), zorder=-1)
shaps = values[:, i]
fvalues = None if features is None else features[:, i]
inds = np.arange(len(shaps))
Expand Down Expand Up @@ -380,7 +401,7 @@ def beeswarm(shap_values, max_display=10, order=Explanation.abs.mean(0),

# plot the nan fvalues in the interaction feature as grey
nan_mask = np.isnan(fvalues)
pl.scatter(shaps[nan_mask], pos + ys[nan_mask], color="#777777",
ax.scatter(shaps[nan_mask], pos + ys[nan_mask], color="#777777",
s=s, alpha=alpha, linewidth=0,
zorder=3, rasterized=len(shaps) > 500)

Expand All @@ -390,13 +411,13 @@ def beeswarm(shap_values, max_display=10, order=Explanation.abs.mean(0),
cvals_imp[np.isnan(cvals)] = (vmin + vmax) / 2.0
cvals[cvals_imp > vmax] = vmax
cvals[cvals_imp < vmin] = vmin
pl.scatter(shaps[np.invert(nan_mask)], pos + ys[np.invert(nan_mask)],
ax.scatter(shaps[np.invert(nan_mask)], pos + ys[np.invert(nan_mask)],
cmap=color, vmin=vmin, vmax=vmax, s=s,
c=cvals, alpha=alpha, linewidth=0,
zorder=3, rasterized=len(shaps) > 500)
else:

pl.scatter(shaps, pos + ys, s=s, alpha=alpha, linewidth=0, zorder=3,
ax.scatter(shaps, pos + ys, s=s, alpha=alpha, linewidth=0, zorder=3,
color=color if colored_feature else "#777777", rasterized=len(shaps) > 500)


Expand All @@ -405,7 +426,7 @@ def beeswarm(shap_values, max_display=10, order=Explanation.abs.mean(0),
import matplotlib.cm as cm
m = cm.ScalarMappable(cmap=color)
m.set_array([0, 1])
cb = pl.colorbar(m, ax=pl.gca(), ticks=[0, 1], aspect=80)
cb = fig.colorbar(m, ax=ax, ticks=[0, 1], aspect=80)
cb.set_ticklabels([labels['FEATURE_VALUE_LOW'], labels['FEATURE_VALUE_HIGH']])
cb.set_label(color_bar_label, size=12, labelpad=0)
cb.ax.tick_params(labelsize=11, length=0)
Expand All @@ -415,21 +436,22 @@ def beeswarm(shap_values, max_display=10, order=Explanation.abs.mean(0),
# cb.ax.set_aspect((bbox.height - 0.9) * 20)
# cb.draw_all()

pl.gca().xaxis.set_ticks_position('bottom')
pl.gca().yaxis.set_ticks_position('none')
pl.gca().spines['right'].set_visible(False)
pl.gca().spines['top'].set_visible(False)
pl.gca().spines['left'].set_visible(False)
pl.gca().tick_params(color=axis_color, labelcolor=axis_color)
pl.yticks(range(len(feature_inds)), reversed(yticklabels), fontsize=13)
pl.gca().tick_params('y', length=20, width=0.5, which='major')
pl.gca().tick_params('x', labelsize=11)
pl.ylim(-1, len(feature_inds))
pl.xlabel(labels['VALUE'], fontsize=13)
ax.xaxis.set_ticks_position('bottom')
ax.yaxis.set_ticks_position('none')
ax.spines['right'].set_visible(False)
ax.spines['top'].set_visible(False)
ax.spines['left'].set_visible(False)
ax.tick_params(color=axis_color, labelcolor=axis_color)
ax.set_yticks(range(len(feature_inds)), reversed(yticklabels), fontsize=13)
ax.tick_params('y', length=20, width=0.5, which='major')
ax.tick_params('x', labelsize=11)
ax.set_ylim(-1, len(feature_inds))
ax.set_xlabel(labels['VALUE'], fontsize=13)
if show:
pl.show()
else:
return pl.gca()
pl.close(fig)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

May I ask about the motivation behind pl.close? I think it may be a breaking change; we may as well decide a suitable pattern and plan to do that same for the other plots.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Of course. Jupyter will still render the plot even if show = false. Closing the figure will prevent that. But maybe I've misunderstood what show is meant to be doing.

Ah if its a breaking change then perhaps more thought should be put into it. How do you think it's a breaking change?

return ax

def shorten_text(text, length_limit):
if len(text) > length_limit:
Expand Down