Skip to content

Commit

Permalink
GH-38325: [Python] Implement PyCapsule interface for Device data in P…
Browse files Browse the repository at this point in the history
…yArrow (#40717)

### Rationale for this change

PyArrow implementation for the specification additions being proposed in
#40708

### What changes are included in this PR?

New `__arrow_c_device_array__` method to `pyarrow.Array` and
`pyarrow.RecordBatch`, and support in the `pyarrow.array(..)`,
`pyarrow.record_batch(..)` and `pyarrow.table(..)` functions to consume
objects that have those methods.

### Are these changes tested?

Yes (for CPU only for now, #40385 is
a prerequisite to test this for CUDA)


* GitHub Issue: #38325
  • Loading branch information
jorisvandenbossche committed Jun 26, 2024
1 parent 5d58fc6 commit 1815a67
Show file tree
Hide file tree
Showing 7 changed files with 364 additions and 32 deletions.
98 changes: 97 additions & 1 deletion python/pyarrow/array.pxi
Original file line number Diff line number Diff line change
Expand Up @@ -129,7 +129,8 @@ def array(object obj, type=None, mask=None, size=None, from_pandas=None,
If both type and size are specified may be a single use iterable. If
not strongly-typed, Arrow type will be inferred for resulting array.
Any Arrow-compatible array that implements the Arrow PyCapsule Protocol
(has an ``__arrow_c_array__`` method) can be passed as well.
(has an ``__arrow_c_array__`` or ``__arrow_c_device_array__`` method)
can be passed as well.
type : pyarrow.DataType
Explicit type to attempt to coerce to, otherwise will be inferred from
the data.
Expand Down Expand Up @@ -245,6 +246,18 @@ def array(object obj, type=None, mask=None, size=None, from_pandas=None,

if hasattr(obj, '__arrow_array__'):
return _handle_arrow_array_protocol(obj, type, mask, size)
elif hasattr(obj, '__arrow_c_device_array__'):
if type is not None:
requested_type = type.__arrow_c_schema__()
else:
requested_type = None
schema_capsule, array_capsule = obj.__arrow_c_device_array__(requested_type)
out_array = Array._import_from_c_device_capsule(schema_capsule, array_capsule)
if type is not None and out_array.type != type:
# PyCapsule interface type coercion is best effort, so we need to
# check the type of the returned array and cast if necessary
out_array = array.cast(type, safe=safe, memory_pool=memory_pool)
return out_array
elif hasattr(obj, '__arrow_c_array__'):
if type is not None:
requested_type = type.__arrow_c_schema__()
Expand Down Expand Up @@ -1880,6 +1893,89 @@ cdef class Array(_PandasConvertible):
)
return pyarrow_wrap_array(c_array)

def __arrow_c_device_array__(self, requested_schema=None, **kwargs):
"""
Get a pair of PyCapsules containing a C ArrowDeviceArray representation
of the object.
Parameters
----------
requested_schema : PyCapsule | None
A PyCapsule containing a C ArrowSchema representation of a requested
schema. PyArrow will attempt to cast the array to this data type.
If None, the array will be returned as-is, with a type matching the
one returned by :meth:`__arrow_c_schema__()`.
kwargs
Currently no additional keyword arguments are supported, but
this method will accept any keyword with a value of ``None``
for compatibility with future keywords.
Returns
-------
Tuple[PyCapsule, PyCapsule]
A pair of PyCapsules containing a C ArrowSchema and ArrowDeviceArray,
respectively.
"""
cdef:
ArrowDeviceArray* c_array
ArrowSchema* c_schema
shared_ptr[CArray] inner_array

non_default_kwargs = [
name for name, value in kwargs.items() if value is not None
]
if non_default_kwargs:
raise NotImplementedError(
f"Received unsupported keyword argument(s): {non_default_kwargs}"
)

if requested_schema is not None:
target_type = DataType._import_from_c_capsule(requested_schema)

if target_type != self.type:
if not self.is_cpu:
raise NotImplementedError(
"Casting to a requested schema is only supported for CPU data"
)
try:
casted_array = _pc().cast(self, target_type, safe=True)
inner_array = pyarrow_unwrap_array(casted_array)
except ArrowInvalid as e:
raise ValueError(
f"Could not cast {self.type} to requested type {target_type}: {e}"
)
else:
inner_array = self.sp_array
else:
inner_array = self.sp_array

schema_capsule = alloc_c_schema(&c_schema)
array_capsule = alloc_c_device_array(&c_array)

with nogil:
check_status(ExportDeviceArray(
deref(inner_array), <shared_ptr[CSyncEvent]>NULL,
c_array, c_schema))

return schema_capsule, array_capsule

@staticmethod
def _import_from_c_device_capsule(schema_capsule, array_capsule):
cdef:
ArrowSchema* c_schema
ArrowDeviceArray* c_array
shared_ptr[CArray] array

c_schema = <ArrowSchema*> PyCapsule_GetPointer(schema_capsule, 'arrow_schema')
c_array = <ArrowDeviceArray*> PyCapsule_GetPointer(
array_capsule, 'arrow_device_array'
)

with nogil:
array = GetResultValue(ImportDeviceArray(c_array, c_schema))

return pyarrow_wrap_array(array)

def __dlpack__(self, stream=None):
"""Export a primitive array as a DLPack capsule.
Expand Down
2 changes: 2 additions & 0 deletions python/pyarrow/includes/libarrow.pxd
Original file line number Diff line number Diff line change
Expand Up @@ -1016,6 +1016,8 @@ cdef extern from "arrow/api.h" namespace "arrow" nogil:
int num_columns()
int64_t num_rows()

CDeviceAllocationType device_type()

CStatus Validate() const
CStatus ValidateFull() const

Expand Down
146 changes: 136 additions & 10 deletions python/pyarrow/table.pxi
Original file line number Diff line number Diff line change
Expand Up @@ -3639,7 +3639,7 @@ cdef class RecordBatch(_Tabular):
requested_schema : PyCapsule | None
A PyCapsule containing a C ArrowSchema representation of a requested
schema. PyArrow will attempt to cast the batch to this schema.
If None, the schema will be returned as-is, with a schema matching the
If None, the batch will be returned as-is, with a schema matching the
one returned by :meth:`__arrow_c_schema__()`.
Returns
Expand All @@ -3657,9 +3657,7 @@ cdef class RecordBatch(_Tabular):

if target_schema != self.schema:
try:
# We don't expose .cast() on RecordBatch, only on Table.
casted_batch = Table.from_batches([self]).cast(
target_schema, safe=True).to_batches()[0]
casted_batch = self.cast(target_schema, safe=True)
inner_batch = pyarrow_unwrap_batch(casted_batch)
except ArrowInvalid as e:
raise ValueError(
Expand Down Expand Up @@ -3700,8 +3698,8 @@ cdef class RecordBatch(_Tabular):
@staticmethod
def _import_from_c_capsule(schema_capsule, array_capsule):
"""
Import RecordBatch from a pair of PyCapsules containing a C ArrowArray
and ArrowSchema, respectively.
Import RecordBatch from a pair of PyCapsules containing a C ArrowSchema
and ArrowArray, respectively.
Parameters
----------
Expand Down Expand Up @@ -3792,6 +3790,121 @@ cdef class RecordBatch(_Tabular):
c_device_array, c_schema))
return pyarrow_wrap_batch(c_batch)

