Skip to content

Commit

Permalink
GH-33926: [Python] DataFrame Interchange Protocol for pyarrow.RecordB…
Browse files Browse the repository at this point in the history
…atch (#34294)

### Rationale for this change
Add the implementation of the Dataframe Interchange Protocol for `pyarrow.RecordBatch`. The protocol is already implemented for pyarrow.Table, see #14804.

### Are these changes tested?
Yes, tests are added to:

- python/pyarrow/tests/interchange/test_interchange_spec.py
- python/pyarrow/tests/interchange/test_conversion.py
* Closes: #33926

Authored-by: Alenka Frim <frim.alenka@gmail.com>
Signed-off-by: Joris Van den Bossche <jorisvandenbossche@gmail.com>
  • Loading branch information
AlenkaF committed Feb 28, 2023
1 parent 61c9a74 commit 6cf5e89
Show file tree
Hide file tree
Showing 5 changed files with 105 additions and 31 deletions.
51 changes: 33 additions & 18 deletions python/pyarrow/interchange/dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,11 +44,13 @@ class _PyArrowDataFrame:
"""

def __init__(
self, df: pa.Table, nan_as_null: bool = False, allow_copy: bool = True
self, df: pa.Table | pa.RecordBatch,
nan_as_null: bool = False,
allow_copy: bool = True
) -> None:
"""
Constructor - an instance of this (private) class is returned from
`pa.Table.__dataframe__`.
`pa.Table.__dataframe__` or `pa.RecordBatch.__dataframe__`.
"""
self._df = df
# ``nan_as_null`` is a keyword intended for the consumer to tell the
Expand Down Expand Up @@ -114,18 +116,21 @@ def num_chunks(self) -> int:
"""
Return the number of chunks the DataFrame consists of.
"""
# pyarrow.Table can have columns with different number
# of chunks so we take the number of chunks that
# .to_batches() returns as it takes the min chunk size
# of all the columns (to_batches is a zero copy method)
batches = self._df.to_batches()
return len(batches)
if isinstance(self._df, pa.RecordBatch):
return 1
else:
# pyarrow.Table can have columns with different number
# of chunks so we take the number of chunks that
# .to_batches() returns as it takes the min chunk size
# of all the columns (to_batches is a zero copy method)
batches = self._df.to_batches()
return len(batches)

def column_names(self) -> Iterable[str]:
"""
Return an iterator yielding the column names.
"""
return self._df.column_names
return self._df.schema.names

def get_column(self, i: int) -> _PyArrowColumn:
"""
Expand Down Expand Up @@ -182,21 +187,31 @@ def get_chunks(
Note that the producer must ensure that all columns are chunked the
same way.
"""
# Subdivide chunks
if n_chunks and n_chunks > 1:
chunk_size = self.num_rows() // n_chunks
if self.num_rows() % n_chunks != 0:
chunk_size += 1
batches = self._df.to_batches(max_chunksize=chunk_size)
if isinstance(self._df, pa.Table):
batches = self._df.to_batches(max_chunksize=chunk_size)
else:
batches = []
for start in range(0, chunk_size * n_chunks, chunk_size):
batches.append(self._df.slice(start, chunk_size))
# In case when the size of the chunk is such that the resulting
# list is one less chunk then n_chunks -> append an empty chunk
if len(batches) == n_chunks - 1:
batches.append(pa.record_batch([[]], schema=self._df.schema))
# yields the chunks that the data is stored as
else:
batches = self._df.to_batches()

iterator_tables = [_PyArrowDataFrame(
pa.Table.from_batches([batch]), self._nan_as_null, self._allow_copy
)
for batch in batches
]
return iterator_tables
if isinstance(self._df, pa.Table):
batches = self._df.to_batches()
else:
batches = [self._df]

# Create an iterator of RecordBatches
iterator = [_PyArrowDataFrame(batch,
self._nan_as_null,
self._allow_copy)
for batch in batches]
return iterator
8 changes: 4 additions & 4 deletions python/pyarrow/interchange/from_dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,8 +60,7 @@

def from_dataframe(df: DataFrameObject, allow_copy=True) -> pa.Table:
"""
Build a ``pa.Table`` from any DataFrame supporting the interchange
protocol.
Build a ``pa.Table`` from any DataFrame supporting the interchange protocol.
Parameters
----------
Expand All @@ -78,6 +77,8 @@ def from_dataframe(df: DataFrameObject, allow_copy=True) -> pa.Table:
"""
if isinstance(df, pa.Table):
return df
elif isinstance(df, pa.RecordBatch):
return pa.Table.from_batches([df])

if not hasattr(df, "__dataframe__"):
raise ValueError("`df` does not support __dataframe__")
Expand Down Expand Up @@ -108,8 +109,7 @@ def _from_dataframe(df: DataFrameObject, allow_copy=True):
batch = protocol_df_chunk_to_pyarrow(chunk, allow_copy)
batches.append(batch)

table = pa.Table.from_batches(batches)
return table
return pa.Table.from_batches(batches)


def protocol_df_chunk_to_pyarrow(
Expand Down
33 changes: 33 additions & 0 deletions python/pyarrow/table.pxi
Original file line number Diff line number Diff line change
Expand Up @@ -1517,6 +1517,39 @@ cdef class RecordBatch(_PandasConvertible):
self.sp_batch = batch
self.batch = batch.get()

# ----------------------------------------------------------------------
def __dataframe__(self, nan_as_null: bool = False, allow_copy: bool = True):
"""
Return the dataframe interchange object implementing the interchange protocol.
Parameters
----------
nan_as_null : bool, default False
Whether to tell the DataFrame to overwrite null values in the data
with ``NaN`` (or ``NaT``).
allow_copy : bool, default True
Whether to allow memory copying when exporting. If set to False
it would cause non-zero-copy exports to fail.
Returns
-------
DataFrame interchange object
The object which consuming library can use to ingress the dataframe.
Notes
-----
Details on the interchange protocol:
https://data-apis.org/dataframe-protocol/latest/index.html
`nan_as_null` currently has no effect; once support for nullable extension
dtypes is added, this value should be propagated to columns.
"""

from pyarrow.interchange.dataframe import _PyArrowDataFrame

return _PyArrowDataFrame(self, nan_as_null, allow_copy)

# ----------------------------------------------------------------------

@staticmethod
def from_pydict(mapping, schema=None, metadata=None):
"""
Expand Down
3 changes: 2 additions & 1 deletion python/pyarrow/tests/interchange/test_conversion.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,7 @@ def test_categorical_roundtrip():

if Version(pd.__version__) < Version("1.5.0"):
pytest.skip("__dataframe__ added to pandas in 1.5.0")

arr = ["Mon", "Tue", "Mon", "Wed", "Mon", "Thu", "Fri", "Sat", "Sun"]
table = pa.table(
{"weekday": pa.array(arr).dictionary_encode()}
Expand Down Expand Up @@ -447,7 +448,7 @@ def test_pyarrow_roundtrip_categorical(offset, length):
assert col_result.size() == col_table.size()
assert col_result.offset == col_table.offset

desc_cat_table = col_result.describe_categorical
desc_cat_table = col_table.describe_categorical
desc_cat_result = col_result.describe_categorical

assert desc_cat_table["is_ordered"] == desc_cat_result["is_ordered"]
Expand Down
41 changes: 33 additions & 8 deletions python/pyarrow/tests/interchange/test_interchange_spec.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,8 +76,10 @@ def test_dtypes(arr):
)
@pytest.mark.parametrize("unit", ['s', 'ms', 'us', 'ns'])
@pytest.mark.parametrize("tz", ['', 'America/New_York', '+07:30', '-04:30'])
@pytest.mark.parametrize("use_batch", [False, True])
def test_mixed_dtypes(uint, uint_bw, int, int_bw,
float, float_bw, np_float, unit, tz):
float, float_bw, np_float, unit, tz,
use_batch):
from datetime import datetime as dt
arr = [1, 2, 3]
dt_arr = [dt(2007, 7, 13), dt(2007, 7, 14), dt(2007, 7, 15)]
Expand All @@ -91,6 +93,8 @@ def test_mixed_dtypes(uint, uint_bw, int, int_bw,
"f": pa.array(dt_arr, type=pa.timestamp(unit, tz=tz))
}
)
if use_batch:
table = table.to_batches()[0]
df = table.__dataframe__()
# 0 = DtypeKind.INT, 1 = DtypeKind.UINT, 2 = DtypeKind.FLOAT,
# 20 = DtypeKind.BOOL, 21 = DtypeKind.STRING, 22 = DtypeKind.DATETIME
Expand Down Expand Up @@ -126,47 +130,62 @@ def test_noncategorical():
col.describe_categorical


