diff --git a/datafusion/tests/test_dataframe.py b/datafusion/tests/test_dataframe.py index bbbdddd41..760c37610 100644 --- a/datafusion/tests/test_dataframe.py +++ b/datafusion/tests/test_dataframe.py @@ -312,3 +312,14 @@ def test_except_all(): df_a_e_b = df_a.except_all(df_b).sort(column("a").sort(ascending=True)) assert df_c.collect() == df_a_e_b.collect() + + +def test_collect_partitioned(): + ctx = SessionContext() + + batch = pa.RecordBatch.from_arrays( + [pa.array([1, 2, 3]), pa.array([4, 5, 6])], + names=["a", "b"], + ) + + assert [[batch]] == ctx.create_dataframe([[batch]]).collect_partitioned() diff --git a/src/dataframe.rs b/src/dataframe.rs index e491c3d9d..4ae0160a9 100644 --- a/src/dataframe.rs +++ b/src/dataframe.rs @@ -129,6 +129,17 @@ impl PyDataFrame { batches.into_iter().map(|rb| rb.to_pyarrow(py)).collect() } + /// Executes this DataFrame and collects all results into a vector of vector of RecordBatch + /// maintaining the input partitioning. + fn collect_partitioned(&self, py: Python) -> PyResult>> { + let batches = wait_for_future(py, self.df.collect_partitioned())?; + + batches + .into_iter() + .map(|rbs| rbs.into_iter().map(|rb| rb.to_pyarrow(py)).collect()) + .collect() + } + /// Print the result, 20 lines by default #[args(num = "20")] fn show(&self, py: Python, num: usize) -> PyResult<()> {