From 1815a679e47a47f2f5f4cd003ffbb18b6db2f952 Mon Sep 17 00:00:00 2001 From: Joris Van den Bossche Date: Wed, 26 Jun 2024 17:41:17 +0200 Subject: [PATCH] GH-38325: [Python] Implement PyCapsule interface for Device data in PyArrow (#40717) ### Rationale for this change PyArrow implementation for the specification additions being proposed in https://github.com/apache/arrow/pull/40708 ### What changes are included in this PR? New `__arrow_c_device_array__` method to `pyarrow.Array` and `pyarrow.RecordBatch`, and support in the `pyarrow.array(..)`, `pyarrow.record_batch(..)` and `pyarrow.table(..)` functions to consume objects that have those methods. ### Are these changes tested? Yes (for CPU only for now, https://github.com/apache/arrow/pull/40385 is a prerequisite to test this for CUDA) * GitHub Issue: #38325 --- python/pyarrow/array.pxi | 98 +++++++++++++++++- python/pyarrow/includes/libarrow.pxd | 2 + python/pyarrow/table.pxi | 146 +++++++++++++++++++++++++-- python/pyarrow/tests/test_array.py | 39 +++++-- python/pyarrow/tests/test_cffi.py | 42 ++++++++ python/pyarrow/tests/test_table.py | 48 ++++++--- python/pyarrow/types.pxi | 21 ++++ 7 files changed, 364 insertions(+), 32 deletions(-) diff --git a/python/pyarrow/array.pxi b/python/pyarrow/array.pxi index daf4adc33e558..b1f90cd16537b 100644 --- a/python/pyarrow/array.pxi +++ b/python/pyarrow/array.pxi @@ -129,7 +129,8 @@ def array(object obj, type=None, mask=None, size=None, from_pandas=None, If both type and size are specified may be a single use iterable. If not strongly-typed, Arrow type will be inferred for resulting array. Any Arrow-compatible array that implements the Arrow PyCapsule Protocol - (has an ``__arrow_c_array__`` method) can be passed as well. + (has an ``__arrow_c_array__`` or ``__arrow_c_device_array__`` method) + can be passed as well. type : pyarrow.DataType Explicit type to attempt to coerce to, otherwise will be inferred from the data. @@ -245,6 +246,18 @@ def array(object obj, type=None, mask=None, size=None, from_pandas=None, if hasattr(obj, '__arrow_array__'): return _handle_arrow_array_protocol(obj, type, mask, size) + elif hasattr(obj, '__arrow_c_device_array__'): + if type is not None: + requested_type = type.__arrow_c_schema__() + else: + requested_type = None + schema_capsule, array_capsule = obj.__arrow_c_device_array__(requested_type) + out_array = Array._import_from_c_device_capsule(schema_capsule, array_capsule) + if type is not None and out_array.type != type: + # PyCapsule interface type coercion is best effort, so we need to + # check the type of the returned array and cast if necessary + out_array = array.cast(type, safe=safe, memory_pool=memory_pool) + return out_array elif hasattr(obj, '__arrow_c_array__'): if type is not None: requested_type = type.__arrow_c_schema__() @@ -1880,6 +1893,89 @@ cdef class Array(_PandasConvertible): ) return pyarrow_wrap_array(c_array) + def __arrow_c_device_array__(self, requested_schema=None, **kwargs): + """ + Get a pair of PyCapsules containing a C ArrowDeviceArray representation + of the object. + + Parameters + ---------- + requested_schema : PyCapsule | None + A PyCapsule containing a C ArrowSchema representation of a requested + schema. PyArrow will attempt to cast the array to this data type. + If None, the array will be returned as-is, with a type matching the + one returned by :meth:`__arrow_c_schema__()`. + kwargs + Currently no additional keyword arguments are supported, but + this method will accept any keyword with a value of ``None`` + for compatibility with future keywords. + + Returns + ------- + Tuple[PyCapsule, PyCapsule] + A pair of PyCapsules containing a C ArrowSchema and ArrowDeviceArray, + respectively. + """ + cdef: + ArrowDeviceArray* c_array + ArrowSchema* c_schema + shared_ptr[CArray] inner_array + + non_default_kwargs = [ + name for name, value in kwargs.items() if value is not None + ] + if non_default_kwargs: + raise NotImplementedError( + f"Received unsupported keyword argument(s): {non_default_kwargs}" + ) + + if requested_schema is not None: + target_type = DataType._import_from_c_capsule(requested_schema) + + if target_type != self.type: + if not self.is_cpu: + raise NotImplementedError( + "Casting to a requested schema is only supported for CPU data" + ) + try: + casted_array = _pc().cast(self, target_type, safe=True) + inner_array = pyarrow_unwrap_array(casted_array) + except ArrowInvalid as e: + raise ValueError( + f"Could not cast {self.type} to requested type {target_type}: {e}" + ) + else: + inner_array = self.sp_array + else: + inner_array = self.sp_array + + schema_capsule = alloc_c_schema(&c_schema) + array_capsule = alloc_c_device_array(&c_array) + + with nogil: + check_status(ExportDeviceArray( + deref(inner_array), NULL, + c_array, c_schema)) + + return schema_capsule, array_capsule + + @staticmethod + def _import_from_c_device_capsule(schema_capsule, array_capsule): + cdef: + ArrowSchema* c_schema + ArrowDeviceArray* c_array + shared_ptr[CArray] array + + c_schema = PyCapsule_GetPointer(schema_capsule, 'arrow_schema') + c_array = PyCapsule_GetPointer( + array_capsule, 'arrow_device_array' + ) + + with nogil: + array = GetResultValue(ImportDeviceArray(c_array, c_schema)) + + return pyarrow_wrap_array(array) + def __dlpack__(self, stream=None): """Export a primitive array as a DLPack capsule. diff --git a/python/pyarrow/includes/libarrow.pxd b/python/pyarrow/includes/libarrow.pxd index 3dee463118442..0d871f411b11b 100644 --- a/python/pyarrow/includes/libarrow.pxd +++ b/python/pyarrow/includes/libarrow.pxd @@ -1016,6 +1016,8 @@ cdef extern from "arrow/api.h" namespace "arrow" nogil: int num_columns() int64_t num_rows() + CDeviceAllocationType device_type() + CStatus Validate() const CStatus ValidateFull() const diff --git a/python/pyarrow/table.pxi b/python/pyarrow/table.pxi index 03018162694b7..eb9ba650dbf60 100644 --- a/python/pyarrow/table.pxi +++ b/python/pyarrow/table.pxi @@ -3639,7 +3639,7 @@ cdef class RecordBatch(_Tabular): requested_schema : PyCapsule | None A PyCapsule containing a C ArrowSchema representation of a requested schema. PyArrow will attempt to cast the batch to this schema. - If None, the schema will be returned as-is, with a schema matching the + If None, the batch will be returned as-is, with a schema matching the one returned by :meth:`__arrow_c_schema__()`. Returns @@ -3657,9 +3657,7 @@ cdef class RecordBatch(_Tabular): if target_schema != self.schema: try: - # We don't expose .cast() on RecordBatch, only on Table. - casted_batch = Table.from_batches([self]).cast( - target_schema, safe=True).to_batches()[0] + casted_batch = self.cast(target_schema, safe=True) inner_batch = pyarrow_unwrap_batch(casted_batch) except ArrowInvalid as e: raise ValueError( @@ -3700,8 +3698,8 @@ cdef class RecordBatch(_Tabular): @staticmethod def _import_from_c_capsule(schema_capsule, array_capsule): """ - Import RecordBatch from a pair of PyCapsules containing a C ArrowArray - and ArrowSchema, respectively. + Import RecordBatch from a pair of PyCapsules containing a C ArrowSchema + and ArrowArray, respectively. Parameters ---------- @@ -3792,6 +3790,121 @@ cdef class RecordBatch(_Tabular): c_device_array, c_schema)) return pyarrow_wrap_batch(c_batch) + def __arrow_c_device_array__(self, requested_schema=None, **kwargs): + """ + Get a pair of PyCapsules containing a C ArrowDeviceArray representation + of the object. + + Parameters + ---------- + requested_schema : PyCapsule | None + A PyCapsule containing a C ArrowSchema representation of a requested + schema. PyArrow will attempt to cast the batch to this data type. + If None, the batch will be returned as-is, with a type matching the + one returned by :meth:`__arrow_c_schema__()`. + kwargs + Currently no additional keyword arguments are supported, but + this method will accept any keyword with a value of ``None`` + for compatibility with future keywords. + + Returns + ------- + Tuple[PyCapsule, PyCapsule] + A pair of PyCapsules containing a C ArrowSchema and ArrowDeviceArray, + respectively. + """ + cdef: + ArrowDeviceArray* c_array + ArrowSchema* c_schema + shared_ptr[CRecordBatch] inner_batch + + non_default_kwargs = [ + name for name, value in kwargs.items() if value is not None + ] + if non_default_kwargs: + raise NotImplementedError( + f"Received unsupported keyword argument(s): {non_default_kwargs}" + ) + + if requested_schema is not None: + target_schema = Schema._import_from_c_capsule(requested_schema) + + if target_schema != self.schema: + if not self.is_cpu: + raise NotImplementedError( + "Casting to a requested schema is only supported for CPU data" + ) + try: + casted_batch = self.cast(target_schema, safe=True) + inner_batch = pyarrow_unwrap_batch(casted_batch) + except ArrowInvalid as e: + raise ValueError( + f"Could not cast {self.schema} to requested schema {target_schema}: {e}" + ) + else: + inner_batch = self.sp_batch + else: + inner_batch = self.sp_batch + + schema_capsule = alloc_c_schema(&c_schema) + array_capsule = alloc_c_device_array(&c_array) + + with nogil: + check_status(ExportDeviceRecordBatch( + deref(inner_batch), NULL, c_array, c_schema)) + + return schema_capsule, array_capsule + + @staticmethod + def _import_from_c_device_capsule(schema_capsule, array_capsule): + """ + Import RecordBatch from a pair of PyCapsules containing a + C ArrowSchema and ArrowDeviceArray, respectively. + + Parameters + ---------- + schema_capsule : PyCapsule + A PyCapsule containing a C ArrowSchema representation of the schema. + array_capsule : PyCapsule + A PyCapsule containing a C ArrowDeviceArray representation of the array. + + Returns + ------- + pyarrow.RecordBatch + """ + cdef: + ArrowSchema* c_schema + ArrowDeviceArray* c_array + shared_ptr[CRecordBatch] batch + + c_schema = PyCapsule_GetPointer(schema_capsule, 'arrow_schema') + c_array = PyCapsule_GetPointer( + array_capsule, 'arrow_device_array' + ) + + with nogil: + batch = GetResultValue(ImportDeviceRecordBatch(c_array, c_schema)) + + return pyarrow_wrap_batch(batch) + + @property + def device_type(self): + """ + The device type where the arrays in the RecordBatch reside. + + Returns + ------- + DeviceAllocationType + """ + return _wrap_device_allocation_type(self.sp_batch.get().device_type()) + + @property + def is_cpu(self): + """ + Whether the RecordBatch's arrays are CPU-accessible. + """ + return self.device_type == DeviceAllocationType.CPU + def _reconstruct_record_batch(columns, schema): """ @@ -5584,7 +5697,8 @@ def record_batch(data, names=None, schema=None, metadata=None): data : dict, list, pandas.DataFrame, Arrow-compatible table A mapping of strings to Arrays or Python lists, a list of Arrays, a pandas DataFame, or any tabular object implementing the - Arrow PyCapsule Protocol (has an ``__arrow_c_array__`` method). + Arrow PyCapsule Protocol (has an ``__arrow_c_array__`` or + ``__arrow_c_device_array__`` method). names : list, default None Column names if list of arrays passed as data. Mutually exclusive with 'schema' argument. @@ -5718,6 +5832,18 @@ def record_batch(data, names=None, schema=None, metadata=None): raise ValueError( "The 'names' argument is not valid when passing a dictionary") return RecordBatch.from_pydict(data, schema=schema, metadata=metadata) + elif hasattr(data, "__arrow_c_device_array__"): + if schema is not None: + requested_schema = schema.__arrow_c_schema__() + else: + requested_schema = None + schema_capsule, array_capsule = data.__arrow_c_device_array__(requested_schema) + batch = RecordBatch._import_from_c_device_capsule(schema_capsule, array_capsule) + if schema is not None and batch.schema != schema: + # __arrow_c_device_array__ coerces schema with best effort, so we might + # need to cast it if the producer wasn't able to cast to exact schema. + batch = batch.cast(schema) + return batch elif hasattr(data, "__arrow_c_array__"): if schema is not None: requested_schema = schema.__arrow_c_schema__() @@ -5747,8 +5873,8 @@ def table(data, names=None, schema=None, metadata=None, nthreads=None): data : dict, list, pandas.DataFrame, Arrow-compatible table A mapping of strings to Arrays or Python lists, a list of arrays or chunked arrays, a pandas DataFame, or any tabular object implementing - the Arrow PyCapsule Protocol (has an ``__arrow_c_array__`` or - ``__arrow_c_stream__`` method). + the Arrow PyCapsule Protocol (has an ``__arrow_c_array__``, + ``__arrow_c_device_array__`` or ``__arrow_c_stream__`` method). names : list, default None Column names if list of arrays passed as data. Mutually exclusive with 'schema' argument. @@ -5888,7 +6014,7 @@ def table(data, names=None, schema=None, metadata=None, nthreads=None): # need to cast it if the producer wasn't able to cast to exact schema. table = table.cast(schema) return table - elif hasattr(data, "__arrow_c_array__"): + elif hasattr(data, "__arrow_c_array__") or hasattr(data, "__arrow_c_device_array__"): if names is not None or metadata is not None: raise ValueError( "The 'names' and 'metadata' arguments are not valid when " diff --git a/python/pyarrow/tests/test_array.py b/python/pyarrow/tests/test_array.py index 27e78e3396ec8..78d06b26e3622 100644 --- a/python/pyarrow/tests/test_array.py +++ b/python/pyarrow/tests/test_array.py @@ -3505,16 +3505,27 @@ def __arrow_array__(self, type=None): assert result.equals(expected) -def test_c_array_protocol(): - class ArrayWrapper: - def __init__(self, data): - self.data = data +class ArrayWrapper: + def __init__(self, data): + self.data = data + + def __arrow_c_array__(self, requested_schema=None): + return self.data.__arrow_c_array__(requested_schema) + + +class ArrayDeviceWrapper: + def __init__(self, data): + self.data = data + + def __arrow_c_device_array__(self, requested_schema=None, **kwargs): + return self.data.__arrow_c_device_array__(requested_schema, **kwargs) - def __arrow_c_array__(self, requested_schema=None): - return self.data.__arrow_c_array__(requested_schema) + +@pytest.mark.parametrize("wrapper_class", [ArrayWrapper, ArrayDeviceWrapper]) +def test_c_array_protocol(wrapper_class): # Can roundtrip through the C array protocol - arr = ArrayWrapper(pa.array([1, 2, 3], type=pa.int64())) + arr = wrapper_class(pa.array([1, 2, 3], type=pa.int64())) result = pa.array(arr) assert result == arr.data @@ -3523,6 +3534,20 @@ def __arrow_c_array__(self, requested_schema=None): assert result == pa.array([1, 2, 3], type=pa.int32()) +def test_c_array_protocol_device_unsupported_keyword(): + # For the device-aware version, we raise a specific error for unsupported keywords + arr = pa.array([1, 2, 3], type=pa.int64()) + + with pytest.raises( + NotImplementedError, + match=r"Received unsupported keyword argument\(s\): \['other'\]" + ): + arr.__arrow_c_device_array__(other="not-none") + + # but with None value it is ignored + _ = arr.__arrow_c_device_array__(other=None) + + def test_concat_array(): concatenated = pa.concat_arrays( [pa.array([1, 2]), pa.array([3, 4])]) diff --git a/python/pyarrow/tests/test_cffi.py b/python/pyarrow/tests/test_cffi.py index 369ed9142824d..70841eeb0619a 100644 --- a/python/pyarrow/tests/test_cffi.py +++ b/python/pyarrow/tests/test_cffi.py @@ -603,6 +603,48 @@ def test_roundtrip_array_capsule(arr, schema_accessor, bad_type, good_type): assert schema_accessor(arr_out) == good_type +@pytest.mark.parametrize('arr,schema_accessor,bad_type,good_type', [ + (pa.array(['a', 'b', 'c']), lambda x: x.type, pa.int32(), pa.string()), + ( + pa.record_batch([pa.array(['a', 'b', 'c'])], names=['x']), + lambda x: x.schema, + pa.schema({'x': pa.int32()}), + pa.schema({'x': pa.string()}) + ), +], ids=['array', 'record_batch']) +def test_roundtrip_device_array_capsule(arr, schema_accessor, bad_type, good_type): + gc.collect() # Make sure no Arrow data dangles in a ref cycle + old_allocated = pa.total_allocated_bytes() + + import_array = type(arr)._import_from_c_device_capsule + + schema_capsule, capsule = arr.__arrow_c_device_array__() + assert PyCapsule_IsValid(schema_capsule, b"arrow_schema") == 1 + assert PyCapsule_IsValid(capsule, b"arrow_device_array") == 1 + arr_out = import_array(schema_capsule, capsule) + assert arr_out.equals(arr) + + assert pa.total_allocated_bytes() > old_allocated + del arr_out + + assert pa.total_allocated_bytes() == old_allocated + + capsule = arr.__arrow_c_array__() + + assert pa.total_allocated_bytes() > old_allocated + del capsule + assert pa.total_allocated_bytes() == old_allocated + + with pytest.raises(ValueError, + match=r"Could not cast.* string to requested .* int32"): + arr.__arrow_c_device_array__(bad_type.__arrow_c_schema__()) + + schema_capsule, array_capsule = arr.__arrow_c_device_array__( + good_type.__arrow_c_schema__()) + arr_out = import_array(schema_capsule, array_capsule) + assert schema_accessor(arr_out) == good_type + + # TODO: implement requested_schema for stream @pytest.mark.parametrize('constructor', [ pa.RecordBatchReader.from_batches, diff --git a/python/pyarrow/tests/test_table.py b/python/pyarrow/tests/test_table.py index 3d5a8a9829e56..30c687b0d94df 100644 --- a/python/pyarrow/tests/test_table.py +++ b/python/pyarrow/tests/test_table.py @@ -535,18 +535,28 @@ def __arrow_c_stream__(self, requested_schema=None): assert result == data.cast(pa.int16()) -def test_recordbatch_c_array_interface(): - class BatchWrapper: - def __init__(self, batch): - self.batch = batch +class BatchWrapper: + def __init__(self, batch): + self.batch = batch - def __arrow_c_array__(self, requested_schema=None): - return self.batch.__arrow_c_array__(requested_schema) + def __arrow_c_array__(self, requested_schema=None): + return self.batch.__arrow_c_array__(requested_schema) + + +class BatchDeviceWrapper: + def __init__(self, batch): + self.batch = batch + def __arrow_c_device_array__(self, requested_schema=None, **kwargs): + return self.batch.__arrow_c_device_array__(requested_schema, **kwargs) + + +@pytest.mark.parametrize("wrapper_class", [BatchWrapper, BatchDeviceWrapper]) +def test_recordbatch_c_array_interface(wrapper_class): data = pa.record_batch([ pa.array([1, 2, 3], type=pa.int64()) ], names=['a']) - wrapper = BatchWrapper(data) + wrapper = wrapper_class(data) # Can roundtrip through the wrapper. result = pa.record_batch(wrapper) @@ -563,18 +573,28 @@ def __arrow_c_array__(self, requested_schema=None): assert result == expected -def test_table_c_array_interface(): - class BatchWrapper: - def __init__(self, batch): - self.batch = batch +def test_recordbatch_c_array_interface_device_unsupported_keyword(): + # For the device-aware version, we raise a specific error for unsupported keywords + data = pa.record_batch( + [pa.array([1, 2, 3], type=pa.int64())], names=['a'] + ) + + with pytest.raises( + NotImplementedError, + match=r"Received unsupported keyword argument\(s\): \['other'\]" + ): + data.__arrow_c_device_array__(other="not-none") + + # but with None value it is ignored + _ = data.__arrow_c_device_array__(other=None) - def __arrow_c_array__(self, requested_schema=None): - return self.batch.__arrow_c_array__(requested_schema) +@pytest.mark.parametrize("wrapper_class", [BatchWrapper, BatchDeviceWrapper]) +def test_table_c_array_interface(wrapper_class): data = pa.record_batch([ pa.array([1, 2, 3], type=pa.int64()) ], names=['a']) - wrapper = BatchWrapper(data) + wrapper = wrapper_class(data) # Can roundtrip through the wrapper. result = pa.table(wrapper) diff --git a/python/pyarrow/types.pxi b/python/pyarrow/types.pxi index ab7bba3779281..4343d7ea300b0 100644 --- a/python/pyarrow/types.pxi +++ b/python/pyarrow/types.pxi @@ -5542,3 +5542,24 @@ cdef object alloc_c_stream(ArrowArrayStream** c_stream): # Ensure the capsule destructor doesn't call a random release pointer c_stream[0].release = NULL return PyCapsule_New(c_stream[0], 'arrow_array_stream', &pycapsule_stream_deleter) + + +cdef void pycapsule_device_array_deleter(object array_capsule) noexcept: + cdef: + ArrowDeviceArray* device_array + # Do not invoke the deleter on a used/moved capsule + device_array = cpython.PyCapsule_GetPointer( + array_capsule, 'arrow_device_array' + ) + if device_array.array.release != NULL: + device_array.array.release(&device_array.array) + + free(device_array) + + +cdef object alloc_c_device_array(ArrowDeviceArray** c_array): + c_array[0] = malloc(sizeof(ArrowDeviceArray)) + # Ensure the capsule destructor doesn't call a random release pointer + c_array[0].array.release = NULL + return PyCapsule_New( + c_array[0], 'arrow_device_array', &pycapsule_device_array_deleter)