Skip to content

Commit

Permalink
Merge cc036f4 into e7833c2
Browse files Browse the repository at this point in the history
  • Loading branch information
CDonnerer committed Dec 31, 2021
2 parents e7833c2 + cc036f4 commit ce372e9
Show file tree
Hide file tree
Showing 7 changed files with 250 additions and 143 deletions.
8 changes: 4 additions & 4 deletions src/shellplot/_plotting.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
figure state.
"""
from dataclasses import dataclass
from typing import Dict, List
from typing import Callable, Dict, List

import numpy as np

Expand All @@ -16,7 +16,7 @@
class PlotCall:
"""Class for storing a call to a plot functions."""

func: callable
func: Callable
args: List
kwargs: Dict

Expand Down Expand Up @@ -116,7 +116,7 @@ def _hist(fig, x, bins=10, **kwargs):
counts_scaled = fig.y_axis.transform(counts)
bin_width = fig.x_axis.display_max // len(counts) - 1
display_max = (bin_width + 1) * len(counts)
fig.x_axis.scale = display_max / (fig.x_axis.limits[1] - fig.x_axis.limits[0])
fig.x_axis._scale = display_max / (fig.x_axis.limits[1] - fig.x_axis.limits[0])

bin = 0

Expand Down Expand Up @@ -147,7 +147,7 @@ def _barh(fig, x, labels=None, **kwargs):

bin_width = fig.y_axis.display_max // len(x) - 1
display_max = (bin_width + 1) * len(x)
fig.y_axis.scale = display_max / (fig.y_axis.limits[1] - fig.y_axis.limits[0])
fig.y_axis._scale = display_max / (fig.y_axis.limits[1] - fig.y_axis.limits[0])

if labels is not None:
fig.y_axis.ticklabels = labels
Expand Down
231 changes: 133 additions & 98 deletions src/shellplot/axis.py
Original file line number Diff line number Diff line change
@@ -1,19 +1,12 @@
"""Module that contains Axis class (usable for both x and y axis)
The main function of an axis is to transform from the data coordinates to the
display coordinates, hence we loosely follow an sklearn transformer api.
It can be used like so:
x_axis = Axis(display_length)
x_axis = x_axis.fit(x)
x_display = x_axis.transform(x)
where x_display is the data in display coordinates
"""
from typing import Any, Optional

import numpy as np

from shellplot.utils import (
difference_round,
is_datetime,
numpy_1d,
round_down,
round_up,
Expand All @@ -23,23 +16,67 @@
tolerance_round,
)

array_like = Any


class Axis:
def __init__(self, display_length, **kwargs):
"""Enables mapping from data to display / plot coordinates.
We loosely follow the sklearn transformer api:
>>> axis = Axis()
>>> axis = x_axis.fit(x_data)
>>> x_display = x_axis.transform(x_data)
where `x_data` and `x_display` correspond to data and display coordinates,
respectively.
When calling `.fit`, we will automatically determine 'reasonable' axis
limits and tick labels. Note that these can also be set by the user.
"""

def __init__(
self,
display_length: Optional[int] = 20,
label: Optional[str] = None,
limits: Optional[array_like] = None,
ticklabels: Optional[array_like] = None,
ticks: Optional[array_like] = None,
nticks: Optional[int] = None,
):
"""Instantiate a new Axis.
Parameters
----------
display_length : int, optional
Length of axis, in characters, default 20
label : Optional[str], optional
Axis label, default None
limits : Optional[array_like], optional
Axis limits, default None (auto-generated)
ticklabels : Optional[array_like], optional
Labels for axis ticks, default None (auto-generated, as ticks)
ticks : Optional[array_like], optional
Where the axis ticks should be, default None (auto-generated)
nticks : Optional[int], optional
Number of axis ticks, default None (auto-generated)
"""
self.display_max = display_length - 1
self._is_datetime = False # datetime axis
self._is_datetime = False # whether or not we are a datetime axis
self._scale = None

for key, value in kwargs.items():
setattr(self, key, value)
self.label = label
self.limits = limits
self.nticks = nticks
self.ticks = ticks
self.ticklabels = ticklabels

# -------------------------------------------------------------------------
# Public properties that can be set by the user
# Properties that can be set / modified by the user
# -------------------------------------------------------------------------

@property
def label(self):
if not hasattr(self, "_label"):
self._label = None
return self._label

@label.setter
Expand All @@ -48,84 +85,82 @@ def label(self, label):

@property
def limits(self):
if not hasattr(self, "_limits"):
self.limits = None
return self._limits

@limits.setter
def limits(self, limits):
self._limits = limits
if limits is not None:
self._limits, _ = to_numeric(np.array(limits))
if limits is not None: # new limits need to update scale and ticks
self._limits = to_numeric(limits)
self._set_scale()
self._reset_ticks()

@property
def n_ticks(self):
if not hasattr(self, "_n_ticks"):
self.n_ticks = self._auto_nticks()
return self._n_ticks
def nticks(self):
if self._nticks is None:
self._nticks = self._auto_nticks()
return self._nticks

@n_ticks.setter
def n_ticks(self, n_ticks):
self._n_ticks = n_ticks
@nticks.setter
def nticks(self, nticks):
self._reset_ticks()
self._nticks = nticks

@property
def ticks(self):
if not hasattr(self, "_ticks"):
if self._is_datetime:
self.ticks = self._get_dt_ticks()
else:
self.ticks = self._get_ticks()
if self._ticks is None:
self._ticks = self._auto_ticks()
return self._ticks

@ticks.setter
def ticks(self, ticks):
self._reset_ticks()
self._ticks = numpy_1d(ticks)
self.ticklabels = self.ticks

@property
def ticklabels(self):
if not hasattr(self, "_ticklabels"):
if self._is_datetime:
self.ticklabels = self._datetime_labels(self.ticks)
else:
self.ticklabels = self.ticks
if self._ticklabels is None:
self._ticklabels = self._auto_ticklabels()
return self._ticklabels

@ticklabels.setter
def ticklabels(self, labels):
if len(labels) != len(self.ticks):
raise ValueError("Len of tick labels must equal len of ticks!")
self._ticklabels = numpy_1d(labels)
def ticklabels(self, ticklabels):
if ticklabels is not None:
if len(ticklabels) != len(self.ticks):
raise ValueError("Len of tick labels must equal len of ticks!")
if is_datetime(ticklabels):
ticklabels = np.datetime_as_string(ticklabels)
self._ticklabels = numpy_1d(ticklabels)

# -------------------------------------------------------------------------
# Methods
# Public methods: fit, transform and generate ticks
# -------------------------------------------------------------------------

def fit(self, x):
"""Fit axis to get conversion from data to plot scale"""
x, self._is_datetime = to_numeric(x)
self._is_datetime = is_datetime(x)
x = to_numeric(x)

if self.limits is None:
self.limits = self._auto_limits(x)
self._limits = self._auto_limits(x)

self._set_scale()

return self

def transform(self, x):
x, _ = to_numeric(x)
x_scaled = self.scale * (x - self.limits[0]).astype(float)
"""Transform data to the plot coordinates"""
x = to_numeric(x)
x_scaled = self._scale * (x - self.limits[0]).astype(float)
x_display = np.around(x_scaled).astype(int)
return np.ma.masked_outside(x_display, 0, self.display_max)

def fit_transform(self, x):
"""Fit axis and transform data to the plot coordinates"""
self = self.fit(x)
return self.transform(x)

def gen_tick_labels(self):
"""Generate display tick location and labels"""
def generate_display_ticks(self):
"""Generate display tick locations and labels"""
display_ticks = self.transform(self.ticks)
within_display = np.logical_and(
display_ticks >= 0, display_ticks <= self.display_max
Expand All @@ -135,67 +170,67 @@ def gen_tick_labels(self):

return zip(display_ticks, display_labels)

# -------------------------------------------------------------------------
# Private methods: Auto scaling & ticks
# -------------------------------------------------------------------------

def _set_scale(self):
self.scale = self.display_max / float(self.limits[1] - self.limits[0])
self._scale = self.display_max / float(self.limits[1] - self.limits[0])

def _auto_limits(self, x, margin=0.25):
"""Automatically find good axis limits"""
x_max, x_min = x.max(), x.min()

max_difference = margin * (x_max - x_min)
ax_min = difference_round(x_min, round_down, max_difference)
ax_max = difference_round(x_max, round_up, max_difference)

return ax_min, ax_max

def _auto_nticks(self):
"""Automatically find number of ticks that fit display"""
max_ticks = int(1.5 * self.display_max ** 0.3) + 1
ticks = np.arange(max_ticks, max_ticks - 2, -1)
remainders = np.remainder(self.display_max, ticks)
return ticks[np.argmin(remainders)] + 1

def _get_ticks(self):
"""Generate sensible axis ticks"""
def _auto_ticks(self):
"""Automatically find good axis ticks"""
if self.limits is None:
raise ValueError("Please fit axis or set limits first!")
elif not self._is_datetime:
return self._auto_numeric_ticks()
else:
return self._auto_datetime_ticks()

def _auto_numeric_ticks(self, tol=0.05):
step, precision = tolerance_round(
(self.limits[1] - self.limits[0]) / (self.n_ticks - 1),
tol=0.05,
(self.limits[1] - self.limits[0]) / (self.nticks - 1),
tol=tol,
)
return np.around(
np.arange(self.limits[0], self.limits[1] + step, step), precision
)
)[: self.nticks]

def _get_dt_ticks(self):
"""Generate sensible axis ticks for datetime"""
def _auto_datetime_ticks(self):
axis_td = to_datetime(np.array(self.limits, dtype="timedelta64[ns]"))
limits_delta = axis_td[1] - axis_td[0]
unit = timedelta_round(limits_delta)
n_units = limits_delta / np.timedelta64(1, unit)
td_step = np.timedelta64(int(n_units / (self.n_ticks - 1)), unit)
td_step = np.timedelta64(int(n_units / (self.nticks - 1)), unit)

return np.arange(
np.datetime64(axis_td[0], unit),
np.datetime64(axis_td[1], unit) + td_step,
td_step,
)

def _datetime_labels(self, ticks):
# TODO: I don't know why the uncommented code existed
# [ns] should not be hardcoded
# dt_ticks = to_datetime(ticks.astype("timedelta64[ns]"))
# delta_ticks = dt_ticks[1] - dt_ticks[0] # TODO: this could fail
# unit = timedelta_round(delta_ticks)
return np.datetime_as_string(ticks) # , unit=unit)

def _auto_limits(self, x, frac=0.25):
"""Automatically find `good` axis limits"""
x_max = x.max()
x_min = x.min()
)[: self.nticks]

max_difference = frac * (x_max - x_min)
ax_min = self._difference_round(x_min, round_down, max_difference)
ax_max = self._difference_round(x_max, round_up, max_difference)

return ax_min, ax_max

def _difference_round(self, val, round_func, max_difference):
for dec in range(10):
rounded = round_func(val, dec)
if abs(rounded - val) <= max_difference:
return rounded

def _auto_nticks(self):
"""Automatically find a `good` number of axis ticks that fits display"""
max_ticks = int(1.5 * self.display_max ** 0.3) + 1
ticks = np.arange(max_ticks, max_ticks - 2, -1)
remainders = np.remainder(self.display_max, ticks)
return ticks[np.argmin(remainders)] + 1
def _auto_ticklabels(self):
if self._is_datetime:
return np.datetime_as_string(self.ticks)
else:
return self.ticks

def _reset_ticks(self):
"""Reset axis ticks and ticklabels"""
attrs = ["_ticks", "_ticklabels"]
for attr in attrs:
if hasattr(self, attr):
delattr(self, attr)
self._ticks = None
self._ticklabels = None
Loading

0 comments on commit ce372e9

Please sign in to comment.