Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
47 changes: 47 additions & 0 deletions python/tests/test_dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -1867,6 +1867,53 @@ def test_to_arrow_table(df):
assert set(pyarrow_table.column_names) == {"a", "b", "c"}


def test_parquet_non_null_column_to_pyarrow(ctx, tmp_path):
path = tmp_path.joinpath("t.parquet")

ctx.sql("create table t_(a int not null)").collect()
ctx.sql("insert into t_ values (1), (2), (3)").collect()
ctx.sql(f"copy (select * from t_) to '{path}'").collect()

ctx.register_parquet("t", path)
pyarrow_table = ctx.sql("select max(a) as m from t").to_arrow_table()
assert pyarrow_table.to_pydict() == {"m": [3]}


Comment on lines +1870 to +1881
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we should also add

  • a regression test for the empty-batch case to ensure when there are zero record batches the DataFrame schema is still applied (because that’s the behavior the code has preserved).
  • a test covering an edge-case where an aggregation yields a NULL (e.g., max on an empty input) and ensure to_arrow_table correctly represents that ensures we didn't hide other nullability issues

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good call, I've added those two new tests.

def test_parquet_empty_batch_to_pyarrow(ctx, tmp_path):
path = tmp_path.joinpath("t.parquet")

ctx.sql("create table t_(a int not null)").collect()
ctx.sql("insert into t_ values (1), (2), (3)").collect()
ctx.sql(f"copy (select * from t_) to '{path}'").collect()

ctx.register_parquet("t", path)
pyarrow_table = ctx.sql("select * from t limit 0").to_arrow_table()
assert pyarrow_table.schema == pa.schema(
[
pa.field("a", pa.int32(), nullable=False),
]
)


def test_parquet_null_aggregation_to_pyarrow(ctx, tmp_path):
path = tmp_path.joinpath("t.parquet")

ctx.sql("create table t_(a int not null)").collect()
ctx.sql("insert into t_ values (1), (2), (3)").collect()
ctx.sql(f"copy (select * from t_) to '{path}'").collect()

ctx.register_parquet("t", path)
pyarrow_table = ctx.sql(
"select max(a) as m from (select * from t where a < 0)"
).to_arrow_table()
assert pyarrow_table.to_pydict() == {"m": [None]}
assert pyarrow_table.schema == pa.schema(
[
pa.field("m", pa.int32(), nullable=True),
]
)


def test_execute_stream(df):
stream = df.execute_stream()
assert all(batch is not None for batch in stream)
Expand Down
11 changes: 9 additions & 2 deletions src/dataframe.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1024,11 +1024,18 @@ impl PyDataFrame {
/// Collect the batches and pass to Arrow Table
fn to_arrow_table(&self, py: Python<'_>) -> PyResult<PyObject> {
let batches = self.collect(py)?.into_pyobject(py)?;
let schema = self.schema().into_pyobject(py)?;

// only use the DataFrame's schema if there are no batches, otherwise let the schema be
// determined from the batches (avoids some inconsistencies with nullable columns)
let args = if batches.len()? == 0 {
let schema = self.schema().into_pyobject(py)?;
PyTuple::new(py, &[batches, schema])?
} else {
PyTuple::new(py, &[batches])?
};

// Instantiate pyarrow Table object and use its from_batches method
let table_class = py.import("pyarrow")?.getattr("Table")?;
let args = PyTuple::new(py, &[batches, schema])?;
let table: PyObject = table_class.call_method1("from_batches", args)?.into();
Ok(table)
}
Expand Down