Skip to content

Commit

Permalink
Add styles parameter to manual_legend
Browse files Browse the repository at this point in the history
Added styles parameter to manual_legend per issue #566. With tests and bug fix for existing manual_legend test.
  • Loading branch information
tktran committed Aug 16, 2020
1 parent bc2c2fc commit bee1c3b
Show file tree
Hide file tree
Showing 4 changed files with 149 additions and 33 deletions.
Binary file modified tests/baseline_images/test_draw/test_manual_legend.png
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
64 changes: 62 additions & 2 deletions tests/test_draw.py
Expand Up @@ -31,9 +31,31 @@ def test_manual_legend_uneven_colors():
"""
Raise exception when colors and labels are mismatched in manual_legend
"""
with pytest.raises(YellowbrickValueError, match="same number of colors as labels"):
with pytest.raises(YellowbrickValueError,
match="list of length equal to the number of labels"):
manual_legend(None, ("a", "b", "c"), ("r", "g"))

def test_manual_legend_styles_malformed_input():
"""
Raise exception when styles and/or colors are not lists of same length
as labels
"""

# styles should be a list of strings
with pytest.raises(YellowbrickValueError,
match="Please specify the styles parameter as a list of strings"):
manual_legend(None, ("a", "b", "c"), styles="ro")

# styles should be a list of same len() as labels
with pytest.raises(YellowbrickValueError,
match="list of length equal to the number of labels"):
manual_legend(None, ("a", "b", "c"), styles=("ro", "--"))

# if colors is passed in alongside styles, it should be of same length
with pytest.raises(YellowbrickValueError,
match="list of length equal to the number of labels"):
manual_legend(None, ("a", "b", "c"), ("r", "g"), styles=("ro", "b--", "--"))


@pytest.fixture(scope="class")
def data(request):
Expand Down Expand Up @@ -83,7 +105,45 @@ def test_manual_legend(self):
)

# Assert image similarity
self.assert_images_similar(ax=ax, tol=0.5)
self.assert_images_similar(ax=ax, tol=0.5, remove_legend=False)

def test_manual_legend_styles(self):
"""
Check that the styles argument to manual_legend is correctly
processed, including its being overridden by the colors argument
"""

# Draw a random scatter plot
random = np.random.RandomState(42)

Ax, Ay = random.normal(50, 2, 100), random.normal(50, 3, 100)
Bx, By = random.normal(42, 3, 100), random.normal(44, 1, 100)
Cx, Cy = random.normal(20, 10, 100), random.normal(30, 1, 100)
Dx, Dy = random.normal(33, 5, 100), random.normal(22, 2, 100)

_, ax = plt.subplots()
ax.scatter(Ax, Ay, c="r", alpha=0.35, label="a")
ax.scatter(Bx, By, c="g", alpha=0.35, label="b")
ax.scatter(Cx, Cy, c="b", alpha=0.35, label="c")
ax.scatter(Dx, Dy, c="y", alpha=0.35, label="d")

# Four style/color combinations are tested here:
# (1) "blue" color should override the "r" of "ro" style
# (2) blank color should, of course, be overriden by the "g" of "-g"
# (3) None color should also be overridden by the third style, but
# since a color is not specified there either, the entry should
# default to black.
# (4) Linestyle, marker, and color are all unspecified. The entry should
# default to a solid black line.
styles = ["ro", "-g", "--", ""]
labels = ("a", "b", "c", "d")
colors = ("blue", "", None, None)
manual_legend(
ax, labels, colors, styles=styles, frameon=True, loc="upper left"
)

# Assert image similarity
self.assert_images_similar(ax=ax, tol=0.5, remove_legend=False)

