From 0c0eaed2c669123193fec845a750f9598d9f944c Mon Sep 17 00:00:00 2001 From: Christian Donnerer Date: Fri, 25 Dec 2020 16:54:14 +0000 Subject: [PATCH] Plot options (#18) * initial stab at global figure options * added figure options * pandas api needs overhaul * added some docstrings * fixed bug in drawing, needs tests * added tests for axis drawing * fixed extra args passed in pandas api * fixed extra args passed in pandas api * added test for drawing canvas * added some docs in drawing module * updated docs for API ref * added custom api docs * revert makefile change * updated docstrings * updated docs with new options for plotting * fixed the drawing tests * updated examples and changelog --- CHANGELOG.rst | 7 ++ docs/api.rst | 21 ++++ docs/examples/basic_usage.rst | 49 ++++++++ docs/index.rst | 3 +- docs/installation.rst | 14 +++ examples/basic_usage.py | 28 ++++- examples/penguins_eda.py | 14 ++- src/shellplot/axis.py | 27 ++-- src/shellplot/drawing.py | 150 +++++++++++++--------- src/shellplot/pandas_api.py | 54 ++++---- src/shellplot/plots.py | 230 +++++++++++++++++++++++++++------- src/shellplot/utils.py | 39 ++++-- tests/test_drawing.py | 90 ++++++++++++- tests/test_plots.py | 12 ++ 14 files changed, 575 insertions(+), 163 deletions(-) create mode 100644 docs/api.rst diff --git a/CHANGELOG.rst b/CHANGELOG.rst index 2220d2b..c7d72a3 100644 --- a/CHANGELOG.rst +++ b/CHANGELOG.rst @@ -3,6 +3,13 @@ Changelog ========= +Version 0.1.4 +------------- +- Fixed bug in x-axis drawing, spurious rounding +- Added options to modify global figure properties, e.g. xlim, figsize, etc +- Updated docs for manually curated API reference as opposed to sphinx api-doc + + Version 0.1.3 ------------- - Fixed bug in x-axis drawing, now tick marks are aligned with plot diff --git a/docs/api.rst b/docs/api.rst new file mode 100644 index 0000000..058d985 --- /dev/null +++ b/docs/api.rst @@ -0,0 +1,21 @@ +.. _api_reference: + +API Reference +=================== + +Plotting functions +------------------- + +.. autofunction:: shellplot.plot + +.. autofunction:: shellplot.hist + +.. autofunction:: shellplot.barh + +.. autofunction:: shellplot.boxplot + + +Data loading +------------------- + +.. autofunction:: shellplot.load_dataset diff --git a/docs/examples/basic_usage.rst b/docs/examples/basic_usage.rst index 2e7093b..a6f3b7c 100644 --- a/docs/examples/basic_usage.rst +++ b/docs/examples/basic_usage.rst @@ -45,6 +45,54 @@ Scatter plots can be created via the ``plot`` function:: -4 -2 0 2 4 +It is possible to modify the appearance of the plot by passing keyword args, +using a similar syntax to `matplotlib`_. E.g. we could modify the above call to +``plot`` like so:: + + + >>> import numpy as np + >>> import shellplot as plt + >>> x = np.arange(-4, 4, 0.01) + >>> y = np.cos(x) + >>> plt_str = plt.plot( + x, y, + figsize=(40, 21), + xlim=(0, 3), + ylim=(-1, 1), + xlabel="x", + ylabel="cos(x)", + return_type="str", + ) + >>> print(plt_str) + + cos(x) + 1.0┤+++++ + | ++++ + | +++ + | +++ + | +++ + 0.5┤ ++ + | +++ + | ++ + | ++ + | ++ + 0.0┤ ++ + | ++ + | +++ + | ++ + | ++ + -0.5┤ ++ + | +++ + | ++ + | ++++ + | ++++ + -1.0┤ +++ + └┬------------┬------------┬------------┬ + 0 1 2 3 + x + +Please refer to :ref:`api_reference` for the full list of possible options. + Histogram ------------------- @@ -212,3 +260,4 @@ parameter:: .. _pandas: https://pandas.pydata.org/ +.. _matplotlib: https://matplotlib.org/contents.html# diff --git a/docs/index.rst b/docs/index.rst index b8ec96c..3b79599 100644 --- a/docs/index.rst +++ b/docs/index.rst @@ -53,10 +53,11 @@ Contents Installation Examples + API Reference License Authors Changelog - Module Reference + Indices and tables diff --git a/docs/installation.rst b/docs/installation.rst index c842c07..3e11eee 100644 --- a/docs/installation.rst +++ b/docs/installation.rst @@ -11,3 +11,17 @@ Install shellplot via ``pip``:: This will install the latest stable version, as well as the required dependencies. + + +Dependencies +------------- + +Shellplot has the following dependencies, all of which are installed automatically +with the above installation command: + +- python 3.6 or newer +- `Numpy`_ +- `Pandas`_ + +.. _NumPy: http://www.numpy.org/ +.. _Pandas: http://pandas.pydata.org diff --git a/examples/basic_usage.py b/examples/basic_usage.py index d4c164f..1be6940 100644 --- a/examples/basic_usage.py +++ b/examples/basic_usage.py @@ -4,13 +4,31 @@ x = np.arange(-4, 4, 0.01) y = np.cos(x) -plt.plot(x, y) +plt.plot(x, y, xlabel="x", ylabel="f(x)", figsize=(60, 25)) + +plt_str = plt.plot( + x, + y, + figsize=(40, 21), + xlim=(0, 3), + ylim=(-1, 1), + xlabel="x", + ylabel="cos(x)", + return_type="str", +) +print(plt_str) + x = [np.random.randn(100) for i in range(3)] -plt.boxplot(x, labels=np.array(["dist_1", "dist_2", "dist_3"])) +plt.boxplot(x, labels=np.array(["dist_1", "dist_2", "dist_3"]), figsize=(40, 25)) -x = np.random.randn(100) -plt.hist(x) +x = np.random.randn(1000) +plt.hist(x, bins=12, figsize=(40, 20), xlabel="normal distribution") x = np.logspace(0, 1, 3) -plt.barh(x, labels=np.array(["bar_1", "bar_b", "bar_3"])) +plt.barh( + x, + labels=np.array(["bar_1", "bar_b", "bar_3"]), + xlabel="my_fun_bars", + figsize=(40, 20), +) diff --git a/examples/penguins_eda.py b/examples/penguins_eda.py index 2cb5177..b8a7a98 100644 --- a/examples/penguins_eda.py +++ b/examples/penguins_eda.py @@ -8,16 +8,20 @@ df = plt.load_dataset("penguins") -df["body_mass_g"].hist() +df["body_mass_g"].hist(figsize=(60, 20)) -df[["species", "island"]].value_counts().plot.barh() +df["species"].value_counts().plot.barh(figsize=(30, 13)) -df.boxplot(column=["bill_length_mm", "bill_depth_mm"]) -df.boxplot(column=["bill_length_mm"], by="species") +# df[["island", "species"]].value_counts().plot.barh(figsize=(30, 30)) -df.dropna().plot("bill_length_mm", "flipper_length_mm", color="species") +df.boxplot(column=["bill_length_mm", "bill_depth_mm"], figsize=(80, 13)) +df.boxplot(column=["bill_length_mm"], by="species", figsize=(30, 13)) + +df.dropna().plot( + "bill_length_mm", "flipper_length_mm", color="species", figsize=(40, 23) +) # df.loc[df["species"] == "Adelie"].dropna().plot( # "bill_depth_mm", "body_mass_g", color="island" diff --git a/src/shellplot/axis.py b/src/shellplot/axis.py index 822a229..bea9297 100644 --- a/src/shellplot/axis.py +++ b/src/shellplot/axis.py @@ -18,22 +18,26 @@ class Axis: - def __init__(self, display_length, title=None, limits=None): + def __init__(self, display_length, label=None, limits=None): self.display_max = display_length - 1 - self._title = title - self._limits = limits + self.label = label + self.limits = limits + + # reverted setting ticks and labels - need to think about the logic here + # self.ticks = ticks + # self.labels = labels # ------------------------------------------------------------------------- # Public properties that can be set by the user # ------------------------------------------------------------------------- @property - def title(self): - return self._title + def label(self): + return self._label - @title.setter - def title(self, title): - self._title = title + @label.setter + def label(self, label): + self._label = label @property def limits(self): @@ -42,7 +46,8 @@ def limits(self): @limits.setter def limits(self, limits): self._limits = limits - self.fit() # setting axis limits automatically fits the axis + if limits is not None: + self.fit() # setting axis limits automatically fits the axis @property def n_ticks(self): @@ -89,7 +94,9 @@ def fit(self, x=None): return self def transform(self, x): - return np.around(self.scale * (x - self.limits[0])).astype(int) + x_scaled = np.around(self.scale * (x - self.limits[0])).astype(int) + within_display = np.logical_and(x_scaled >= 0, x_scaled <= self.display_max) + return np.ma.masked_where(~within_display, x_scaled) def fit_transform(self, x): self = self.fit(x) diff --git a/src/shellplot/drawing.py b/src/shellplot/drawing.py index 4a9df41..f02eaab 100644 --- a/src/shellplot/drawing.py +++ b/src/shellplot/drawing.py @@ -5,6 +5,7 @@ type of plot. """ +from typing import List PALETTE = { # empty space @@ -14,8 +15,8 @@ 2: "*", 3: "o", 4: "x", - 5: "_", - 6: "|", + 5: "@", + 6: ".", # bar drawing 20: "|", 21: "_", @@ -24,57 +25,49 @@ } -def draw(canvas, x_axis, y_axis, legend=None): - plt_lines = _draw_canvas(canvas) +def draw(canvas, x_axis, y_axis, legend=None) -> str: + """Draw figure from plot elements (i.e. canvas, x-axis, y-axis, legend) - label_len = max([len(str(val)) for (t, val) in y_axis.tick_labels()]) - l_pad = label_len + 1 + Internally, this functions draws all elements as list of strings, and then + joins them into a single string. - y_lines = _draw_y_axis(canvas, y_axis, l_pad) - x_lines = _draw_x_axis(canvas, x_axis, l_pad) + Parameters + ---------- + canvas : np.ndarray + The data to be drawn + x_axis : shellplot.axis.Axis + Fitted x-axis + y_axis : shellplot.axis.Axis + Fitted y-axis + legend : dict[str, str], optional + Legend of the plot + + Returns + ------- + str + The drawn figure + + """ + canvas_lines = _draw_canvas(canvas) + + left_pad = max([len(str(val)) for (t, val) in y_axis.tick_labels()]) + 1 + y_lines = _draw_y_axis(y_axis, left_pad) + x_lines = _draw_x_axis(x_axis, left_pad) if legend is not None: legend_lines = _draw_legend(legend) else: legend_lines = None - return _join_plot_lines(plt_lines, y_lines, x_lines, legend_lines) - - -def _draw_legend(legend): - legend_lines = list() - - for marker, name in legend.items(): - legend_str = f"{PALETTE[marker]} {name}" - legend_lines.append(legend_str) - return legend_lines - - -def _pad_lines(lines, ref_lines): - if lines is None: - lines = list() - - empty_pad = len(ref_lines) - len(lines) - return [""] * empty_pad + lines + return _join_plot_lines(canvas_lines, y_lines, x_lines, legend_lines) -def _join_plot_lines(plt_lines, y_lines, x_lines, legend_lines): - plt_str = "\n" +# ------------------------------------------------------------------------------ +# Drawing functions for individual plot elements (canvas, x-axis, y-axis, legend) +# ------------------------------------------------------------------------------ - plt_lines = _pad_lines(plt_lines, y_lines) - legend_lines = _pad_lines(legend_lines, y_lines) - - for ax, plt, leg in zip(y_lines, plt_lines, legend_lines): - plt_str += ax + plt + leg + "\n" - - for ax in x_lines: - plt_str += ax - - return plt_str - - -def _draw_canvas(canvas): +def _draw_canvas(canvas) -> List[str]: plt_lines = list() for i in reversed(range(canvas.shape[1])): @@ -86,50 +79,85 @@ def _draw_canvas(canvas): return plt_lines -def _draw_y_axis(canvas, y_axis, l_pad): +def _draw_y_axis(y_axis, left_pad) -> List[str]: y_lines = list() y_ticks = y_axis.tick_labels() - for i in reversed(range(canvas.shape[1])): + for i in reversed(range(y_axis.display_max + 1)): ax_line = "" if len(y_ticks) > 0 and i == y_ticks[-1][0]: - ax_line += f"{str(y_ticks[-1][1]).rjust(l_pad)}┤" + ax_line += f"{str(y_ticks[-1][1]).rjust(left_pad)}┤" y_ticks.pop(-1) else: - ax_line += " " * l_pad + "|" + ax_line += " " * left_pad + "|" y_lines.append(ax_line) - if y_axis.title is not None: - title_pad = l_pad - len(y_axis.title) // 2 - title_str = " " * title_pad + y_axis.title - y_lines.insert(0, title_str) + if y_axis.label is not None: + label_pad = left_pad - len(y_axis.label) // 2 + label_str = " " * label_pad + y_axis.label + y_lines.insert(0, label_str) return y_lines -def _draw_x_axis(canvas, x_axis, l_pad): +def _draw_x_axis(x_axis, left_pad) -> List[str]: x_ticks = x_axis.tick_labels() - upper_ax = " " * l_pad + "└" - lower_ax = " " * l_pad + " " + upper_ax = " " * left_pad + "└" + lower_ax = " " * left_pad + " " marker = "┬" + overpad = 50 - for j in range(canvas.shape[0]): + for j in range(x_axis.display_max + 1): if len(x_ticks) > 0 and j == x_ticks[0][0]: lower_ax = lower_ax[: len(upper_ax)] - label = str(round(x_ticks[0][1], 2)) - lower_ax += label + " " * 20 - + lower_ax += str(x_ticks[0][1]) + " " * overpad upper_ax += marker x_ticks.pop(0) else: upper_ax += "-" - ax_lines = [upper_ax + "\n", lower_ax + "\n"] + ax_lines = [upper_ax + "\n", lower_ax[: len(lower_ax) - overpad] + "\n"] - if x_axis.title is not None: - title_pad = int(canvas.shape[0] / 2 - len(x_axis.title) / 2) - title_str = " " * (l_pad + title_pad) + x_axis.title - ax_lines.append(title_str) + if x_axis.label is not None: + label_pad = (x_axis.display_max + 1) // 2 - len(x_axis.label) // 2 + label_str = " " * (left_pad + 1 + label_pad) + x_axis.label + ax_lines.append(label_str) return ax_lines + + +def _draw_legend(legend) -> List[str]: + legend_lines = list() + + for marker, name in legend.items(): + legend_str = f"{PALETTE[marker]} {name}" + legend_lines.append(legend_str) + return legend_lines + + +# ------------------------------------------------------------------------------ +# Helper functions +# ------------------------------------------------------------------------------ + + +def _join_plot_lines(plt_lines, y_lines, x_lines, legend_lines): + plt_str = "\n" + plt_lines = _pad_lines(plt_lines, y_lines) + legend_lines = _pad_lines(legend_lines, y_lines) + + for ax, plt, leg in zip(y_lines, plt_lines, legend_lines): + plt_str += ax + plt + leg + "\n" + + for ax in x_lines: + plt_str += ax + + return plt_str + + +def _pad_lines(lines, ref_lines): + if lines is None: + lines = list() + + empty_pad = len(ref_lines) - len(lines) + return [""] * empty_pad + lines diff --git a/src/shellplot/pandas_api.py b/src/shellplot/pandas_api.py index 24af654..3082095 100644 --- a/src/shellplot/pandas_api.py +++ b/src/shellplot/pandas_api.py @@ -22,26 +22,25 @@ def plot(data, kind, **kwargs): # TODO: check kind if isinstance(data, pd.Series): - return _plot_series(data, kind) + return _plot_series(data, kind, **kwargs) else: return _plot_frame(data, **kwargs) def hist_series(data, **kwargs): - return plt.hist(x=data.values, x_title=data.name, **kwargs) + return plt.hist(x=data, **kwargs) def boxplot_frame(data, *args, **kwargs): - column = kwargs.get("column", data.columns) - by = kwargs.get("by") + column = kwargs.pop("column", data.columns) + by = kwargs.pop("by") if by is not None: df = data.pivot(columns=by, values=column) - x_title = df.columns.get_level_values(0)[0] - y_title = by + xlabel = df.columns.get_level_values(0)[0] labels = df.columns.get_level_values(1) - kwargs.update({"x_title": x_title, "y_title": y_title, "labels": labels}) + kwargs.update({"xlabel": xlabel, "ylabel": by, "labels": labels}) else: df = data[column] kwargs.update({"labels": df.columns}) @@ -79,28 +78,35 @@ def _plot_series(data, kind, *args, **kwargs): def _series_barh(data, **kwargs): - return plt.barh( - x=data.values, labels=data.index, x_title=data.name, y_title=data.index.name - ) + x_col = kwargs.pop("x") + + if x_col is not None: + data = data[x_col] + + return plt.barh(x=data, labels=data.index, **kwargs) def _series_line(data, **kwargs): - return plt.plot( - x=data.index.values, - y=data.values, - x_title=data.index.name, - y_title=data.name, - ) + x_col = kwargs.pop("x") + y_col = kwargs.pop("y") + + # why do we get both x and y here? + if x_col is not None: + data = data[x_col] + if y_col is not None: + data = data[y_col] + + return plt.plot(x=data.index, y=data, **kwargs) def _series_boxplot(data, *args, **kwargs): - return plt.boxplot(data, labels=np.array([data.name])) + return plt.boxplot(data, labels=np.array([data.name]), **kwargs) def _plot_frame(data, **kwargs): - x_col = kwargs.get("x") - y_col = kwargs.get("y") - color = kwargs.get("color", None) + x_col = kwargs.pop("x") + y_col = kwargs.pop("y") + color = kwargs.pop("color", None) if x_col is None or y_col is None: raise ValueError("Please provide both x, y column names") @@ -111,10 +117,4 @@ def _plot_frame(data, **kwargs): s_x = data[x_col] s_y = data[y_col] - return plt.plot( - x=s_x.values, - y=s_y.values, - x_title=s_x.name, - y_title=s_y.name, - color=color, - ) + return plt.plot(x=s_x, y=s_y, color=color, **kwargs) diff --git a/src/shellplot/plots.py b/src/shellplot/plots.py index c9d69bd..6da7209 100644 --- a/src/shellplot/plots.py +++ b/src/shellplot/plots.py @@ -1,17 +1,13 @@ """Shellplot plots """ import numpy as np -import pandas as pd from shellplot.axis import Axis from shellplot.drawing import draw -from shellplot.utils import numpy_2d, remove_any_nan +from shellplot.utils import get_label, numpy_2d, remove_any_nan __all__ = ["plot", "hist", "barh", "boxplot"] -DISPLAY_X = 70 -DISPLAY_Y = 25 - # ----------------------------------------------------------------------------- # Exposed functions that directly print the plot @@ -19,23 +15,152 @@ def plot(*args, **kwargs): + """Plot x versus y as scatter. + + Parameters + ---------- + x : array-like + The horizontal coordinates of the data points. + Should be 1d np.ndarray or pandas series + y : array-like + The vertical coordinates of the data points. + Should be 1d np.ndarray or pandas series + color : array, optional + Color of scatter. Needs to be of same dimension as x, y + Should be 1-d np.ndarray or pandas series + figsize : a tuple (width, height) in ascii characters, optional + Size of the figure. + xlim : 2-tuple/list, optional + Set the x limits. + ylim : 2-tuple/list, optional + Set the y limits. + xlabel : str, optional + Name to use for the xlabel on x-axis. + ylabel : str, optional + Name to use for the ylabel on y-axis. + return_type : str, optional + If `'str'`, returns the plot as a string. Otherwise, the plot will be + directly printed to stdout. + + Returns + ------- + result + See Notes. + + """ plt_str = _plot(*args, **kwargs) - print(plt_str) + return return_plt(plt_str, **kwargs) def hist(*args, **kwargs): + """Plot a histogram of x + + Parameters + ---------- + x : array-like + The array of points to plot a histogram of. Should be 1d np.ndarray or + pandas series. + bins : int, optional + Number of bins in histogram. Default is 10 bins. + figsize : a tuple (width, height) in ascii characters, optional + Size of the figure. + xlim : 2-tuple/list, optional + Set the x limits. + ylim : 2-tuple/list, optional + Set the y limits. + xlabel : str, optional + Name to use for the xlabel on x-axis. + ylabel : str, optional + Name to use for the ylabel on y-axis. + return_type : str, optional + If `'str'`, returns the plot as a string. Otherwise, the plot will be + directly printed to stdout. + + Returns + ------- + result + See Notes. + + """ plt_str = _hist(*args, **kwargs) - print(plt_str) + return return_plt(plt_str, **kwargs) def barh(*args, **kwargs): + """Plot horizontal bars + + Parameters + ---------- + x : array-like + The witdth of the horizontal bars. Should be 1d np.ndarray or pandas + series. + labels : array-like + Array that is used to label the bars. Needs to have the same dim as x. + figsize : a tuple (width, height) in ascii characters, optional + Size of the figure. + xlim : 2-tuple/list, optional + Set the x limits. + ylim : 2-tuple/list, optional + Set the y limits. + xlabel : str, optional + Name to use for the xlabel on x-axis. + ylabel : str, optional + Name to use for the ylabel on y-axis. + return_type : str, optional + If `'str'`, returns the plot as a string. Otherwise, the plot will be + directly printed to stdout. + + Returns + ------- + result + See Notes. + + """ plt_str = _barh(*args, **kwargs) - print(plt_str) + return return_plt(plt_str, **kwargs) def boxplot(*args, **kwargs): + """Plot a boxplot of x + + Note that currently this makes a boxplot using the quantiles: + [0, 0.25, 0.5, 0.75, 1.0] - i.e. it the whiskers will not exclude outliers + + Parameters + ---------- + x : array-like + The horizontal coordinates of the data points. + Can be 1d or 2d np.ndarray/ pandas series/ dataframe. If 2d, each 1d + slice will be plotted as a separate boxplot. + figsize : a tuple (width, height) in ascii characters, optional + Size of the figure. + xlim : 2-tuple/list, optional + Set the x limits. + ylim : 2-tuple/list, optional + Set the y limits. + xlabel : str, optional + Name to use for the xlabel on x-axis. + ylabel : str, optional + Name to use for the ylabel on y-axis. + return_type : str, optional + If `'str'`, returns the plot as a string. Otherwise, the plot will be + directly printed to stdout. + + Returns + ------- + result + See Notes. + + """ plt_str = _boxplot(*args, **kwargs) - print(plt_str) + return return_plt(plt_str, **kwargs) + + +def return_plt(plt_str, **kwargs): + if kwargs.get("return_type") == "str": + return plt_str + else: + print(plt_str) # ----------------------------------------------------------------------------- @@ -43,34 +168,49 @@ def boxplot(*args, **kwargs): # ----------------------------------------------------------------------------- -def _plot(x, y, color=None, x_title=None, y_title=None): - """Scatterplot""" - x, y = remove_any_nan(x, y) +def _init_figure( + figsize=None, xlim=None, ylim=None, xlabel=None, ylabel=None, **kwargs +): + """Initialise a new figure. - def get_name(x): - if isinstance(x, pd.Series): - return x.name - else: - return None + TODO: + - This could be a class to hold a figure state? + - add ticks + - add tick labels + """ + if figsize is None: + figsize = (70, 25) # this should go somewhere else + + x_axis = Axis(figsize[0], label=xlabel, limits=xlim) + y_axis = Axis(figsize[1], label=ylabel, limits=ylim) + canvas = np.zeros(shape=(figsize[0], figsize[1]), dtype=int) + + return x_axis, y_axis, canvas + + +def _plot(x, y, color=None, **kwargs): + """Scatterplot""" - if x_title is None: - x_title = get_name(x) - if y_title is None: - y_title = get_name(y) + if kwargs.get("xlabel") is None: + kwargs.update({"xlabel": get_label(x)}) + if kwargs.get("ylabel") is None: + kwargs.update({"ylabel": get_label(y)}) - x_axis = Axis(DISPLAY_X, title=x_title) - y_axis = Axis(DISPLAY_Y, title=y_title) + x_axis, y_axis, canvas = _init_figure(**kwargs) + x, y = remove_any_nan(x, y) x_scaled = x_axis.fit_transform(x) y_scaled = y_axis.fit_transform(y) - canvas = np.zeros(shape=(DISPLAY_X, DISPLAY_Y), dtype=int) + within_display = np.logical_and(x_scaled.mask == 0, y_scaled.mask == 0) + x_scaled, y_scaled = x_scaled[within_display], y_scaled[within_display] if color is not None: - values = np.unique(color) + color_scaled = color.to_numpy()[within_display] + values = np.unique(color_scaled) for ii, val in enumerate(values): - mask = val == color + mask = val == color_scaled canvas[x_scaled[mask], y_scaled[mask]] = ii + 1 legend = {ii + 1: val for ii, val in enumerate(values)} @@ -81,23 +221,25 @@ def get_name(x): return draw(canvas=canvas, y_axis=y_axis, x_axis=x_axis, legend=legend) -def _hist(x, bins=10, x_title=None, **kwargs): +def _hist(x, bins=10, **kwargs): """Histogram""" x = x[~np.isnan(x)] counts, bin_edges = np.histogram(x, bins) - y_axis = Axis(DISPLAY_Y, title="counts") - x_axis = Axis(DISPLAY_X, title=x_title) + if kwargs.get("xlabel") is None: + kwargs.update({"xlabel": get_label(x)}) + if kwargs.get("ylabel") is None: + kwargs.update({"ylabel": "counts"}) + + x_axis, y_axis, canvas = _init_figure(**kwargs) y_axis.limits = (0, max(counts)) counts_scaled = y_axis.transform(counts) x_axis = x_axis.fit(bin_edges) - canvas = np.zeros(shape=(DISPLAY_X, DISPLAY_Y), dtype=int) - bin = 0 - bin_width = int((DISPLAY_X - 1) / len(counts)) - 1 + bin_width = x_axis.display_max // len(counts) - 1 for count in counts_scaled: canvas = _add_vbar(canvas, bin, bin_width, count) @@ -109,10 +251,11 @@ def _hist(x, bins=10, x_title=None, **kwargs): return draw(canvas=canvas, y_axis=y_axis, x_axis=x_axis) -def _barh(x, labels=None, x_title=None, y_title=None): +def _barh(x, labels=None, **kwargs): """Horizontal bar plot""" - y_axis = Axis(DISPLAY_Y, title=y_title) - x_axis = Axis(DISPLAY_X, title=x_title) + + kwargs.update({"xlabel": get_label(x)}) + x_axis, y_axis, canvas = _init_figure(**kwargs) x_axis.limits = (0, max(x)) x_scaled = x_axis.fit_transform(x) @@ -123,10 +266,8 @@ def _barh(x, labels=None, x_title=None, y_title=None): if labels is not None: y_axis.labels = labels - canvas = np.zeros(shape=(DISPLAY_X, DISPLAY_Y), dtype=int) - bin = 0 - bin_width = int((DISPLAY_Y - 1) / len(x)) - 1 + bin_width = y_axis.display_max // len(x) - 1 for val in x_scaled: canvas = _add_hbar(canvas, bin, bin_width, val) @@ -138,18 +279,17 @@ def _barh(x, labels=None, x_title=None, y_title=None): return draw(canvas=canvas, y_axis=y_axis, x_axis=x_axis) -def _boxplot(x, labels=None, x_title=None, y_title=None, **kwargs): +def _boxplot(x, labels=None, **kwargs): """Box plot""" + + x_axis, y_axis, canvas = _init_figure(**kwargs) + x = numpy_2d(x) x = np.ma.masked_where(np.isnan(x), x) quantiles = np.array( [np.quantile(dist[dist.mask == 0], q=[0, 0.25, 0.5, 0.75, 1.0]) for dist in x] ) - - x_axis = Axis(DISPLAY_X, x_title) - y_axis = Axis(DISPLAY_Y, y_title) - quantiles_scaled = x_axis.fit_transform(quantiles) y_axis = y_axis.fit(np.array([0, len(x)])) @@ -160,8 +300,6 @@ def _boxplot(x, labels=None, x_title=None, y_title=None, **kwargs): if labels is not None: y_axis.labels = labels - canvas = np.zeros(shape=(DISPLAY_X, DISPLAY_Y), dtype=int) - for ii in range(len(x)): quants = quantiles_scaled[ii, :] lims = y_lims[ii, :] @@ -193,7 +331,7 @@ def _add_hbar(canvas, start, width, height): def _add_box_and_whiskers(canvas, quantiles, limits): """Add a box and whiskers to the canvas""" - for jj in [0, 1, 2, 3, 4]: + for jj in range(5): canvas[quantiles[jj], limits[0] + 1 : limits[2]] = 20 canvas[quantiles[0] + 1 : quantiles[1], limits[1]] = 22 diff --git a/src/shellplot/utils.py b/src/shellplot/utils.py index d2613c4..7dc60ec 100644 --- a/src/shellplot/utils.py +++ b/src/shellplot/utils.py @@ -6,6 +6,28 @@ import numpy as np import pandas as pd +__all__ = ["load_dataset"] + + +def load_dataset(name: str) -> pd.DataFrame: + """Load standard dataset from shellplot library + + Parameters + ---------- + name : str + Name of the dataset. Available options are `penguins` + + Returns + ------- + pd.DataFrame + Pandas dataframe of dataset + + """ + module_path = os.path.dirname(__file__) + dataset_path = os.path.join(module_path, "datasets", f"{name}.csv") + + return pd.read_csv(dataset_path) + def tolerance_round(x, tol=1e-3): error = 1.0 @@ -45,13 +67,6 @@ def remove_any_nan(x, y): return x[~is_any_nan], y[~is_any_nan] -def load_dataset(name): - module_path = os.path.dirname(__file__) - dataset_path = os.path.join(module_path, "datasets", f"{name}.csv") - - return pd.read_csv(dataset_path) - - def numpy_2d(x): """Reshape and transform various array-like inputs to 2d np arrays""" if isinstance(x, np.ndarray): @@ -65,3 +80,13 @@ def numpy_2d(x): return x else: return None + + +def get_label(x): + """Try to get names out of array-like inputs""" + if isinstance(x, pd.Series): + return x.name + elif isinstance(x, pd.DataFrame): + return x.columns + else: + return None diff --git a/tests/test_drawing.py b/tests/test_drawing.py index 12dc1a3..9d0db2a 100644 --- a/tests/test_drawing.py +++ b/tests/test_drawing.py @@ -2,7 +2,16 @@ """ import pytest -from shellplot.drawing import _draw_legend, _pad_lines +import numpy as np + +from shellplot.axis import Axis +from shellplot.drawing import ( + _draw_canvas, + _draw_legend, + _draw_x_axis, + _draw_y_axis, + _pad_lines, +) def test_draw_legend(): @@ -21,3 +30,82 @@ def test_draw_legend(): def test_pad_lines(lines, ref_lines, expecte_padded_lines): padded_lines = _pad_lines(lines, ref_lines) assert padded_lines == expecte_padded_lines + + +@pytest.mark.parametrize( + "axis,expected_axis_lines", + [ + ( + Axis(display_length=50, label="my_fun_label", limits=(0, 1)), + [ + "└┬---------┬---------┬--------┬---------┬---------┬\n", + " 0.0 0.2 0.4 0.6 0.8 1.0\n", + " my_fun_label", + ], + ), + ( + Axis(display_length=50, label="my_fun_label", limits=(0, 0.01)), + [ + "└┬---------┬---------┬--------┬---------┬---------┬\n", + " 0.0 0.002 0.004 0.006 0.008 0.01\n", + " my_fun_label", + ], + ), + ], +) +def test_draw_x_axis(axis, expected_axis_lines): + x_lines = _draw_x_axis(x_axis=axis, left_pad=0) + assert x_lines == expected_axis_lines + + +@pytest.mark.parametrize( + "axis,expected_axis_lines", + [ + ( + Axis(display_length=15, label="my_fun_label", limits=(0, 1)), + [ + " my_fun_label", + " 1.0┤", + " |", + " |", + " |", + " 0.75┤", + " |", + " |", + " 0.5┤", + " |", + " |", + " 0.25┤", + " |", + " |", + " |", + " 0.0┤", + ], + ), + ], +) +def test_draw_y_axis(axis, expected_axis_lines): + y_lines = _draw_y_axis(y_axis=axis, left_pad=10) + assert y_lines == expected_axis_lines + + +@pytest.mark.parametrize( + "canvas,expected_canvas_lines", + [ + ( + np.array( + [ + [0, 0, 0, 0, 5], + [0, 0, 0, 4, 0], + [0, 0, 3, 0, 0], + [0, 2, 0, 0, 0], + [1, 0, 0, 0, 0], + ] + ), + ["@ ", " x ", " o ", " * ", " +"], + ), + ], +) +def test_draw_canvas(canvas, expected_canvas_lines): + canvas_lines = _draw_canvas(canvas) + assert canvas_lines == expected_canvas_lines diff --git a/tests/test_plots.py b/tests/test_plots.py index 71305af..32e52ba 100644 --- a/tests/test_plots.py +++ b/tests/test_plots.py @@ -42,6 +42,18 @@ def test_boxplot(): # ----------------------------------------------------------------------------- +def test_plot_figsize(): + x = np.arange(-3, 3, 0.01) + y = np.cos(x) ** 2 + plt_str = _plot(x, y, figsize=(60, 15)) + assert isinstance(plt_str, str) + + +# ----------------------------------------------------------------------------- +# Unit tests +# ----------------------------------------------------------------------------- + + @pytest.fixture def expected_canvas_vbar(): return np.array(