Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Addressing issue #525: Added helper for stacked #870

Merged
merged 7 commits into from
Jun 4, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file modified tests/baseline_images/test_draw/test_manual_legend.png
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
93 changes: 93 additions & 0 deletions tests/test_draw.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,10 +34,23 @@ def test_manual_legend_uneven_colors():
manual_legend(None, ('a', 'b', 'c'), ('r', 'g'))


@pytest.fixture(scope="class")
def data(request):

data = np.array(
[[4, 8, 7, 6, 5, 2, 1],
[6, 7, 9, 6, 9, 3, 6],
[5, 1, 6, 8, 4, 7, 8],
[6, 8, 1, 5, 6, 7, 4]]
)

request.cls.data = data

##########################################################################
## Visual test cases for high-level drawing utilities
##########################################################################

@pytest.mark.usefixtures("data")
class TestDraw(VisualTestCase):
"""
Visual tests for the high-level drawing utilities
Expand Down Expand Up @@ -67,3 +80,83 @@ def test_manual_legend(self):

# Assert image similarity
self.assert_images_similar(ax=ax)

def test_vertical_bar_stack(self):
"""
Test bar_stack for vertical orientation
"""
_, ax = plt.subplots()

# Plots stacked bar charts
bar_stack(self.data, ax=ax, orientation='v')

# Assert image similarity
self.assert_images_similar(ax=ax, tol=0.1)

def test_horizontal_bar_stack(self):
"""
Test bar_stack for horizontal orientation
"""
_, ax = plt.subplots()
# Plots stacked bar charts
bar_stack(self.data, ax=ax, orientation='h')

# Assert image similarity
self.assert_images_similar(ax=ax, tol=0.1)

def test_single_row_bar_stack(self):
"""
Test bar_stack for single row
"""
data = np.array([[4, 8, 7, 6, 5, 2, 1]])

_, ax = plt.subplots()

# Plots stacked bar charts
bar_stack(data, ax=ax)

# Assert image similarity
self.assert_images_similar(ax=ax, tol=0.1)

def test_labels_vertical(self):
"""
Test labels and ticks for vertical barcharts
"""
labels = ['books', 'cinema', 'cooking', 'gaming']
ticks = ['noun', 'verb', 'adverb', 'pronoun', 'preposition',
'digit', 'other']
_, ax = plt.subplots()

# Plots stacked bar charts
bar_stack(self.data, labels = labels, ticks=ticks,
colors=['r','b','g','y'])

# Extract tick labels from the plot
ticks_ax = [tick.get_text() for tick in ax.xaxis.get_ticklabels()]
#Assert that ticks are set properly
assert ticks_ax==ticks
bbengfort marked this conversation as resolved.
Show resolved Hide resolved

# Assert image similarity
self.assert_images_similar(ax=ax, tol=0.05)

def test_labels_horizontal(self):
"""
Test labels and ticks with horizontal barcharts
"""
labels = ['books', 'cinema', 'cooking', 'gaming']
ticks = ['noun', 'verb', 'adverb', 'pronoun', 'preposition',
'digit', 'other']
_, ax = plt.subplots()

# Plots stacked bar charts
bar_stack(self.data, labels = labels, ticks=ticks, orientation='h',
colormap='cool')

# Extract tick labels from the plot
ticks_ax = [tick.get_text() for tick in ax.yaxis.get_ticklabels()]
#Assert that ticks are set properly
assert ticks_ax==ticks

# Assert image similarity
self.assert_images_similar(ax=ax, tol=0.05)

78 changes: 77 additions & 1 deletion yellowbrick/draw.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,11 +16,12 @@

from .base import Visualizer
from .exceptions import YellowbrickValueError
from .style.colors import resolve_colors

from matplotlib import patches

import matplotlib.pyplot as plt

import numpy as np

##########################################################################
## Legend Drawing Utilities
Expand Down Expand Up @@ -89,3 +90,78 @@ def manual_legend(g, labels, colors, **legend_kwargs):
# Return the Legend artist
return g.legend(handles=handles, **legend_kwargs)


def bar_stack(data, ax=None, labels=None, ticks=None, colors=None,
orientation='vertical', colormap=None, **kwargs):
"""
An advanced bar chart plotting utility that can draw bar and stacked bar charts from
data, wrapping calls to the specified matplotlib.Axes object.

Parameters
----------
data : 2D array-like
The data associated with the bar chart where the columns represent each bar
and the rows represent each stack in the bar chart. A single bar chart would
be a 2D array with only one row, a bar chart with three stacks per bar would
have a shape of (3, b).
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We might need to make this a bit more readable and understandable. @DistrictDataLabs/team-oz-maintainers does this make sense to you? Any thoughts?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I picked it up from the #525. It had a basic function signature. We can change this if you suggest.


ax : matplotlib.Axes, default: None
The axes object to draw the barplot on, uses plt.gca() if not specified.

labels : list of str, default: None
The labels for each row in the bar stack, used to create a legend.

ticks : list of str, default: None
The labels for each bar, added to the x-axis for a vertical plot, or the y-axis
for a horizontal plot.

colors : array-like, default: None
Specify the colors of each bar, each row in the stack, or every segment.

colormap : string or matplotlib cmap
Specify a colormap for each bar, each row in the stack, or every segment.

kwargs : dict
Additional keyword arguments to pass to ``ax.bar``.
"""
if ax is None:
ax = plt.gca()

colors = resolve_colors(n_colors=data.shape[0],
colormap=colormap,
colors=colors)

idx = np.arange(data.shape[1])
prev = np.zeros(data.shape[1])
orientation = orientation.lower()
if orientation.startswith('v'):
for rdx,row in enumerate(data):
ax.bar(idx,
row,
bottom = prev,
color = colors[rdx])
prev+=row
ax.set_xticks(idx)
if ticks is not None:
ax.set_xticklabels(ticks, rotation=90)

elif orientation.startswith('h'):
for rdx,row in enumerate(data):
ax.barh(idx,
row,
left = prev,
color = colors[rdx])
prev+=row
ax.set_yticks(idx)
if ticks is not None:
ax.set_yticklabels(ticks)
else:
raise YellowbrickValueError(
"unknown orientation '{}'".format(orientation)
)

# Generates default labels is labels are not specified.
labels = labels or np.arange(data.shape[0])

manual_legend(ax, labels=labels, colors=colors)
return ax