def __arrow_c_device_array__(self, requested_schema=None, **kwargs):
"""
Get a pair of PyCapsules containing a C ArrowDeviceArray representation
of the object.
Parameters
----------
requested_schema : PyCapsule | None
A PyCapsule containing a C ArrowSchema representation of a requested
schema. PyArrow will attempt to cast the batch to this data type.
If None, the batch will be returned as-is, with a type matching the
one returned by :meth:`__arrow_c_schema__()`.
kwargs
Currently no additional keyword arguments are supported, but
this method will accept any keyword with a value of ``None``
for compatibility with future keywords.
Returns
-------
Tuple[PyCapsule, PyCapsule]
A pair of PyCapsules containing a C ArrowSchema and ArrowDeviceArray,
respectively.
"""
cdef:
ArrowDeviceArray* c_array
ArrowSchema* c_schema
shared_ptr[CRecordBatch] inner_batch

non_default_kwargs = [
name for name, value in kwargs.items() if value is not None
]
if non_default_kwargs:
raise NotImplementedError(
f"Received unsupported keyword argument(s): {non_default_kwargs}"
)

if requested_schema is not None:
target_schema = Schema._import_from_c_capsule(requested_schema)

if target_schema != self.schema:
if not self.is_cpu:
raise NotImplementedError(
"Casting to a requested schema is only supported for CPU data"
)
try:
casted_batch = self.cast(target_schema, safe=True)
inner_batch = pyarrow_unwrap_batch(casted_batch)
except ArrowInvalid as e:
raise ValueError(
f"Could not cast {self.schema} to requested schema {target_schema}: {e}"
)
else:
inner_batch = self.sp_batch
else:
inner_batch = self.sp_batch

