Skip to content

Commit

Permalink
TST/CLN: Use more frame_or_series fixture (pandas-dev#48926)
Browse files Browse the repository at this point in the history
* TST/CLN: Use more frame_or_series fixture

* Revert for base ext tests
  • Loading branch information
mroeschke committed Oct 4, 2022
1 parent ff9a1dc commit e25aa9d
Show file tree
Hide file tree
Showing 12 changed files with 50 additions and 57 deletions.
5 changes: 2 additions & 3 deletions pandas/tests/apply/test_invalid_arg.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,12 +90,11 @@ def test_map_datetimetz_na_action():
s.map(lambda x: x, na_action="ignore")


@pytest.mark.parametrize("box", [DataFrame, Series])
@pytest.mark.parametrize("method", ["apply", "agg", "transform"])
@pytest.mark.parametrize("func", [{"A": {"B": "sum"}}, {"A": {"B": ["sum"]}}])
def test_nested_renamer(box, method, func):
def test_nested_renamer(frame_or_series, method, func):
# GH 35964
obj = box({"A": [1]})
obj = frame_or_series({"A": [1]})
match = "nested renamer is not supported"
with pytest.raises(SpecificationError, match=match):
getattr(obj, method)(func)
Expand Down
4 changes: 2 additions & 2 deletions pandas/tests/extension/base/ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -179,13 +179,13 @@ def test_direct_arith_with_ndframe_returns_not_implemented(self, data, box):
result = data.__eq__(other)
assert result is NotImplemented
else:
raise pytest.skip(f"{type(data).__name__} does not implement __eq__")
pytest.skip(f"{type(data).__name__} does not implement __eq__")

if hasattr(data, "__ne__"):
result = data.__ne__(other)
assert result is NotImplemented
else:
raise pytest.skip(f"{type(data).__name__} does not implement __ne__")
pytest.skip(f"{type(data).__name__} does not implement __ne__")


class BaseUnaryOpsTests(BaseOpsUtil):
Expand Down
7 changes: 4 additions & 3 deletions pandas/tests/extension/test_period.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,11 +142,12 @@ def test_add_series_with_extension_array(self, data):
with pytest.raises(TypeError, match=msg):
s + data

@pytest.mark.parametrize("box", [pd.Series, pd.DataFrame])
def test_direct_arith_with_ndframe_returns_not_implemented(self, data, box):
def test_direct_arith_with_ndframe_returns_not_implemented(
self, data, frame_or_series
):
# Override to use __sub__ instead of __add__
other = pd.Series(data)
if box is pd.DataFrame:
if frame_or_series is pd.DataFrame:
other = other.to_frame()

result = data.__sub__(other)
Expand Down
10 changes: 4 additions & 6 deletions pandas/tests/frame/indexing/test_xs.py
Original file line number Diff line number Diff line change
Expand Up @@ -324,8 +324,7 @@ def test_xs_loc_equality(self, multiindex_dataframe_random_data):
expected = df.loc[("bar", "two")]
tm.assert_series_equal(result, expected)

@pytest.mark.parametrize("klass", [DataFrame, Series])
def test_xs_IndexSlice_argument_not_implemented(self, klass):
def test_xs_IndexSlice_argument_not_implemented(self, frame_or_series):
# GH#35301

index = MultiIndex(
Expand All @@ -334,7 +333,7 @@ def test_xs_IndexSlice_argument_not_implemented(self, klass):
)

obj = DataFrame(np.random.randn(6, 4), index=index)
if klass is Series:
if frame_or_series is Series:
obj = obj[0]

expected = obj.iloc[-2:].droplevel(0)
Expand All @@ -345,10 +344,9 @@ def test_xs_IndexSlice_argument_not_implemented(self, klass):
result = obj.loc[IndexSlice[("foo", "qux", 0), :]]
tm.assert_equal(result, expected)

@pytest.mark.parametrize("klass", [DataFrame, Series])
def test_xs_levels_raises(self, klass):
def test_xs_levels_raises(self, frame_or_series):
obj = DataFrame({"A": [1, 2, 3]})
if klass is Series:
if frame_or_series is Series:
obj = obj["A"]

msg = "Index must be a MultiIndex"
Expand Down
7 changes: 3 additions & 4 deletions pandas/tests/frame/methods/test_drop.py
Original file line number Diff line number Diff line change
Expand Up @@ -422,17 +422,16 @@ def test_drop_level_nonunique_datetime(self):
expected = df.loc[idx != 4]
tm.assert_frame_equal(result, expected)

@pytest.mark.parametrize("box", [Series, DataFrame])
def test_drop_tz_aware_timestamp_across_dst(self, box):
def test_drop_tz_aware_timestamp_across_dst(self, frame_or_series):
# GH#21761
start = Timestamp("2017-10-29", tz="Europe/Berlin")
end = Timestamp("2017-10-29 04:00:00", tz="Europe/Berlin")
index = pd.date_range(start, end, freq="15min")
data = box(data=[1] * len(index), index=index)
data = frame_or_series(data=[1] * len(index), index=index)
result = data.drop(start)
expected_start = Timestamp("2017-10-29 00:15:00", tz="Europe/Berlin")
expected_idx = pd.date_range(expected_start, end, freq="15min")
expected = box(data=[1] * len(expected_idx), index=expected_idx)
expected = frame_or_series(data=[1] * len(expected_idx), index=expected_idx)
tm.assert_equal(result, expected)

def test_drop_preserve_names(self):
Expand Down
9 changes: 5 additions & 4 deletions pandas/tests/frame/methods/test_pct_change.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,13 +22,14 @@ class TestDataFramePctChange:
(-1, "bfill", 1, [np.nan, 0, -0.5, -0.5, -0.6, np.nan, np.nan, np.nan]),
],
)
@pytest.mark.parametrize("klass", [DataFrame, Series])
def test_pct_change_with_nas(self, periods, fill_method, limit, exp, klass):
def test_pct_change_with_nas(
self, periods, fill_method, limit, exp, frame_or_series
):
vals = [np.nan, np.nan, 1, 2, 4, 10, np.nan, np.nan]
obj = klass(vals)
obj = frame_or_series(vals)

