Skip to content

Commit

Permalink
ARROW-6176: [Python] Basic implementation of __arrow_ext_class__, in …
Browse files Browse the repository at this point in the history
…pure Python

This is a very basic implementation of what could be `__arrow_ext_class__`, i.e. allowing extension types in PyArrow to have a custom Extension Array class (useful for adding additional logic). It is an alternative to adding dynamic attributes to `ExtensionArray` (see ARROW-8131),

Of interest @kszucs @jorisvandenbossche

Closes #6653 from balancap/ARROW-6176-allow-sub-class-extension-array-to-attach-custom-extension-type

Authored-by: Paul Balanca <paul.balanca@gmail.com>
Signed-off-by: Wes McKinney <wesm+git@apache.org>
  • Loading branch information
balancap authored and wesm committed Apr 8, 2020
1 parent 90a47c0 commit 9422f4d
Show file tree
Hide file tree
Showing 5 changed files with 124 additions and 18 deletions.
62 changes: 62 additions & 0 deletions docs/source/python/extending_types.rst
Expand Up @@ -226,6 +226,68 @@ data type from above would look like::

Also the storage type does not need to be fixed but can be parametrized.

Custom extension array class
~~~~~~~~~~~~~~~~~~~~~~~~~~~~

By default, all arrays with an extension type are constructed or deserialized into
a built-in :class:`ExtensionArray` object. Nevertheless, one could want to subclass
:class:`ExtensionArray` in order to add some custom logic specific to the extension
type. Arrow allows to do so by adding a special method ``__arrow_ext_class__`` to the
definition of the extension type.

For instance, let us consider the example from the `Numpy Quickstart <https://docs.scipy.org/doc/numpy-1.13.0/user/quickstart.html>`_ of points in 3D space.
We can store these as a fixed-size list, where we wish to be able to extract
the data as a 2-D Numpy array ``(N, 3)`` without any copy::

class Point3DArray(pa.ExtensionArray):
def to_numpy_array(self):
return arr.storage.flatten().to_numpy().reshape((-1, 3))


class Point3DType(pa.PyExtensionType):
def __init__(self):
pa.PyExtensionType.__init__(self, pa.list_(pa.float32(), 3))

def __reduce__(self):
return Point3DType, ()

def __arrow_ext_class__(self):
return Point3DArray

Arrays built using this extension type now have the expected custom array class::

>>> storage = pa.array([[1, 2, 3], [4, 5, 6]], pa.list_(pa.float32(), 3))
>>> arr = pa.ExtensionArray.from_storage(Point3DType(), storage)
>>> arr
<__main__.Point3DArray object at 0x7f40dea80670>
[
[
1,
2,
3
],
[
4,
5,
6
]
]

The additional methods in the extension class are then available to the user::

>>> arr.to_numpy_array()
array([[1., 2., 3.],
[4., 5., 6.]], dtype=float32)


This array can be sent over IPC, received in another Python process, and the custom
extension array class will be preserved (as long as the definitions of the classes above
are available).

The same ``__arrow_ext_class__`` specialization can be used with custom types defined
by subclassing :class:`ExtensionType`.


Conversion to pandas
~~~~~~~~~~~~~~~~~~~~

Expand Down
13 changes: 13 additions & 0 deletions python/pyarrow/array.pxi
Expand Up @@ -2173,6 +2173,19 @@ cdef dict _array_classes = {
}


cdef object get_array_class_from_type(
const shared_ptr[CDataType]& sp_data_type):
cdef CDataType* data_type = sp_data_type.get()
if data_type == NULL:
raise ValueError('Array data type was NULL')

if data_type.id() == _Type_EXTENSION:
py_ext_data_type = pyarrow_wrap_data_type(sp_data_type)
return py_ext_data_type.__arrow_ext_class__()
else:
return _array_classes[data_type.id()]


cdef object get_values(object obj, bint* is_series):
if pandas_api.is_series(obj) or pandas_api.is_index(obj):
result = pandas_api.get_values(obj)
Expand Down
7 changes: 1 addition & 6 deletions python/pyarrow/public-api.pxi
Expand Up @@ -193,12 +193,7 @@ cdef api object pyarrow_wrap_array(const shared_ptr[CArray]& sp_array):
if sp_array.get() == NULL:
raise ValueError('Array was NULL')

cdef CDataType* data_type = sp_array.get().type().get()

if data_type == NULL:
raise ValueError('Array data type was NULL')

klass = _array_classes[data_type.id()]
klass = get_array_class_from_type(sp_array.get().type())

cdef Array arr = klass.__new__(klass)
arr.init(sp_array)
Expand Down
50 changes: 38 additions & 12 deletions python/pyarrow/tests/test_extension_type.py
Expand Up @@ -269,8 +269,11 @@ def test_ipc_unknown_type():
assert arr.type == ParamExtType(3)


