From 74614cb66a3bc011c57e0befcaa2cd69e6d31a6c Mon Sep 17 00:00:00 2001 From: Tom Augspurger Date: Thu, 22 Feb 2018 11:07:35 -0600 Subject: [PATCH] ENH: ExtensionArray.fillna --- pandas/core/arrays/base.py | 84 +++++++++++++++++++ pandas/core/internals.py | 38 ++++----- pandas/tests/extension/base/missing.py | 80 ++++++++++++++++++ .../extension/category/test_categorical.py | 5 +- .../tests/extension/decimal/test_decimal.py | 75 ++++++++--------- pandas/tests/extension/json/test_json.py | 8 +- 6 files changed, 225 insertions(+), 65 deletions(-) diff --git a/pandas/core/arrays/base.py b/pandas/core/arrays/base.py index cec881394a021..8dc4ddbff7d5c 100644 --- a/pandas/core/arrays/base.py +++ b/pandas/core/arrays/base.py @@ -1,4 +1,6 @@ """An interface for extending pandas with custom arrays.""" +import itertools + import numpy as np from pandas.errors import AbstractMethodError @@ -216,6 +218,88 @@ def isna(self): """ raise AbstractMethodError(self) + def tolist(self): + # type: () -> list + """Convert the array to a list of scalars.""" + return list(self) + + def fillna(self, value=None, method=None, limit=None): + """ Fill NA/NaN values using the specified method. + + Parameters + ---------- + method : {'backfill', 'bfill', 'pad', 'ffill', None}, default None + Method to use for filling holes in reindexed Series + pad / ffill: propagate last valid observation forward to next valid + backfill / bfill: use NEXT valid observation to fill gap + value : scalar, array-like + If a scalar value is passed it is used to fill all missing values. + Alternatively, an array-like 'value' can be given. It's expected + that the array-like have the same length as 'self'. + limit : int, default None + (Not implemented yet for ExtensionArray!) + If method is specified, this is the maximum number of consecutive + NaN values to forward/backward fill. In other words, if there is + a gap with more than this number of consecutive NaNs, it will only + be partially filled. If method is not specified, this is the + maximum number of entries along the entire axis where NaNs will be + filled. + + Returns + ------- + filled : ExtensionArray with NA/NaN filled + """ + from pandas.api.types import is_scalar + from pandas.util._validators import validate_fillna_kwargs + + value, method = validate_fillna_kwargs(value, method) + + if not is_scalar(value): + if len(value) != len(self): + raise ValueError("Length of 'value' does not match. Got ({}) " + " expected {}".format(len(value), len(self))) + else: + value = itertools.cycle([value]) + + if limit is not None: + msg = ("Specifying 'limit' for 'fillna' has not been implemented " + "yet for {} typed data".format(self.dtype)) + raise NotImplementedError(msg) + + mask = self.isna() + + if mask.any(): + # ffill / bfill + if method is not None: + if method == 'backfill': + data = reversed(self) + mask = reversed(mask) + last_valid = self[len(self) - 1] + else: + last_valid = self[0] + data = self + + new_values = [] + + for is_na, val in zip(mask, data): + if is_na: + new_values.append(last_valid) + else: + new_values.append(val) + last_valid = val + + if method in {'bfill', 'backfill'}: + new_values = list(reversed(new_values)) + else: + # fill with value + new_values = [ + val if is_na else original + for is_na, original, val in zip(mask, self, value) + ] + else: + new_values = self + return type(self)(new_values) + # ------------------------------------------------------------------------ # Indexing methods # ------------------------------------------------------------------------ diff --git a/pandas/core/internals.py b/pandas/core/internals.py index 00ef8f9cef598..da7329e6ced23 100644 --- a/pandas/core/internals.py +++ b/pandas/core/internals.py @@ -1963,6 +1963,23 @@ def concat_same_type(self, to_concat, placement=None): return self.make_block_same_class(values, ndim=self.ndim, placement=placement) + def fillna(self, value, limit=None, inplace=False, downcast=None, + mgr=None): + values = self.values if inplace else self.values.copy() + values = values.fillna(value=value, limit=limit) + return [self.make_block_same_class(values=values, + placement=self.mgr_locs, + ndim=self.ndim)] + + def interpolate(self, method='pad', axis=0, inplace=False, limit=None, + fill_value=None, **kwargs): + + values = self.values if inplace else self.values.copy() + return self.make_block_same_class( + values=values.fillna(value=fill_value, method=method, + limit=limit), + placement=self.mgr_locs) + class NumericBlock(Block): __slots__ = () @@ -2522,27 +2539,6 @@ def _try_coerce_result(self, result): return result - def fillna(self, value, limit=None, inplace=False, downcast=None, - mgr=None): - # we may need to upcast our fill to match our dtype - if limit is not None: - raise NotImplementedError("specifying a limit for 'fillna' has " - "not been implemented yet") - - values = self.values if inplace else self.values.copy() - values = self._try_coerce_result(values.fillna(value=value, - limit=limit)) - return [self.make_block(values=values)] - - def interpolate(self, method='pad', axis=0, inplace=False, limit=None, - fill_value=None, **kwargs): - - values = self.values if inplace else self.values.copy() - return self.make_block_same_class( - values=values.fillna(fill_value=fill_value, method=method, - limit=limit), - placement=self.mgr_locs) - def shift(self, periods, axis=0, mgr=None): return self.make_block_same_class(values=self.values.shift(periods), placement=self.mgr_locs) diff --git a/pandas/tests/extension/base/missing.py b/pandas/tests/extension/base/missing.py index 3ae82fa1ca432..086bd0c3b95fa 100644 --- a/pandas/tests/extension/base/missing.py +++ b/pandas/tests/extension/base/missing.py @@ -1,4 +1,5 @@ import numpy as np +import pytest import pandas as pd import pandas.util.testing as tm @@ -45,3 +46,82 @@ def test_dropna_frame(self, data_missing): result = df.dropna() expected = df.iloc[:0] self.assert_frame_equal(result, expected) + + def test_fillna_limit_raises(self, data_missing): + ser = pd.Series(data_missing) + fill_value = data_missing[1] + xpr = "Specifying 'limit' for 'fillna'.*{}".format(data_missing.dtype) + + with tm.assert_raises_regex(NotImplementedError, xpr): + ser.fillna(fill_value, limit=2) + + def test_fillna_series(self, data_missing): + fill_value = data_missing[1] + ser = pd.Series(data_missing) + + result = ser.fillna(fill_value) + expected = pd.Series(type(data_missing)([fill_value, fill_value])) + self.assert_series_equal(result, expected) + + # Fill with a series + result = ser.fillna(expected) + self.assert_series_equal(result, expected) + + # Fill with a series not affecting the missing values + result = ser.fillna(ser) + self.assert_series_equal(result, ser) + + @pytest.mark.xfail(reason="Too magical?") + def test_fillna_series_with_dict(self, data_missing): + fill_value = data_missing[1] + ser = pd.Series(data_missing) + expected = pd.Series(type(data_missing)([fill_value, fill_value])) + + # Fill with a dict + result = ser.fillna({0: fill_value}) + self.assert_series_equal(result, expected) + + # Fill with a dict not affecting the missing values + result = ser.fillna({1: fill_value}) + ser = pd.Series(data_missing) + self.assert_series_equal(result, ser) + + @pytest.mark.parametrize('method', ['ffill', 'bfill']) + def test_fillna_series_method(self, data_missing, method): + fill_value = data_missing[1] + + if method == 'ffill': + data_missing = type(data_missing)(data_missing[::-1]) + + result = pd.Series(data_missing).fillna(method=method) + expected = pd.Series(type(data_missing)([fill_value, fill_value])) + + self.assert_series_equal(result, expected) + + def test_fillna_frame(self, data_missing): + fill_value = data_missing[1] + + result = pd.DataFrame({ + "A": data_missing, + "B": [1, 2] + }).fillna(fill_value) + + expected = pd.DataFrame({ + "A": type(data_missing)([fill_value, fill_value]), + "B": [1, 2], + }) + + self.assert_frame_equal(result, expected) + + def test_fillna_fill_other(self, data): + result = pd.DataFrame({ + "A": data, + "B": [np.nan] * len(data) + }).fillna({"B": 0.0}) + + expected = pd.DataFrame({ + "A": data, + "B": [0.0] * len(result), + }) + + self.assert_frame_equal(result, expected) diff --git a/pandas/tests/extension/category/test_categorical.py b/pandas/tests/extension/category/test_categorical.py index 8f413b4a19730..ddd8d01b841c7 100644 --- a/pandas/tests/extension/category/test_categorical.py +++ b/pandas/tests/extension/category/test_categorical.py @@ -69,7 +69,10 @@ def test_getitem_scalar(self): class TestMissing(base.BaseMissingTests): - pass + + @pytest.mark.skip(reason="Backwards compatability") + def test_fillna_limit_raises(self): + """Has a different error message.""" class TestMethods(base.BaseMethodsTests): diff --git a/pandas/tests/extension/decimal/test_decimal.py b/pandas/tests/extension/decimal/test_decimal.py index 7b4d079ecad87..01ae092bc1521 100644 --- a/pandas/tests/extension/decimal/test_decimal.py +++ b/pandas/tests/extension/decimal/test_decimal.py @@ -35,68 +35,59 @@ def na_value(): return decimal.Decimal("NaN") -class TestDtype(base.BaseDtypeTests): - pass +class BaseDecimal(object): + @staticmethod + def assert_series_equal(left, right, *args, **kwargs): + # tm.assert_series_equal doesn't handle Decimal('NaN'). + # We will ensure that the NA values match, and then + # drop those values before moving on. + left_na = left.isna() + right_na = right.isna() -class TestInterface(base.BaseInterfaceTests): - pass + tm.assert_series_equal(left_na, right_na) + tm.assert_series_equal(left[~left_na], right[~right_na], + *args, **kwargs) + @staticmethod + def assert_frame_equal(left, right, *args, **kwargs): + # TODO(EA): select_dtypes + decimals = (left.dtypes == 'decimal').index -class TestConstructors(base.BaseConstructorsTests): - pass + for col in decimals: + BaseDecimal.assert_series_equal(left[col], right[col], + *args, **kwargs) + left = left.drop(columns=decimals) + right = right.drop(columns=decimals) + tm.assert_frame_equal(left, right, *args, **kwargs) -class TestReshaping(base.BaseReshapingTests): - def test_align(self, data, na_value): - # Have to override since assert_series_equal doesn't - # compare Decimal(NaN) properly. - a = data[:3] - b = data[2:5] - r1, r2 = pd.Series(a).align(pd.Series(b, index=[1, 2, 3])) +class TestDtype(BaseDecimal, base.BaseDtypeTests): + pass - # NaN handling - e1 = pd.Series(type(data)(list(a) + [na_value])) - e2 = pd.Series(type(data)([na_value] + list(b))) - tm.assert_series_equal(r1.iloc[:3], e1.iloc[:3]) - assert r1[3].is_nan() - assert e1[3].is_nan() - tm.assert_series_equal(r2.iloc[1:], e2.iloc[1:]) - assert r2[0].is_nan() - assert e2[0].is_nan() +class TestInterface(BaseDecimal, base.BaseInterfaceTests): + pass - def test_align_frame(self, data, na_value): - # Override for Decimal(NaN) comparison - a = data[:3] - b = data[2:5] - r1, r2 = pd.DataFrame({'A': a}).align( - pd.DataFrame({'A': b}, index=[1, 2, 3]) - ) - # Assumes that the ctor can take a list of scalars of the type - e1 = pd.DataFrame({'A': type(data)(list(a) + [na_value])}) - e2 = pd.DataFrame({'A': type(data)([na_value] + list(b))}) +class TestConstructors(BaseDecimal, base.BaseConstructorsTests): + pass - tm.assert_frame_equal(r1.iloc[:3], e1.iloc[:3]) - assert r1.loc[3, 'A'].is_nan() - assert e1.loc[3, 'A'].is_nan() - tm.assert_frame_equal(r2.iloc[1:], e2.iloc[1:]) - assert r2.loc[0, 'A'].is_nan() - assert e2.loc[0, 'A'].is_nan() +class TestReshaping(BaseDecimal, base.BaseReshapingTests): + pass -class TestGetitem(base.BaseGetitemTests): +class TestGetitem(BaseDecimal, base.BaseGetitemTests): pass -class TestMissing(base.BaseMissingTests): +class TestMissing(BaseDecimal, base.BaseMissingTests): pass -class TestMethods(base.BaseMethodsTests): +class TestMethods(BaseDecimal, base.BaseMethodsTests): @pytest.mark.parametrize('dropna', [True, False]) @pytest.mark.xfail(reason="value_counts not implemented yet.") def test_value_counts(self, all_data, dropna): @@ -112,7 +103,7 @@ def test_value_counts(self, all_data, dropna): tm.assert_series_equal(result, expected) -class TestCasting(base.BaseCastingTests): +class TestCasting(BaseDecimal, base.BaseCastingTests): pass diff --git a/pandas/tests/extension/json/test_json.py b/pandas/tests/extension/json/test_json.py index e0721bb1d8d1a..16d5e4415a79f 100644 --- a/pandas/tests/extension/json/test_json.py +++ b/pandas/tests/extension/json/test_json.py @@ -60,7 +60,13 @@ class TestGetitem(base.BaseGetitemTests): class TestMissing(base.BaseMissingTests): - pass + @pytest.mark.xfail(reason="Setting a dict as a scalar") + def test_fillna_series(self): + """We treat dictionaries as a mapping in fillna, not a scalar.""" + + @pytest.mark.xfail(reason="Setting a dict as a scalar") + def test_fillna_frame(self): + """We treat dictionaries as a mapping in fillna, not a scalar.""" class TestMethods(base.BaseMethodsTests):