diff --git a/core/src/main/java/org/apache/datafusion/ScalarUdf.java b/core/src/main/java/org/apache/datafusion/ScalarUdf.java
new file mode 100644
index 0000000..c546ee1
--- /dev/null
+++ b/core/src/main/java/org/apache/datafusion/ScalarUdf.java
@@ -0,0 +1,49 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.datafusion;
+
+import java.util.List;
+
+import org.apache.arrow.memory.BufferAllocator;
+import org.apache.arrow.vector.FieldVector;
+
+/**
+ * A Java-implemented scalar SQL function. Register an instance with {@link
+ * SessionContext#registerUdf} to make it callable from SQL or DataFrame plans.
+ *
+ *
Implementations may be invoked concurrently by DataFusion on multiple worker threads. If the
+ * implementation carries mutable state, the implementation must synchronize it.
+ */
+@FunctionalInterface
+public interface ScalarUdf {
+ /**
+ * Compute the function result for one input batch.
+ *
+ * @param allocator the {@link BufferAllocator} that MUST be used for any new {@link FieldVector}
+ * allocation, including the result. Buffers allocated from other allocators will not survive
+ * the JNI handoff.
+ * @param args one {@link FieldVector} per declared argument, all of the same length. These are
+ * read-only views; the implementation must NOT close them.
+ * @return a {@link FieldVector} of the declared return type and the same length as the inputs.
+ * Ownership transfers to the framework on return; the implementation must NOT close the
+ * returned vector.
+ */
+ FieldVector evaluate(BufferAllocator allocator, List args);
+}
diff --git a/core/src/main/java/org/apache/datafusion/SessionContext.java b/core/src/main/java/org/apache/datafusion/SessionContext.java
index 3cea058..6cba42c 100644
--- a/core/src/main/java/org/apache/datafusion/SessionContext.java
+++ b/core/src/main/java/org/apache/datafusion/SessionContext.java
@@ -23,6 +23,8 @@
import java.io.ByteArrayOutputStream;
import java.io.IOException;
import java.nio.channels.Channels;
+import java.util.ArrayList;
+import java.util.List;
import org.apache.arrow.memory.BufferAllocator;
import org.apache.arrow.memory.RootAllocator;
@@ -30,6 +32,9 @@
import org.apache.arrow.vector.ipc.ArrowStreamWriter;
import org.apache.arrow.vector.ipc.ReadChannel;
import org.apache.arrow.vector.ipc.message.MessageSerializer;
+import org.apache.arrow.vector.types.pojo.ArrowType;
+import org.apache.arrow.vector.types.pojo.Field;
+import org.apache.arrow.vector.types.pojo.FieldType;
import org.apache.arrow.vector.types.pojo.Schema;
/**
@@ -204,6 +209,41 @@ public DataFrame readParquet(String path, ParquetReadOptions options) {
return new DataFrame(dfHandle);
}
+ /**
+ * Register a Java-implemented scalar UDF. After registration, the function can be invoked by SQL
+ * via its {@code name} or referenced in DataFusion plans deserialised with {@link #fromProto}.
+ *
+ * Argument and return types are declared at registration time. The UDF is registered with an
+ * exact signature: the runtime will reject calls whose argument types do not match {@code
+ * argTypes} exactly.
+ *
+ * @throws RuntimeException if registration fails (e.g., name already registered with an
+ * incompatible signature, schema serialisation failure).
+ */
+ public void registerUdf(
+ String name,
+ ScalarUdf udf,
+ ArrowType returnType,
+ List argTypes,
+ Volatility volatility) {
+ if (nativeHandle == 0) {
+ throw new IllegalStateException("SessionContext is closed");
+ }
+ java.util.Objects.requireNonNull(name, "name");
+ java.util.Objects.requireNonNull(udf, "udf");
+ java.util.Objects.requireNonNull(returnType, "returnType");
+ java.util.Objects.requireNonNull(argTypes, "argTypes");
+ java.util.Objects.requireNonNull(volatility, "volatility");
+ List fields = new ArrayList<>(argTypes.size() + 1);
+ fields.add(new Field("return", FieldType.nullable(returnType), null));
+ for (int i = 0; i < argTypes.size(); i++) {
+ fields.add(new Field("arg" + i, FieldType.nullable(argTypes.get(i)), null));
+ }
+ Schema signatureSchema = new Schema(fields);
+ byte[] signatureBytes = serializeSchemaIpc(signatureSchema);
+ registerScalarUdf(nativeHandle, name, signatureBytes, volatility.code(), udf);
+ }
+
private static byte[] serializeSchemaIpc(Schema schema) {
ByteArrayOutputStream baos = new ByteArrayOutputStream();
try (BufferAllocator allocator = new RootAllocator();
@@ -248,4 +288,7 @@ private static native long readCsvWithOptions(
long handle, String path, byte[] optionsBytes, byte[] schemaIpcBytes);
private static native void closeSessionContext(long handle);
+
+ private static native void registerScalarUdf(
+ long handle, String name, byte[] signatureSchemaBytes, byte volatility, ScalarUdf udf);
}
diff --git a/core/src/main/java/org/apache/datafusion/Volatility.java b/core/src/main/java/org/apache/datafusion/Volatility.java
new file mode 100644
index 0000000..eea2fa3
--- /dev/null
+++ b/core/src/main/java/org/apache/datafusion/Volatility.java
@@ -0,0 +1,48 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.datafusion;
+
+/**
+ * Volatility classification for a UDF. Mirrors DataFusion's {@code Volatility} enum.
+ *
+ *
+ * - {@link #IMMUTABLE} — pure function: same inputs always produce the same output; safe to
+ * constant-fold and common-subexpression-eliminate.
+ *
- {@link #STABLE} — deterministic within a single query but not across queries (e.g., {@code
+ * now()}).
+ *
- {@link #VOLATILE} — may return a different value on every call (e.g., {@code random()}).
+ *
+ */
+public enum Volatility {
+ IMMUTABLE((byte) 0),
+ STABLE((byte) 1),
+ VOLATILE((byte) 2);
+
+ private final byte code;
+
+ Volatility(byte code) {
+ this.code = code;
+ }
+
+ /** Stable byte code for FFI. */
+ public byte code() {
+ return code;
+ }
+}
diff --git a/core/src/main/java/org/apache/datafusion/internal/JniBridge.java b/core/src/main/java/org/apache/datafusion/internal/JniBridge.java
new file mode 100644
index 0000000..1dd4389
--- /dev/null
+++ b/core/src/main/java/org/apache/datafusion/internal/JniBridge.java
@@ -0,0 +1,93 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.datafusion.internal;
+
+import java.util.List;
+
+import org.apache.arrow.c.ArrowArray;
+import org.apache.arrow.c.ArrowSchema;
+import org.apache.arrow.c.Data;
+import org.apache.arrow.memory.RootAllocator;
+import org.apache.arrow.vector.FieldVector;
+import org.apache.arrow.vector.VectorSchemaRoot;
+import org.apache.datafusion.ScalarUdf;
+
+/** Internal trampoline invoked from native code on every UDF call. Not part of the public API. */
+public final class JniBridge {
+ /**
+ * Shared allocator for UDF inputs/outputs. Created once at class-load time; never closed.
+ * Outstanding allocations are released by the FFI structs' release callbacks when the native side
+ * drops them after {@code from_ffi}.
+ */
+ private static final RootAllocator ALLOCATOR = new RootAllocator();
+
+ private JniBridge() {}
+
+ /**
+ * Invoke a scalar UDF for one batch. Called from native code; not for application use.
+ *
+ * @param udf the registered {@link ScalarUdf} instance
+ * @param argsArrayAddr address of a populated {@code FFI_ArrowArray} struct holding the input
+ * batch as a struct array (one field per UDF argument)
+ * @param argsSchemaAddr address of the matching {@code FFI_ArrowSchema}
+ * @param resultArrayAddr address of an empty {@code FFI_ArrowArray} the bridge writes into
+ * @param resultSchemaAddr address of an empty {@code FFI_ArrowSchema} the bridge writes into
+ * @param expectedRowCount the row count the result vector must have
+ */
+ public static void invokeScalarUdf(
+ ScalarUdf udf,
+ long argsArrayAddr,
+ long argsSchemaAddr,
+ long resultArrayAddr,
+ long resultSchemaAddr,
+ int expectedRowCount) {
+ ArrowArray argsArr = ArrowArray.wrap(argsArrayAddr);
+ ArrowSchema argsSch = ArrowSchema.wrap(argsSchemaAddr);
+ ArrowArray resultArr = ArrowArray.wrap(resultArrayAddr);
+ ArrowSchema resultSch = ArrowSchema.wrap(resultSchemaAddr);
+
+ try (VectorSchemaRoot root = Data.importVectorSchemaRoot(ALLOCATOR, argsArr, argsSch, null)) {
+ List argVectors = root.getFieldVectors();
+
+ FieldVector result = udf.evaluate(ALLOCATOR, argVectors);
+
+ if (result == null) {
+ throw new IllegalStateException("ScalarUdf.evaluate returned null");
+ }
+ if (result.getValueCount() != expectedRowCount) {
+ try {
+ throw new IllegalStateException(
+ "ScalarUdf.evaluate returned vector with "
+ + result.getValueCount()
+ + " rows; expected "
+ + expectedRowCount);
+ } finally {
+ result.close();
+ }
+ }
+
+ try {
+ Data.exportVector(ALLOCATOR, result, null, resultArr, resultSch);
+ } finally {
+ result.close();
+ }
+ }
+ }
+}
diff --git a/core/src/test/java/org/apache/datafusion/DataFrameTransformationsTest.java b/core/src/test/java/org/apache/datafusion/DataFrameTransformationsTest.java
index cb5c9ef..6b1ed20 100644
--- a/core/src/test/java/org/apache/datafusion/DataFrameTransformationsTest.java
+++ b/core/src/test/java/org/apache/datafusion/DataFrameTransformationsTest.java
@@ -246,8 +246,7 @@ void limitRejectsNegativeArgs() {
@Test
void distinctRemovesDuplicates() {
try (SessionContext ctx = new SessionContext();
- DataFrame source =
- ctx.sql("SELECT * FROM (VALUES (1), (1), (2), (2), (3)) AS t(x)");
+ DataFrame source = ctx.sql("SELECT * FROM (VALUES (1), (1), (2), (2), (3)) AS t(x)");
DataFrame deduped = source.distinct()) {
assertEquals(3L, deduped.count());
}
diff --git a/core/src/test/java/org/apache/datafusion/ScalarUdfTest.java b/core/src/test/java/org/apache/datafusion/ScalarUdfTest.java
new file mode 100644
index 0000000..05f263d
--- /dev/null
+++ b/core/src/test/java/org/apache/datafusion/ScalarUdfTest.java
@@ -0,0 +1,464 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.datafusion;
+
+import static org.junit.jupiter.api.Assertions.assertEquals;
+
+import java.util.List;
+
+import org.apache.arrow.memory.BufferAllocator;
+import org.apache.arrow.memory.RootAllocator;
+import org.apache.arrow.vector.FieldVector;
+import org.apache.arrow.vector.IntVector;
+import org.apache.arrow.vector.VectorSchemaRoot;
+import org.apache.arrow.vector.ipc.ArrowReader;
+import org.apache.arrow.vector.types.pojo.ArrowType;
+import org.junit.jupiter.api.Test;
+
+class ScalarUdfTest {
+
+ /** Adds 1 to each row of an Int32 column. */
+ static final class AddOne implements ScalarUdf {
+ @Override
+ public FieldVector evaluate(BufferAllocator allocator, List args) {
+ IntVector in = (IntVector) args.get(0);
+ IntVector out = new IntVector("add_one_out", allocator);
+ int n = in.getValueCount();
+ out.allocateNew(n);
+ for (int i = 0; i < n; i++) {
+ if (in.isNull(i)) {
+ out.setNull(i);
+ } else {
+ out.set(i, in.get(i) + 1);
+ }
+ }
+ out.setValueCount(n);
+ return out;
+ }
+ }
+
+ @Test
+ void addOne_overConstantTable_returnsIncrementedValues() throws Exception {
+ try (SessionContext ctx = new SessionContext();
+ BufferAllocator allocator = new RootAllocator()) {
+ ctx.registerUdf(
+ "add_one",
+ new AddOne(),
+ new ArrowType.Int(32, true),
+ List.of(new ArrowType.Int(32, true)),
+ Volatility.IMMUTABLE);
+
+ try (DataFrame df =
+ ctx.sql(
+ "SELECT add_one(x) AS y"
+ + " FROM (VALUES (CAST(1 AS INT)), (CAST(2 AS INT)), (CAST(3 AS INT)))"
+ + " AS t(x)");
+ ArrowReader r = df.collect(allocator)) {
+ assertEquals(true, r.loadNextBatch());
+ VectorSchemaRoot root = r.getVectorSchemaRoot();
+ IntVector y = (IntVector) root.getVector("y");
+ assertEquals(3, y.getValueCount());
+ assertEquals(2, y.get(0));
+ assertEquals(3, y.get(1));
+ assertEquals(4, y.get(2));
+ }
+ }
+ }
+
+ /** Concatenates two Utf8 columns. */
+ static final class Concat implements ScalarUdf {
+ @Override
+ public FieldVector evaluate(BufferAllocator allocator, List args) {
+ org.apache.arrow.vector.VarCharVector left = (org.apache.arrow.vector.VarCharVector) args.get(0);
+ org.apache.arrow.vector.VarCharVector right = (org.apache.arrow.vector.VarCharVector) args.get(1);
+ org.apache.arrow.vector.VarCharVector out =
+ new org.apache.arrow.vector.VarCharVector("concat_out", allocator);
+ int n = left.getValueCount();
+ out.allocateNew(n);
+ for (int i = 0; i < n; i++) {
+ if (left.isNull(i) || right.isNull(i)) {
+ out.setNull(i);
+ } else {
+ byte[] l = left.get(i);
+ byte[] r = right.get(i);
+ byte[] both = new byte[l.length + r.length];
+ System.arraycopy(l, 0, both, 0, l.length);
+ System.arraycopy(r, 0, both, l.length, r.length);
+ out.setSafe(i, both);
+ }
+ }
+ out.setValueCount(n);
+ return out;
+ }
+ }
+
+ @Test
+ void concat_overVarCharColumns_concatenatesValues() throws Exception {
+ try (SessionContext ctx = new SessionContext();
+ BufferAllocator allocator = new RootAllocator()) {
+ ctx.registerUdf(
+ "java_concat",
+ new Concat(),
+ new ArrowType.Utf8(),
+ List.of(new ArrowType.Utf8(), new ArrowType.Utf8()),
+ Volatility.IMMUTABLE);
+
+ try (DataFrame df =
+ ctx.sql(
+ "SELECT java_concat(a, b) AS c FROM (VALUES ('foo','bar'),('hello','!')) AS t(a, b)");
+ ArrowReader r = df.collect(allocator)) {
+ assertEquals(true, r.loadNextBatch());
+ VectorSchemaRoot root = r.getVectorSchemaRoot();
+ org.apache.arrow.vector.VarCharVector c =
+ (org.apache.arrow.vector.VarCharVector) root.getVector("c");
+ assertEquals(2, c.getValueCount());
+ assertEquals("foobar", new String(c.get(0)));
+ assertEquals("hello!", new String(c.get(1)));
+ }
+ }
+ }
+
+ /** Squares a Float64 column. */
+ static final class Square implements ScalarUdf {
+ @Override
+ public FieldVector evaluate(BufferAllocator allocator, List args) {
+ org.apache.arrow.vector.Float8Vector in = (org.apache.arrow.vector.Float8Vector) args.get(0);
+ org.apache.arrow.vector.Float8Vector out =
+ new org.apache.arrow.vector.Float8Vector("square_out", allocator);
+ int n = in.getValueCount();
+ out.allocateNew(n);
+ for (int i = 0; i < n; i++) {
+ if (in.isNull(i)) {
+ out.setNull(i);
+ } else {
+ double v = in.get(i);
+ out.set(i, v * v);
+ }
+ }
+ out.setValueCount(n);
+ return out;
+ }
+ }
+
+ @Test
+ void square_overFloat64Column_squaresValues() throws Exception {
+ try (SessionContext ctx = new SessionContext();
+ BufferAllocator allocator = new RootAllocator()) {
+ ctx.registerUdf(
+ "java_square",
+ new Square(),
+ new ArrowType.FloatingPoint(org.apache.arrow.vector.types.FloatingPointPrecision.DOUBLE),
+ List.of(
+ new ArrowType.FloatingPoint(
+ org.apache.arrow.vector.types.FloatingPointPrecision.DOUBLE)),
+ Volatility.IMMUTABLE);
+
+ try (DataFrame df =
+ ctx.sql("SELECT java_square(x) AS y FROM (VALUES (2.0),(3.5)) AS t(x)");
+ ArrowReader r = df.collect(allocator)) {
+ assertEquals(true, r.loadNextBatch());
+ VectorSchemaRoot root = r.getVectorSchemaRoot();
+ org.apache.arrow.vector.Float8Vector y =
+ (org.apache.arrow.vector.Float8Vector) root.getVector("y");
+ assertEquals(2, y.getValueCount());
+ assertEquals(4.0, y.get(0), 0.0);
+ assertEquals(12.25, y.get(1), 0.0);
+ }
+ }
+ }
+
+ @Test
+ void addOne_invokedTwiceInOneSession_executesIndependently() throws Exception {
+ // Re-running the same UDF query twice exercises that GlobalRefs and JNI state
+ // don't accumulate across invocations within a session.
+ try (SessionContext ctx = new SessionContext();
+ BufferAllocator allocator = new RootAllocator()) {
+ ctx.registerUdf(
+ "add_one",
+ new AddOne(),
+ new ArrowType.Int(32, true),
+ List.of(new ArrowType.Int(32, true)),
+ Volatility.IMMUTABLE);
+
+ for (int run = 0; run < 2; run++) {
+ try (DataFrame df = ctx.sql("SELECT add_one(CAST(5 AS INT)) AS y");
+ ArrowReader r = df.collect(allocator)) {
+ assertEquals(true, r.loadNextBatch());
+ IntVector y = (IntVector) r.getVectorSchemaRoot().getVector("y");
+ assertEquals(1, y.getValueCount());
+ assertEquals(6, y.get(0));
+ }
+ }
+ }
+ }
+
+ static final class ReturnsNull implements ScalarUdf {
+ @Override
+ public FieldVector evaluate(BufferAllocator allocator, List args) {
+ return null;
+ }
+ }
+
+ @Test
+ void udfReturningNull_surfacesIllegalStateException() {
+ try (SessionContext ctx = new SessionContext();
+ BufferAllocator allocator = new RootAllocator()) {
+ ctx.registerUdf(
+ "bad_null",
+ new ReturnsNull(),
+ new ArrowType.Int(32, true),
+ List.of(new ArrowType.Int(32, true)),
+ Volatility.IMMUTABLE);
+ RuntimeException ex =
+ org.junit.jupiter.api.Assertions.assertThrows(
+ RuntimeException.class,
+ () -> {
+ try (DataFrame df = ctx.sql("SELECT bad_null(CAST(1 AS INT))");
+ ArrowReader r = df.collect(allocator)) {
+ while (r.loadNextBatch()) {}
+ }
+ });
+ org.junit.jupiter.api.Assertions.assertTrue(
+ ex.getMessage().contains("returned null"),
+ "expected error to mention 'returned null', got: " + ex.getMessage());
+ }
+ }
+
+ static final class WrongRowCount implements ScalarUdf {
+ @Override
+ public FieldVector evaluate(BufferAllocator allocator, List args) {
+ IntVector in = (IntVector) args.get(0);
+ IntVector out = new IntVector("out", allocator);
+ out.allocateNew(in.getValueCount() + 1); // off by one
+ for (int i = 0; i < in.getValueCount() + 1; i++) out.set(i, 0);
+ out.setValueCount(in.getValueCount() + 1);
+ return out;
+ }
+ }
+
+ @Test
+ void udfReturningWrongRowCount_surfacesIllegalStateException() {
+ try (SessionContext ctx = new SessionContext();
+ BufferAllocator allocator = new RootAllocator()) {
+ ctx.registerUdf(
+ "bad_rows",
+ new WrongRowCount(),
+ new ArrowType.Int(32, true),
+ List.of(new ArrowType.Int(32, true)),
+ Volatility.IMMUTABLE);
+ RuntimeException ex =
+ org.junit.jupiter.api.Assertions.assertThrows(
+ RuntimeException.class,
+ () -> {
+ try (DataFrame df = ctx.sql("SELECT bad_rows(CAST(1 AS INT))");
+ ArrowReader r = df.collect(allocator)) {
+ while (r.loadNextBatch()) {}
+ }
+ });
+ org.junit.jupiter.api.Assertions.assertTrue(
+ ex.getMessage().contains("expected") && ex.getMessage().contains("rows"),
+ "expected error to mention row mismatch, got: " + ex.getMessage());
+ }
+ }
+
+ static final class WrongType implements ScalarUdf {
+ @Override
+ public FieldVector evaluate(BufferAllocator allocator, List args) {
+ // Declared return type is Int32; return Float64.
+ org.apache.arrow.vector.Float8Vector out =
+ new org.apache.arrow.vector.Float8Vector("out", allocator);
+ out.allocateNew(args.get(0).getValueCount());
+ for (int i = 0; i < args.get(0).getValueCount(); i++) out.set(i, 0.0);
+ out.setValueCount(args.get(0).getValueCount());
+ return out;
+ }
+ }
+
+ @Test
+ void udfReturningWrongType_surfacesTypeMismatch() {
+ try (SessionContext ctx = new SessionContext();
+ BufferAllocator allocator = new RootAllocator()) {
+ ctx.registerUdf(
+ "bad_type",
+ new WrongType(),
+ new ArrowType.Int(32, true),
+ List.of(new ArrowType.Int(32, true)),
+ Volatility.IMMUTABLE);
+ RuntimeException ex =
+ org.junit.jupiter.api.Assertions.assertThrows(
+ RuntimeException.class,
+ () -> {
+ try (DataFrame df = ctx.sql("SELECT bad_type(CAST(1 AS INT))");
+ ArrowReader r = df.collect(allocator)) {
+ while (r.loadNextBatch()) {}
+ }
+ });
+ org.junit.jupiter.api.Assertions.assertTrue(
+ ex.getMessage().toLowerCase().contains("type"),
+ "expected error to mention type mismatch, got: " + ex.getMessage());
+ }
+ }
+
+ static final class ThrowsIAE implements ScalarUdf {
+ @Override
+ public FieldVector evaluate(BufferAllocator allocator, List args) {
+ throw new IllegalArgumentException("custom boom from UDF");
+ }
+ }
+
+ @Test
+ void udfThrowingException_propagatesClassAndMessage() {
+ try (SessionContext ctx = new SessionContext();
+ BufferAllocator allocator = new RootAllocator()) {
+ ctx.registerUdf(
+ "boom",
+ new ThrowsIAE(),
+ new ArrowType.Int(32, true),
+ List.of(new ArrowType.Int(32, true)),
+ Volatility.IMMUTABLE);
+ RuntimeException ex =
+ org.junit.jupiter.api.Assertions.assertThrows(
+ RuntimeException.class,
+ () -> {
+ try (DataFrame df = ctx.sql("SELECT boom(CAST(1 AS INT))");
+ ArrowReader r = df.collect(allocator)) {
+ while (r.loadNextBatch()) {}
+ }
+ });
+ String msg = ex.getMessage();
+ org.junit.jupiter.api.Assertions.assertTrue(
+ msg.contains("IllegalArgumentException"),
+ "expected class name in error, got: " + msg);
+ org.junit.jupiter.api.Assertions.assertTrue(
+ msg.contains("custom boom from UDF"),
+ "expected user message in error, got: " + msg);
+ }
+ }
+
+ @Test
+ void twoUdfsInOneSession_bothCallable() throws Exception {
+ try (SessionContext ctx = new SessionContext();
+ BufferAllocator allocator = new RootAllocator()) {
+ ctx.registerUdf(
+ "add_one",
+ new AddOne(),
+ new ArrowType.Int(32, true),
+ List.of(new ArrowType.Int(32, true)),
+ Volatility.IMMUTABLE);
+ ctx.registerUdf(
+ "java_square",
+ new Square(),
+ new ArrowType.FloatingPoint(org.apache.arrow.vector.types.FloatingPointPrecision.DOUBLE),
+ List.of(
+ new ArrowType.FloatingPoint(
+ org.apache.arrow.vector.types.FloatingPointPrecision.DOUBLE)),
+ Volatility.IMMUTABLE);
+
+ try (DataFrame df =
+ ctx.sql(
+ "SELECT add_one(CAST(10 AS INT)) AS a, java_square(CAST(3 AS DOUBLE)) AS b");
+ ArrowReader r = df.collect(allocator)) {
+ assertEquals(true, r.loadNextBatch());
+ VectorSchemaRoot root = r.getVectorSchemaRoot();
+ IntVector a = (IntVector) root.getVector("a");
+ org.apache.arrow.vector.Float8Vector b =
+ (org.apache.arrow.vector.Float8Vector) root.getVector("b");
+ assertEquals(11, a.get(0));
+ assertEquals(9.0, b.get(0), 0.0);
+ }
+ }
+ }
+
+ @Test
+ void registerSameNameAfterCloseInNewSession_works() throws Exception {
+ for (int round = 0; round < 2; round++) {
+ try (SessionContext ctx = new SessionContext();
+ BufferAllocator allocator = new RootAllocator()) {
+ ctx.registerUdf(
+ "add_one",
+ new AddOne(),
+ new ArrowType.Int(32, true),
+ List.of(new ArrowType.Int(32, true)),
+ Volatility.IMMUTABLE);
+ try (DataFrame df = ctx.sql("SELECT add_one(CAST(7 AS INT))");
+ ArrowReader r = df.collect(allocator)) {
+ assertEquals(true, r.loadNextBatch());
+ IntVector v = (IntVector) r.getVectorSchemaRoot().getVector(0);
+ assertEquals(8, v.get(0));
+ }
+ }
+ }
+ }
+
+ @Test
+ void udfAppliedToMultiRowQuery_processesAllRows() throws Exception {
+ try (SessionContext ctx = new SessionContext();
+ BufferAllocator allocator = new RootAllocator()) {
+ ctx.registerUdf(
+ "add_one",
+ new AddOne(),
+ new ArrowType.Int(32, true),
+ List.of(new ArrowType.Int(32, true)),
+ Volatility.IMMUTABLE);
+ String values =
+ java.util.stream.IntStream.rangeClosed(1, 100)
+ .mapToObj(i -> "(CAST(" + i + " AS INT))")
+ .collect(java.util.stream.Collectors.joining(", "));
+ try (DataFrame df =
+ ctx.sql("SELECT add_one(x) AS y FROM (VALUES " + values + ") AS t(x) ORDER BY y");
+ ArrowReader r = df.collect(allocator)) {
+ long total = 0;
+ long rows = 0;
+ while (r.loadNextBatch()) {
+ IntVector y = (IntVector) r.getVectorSchemaRoot().getVector("y");
+ for (int i = 0; i < y.getValueCount(); i++) {
+ total += y.get(i);
+ rows++;
+ }
+ }
+ assertEquals(100, rows);
+ // Sum of 2..101 = (2+101)*100/2 = 5150
+ assertEquals(5150L, total);
+ }
+ }
+ }
+
+ @Test
+ void volatilityBytesRoundTrip_forAllThreeKinds() throws Exception {
+ for (Volatility v : Volatility.values()) {
+ try (SessionContext ctx = new SessionContext();
+ BufferAllocator allocator = new RootAllocator()) {
+ ctx.registerUdf(
+ "add_one_" + v.name().toLowerCase(),
+ new AddOne(),
+ new ArrowType.Int(32, true),
+ List.of(new ArrowType.Int(32, true)),
+ v);
+ try (DataFrame df =
+ ctx.sql("SELECT add_one_" + v.name().toLowerCase() + "(CAST(0 AS INT))");
+ ArrowReader r = df.collect(allocator)) {
+ assertEquals(true, r.loadNextBatch());
+ IntVector y = (IntVector) r.getVectorSchemaRoot().getVector(0);
+ assertEquals(1, y.get(0));
+ }
+ }
+ }
+ }
+}
diff --git a/docs/source/user-guide/index.md b/docs/source/user-guide/index.md
index 2f32499..599340c 100644
--- a/docs/source/user-guide/index.md
+++ b/docs/source/user-guide/index.md
@@ -37,6 +37,7 @@ sessioncontext
dataframe
parquet
proto-plans
+scalar-udf
api-reference
```
diff --git a/docs/source/user-guide/scalar-udf.md b/docs/source/user-guide/scalar-udf.md
new file mode 100644
index 0000000..0d748f2
--- /dev/null
+++ b/docs/source/user-guide/scalar-udf.md
@@ -0,0 +1,97 @@
+
+
+# Scalar UDFs
+
+A scalar UDF is a Java-implemented SQL function that operates on one row at a
+time, expressed in vectorised form: each invocation receives a batch of input
+columns and returns a single output column of the same length.
+
+## Implement
+
+Implement the `ScalarUdf` interface:
+
+```java
+import java.util.List;
+import org.apache.arrow.memory.BufferAllocator;
+import org.apache.arrow.vector.FieldVector;
+import org.apache.arrow.vector.IntVector;
+import org.apache.datafusion.ScalarUdf;
+
+public final class AddOne implements ScalarUdf {
+ @Override
+ public FieldVector evaluate(BufferAllocator allocator, List args) {
+ IntVector in = (IntVector) args.get(0);
+ IntVector out = new IntVector("add_one", allocator);
+ out.allocateNew(in.getValueCount());
+ for (int i = 0; i < in.getValueCount(); i++) {
+ if (in.isNull(i)) out.setNull(i);
+ else out.set(i, in.get(i) + 1);
+ }
+ out.setValueCount(in.getValueCount());
+ return out;
+ }
+}
+```
+
+Allocate any new vectors — including the result — from the supplied
+`BufferAllocator`. The input vectors are read-only views; do not close them.
+Ownership of the returned vector transfers to the framework on return.
+
+## Register
+
+```java
+try (SessionContext ctx = new SessionContext()) {
+ ctx.registerUdf(
+ "add_one",
+ new AddOne(),
+ new ArrowType.Int(32, true),
+ List.of(new ArrowType.Int(32, true)),
+ Volatility.IMMUTABLE);
+
+ try (DataFrame df = ctx.sql("SELECT add_one(x) FROM t");
+ ArrowReader r = df.collect(allocator)) {
+ // ...
+ }
+}
+```
+
+The signature is exact: a call must match the declared `argTypes` exactly. Use
+`Volatility.IMMUTABLE` for pure functions, `STABLE` for functions that are
+deterministic within a single query, and `VOLATILE` for non-deterministic
+functions.
+
+## Errors
+
+If the UDF throws, the exception class and message surface in the
+`RuntimeException` raised from `collect()`. If the returned vector is `null`,
+has the wrong row count, or the wrong type, the runtime raises a
+`RuntimeException` with a descriptive message.
+
+## Threading
+
+DataFusion may invoke a UDF concurrently from multiple worker threads. If the
+implementation carries mutable state, the implementation must synchronize it.
+
+## Limitations (v1)
+
+- Scalar UDFs only — no aggregates, window functions, or table functions.
+- Exact-signature only — no variadic or polymorphic argument lists.
+- No nullable-argument short-circuiting; null inputs are passed through to the
+ UDF as nulls in the input vector.
diff --git a/examples/src/main/java/org/apache/datafusion/examples/AddOneExample.java b/examples/src/main/java/org/apache/datafusion/examples/AddOneExample.java
new file mode 100644
index 0000000..e1ef1c5
--- /dev/null
+++ b/examples/src/main/java/org/apache/datafusion/examples/AddOneExample.java
@@ -0,0 +1,82 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.datafusion.examples;
+
+import java.util.List;
+import org.apache.arrow.memory.BufferAllocator;
+import org.apache.arrow.memory.RootAllocator;
+import org.apache.arrow.vector.FieldVector;
+import org.apache.arrow.vector.IntVector;
+import org.apache.arrow.vector.VectorSchemaRoot;
+import org.apache.arrow.vector.ipc.ArrowReader;
+import org.apache.arrow.vector.types.pojo.ArrowType;
+import org.apache.datafusion.DataFrame;
+import org.apache.datafusion.ScalarUdf;
+import org.apache.datafusion.SessionContext;
+import org.apache.datafusion.Volatility;
+
+/** Demonstrates registering a Java scalar UDF and invoking it from SQL. */
+public final class AddOneExample {
+
+ /** Adds 1 to each value of an Int32 column. */
+ public static final class AddOne implements ScalarUdf {
+ @Override
+ public FieldVector evaluate(BufferAllocator allocator, List args) {
+ IntVector in = (IntVector) args.get(0);
+ IntVector out = new IntVector("add_one_out", allocator);
+ int n = in.getValueCount();
+ out.allocateNew(n);
+ for (int i = 0; i < n; i++) {
+ if (in.isNull(i)) {
+ out.setNull(i);
+ } else {
+ out.set(i, in.get(i) + 1);
+ }
+ }
+ out.setValueCount(n);
+ return out;
+ }
+ }
+
+ public static void main(String[] args) throws Exception {
+ try (SessionContext ctx = new SessionContext();
+ BufferAllocator allocator = new RootAllocator()) {
+ ctx.registerUdf(
+ "add_one",
+ new AddOne(),
+ new ArrowType.Int(32, true),
+ List.of(new ArrowType.Int(32, true)),
+ Volatility.IMMUTABLE);
+
+ try (DataFrame df =
+ ctx.sql(
+ "SELECT add_one(x) AS y FROM (VALUES (CAST(1 AS INT)),(CAST(2 AS INT)),(CAST(3 AS INT))) AS t(x)");
+ ArrowReader reader = df.collect(allocator)) {
+ while (reader.loadNextBatch()) {
+ VectorSchemaRoot root = reader.getVectorSchemaRoot();
+ IntVector y = (IntVector) root.getVector("y");
+ for (int i = 0; i < y.getValueCount(); i++) {
+ System.out.println("y = " + y.get(i));
+ }
+ }
+ }
+ }
+ }
+}
diff --git a/native/src/lib.rs b/native/src/lib.rs
index 08a919b..81efd9a 100644
--- a/native/src/lib.rs
+++ b/native/src/lib.rs
@@ -19,6 +19,7 @@ mod csv;
mod errors;
mod proto;
mod schema;
+mod udf;
pub(crate) mod proto_gen {
include!(concat!(env!("OUT_DIR"), "/datafusion_java.rs"));
@@ -35,10 +36,12 @@ use datafusion::dataframe::DataFrame;
use datafusion::dataframe::DataFrameWriteOptions;
use datafusion::error::DataFusionError;
use datafusion::execution::runtime_env::RuntimeEnvBuilder;
+use datafusion::logical_expr::{ScalarUDF, Signature};
use datafusion::prelude::{ParquetReadOptions, SessionConfig, SessionContext};
-use jni::objects::{JByteArray, JClass, JObjectArray, JString};
+use jni::objects::{JByteArray, JClass, JObject, JObjectArray, JString};
use jni::sys::{jboolean, jint, jlong};
use jni::JNIEnv;
+use jni::JavaVM;
use prost::Message;
use tokio::runtime::Runtime;
@@ -47,6 +50,21 @@ use crate::proto_gen::ParquetReadOptionsProto;
use crate::proto_gen::SessionOptions;
use crate::schema::decode_optional_schema;
+static JAVA_VM: OnceLock = OnceLock::new();
+
+#[no_mangle]
+pub extern "system" fn JNI_OnLoad(vm: JavaVM, _reserved: *mut std::ffi::c_void) -> jni::sys::jint {
+ let _ = JAVA_VM.set(vm);
+ jni::sys::JNI_VERSION_1_8
+}
+
+#[allow(dead_code)]
+pub(crate) fn jvm() -> &'static JavaVM {
+ JAVA_VM
+ .get()
+ .expect("JNI_OnLoad has not been called; JavaVM unavailable")
+}
+
pub(crate) fn runtime() -> &'static Runtime {
static RT: OnceLock = OnceLock::new();
RT.get_or_init(|| Runtime::new().expect("failed to create Tokio runtime"))
@@ -475,3 +493,61 @@ pub extern "system" fn Java_org_apache_datafusion_SessionContext_readParquetWith
})
})
}
+
+#[no_mangle]
+pub extern "system" fn Java_org_apache_datafusion_SessionContext_registerScalarUdf<'local>(
+ mut env: JNIEnv<'local>,
+ _class: JClass<'local>,
+ handle: jlong,
+ name: JString<'local>,
+ signature_schema_bytes: JByteArray<'local>,
+ volatility: jni::sys::jbyte,
+ udf: JObject<'local>,
+) {
+ try_unwrap_or_throw(&mut env, (), |env| -> JniResult<()> {
+ if handle == 0 {
+ return Err("SessionContext handle is null".into());
+ }
+ // SAFETY: handle is a valid Box allocated by createSessionContext.
+ let ctx = unsafe { &*(handle as *const SessionContext) };
+ let name: String = env.get_string(&name)?.into();
+
+ // Decode the signature schema (field 0 = return type, fields 1..N = arg types).
+ let signature_schema = crate::schema::decode_optional_schema(env, signature_schema_bytes)?
+ .ok_or("signature schema bytes were null")?;
+ let fields = signature_schema.fields();
+ if fields.is_empty() {
+ return Err("signature schema must have at least a return-type field".into());
+ }
+ let return_type = fields[0].data_type().clone();
+ let arg_types: Vec = fields
+ .iter()
+ .skip(1)
+ .map(|f| f.data_type().clone())
+ .collect();
+
+ let volatility = crate::udf::volatility_from_byte(volatility as u8)?;
+ let signature = Signature::exact(arg_types, volatility);
+
+ // Hold references that survive the JNI call.
+ let udf_global_ref = env.new_global_ref(&udf)?;
+ let bridge_class_local = env.find_class("org/apache/datafusion/internal/JniBridge")?;
+ let bridge_class = env.new_global_ref(&bridge_class_local)?;
+ let invoke_method = env.get_static_method_id(
+ &bridge_class_local,
+ "invokeScalarUdf",
+ "(Lorg/apache/datafusion/ScalarUdf;JJJJI)V",
+ )?;
+
+ let java_udf = crate::udf::JavaScalarUdf {
+ name: name.clone(),
+ signature,
+ return_type,
+ udf_global_ref,
+ bridge_class,
+ invoke_method,
+ };
+ ctx.register_udf(ScalarUDF::new_from_impl(java_udf));
+ Ok(())
+ })
+}
diff --git a/native/src/udf.rs b/native/src/udf.rs
new file mode 100644
index 0000000..cf19b0b
--- /dev/null
+++ b/native/src/udf.rs
@@ -0,0 +1,293 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+//! Java-backed scalar UDF support.
+
+use std::any::Any;
+use std::fmt;
+use std::sync::Arc;
+
+use datafusion::arrow::array::{make_array, Array, ArrayRef, StructArray};
+use datafusion::arrow::datatypes::{DataType, Field, Fields};
+use datafusion::arrow::ffi::{from_ffi, to_ffi, FFI_ArrowArray, FFI_ArrowSchema};
+use datafusion::error::DataFusionError;
+use datafusion::logical_expr::{
+ ColumnarValue, ScalarFunctionArgs, ScalarUDFImpl, Signature, TypeSignature, Volatility,
+};
+use jni::objects::{GlobalRef, JStaticMethodID, JThrowable};
+use jni::signature::{Primitive, ReturnType};
+use jni::sys::{jlong, jvalue};
+use jni::JNIEnv;
+
+pub(crate) struct JavaScalarUdf {
+ pub(crate) name: String,
+ pub(crate) signature: Signature,
+ pub(crate) return_type: DataType,
+ /// Global ref to the user's `org.apache.datafusion.ScalarUdf` instance.
+ pub(crate) udf_global_ref: GlobalRef,
+ /// Global ref to the `org.apache.datafusion.internal.JniBridge` class.
+ pub(crate) bridge_class: GlobalRef,
+ /// Method ID for `JniBridge.invokeScalarUdf`.
+ pub(crate) invoke_method: JStaticMethodID,
+}
+
+// SAFETY: JStaticMethodID is a JNI handle that's safe to share because the
+// class it points to is held alive by `bridge_class`. We never mutate
+// `invoke_method` after construction; DataFusion requires `Send + Sync` on
+// `ScalarUDFImpl`.
+unsafe impl Send for JavaScalarUdf {}
+unsafe impl Sync for JavaScalarUdf {}
+
+impl fmt::Debug for JavaScalarUdf {
+ fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
+ f.debug_struct("JavaScalarUdf")
+ .field("name", &self.name)
+ .field("return_type", &self.return_type)
+ .finish()
+ }
+}
+
+impl PartialEq for JavaScalarUdf {
+ fn eq(&self, other: &Self) -> bool {
+ // Two Java UDFs are equal iff they wrap the same registered name.
+ self.name == other.name
+ }
+}
+
+impl Eq for JavaScalarUdf {}
+
+impl std::hash::Hash for JavaScalarUdf {
+ fn hash(&self, state: &mut H) {
+ self.name.hash(state);
+ }
+}
+
+impl ScalarUDFImpl for JavaScalarUdf {
+ fn as_any(&self) -> &dyn Any {
+ self
+ }
+
+ fn name(&self) -> &str {
+ &self.name
+ }
+
+ fn signature(&self) -> &Signature {
+ &self.signature
+ }
+
+ fn return_type(&self, _arg_types: &[DataType]) -> datafusion::error::Result {
+ Ok(self.return_type.clone())
+ }
+
+ fn invoke_with_args(
+ &self,
+ args: ScalarFunctionArgs,
+ ) -> datafusion::error::Result {
+ let number_rows = args.number_rows;
+
+ // 1. Materialise scalars to arrays so all columns are length-N.
+ let arrays: Vec = args
+ .args
+ .iter()
+ .map(|cv| cv.clone().into_array(number_rows))
+ .collect::>>()?;
+
+ // 2. Build a single struct array carrying all arg columns. Field names/types come
+ // from the signature's Exact type list (matches what the Java caller declared).
+ let signature_fields: Vec> = match &self.signature.type_signature {
+ TypeSignature::Exact(types) => types
+ .iter()
+ .enumerate()
+ .map(|(i, ty)| Arc::new(Field::new(format!("arg{}", i), ty.clone(), true)))
+ .collect(),
+ _ => {
+ return Err(DataFusionError::Internal(
+ "JavaScalarUdf signature is not Exact; only Signature::exact is supported"
+ .to_string(),
+ ))
+ }
+ };
+
+ let fields = Fields::from(
+ signature_fields
+ .iter()
+ .map(|f| f.as_ref().clone())
+ .collect::>(),
+ );
+ let struct_array = StructArray::try_new_with_length(fields, arrays, None, number_rows)
+ .map_err(|e| DataFusionError::ArrowError(Box::new(e), None))?;
+ let args_data = struct_array.into_data();
+ let (args_ffi_array, args_ffi_schema) =
+ to_ffi(&args_data).map_err(|e| DataFusionError::ArrowError(Box::new(e), None))?;
+
+ // 3. Pre-allocate empty FFI structs for the result.
+ let result_ffi_array = FFI_ArrowArray::empty();
+ let result_ffi_schema = FFI_ArrowSchema::empty();
+
+ // 4. Box for stable addresses across the JNI call.
+ let mut args_array_box = Box::new(args_ffi_array);
+ let mut args_schema_box = Box::new(args_ffi_schema);
+ let mut result_array_box = Box::new(result_ffi_array);
+ let mut result_schema_box = Box::new(result_ffi_schema);
+
+ let args_array_addr = args_array_box.as_mut() as *mut _ as jlong;
+ let args_schema_addr = args_schema_box.as_mut() as *mut _ as jlong;
+ let result_array_addr = result_array_box.as_mut() as *mut _ as jlong;
+ let result_schema_addr = result_schema_box.as_mut() as *mut _ as jlong;
+
+ // 5. Attach JNI to current thread.
+ let mut env = crate::jvm()
+ .attach_current_thread()
+ .map_err(|e| DataFusionError::Execution(format!("JNI attach failed: {}", e)))?;
+
+ // 6. Call JniBridge.invokeScalarUdf(udf, args*, result*, expectedRowCount).
+ //
+ // Build the jvalue argument array for call_static_method_unchecked.
+ // SAFETY: we build the args inline and pass them immediately; the JObject
+ // pointed to by udf_global_ref is alive for the duration of this call.
+ let expected_rows = i32::try_from(number_rows).map_err(|_| {
+ DataFusionError::Execution(format!(
+ "batch row count {} exceeds i32::MAX; UDFs cannot handle batches larger than 2^31 - 1 rows",
+ number_rows
+ ))
+ })?;
+
+ let udf_jobject = self.udf_global_ref.as_obj();
+ // SAFETY: udf_jobject is derived from a GlobalRef alive for the duration of this
+ // function. The raw pointer is only read by the JNI call below, which happens
+ // before any code that could drop udf_global_ref.
+ let call_args: [jvalue; 6] = [
+ // ScalarUdf instance
+ jvalue {
+ l: udf_jobject.as_raw(),
+ },
+ // argsArrayAddr
+ jvalue { j: args_array_addr },
+ // argsSchemaAddr
+ jvalue {
+ j: args_schema_addr,
+ },
+ // resultArrayAddr
+ jvalue {
+ j: result_array_addr,
+ },
+ // resultSchemaAddr
+ jvalue {
+ j: result_schema_addr,
+ },
+ // expectedRowCount
+ jvalue { i: expected_rows },
+ ];
+
+ let call_result = unsafe {
+ env.call_static_method_unchecked(
+ &self.bridge_class,
+ self.invoke_method,
+ ReturnType::Primitive(Primitive::Void),
+ &call_args,
+ )
+ };
+
+ // 7. If Java threw, translate to DataFusionError. Always check exception_check first.
+ if env.exception_check().unwrap_or(false) {
+ let throwable = env.exception_occurred().map_err(|e| {
+ DataFusionError::Execution(format!("exception_occurred failed: {}", e))
+ })?;
+ env.exception_clear().ok();
+ let message = jthrowable_to_string(&mut env, &throwable, &self.name);
+ return Err(DataFusionError::Execution(message));
+ }
+ call_result.map_err(|e| DataFusionError::Execution(format!("JNI call failed: {}", e)))?;
+
+ // 8. Import result. from_ffi consumes the FFI_ArrowArray.
+ let result_array = *result_array_box;
+ let result_schema = *result_schema_box;
+ // SAFETY: Java's `Data.exportVector` populated `result_array_box` and
+ // `result_schema_box` in place via the C Data Interface, and the
+ // exception check above guarantees the call succeeded without
+ // throwing — so the FFI structs are fully initialized.
+ let result_data = unsafe { from_ffi(result_array, &result_schema) }
+ .map_err(|e| DataFusionError::ArrowError(Box::new(e), None))?;
+
+ // 9. Validate type.
+ if result_data.data_type() != &self.return_type {
+ return Err(DataFusionError::Execution(format!(
+ "Java UDF '{}' returned vector of type {:?}; declared return type was {:?}",
+ self.name,
+ result_data.data_type(),
+ self.return_type
+ )));
+ }
+
+ let array: ArrayRef = make_array(result_data);
+ Ok(ColumnarValue::Array(array))
+ }
+}
+
+pub(crate) fn volatility_from_byte(byte: u8) -> datafusion::error::Result {
+ match byte {
+ 0 => Ok(Volatility::Immutable),
+ 1 => Ok(Volatility::Stable),
+ 2 => Ok(Volatility::Volatile),
+ other => Err(DataFusionError::Execution(format!(
+ "unknown volatility byte: {}",
+ other
+ ))),
+ }
+}
+
+/// Best-effort: extract class name and getMessage() from a Java throwable.
+/// Anything that goes wrong collapses to a generic message so we don't
+/// double-throw inside an error path.
+fn jthrowable_to_string(env: &mut JNIEnv, throwable: &JThrowable, udf_name: &str) -> String {
+ let class_name_result = (|| -> jni::errors::Result {
+ let class = env.call_method(throwable, "getClass", "()Ljava/lang/Class;", &[])?;
+ let class_obj = class.l()?;
+ let name = env.call_method(&class_obj, "getName", "()Ljava/lang/String;", &[])?;
+ let name_obj = name.l()?;
+ let name_str: String = env.get_string(&name_obj.into())?.into();
+ Ok(name_str)
+ })();
+ let class_name = match class_name_result {
+ Ok(s) => s,
+ Err(_) => {
+ // A reflective call itself threw — clear that secondary exception so the
+ // thread is in a clean state when we return to the JVM.
+ env.exception_clear().ok();
+ "".to_string()
+ }
+ };
+
+ let message_result = (|| -> jni::errors::Result {
+ let msg = env.call_method(throwable, "getMessage", "()Ljava/lang/String;", &[])?;
+ let msg_obj = msg.l()?;
+ if msg_obj.is_null() {
+ return Ok("".to_string());
+ }
+ let s: String = env.get_string(&msg_obj.into())?.into();
+ Ok(s)
+ })();
+ let message = match message_result {
+ Ok(s) => s,
+ Err(_) => {
+ env.exception_clear().ok();
+ "".to_string()
+ }
+ };
+
+ format!("Java UDF '{}' threw {}: {}", udf_name, class_name, message)
+}