Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 36 additions & 0 deletions core/src/main/java/org/apache/datafusion/DataFrame.java
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,10 @@ public final class DataFrame implements AutoCloseable {
* <p>Consumes this DataFrame: the native plan is released as soon as the stream is established.
* The caller is responsible for closing the returned reader, and the supplied allocator must
* outlive it.
*
* <p>This method materializes every batch on the native heap before the first batch crosses the
* FFI boundary, which can OOM the Rust side for unbounded or very large result sets. Prefer
* {@link #executeStream(BufferAllocator)} for analytics-scale queries.
*/
public ArrowReader collect(BufferAllocator allocator) {
if (nativeHandle == 0) {
Expand All @@ -70,6 +74,36 @@ public ArrowReader collect(BufferAllocator allocator) {
}
}

/**
* Execute the plan and return its record batches as a streaming {@link ArrowReader}. Each call to
* {@link ArrowReader#loadNextBatch} drives one async {@code stream.next()} on the native side, so
* memory pressure stays bounded by the executor pipeline plus one in-flight batch instead of the
* full result set.
*
* <p>Consumes this DataFrame with the same lifecycle rules as {@link #collect(BufferAllocator)}:
* the native plan is released as soon as the stream is established, the caller closes the
* returned reader, and the supplied allocator must outlive it.
*
* <p>For result sets that fit comfortably in native memory and are read in their entirety, {@link
* #collect(BufferAllocator)} remains a reasonable choice. For TB-scale or unbounded result sets,
* use this method.
*/
public ArrowReader executeStream(BufferAllocator allocator) {
if (nativeHandle == 0) {
throw new IllegalStateException("DataFrame is closed or already collected");
}
ArrowArrayStream stream = ArrowArrayStream.allocateNew(allocator);
long handle = nativeHandle;
nativeHandle = 0;
try {
executeStreamDataFrame(handle, stream.memoryAddress());
return Data.importArrayStream(allocator, stream);
} catch (Throwable e) {
stream.close();
throw e;
}
}

/** Execute the plan and return the number of rows. */
public long count() {
if (nativeHandle == 0) {
Expand Down Expand Up @@ -211,6 +245,8 @@ public void close() {

private static native void collectDataFrame(long handle, long ffiStreamAddr);

private static native void executeStreamDataFrame(long handle, long ffiStreamAddr);

private static native void closeDataFrame(long handle);

private static native long countRows(long handle);
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,193 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/

package org.apache.datafusion;

import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertThrows;
import static org.junit.jupiter.api.Assertions.assertTrue;

import java.io.IOException;
import java.nio.file.Files;
import java.nio.file.Path;

import org.apache.arrow.memory.BufferAllocator;
import org.apache.arrow.memory.RootAllocator;
import org.apache.arrow.vector.BigIntVector;
import org.apache.arrow.vector.VectorSchemaRoot;
import org.apache.arrow.vector.ipc.ArrowReader;
import org.junit.jupiter.api.Assumptions;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.io.TempDir;

class DataFrameExecuteStreamTest {

/**
* Write a CSV with `rows` integer rows, one column `x`. Used in tests that need a real file scan
* so DataFusion's batching honors {@code batch_size} -- in-memory {@code VALUES} plans get
* coalesced into a single batch in some DataFusion versions, which would make those tests
* brittle.
*/
private static Path writeRowsCsv(Path dir, int rows) throws IOException {
StringBuilder sb = new StringBuilder("x\n");
for (int i = 1; i <= rows; i++) {
sb.append(i).append('\n');
}
Path file = dir.resolve("rows.csv");
Files.writeString(file, sb.toString());
return file;
}

@Test
void executeStreamYieldsTheSameRowsAsCollect() throws Exception {
String sql = "SELECT * FROM (VALUES (1), (2), (3), (4), (5)) AS t(x)";

long collected = 0;
try (BufferAllocator allocator = new RootAllocator();
SessionContext ctx = new SessionContext();
DataFrame df = ctx.sql(sql);
ArrowReader reader = df.collect(allocator)) {
while (reader.loadNextBatch()) {
collected += reader.getVectorSchemaRoot().getRowCount();
}
}

long streamed = 0;
try (BufferAllocator allocator = new RootAllocator();
SessionContext ctx = new SessionContext();
DataFrame df = ctx.sql(sql);
ArrowReader reader = df.executeStream(allocator)) {
while (reader.loadNextBatch()) {
streamed += reader.getVectorSchemaRoot().getRowCount();
}
}

assertEquals(5L, collected);
assertEquals(collected, streamed);
}

@Test
void executeStreamConsumesTheDataFrame() throws Exception {
try (BufferAllocator allocator = new RootAllocator();
SessionContext ctx = new SessionContext()) {
DataFrame df = ctx.sql("SELECT 1");
try (ArrowReader reader = df.executeStream(allocator)) {
assertTrue(reader.loadNextBatch());
}
// After a successful executeStream, the DataFrame's native handle is
// released. A second collect/executeStream/count must throw.
assertThrows(IllegalStateException.class, () -> df.executeStream(allocator));
assertThrows(IllegalStateException.class, () -> df.collect(allocator));
assertThrows(IllegalStateException.class, df::count);
// close() on an already-streamed DataFrame is a no-op (no double-free).
df.close();
}
}

@Test
void executeStreamReadsBatchByBatch(@TempDir Path tempDir) throws Exception {
// CSV with 5 rows scanned at batch_size=2 reliably yields multiple batches
// across DataFusion versions, where an in-memory VALUES plan can be
// coalesced into a single batch by the planner. The point of this test is
// to pin "executeStream actually streams" without coupling to planner
// batching behavior that may shift in upstream releases.
Path csv = writeRowsCsv(tempDir, 5);
try (BufferAllocator allocator = new RootAllocator();
SessionContext ctx = SessionContext.builder().batchSize(2).build()) {
ctx.registerCsv("rows", csv.toAbsolutePath().toString());
try (DataFrame df = ctx.sql("SELECT x FROM rows");
ArrowReader reader = df.executeStream(allocator)) {
int batches = 0;
long total = 0;
int maxBatchSize = 0;
while (reader.loadNextBatch()) {
batches++;
VectorSchemaRoot root = reader.getVectorSchemaRoot();
total += root.getRowCount();
maxBatchSize = Math.max(maxBatchSize, root.getRowCount());
}
assertEquals(5L, total);
assertTrue(batches >= 2, "expected multiple batches with batchSize=2, got " + batches);
assertTrue(maxBatchSize <= 2, "expected each batch <= 2 rows, got " + maxBatchSize);
}
}
}

@Test
void executeStreamSurvivesEarlyClose() throws Exception {
// Close the reader after the first batch and confirm no native panic /
// resource leak. The DataFrame is already consumed; explicit close on it
// must remain a no-op.
try (BufferAllocator allocator = new RootAllocator();
SessionContext ctx = SessionContext.builder().batchSize(1).build();
DataFrame df = ctx.sql("SELECT * FROM (VALUES (1), (2), (3)) AS t(x)")) {
ArrowReader reader = df.executeStream(allocator);
assertTrue(reader.loadNextBatch());
reader.close();
}
}

@Test
void executeStreamOverParquetMatchesCollectRowCount() throws Exception {
Path lineitem = Path.of("tpch-data/sf1/lineitem.parquet");
Assumptions.assumeTrue(
Files.exists(lineitem), "TPC-H SF1 data not found; run `make tpch-data` first");

try (BufferAllocator allocator = new RootAllocator();
SessionContext ctx = new SessionContext()) {
ctx.registerParquet("lineitem", lineitem.toAbsolutePath().toString());

long collected;
try (DataFrame df = ctx.sql("SELECT COUNT(*) FROM lineitem");
ArrowReader reader = df.collect(allocator)) {
assertTrue(reader.loadNextBatch());
BigIntVector v = (BigIntVector) reader.getVectorSchemaRoot().getVector(0);
collected = v.get(0);
}

long streamed = 0;
try (DataFrame df = ctx.sql("SELECT l_orderkey FROM lineitem");
ArrowReader reader = df.executeStream(allocator)) {
while (reader.loadNextBatch()) {
streamed += reader.getVectorSchemaRoot().getRowCount();
}
}
assertEquals(collected, streamed);
}
}

@Test
void executeStreamColumnValuesAreCorrect() throws Exception {
// Pin actual cell values, not just row counts: a regression that
// shipped wrong values per batch must be caught.
try (BufferAllocator allocator = new RootAllocator();
SessionContext ctx = SessionContext.builder().batchSize(2).build();
DataFrame df = ctx.sql("SELECT * FROM (VALUES (10), (20), (30), (40)) AS t(x) ORDER BY x");
ArrowReader reader = df.executeStream(allocator)) {
java.util.List<Long> seen = new java.util.ArrayList<>();
while (reader.loadNextBatch()) {
BigIntVector v = (BigIntVector) reader.getVectorSchemaRoot().getVector(0);
for (int i = 0; i < v.getValueCount(); i++) {
seen.add(v.get(i));
}
}
assertEquals(java.util.List.of(10L, 20L, 30L, 40L), seen);
}
}
}
1 change: 1 addition & 0 deletions native/Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions native/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ crate-type = ["cdylib"]
arrow = { version = "58", features = ["ffi"] }
datafusion = "53.1.0"
datafusion-proto = "53.1.0"
futures = "0.3"
jni = "0.21"
prost = "0.14"
tokio = { version = "1", features = ["rt-multi-thread"] }
Expand Down
62 changes: 61 additions & 1 deletion native/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -27,15 +27,19 @@ pub(crate) mod proto_gen {
use std::path::PathBuf;
use std::sync::{Arc, OnceLock};

use datafusion::arrow::array::RecordBatch;
use datafusion::arrow::datatypes::SchemaRef;
use datafusion::arrow::error::ArrowError;
use datafusion::arrow::ffi_stream::FFI_ArrowArrayStream;
use datafusion::arrow::record_batch::RecordBatchIterator;
use datafusion::arrow::record_batch::{RecordBatchIterator, RecordBatchReader};
use datafusion::config::TableParquetOptions;
use datafusion::dataframe::DataFrame;
use datafusion::dataframe::DataFrameWriteOptions;
use datafusion::error::DataFusionError;
use datafusion::execution::runtime_env::RuntimeEnvBuilder;
use datafusion::execution::SendableRecordBatchStream;
use datafusion::prelude::{ParquetReadOptions, SessionConfig, SessionContext};
use futures::StreamExt;
use jni::objects::{JByteArray, JClass, JObjectArray, JString};
use jni::sys::{jboolean, jint, jlong};
use jni::JNIEnv;
Expand Down Expand Up @@ -152,6 +156,62 @@ pub extern "system" fn Java_org_apache_datafusion_DataFrame_collectDataFrame<'lo
})
}

/// Bridges DataFusion's async [`SendableRecordBatchStream`] to the synchronous
/// [`RecordBatchReader`] interface that `FFI_ArrowArrayStream` (and therefore
/// the Java `ArrowReader`) consumes. Each call to `next()` drives one
/// `runtime().block_on(stream.next())`, so memory pressure stays bounded by the
/// executor pipeline plus a single in-flight batch.
struct StreamingReader {
schema: SchemaRef,
stream: SendableRecordBatchStream,
}

impl Iterator for StreamingReader {
type Item = Result<RecordBatch, ArrowError>;

fn next(&mut self) -> Option<Self::Item> {
runtime()
.block_on(self.stream.next())
.map(|r| r.map_err(|e| ArrowError::ExternalError(Box::new(e))))
}
}

impl RecordBatchReader for StreamingReader {
fn schema(&self) -> SchemaRef {
self.schema.clone()
}
}

#[no_mangle]
pub extern "system" fn Java_org_apache_datafusion_DataFrame_executeStreamDataFrame<'local>(
mut env: JNIEnv<'local>,
_class: JClass<'local>,
handle: jlong,
ffi_stream_addr: jlong,
) {
try_unwrap_or_throw(&mut env, (), |_env| -> JniResult<()> {
if handle == 0 {
return Err("DataFrame handle is null".into());
}
if ffi_stream_addr == 0 {
return Err("ffi stream address is null".into());
}
let df = unsafe { *Box::from_raw(handle as *mut DataFrame) };

let ffi: FFI_ArrowArrayStream = runtime().block_on(async {
let schema: SchemaRef = Arc::new(df.schema().as_arrow().clone());
let stream = df.execute_stream().await?;
let reader = StreamingReader { schema, stream };
Ok::<_, DataFusionError>(FFI_ArrowArrayStream::new(Box::new(reader)))
})?;

unsafe {
std::ptr::write(ffi_stream_addr as *mut FFI_ArrowArrayStream, ffi);
}
Ok(())
})
}

#[no_mangle]
pub extern "system" fn Java_org_apache_datafusion_DataFrame_countRows<'local>(
mut env: JNIEnv<'local>,
Expand Down
Loading