Skip to content

Commit

Permalink
StringArray comparisions return BooleanArray
Browse files Browse the repository at this point in the history
  • Loading branch information
TomAugspurger committed Dec 12, 2019
1 parent daa3158 commit c548c67
Show file tree
Hide file tree
Showing 5 changed files with 94 additions and 9 deletions.
4 changes: 4 additions & 0 deletions doc/source/user_guide/text.rst
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,10 @@ l. For ``StringDtype``, :ref:`string accessor methods<api.series.str>`
2. Some string methods, like :meth:`Series.str.decode` are not available
on ``StringArray`` because ``StringArray`` only holds strings, not
bytes.
3. In comparision operations, :class:`StringArray` and ``Series`` backed
by a ``StringArray`` will return a :class:`BooleanArray`, rather than
a ``bool`` or ``object`` dtype array, depending on whether there are
missing values.


Everything else that follows in the rest of this document applies equally to
Expand Down
31 changes: 24 additions & 7 deletions pandas/core/arrays/string_.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,6 +120,10 @@ class StringArray(PandasArray):
copy : bool, default False
Whether to copy the array of data.
Notes
-----
StringArray returns a BooleanArray for comparison methods.
Attributes
----------
None
Expand Down Expand Up @@ -148,6 +152,13 @@ class StringArray(PandasArray):
Traceback (most recent call last):
...
ValueError: StringArray requires an object-dtype ndarray of strings.
For comparision methods, this returns a :class:`pandas.BooleanArray`
>>> pd.array(["a", None, "c"], dtype="string") == "a"
<BooleanArray>
[True, NA, False]
Length: 3, dtype: boolean
"""

# undo the PandasArray hack
Expand Down Expand Up @@ -255,7 +266,12 @@ def value_counts(self, dropna=False):
# Overrride parent because we have different return types.
@classmethod
def _create_arithmetic_method(cls, op):
# Note: this handles both arithmetic and comparison methods.
def method(self, other):
from pandas.arrays import BooleanArray

assert op.__name__ in ops.ARITHMETIC_BINOPS | ops.COMPARISON_BINOPS

if isinstance(other, (ABCIndexClass, ABCSeries, ABCDataFrame)):
return NotImplemented

Expand All @@ -275,15 +291,16 @@ def method(self, other):
other = np.asarray(other)
other = other[valid]

result = np.empty_like(self._ndarray, dtype="object")
result[mask] = StringDtype.na_value
result[valid] = op(self._ndarray[valid], other)

if op.__name__ in {"add", "radd", "mul", "rmul"}:
if op.__name__ in ops.ARITHMETIC_BINOPS:
result = np.empty_like(self._ndarray, dtype="object")
result[mask] = StringDtype.na_value
result[valid] = op(self._ndarray[valid], other)
return StringArray(result)
else:
dtype = "object" if mask.any() else "bool"
return np.asarray(result, dtype=dtype)
# logical
result = np.zeros(len(self._ndarray), dtype="bool")
result[valid] = op(self._ndarray[valid], other)
return BooleanArray(result, mask)

return compat.set_function_name(method, f"__{op.__name__}__", cls)

Expand Down
33 changes: 32 additions & 1 deletion pandas/core/ops/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
"""
import datetime
import operator
from typing import Tuple, Union
from typing import Set, Tuple, Union

import numpy as np

Expand Down Expand Up @@ -59,6 +59,37 @@
rxor,
)

# -----------------------------------------------------------------------------
# constants
ARITHMETIC_BINOPS: Set[str] = {
"add",
"sub",
"mul",
"pow",
"mod",
"floordiv",
"truediv",
"divmod",
"radd",
"rsub",
"rmul",
"rpow",
"rmod",
"rfloordiv",
"rtruediv",
"rdivmod",
}


COMPARISON_BINOPS: Set[str] = {
"eq",
"ne",
"lt",
"gt",
"le",
"ge",
}

# -----------------------------------------------------------------------------
# Ops Wrapping Utilities

Expand Down
33 changes: 33 additions & 0 deletions pandas/tests/arrays/string_/test_string.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,6 +154,39 @@ def test_add_frame():
tm.assert_frame_equal(result, expected)


def test_comparison_methods_scalar(all_compare_operators):
op_name = all_compare_operators

a = pd.array(["a", None, "c"], dtype="string")
other = "a"
result = getattr(a, op_name)(other)
expected = np.array([getattr(item, op_name)(other) for item in a], dtype=object)
expected[1] = None
expected = pd.array(expected, dtype="boolean")
tm.assert_extension_array_equal(result, expected)

result = getattr(a, op_name)(pd.NA)
expected = pd.array([None, None, None], dtype="boolean")
tm.assert_extension_array_equal(result, expected)


def test_comparison_methods_array(all_compare_operators):
op_name = all_compare_operators

a = pd.array(["a", None, "c"], dtype="string")
other = [None, None, "c"]
result = getattr(a, op_name)(other)
expected = np.empty_like(a, dtype="object")
expected[:2] = None
expected[-1] = getattr(other[-1], op_name)(a[-1])
expected = pd.array(expected, dtype="boolean")
tm.assert_extension_array_equal(result, expected)

result = getattr(a, op_name)(pd.NA)
expected = pd.array([None, None, None], dtype="boolean")
tm.assert_extension_array_equal(result, expected)


def test_constructor_raises():
with pytest.raises(ValueError, match="sequence of strings"):
pd.arrays.StringArray(np.array(["a", "b"], dtype="S1"))
Expand Down
2 changes: 1 addition & 1 deletion pandas/tests/extension/test_string.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,7 @@ class TestCasting(base.BaseCastingTests):
class TestComparisonOps(base.BaseComparisonOpsTests):
def _compare_other(self, s, data, op_name, other):
result = getattr(s, op_name)(other)
expected = getattr(s.astype(object), op_name)(other)
expected = getattr(s.astype(object), op_name)(other).astype("boolean")
self.assert_series_equal(result, expected)

def test_compare_scalar(self, data, all_compare_operators):
Expand Down

0 comments on commit c548c67

Please sign in to comment.