schema_capsule = alloc_c_schema(&c_schema)
array_capsule = alloc_c_device_array(&c_array)

with nogil:
check_status(ExportDeviceRecordBatch(
deref(inner_batch), <shared_ptr[CSyncEvent]>NULL, c_array, c_schema))

return schema_capsule, array_capsule

@staticmethod
def _import_from_c_device_capsule(schema_capsule, array_capsule):
"""
Import RecordBatch from a pair of PyCapsules containing a
C ArrowSchema and ArrowDeviceArray, respectively.
Parameters
----------
schema_capsule : PyCapsule
A PyCapsule containing a C ArrowSchema representation of the schema.
array_capsule : PyCapsule
A PyCapsule containing a C ArrowDeviceArray representation of the array.
Returns
-------
pyarrow.RecordBatch
"""
cdef:
ArrowSchema* c_schema
ArrowDeviceArray* c_array
shared_ptr[CRecordBatch] batch

c_schema = <ArrowSchema*> PyCapsule_GetPointer(schema_capsule, 'arrow_schema')
c_array = <ArrowDeviceArray*> PyCapsule_GetPointer(
array_capsule, 'arrow_device_array'
)

with nogil:
batch = GetResultValue(ImportDeviceRecordBatch(c_array, c_schema))

return pyarrow_wrap_batch(batch)

@property
def device_type(self):
"""
The device type where the arrays in the RecordBatch reside.
Returns
-------
DeviceAllocationType
"""
return _wrap_device_allocation_type(self.sp_batch.get().device_type())

@property
def is_cpu(self):
"""
Whether the RecordBatch's arrays are CPU-accessible.
"""
return self.device_type == DeviceAllocationType.CPU


