Skip to content

Commit

Permalink
added test for get_label func
Browse files Browse the repository at this point in the history
  • Loading branch information
CDonnerer committed Dec 25, 2020
1 parent 6673a88 commit 935529b
Show file tree
Hide file tree
Showing 4 changed files with 22 additions and 4 deletions.
2 changes: 1 addition & 1 deletion src/shellplot/pandas_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ def plot(data, kind, **kwargs):
return _plot_frame(data, kind, **kwargs)
else:
# we should never get here
return ValueError
raise ValueError


def hist_series(data, **kwargs):
Expand Down
6 changes: 3 additions & 3 deletions src/shellplot/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,15 +71,16 @@ def remove_any_nan(x, y):
@singledispatch
def numpy_2d(x):
"""Reshape and transform various array-like inputs to 2d np arrays"""
pass


@numpy_2d.register
def _(x: np.ndarray):
if len(x.shape) == 1:
return x[np.newaxis]
else:
elif len(x.shape) == 2:
return x
else:
raise ValueError("Array dimensions need to be <= 2!")


@numpy_2d.register
Expand All @@ -95,7 +96,6 @@ def _(x: list):
@singledispatch
def numpy_1d(x):
"""Reshape and transform various array-like inputs to 1d np arrays"""
pass


@numpy_1d.register
Expand Down
5 changes: 5 additions & 0 deletions tests/test_pandas_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,11 @@ def random_frame():


def test_plot_frame(df_penguins):
set_shellplot_plotting_backend()
df_penguins.dropna().plot("bill_length_mm", "flipper_length_mm")


def test_plot_frame_color(df_penguins):
set_shellplot_plotting_backend()
df_penguins.dropna().plot("bill_length_mm", "flipper_length_mm", color="species")

Expand Down
13 changes: 13 additions & 0 deletions tests/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import pandas as pd

from shellplot.utils import (
get_label,
load_dataset,
numpy_1d,
numpy_2d,
Expand Down Expand Up @@ -110,3 +111,15 @@ def test_numpy_2d(x, expected_np_2d):
def test_numpy_1d(x, expected_np_1d):
np_1d = numpy_1d(x)
np.testing.assert_equal(np_1d, expected_np_1d)


@pytest.mark.parametrize(
"x, expected_label",
[
(pd.Series(data=[0, 1], name="my_series"), "my_series"),
# (pd.DataFrame({"feat_1": [0, 1], "feat_2": [0, 1]}), TODO
],
)
def test_get_label(x, expected_label):
label = get_label(x)
assert label == expected_label

0 comments on commit 935529b

Please sign in to comment.