Skip to content

Commit

Permalink
added axis property settings test
Browse files Browse the repository at this point in the history
  • Loading branch information
CDonnerer committed Dec 31, 2021
1 parent cbdb883 commit 7fb7d36
Show file tree
Hide file tree
Showing 3 changed files with 75 additions and 38 deletions.
48 changes: 25 additions & 23 deletions src/shellplot/axis.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,9 +20,9 @@


class Axis:
"""Enables mapping from data to display coordinates.
"""Enables mapping from data to display / plot coordinates.
We loosely follow the sklearn transform api:
We loosely follow the sklearn transformer api:
>>> axis = Axis()
>>> axis = x_axis.fit(x_data)
Expand All @@ -49,17 +49,17 @@ def __init__(
Parameters
----------
display_length : int, optional
Length of axis, in characters (default 20)
Length of axis, in characters, default 20
label : Optional[str], optional
Axis label, by default None
Axis label, default None
limits : Optional[array_like], optional
Axis limits, by default None (auto-generated)
Axis limits, default None (auto-generated)
ticklabels : Optional[array_like], optional
Labels for axis ticks, by default None (auto-generated, as ticks)
Labels for axis ticks, default None (auto-generated, as ticks)
ticks : Optional[array_like], optional
Where the axis ticks should be. Default None (auto-generated)
Where the axis ticks should be, default None (auto-generated)
nticks : Optional[int], optional
Number of axis ticks. Default None (auto-generated)
Number of axis ticks, default None (auto-generated)
"""
self.display_max = display_length - 1
self._is_datetime = False # whether or not we are a datetime axis
Expand Down Expand Up @@ -90,8 +90,8 @@ def limits(self):
@limits.setter
def limits(self, limits):
self._limits = limits
if limits is not None:
self._limits = to_numeric(np.array(limits))
if limits is not None: # new limits need to update scale and ticks
self._limits = to_numeric(limits)
self._set_scale()
self._reset_ticks()

Expand Down Expand Up @@ -124,16 +124,16 @@ def ticklabels(self):
return self._ticklabels

@ticklabels.setter
def ticklabels(self, labels):
if labels is not None:
if len(labels) != len(self.ticks):
def ticklabels(self, ticklabels):
if ticklabels is not None:
if len(ticklabels) != len(self.ticks):
raise ValueError("Len of tick labels must equal len of ticks!")
if self._is_datetime:
labels = np.datetime_as_string(labels)
self._ticklabels = numpy_1d(labels)
if is_datetime(ticklabels):
ticklabels = np.datetime_as_string(ticklabels)
self._ticklabels = numpy_1d(ticklabels)

# -------------------------------------------------------------------------
# Fit & transform
# Public methods: fit, transform and generate ticks
# -------------------------------------------------------------------------

def fit(self, x):
Expand All @@ -145,21 +145,22 @@ def fit(self, x):
self._limits = self._auto_limits(x)

self._set_scale()

return self

def transform(self, x):
"""Transform data to the plot coordinates"""
x = to_numeric(x)
x_scaled = self._scale * (x - self.limits[0]).astype(float)
x_display = np.around(x_scaled).astype(int)
return np.ma.masked_outside(x_display, 0, self.display_max)

def fit_transform(self, x):
"""Fit axis and transform data to the plot coordinates"""
self = self.fit(x)
return self.transform(x)

def gen_tick_labels(self):
"""Generate display tick location and labels"""
def generate_display_ticks(self):
"""Generate display tick locations and labels"""
display_ticks = self.transform(self.ticks)
within_display = np.logical_and(
display_ticks >= 0, display_ticks <= self.display_max
Expand All @@ -170,14 +171,14 @@ def gen_tick_labels(self):
return zip(display_ticks, display_labels)

# -------------------------------------------------------------------------
# Auto scaling & ticks
# Private methods: Auto scaling & ticks
# -------------------------------------------------------------------------

def _set_scale(self):
self._scale = self.display_max / float(self.limits[1] - self.limits[0])

def _auto_limits(self, x, margin=0.25):
"""Automatically find reasonable axis limits"""
"""Automatically find good axis limits"""
x_max, x_min = x.max(), x.min()

max_difference = margin * (x_max - x_min)
Expand All @@ -187,13 +188,14 @@ def _auto_limits(self, x, margin=0.25):
return ax_min, ax_max

def _auto_nticks(self):
"""Automatically find reasonable number of ticks that fit display"""
"""Automatically find number of ticks that fit display"""
max_ticks = int(1.5 * self.display_max ** 0.3) + 1
ticks = np.arange(max_ticks, max_ticks - 2, -1)
remainders = np.remainder(self.display_max, ticks)
return ticks[np.argmin(remainders)] + 1

def _auto_ticks(self):
"""Automatically find good axis ticks"""
if self.limits is None:
return None
elif not self._is_datetime:
Expand Down
6 changes: 3 additions & 3 deletions src/shellplot/drawing.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ def draw(canvas, x_axis, y_axis, legend=None) -> str:
"""
canvas_lines = _draw_canvas(canvas)

left_pad = max([len(str(val)) for (t, val) in y_axis.gen_tick_labels()]) + 1
left_pad = max([len(str(val)) for (t, val) in y_axis.generate_display_ticks()]) + 1
y_lines = _draw_y_axis(y_axis, left_pad)
x_lines = _draw_x_axis(x_axis, left_pad)

Expand Down Expand Up @@ -82,7 +82,7 @@ def _draw_canvas(canvas) -> List[str]:
def _draw_y_axis(y_axis, left_pad) -> List[str]:
y_lines = list()

y_ticks = list(y_axis.gen_tick_labels())
y_ticks = list(y_axis.generate_display_ticks())

for i in reversed(range(y_axis.display_max + 1)):
ax_line = ""
Expand All @@ -101,7 +101,7 @@ def _draw_y_axis(y_axis, left_pad) -> List[str]:


def _draw_x_axis(x_axis, left_pad) -> List[str]:
x_ticks = list(x_axis.gen_tick_labels())
x_ticks = list(x_axis.generate_display_ticks())

upper_ax = " " * left_pad + "└"
lower_ax = " " * left_pad + " "
Expand Down
59 changes: 47 additions & 12 deletions tests/test_axis.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,9 +95,9 @@ def test_axis_datetime_ticks(limits, n_ticks, expected_labels):
axis = Axis(display_length=79)
axis.nticks = n_ticks
axis.fit(np.array(limits))
labels = axis.ticklabels
ticklabels = axis.ticklabels

assert list(labels) == list(expected_labels)
assert list(ticklabels) == list(expected_labels)


@pytest.mark.parametrize(
Expand All @@ -109,12 +109,12 @@ def test_axis_datetime_ticks(limits, n_ticks, expected_labels):
((10, 12), np.array([10, 11, 12]), [(0, 10), (40, 11), (79, 12)]),
],
)
def test_axis_tick_labels(limits, ticks, expected_tick_labels):
def test_axis_display_ticks(limits, ticks, expected_tick_labels):
"""Test axis ticks generation"""
axis = Axis(display_length=80)
axis.limits = limits
axis.ticks = ticks
tick_labels = list(axis.gen_tick_labels())
tick_labels = list(axis.generate_display_ticks())