def test_categorical():
@pytest.mark.parametrize("use_batch", [False, True])
def test_categorical(use_batch):
import pyarrow as pa
arr = ["Mon", "Tue", "Mon", "Wed", "Mon", "Thu", "Fri", "Sat", None]
table = pa.table(
{"weekday": pa.array(arr).dictionary_encode()}
)
if use_batch:
table = table.to_batches()[0]

col = table.__dataframe__().get_column_by_name("weekday")
categorical = col.describe_categorical
assert isinstance(categorical["is_ordered"], bool)
assert isinstance(categorical["is_dictionary"], bool)


def test_dataframe():
@pytest.mark.parametrize("use_batch", [False, True])
def test_dataframe(use_batch):
n = pa.chunked_array([[2, 2, 4], [4, 5, 100]])
a = pa.chunked_array([["Flamingo", "Parrot", "Cow"],
["Horse", "Brittle stars", "Centipede"]])
table = pa.table([n, a], names=['n_legs', 'animals'])
if use_batch:
table = table.combine_chunks().to_batches()[0]
df = table.__dataframe__()

assert df.num_columns() == 2
assert df.num_rows() == 6
assert df.num_chunks() == 2
if use_batch:
assert df.num_chunks() == 1
else:
assert df.num_chunks() == 2
assert list(df.column_names()) == ['n_legs', 'animals']
assert list(df.select_columns((1,)).column_names()) == list(
df.select_columns_by_name(("animals",)).column_names()
)


