Skip to content

Commit

Permalink
remove inheritance
Browse files Browse the repository at this point in the history
  • Loading branch information
wjones127 committed Sep 27, 2023
1 parent 04ba15f commit 446606c
Show file tree
Hide file tree
Showing 5 changed files with 33 additions and 47 deletions.
14 changes: 5 additions & 9 deletions docs/source/format/CDataInterface/PyCapsuleInterface.rst
Original file line number Diff line number Diff line change
Expand Up @@ -115,10 +115,8 @@ Export Protocol
The interface is three separate protocols:

* ``ArrowSchemaExportable``, which defines the ``__arrow_c_schema__`` method.
* ``ArrowArrayExportable``, which defines the ``__arrow_c_array__``. It extends
``ArrowSchemaExportable``, so it must also define ``__arrow_c_schema__``.
* ``ArrowStreamExportable``, which defines the ``__arrow_c_stream__``. It extends
``ArrowSchemaExportable``, so it must also define ``__arrow_c_schema__``.
* ``ArrowArrayExportable``, which defines the ``__arrow_c_array__``.
* ``ArrowStreamExportable``, which defines the ``__arrow_c_stream__``.

The protocols are defined below in terms of ``typing.Protocol``. These may be
copied into a library for the purposes of static type checking, but this is not
Expand All @@ -139,7 +137,7 @@ required to implement the protocol.
"""
...
class ArrowArrayExportable(ArrowSchemaExportable, Protocol):
class ArrowArrayExportable(Protocol):
def __arrow_c_array__(
self,
requested_schema: object | None = None
Expand Down Expand Up @@ -173,7 +171,7 @@ required to implement the protocol.
...
class ArrowStreamExportable(ArrowSchemaExportable, Protocol):
class ArrowStreamExportable(Protocol):
def __arrow_c_stream__(self, requested_schema: object | None = None) -> object:
"""
Get a PyCapsule containing a C ArrowArrayStream representation of the object.
Expand Down Expand Up @@ -220,9 +218,7 @@ In order to allow the caller to request a specific representation, the

The callee should attempt to provide the data in the requested schema. However,
if the callee cannot provide the data in the requested schema, they may return
with the schema as provided by the ``__arrow_c_schema__`` method. Similarly, if
no schema is requested, the callee _must_ return with the schema as provided by
the ``__arrow_c_schema__`` method.
with the same schema as if ``None`` were passed to ``requested_schema``.

If the caller requests a schema that is not compatible with the data,
say requesting a schema with a different number of fields, the callee should
Expand Down
6 changes: 0 additions & 6 deletions python/pyarrow/array.pxi
Original file line number Diff line number Diff line change
Expand Up @@ -1702,12 +1702,6 @@ cdef class Array(_PandasConvertible):
c_type))
return pyarrow_wrap_array(c_array)

def __arrow_c_schema__(self):
"""
Get the array's data type as an ArrowSchema PyCapsule.
"""
return self.type.__arrow_c_schema__()

def __arrow_c_array__(self, requested_schema=None):
"""
Get a pair of PyCapsules containing a C ArrowArray representation of the object.
Expand Down
6 changes: 0 additions & 6 deletions python/pyarrow/ipc.pxi
Original file line number Diff line number Diff line change
Expand Up @@ -817,12 +817,6 @@ cdef class RecordBatchReader(_Weakrefable):
self.reader = c_reader
return self

def __arrow_c_schema__(self):
"""
Export the schema to a C ArrowSchema PyCapsule.
"""
return self.schema.__arrow_c_schema__()

def __arrow_c_stream__(self, requested_schema=None):
"""
Export to a C ArrowArrayStream PyCapsule.
Expand Down
22 changes: 16 additions & 6 deletions python/pyarrow/table.pxi
Original file line number Diff line number Diff line change
Expand Up @@ -1916,12 +1916,6 @@ cdef class _Tabular(_PandasConvertible):
def schema(self):
raise NotImplementedError