res = obj.pct_change(periods=periods, fill_method=fill_method, limit=limit)
tm.assert_equal(res, klass(exp))
tm.assert_equal(res, frame_or_series(exp))

def test_pct_change_numeric(self):
# GH#11150
Expand Down
6 changes: 2 additions & 4 deletions pandas/tests/frame/methods/test_rename.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@
DataFrame,
Index,
MultiIndex,
Series,
merge,
)
import pandas._testing as tm
Expand All @@ -32,9 +31,8 @@ def test_rename_signature(self):
"errors",
}

@pytest.mark.parametrize("klass", [Series, DataFrame])
def test_rename_mi(self, klass):
obj = klass(
def test_rename_mi(self, frame_or_series):
obj = frame_or_series(
[11, 21, 31],
index=MultiIndex.from_tuples([("A", x) for x in ["a", "B", "c"]]),
)
Expand Down
9 changes: 4 additions & 5 deletions pandas/tests/frame/methods/test_sample.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,14 +11,13 @@


class TestSample:
@pytest.fixture(params=[Series, DataFrame])
def obj(self, request):
klass = request.param
if klass is Series:
@pytest.fixture
def obj(self, frame_or_series):
if frame_or_series is Series:
arr = np.random.randn(10)
else:
arr = np.random.randn(10, 10)
return klass(arr, dtype=None)
return frame_or_series(arr, dtype=None)

@pytest.mark.parametrize("test", list(range(10)))
def test_sample(self, test, obj):
Expand Down
5 changes: 2 additions & 3 deletions pandas/tests/io/formats/test_to_csv.py
Original file line number Diff line number Diff line change
Expand Up @@ -385,10 +385,9 @@ def test_to_csv_multi_index(self):
),
],
)
@pytest.mark.parametrize("klass", [DataFrame, pd.Series])
def test_to_csv_single_level_multi_index(self, ind, expected, klass):
def test_to_csv_single_level_multi_index(self, ind, expected, frame_or_series):
# see gh-19589
obj = klass(pd.Series([1], ind, name="data"))
obj = frame_or_series(pd.Series([1], ind, name="data"))

with tm.assert_produces_warning(FutureWarning, match="lineterminator"):
# GH#9568 standardize on lineterminator matching stdlib
Expand Down
8 changes: 4 additions & 4 deletions pandas/tests/resample/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,13 +159,13 @@ def empty_frame_dti(series):
return DataFrame(index=index)


@pytest.fixture(params=[Series, DataFrame])
def series_and_frame(request, series, frame):
@pytest.fixture
def series_and_frame(frame_or_series, series, frame):
"""
Fixture for parametrization of Series and DataFrame with date_range,
period_range and timedelta_range indexes
"""
if request.param == Series:
if frame_or_series == Series:
return series
if request.param == DataFrame:
if frame_or_series == DataFrame:
return frame
11 changes: 5 additions & 6 deletions pandas/tests/reshape/concat/test_concat.py
Original file line number Diff line number Diff line change
Expand Up @@ -505,16 +505,15 @@ def test_concat_duplicate_indices_raise(self):
concat([df1, df2], axis=1)


