-
-
Notifications
You must be signed in to change notification settings - Fork 554
/
test_draw.py
163 lines (125 loc) · 4.81 KB
/
test_draw.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
# tests.test_draw
# Tests for the high-level drawing utility functions
#
# Author: Benjamin Bengfort <benjamin@bengfort.com>
# Created: Sun Aug 19 11:21:04 2018 -0400
#
# ID: test_draw.py [dd915ad] benjamin@bengfort.com $
"""
Tests for the high-level drawing utility functions
"""
##########################################################################
## Imports
##########################################################################
import pytest
import numpy as np
import matplotlib.pyplot as plt
from yellowbrick.draw import *
from .base import VisualTestCase
##########################################################################
## Simple tests for high-level drawing utilities
##########################################################################
def test_manual_legend_uneven_colors():
"""
Raise exception when colors and labels are mismatched in manual_legend
"""
with pytest.raises(YellowbrickValueError, match="same number of colors as labels"):
manual_legend(None, ("a", "b", "c"), ("r", "g"))
@pytest.fixture(scope="class")
def data(request):
data = np.array(
[
[4, 8, 7, 6, 5, 2, 1],
[6, 7, 9, 6, 9, 3, 6],
[5, 1, 6, 8, 4, 7, 8],
[6, 8, 1, 5, 6, 7, 4],
]
)
request.cls.data = data
##########################################################################
## Visual test cases for high-level drawing utilities
##########################################################################
@pytest.mark.usefixtures("data")
class TestDraw(VisualTestCase):
"""
Visual tests for the high-level drawing utilities
"""
def test_manual_legend(self):
"""
Check that the manual legend is drawn without axes artists
"""
# Draw a random scatter plot
random = np.random.RandomState(42)
Ax, Ay = random.normal(50, 2, 100), random.normal(50, 3, 100)
Bx, By = random.normal(42, 3, 100), random.normal(44, 1, 100)
Cx, Cy = random.normal(20, 10, 100), random.normal(30, 1, 100)
_, ax = plt.subplots()
ax.scatter(Ax, Ay, c="r", alpha=0.35, label="a")
ax.scatter(Bx, By, c="g", alpha=0.35, label="b")
ax.scatter(Cx, Cy, c="b", alpha=0.35, label="c")
# Add the manual legend
manual_legend(
ax, ("a", "b", "c"), ("r", "g", "b"), frameon=True, loc="upper left"
)
# Assert image similarity
self.assert_images_similar(ax=ax, tol=0.5)
def test_vertical_bar_stack(self):
"""
Test bar_stack for vertical orientation
"""
_, ax = plt.subplots()
# Plots stacked bar charts
bar_stack(self.data, ax=ax, orientation="v")
# Assert image similarity
self.assert_images_similar(ax=ax, tol=0.1)
def test_horizontal_bar_stack(self):
"""
Test bar_stack for horizontal orientation
"""
_, ax = plt.subplots()
# Plots stacked bar charts
bar_stack(self.data, ax=ax, orientation="h")
# Assert image similarity
self.assert_images_similar(ax=ax, tol=0.1)
def test_single_row_bar_stack(self):
"""
Test bar_stack for single row
"""
data = np.array([[4, 8, 7, 6, 5, 2, 1]])
_, ax = plt.subplots()
# Plots stacked bar charts
bar_stack(data, ax=ax)
# Assert image similarity
self.assert_images_similar(ax=ax, tol=0.1)
def test_labels_vertical(self):
"""
Test labels and ticks for vertical barcharts
"""
labels = ["books", "cinema", "cooking", "gaming"]
ticks = ["noun", "verb", "adverb", "pronoun", "preposition", "digit", "other"]
_, ax = plt.subplots()
# Plots stacked bar charts
bar_stack(self.data, labels=labels, ticks=ticks, colors=["r", "b", "g", "y"])
# Extract tick labels from the plot
ticks_ax = [tick.get_text() for tick in ax.xaxis.get_ticklabels()]
# Assert that ticks are set properly
assert ticks_ax == ticks
# Assert image similarity
self.assert_images_similar(ax=ax, tol=0.05)
def test_labels_horizontal(self):
"""
Test labels and ticks with horizontal barcharts
"""
labels = ["books", "cinema", "cooking", "gaming"]
ticks = ["noun", "verb", "adverb", "pronoun", "preposition", "digit", "other"]
_, ax = plt.subplots()
# Plots stacked bar charts
bar_stack(
self.data, labels=labels, ticks=ticks, orientation="h", colormap="cool"
)
# Extract tick labels from the plot
ticks_ax = [tick.get_text() for tick in ax.yaxis.get_ticklabels()]
# Assert that ticks are set properly
assert ticks_ax == ticks
# Assert image similarity
self.assert_images_similar(ax=ax, tol=0.05)