def __arrow_c_schema__(self):
"""
Export an ArrowSchema PyCapsule with the table's schema.
"""
return self.schema.__arrow_c_schema__()

def sort_by(self, sorting, **kwargs):
"""
Sort the Table or RecordBatch by one or multiple columns.
Expand Down Expand Up @@ -3040,6 +3034,22 @@ cdef class RecordBatch(_Tabular):

return schema_capsule, array_capsule

def __arrow_c_stream__(self, requested_schema=None):
"""
Export the batch as an Arrow C stream PyCapsule.
Parameters
----------
requested_schema : pyarrow.lib.Schema, default None
A schema to attempt to cast the streamed data to. This is currently
unsupported and will raise an error.
Returns
-------
PyCapsule
"""
return Table.from_batches([self]).__arrow_c_stream__(requested_schema)

@staticmethod
def _import_from_c_capsule(schema_capsule, array_capsule):
"""
Expand Down
32 changes: 12 additions & 20 deletions python/pyarrow/tests/test_cffi.py
Original file line number Diff line number Diff line change
Expand Up @@ -450,22 +450,13 @@ def test_roundtrip_array_capsule(arr, schema_accessor, bad_type, good_type):
old_allocated = pa.total_allocated_bytes()

import_array = type(arr)._import_from_c_capsule
import_schema = type(schema_accessor(arr))._import_from_c_capsule

schema_capsule = arr.__arrow_c_schema__()
assert pycapi.PyCapsule_IsValid(schema_capsule, b"arrow_schema") == 1
schema_out = import_schema(schema_capsule)
assert schema_out == schema_accessor(arr)

schema_capsule, capsule = arr.__arrow_c_array__()
assert pycapi.PyCapsule_IsValid(schema_capsule, b"arrow_schema") == 1
assert pycapi.PyCapsule_IsValid(capsule, b"arrow_array") == 1
arr_out = import_array(schema_capsule, capsule)
assert arr_out.equals(arr)

schema_out_2 = import_schema(arr.__arrow_c_schema__())
assert schema_out_2 == schema_accessor(arr)

assert pa.total_allocated_bytes() > old_allocated
del arr_out

Expand Down Expand Up @@ -502,11 +493,6 @@ def test_roundtrip_reader_capsule(constructor):

obj = constructor(schema, batches)

schema_capsule = obj.__arrow_c_schema__()
assert pycapi.PyCapsule_IsValid(schema_capsule, b"arrow_schema") == 1
schema_out = pa.Schema._import_from_c_capsule(schema_capsule)
assert schema_out == obj.schema

capsule = obj.__arrow_c_stream__()
assert pycapi.PyCapsule_IsValid(capsule, b"arrow_array_stream") == 1
imported_reader = pa.RecordBatchReader._import_from_c_capsule(capsule)
Expand All @@ -520,12 +506,6 @@ def test_roundtrip_reader_capsule(constructor):

obj = constructor(schema, batches)

schema_capsule = obj.__arrow_c_schema__()

assert pa.total_allocated_bytes() > old_allocated
del schema_capsule
assert pa.total_allocated_bytes() == old_allocated

# TODO: turn this to ValueError once we implement validation.
bad_schema = pa.schema({'ints': pa.int32()})
with pytest.raises(NotImplementedError):
Expand All @@ -538,3 +518,15 @@ def test_roundtrip_reader_capsule(constructor):
assert imported_reader.schema == matching_schema
for batch, expected in zip(imported_reader, batches):
assert batch.equals(expected)


def test_roundtrip_batch_reader_capsule():
batch = make_batch()

capsule = batch.__arrow_c_stream__()
assert pycapi.PyCapsule_IsValid(capsule, b"arrow_array_stream") == 1
imported_reader = pa.RecordBatchReader._import_from_c_capsule(capsule)
assert imported_reader.schema == batch.schema
assert imported_reader.read_next_batch().equals(batch)
with pytest.raises(StopIteration):
imported_reader.read_next_batch()

0 comments on commit 446606c

Please sign in to comment.