Skip to content

Commit

Permalink
ENH: ExtensionArray.fillna
Browse files Browse the repository at this point in the history
  • Loading branch information
TomAugspurger committed Feb 26, 2018
1 parent 1e4c50a commit 74614cb
Show file tree
Hide file tree
Showing 6 changed files with 225 additions and 65 deletions.
84 changes: 84 additions & 0 deletions pandas/core/arrays/base.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
"""An interface for extending pandas with custom arrays."""
import itertools

import numpy as np

from pandas.errors import AbstractMethodError
Expand Down Expand Up @@ -216,6 +218,88 @@ def isna(self):
"""
raise AbstractMethodError(self)

def tolist(self):
# type: () -> list
"""Convert the array to a list of scalars."""
return list(self)

def fillna(self, value=None, method=None, limit=None):
""" Fill NA/NaN values using the specified method.
Parameters
----------
method : {'backfill', 'bfill', 'pad', 'ffill', None}, default None
Method to use for filling holes in reindexed Series
pad / ffill: propagate last valid observation forward to next valid
backfill / bfill: use NEXT valid observation to fill gap
value : scalar, array-like
If a scalar value is passed it is used to fill all missing values.
Alternatively, an array-like 'value' can be given. It's expected
that the array-like have the same length as 'self'.
limit : int, default None
(Not implemented yet for ExtensionArray!)
If method is specified, this is the maximum number of consecutive
NaN values to forward/backward fill. In other words, if there is
a gap with more than this number of consecutive NaNs, it will only
be partially filled. If method is not specified, this is the
maximum number of entries along the entire axis where NaNs will be
filled.
Returns
-------
filled : ExtensionArray with NA/NaN filled
"""
from pandas.api.types import is_scalar
from pandas.util._validators import validate_fillna_kwargs

value, method = validate_fillna_kwargs(value, method)

if not is_scalar(value):
if len(value) != len(self):
raise ValueError("Length of 'value' does not match. Got ({}) "
" expected {}".format(len(value), len(self)))
else:
value = itertools.cycle([value])

if limit is not None:
msg = ("Specifying 'limit' for 'fillna' has not been implemented "
"yet for {} typed data".format(self.dtype))
raise NotImplementedError(msg)

mask = self.isna()

if mask.any():
# ffill / bfill
if method is not None:
if method == 'backfill':
data = reversed(self)
mask = reversed(mask)
last_valid = self[len(self) - 1]
else:
last_valid = self[0]
data = self

new_values = []

for is_na, val in zip(mask, data):
if is_na:
new_values.append(last_valid)
else:
new_values.append(val)
last_valid = val

if method in {'bfill', 'backfill'}:
new_values = list(reversed(new_values))
else:
# fill with value
new_values = [
val if is_na else original
for is_na, original, val in zip(mask, self, value)
]
else:
new_values = self
return type(self)(new_values)

# ------------------------------------------------------------------------
# Indexing methods
# ------------------------------------------------------------------------
Expand Down
38 changes: 17 additions & 21 deletions pandas/core/internals.py
Original file line number Diff line number Diff line change
Expand Up @@ -1963,6 +1963,23 @@ def concat_same_type(self, to_concat, placement=None):
return self.make_block_same_class(values, ndim=self.ndim,
placement=placement)

def fillna(self, value, limit=None, inplace=False, downcast=None,
mgr=None):
values = self.values if inplace else self.values.copy()
values = values.fillna(value=value, limit=limit)
return [self.make_block_same_class(values=values,
placement=self.mgr_locs,
ndim=self.ndim)]

def interpolate(self, method='pad', axis=0, inplace=False, limit=None,
fill_value=None, **kwargs):

values = self.values if inplace else self.values.copy()
return self.make_block_same_class(
values=values.fillna(value=fill_value, method=method,
limit=limit),
placement=self.mgr_locs)


class NumericBlock(Block):
__slots__ = ()
Expand Down Expand Up @@ -2522,27 +2539,6 @@ def _try_coerce_result(self, result):

return result

def fillna(self, value, limit=None, inplace=False, downcast=None,
mgr=None):
# we may need to upcast our fill to match our dtype
if limit is not None:
raise NotImplementedError("specifying a limit for 'fillna' has "
"not been implemented yet")

values = self.values if inplace else self.values.copy()
values = self._try_coerce_result(values.fillna(value=value,
limit=limit))
return [self.make_block(values=values)]

def interpolate(self, method='pad', axis=0, inplace=False, limit=None,
fill_value=None, **kwargs):

values = self.values if inplace else self.values.copy()
return self.make_block_same_class(
values=values.fillna(fill_value=fill_value, method=method,
limit=limit),
placement=self.mgr_locs)

def shift(self, periods, axis=0, mgr=None):
return self.make_block_same_class(values=self.values.shift(periods),
placement=self.mgr_locs)
Expand Down
80 changes: 80 additions & 0 deletions pandas/tests/extension/base/missing.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import numpy as np
import pytest

import pandas as pd
import pandas.util.testing as tm
Expand Down Expand Up @@ -45,3 +46,82 @@ def test_dropna_frame(self, data_missing):
result = df.dropna()
expected = df.iloc[:0]
self.assert_frame_equal(result, expected)

def test_fillna_limit_raises(self, data_missing):
ser = pd.Series(data_missing)
fill_value = data_missing[1]
xpr = "Specifying 'limit' for 'fillna'.*{}".format(data_missing.dtype)

with tm.assert_raises_regex(NotImplementedError, xpr):
ser.fillna(fill_value, limit=2)

def test_fillna_series(self, data_missing):
fill_value = data_missing[1]
ser = pd.Series(data_missing)

result = ser.fillna(fill_value)
expected = pd.Series(type(data_missing)([fill_value, fill_value]))
self.assert_series_equal(result, expected)

# Fill with a series
result = ser.fillna(expected)
self.assert_series_equal(result, expected)

# Fill with a series not affecting the missing values
result = ser.fillna(ser)
self.assert_series_equal(result, ser)

@pytest.mark.xfail(reason="Too magical?")
def test_fillna_series_with_dict(self, data_missing):
fill_value = data_missing[1]
ser = pd.Series(data_missing)
expected = pd.Series(type(data_missing)([fill_value, fill_value]))

# Fill with a dict
result = ser.fillna({0: fill_value})
self.assert_series_equal(result, expected)

# Fill with a dict not affecting the missing values
result = ser.fillna({1: fill_value})
ser = pd.Series(data_missing)
self.assert_series_equal(result, ser)

@pytest.mark.parametrize('method', ['ffill', 'bfill'])
def test_fillna_series_method(self, data_missing, method):
fill_value = data_missing[1]

if method == 'ffill':
data_missing = type(data_missing)(data_missing[::-1])

result = pd.Series(data_missing).fillna(method=method)
expected = pd.Series(type(data_missing)([fill_value, fill_value]))

self.assert_series_equal(result, expected)

def test_fillna_frame(self, data_missing):
fill_value = data_missing[1]

result = pd.DataFrame({
"A": data_missing,
"B": [1, 2]
}).fillna(fill_value)

expected = pd.DataFrame({
"A": type(data_missing)([fill_value, fill_value]),
"B": [1, 2],
})

self.assert_frame_equal(result, expected)

def test_fillna_fill_other(self, data):
result = pd.DataFrame({
"A": data,
"B": [np.nan] * len(data)
}).fillna({"B": 0.0})

expected = pd.DataFrame({
"A": data,
"B": [0.0] * len(result),
})

self.assert_frame_equal(result, expected)
5 changes: 4 additions & 1 deletion pandas/tests/extension/category/test_categorical.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,10 @@ def test_getitem_scalar(self):


class TestMissing(base.BaseMissingTests):
pass

@pytest.mark.skip(reason="Backwards compatability")
def test_fillna_limit_raises(self):
"""Has a different error message."""


class TestMethods(base.BaseMethodsTests):
Expand Down
75 changes: 33 additions & 42 deletions pandas/tests/extension/decimal/test_decimal.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,68 +35,59 @@ def na_value():
return decimal.Decimal("NaN")


class TestDtype(base.BaseDtypeTests):
pass
class BaseDecimal(object):
@staticmethod
def assert_series_equal(left, right, *args, **kwargs):
# tm.assert_series_equal doesn't handle Decimal('NaN').
# We will ensure that the NA values match, and then
# drop those values before moving on.

left_na = left.isna()
right_na = right.isna()

class TestInterface(base.BaseInterfaceTests):
pass
tm.assert_series_equal(left_na, right_na)
tm.assert_series_equal(left[~left_na], right[~right_na],
*args, **kwargs)

@staticmethod
def assert_frame_equal(left, right, *args, **kwargs):
# TODO(EA): select_dtypes
decimals = (left.dtypes == 'decimal').index

class TestConstructors(base.BaseConstructorsTests):
pass
for col in decimals:
BaseDecimal.assert_series_equal(left[col], right[col],
*args, **kwargs)

left = left.drop(columns=decimals)
right = right.drop(columns=decimals)
tm.assert_frame_equal(left, right, *args, **kwargs)

class TestReshaping(base.BaseReshapingTests):

def test_align(self, data, na_value):
# Have to override since assert_series_equal doesn't
# compare Decimal(NaN) properly.
a = data[:3]
b = data[2:5]
r1, r2 = pd.Series(a).align(pd.Series(b, index=[1, 2, 3]))
class TestDtype(BaseDecimal, base.BaseDtypeTests):
pass

# NaN handling
e1 = pd.Series(type(data)(list(a) + [na_value]))
e2 = pd.Series(type(data)([na_value] + list(b)))
tm.assert_series_equal(r1.iloc[:3], e1.iloc[:3])
assert r1[3].is_nan()
assert e1[3].is_nan()

tm.assert_series_equal(r2.iloc[1:], e2.iloc[1:])
assert r2[0].is_nan()
assert e2[0].is_nan()
class TestInterface(BaseDecimal, base.BaseInterfaceTests):
pass

def test_align_frame(self, data, na_value):
# Override for Decimal(NaN) comparison
a = data[:3]
b = data[2:5]
r1, r2 = pd.DataFrame({'A': a}).align(
pd.DataFrame({'A': b}, index=[1, 2, 3])
)

# Assumes that the ctor can take a list of scalars of the type
e1 = pd.DataFrame({'A': type(data)(list(a) + [na_value])})
e2 = pd.DataFrame({'A': type(data)([na_value] + list(b))})
class TestConstructors(BaseDecimal, base.BaseConstructorsTests):
pass

tm.assert_frame_equal(r1.iloc[:3], e1.iloc[:3])
assert r1.loc[3, 'A'].is_nan()
assert e1.loc[3, 'A'].is_nan()

tm.assert_frame_equal(r2.iloc[1:], e2.iloc[1:])
assert r2.loc[0, 'A'].is_nan()
assert e2.loc[0, 'A'].is_nan()
class TestReshaping(BaseDecimal, base.BaseReshapingTests):
pass


class TestGetitem(base.BaseGetitemTests):
class TestGetitem(BaseDecimal, base.BaseGetitemTests):
pass


class TestMissing(base.BaseMissingTests):
class TestMissing(BaseDecimal, base.BaseMissingTests):
pass


class TestMethods(base.BaseMethodsTests):
class TestMethods(BaseDecimal, base.BaseMethodsTests):
@pytest.mark.parametrize('dropna', [True, False])
@pytest.mark.xfail(reason="value_counts not implemented yet.")
def test_value_counts(self, all_data, dropna):
Expand All @@ -112,7 +103,7 @@ def test_value_counts(self, all_data, dropna):
tm.assert_series_equal(result, expected)


class TestCasting(base.BaseCastingTests):
class TestCasting(BaseDecimal, base.BaseCastingTests):
pass


Expand Down
8 changes: 7 additions & 1 deletion pandas/tests/extension/json/test_json.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,13 @@ class TestGetitem(base.BaseGetitemTests):


class TestMissing(base.BaseMissingTests):
pass
@pytest.mark.xfail(reason="Setting a dict as a scalar")
def test_fillna_series(self):
"""We treat dictionaries as a mapping in fillna, not a scalar."""

@pytest.mark.xfail(reason="Setting a dict as a scalar")
def test_fillna_frame(self):
"""We treat dictionaries as a mapping in fillna, not a scalar."""


class TestMethods(base.BaseMethodsTests):
Expand Down

0 comments on commit 74614cb

Please sign in to comment.