assert tick_labels == expected_tick_labels

Expand All @@ -136,21 +136,56 @@ def test_axis_ticklabels_len_error(ticks, labels):
axis.ticklabels = labels


def test_axis_reset():
"""Check that updating limits leads to new axis ticks"""
# -----------------------------------------------------------------------------
# Test axis property setting
# -----------------------------------------------------------------------------


@pytest.mark.parametrize(
# fmt: off
"axis_property, value, expected_value",
[
("label", "my fun label", "my fun label"),
("limits", (4, 5), np.array([4, 5])),
("nticks", 5, 5),
("ticks", (0, 1, 2), np.array([0, 1, 2])),
("ticklabels", ("a", "b", "c"), np.array(["a", "b", "c"])),
(
"ticklabels",
(
np.datetime64("2001-01-01"),
np.datetime64("2001-01-03"),
np.datetime64("2001-01-05"),
),
np.array(["2001-01-01", "2001-01-03", "2001-01-05"])
)
],
)
def test_axis_property_can_be_set(axis_property, value, expected_value):
axis = Axis()
axis = axis.fit((0, 1))
axis.nticks = 3

x = np.array([45, 123])
setattr(axis, axis_property, value)
set_value = getattr(axis, axis_property)

if isinstance(expected_value, np.ndarray):
np.testing.assert_array_equal(set_value, expected_value)
else:
assert set_value == expected_value


def test_axis_limit_update_changes_ticks():
"""Check that updating limits leads to new axis ticks"""
axis = Axis(display_length=80)
axis.fit(x)

axis.limits = (0, 300)
ticks = axis.ticks
np.testing.assert_array_equal(ticks, np.array([0, 50, 100, 150, 200, 250, 300]))
np.testing.assert_array_equal(
axis.ticks, np.array([0, 50, 100, 150, 200, 250, 300])
)

axis.limits = (50, 80)
ticks = axis.ticks
np.testing.assert_array_equal(ticks, np.array([50, 55, 60, 65, 70, 75, 80]))
np.testing.assert_array_equal(axis.ticks, np.array([50, 55, 60, 65, 70, 75, 80]))


def test_axis_properties():
Expand Down

0 comments on commit 7fb7d36

Please sign in to comment.