@pytest.mark.parametrize("pdt", [Series, DataFrame])
@pytest.mark.parametrize("dt", np.sctypes["float"])
def test_concat_no_unnecessary_upcast(dt, pdt):
def test_concat_no_unnecessary_upcast(dt, frame_or_series):
# GH 13247
dims = pdt(dtype=object).ndim
dims = frame_or_series(dtype=object).ndim

dfs = [
pdt(np.array([1], dtype=dt, ndmin=dims)),
pdt(np.array([np.nan], dtype=dt, ndmin=dims)),
pdt(np.array([5], dtype=dt, ndmin=dims)),
frame_or_series(np.array([1], dtype=dt, ndmin=dims)),
frame_or_series(np.array([np.nan], dtype=dt, ndmin=dims)),
frame_or_series(np.array([5], dtype=dt, ndmin=dims)),
]
x = concat(dfs)
assert x.values.dtype == dt
Expand Down
26 changes: 13 additions & 13 deletions pandas/tests/window/test_base_indexer.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,6 @@ def get_window_bounds(self, num_values, min_periods, center, closed, step):
tm.assert_frame_equal(result, expected)


@pytest.mark.parametrize("constructor", [Series, DataFrame])
@pytest.mark.parametrize(
"func,np_func,expected,np_kwargs",
[
Expand Down Expand Up @@ -149,7 +148,9 @@ def get_window_bounds(self, num_values, min_periods, center, closed, step):
],
)
@pytest.mark.filterwarnings("ignore:min_periods:FutureWarning")
def test_rolling_forward_window(constructor, func, np_func, expected, np_kwargs, step):
def test_rolling_forward_window(
frame_or_series, func, np_func, expected, np_kwargs, step
):
# GH 32865
values = np.arange(10.0)
values[5] = 100.0
Expand All @@ -158,47 +159,46 @@ def test_rolling_forward_window(constructor, func, np_func, expected, np_kwargs,

match = "Forward-looking windows can't have center=True"
with pytest.raises(ValueError, match=match):
rolling = constructor(values).rolling(window=indexer, center=True)
rolling = frame_or_series(values).rolling(window=indexer, center=True)
getattr(rolling, func)()

match = "Forward-looking windows don't support setting the closed argument"
with pytest.raises(ValueError, match=match):
rolling = constructor(values).rolling(window=indexer, closed="right")
rolling = frame_or_series(values).rolling(window=indexer, closed="right")
getattr(rolling, func)()

rolling = constructor(values).rolling(window=indexer, min_periods=2, step=step)
rolling = frame_or_series(values).rolling(window=indexer, min_periods=2, step=step)
result = getattr(rolling, func)()

# Check that the function output matches the explicitly provided array
expected = constructor(expected)[::step]
expected = frame_or_series(expected)[::step]
tm.assert_equal(result, expected)

# Check that the rolling function output matches applying an alternative
# function to the rolling window object
expected2 = constructor(rolling.apply(lambda x: np_func(x, **np_kwargs)))
expected2 = frame_or_series(rolling.apply(lambda x: np_func(x, **np_kwargs)))
tm.assert_equal(result, expected2)

# Check that the function output matches applying an alternative function
# if min_periods isn't specified
# GH 39604: After count-min_periods deprecation, apply(lambda x: len(x))
# is equivalent to count after setting min_periods=0
min_periods = 0 if func == "count" else None
rolling3 = constructor(values).rolling(window=indexer, min_periods=min_periods)
rolling3 = frame_or_series(values).rolling(window=indexer, min_periods=min_periods)
result3 = getattr(rolling3, func)()
expected3 = constructor(rolling3.apply(lambda x: np_func(x, **np_kwargs)))
expected3 = frame_or_series(rolling3.apply(lambda x: np_func(x, **np_kwargs)))
tm.assert_equal(result3, expected3)


@pytest.mark.parametrize("constructor", [Series, DataFrame])
def test_rolling_forward_skewness(constructor, step):
def test_rolling_forward_skewness(frame_or_series, step):
values = np.arange(10.0)
values[5] = 100.0

indexer = FixedForwardWindowIndexer(window_size=5)
rolling = constructor(values).rolling(window=indexer, min_periods=3, step=step)
rolling = frame_or_series(values).rolling(window=indexer, min_periods=3, step=step)
result = rolling.skew()

expected = constructor(
expected = frame_or_series(
[
0.0,
2.232396,
Expand Down

0 comments on commit e25aa9d

Please sign in to comment.