@pytest.mark.parametrize("use_batch", [False, True])
@pytest.mark.parametrize(["size", "n_chunks"], [(10, 3), (12, 3), (12, 5)])
def test_df_get_chunks(size, n_chunks):
def test_df_get_chunks(use_batch, size, n_chunks):
table = pa.table({"x": list(range(size))})
if use_batch:
table = table.to_batches()[0]
df = table.__dataframe__()
chunks = list(df.get_chunks(n_chunks))
assert len(chunks) == n_chunks
assert sum(chunk.num_rows() for chunk in chunks) == size


@pytest.mark.parametrize("use_batch", [False, True])
@pytest.mark.parametrize(["size", "n_chunks"], [(10, 3), (12, 3), (12, 5)])
def test_column_get_chunks(size, n_chunks):
def test_column_get_chunks(use_batch, size, n_chunks):
table = pa.table({"x": list(range(size))})
if use_batch:
table = table.to_batches()[0]
df = table.__dataframe__()
chunks = list(df.get_column(0).get_chunks(n_chunks))
assert len(chunks) == n_chunks
Expand All @@ -187,7 +206,8 @@ def test_column_get_chunks(size, n_chunks):
(pa.float64(), np.float64)
]
)
def test_get_columns(uint, int, float, np_float):
@pytest.mark.parametrize("use_batch", [False, True])
def test_get_columns(uint, int, float, np_float, use_batch):
arr = [[1, 2, 3], [4, 5]]
arr_float = np.array([1, 2, 3, 4, 5], dtype=np_float)
table = pa.table(
Expand All @@ -197,6 +217,8 @@ def test_get_columns(uint, int, float, np_float):
"c": pa.array(arr_float, type=float)
}
)
if use_batch:
table = table.combine_chunks().to_batches()[0]
df = table.__dataframe__()
for col in df.get_columns():
assert col.size() == 5
Expand All @@ -212,9 +234,12 @@ def test_get_columns(uint, int, float, np_float):
@pytest.mark.parametrize(
"int", [pa.int8(), pa.int16(), pa.int32(), pa.int64()]
)
def test_buffer(int):
@pytest.mark.parametrize("use_batch", [False, True])
def test_buffer(int, use_batch):
arr = [0, 1, -1]
table = pa.table({"a": pa.array(arr, type=int)})
if use_batch:
table = table.to_batches()[0]
df = table.__dataframe__()
col = df.get_column(0)
buf = col.get_buffers()
Expand Down

0 comments on commit 6cf5e89

Please sign in to comment.