Skip to content

Commit

Permalink
Implement Arrow PyCapsule Interface (#5070)
Browse files Browse the repository at this point in the history
* arrow ffi array copy

* remove copy_ffi_array

* docstring

* wip: pycapsule support

* return

* Update arrow/src/pyarrow.rs

Co-authored-by: Raphael Taylor-Davies <1781103+tustvold@users.noreply.github.com>

* remove sync impl

* Update arrow/src/pyarrow.rs

Co-authored-by: Will Jones <willjones127@gmail.com>

* Remove copy()

* Need &mut FFI_ArrowArray for std::mem::replace

* Use std::ptr::replace

* update comments

* Minimize unsafe block

* revert pub release functions

* Add RecordBatch and Stream conversion

* fix returns

* Fix return type

* Fix name

* fix ci

* Add tests

* Add table test

* skip if pre pyarrow 14

* bump python version in CI to use pyarrow 14

* Add record batch test

* Update arrow/src/pyarrow.rs

Co-authored-by: Raphael Taylor-Davies <1781103+tustvold@users.noreply.github.com>

* run on pyarrow 13 and 14

* Update .github/workflows/integration.yml

Co-authored-by: Will Jones <willjones127@gmail.com>

---------

Co-authored-by: Raphael Taylor-Davies <1781103+tustvold@users.noreply.github.com>
Co-authored-by: Will Jones <willjones127@gmail.com>
  • Loading branch information
3 people committed Nov 15, 2023
1 parent 4b9d789 commit aff86e7
Show file tree
Hide file tree
Showing 5 changed files with 274 additions and 8 deletions.
6 changes: 4 additions & 2 deletions .github/workflows/integration.yml
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,8 @@ jobs:
strategy:
matrix:
rust: [ stable ]
# PyArrow 13 was the last version prior to introduction to Arrow PyCapsules
pyarrow: [ "13", "14" ]
steps:
- uses: actions/checkout@v4
with:
Expand All @@ -128,14 +130,14 @@ jobs:
key: ${{ runner.os }}-${{ matrix.arch }}-target-maturin-cache-${{ matrix.rust }}-
- uses: actions/setup-python@v4
with:
python-version: '3.7'
python-version: '3.8'
- name: Upgrade pip and setuptools
run: pip install --upgrade pip setuptools wheel virtualenv
- name: Create virtualenv and install dependencies
run: |
virtualenv venv
source venv/bin/activate
pip install maturin toml pytest pytz pyarrow>=5.0
pip install maturin toml pytest pytz pyarrow==${{ matrix.pyarrow }}
- name: Run Rust tests
run: |
source venv/bin/activate
Expand Down
2 changes: 2 additions & 0 deletions arrow-pyarrow-integration-testing/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ Note that this crate uses two languages and an external ABI:
* `Rust`
* `Python`
* C ABI privately exposed by `Pyarrow`.
* PyCapsule ABI publicly exposed by `pyarrow`

## Basic idea

Expand All @@ -36,6 +37,7 @@ we can use pyarrow's interface to move pointers from and to Rust.
## Relevant literature

* [Arrow's CDataInterface](https://arrow.apache.org/docs/format/CDataInterface.html)
* [Arrow PyCapsule Interface](https://arrow.apache.org/docs/format/CDataInterface/PyCapsuleInterface.html)
* [Rust's FFI](https://doc.rust-lang.org/nomicon/ffi.html)
* [Pyarrow private binds](https://github.com/apache/arrow/blob/ae1d24efcc3f1ac2a876d8d9f544a34eb04ae874/python/pyarrow/array.pxi#L1226)
* [PyO3](https://docs.rs/pyo3/0.12.1/pyo3/index.html)
Expand Down
138 changes: 134 additions & 4 deletions arrow-pyarrow-integration-testing/tests/test_sql.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,8 @@

import arrow_pyarrow_integration_testing as rust

PYARROW_PRE_14 = int(pa.__version__.split('.')[0]) < 14


@contextlib.contextmanager
def no_pyarrow_leak():
Expand Down Expand Up @@ -113,13 +115,49 @@ def assert_pyarrow_leak():
_unsupported_pyarrow_types = [
]

# As of pyarrow 14, pyarrow implements the Arrow PyCapsule interface
# (https://arrow.apache.org/docs/format/CDataInterface/PyCapsuleInterface.html).
# This defines that Arrow consumers should allow any object that has specific "dunder"
# methods, `__arrow_c_*_`. These wrapper classes ensure that arrow-rs is able to handle
# _any_ class, without pyarrow-specific handling.
class SchemaWrapper:
def __init__(self, schema):
self.schema = schema

def __arrow_c_schema__(self):
return self.schema.__arrow_c_schema__()


class ArrayWrapper:
def __init__(self, array):
self.array = array

def __arrow_c_array__(self):
return self.array.__arrow_c_array__()


class StreamWrapper:
def __init__(self, stream):
self.stream = stream

def __arrow_c_stream__(self):
return self.stream.__arrow_c_stream__()


@pytest.mark.parametrize("pyarrow_type", _supported_pyarrow_types, ids=str)
def test_type_roundtrip(pyarrow_type):
restored = rust.round_trip_type(pyarrow_type)
assert restored == pyarrow_type
assert restored is not pyarrow_type

@pytest.mark.skipif(PYARROW_PRE_14, reason="requires pyarrow 14")
@pytest.mark.parametrize("pyarrow_type", _supported_pyarrow_types, ids=str)
def test_type_roundtrip_pycapsule(pyarrow_type):
wrapped = SchemaWrapper(pyarrow_type)
restored = rust.round_trip_type(wrapped)
assert restored == pyarrow_type
assert restored is not pyarrow_type


@pytest.mark.parametrize("pyarrow_type", _unsupported_pyarrow_types, ids=str)
def test_type_roundtrip_raises(pyarrow_type):
Expand All @@ -138,6 +176,20 @@ def test_field_roundtrip(pyarrow_type):
field = rust.round_trip_field(pyarrow_field)
assert field == pyarrow_field

@pytest.mark.skipif(PYARROW_PRE_14, reason="requires pyarrow 14")
@pytest.mark.parametrize('pyarrow_type', _supported_pyarrow_types, ids=str)
def test_field_roundtrip_pycapsule(pyarrow_type):
pyarrow_field = pa.field("test", pyarrow_type, nullable=True)
wrapped = SchemaWrapper(pyarrow_field)
field = rust.round_trip_field(wrapped)
assert field == wrapped.schema

if pyarrow_type != pa.null():
# A null type field may not be non-nullable
pyarrow_field = pa.field("test", pyarrow_type, nullable=False)
field = rust.round_trip_field(wrapped)
assert field == wrapped.schema

def test_field_metadata_roundtrip():
metadata = {"hello": "World! 😊", "x": "2"}
pyarrow_field = pa.field("test", pa.int32(), metadata=metadata)
Expand All @@ -163,6 +215,17 @@ def test_primitive_python():
del b


@pytest.mark.skipif(PYARROW_PRE_14, reason="requires pyarrow 14")
def test_primitive_python_pycapsule():
"""
Python -> Rust -> Python
"""
a = pa.array([1, 2, 3])
wrapped = ArrayWrapper(a)
b = rust.double(wrapped)
assert b == pa.array([2, 4, 6])


def test_primitive_rust():
"""
Rust -> Python -> Rust
Expand Down Expand Up @@ -433,6 +496,33 @@ def test_record_batch_reader():
got_batches = list(b)
assert got_batches == batches

@pytest.mark.skipif(PYARROW_PRE_14, reason="requires pyarrow 14")
def test_record_batch_reader_pycapsule():
"""
Python -> Rust -> Python
"""
schema = pa.schema([('ints', pa.list_(pa.int32()))], metadata={b'key1': b'value1'})
batches = [
pa.record_batch([[[1], [2, 42]]], schema),
pa.record_batch([[None, [], [5, 6]]], schema),
]
a = pa.RecordBatchReader.from_batches(schema, batches)
wrapped = StreamWrapper(a)
b = rust.round_trip_record_batch_reader(wrapped)

assert b.schema == schema
got_batches = list(b)
assert got_batches == batches

# Also try the boxed reader variant
a = pa.RecordBatchReader.from_batches(schema, batches)
wrapped = StreamWrapper(a)
b = rust.boxed_reader_roundtrip(wrapped)
assert b.schema == schema
got_batches = list(b)
assert got_batches == batches


def test_record_batch_reader_error():
schema = pa.schema([('ints', pa.list_(pa.int32()))])

Expand All @@ -453,24 +543,64 @@ def iter_batches():
with pytest.raises(ValueError, match="invalid utf-8"):
rust.round_trip_record_batch_reader(reader)


@pytest.mark.skipif(PYARROW_PRE_14, reason="requires pyarrow 14")
def test_record_batch_pycapsule():
"""
Python -> Rust -> Python
"""
schema = pa.schema([('ints', pa.list_(pa.int32()))], metadata={b'key1': b'value1'})
batch = pa.record_batch([[[1], [2, 42]]], schema)
wrapped = StreamWrapper(batch)
b = rust.round_trip_record_batch_reader(wrapped)
new_table = b.read_all()
new_batches = new_table.to_batches()

assert len(new_batches) == 1
new_batch = new_batches[0]

assert batch == new_batch
assert batch.schema == new_batch.schema


@pytest.mark.skipif(PYARROW_PRE_14, reason="requires pyarrow 14")
def test_table_pycapsule():
"""
Python -> Rust -> Python
"""
schema = pa.schema([('ints', pa.list_(pa.int32()))], metadata={b'key1': b'value1'})
batches = [
pa.record_batch([[[1], [2, 42]]], schema),
pa.record_batch([[None, [], [5, 6]]], schema),
]
table = pa.Table.from_batches(batches)
wrapped = StreamWrapper(table)
b = rust.round_trip_record_batch_reader(wrapped)
new_table = b.read_all()

assert table.schema == new_table.schema
assert table == new_table
assert len(table.to_batches()) == len(new_table.to_batches())


def test_reject_other_classes():
# Arbitrary type that is not a PyArrow type
not_pyarrow = ["hello"]

with pytest.raises(TypeError, match="Expected instance of pyarrow.lib.Array, got builtins.list"):
rust.round_trip_array(not_pyarrow)

with pytest.raises(TypeError, match="Expected instance of pyarrow.lib.Schema, got builtins.list"):
rust.round_trip_schema(not_pyarrow)

with pytest.raises(TypeError, match="Expected instance of pyarrow.lib.Field, got builtins.list"):
rust.round_trip_field(not_pyarrow)

with pytest.raises(TypeError, match="Expected instance of pyarrow.lib.DataType, got builtins.list"):
rust.round_trip_type(not_pyarrow)

with pytest.raises(TypeError, match="Expected instance of pyarrow.lib.RecordBatch, got builtins.list"):
rust.round_trip_record_batch(not_pyarrow)

with pytest.raises(TypeError, match="Expected instance of pyarrow.lib.RecordBatchReader, got builtins.list"):
rust.round_trip_record_batch_reader(not_pyarrow)
2 changes: 2 additions & 0 deletions arrow-schema/src/ffi.rs
Original file line number Diff line number Diff line change
Expand Up @@ -351,6 +351,8 @@ impl Drop for FFI_ArrowSchema {
}
}

unsafe impl Send for FFI_ArrowSchema {}

impl TryFrom<&FFI_ArrowSchema> for DataType {
type Error = ArrowError;

Expand Down

0 comments on commit aff86e7

Please sign in to comment.