def test_vertical_bar_stack(self):
"""
Expand Down
118 changes: 87 additions & 31 deletions yellowbrick/draw.py
Expand Up @@ -21,7 +21,7 @@
from .exceptions import YellowbrickValueError
from .style.colors import resolve_colors

from matplotlib import patches
from matplotlib import axes, patches, lines

import matplotlib.pyplot as plt
import numpy as np
Expand All @@ -30,15 +30,17 @@
## Legend Drawing Utilities
##########################################################################


def manual_legend(g, labels, colors, **legend_kwargs):
def manual_legend(g, labels, colors=None, styles=None, **legend_kwargs):
"""
Adds a manual legend for a scatter plot to the visualizer where the labels
and associated colors are drawn with circle patches instead of determining
them from the labels of the artist objects on the axes. This helper is
used either when there are a lot of duplicate labels, no labeled artists,
or when the color of the legend doesn't exactly match the color in the
figure (e.g. because of the use of transparency).
Adds a manual legend for a scatter plot to the visualizer. The legend
entries are drawn according to the ``styles`` parameter if specified, and
with circle patches (colored according to ``colors``) if not specified.
Calling this function overrides the default behavior of drawing the legend
from the labels of the artist objects on the axes.
This helper is used either when there are a lot of duplicate labels,
no labeled artists, or when the color of the legend doesn't exactly
match the color in the figure (e.g. because of the use of transparency).
Parameters
----------
Expand All @@ -51,10 +53,21 @@ def manual_legend(g, labels, colors, **legend_kwargs):
The text labels to associate with the legend. Note that the labels
will be added to the legend in the order specified.
colors : list of colors
A list of any valid matplotlib color reference. The number of colors
specified must be equal to the number of labels.
colors : list of colors, default: None
A list of any valid matplotlib color references. If ``styles``
is provided, colors must be either ``None`` or a list of equal length to
``labels``; in the latter case, this parameter takes predence over any
colors specified in ``styles``. To skip specifying a color for a
particular entry, use an empty string, None, or 'None'.
styles : list of str, default: None
A list of matplotlib-style format strings, each corresponding to a label
and describing its graphical appearance in the legend, e.g., 'ro' for a
red circle. The number of styles specified must be equal to the number
of labels. Either one or both of ``colors`` and ``styles`` must be
specified. Consistent with matplotlib, blank style entries default to
solid, unmarked, black lines.
legend_kwargs : dict
Any additional keyword arguments to pass to the legend.
Expand All @@ -64,36 +77,78 @@ def manual_legend(g, labels, colors, **legend_kwargs):
The artist created by the ax.legend() call, returned for further
manipulation if required by the caller.
Notes
-----
Right now this method simply draws the patches as rectangles and cannot
take into account the line or scatter plot properties (e.g. line style or
marker style). It is possible to add Line2D patches to the artist that do
add manual styles like this, which we can explore in the future.
.. seealso:: https://matplotlib.org/gallery/text_labels_and_annotations/custom_legends.html
.. seealso:: https://matplotlib.org/3.3.0/api/_as_gen/matplotlib.pyplot.plot.html
"""

# Get access to the matplotlib Axes
if isinstance(g, Visualizer):
g = g.ax
elif g is None:
g = plt.gca()

# Ensure that labels and colors are the same length to prevent odd behavior.
if len(colors) != len(labels):
raise YellowbrickValueError(
"please specify the same number of colors as labels!"
)

# Create the legend handles with the associated colors and labels
handles = [
patches.Patch(color=color, label=label) for color, label in zip(colors, labels)
]
if styles:
# Documented the `styles` parameter as being a list when really
# it makes sense to accept it as a list or a tuple
if type(styles) not in (list, tuple):
raise YellowbrickValueError(
"Please specify the styles parameter as a list of strings!"
)

if len(styles) != len(labels):
raise YellowbrickValueError(
"Please specify the styles parameter as a list of length "
"equal to the number of labels!"
)

if colors is not None and len(colors) != len(labels):
raise YellowbrickValueError(
"Please specify the colors parameter either as colors=None or "
"a list of length equal to the number of labels. You can use "
"an empty string or None as a placeholder for colors that "
"are already specified in the corresponding styles entry."
)
else:
if colors is None or len(colors) != len(labels):
raise YellowbrickValueError(
"Please specify the colors parameter as a list of length equal "
"to the number of labels!"
)

# Set legend's artist handles to:
# linestyles/markers/colors specified by `styles` if passed in, or
# patches according to `colors` if it is not
if styles:
if colors is None:
colors = [None] * len(styles)
else:
colors = [None if color in ("", " ", None, 'None') else color
for color in colors]

handles = list()
for style, color, label in zip(styles, colors, labels):
linestyle, marker, style_color = \
axes._base._process_plot_format(style)

# colors parameter should take precedence over styles,
# consistent with matplotlib
color = color or style_color or 'black'
# _process_plot_format() above will have already set linestyle to
# '-' and marker to 'None' if they weren't specified

line_2d = lines.Line2D([0], [0], linestyle=linestyle, marker=marker,
color=color, label=label)
handles.append(line_2d)
else:
handles = [
patches.Patch(color=color, label=label) for
color, label in zip(colors, labels)
]

# Return the Legend artist
return g.legend(handles=handles, **legend_kwargs)


def bar_stack(
data,
ax=None,
Expand Down Expand Up @@ -192,3 +247,4 @@ def bar_stack(
legend_kws = legend_kws or {}
manual_legend(ax, labels=labels, colors=colors, **legend_kws)
return ax

0 comments on commit bee1c3b

Please sign in to comment.