From 9743fde3839804fc36121b4ecad44fbc770696b6 Mon Sep 17 00:00:00 2001 From: Lantao Jin Date: Thu, 14 May 2026 09:33:50 +0000 Subject: [PATCH] feat(dataframe): add executeStream(allocator) for incremental batch iteration DataFrame.collect(allocator) materializes every batch into a Vec on the Rust heap before the first batch crosses the FFI boundary into Java. For TB-scale or unbounded result sets, this OOMs the Rust side regardless of how memory accounting is configured downstream. Adds DataFrame.executeStream(BufferAllocator) -> ArrowReader as a peer to collect, sharing the same lifecycle (consumes the DataFrame, caller closes the returned reader, allocator must outlive it) but pulling batches lazily. The native side wraps DataFusion's existing SessionContext::execute_stream output (a SendableRecordBatchStream) in a small StreamingReader adapter that bridges async Stream::next() to Arrow's synchronous RecordBatchReader trait; each call to ArrowReader.loadNextBatch() on the Java side drives one runtime().block_on(stream.next()) on the Rust side. Memory pressure stays bounded by the executor pipeline plus a single in-flight batch instead of the full result set. collect remains on its current code path and is unchanged behaviorally; only its Javadoc gains a forward-pointer to executeStream for analytics-scale queries. A follow-up could consolidate collect onto executeStream + concat (~10 LOC, no API change) but that refactor is out of scope here to keep the diff focused on adding the streaming primitive. Tests cover equivalence with collect (same row count over a small VALUES query), the consumes-DataFrame contract (second collect/executeStream/count throws after a successful executeStream), incremental delivery (with batchSize=2 over 5 rows the reader yields multiple batches and no single batch holds the full result), early-close survival (closing the reader mid-stream does not panic), TPC-H integration gated on the SF1 lineitem table, and column-value correctness (pins actual cell values across batches, not just row counts). native/Cargo.toml gains futures = "0.3" for StreamExt::next; it is already pulled in transitively by tokio and datafusion, the addition just makes the import path explicit. --- .../java/org/apache/datafusion/DataFrame.java | 36 ++++ .../DataFrameExecuteStreamTest.java | 193 ++++++++++++++++++ native/Cargo.lock | 1 + native/Cargo.toml | 1 + native/src/lib.rs | 62 +++++- 5 files changed, 292 insertions(+), 1 deletion(-) create mode 100644 core/src/test/java/org/apache/datafusion/DataFrameExecuteStreamTest.java diff --git a/core/src/main/java/org/apache/datafusion/DataFrame.java b/core/src/main/java/org/apache/datafusion/DataFrame.java index dceb497..dd1a46e 100644 --- a/core/src/main/java/org/apache/datafusion/DataFrame.java +++ b/core/src/main/java/org/apache/datafusion/DataFrame.java @@ -53,6 +53,10 @@ public final class DataFrame implements AutoCloseable { *

Consumes this DataFrame: the native plan is released as soon as the stream is established. * The caller is responsible for closing the returned reader, and the supplied allocator must * outlive it. + * + *

This method materializes every batch on the native heap before the first batch crosses the + * FFI boundary, which can OOM the Rust side for unbounded or very large result sets. Prefer + * {@link #executeStream(BufferAllocator)} for analytics-scale queries. */ public ArrowReader collect(BufferAllocator allocator) { if (nativeHandle == 0) { @@ -70,6 +74,36 @@ public ArrowReader collect(BufferAllocator allocator) { } } + /** + * Execute the plan and return its record batches as a streaming {@link ArrowReader}. Each call to + * {@link ArrowReader#loadNextBatch} drives one async {@code stream.next()} on the native side, so + * memory pressure stays bounded by the executor pipeline plus one in-flight batch instead of the + * full result set. + * + *

Consumes this DataFrame with the same lifecycle rules as {@link #collect(BufferAllocator)}: + * the native plan is released as soon as the stream is established, the caller closes the + * returned reader, and the supplied allocator must outlive it. + * + *

For result sets that fit comfortably in native memory and are read in their entirety, {@link + * #collect(BufferAllocator)} remains a reasonable choice. For TB-scale or unbounded result sets, + * use this method. + */ + public ArrowReader executeStream(BufferAllocator allocator) { + if (nativeHandle == 0) { + throw new IllegalStateException("DataFrame is closed or already collected"); + } + ArrowArrayStream stream = ArrowArrayStream.allocateNew(allocator); + long handle = nativeHandle; + nativeHandle = 0; + try { + executeStreamDataFrame(handle, stream.memoryAddress()); + return Data.importArrayStream(allocator, stream); + } catch (Throwable e) { + stream.close(); + throw e; + } + } + /** Execute the plan and return the number of rows. */ public long count() { if (nativeHandle == 0) { @@ -211,6 +245,8 @@ public void close() { private static native void collectDataFrame(long handle, long ffiStreamAddr); + private static native void executeStreamDataFrame(long handle, long ffiStreamAddr); + private static native void closeDataFrame(long handle); private static native long countRows(long handle); diff --git a/core/src/test/java/org/apache/datafusion/DataFrameExecuteStreamTest.java b/core/src/test/java/org/apache/datafusion/DataFrameExecuteStreamTest.java new file mode 100644 index 0000000..e95d2c5 --- /dev/null +++ b/core/src/test/java/org/apache/datafusion/DataFrameExecuteStreamTest.java @@ -0,0 +1,193 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.datafusion; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertThrows; +import static org.junit.jupiter.api.Assertions.assertTrue; + +import java.io.IOException; +import java.nio.file.Files; +import java.nio.file.Path; + +import org.apache.arrow.memory.BufferAllocator; +import org.apache.arrow.memory.RootAllocator; +import org.apache.arrow.vector.BigIntVector; +import org.apache.arrow.vector.VectorSchemaRoot; +import org.apache.arrow.vector.ipc.ArrowReader; +import org.junit.jupiter.api.Assumptions; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.io.TempDir; + +class DataFrameExecuteStreamTest { + + /** + * Write a CSV with `rows` integer rows, one column `x`. Used in tests that need a real file scan + * so DataFusion's batching honors {@code batch_size} -- in-memory {@code VALUES} plans get + * coalesced into a single batch in some DataFusion versions, which would make those tests + * brittle. + */ + private static Path writeRowsCsv(Path dir, int rows) throws IOException { + StringBuilder sb = new StringBuilder("x\n"); + for (int i = 1; i <= rows; i++) { + sb.append(i).append('\n'); + } + Path file = dir.resolve("rows.csv"); + Files.writeString(file, sb.toString()); + return file; + } + + @Test + void executeStreamYieldsTheSameRowsAsCollect() throws Exception { + String sql = "SELECT * FROM (VALUES (1), (2), (3), (4), (5)) AS t(x)"; + + long collected = 0; + try (BufferAllocator allocator = new RootAllocator(); + SessionContext ctx = new SessionContext(); + DataFrame df = ctx.sql(sql); + ArrowReader reader = df.collect(allocator)) { + while (reader.loadNextBatch()) { + collected += reader.getVectorSchemaRoot().getRowCount(); + } + } + + long streamed = 0; + try (BufferAllocator allocator = new RootAllocator(); + SessionContext ctx = new SessionContext(); + DataFrame df = ctx.sql(sql); + ArrowReader reader = df.executeStream(allocator)) { + while (reader.loadNextBatch()) { + streamed += reader.getVectorSchemaRoot().getRowCount(); + } + } + + assertEquals(5L, collected); + assertEquals(collected, streamed); + } + + @Test + void executeStreamConsumesTheDataFrame() throws Exception { + try (BufferAllocator allocator = new RootAllocator(); + SessionContext ctx = new SessionContext()) { + DataFrame df = ctx.sql("SELECT 1"); + try (ArrowReader reader = df.executeStream(allocator)) { + assertTrue(reader.loadNextBatch()); + } + // After a successful executeStream, the DataFrame's native handle is + // released. A second collect/executeStream/count must throw. + assertThrows(IllegalStateException.class, () -> df.executeStream(allocator)); + assertThrows(IllegalStateException.class, () -> df.collect(allocator)); + assertThrows(IllegalStateException.class, df::count); + // close() on an already-streamed DataFrame is a no-op (no double-free). + df.close(); + } + } + + @Test + void executeStreamReadsBatchByBatch(@TempDir Path tempDir) throws Exception { + // CSV with 5 rows scanned at batch_size=2 reliably yields multiple batches + // across DataFusion versions, where an in-memory VALUES plan can be + // coalesced into a single batch by the planner. The point of this test is + // to pin "executeStream actually streams" without coupling to planner + // batching behavior that may shift in upstream releases. + Path csv = writeRowsCsv(tempDir, 5); + try (BufferAllocator allocator = new RootAllocator(); + SessionContext ctx = SessionContext.builder().batchSize(2).build()) { + ctx.registerCsv("rows", csv.toAbsolutePath().toString()); + try (DataFrame df = ctx.sql("SELECT x FROM rows"); + ArrowReader reader = df.executeStream(allocator)) { + int batches = 0; + long total = 0; + int maxBatchSize = 0; + while (reader.loadNextBatch()) { + batches++; + VectorSchemaRoot root = reader.getVectorSchemaRoot(); + total += root.getRowCount(); + maxBatchSize = Math.max(maxBatchSize, root.getRowCount()); + } + assertEquals(5L, total); + assertTrue(batches >= 2, "expected multiple batches with batchSize=2, got " + batches); + assertTrue(maxBatchSize <= 2, "expected each batch <= 2 rows, got " + maxBatchSize); + } + } + } + + @Test + void executeStreamSurvivesEarlyClose() throws Exception { + // Close the reader after the first batch and confirm no native panic / + // resource leak. The DataFrame is already consumed; explicit close on it + // must remain a no-op. + try (BufferAllocator allocator = new RootAllocator(); + SessionContext ctx = SessionContext.builder().batchSize(1).build(); + DataFrame df = ctx.sql("SELECT * FROM (VALUES (1), (2), (3)) AS t(x)")) { + ArrowReader reader = df.executeStream(allocator); + assertTrue(reader.loadNextBatch()); + reader.close(); + } + } + + @Test + void executeStreamOverParquetMatchesCollectRowCount() throws Exception { + Path lineitem = Path.of("tpch-data/sf1/lineitem.parquet"); + Assumptions.assumeTrue( + Files.exists(lineitem), "TPC-H SF1 data not found; run `make tpch-data` first"); + + try (BufferAllocator allocator = new RootAllocator(); + SessionContext ctx = new SessionContext()) { + ctx.registerParquet("lineitem", lineitem.toAbsolutePath().toString()); + + long collected; + try (DataFrame df = ctx.sql("SELECT COUNT(*) FROM lineitem"); + ArrowReader reader = df.collect(allocator)) { + assertTrue(reader.loadNextBatch()); + BigIntVector v = (BigIntVector) reader.getVectorSchemaRoot().getVector(0); + collected = v.get(0); + } + + long streamed = 0; + try (DataFrame df = ctx.sql("SELECT l_orderkey FROM lineitem"); + ArrowReader reader = df.executeStream(allocator)) { + while (reader.loadNextBatch()) { + streamed += reader.getVectorSchemaRoot().getRowCount(); + } + } + assertEquals(collected, streamed); + } + } + + @Test + void executeStreamColumnValuesAreCorrect() throws Exception { + // Pin actual cell values, not just row counts: a regression that + // shipped wrong values per batch must be caught. + try (BufferAllocator allocator = new RootAllocator(); + SessionContext ctx = SessionContext.builder().batchSize(2).build(); + DataFrame df = ctx.sql("SELECT * FROM (VALUES (10), (20), (30), (40)) AS t(x) ORDER BY x"); + ArrowReader reader = df.executeStream(allocator)) { + java.util.List seen = new java.util.ArrayList<>(); + while (reader.loadNextBatch()) { + BigIntVector v = (BigIntVector) reader.getVectorSchemaRoot().getVector(0); + for (int i = 0; i < v.getValueCount(); i++) { + seen.add(v.get(i)); + } + } + assertEquals(java.util.List.of(10L, 20L, 30L, 40L), seen); + } + } +} diff --git a/native/Cargo.lock b/native/Cargo.lock index 495cc60..bb9578f 100644 --- a/native/Cargo.lock +++ b/native/Cargo.lock @@ -1141,6 +1141,7 @@ dependencies = [ "arrow", "datafusion", "datafusion-proto", + "futures", "jni", "prost", "prost-build", diff --git a/native/Cargo.toml b/native/Cargo.toml index 01dd002..b9fca20 100644 --- a/native/Cargo.toml +++ b/native/Cargo.toml @@ -28,6 +28,7 @@ crate-type = ["cdylib"] arrow = { version = "58", features = ["ffi"] } datafusion = "53.1.0" datafusion-proto = "53.1.0" +futures = "0.3" jni = "0.21" prost = "0.14" tokio = { version = "1", features = ["rt-multi-thread"] } diff --git a/native/src/lib.rs b/native/src/lib.rs index 08a919b..c32ca05 100644 --- a/native/src/lib.rs +++ b/native/src/lib.rs @@ -27,15 +27,19 @@ pub(crate) mod proto_gen { use std::path::PathBuf; use std::sync::{Arc, OnceLock}; +use datafusion::arrow::array::RecordBatch; use datafusion::arrow::datatypes::SchemaRef; +use datafusion::arrow::error::ArrowError; use datafusion::arrow::ffi_stream::FFI_ArrowArrayStream; -use datafusion::arrow::record_batch::RecordBatchIterator; +use datafusion::arrow::record_batch::{RecordBatchIterator, RecordBatchReader}; use datafusion::config::TableParquetOptions; use datafusion::dataframe::DataFrame; use datafusion::dataframe::DataFrameWriteOptions; use datafusion::error::DataFusionError; use datafusion::execution::runtime_env::RuntimeEnvBuilder; +use datafusion::execution::SendableRecordBatchStream; use datafusion::prelude::{ParquetReadOptions, SessionConfig, SessionContext}; +use futures::StreamExt; use jni::objects::{JByteArray, JClass, JObjectArray, JString}; use jni::sys::{jboolean, jint, jlong}; use jni::JNIEnv; @@ -152,6 +156,62 @@ pub extern "system" fn Java_org_apache_datafusion_DataFrame_collectDataFrame<'lo }) } +/// Bridges DataFusion's async [`SendableRecordBatchStream`] to the synchronous +/// [`RecordBatchReader`] interface that `FFI_ArrowArrayStream` (and therefore +/// the Java `ArrowReader`) consumes. Each call to `next()` drives one +/// `runtime().block_on(stream.next())`, so memory pressure stays bounded by the +/// executor pipeline plus a single in-flight batch. +struct StreamingReader { + schema: SchemaRef, + stream: SendableRecordBatchStream, +} + +impl Iterator for StreamingReader { + type Item = Result; + + fn next(&mut self) -> Option { + runtime() + .block_on(self.stream.next()) + .map(|r| r.map_err(|e| ArrowError::ExternalError(Box::new(e)))) + } +} + +impl RecordBatchReader for StreamingReader { + fn schema(&self) -> SchemaRef { + self.schema.clone() + } +} + +#[no_mangle] +pub extern "system" fn Java_org_apache_datafusion_DataFrame_executeStreamDataFrame<'local>( + mut env: JNIEnv<'local>, + _class: JClass<'local>, + handle: jlong, + ffi_stream_addr: jlong, +) { + try_unwrap_or_throw(&mut env, (), |_env| -> JniResult<()> { + if handle == 0 { + return Err("DataFrame handle is null".into()); + } + if ffi_stream_addr == 0 { + return Err("ffi stream address is null".into()); + } + let df = unsafe { *Box::from_raw(handle as *mut DataFrame) }; + + let ffi: FFI_ArrowArrayStream = runtime().block_on(async { + let schema: SchemaRef = Arc::new(df.schema().as_arrow().clone()); + let stream = df.execute_stream().await?; + let reader = StreamingReader { schema, stream }; + Ok::<_, DataFusionError>(FFI_ArrowArrayStream::new(Box::new(reader))) + })?; + + unsafe { + std::ptr::write(ffi_stream_addr as *mut FFI_ArrowArrayStream, ffi); + } + Ok(()) + }) +} + #[no_mangle] pub extern "system" fn Java_org_apache_datafusion_DataFrame_countRows<'local>( mut env: JNIEnv<'local>,