def _reconstruct_record_batch(columns, schema):
"""
Expand Down Expand Up @@ -5584,7 +5697,8 @@ def record_batch(data, names=None, schema=None, metadata=None):
data : dict, list, pandas.DataFrame, Arrow-compatible table
A mapping of strings to Arrays or Python lists, a list of Arrays,
a pandas DataFame, or any tabular object implementing the
Arrow PyCapsule Protocol (has an ``__arrow_c_array__`` method).
Arrow PyCapsule Protocol (has an ``__arrow_c_array__`` or
``__arrow_c_device_array__`` method).
names : list, default None
Column names if list of arrays passed as data. Mutually exclusive with
'schema' argument.
Expand Down Expand Up @@ -5718,6 +5832,18 @@ def record_batch(data, names=None, schema=None, metadata=None):
raise ValueError(
"The 'names' argument is not valid when passing a dictionary")
return RecordBatch.from_pydict(data, schema=schema, metadata=metadata)
elif hasattr(data, "__arrow_c_device_array__"):
if schema is not None:
requested_schema = schema.__arrow_c_schema__()
else:
requested_schema = None
schema_capsule, array_capsule = data.__arrow_c_device_array__(requested_schema)
batch = RecordBatch._import_from_c_device_capsule(schema_capsule, array_capsule)
if schema is not None and batch.schema != schema:
# __arrow_c_device_array__ coerces schema with best effort, so we might
# need to cast it if the producer wasn't able to cast to exact schema.
batch = batch.cast(schema)
return batch
elif hasattr(data, "__arrow_c_array__"):
if schema is not None:
requested_schema = schema.__arrow_c_schema__()
Expand Down Expand Up @@ -5747,8 +5873,8 @@ def table(data, names=None, schema=None, metadata=None, nthreads=None):
data : dict, list, pandas.DataFrame, Arrow-compatible table
A mapping of strings to Arrays or Python lists, a list of arrays or
chunked arrays, a pandas DataFame, or any tabular object implementing
the Arrow PyCapsule Protocol (has an ``__arrow_c_array__`` or
``__arrow_c_stream__`` method).
the Arrow PyCapsule Protocol (has an ``__arrow_c_array__``,
``__arrow_c_device_array__`` or ``__arrow_c_stream__`` method).
names : list, default None
Column names if list of arrays passed as data. Mutually exclusive with
'schema' argument.
Expand Down Expand Up @@ -5888,7 +6014,7 @@ def table(data, names=None, schema=None, metadata=None, nthreads=None):
# need to cast it if the producer wasn't able to cast to exact schema.
table = table.cast(schema)
return table
elif hasattr(data, "__arrow_c_array__"):
elif hasattr(data, "__arrow_c_array__") or hasattr(data, "__arrow_c_device_array__"):
if names is not None or metadata is not None:
raise ValueError(
"The 'names' and 'metadata' arguments are not valid when "
Expand Down
39 changes: 32 additions & 7 deletions python/pyarrow/tests/test_array.py
Original file line number Diff line number Diff line change
Expand Up @@ -3505,16 +3505,27 @@ def __arrow_array__(self, type=None):
assert result.equals(expected)


def test_c_array_protocol():
class ArrayWrapper:
def __init__(self, data):
self.data = data
class ArrayWrapper:
def __init__(self, data):
self.data = data

def __arrow_c_array__(self, requested_schema=None):
return self.data.__arrow_c_array__(requested_schema)


class ArrayDeviceWrapper:
def __init__(self, data):
self.data = data

def __arrow_c_device_array__(self, requested_schema=None, **kwargs):
return self.data.__arrow_c_device_array__(requested_schema, **kwargs)

def __arrow_c_array__(self, requested_schema=None):
return self.data.__arrow_c_array__(requested_schema)

@pytest.mark.parametrize("wrapper_class", [ArrayWrapper, ArrayDeviceWrapper])
def test_c_array_protocol(wrapper_class):

# Can roundtrip through the C array protocol
arr = ArrayWrapper(pa.array([1, 2, 3], type=pa.int64()))
arr = wrapper_class(pa.array([1, 2, 3], type=pa.int64()))
result = pa.array(arr)
assert result == arr.data

Expand All @@ -3523,6 +3534,20 @@ def __arrow_c_array__(self, requested_schema=None):
assert result == pa.array([1, 2, 3], type=pa.int32())


def test_c_array_protocol_device_unsupported_keyword():
# For the device-aware version, we raise a specific error for unsupported keywords
arr = pa.array([1, 2, 3], type=pa.int64())

with pytest.raises(
NotImplementedError,
match=r"Received unsupported keyword argument\(s\): \['other'\]"
):
arr.__arrow_c_device_array__(other="not-none")

# but with None value it is ignored
_ = arr.__arrow_c_device_array__(other=None)


def test_concat_array():
concatenated = pa.concat_arrays(
[pa.array([1, 2]), pa.array([3, 4])])
Expand Down
Loading

0 comments on commit 1815a67

Please sign in to comment.