Skip to content

Commit

Permalink
API: ExtensionDtype Equality and Hashability
Browse files Browse the repository at this point in the history
  • Loading branch information
TomAugspurger committed Oct 4, 2018
1 parent d430195 commit ed56aa3
Show file tree
Hide file tree
Showing 6 changed files with 60 additions and 17 deletions.
9 changes: 9 additions & 0 deletions doc/source/whatsnew/v0.24.0.txt
Expand Up @@ -492,6 +492,15 @@ Previous Behavior:
ExtensionType Changes
^^^^^^^^^^^^^^^^^^^^^

**:class:`pandas.api.extensions.ExtensionDtype` Equality and Hashability**

Pandas now requires that extension dtypes be hashable. The base class implements
a default ``__eq__`` and ``__hash__``. If you have a parametrized dtype, you should
update the ``ExtensionDtype._metadata`` tuple to match the signature of your
``__init__`` method. See :class:`pandas.api.extensions.ExtensionDtype` for more.

**Other changes**

- ``ExtensionArray`` has gained the abstract methods ``.dropna()`` (:issue:`21185`)
- ``ExtensionDtype`` has gained the ability to instantiate from string dtypes, e.g. ``decimal`` would instantiate a registered ``DecimalDtype``; furthermore
the ``ExtensionDtype`` has gained the method ``construct_array_type`` (:issue:`21185`)
Expand Down
45 changes: 38 additions & 7 deletions pandas/core/dtypes/base.py
Expand Up @@ -22,14 +22,17 @@ class _DtypeOpsMixin(object):
# of the NA value, not the physical NA vaalue for storage.
# e.g. for JSONArray, this is an empty dictionary.
na_value = np.nan
_metadata = ()

def __eq__(self, other):
"""Check whether 'other' is equal to self.
By default, 'other' is considered equal if
By default, 'other' is considered equal if either
* it's a string matching 'self.name'.
* it's an instance of this type.
* it's an instance of this type and all of the
the attributes in ``self._metadata`` are equal between
`self` and `other`.
Parameters
----------
Expand All @@ -40,11 +43,19 @@ def __eq__(self, other):
bool
"""
if isinstance(other, compat.string_types):
return other == self.name
elif isinstance(other, type(self)):
return True
else:
return False
try:
other = self.construct_from_string(other)
except TypeError:
return False
if isinstance(other, type(self)):
return all(
getattr(self, attr) == getattr(other, attr)
for attr in self._metadata
)
return False

def __hash__(self):
return hash(tuple(getattr(self, attr) for attr in self._metadata))

def __ne__(self, other):
return not self.__eq__(other)
Expand Down Expand Up @@ -161,6 +172,26 @@ class ExtensionDtype(_DtypeOpsMixin):
The `na_value` class attribute can be used to set the default NA value
for this type. :attr:`numpy.nan` is used by default.
ExtensionDtypes are required to be hashable. The base class provides
a default implementation, which relies on the ``_metadata`` class
attribute. ``_metadata`` should be a tuple containing the strings
that define your data type. For example, with ``PeriodDtype`` that's
the ``freq`` attribute.
**If you have a parametrized dtype you should set the ``_metadata``
class property**.
Ideally, the attributes in ``_metadata`` will match the
parameters to your ``ExtensionDtype.__init__`` (if any). If any of
the attributes in ``_metadata`` don't implement the standard
``__eq__`` or ``__hash__``, the default implementations here will not
work.
.. versionchanged:: 0.24.0
Added ``_metadata``, ``__hash__``, and changed the default definition
of ``__eq__``.
This class does not inherit from 'abc.ABCMeta' for performance reasons.
Methods and properties required by the interface raise
``pandas.errors.AbstractMethodError`` and no ``register`` method is
Expand Down
9 changes: 4 additions & 5 deletions pandas/core/dtypes/dtypes.py
Expand Up @@ -101,7 +101,6 @@ class PandasExtensionDtype(_DtypeOpsMixin):
base = None
isbuiltin = 0
isnative = 0
_metadata = []
_cache = {}

def __unicode__(self):
Expand Down Expand Up @@ -209,7 +208,7 @@ class CategoricalDtype(PandasExtensionDtype, ExtensionDtype):
kind = 'O'
str = '|O08'
base = np.dtype('O')
_metadata = ['categories', 'ordered']
_metadata = ('categories', 'ordered')
_cache = {}

def __init__(self, categories=None, ordered=None):
Expand Down Expand Up @@ -485,7 +484,7 @@ class DatetimeTZDtype(PandasExtensionDtype):
str = '|M8[ns]'
num = 101
base = np.dtype('M8[ns]')
_metadata = ['unit', 'tz']
_metadata = ('unit', 'tz')
_match = re.compile(r"(datetime64|M8)\[(?P<unit>.+), (?P<tz>.+)\]")
_cache = {}

Expand Down Expand Up @@ -589,7 +588,7 @@ class PeriodDtype(PandasExtensionDtype):
str = '|O08'
base = np.dtype('O')
num = 102
_metadata = ['freq']
_metadata = ('freq',)
_match = re.compile(r"(P|p)eriod\[(?P<freq>.+)\]")
_cache = {}

Expand Down Expand Up @@ -709,7 +708,7 @@ class IntervalDtype(PandasExtensionDtype, ExtensionDtype):
str = '|O08'
base = np.dtype('O')
num = 103
_metadata = ['subtype']
_metadata = ('subtype',)
_match = re.compile(r"(I|i)nterval\[(?P<subtype>.+)\]")
_cache = {}

Expand Down
7 changes: 7 additions & 0 deletions pandas/tests/extension/base/dtype.py
Expand Up @@ -49,6 +49,10 @@ def test_eq_with_str(self, dtype):
def test_eq_with_numpy_object(self, dtype):
assert dtype != np.dtype('object')

def test_eq_with_self(self, dtype):
assert dtype == dtype
assert dtype != object()

def test_array_type(self, data, dtype):
assert dtype.construct_array_type() is type(data)

Expand Down Expand Up @@ -81,3 +85,6 @@ def test_check_dtype(self, data):
index=list('ABCD'))
result = df.dtypes.apply(str) == str(dtype)
self.assert_series_equal(result, expected)

def test_hashable(self, dtype):
hash(dtype) # no error
6 changes: 1 addition & 5 deletions pandas/tests/extension/decimal/array.py
Expand Up @@ -15,15 +15,11 @@ class DecimalDtype(ExtensionDtype):
type = decimal.Decimal
name = 'decimal'
na_value = decimal.Decimal('NaN')
_metadata = ('context',)

def __init__(self, context=None):
self.context = context or decimal.getcontext()

def __eq__(self, other):
if isinstance(other, type(self)):
return self.context == other.context
return super(DecimalDtype, self).__eq__(other)

def __repr__(self):
return 'DecimalDtype(context={})'.format(self.context)

Expand Down
1 change: 1 addition & 0 deletions pandas/tests/extension/json/array.py
Expand Up @@ -27,6 +27,7 @@
class JSONDtype(ExtensionDtype):
type = compat.Mapping
name = 'json'

try:
na_value = collections.UserDict()
except AttributeError:
Expand Down

0 comments on commit ed56aa3

Please sign in to comment.