diff --git a/core/src/main/java/org/apache/datafusion/SessionContext.java b/core/src/main/java/org/apache/datafusion/SessionContext.java index 049761d..674341a 100644 --- a/core/src/main/java/org/apache/datafusion/SessionContext.java +++ b/core/src/main/java/org/apache/datafusion/SessionContext.java @@ -456,6 +456,41 @@ public void registerUdf(ScalarUdf udf) { registerScalarUdf(nativeHandle, name, signatureBytes, volatility.code(), impl); } + /** + * Register a Java-implemented {@link TableProvider} under {@code name}. SQL queries that + * reference {@code name} call back into {@code provider} to fetch batches. + * + *

{@link TableProvider#schema()} is called once here, on the calling thread, and cached on the + * native side. {@link TableProvider#scan(org.apache.arrow.memory.BufferAllocator)} is called once + * per query that touches the table, on a Tokio worker thread; it must return a fresh, independent + * {@link org.apache.arrow.vector.ipc.ArrowReader} on every call, with its buffers allocated from + * the {@link org.apache.arrow.memory.BufferAllocator} the framework supplies. + * + *

This is the Java counterpart to DataFusion's Rust {@code SessionContext::register_table}. + * + * @throws IllegalArgumentException if {@code name} or {@code provider} is {@code null}. + * @throws IllegalStateException if {@code provider.schema()} returns {@code null}, or this + * context is closed. + * @throws RuntimeException if native registration fails. + */ + public void registerTable(String name, TableProvider provider) { + if (nativeHandle == 0) { + throw new IllegalStateException("SessionContext is closed"); + } + if (name == null) { + throw new IllegalArgumentException("registerTable name must be non-null"); + } + if (provider == null) { + throw new IllegalArgumentException("registerTable provider must be non-null"); + } + Schema schema = provider.schema(); + if (schema == null) { + throw new IllegalStateException("TableProvider.schema returned null"); + } + byte[] schemaIpc = serializeSchemaIpc(schema); + registerTableNative(nativeHandle, name, schemaIpc, provider); + } + private static byte[] serializeSchemaIpc(Schema schema) { ByteArrayOutputStream baos = new ByteArrayOutputStream(); try (BufferAllocator allocator = new RootAllocator(); @@ -523,4 +558,7 @@ private static native long readJsonWithOptions( private static native void registerScalarUdf( long handle, String name, byte[] signatureSchemaBytes, byte volatility, ScalarFunction impl); + + private static native void registerTableNative( + long handle, String name, byte[] schemaIpcBytes, TableProvider provider); } diff --git a/core/src/main/java/org/apache/datafusion/SimpleTableProvider.java b/core/src/main/java/org/apache/datafusion/SimpleTableProvider.java new file mode 100644 index 0000000..bad347d --- /dev/null +++ b/core/src/main/java/org/apache/datafusion/SimpleTableProvider.java @@ -0,0 +1,70 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.datafusion; + +import java.util.function.Function; + +import org.apache.arrow.memory.BufferAllocator; +import org.apache.arrow.vector.ipc.ArrowReader; +import org.apache.arrow.vector.types.pojo.Schema; + +/** + * A {@link TableProvider} that pairs a fixed {@link Schema} with a function that opens a fresh + * {@link ArrowReader} for each scan. Provided as a convenience for the common case where there is + * no projection / filter pushdown to implement. + * + *

Each call to {@link #scan(BufferAllocator)} invokes the supplied function and returns whatever + * {@link ArrowReader} it produces, so the function MUST return a fresh, independent reader on every + * invocation (see the contract on {@link TableProvider#scan(BufferAllocator)}). + * + *

As {@link TableProvider} grows additional methods in the future, this class will provide + * defaults so existing callers keep working without changes. + */ +public final class SimpleTableProvider implements TableProvider { + + private final Schema schema; + private final Function scanFn; + + /** + * @param schema the table schema; returned as-is from {@link #schema()} + * @param scanFn called on every {@link #scan(BufferAllocator)} with the framework-supplied + * allocator; must return a fresh, independent {@link ArrowReader} each time + */ + public SimpleTableProvider(Schema schema, Function scanFn) { + if (schema == null) { + throw new IllegalArgumentException("schema must be non-null"); + } + if (scanFn == null) { + throw new IllegalArgumentException("scanFn must be non-null"); + } + this.schema = schema; + this.scanFn = scanFn; + } + + @Override + public Schema schema() { + return schema; + } + + @Override + public ArrowReader scan(BufferAllocator allocator) { + return scanFn.apply(allocator); + } +} diff --git a/core/src/main/java/org/apache/datafusion/TableProvider.java b/core/src/main/java/org/apache/datafusion/TableProvider.java new file mode 100644 index 0000000..a50bc00 --- /dev/null +++ b/core/src/main/java/org/apache/datafusion/TableProvider.java @@ -0,0 +1,59 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.datafusion; + +import org.apache.arrow.memory.BufferAllocator; +import org.apache.arrow.vector.ipc.ArrowReader; +import org.apache.arrow.vector.types.pojo.Schema; + +/** + * A Java-implemented table that can be registered with a {@link SessionContext} via {@link + * SessionContext#registerTable(String, TableProvider)}. Mirrors the role of DataFusion's Rust + * {@code TableProvider} trait, but at present only exposes the methods needed for a full table + * scan; future versions may add filter/projection pushdown and multi-partition support as default + * methods so existing implementations keep working. + * + *

{@link SimpleTableProvider} is a ready-made implementation for the common case of "I have a + * schema and a function that returns an {@link ArrowReader}". + * + *

Each call to {@link #scan(BufferAllocator)} must return a fresh, independent {@link + * ArrowReader} so that queries which touch the table more than once (self-joins, {@code UNION ALL}, + * repeated reads) work correctly. The returned reader is closed by the framework when the stream + * ends. + * + *

The schema returned by {@link #schema()} is captured once at registration time. Every batch + * produced by every {@code ArrowReader} returned from {@link #scan(BufferAllocator)} must conform + * to it; a mismatch fails the query. + */ +public interface TableProvider { + /** The fixed schema of this table. Called once, at registration time. */ + Schema schema(); + + /** + * Open a fresh batch stream for this table. Called once per physical scan of the table — a single + * query may invoke this more than once (self-joins, {@code UNION ALL} over the same table, etc.). + * + *

Each invocation MUST return an independent {@link ArrowReader}. The reader's schema MUST + * equal {@link #schema()}. The reader's buffers MUST be allocated from {@code allocator} (or from + * a child of it) — the framework needs the reader's allocator hierarchy to share a root with the + * one it passes here. The allocator contract mirrors the one on {@link ScalarFunction#evaluate}. + */ + ArrowReader scan(BufferAllocator allocator); +} diff --git a/core/src/main/java/org/apache/datafusion/internal/JniBridge.java b/core/src/main/java/org/apache/datafusion/internal/JniBridge.java index 8248357..5d87c1a 100644 --- a/core/src/main/java/org/apache/datafusion/internal/JniBridge.java +++ b/core/src/main/java/org/apache/datafusion/internal/JniBridge.java @@ -23,14 +23,17 @@ import java.util.List; import org.apache.arrow.c.ArrowArray; +import org.apache.arrow.c.ArrowArrayStream; import org.apache.arrow.c.ArrowSchema; import org.apache.arrow.c.Data; import org.apache.arrow.memory.RootAllocator; import org.apache.arrow.vector.FieldVector; import org.apache.arrow.vector.VectorSchemaRoot; +import org.apache.arrow.vector.ipc.ArrowReader; import org.apache.datafusion.ColumnarValue; import org.apache.datafusion.ScalarFunction; import org.apache.datafusion.ScalarFunctionArgs; +import org.apache.datafusion.TableProvider; /** Internal trampoline invoked from native code on every UDF call. Not part of the public API. */ public final class JniBridge { @@ -139,4 +142,34 @@ public static byte invokeScalarUdf( return resultKind; } } + + /** + * Open a fresh batch stream from a Java {@link TableProvider} and export it through the supplied + * Arrow C Data Interface address. Called from native code; not for application use. + * + *

{@link TableProvider#scan(org.apache.arrow.memory.BufferAllocator)} is called with {@link + * #ALLOCATOR} so that the reader's buffers share the same allocator root required by {@link + * Data#exportArrayStream}. + * + *

On success, ownership of the returned reader transfers to the FFI stream's release callback, + * so the native side closing the stream also closes the reader. On any failure during export, the + * reader is closed here before the exception propagates. + */ + public static void invokeTableScan(TableProvider provider, long ffiStreamAddr) { + ArrowReader reader = provider.scan(ALLOCATOR); + if (reader == null) { + throw new IllegalStateException("TableProvider.scan returned null"); + } + ArrowArrayStream stream = ArrowArrayStream.wrap(ffiStreamAddr); + try { + Data.exportArrayStream(ALLOCATOR, reader, stream); + } catch (Throwable t) { + try { + reader.close(); + } catch (Exception ignored) { + // best-effort cleanup; original error wins + } + throw t; + } + } } diff --git a/core/src/test/java/org/apache/datafusion/TableProviderTest.java b/core/src/test/java/org/apache/datafusion/TableProviderTest.java new file mode 100644 index 0000000..eb58097 --- /dev/null +++ b/core/src/test/java/org/apache/datafusion/TableProviderTest.java @@ -0,0 +1,423 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.datafusion; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertTrue; + +import java.io.ByteArrayInputStream; +import java.io.ByteArrayOutputStream; +import java.io.IOException; +import java.nio.channels.Channels; +import java.util.List; +import java.util.concurrent.atomic.AtomicInteger; + +import org.apache.arrow.memory.BufferAllocator; +import org.apache.arrow.memory.RootAllocator; +import org.apache.arrow.vector.IntVector; +import org.apache.arrow.vector.VarCharVector; +import org.apache.arrow.vector.VectorSchemaRoot; +import org.apache.arrow.vector.ipc.ArrowReader; +import org.apache.arrow.vector.ipc.ArrowStreamReader; +import org.apache.arrow.vector.ipc.ArrowStreamWriter; +import org.apache.arrow.vector.types.pojo.ArrowType; +import org.apache.arrow.vector.types.pojo.Field; +import org.apache.arrow.vector.types.pojo.FieldType; +import org.apache.arrow.vector.types.pojo.Schema; +import org.junit.jupiter.api.Test; + +class TableProviderTest { + + private static final ArrowType INT32 = new ArrowType.Int(32, true); + private static final ArrowType UTF8 = new ArrowType.Utf8(); + + /** + * In-memory {@link TableProvider} fixture. The batches are serialised to Arrow IPC bytes once at + * construction (using a private allocator); each {@link #scan(BufferAllocator)} call returns a + * fresh {@link ArrowStreamReader} backed by those bytes, using the framework-supplied allocator. + */ + static final class InMemoryTableProvider implements TableProvider { + private final Schema schema; + private final byte[] ipcBytes; + private final AtomicInteger scanCount = new AtomicInteger(); + + InMemoryTableProvider(Schema schema, byte[] ipcBytes) { + this.schema = schema; + this.ipcBytes = ipcBytes; + } + + /** + * Build a fixture from one or more vector-schema-root batches. The caller's allocator may be a + * temporary RootAllocator; this constructor reads all data into IPC bytes immediately. + */ + static InMemoryTableProvider fromBatches(Schema schema, List batches) { + return new InMemoryTableProvider(schema, serializeBatches(schema, batches)); + } + + static byte[] serializeBatches(Schema schema, List batches) { + ByteArrayOutputStream baos = new ByteArrayOutputStream(); + try (BufferAllocator tmp = new RootAllocator(); + VectorSchemaRoot stagingRoot = VectorSchemaRoot.create(schema, tmp); + ArrowStreamWriter writer = + new ArrowStreamWriter(stagingRoot, null, Channels.newChannel(baos))) { + writer.start(); + for (VectorSchemaRoot batch : batches) { + stagingRoot.allocateNew(); + int rowCount = batch.getRowCount(); + stagingRoot.setRowCount(rowCount); + for (int i = 0; i < batch.getFieldVectors().size(); i++) { + org.apache.arrow.vector.FieldVector src = batch.getFieldVectors().get(i); + org.apache.arrow.vector.FieldVector dst = stagingRoot.getFieldVectors().get(i); + for (int r = 0; r < rowCount; r++) { + dst.copyFromSafe(r, r, src); + } + dst.setValueCount(rowCount); + } + writer.writeBatch(); + } + writer.end(); + } catch (IOException e) { + throw new RuntimeException("failed to serialize batches", e); + } + return baos.toByteArray(); + } + + @Override + public Schema schema() { + return schema; + } + + @Override + public ArrowReader scan(BufferAllocator allocator) { + scanCount.incrementAndGet(); + return new ArrowStreamReader(new ByteArrayInputStream(ipcBytes), allocator); + } + + int scanCount() { + return scanCount.get(); + } + } + + /** Build a one-batch in-memory fixture of (id INT, name UTF8) with the given rows. */ + private static InMemoryTableProvider buildTwoColumnTable(int[] ids, String[] names) { + Schema schema = + new Schema( + List.of( + new Field("id", FieldType.nullable(INT32), null), + new Field("name", FieldType.nullable(UTF8), null))); + try (BufferAllocator tmp = new RootAllocator(); + VectorSchemaRoot root = VectorSchemaRoot.create(schema, tmp)) { + IntVector idVec = (IntVector) root.getVector("id"); + VarCharVector nameVec = (VarCharVector) root.getVector("name"); + int n = ids.length; + idVec.allocateNew(n); + nameVec.allocateNew(n); + for (int i = 0; i < n; i++) { + idVec.set(i, ids[i]); + nameVec.setSafe(i, names[i].getBytes()); + } + idVec.setValueCount(n); + nameVec.setValueCount(n); + root.setRowCount(n); + return InMemoryTableProvider.fromBatches(schema, List.of(root)); + } + } + + @Test + void registerTable_selectStar_returnsAllRows() throws Exception { + try (BufferAllocator allocator = new RootAllocator(); + SessionContext ctx = new SessionContext()) { + InMemoryTableProvider src = + buildTwoColumnTable(new int[] {1, 2, 3}, new String[] {"a", "b", "c"}); + ctx.registerTable("t", src); + + try (DataFrame df = ctx.sql("SELECT id, name FROM t ORDER BY id"); + ArrowReader r = df.collect(allocator)) { + assertTrue(r.loadNextBatch()); + VectorSchemaRoot out = r.getVectorSchemaRoot(); + IntVector id = (IntVector) out.getVector("id"); + VarCharVector name = (VarCharVector) out.getVector("name"); + assertEquals(3, id.getValueCount()); + assertEquals(1, id.get(0)); + assertEquals(2, id.get(1)); + assertEquals(3, id.get(2)); + assertEquals("a", new String(name.get(0))); + assertEquals("b", new String(name.get(1))); + assertEquals("c", new String(name.get(2))); + while (r.loadNextBatch()) {} + } + assertEquals(1, src.scanCount()); + } + } + + @Test + void registerTable_unionAllSelf_callsScanTwice() throws Exception { + try (BufferAllocator allocator = new RootAllocator(); + SessionContext ctx = new SessionContext()) { + InMemoryTableProvider src = buildTwoColumnTable(new int[] {1, 2}, new String[] {"a", "b"}); + ctx.registerTable("t", src); + + try (DataFrame df = ctx.sql("SELECT id FROM t UNION ALL SELECT id FROM t"); + ArrowReader r = df.collect(allocator)) { + long total = 0; + while (r.loadNextBatch()) { + IntVector id = (IntVector) r.getVectorSchemaRoot().getVector("id"); + total += id.getValueCount(); + } + assertEquals(4, total); + } + assertEquals(2, src.scanCount()); + } + } + + @Test + void registerTable_emptyStream_yieldsNoRows() throws Exception { + try (BufferAllocator allocator = new RootAllocator(); + SessionContext ctx = new SessionContext()) { + Schema schema = new Schema(List.of(new Field("id", FieldType.nullable(INT32), null))); + InMemoryTableProvider src = InMemoryTableProvider.fromBatches(schema, List.of()); + ctx.registerTable("t", src); + + try (DataFrame df = ctx.sql("SELECT id FROM t"); + ArrowReader r = df.collect(allocator)) { + long total = 0; + while (r.loadNextBatch()) { + IntVector id = (IntVector) r.getVectorSchemaRoot().getVector("id"); + total += id.getValueCount(); + } + assertEquals(0, total); + } + } + } + + @Test + void registerTable_projectSingleColumn_returnsOnlyThatColumn() throws Exception { + try (BufferAllocator allocator = new RootAllocator(); + SessionContext ctx = new SessionContext()) { + InMemoryTableProvider src = + buildTwoColumnTable(new int[] {10, 20, 30}, new String[] {"x", "y", "z"}); + ctx.registerTable("t", src); + + try (DataFrame df = ctx.sql("SELECT name FROM t ORDER BY name"); + ArrowReader r = df.collect(allocator)) { + assertTrue(r.loadNextBatch()); + VectorSchemaRoot out = r.getVectorSchemaRoot(); + assertEquals(1, out.getSchema().getFields().size()); + VarCharVector name = (VarCharVector) out.getVector("name"); + assertEquals(3, name.getValueCount()); + assertEquals("x", new String(name.get(0))); + assertEquals("y", new String(name.get(1))); + assertEquals("z", new String(name.get(2))); + while (r.loadNextBatch()) {} + } + } + } + + static final class ThrowingTableProvider implements TableProvider { + @Override + public Schema schema() { + return new Schema(List.of(new Field("id", FieldType.nullable(INT32), null))); + } + + @Override + public ArrowReader scan(BufferAllocator allocator) { + throw new IllegalArgumentException("custom boom from TableProvider"); + } + } + + @Test + void registerTable_scanThrows_propagatesClassAndMessage() { + try (BufferAllocator allocator = new RootAllocator(); + SessionContext ctx = new SessionContext()) { + ctx.registerTable("t", new ThrowingTableProvider()); + + RuntimeException ex = + org.junit.jupiter.api.Assertions.assertThrows( + RuntimeException.class, + () -> { + try (DataFrame df = ctx.sql("SELECT id FROM t"); + ArrowReader r = df.collect(allocator)) { + while (r.loadNextBatch()) {} + } + }); + String msg = ex.getMessage(); + assertTrue( + msg.contains("IllegalArgumentException"), + "expected exception class in error, got: " + msg); + assertTrue( + msg.contains("custom boom from TableProvider"), + "expected user message in error, got: " + msg); + } + } + + static final class NullReturningTableProvider implements TableProvider { + @Override + public Schema schema() { + return new Schema(List.of(new Field("id", FieldType.nullable(INT32), null))); + } + + @Override + public ArrowReader scan(BufferAllocator allocator) { + return null; + } + } + + @Test + void registerTable_scanReturnsNull_failsWithIllegalStateException() { + try (BufferAllocator allocator = new RootAllocator(); + SessionContext ctx = new SessionContext()) { + ctx.registerTable("t", new NullReturningTableProvider()); + + RuntimeException ex = + org.junit.jupiter.api.Assertions.assertThrows( + RuntimeException.class, + () -> { + try (DataFrame df = ctx.sql("SELECT id FROM t"); + ArrowReader r = df.collect(allocator)) { + while (r.loadNextBatch()) {} + } + }); + String msg = ex.getMessage(); + assertTrue( + msg.contains("IllegalStateException"), + "expected IllegalStateException in error, got: " + msg); + assertTrue(msg.contains("returned null"), "expected 'returned null' wording, got: " + msg); + } + } + + /** Declares (id INT) but scan() returns (id INT, extra UTF8). */ + static final class SchemaLyingTableProvider implements TableProvider { + @Override + public Schema schema() { + return new Schema(List.of(new Field("id", FieldType.nullable(INT32), null))); + } + + @Override + public ArrowReader scan(BufferAllocator allocator) { + Schema actualSchema = + new Schema( + List.of( + new Field("id", FieldType.nullable(INT32), null), + new Field("extra", FieldType.nullable(UTF8), null))); + try (BufferAllocator tmp = new RootAllocator(); + VectorSchemaRoot root = VectorSchemaRoot.create(actualSchema, tmp)) { + root.setRowCount(0); + byte[] ipc = InMemoryTableProvider.serializeBatches(actualSchema, List.of(root)); + return new ArrowStreamReader(new ByteArrayInputStream(ipc), allocator); + } + } + } + + @Test + void registerTable_schemaMismatch_failsQueryWithReadableError() { + try (BufferAllocator allocator = new RootAllocator(); + SessionContext ctx = new SessionContext()) { + ctx.registerTable("t", new SchemaLyingTableProvider()); + + RuntimeException ex = + org.junit.jupiter.api.Assertions.assertThrows( + RuntimeException.class, + () -> { + try (DataFrame df = ctx.sql("SELECT id FROM t"); + ArrowReader r = df.collect(allocator)) { + while (r.loadNextBatch()) {} + } + }); + String msg = ex.getMessage(); + assertTrue( + msg.contains("registered schema") || msg.contains("returned schema"), + "expected schema-mismatch wording, got: " + msg); + } + } + + @Test + void registerTable_twoTables_joinable() throws Exception { + try (BufferAllocator allocator = new RootAllocator(); + SessionContext ctx = new SessionContext()) { + InMemoryTableProvider left = buildTwoColumnTable(new int[] {1, 2}, new String[] {"a", "b"}); + InMemoryTableProvider right = buildTwoColumnTable(new int[] {2, 3}, new String[] {"B", "C"}); + ctx.registerTable("l", left); + ctx.registerTable("r", right); + + int totalRows = 0; + int lidVal = -1; + int ridVal = -1; + String lnameVal = null; + String rnameVal = null; + + try (DataFrame df = + ctx.sql( + "SELECT l.id AS lid, r.id AS rid, l.name AS lname, r.name AS rname" + + " FROM l JOIN r ON l.id = r.id"); + ArrowReader rd = df.collect(allocator)) { + while (rd.loadNextBatch()) { + VectorSchemaRoot out = rd.getVectorSchemaRoot(); + IntVector lid = (IntVector) out.getVector("lid"); + IntVector rid = (IntVector) out.getVector("rid"); + VarCharVector lname = (VarCharVector) out.getVector("lname"); + VarCharVector rname = (VarCharVector) out.getVector("rname"); + int n = lid.getValueCount(); + for (int i = 0; i < n; i++) { + if (totalRows == 0) { + lidVal = lid.get(i); + ridVal = rid.get(i); + lnameVal = new String(lname.get(i)); + rnameVal = new String(rname.get(i)); + } + totalRows++; + } + } + } + assertEquals(1, totalRows); + assertEquals(2, lidVal); + assertEquals(2, ridVal); + assertEquals("b", lnameVal); + assertEquals("B", rnameVal); + assertEquals(1, left.scanCount()); + assertEquals(1, right.scanCount()); + } + } + + @Test + void simpleTableProvider_registerAndQuery_returnsRows() throws Exception { + try (BufferAllocator allocator = new RootAllocator(); + SessionContext ctx = new SessionContext()) { + InMemoryTableProvider backing = + buildTwoColumnTable(new int[] {1, 2}, new String[] {"a", "b"}); + TableProvider provider = new SimpleTableProvider(backing.schema(), backing::scan); + ctx.registerTable("t", provider); + + try (DataFrame df = ctx.sql("SELECT id, name FROM t ORDER BY id"); + ArrowReader r = df.collect(allocator)) { + assertTrue(r.loadNextBatch()); + VectorSchemaRoot out = r.getVectorSchemaRoot(); + IntVector id = (IntVector) out.getVector("id"); + VarCharVector name = (VarCharVector) out.getVector("name"); + assertEquals(2, id.getValueCount()); + assertEquals(1, id.get(0)); + assertEquals(2, id.get(1)); + assertEquals("a", new String(name.get(0))); + assertEquals("b", new String(name.get(1))); + while (r.loadNextBatch()) {} + } + assertEquals(1, backing.scanCount()); + } + } +} diff --git a/docs/source/user-guide/index.md b/docs/source/user-guide/index.md index 599340c..85c2bed 100644 --- a/docs/source/user-guide/index.md +++ b/docs/source/user-guide/index.md @@ -38,6 +38,7 @@ dataframe parquet proto-plans scalar-udf +table-provider api-reference ``` diff --git a/docs/source/user-guide/table-provider.md b/docs/source/user-guide/table-provider.md new file mode 100644 index 0000000..7eed07d --- /dev/null +++ b/docs/source/user-guide/table-provider.md @@ -0,0 +1,118 @@ + + +# Java table providers + +`SessionContext.registerTable(name, provider)` registers a Java-implemented +table. SQL queries that reference `name` call back into your `TableProvider` +to fetch batches. Data flows from Java to native code via the Arrow C Data +Interface, so there are no extra copies in the hot path. This is the Java +counterpart to DataFusion's Rust `SessionContext::register_table`. + +## Implement + +```java +import org.apache.arrow.memory.BufferAllocator; +import org.apache.arrow.vector.ipc.ArrowReader; +import org.apache.arrow.vector.types.pojo.Schema; +import org.apache.datafusion.TableProvider; + +public final class MyTable implements TableProvider { + private final Schema schema; + + public MyTable(Schema schema) { + this.schema = schema; + } + + @Override public Schema schema() { return schema; } + + @Override + public ArrowReader scan(BufferAllocator allocator) { + // Return a fresh ArrowReader. The reader must allocate its buffers + // from `allocator` (or a child of it) — the framework needs the + // allocator hierarchy to share a root. + return openMyReader(allocator); + } +} +``` + +For the common case of "I have a schema and a function that returns an +`ArrowReader`," `SimpleTableProvider` packages those two into a ready-made +`TableProvider` without having to subclass: + +```java +TableProvider t = new SimpleTableProvider(mySchema(), allocator -> openMyReader(allocator)); +ctx.registerTable("t", t); +``` + +## Register and query + +```java +try (SessionContext ctx = new SessionContext(); + BufferAllocator allocator = new RootAllocator()) { + ctx.registerTable("t", new MyTable(mySchema())); + + try (DataFrame df = ctx.sql("SELECT * FROM t WHERE x > 10"); + ArrowReader r = df.collect(allocator)) { + while (r.loadNextBatch()) { + // ... + } + } +} +``` + +## Contract + +- `schema()` is called exactly once, on the caller's thread, at registration + time. Throwing from it aborts registration with the original exception. +- `scan(allocator)` is called once per SQL query that touches the table, on a + worker thread. It must return a fresh, independent `ArrowReader` on every + call — this is what makes self-joins and `UNION ALL` over the same table + work. +- The reader returned by `scan` must allocate its buffers from the supplied + `allocator` (or a child of it). Arrow Java's `Data.exportArrayStream` + requires the reader's allocator and the export allocator to share a root. +- The returned reader's schema must equal the schema returned by `schema()`. + A mismatch fails the query. +- You do not need to close the returned reader yourself. The framework + installs a release callback that closes it when the underlying FFI stream + is dropped. + +## Errors + +Exceptions thrown from `scan()` or from the returned reader surface in the +`RuntimeException` raised by `collect()`. The error message includes the Java +exception class and `getMessage()`, in the same format used for scalar UDF +errors. + +## Threading + +`SessionContext` is single-threaded, but `scan(allocator)` may be invoked from +any DataFusion worker thread. If your implementation maintains mutable state +across scans, synchronise it. + +## Limitations (v1) + +- Single-partition scans only. DataFusion sees the table as one partition; + multi-partition parallelism is a follow-up. +- No projection or filter pushdown. DataFusion applies projection and + filters on top of the batches you return; the Java side always sees the + full schema. The interface is intentionally minimal so it can grow these + capabilities (as default methods) without breaking existing implementations. +- No `deregisterTable`. Tables live until the `SessionContext` is closed. diff --git a/examples/pom.xml b/examples/pom.xml index 885e0d4..97a6b40 100644 --- a/examples/pom.xml +++ b/examples/pom.xml @@ -53,6 +53,17 @@ under the License. arrow-memory-netty runtime + + + org.apache.arrow + arrow-jdbc + ${arrow.version} + + + com.h2database + h2 + 2.3.232 + diff --git a/examples/src/main/java/org/apache/datafusion/examples/JdbcExample.java b/examples/src/main/java/org/apache/datafusion/examples/JdbcExample.java new file mode 100644 index 0000000..a48b203 --- /dev/null +++ b/examples/src/main/java/org/apache/datafusion/examples/JdbcExample.java @@ -0,0 +1,258 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.datafusion.examples; + +import java.io.IOException; +import java.sql.Connection; +import java.sql.DriverManager; +import java.sql.PreparedStatement; +import java.sql.ResultSet; +import java.sql.SQLException; +import java.sql.Statement; +import java.util.Calendar; +import java.util.TimeZone; + +import org.apache.arrow.adapter.jdbc.ArrowVectorIterator; +import org.apache.arrow.adapter.jdbc.JdbcToArrow; +import org.apache.arrow.adapter.jdbc.JdbcToArrowConfig; +import org.apache.arrow.adapter.jdbc.JdbcToArrowConfigBuilder; +import org.apache.arrow.adapter.jdbc.JdbcToArrowUtils; +import org.apache.arrow.memory.BufferAllocator; +import org.apache.arrow.memory.RootAllocator; +import org.apache.arrow.vector.Float8Vector; +import org.apache.arrow.vector.VarCharVector; +import org.apache.arrow.vector.VectorLoader; +import org.apache.arrow.vector.VectorSchemaRoot; +import org.apache.arrow.vector.VectorUnloader; +import org.apache.arrow.vector.ipc.ArrowReader; +import org.apache.arrow.vector.ipc.message.ArrowRecordBatch; +import org.apache.arrow.vector.types.pojo.Schema; +import org.apache.datafusion.DataFrame; +import org.apache.datafusion.SessionContext; +import org.apache.datafusion.TableProvider; + +/** + * Demonstrates a JDBC-backed {@link TableProvider}. Populates an H2 in-memory table, registers it + * with DataFusion via {@link SessionContext#registerTable}, and runs an aggregation query against + * it. + * + *

Run with: + * + *

+ *   ./mvnw -pl examples exec:exec -Dexec.mainClass=org.apache.datafusion.examples.JdbcExample
+ * 
+ */ +public final class JdbcExample { + + /** + * Read-only DataFusion table backed by a JDBC query. Schema is captured at construction time from + * {@link PreparedStatement#getMetaData()}; each {@link #scan} re-executes the query and streams + * the result through {@code arrow-jdbc}'s {@link ArrowVectorIterator}. + */ + public static final class JdbcTableProvider implements TableProvider { + private final String url; + private final String query; + private final Schema schema; + + public JdbcTableProvider(String url, String query) { + this.url = url; + this.query = query; + this.schema = fetchSchema(); + } + + private Schema fetchSchema() { + try (BufferAllocator tmp = new RootAllocator(); + Connection conn = DriverManager.getConnection(url); + PreparedStatement stmt = conn.prepareStatement(query)) { + JdbcToArrowConfig config = configFor(tmp); + return JdbcToArrowUtils.jdbcToArrowSchema(stmt.getMetaData(), config); + } catch (SQLException e) { + throw new RuntimeException("Failed to fetch JDBC schema for query: " + query, e); + } + } + + @Override + public Schema schema() { + return schema; + } + + @Override + public ArrowReader scan(BufferAllocator allocator) { + // Run the query and stream the result through arrow-jdbc's ArrowVectorIterator, which + // emits VectorSchemaRoots whose buffers are allocated from `allocator`. We wrap that + // iterator in a small ArrowReader subclass so that DataFusion can consume it. The + // JDBC Connection / ResultSet are kept open by the closure created in iteratorOf() + // and closed by JdbcArrowReader.closeReadSource(). + JdbcToArrowConfig config = configFor(allocator); + OpenedQuery opened = openQuery(config); + try { + ArrowVectorIterator iter = JdbcToArrow.sqlToArrowVectorIterator(opened.rs, config); + return new JdbcArrowReader(allocator, schema, iter, opened); + } catch (SQLException | IOException e) { + opened.closeQuietly(); + throw new RuntimeException("Failed to create JDBC iterator for query: " + query, e); + } + } + + private OpenedQuery openQuery(JdbcToArrowConfig config) { + try { + Connection conn = DriverManager.getConnection(url); + Statement stmt = conn.createStatement(); + ResultSet rs = stmt.executeQuery(query); + return new OpenedQuery(conn, stmt, rs); + } catch (SQLException e) { + throw new RuntimeException("Failed to execute JDBC query: " + query, e); + } + } + + private static JdbcToArrowConfig configFor(BufferAllocator allocator) { + return new JdbcToArrowConfigBuilder( + allocator, Calendar.getInstance(TimeZone.getTimeZone("UTC"))) + .build(); + } + } + + /** Bundle of JDBC handles to close together when the scan finishes. */ + private static final class OpenedQuery { + final Connection conn; + final Statement stmt; + final ResultSet rs; + + OpenedQuery(Connection conn, Statement stmt, ResultSet rs) { + this.conn = conn; + this.stmt = stmt; + this.rs = rs; + } + + void closeQuietly() { + try { + rs.close(); + } catch (SQLException ignored) { + // best-effort + } + try { + stmt.close(); + } catch (SQLException ignored) { + // best-effort + } + try { + conn.close(); + } catch (SQLException ignored) { + // best-effort + } + } + } + + /** + * {@link ArrowReader} backed by an {@link ArrowVectorIterator}. Each {@link #loadNextBatch} pulls + * the next {@link VectorSchemaRoot} from the iterator and transfers its data into the reader's + * managed root via {@link VectorUnloader}/{@link VectorLoader}. + */ + private static final class JdbcArrowReader extends ArrowReader { + private final Schema schema; + private final ArrowVectorIterator iter; + private final OpenedQuery opened; + + JdbcArrowReader( + BufferAllocator allocator, Schema schema, ArrowVectorIterator iter, OpenedQuery opened) { + super(allocator); + this.schema = schema; + this.iter = iter; + this.opened = opened; + } + + @Override + protected Schema readSchema() { + return schema; + } + + @Override + public boolean loadNextBatch() throws IOException { + if (!iter.hasNext()) { + return false; + } + try (VectorSchemaRoot batch = iter.next()) { + VectorUnloader unloader = new VectorUnloader(batch); + try (ArrowRecordBatch rb = unloader.getRecordBatch()) { + new VectorLoader(getVectorSchemaRoot()).load(rb); + } + } + return true; + } + + @Override + public long bytesRead() { + return 0; + } + + @Override + protected void closeReadSource() { + iter.close(); + opened.closeQuietly(); + } + } + + public static void main(String[] args) throws Exception { + String url = "jdbc:h2:mem:demo;DB_CLOSE_DELAY=-1"; + + // Populate an H2 in-memory table. Column names are double-quoted so H2 stores them in the + // exact (lowercase) case; arrow-jdbc uses ResultSetMetaData.getColumnName, which returns + // the stored case, so the resulting Arrow schema ends up with lowercase field names that + // DataFusion SQL can refer to without quoting. + try (Connection conn = DriverManager.getConnection(url); + Statement stmt = conn.createStatement()) { + stmt.execute( + "CREATE TABLE \"orders\" (" + + "\"id\" INT PRIMARY KEY," + + " \"customer\" VARCHAR(64) NOT NULL," + + " \"total\" DOUBLE NOT NULL)"); + stmt.execute( + "INSERT INTO \"orders\" VALUES" + + " (1, 'alice', 19.99)," + + " (2, 'bob', 7.50)," + + " (3, 'alice', 100.00)"); + } + + JdbcTableProvider src = + new JdbcTableProvider(url, "SELECT \"id\", \"customer\", \"total\" FROM \"orders\""); + + try (SessionContext ctx = new SessionContext(); + BufferAllocator allocator = new RootAllocator()) { + ctx.registerTable("orders", src); + + try (DataFrame df = + ctx.sql( + "SELECT customer, SUM(total) AS spend" + + " FROM orders GROUP BY customer ORDER BY customer"); + ArrowReader reader = df.collect(allocator)) { + System.out.printf("%-10s | %s%n", "customer", "spend"); + System.out.println("-----------+--------"); + while (reader.loadNextBatch()) { + VectorSchemaRoot root = reader.getVectorSchemaRoot(); + VarCharVector customer = (VarCharVector) root.getVector("customer"); + Float8Vector spend = (Float8Vector) root.getVector("spend"); + for (int i = 0; i < customer.getValueCount(); i++) { + System.out.printf("%-10s | %.2f%n", new String(customer.get(i)), spend.get(i)); + } + } + } + } + } +} diff --git a/native/Cargo.lock b/native/Cargo.lock index 418c359..7171f72 100644 --- a/native/Cargo.lock +++ b/native/Cargo.lock @@ -1250,6 +1250,7 @@ name = "datafusion-jni" version = "0.1.0" dependencies = [ "arrow", + "async-trait", "datafusion", "datafusion-proto", "futures", diff --git a/native/Cargo.toml b/native/Cargo.toml index 983d7eb..28e1e8f 100644 --- a/native/Cargo.toml +++ b/native/Cargo.toml @@ -26,6 +26,7 @@ crate-type = ["cdylib"] [dependencies] arrow = { version = "58", features = ["ffi"] } +async-trait = "0.1" datafusion = { version = "53.1.0", features = ["avro"] } datafusion-proto = "53.1.0" futures = "0.3" diff --git a/native/src/jni_util.rs b/native/src/jni_util.rs new file mode 100644 index 0000000..daa2b63 --- /dev/null +++ b/native/src/jni_util.rs @@ -0,0 +1,69 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Small shared helpers for JNI call sites. + +use jni::objects::JThrowable; +use jni::JNIEnv; + +/// Best-effort: extract class name and `getMessage()` from a Java throwable. +/// Anything that goes wrong collapses to a generic message so we don't +/// double-throw inside an error path. +/// +/// `kind` and `name` are used to build the surfaced error message +/// (e.g., `kind="UDF" name="add_one"` -> `"Java UDF 'add_one' threw ..."`). +pub(crate) fn jthrowable_to_string( + env: &mut JNIEnv, + throwable: &JThrowable, + kind: &str, + name: &str, +) -> String { + let class_name_result = (|| -> jni::errors::Result { + let class = env.call_method(throwable, "getClass", "()Ljava/lang/Class;", &[])?; + let class_obj = class.l()?; + let n = env.call_method(&class_obj, "getName", "()Ljava/lang/String;", &[])?; + let n_obj = n.l()?; + let n_str: String = env.get_string(&n_obj.into())?.into(); + Ok(n_str) + })(); + let class_name = match class_name_result { + Ok(s) => s, + Err(_) => { + env.exception_clear().ok(); + "".to_string() + } + }; + + let message_result = (|| -> jni::errors::Result { + let msg = env.call_method(throwable, "getMessage", "()Ljava/lang/String;", &[])?; + let msg_obj = msg.l()?; + if msg_obj.is_null() { + return Ok("".to_string()); + } + let s: String = env.get_string(&msg_obj.into())?.into(); + Ok(s) + })(); + let message = match message_result { + Ok(s) => s, + Err(_) => { + env.exception_clear().ok(); + "".to_string() + } + }; + + format!("Java {} '{}' threw {}: {}", kind, name, class_name, message) +} diff --git a/native/src/lib.rs b/native/src/lib.rs index 1d0f36d..a235cd3 100644 --- a/native/src/lib.rs +++ b/native/src/lib.rs @@ -19,9 +19,11 @@ mod arrow; mod avro; mod csv; mod errors; +mod jni_util; mod json; mod proto; mod schema; +mod table_provider; mod udf; pub(crate) mod proto_gen { @@ -765,3 +767,45 @@ pub extern "system" fn Java_org_apache_datafusion_SessionContext_registerScalarU Ok(()) }) } + +#[no_mangle] +pub extern "system" fn Java_org_apache_datafusion_SessionContext_registerTableNative<'local>( + mut env: JNIEnv<'local>, + _class: JClass<'local>, + handle: jlong, + name: JString<'local>, + schema_ipc_bytes: JByteArray<'local>, + provider: JObject<'local>, +) { + try_unwrap_or_throw(&mut env, (), |env| -> JniResult<()> { + if handle == 0 { + return Err("SessionContext handle is null".into()); + } + // SAFETY: handle is a valid Box allocated by createSessionContext. + let ctx = unsafe { &*(handle as *const SessionContext) }; + let name: String = env.get_string(&name)?.into(); + + let schema = crate::schema::decode_optional_schema(env, schema_ipc_bytes)? + .ok_or("schema bytes were null")?; + let schema = Arc::new(schema); + + let source_global_ref = Arc::new(env.new_global_ref(&provider)?); + let bridge_class_local = env.find_class("org/apache/datafusion/internal/JniBridge")?; + let bridge_class = Arc::new(env.new_global_ref(&bridge_class_local)?); + let invoke_method = env.get_static_method_id( + &bridge_class_local, + "invokeTableScan", + "(Lorg/apache/datafusion/TableProvider;J)V", + )?; + + let java_tp = crate::table_provider::JavaTableProvider { + name: name.clone(), + schema, + source_global_ref, + bridge_class, + invoke_method, + }; + let _ = ctx.register_table(name.as_str(), Arc::new(java_tp))?; + Ok(()) + }) +} diff --git a/native/src/table_provider.rs b/native/src/table_provider.rs new file mode 100644 index 0000000..70eaac2 --- /dev/null +++ b/native/src/table_provider.rs @@ -0,0 +1,269 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Java-backed [`TableProvider`] implementation. +//! +//! Used by `SessionContext::registerTable` on the Java side to register user-implemented +//! `TableProvider`s. The internal struct here mirrors the role of DataFusion's Rust +//! `TableProvider` trait; it currently only supports a single-partition, no-pushdown scan, +//! with future pushdown and partitioning support tracked as follow-ups. + +use std::any::Any; +use std::fmt; +use std::sync::Arc; + +use async_trait::async_trait; +use datafusion::arrow::datatypes::SchemaRef; +use datafusion::arrow::ffi_stream::{ArrowArrayStreamReader, FFI_ArrowArrayStream}; +use datafusion::arrow::record_batch::{RecordBatch, RecordBatchReader}; +use datafusion::catalog::Session; +use datafusion::datasource::{TableProvider, TableType}; +use datafusion::error::{DataFusionError, Result}; +use datafusion::execution::context::TaskContext; +use datafusion::logical_expr::Expr; +use datafusion::physical_expr::EquivalenceProperties; +use datafusion::physical_plan::execution_plan::{Boundedness, EmissionType}; +use datafusion::physical_plan::stream::RecordBatchStreamAdapter; +use datafusion::physical_plan::{ + DisplayAs, DisplayFormatType, ExecutionPlan, Partitioning, PlanProperties, + SendableRecordBatchStream, +}; +use futures::stream::StreamExt; +use jni::objects::{GlobalRef, JStaticMethodID}; +use jni::signature::{Primitive, ReturnType}; +use jni::sys::{jlong, jvalue}; + +use crate::jni_util::jthrowable_to_string; + +pub(crate) struct JavaTableProvider { + pub(crate) name: String, + pub(crate) schema: SchemaRef, + pub(crate) source_global_ref: Arc, + pub(crate) bridge_class: Arc, + pub(crate) invoke_method: JStaticMethodID, +} + +// SAFETY: see the matching unsafe impls on JavaScalarUdf. The GlobalRefs keep +// the Java objects alive; JStaticMethodID points into the class held by +// bridge_class; nothing is mutated after construction. +unsafe impl Send for JavaTableProvider {} +unsafe impl Sync for JavaTableProvider {} + +impl fmt::Debug for JavaTableProvider { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_struct("JavaTableProvider") + .field("name", &self.name) + .field("schema", &self.schema) + .finish() + } +} + +#[async_trait] +impl TableProvider for JavaTableProvider { + fn as_any(&self) -> &dyn Any { + self + } + + fn schema(&self) -> SchemaRef { + Arc::clone(&self.schema) + } + + fn table_type(&self) -> TableType { + TableType::Base + } + + async fn scan( + &self, + _state: &dyn Session, + projection: Option<&Vec>, + _filters: &[Expr], + _limit: Option, + ) -> Result> { + let projected_schema = match projection { + Some(p) => Arc::new(self.schema.project(p)?), + None => Arc::clone(&self.schema), + }; + let plan_properties = Arc::new(PlanProperties::new( + EquivalenceProperties::new(Arc::clone(&projected_schema)), + Partitioning::UnknownPartitioning(1), + EmissionType::Incremental, + Boundedness::Bounded, + )); + Ok(Arc::new(JavaScanExec { + name: self.name.clone(), + full_schema: Arc::clone(&self.schema), + projected_schema, + projection: projection.cloned(), + source_global_ref: Arc::clone(&self.source_global_ref), + bridge_class: Arc::clone(&self.bridge_class), + invoke_method: self.invoke_method, + plan_properties, + })) + } +} + +pub(crate) struct JavaScanExec { + name: String, + full_schema: SchemaRef, + projected_schema: SchemaRef, + projection: Option>, + source_global_ref: Arc, + bridge_class: Arc, + invoke_method: JStaticMethodID, + plan_properties: Arc, +} + +// SAFETY: same reasoning as JavaTableProvider above — GlobalRefs via Arc keep +// Java objects alive; JStaticMethodID is stable; nothing mutated after construction. +unsafe impl Send for JavaScanExec {} +unsafe impl Sync for JavaScanExec {} + +impl fmt::Debug for JavaScanExec { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_struct("JavaScanExec") + .field("name", &self.name) + .field("projected_schema", &self.projected_schema) + .finish() + } +} + +impl DisplayAs for JavaScanExec { + fn fmt_as(&self, _t: DisplayFormatType, f: &mut fmt::Formatter) -> fmt::Result { + write!(f, "JavaScanExec: name={}", self.name) + } +} + +impl ExecutionPlan for JavaScanExec { + fn name(&self) -> &str { + "JavaScanExec" + } + + fn as_any(&self) -> &dyn Any { + self + } + + fn properties(&self) -> &Arc { + &self.plan_properties + } + + fn children(&self) -> Vec<&Arc> { + vec![] + } + + fn with_new_children( + self: Arc, + _children: Vec>, + ) -> Result> { + Ok(self) + } + + fn execute( + &self, + partition: usize, + _ctx: Arc, + ) -> Result { + if partition != 0 { + return Err(DataFusionError::Internal(format!( + "JavaScanExec has 1 partition; got {}", + partition + ))); + } + + // 1. Allocate an empty FFI stream and box it for a stable address. + let mut ffi_box = Box::new(FFI_ArrowArrayStream::empty()); + let ffi_addr = ffi_box.as_mut() as *mut FFI_ArrowArrayStream as jlong; + + // 2. Attach the JVM and call the bridge. + // + // The attachment scope is just this function: we need the JVM attached for + // the synchronous `invokeTableScan` call. Subsequent polls of the + // returned stream do not need this attachment, because the FFI release / + // get_next callbacks installed by arrow-java's `Data.exportArrayStream` + // self-attach to the JVM via the global `JavaVM` set in our `JNI_OnLoad`. + let mut env = crate::jvm() + .attach_current_thread() + .map_err(|e| DataFusionError::Execution(format!("JNI attach failed: {}", e)))?; + + let source_jobj = self.source_global_ref.as_obj(); + let call_args: [jvalue; 2] = [ + jvalue { + l: source_jobj.as_raw(), + }, + jvalue { j: ffi_addr }, + ]; + + let call_result = unsafe { + env.call_static_method_unchecked( + self.bridge_class.as_ref(), + self.invoke_method, + ReturnType::Primitive(Primitive::Void), + &call_args, + ) + }; + + // 3. Surface any Java exception. + if env.exception_check().unwrap_or(false) { + let throwable = env.exception_occurred().map_err(|e| { + DataFusionError::Execution(format!("exception_occurred failed: {}", e)) + })?; + env.exception_clear().ok(); + return Err(DataFusionError::Execution(jthrowable_to_string( + &mut env, + &throwable, + "TableProvider", + &self.name, + ))); + } + + call_result.map_err(|e| DataFusionError::Execution(format!("JNI call failed: {}", e)))?; + + // 4. Reclaim the FFI struct and import as a RecordBatchReader. + let ffi_stream: FFI_ArrowArrayStream = *ffi_box; + let reader = ArrowArrayStreamReader::try_new(ffi_stream) + .map_err(|e| DataFusionError::ArrowError(Box::new(e), None))?; + + // 5. Verify the producer's declared schema matches our registered schema. + let reader_schema = reader.schema(); + // Schema::PartialEq compares fields AND metadata. If IPC / FFI round-trips + // ever normalise metadata differently between the registration path and the + // scan path, switch to comparing `.fields()` only. + if reader_schema.as_ref() != self.full_schema.as_ref() { + return Err(DataFusionError::Execution(format!( + "Java TableProvider '{}' returned schema {:?}; registered schema was {:?}", + self.name, reader_schema, self.full_schema + ))); + } + + // 6. Wrap as a Stream and (if a projection is set) project each batch. + let projection = self.projection.clone(); + let stream = futures::stream::iter(reader).map(move |batch_result| { + let batch: RecordBatch = + batch_result.map_err(|e| DataFusionError::ArrowError(Box::new(e), None))?; + match &projection { + Some(p) => batch + .project(p) + .map_err(|e| DataFusionError::ArrowError(Box::new(e), None)), + None => Ok(batch), + } + }); + + Ok(Box::pin(RecordBatchStreamAdapter::new( + Arc::clone(&self.projected_schema), + stream, + ))) + } +} diff --git a/native/src/udf.rs b/native/src/udf.rs index d2b18b4..41258a7 100644 --- a/native/src/udf.rs +++ b/native/src/udf.rs @@ -28,10 +28,9 @@ use datafusion::error::DataFusionError; use datafusion::logical_expr::{ ColumnarValue, ScalarFunctionArgs, ScalarUDFImpl, Signature, TypeSignature, Volatility, }; -use jni::objects::{GlobalRef, JStaticMethodID, JThrowable}; +use jni::objects::{GlobalRef, JStaticMethodID}; use jni::signature::{Primitive, ReturnType}; use jni::sys::{jbyte, jlong, jvalue}; -use jni::JNIEnv; pub(crate) struct JavaScalarUdf { pub(crate) name: String, @@ -230,7 +229,8 @@ impl ScalarUDFImpl for JavaScalarUdf { DataFusionError::Execution(format!("exception_occurred failed: {}", e)) })?; env.exception_clear().ok(); - let message = jthrowable_to_string(&mut env, &throwable, &self.name); + let message = + crate::jni_util::jthrowable_to_string(&mut env, &throwable, "UDF", &self.name); return Err(DataFusionError::Execution(message)); } @@ -292,45 +292,3 @@ pub(crate) fn volatility_from_byte(byte: u8) -> datafusion::error::Result String { - let class_name_result = (|| -> jni::errors::Result { - let class = env.call_method(throwable, "getClass", "()Ljava/lang/Class;", &[])?; - let class_obj = class.l()?; - let name = env.call_method(&class_obj, "getName", "()Ljava/lang/String;", &[])?; - let name_obj = name.l()?; - let name_str: String = env.get_string(&name_obj.into())?.into(); - Ok(name_str) - })(); - let class_name = match class_name_result { - Ok(s) => s, - Err(_) => { - // A reflective call itself threw — clear that secondary exception so the - // thread is in a clean state when we return to the JVM. - env.exception_clear().ok(); - "".to_string() - } - }; - - let message_result = (|| -> jni::errors::Result { - let msg = env.call_method(throwable, "getMessage", "()Ljava/lang/String;", &[])?; - let msg_obj = msg.l()?; - if msg_obj.is_null() { - return Ok("".to_string()); - } - let s: String = env.get_string(&msg_obj.into())?.into(); - Ok(s) - })(); - let message = match message_result { - Ok(s) => s, - Err(_) => { - env.exception_clear().ok(); - "".to_string() - } - }; - - format!("Java UDF '{}' threw {}: {}", udf_name, class_name, message) -}