class PeriodType(pa.ExtensionType):
class PeriodArray(pa.ExtensionArray):
pass


class PeriodType(pa.ExtensionType):
def __init__(self, freq):
# attributes need to be set first before calling
# super init (as that calls serialize)
Expand Down Expand Up @@ -299,12 +302,27 @@ def __eq__(self, other):
return NotImplemented


@pytest.fixture
def registered_period_type():
class PeriodTypeWithClass(PeriodType):
def __init__(self, freq):
PeriodType.__init__(self, freq)

def __arrow_ext_class__(self):
return PeriodArray

@classmethod
def __arrow_ext_deserialize__(cls, storage_type, serialized):
freq = PeriodType.__arrow_ext_deserialize__(
storage_type, serialized).freq
return PeriodTypeWithClass(freq)


@pytest.fixture(params=[PeriodType('D'), PeriodTypeWithClass('D')])
def registered_period_type(request):
# setup
period_type = PeriodType('D')
period_type = request.param
period_class = period_type.__arrow_ext_class__()
pa.register_extension_type(period_type)
yield
yield period_type, period_class
# teardown
try:
pa.unregister_extension_type('pandas.period')
Expand All @@ -316,30 +334,35 @@ def test_generic_ext_type():
period_type = PeriodType('D')
assert period_type.extension_name == "pandas.period"
assert period_type.storage_type == pa.int64()
# default ext_class expected.
assert period_type.__arrow_ext_class__() == pa.ExtensionArray


def test_generic_ext_type_ipc(registered_period_type):
period_type = PeriodType('D')
period_type, period_class = registered_period_type
storage = pa.array([1, 2, 3, 4], pa.int64())
arr = pa.ExtensionArray.from_storage(period_type, storage)
batch = pa.RecordBatch.from_arrays([arr], ["ext"])
# check the built array has exactly the expected clss
assert type(arr) == period_class

buf = ipc_write_batch(batch)
del batch
batch = ipc_read_batch(buf)

result = batch.column(0)
assert isinstance(result, pa.ExtensionArray)
# check the deserialized array class is the expected one
assert type(result) == period_class
assert result.type.extension_name == "pandas.period"
assert arr.storage.to_pylist() == [1, 2, 3, 4]

# we get back an actual PeriodType
assert isinstance(result.type, PeriodType)
assert result.type.freq == 'D'
assert result.type == PeriodType('D')
assert result.type == period_type

# using different parametrization as how it was registered
period_type_H = PeriodType('H')
period_type_H = period_type.__class__('H')
assert period_type_H.extension_name == "pandas.period"
assert period_type_H.freq == 'H'

Expand All @@ -352,11 +375,11 @@ def test_generic_ext_type_ipc(registered_period_type):
result = batch.column(0)
assert isinstance(result.type, PeriodType)
assert result.type.freq == 'H'
assert result.type == PeriodType('H')
assert type(result) == period_class


def test_generic_ext_type_ipc_unknown(registered_period_type):
period_type = PeriodType('D')
period_type, _ = registered_period_type
storage = pa.array([1, 2, 3, 4], pa.int64())
arr = pa.ExtensionArray.from_storage(period_type, storage)
batch = pa.RecordBatch.from_arrays([arr], ["ext"])
Expand Down Expand Up @@ -403,7 +426,7 @@ def test_generic_ext_type_register(registered_period_type):
@pytest.mark.parquet
def test_parquet(tmpdir, registered_period_type):
# parquet support for extension types
period_type = PeriodType('D')
period_type, period_class = registered_period_type
storage = pa.array([1, 2, 3, 4], pa.int64())
arr = pa.ExtensionArray.from_storage(period_type, storage)
table = pa.table([arr], names=["ext"])
Expand All @@ -429,6 +452,9 @@ def test_parquet(tmpdir, registered_period_type):
# when reading in, properly create extension type if it is registered
result = pq.read_table(filename)
assert result.column("ext").type == period_type
# get the exact array class defined by the registered type.
result_array = result.column("ext").chunk(0)
assert type(result_array) == period_class

# when the type is not registered, read in as storage type
pa.unregister_extension_type(period_type.extension_name)
Expand Down
10 changes: 10 additions & 0 deletions python/pyarrow/types.pxi
Expand Up @@ -718,6 +718,16 @@ cdef class ExtensionType(BaseExtensionType):
"""
return NotImplementedError

def __arrow_ext_class__(self):
"""Return an extension array class to be used for building or
deserializing arrays with this extension type.
This method should return a subclass of the ExtensionArray class. By
default, if not specialized in the extension implementation, an
extension type array will be a built-in ExtensionArray instance.
"""
return ExtensionArray


cdef class PyExtensionType(ExtensionType):
"""
Expand Down

0 comments on commit 9422f4d

Please sign in to comment.