diff --git a/core/src/main/java/org/apache/datafusion/ScalarUdf.java b/core/src/main/java/org/apache/datafusion/ScalarUdf.java new file mode 100644 index 0000000..c546ee1 --- /dev/null +++ b/core/src/main/java/org/apache/datafusion/ScalarUdf.java @@ -0,0 +1,49 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.datafusion; + +import java.util.List; + +import org.apache.arrow.memory.BufferAllocator; +import org.apache.arrow.vector.FieldVector; + +/** + * A Java-implemented scalar SQL function. Register an instance with {@link + * SessionContext#registerUdf} to make it callable from SQL or DataFrame plans. + * + *

Implementations may be invoked concurrently by DataFusion on multiple worker threads. If the + * implementation carries mutable state, the implementation must synchronize it. + */ +@FunctionalInterface +public interface ScalarUdf { + /** + * Compute the function result for one input batch. + * + * @param allocator the {@link BufferAllocator} that MUST be used for any new {@link FieldVector} + * allocation, including the result. Buffers allocated from other allocators will not survive + * the JNI handoff. + * @param args one {@link FieldVector} per declared argument, all of the same length. These are + * read-only views; the implementation must NOT close them. + * @return a {@link FieldVector} of the declared return type and the same length as the inputs. + * Ownership transfers to the framework on return; the implementation must NOT close the + * returned vector. + */ + FieldVector evaluate(BufferAllocator allocator, List args); +} diff --git a/core/src/main/java/org/apache/datafusion/SessionContext.java b/core/src/main/java/org/apache/datafusion/SessionContext.java index 3cea058..6cba42c 100644 --- a/core/src/main/java/org/apache/datafusion/SessionContext.java +++ b/core/src/main/java/org/apache/datafusion/SessionContext.java @@ -23,6 +23,8 @@ import java.io.ByteArrayOutputStream; import java.io.IOException; import java.nio.channels.Channels; +import java.util.ArrayList; +import java.util.List; import org.apache.arrow.memory.BufferAllocator; import org.apache.arrow.memory.RootAllocator; @@ -30,6 +32,9 @@ import org.apache.arrow.vector.ipc.ArrowStreamWriter; import org.apache.arrow.vector.ipc.ReadChannel; import org.apache.arrow.vector.ipc.message.MessageSerializer; +import org.apache.arrow.vector.types.pojo.ArrowType; +import org.apache.arrow.vector.types.pojo.Field; +import org.apache.arrow.vector.types.pojo.FieldType; import org.apache.arrow.vector.types.pojo.Schema; /** @@ -204,6 +209,41 @@ public DataFrame readParquet(String path, ParquetReadOptions options) { return new DataFrame(dfHandle); } + /** + * Register a Java-implemented scalar UDF. After registration, the function can be invoked by SQL + * via its {@code name} or referenced in DataFusion plans deserialised with {@link #fromProto}. + * + *

Argument and return types are declared at registration time. The UDF is registered with an + * exact signature: the runtime will reject calls whose argument types do not match {@code + * argTypes} exactly. + * + * @throws RuntimeException if registration fails (e.g., name already registered with an + * incompatible signature, schema serialisation failure). + */ + public void registerUdf( + String name, + ScalarUdf udf, + ArrowType returnType, + List argTypes, + Volatility volatility) { + if (nativeHandle == 0) { + throw new IllegalStateException("SessionContext is closed"); + } + java.util.Objects.requireNonNull(name, "name"); + java.util.Objects.requireNonNull(udf, "udf"); + java.util.Objects.requireNonNull(returnType, "returnType"); + java.util.Objects.requireNonNull(argTypes, "argTypes"); + java.util.Objects.requireNonNull(volatility, "volatility"); + List fields = new ArrayList<>(argTypes.size() + 1); + fields.add(new Field("return", FieldType.nullable(returnType), null)); + for (int i = 0; i < argTypes.size(); i++) { + fields.add(new Field("arg" + i, FieldType.nullable(argTypes.get(i)), null)); + } + Schema signatureSchema = new Schema(fields); + byte[] signatureBytes = serializeSchemaIpc(signatureSchema); + registerScalarUdf(nativeHandle, name, signatureBytes, volatility.code(), udf); + } + private static byte[] serializeSchemaIpc(Schema schema) { ByteArrayOutputStream baos = new ByteArrayOutputStream(); try (BufferAllocator allocator = new RootAllocator(); @@ -248,4 +288,7 @@ private static native long readCsvWithOptions( long handle, String path, byte[] optionsBytes, byte[] schemaIpcBytes); private static native void closeSessionContext(long handle); + + private static native void registerScalarUdf( + long handle, String name, byte[] signatureSchemaBytes, byte volatility, ScalarUdf udf); } diff --git a/core/src/main/java/org/apache/datafusion/Volatility.java b/core/src/main/java/org/apache/datafusion/Volatility.java new file mode 100644 index 0000000..eea2fa3 --- /dev/null +++ b/core/src/main/java/org/apache/datafusion/Volatility.java @@ -0,0 +1,48 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.datafusion; + +/** + * Volatility classification for a UDF. Mirrors DataFusion's {@code Volatility} enum. + * + *

+ */ +public enum Volatility { + IMMUTABLE((byte) 0), + STABLE((byte) 1), + VOLATILE((byte) 2); + + private final byte code; + + Volatility(byte code) { + this.code = code; + } + + /** Stable byte code for FFI. */ + public byte code() { + return code; + } +} diff --git a/core/src/main/java/org/apache/datafusion/internal/JniBridge.java b/core/src/main/java/org/apache/datafusion/internal/JniBridge.java new file mode 100644 index 0000000..1dd4389 --- /dev/null +++ b/core/src/main/java/org/apache/datafusion/internal/JniBridge.java @@ -0,0 +1,93 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.datafusion.internal; + +import java.util.List; + +import org.apache.arrow.c.ArrowArray; +import org.apache.arrow.c.ArrowSchema; +import org.apache.arrow.c.Data; +import org.apache.arrow.memory.RootAllocator; +import org.apache.arrow.vector.FieldVector; +import org.apache.arrow.vector.VectorSchemaRoot; +import org.apache.datafusion.ScalarUdf; + +/** Internal trampoline invoked from native code on every UDF call. Not part of the public API. */ +public final class JniBridge { + /** + * Shared allocator for UDF inputs/outputs. Created once at class-load time; never closed. + * Outstanding allocations are released by the FFI structs' release callbacks when the native side + * drops them after {@code from_ffi}. + */ + private static final RootAllocator ALLOCATOR = new RootAllocator(); + + private JniBridge() {} + + /** + * Invoke a scalar UDF for one batch. Called from native code; not for application use. + * + * @param udf the registered {@link ScalarUdf} instance + * @param argsArrayAddr address of a populated {@code FFI_ArrowArray} struct holding the input + * batch as a struct array (one field per UDF argument) + * @param argsSchemaAddr address of the matching {@code FFI_ArrowSchema} + * @param resultArrayAddr address of an empty {@code FFI_ArrowArray} the bridge writes into + * @param resultSchemaAddr address of an empty {@code FFI_ArrowSchema} the bridge writes into + * @param expectedRowCount the row count the result vector must have + */ + public static void invokeScalarUdf( + ScalarUdf udf, + long argsArrayAddr, + long argsSchemaAddr, + long resultArrayAddr, + long resultSchemaAddr, + int expectedRowCount) { + ArrowArray argsArr = ArrowArray.wrap(argsArrayAddr); + ArrowSchema argsSch = ArrowSchema.wrap(argsSchemaAddr); + ArrowArray resultArr = ArrowArray.wrap(resultArrayAddr); + ArrowSchema resultSch = ArrowSchema.wrap(resultSchemaAddr); + + try (VectorSchemaRoot root = Data.importVectorSchemaRoot(ALLOCATOR, argsArr, argsSch, null)) { + List argVectors = root.getFieldVectors(); + + FieldVector result = udf.evaluate(ALLOCATOR, argVectors); + + if (result == null) { + throw new IllegalStateException("ScalarUdf.evaluate returned null"); + } + if (result.getValueCount() != expectedRowCount) { + try { + throw new IllegalStateException( + "ScalarUdf.evaluate returned vector with " + + result.getValueCount() + + " rows; expected " + + expectedRowCount); + } finally { + result.close(); + } + } + + try { + Data.exportVector(ALLOCATOR, result, null, resultArr, resultSch); + } finally { + result.close(); + } + } + } +} diff --git a/core/src/test/java/org/apache/datafusion/DataFrameTransformationsTest.java b/core/src/test/java/org/apache/datafusion/DataFrameTransformationsTest.java index cb5c9ef..6b1ed20 100644 --- a/core/src/test/java/org/apache/datafusion/DataFrameTransformationsTest.java +++ b/core/src/test/java/org/apache/datafusion/DataFrameTransformationsTest.java @@ -246,8 +246,7 @@ void limitRejectsNegativeArgs() { @Test void distinctRemovesDuplicates() { try (SessionContext ctx = new SessionContext(); - DataFrame source = - ctx.sql("SELECT * FROM (VALUES (1), (1), (2), (2), (3)) AS t(x)"); + DataFrame source = ctx.sql("SELECT * FROM (VALUES (1), (1), (2), (2), (3)) AS t(x)"); DataFrame deduped = source.distinct()) { assertEquals(3L, deduped.count()); } diff --git a/core/src/test/java/org/apache/datafusion/ScalarUdfTest.java b/core/src/test/java/org/apache/datafusion/ScalarUdfTest.java new file mode 100644 index 0000000..05f263d --- /dev/null +++ b/core/src/test/java/org/apache/datafusion/ScalarUdfTest.java @@ -0,0 +1,464 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.datafusion; + +import static org.junit.jupiter.api.Assertions.assertEquals; + +import java.util.List; + +import org.apache.arrow.memory.BufferAllocator; +import org.apache.arrow.memory.RootAllocator; +import org.apache.arrow.vector.FieldVector; +import org.apache.arrow.vector.IntVector; +import org.apache.arrow.vector.VectorSchemaRoot; +import org.apache.arrow.vector.ipc.ArrowReader; +import org.apache.arrow.vector.types.pojo.ArrowType; +import org.junit.jupiter.api.Test; + +class ScalarUdfTest { + + /** Adds 1 to each row of an Int32 column. */ + static final class AddOne implements ScalarUdf { + @Override + public FieldVector evaluate(BufferAllocator allocator, List args) { + IntVector in = (IntVector) args.get(0); + IntVector out = new IntVector("add_one_out", allocator); + int n = in.getValueCount(); + out.allocateNew(n); + for (int i = 0; i < n; i++) { + if (in.isNull(i)) { + out.setNull(i); + } else { + out.set(i, in.get(i) + 1); + } + } + out.setValueCount(n); + return out; + } + } + + @Test + void addOne_overConstantTable_returnsIncrementedValues() throws Exception { + try (SessionContext ctx = new SessionContext(); + BufferAllocator allocator = new RootAllocator()) { + ctx.registerUdf( + "add_one", + new AddOne(), + new ArrowType.Int(32, true), + List.of(new ArrowType.Int(32, true)), + Volatility.IMMUTABLE); + + try (DataFrame df = + ctx.sql( + "SELECT add_one(x) AS y" + + " FROM (VALUES (CAST(1 AS INT)), (CAST(2 AS INT)), (CAST(3 AS INT)))" + + " AS t(x)"); + ArrowReader r = df.collect(allocator)) { + assertEquals(true, r.loadNextBatch()); + VectorSchemaRoot root = r.getVectorSchemaRoot(); + IntVector y = (IntVector) root.getVector("y"); + assertEquals(3, y.getValueCount()); + assertEquals(2, y.get(0)); + assertEquals(3, y.get(1)); + assertEquals(4, y.get(2)); + } + } + } + + /** Concatenates two Utf8 columns. */ + static final class Concat implements ScalarUdf { + @Override + public FieldVector evaluate(BufferAllocator allocator, List args) { + org.apache.arrow.vector.VarCharVector left = (org.apache.arrow.vector.VarCharVector) args.get(0); + org.apache.arrow.vector.VarCharVector right = (org.apache.arrow.vector.VarCharVector) args.get(1); + org.apache.arrow.vector.VarCharVector out = + new org.apache.arrow.vector.VarCharVector("concat_out", allocator); + int n = left.getValueCount(); + out.allocateNew(n); + for (int i = 0; i < n; i++) { + if (left.isNull(i) || right.isNull(i)) { + out.setNull(i); + } else { + byte[] l = left.get(i); + byte[] r = right.get(i); + byte[] both = new byte[l.length + r.length]; + System.arraycopy(l, 0, both, 0, l.length); + System.arraycopy(r, 0, both, l.length, r.length); + out.setSafe(i, both); + } + } + out.setValueCount(n); + return out; + } + } + + @Test + void concat_overVarCharColumns_concatenatesValues() throws Exception { + try (SessionContext ctx = new SessionContext(); + BufferAllocator allocator = new RootAllocator()) { + ctx.registerUdf( + "java_concat", + new Concat(), + new ArrowType.Utf8(), + List.of(new ArrowType.Utf8(), new ArrowType.Utf8()), + Volatility.IMMUTABLE); + + try (DataFrame df = + ctx.sql( + "SELECT java_concat(a, b) AS c FROM (VALUES ('foo','bar'),('hello','!')) AS t(a, b)"); + ArrowReader r = df.collect(allocator)) { + assertEquals(true, r.loadNextBatch()); + VectorSchemaRoot root = r.getVectorSchemaRoot(); + org.apache.arrow.vector.VarCharVector c = + (org.apache.arrow.vector.VarCharVector) root.getVector("c"); + assertEquals(2, c.getValueCount()); + assertEquals("foobar", new String(c.get(0))); + assertEquals("hello!", new String(c.get(1))); + } + } + } + + /** Squares a Float64 column. */ + static final class Square implements ScalarUdf { + @Override + public FieldVector evaluate(BufferAllocator allocator, List args) { + org.apache.arrow.vector.Float8Vector in = (org.apache.arrow.vector.Float8Vector) args.get(0); + org.apache.arrow.vector.Float8Vector out = + new org.apache.arrow.vector.Float8Vector("square_out", allocator); + int n = in.getValueCount(); + out.allocateNew(n); + for (int i = 0; i < n; i++) { + if (in.isNull(i)) { + out.setNull(i); + } else { + double v = in.get(i); + out.set(i, v * v); + } + } + out.setValueCount(n); + return out; + } + } + + @Test + void square_overFloat64Column_squaresValues() throws Exception { + try (SessionContext ctx = new SessionContext(); + BufferAllocator allocator = new RootAllocator()) { + ctx.registerUdf( + "java_square", + new Square(), + new ArrowType.FloatingPoint(org.apache.arrow.vector.types.FloatingPointPrecision.DOUBLE), + List.of( + new ArrowType.FloatingPoint( + org.apache.arrow.vector.types.FloatingPointPrecision.DOUBLE)), + Volatility.IMMUTABLE); + + try (DataFrame df = + ctx.sql("SELECT java_square(x) AS y FROM (VALUES (2.0),(3.5)) AS t(x)"); + ArrowReader r = df.collect(allocator)) { + assertEquals(true, r.loadNextBatch()); + VectorSchemaRoot root = r.getVectorSchemaRoot(); + org.apache.arrow.vector.Float8Vector y = + (org.apache.arrow.vector.Float8Vector) root.getVector("y"); + assertEquals(2, y.getValueCount()); + assertEquals(4.0, y.get(0), 0.0); + assertEquals(12.25, y.get(1), 0.0); + } + } + } + + @Test + void addOne_invokedTwiceInOneSession_executesIndependently() throws Exception { + // Re-running the same UDF query twice exercises that GlobalRefs and JNI state + // don't accumulate across invocations within a session. + try (SessionContext ctx = new SessionContext(); + BufferAllocator allocator = new RootAllocator()) { + ctx.registerUdf( + "add_one", + new AddOne(), + new ArrowType.Int(32, true), + List.of(new ArrowType.Int(32, true)), + Volatility.IMMUTABLE); + + for (int run = 0; run < 2; run++) { + try (DataFrame df = ctx.sql("SELECT add_one(CAST(5 AS INT)) AS y"); + ArrowReader r = df.collect(allocator)) { + assertEquals(true, r.loadNextBatch()); + IntVector y = (IntVector) r.getVectorSchemaRoot().getVector("y"); + assertEquals(1, y.getValueCount()); + assertEquals(6, y.get(0)); + } + } + } + } + + static final class ReturnsNull implements ScalarUdf { + @Override + public FieldVector evaluate(BufferAllocator allocator, List args) { + return null; + } + } + + @Test + void udfReturningNull_surfacesIllegalStateException() { + try (SessionContext ctx = new SessionContext(); + BufferAllocator allocator = new RootAllocator()) { + ctx.registerUdf( + "bad_null", + new ReturnsNull(), + new ArrowType.Int(32, true), + List.of(new ArrowType.Int(32, true)), + Volatility.IMMUTABLE); + RuntimeException ex = + org.junit.jupiter.api.Assertions.assertThrows( + RuntimeException.class, + () -> { + try (DataFrame df = ctx.sql("SELECT bad_null(CAST(1 AS INT))"); + ArrowReader r = df.collect(allocator)) { + while (r.loadNextBatch()) {} + } + }); + org.junit.jupiter.api.Assertions.assertTrue( + ex.getMessage().contains("returned null"), + "expected error to mention 'returned null', got: " + ex.getMessage()); + } + } + + static final class WrongRowCount implements ScalarUdf { + @Override + public FieldVector evaluate(BufferAllocator allocator, List args) { + IntVector in = (IntVector) args.get(0); + IntVector out = new IntVector("out", allocator); + out.allocateNew(in.getValueCount() + 1); // off by one + for (int i = 0; i < in.getValueCount() + 1; i++) out.set(i, 0); + out.setValueCount(in.getValueCount() + 1); + return out; + } + } + + @Test + void udfReturningWrongRowCount_surfacesIllegalStateException() { + try (SessionContext ctx = new SessionContext(); + BufferAllocator allocator = new RootAllocator()) { + ctx.registerUdf( + "bad_rows", + new WrongRowCount(), + new ArrowType.Int(32, true), + List.of(new ArrowType.Int(32, true)), + Volatility.IMMUTABLE); + RuntimeException ex = + org.junit.jupiter.api.Assertions.assertThrows( + RuntimeException.class, + () -> { + try (DataFrame df = ctx.sql("SELECT bad_rows(CAST(1 AS INT))"); + ArrowReader r = df.collect(allocator)) { + while (r.loadNextBatch()) {} + } + }); + org.junit.jupiter.api.Assertions.assertTrue( + ex.getMessage().contains("expected") && ex.getMessage().contains("rows"), + "expected error to mention row mismatch, got: " + ex.getMessage()); + } + } + + static final class WrongType implements ScalarUdf { + @Override + public FieldVector evaluate(BufferAllocator allocator, List args) { + // Declared return type is Int32; return Float64. + org.apache.arrow.vector.Float8Vector out = + new org.apache.arrow.vector.Float8Vector("out", allocator); + out.allocateNew(args.get(0).getValueCount()); + for (int i = 0; i < args.get(0).getValueCount(); i++) out.set(i, 0.0); + out.setValueCount(args.get(0).getValueCount()); + return out; + } + } + + @Test + void udfReturningWrongType_surfacesTypeMismatch() { + try (SessionContext ctx = new SessionContext(); + BufferAllocator allocator = new RootAllocator()) { + ctx.registerUdf( + "bad_type", + new WrongType(), + new ArrowType.Int(32, true), + List.of(new ArrowType.Int(32, true)), + Volatility.IMMUTABLE); + RuntimeException ex = + org.junit.jupiter.api.Assertions.assertThrows( + RuntimeException.class, + () -> { + try (DataFrame df = ctx.sql("SELECT bad_type(CAST(1 AS INT))"); + ArrowReader r = df.collect(allocator)) { + while (r.loadNextBatch()) {} + } + }); + org.junit.jupiter.api.Assertions.assertTrue( + ex.getMessage().toLowerCase().contains("type"), + "expected error to mention type mismatch, got: " + ex.getMessage()); + } + } + + static final class ThrowsIAE implements ScalarUdf { + @Override + public FieldVector evaluate(BufferAllocator allocator, List args) { + throw new IllegalArgumentException("custom boom from UDF"); + } + } + + @Test + void udfThrowingException_propagatesClassAndMessage() { + try (SessionContext ctx = new SessionContext(); + BufferAllocator allocator = new RootAllocator()) { + ctx.registerUdf( + "boom", + new ThrowsIAE(), + new ArrowType.Int(32, true), + List.of(new ArrowType.Int(32, true)), + Volatility.IMMUTABLE); + RuntimeException ex = + org.junit.jupiter.api.Assertions.assertThrows( + RuntimeException.class, + () -> { + try (DataFrame df = ctx.sql("SELECT boom(CAST(1 AS INT))"); + ArrowReader r = df.collect(allocator)) { + while (r.loadNextBatch()) {} + } + }); + String msg = ex.getMessage(); + org.junit.jupiter.api.Assertions.assertTrue( + msg.contains("IllegalArgumentException"), + "expected class name in error, got: " + msg); + org.junit.jupiter.api.Assertions.assertTrue( + msg.contains("custom boom from UDF"), + "expected user message in error, got: " + msg); + } + } + + @Test + void twoUdfsInOneSession_bothCallable() throws Exception { + try (SessionContext ctx = new SessionContext(); + BufferAllocator allocator = new RootAllocator()) { + ctx.registerUdf( + "add_one", + new AddOne(), + new ArrowType.Int(32, true), + List.of(new ArrowType.Int(32, true)), + Volatility.IMMUTABLE); + ctx.registerUdf( + "java_square", + new Square(), + new ArrowType.FloatingPoint(org.apache.arrow.vector.types.FloatingPointPrecision.DOUBLE), + List.of( + new ArrowType.FloatingPoint( + org.apache.arrow.vector.types.FloatingPointPrecision.DOUBLE)), + Volatility.IMMUTABLE); + + try (DataFrame df = + ctx.sql( + "SELECT add_one(CAST(10 AS INT)) AS a, java_square(CAST(3 AS DOUBLE)) AS b"); + ArrowReader r = df.collect(allocator)) { + assertEquals(true, r.loadNextBatch()); + VectorSchemaRoot root = r.getVectorSchemaRoot(); + IntVector a = (IntVector) root.getVector("a"); + org.apache.arrow.vector.Float8Vector b = + (org.apache.arrow.vector.Float8Vector) root.getVector("b"); + assertEquals(11, a.get(0)); + assertEquals(9.0, b.get(0), 0.0); + } + } + } + + @Test + void registerSameNameAfterCloseInNewSession_works() throws Exception { + for (int round = 0; round < 2; round++) { + try (SessionContext ctx = new SessionContext(); + BufferAllocator allocator = new RootAllocator()) { + ctx.registerUdf( + "add_one", + new AddOne(), + new ArrowType.Int(32, true), + List.of(new ArrowType.Int(32, true)), + Volatility.IMMUTABLE); + try (DataFrame df = ctx.sql("SELECT add_one(CAST(7 AS INT))"); + ArrowReader r = df.collect(allocator)) { + assertEquals(true, r.loadNextBatch()); + IntVector v = (IntVector) r.getVectorSchemaRoot().getVector(0); + assertEquals(8, v.get(0)); + } + } + } + } + + @Test + void udfAppliedToMultiRowQuery_processesAllRows() throws Exception { + try (SessionContext ctx = new SessionContext(); + BufferAllocator allocator = new RootAllocator()) { + ctx.registerUdf( + "add_one", + new AddOne(), + new ArrowType.Int(32, true), + List.of(new ArrowType.Int(32, true)), + Volatility.IMMUTABLE); + String values = + java.util.stream.IntStream.rangeClosed(1, 100) + .mapToObj(i -> "(CAST(" + i + " AS INT))") + .collect(java.util.stream.Collectors.joining(", ")); + try (DataFrame df = + ctx.sql("SELECT add_one(x) AS y FROM (VALUES " + values + ") AS t(x) ORDER BY y"); + ArrowReader r = df.collect(allocator)) { + long total = 0; + long rows = 0; + while (r.loadNextBatch()) { + IntVector y = (IntVector) r.getVectorSchemaRoot().getVector("y"); + for (int i = 0; i < y.getValueCount(); i++) { + total += y.get(i); + rows++; + } + } + assertEquals(100, rows); + // Sum of 2..101 = (2+101)*100/2 = 5150 + assertEquals(5150L, total); + } + } + } + + @Test + void volatilityBytesRoundTrip_forAllThreeKinds() throws Exception { + for (Volatility v : Volatility.values()) { + try (SessionContext ctx = new SessionContext(); + BufferAllocator allocator = new RootAllocator()) { + ctx.registerUdf( + "add_one_" + v.name().toLowerCase(), + new AddOne(), + new ArrowType.Int(32, true), + List.of(new ArrowType.Int(32, true)), + v); + try (DataFrame df = + ctx.sql("SELECT add_one_" + v.name().toLowerCase() + "(CAST(0 AS INT))"); + ArrowReader r = df.collect(allocator)) { + assertEquals(true, r.loadNextBatch()); + IntVector y = (IntVector) r.getVectorSchemaRoot().getVector(0); + assertEquals(1, y.get(0)); + } + } + } + } +} diff --git a/docs/source/user-guide/index.md b/docs/source/user-guide/index.md index 2f32499..599340c 100644 --- a/docs/source/user-guide/index.md +++ b/docs/source/user-guide/index.md @@ -37,6 +37,7 @@ sessioncontext dataframe parquet proto-plans +scalar-udf api-reference ``` diff --git a/docs/source/user-guide/scalar-udf.md b/docs/source/user-guide/scalar-udf.md new file mode 100644 index 0000000..0d748f2 --- /dev/null +++ b/docs/source/user-guide/scalar-udf.md @@ -0,0 +1,97 @@ + + +# Scalar UDFs + +A scalar UDF is a Java-implemented SQL function that operates on one row at a +time, expressed in vectorised form: each invocation receives a batch of input +columns and returns a single output column of the same length. + +## Implement + +Implement the `ScalarUdf` interface: + +```java +import java.util.List; +import org.apache.arrow.memory.BufferAllocator; +import org.apache.arrow.vector.FieldVector; +import org.apache.arrow.vector.IntVector; +import org.apache.datafusion.ScalarUdf; + +public final class AddOne implements ScalarUdf { + @Override + public FieldVector evaluate(BufferAllocator allocator, List args) { + IntVector in = (IntVector) args.get(0); + IntVector out = new IntVector("add_one", allocator); + out.allocateNew(in.getValueCount()); + for (int i = 0; i < in.getValueCount(); i++) { + if (in.isNull(i)) out.setNull(i); + else out.set(i, in.get(i) + 1); + } + out.setValueCount(in.getValueCount()); + return out; + } +} +``` + +Allocate any new vectors — including the result — from the supplied +`BufferAllocator`. The input vectors are read-only views; do not close them. +Ownership of the returned vector transfers to the framework on return. + +## Register + +```java +try (SessionContext ctx = new SessionContext()) { + ctx.registerUdf( + "add_one", + new AddOne(), + new ArrowType.Int(32, true), + List.of(new ArrowType.Int(32, true)), + Volatility.IMMUTABLE); + + try (DataFrame df = ctx.sql("SELECT add_one(x) FROM t"); + ArrowReader r = df.collect(allocator)) { + // ... + } +} +``` + +The signature is exact: a call must match the declared `argTypes` exactly. Use +`Volatility.IMMUTABLE` for pure functions, `STABLE` for functions that are +deterministic within a single query, and `VOLATILE` for non-deterministic +functions. + +## Errors + +If the UDF throws, the exception class and message surface in the +`RuntimeException` raised from `collect()`. If the returned vector is `null`, +has the wrong row count, or the wrong type, the runtime raises a +`RuntimeException` with a descriptive message. + +## Threading + +DataFusion may invoke a UDF concurrently from multiple worker threads. If the +implementation carries mutable state, the implementation must synchronize it. + +## Limitations (v1) + +- Scalar UDFs only — no aggregates, window functions, or table functions. +- Exact-signature only — no variadic or polymorphic argument lists. +- No nullable-argument short-circuiting; null inputs are passed through to the + UDF as nulls in the input vector. diff --git a/examples/src/main/java/org/apache/datafusion/examples/AddOneExample.java b/examples/src/main/java/org/apache/datafusion/examples/AddOneExample.java new file mode 100644 index 0000000..e1ef1c5 --- /dev/null +++ b/examples/src/main/java/org/apache/datafusion/examples/AddOneExample.java @@ -0,0 +1,82 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.datafusion.examples; + +import java.util.List; +import org.apache.arrow.memory.BufferAllocator; +import org.apache.arrow.memory.RootAllocator; +import org.apache.arrow.vector.FieldVector; +import org.apache.arrow.vector.IntVector; +import org.apache.arrow.vector.VectorSchemaRoot; +import org.apache.arrow.vector.ipc.ArrowReader; +import org.apache.arrow.vector.types.pojo.ArrowType; +import org.apache.datafusion.DataFrame; +import org.apache.datafusion.ScalarUdf; +import org.apache.datafusion.SessionContext; +import org.apache.datafusion.Volatility; + +/** Demonstrates registering a Java scalar UDF and invoking it from SQL. */ +public final class AddOneExample { + + /** Adds 1 to each value of an Int32 column. */ + public static final class AddOne implements ScalarUdf { + @Override + public FieldVector evaluate(BufferAllocator allocator, List args) { + IntVector in = (IntVector) args.get(0); + IntVector out = new IntVector("add_one_out", allocator); + int n = in.getValueCount(); + out.allocateNew(n); + for (int i = 0; i < n; i++) { + if (in.isNull(i)) { + out.setNull(i); + } else { + out.set(i, in.get(i) + 1); + } + } + out.setValueCount(n); + return out; + } + } + + public static void main(String[] args) throws Exception { + try (SessionContext ctx = new SessionContext(); + BufferAllocator allocator = new RootAllocator()) { + ctx.registerUdf( + "add_one", + new AddOne(), + new ArrowType.Int(32, true), + List.of(new ArrowType.Int(32, true)), + Volatility.IMMUTABLE); + + try (DataFrame df = + ctx.sql( + "SELECT add_one(x) AS y FROM (VALUES (CAST(1 AS INT)),(CAST(2 AS INT)),(CAST(3 AS INT))) AS t(x)"); + ArrowReader reader = df.collect(allocator)) { + while (reader.loadNextBatch()) { + VectorSchemaRoot root = reader.getVectorSchemaRoot(); + IntVector y = (IntVector) root.getVector("y"); + for (int i = 0; i < y.getValueCount(); i++) { + System.out.println("y = " + y.get(i)); + } + } + } + } + } +} diff --git a/native/src/lib.rs b/native/src/lib.rs index 08a919b..81efd9a 100644 --- a/native/src/lib.rs +++ b/native/src/lib.rs @@ -19,6 +19,7 @@ mod csv; mod errors; mod proto; mod schema; +mod udf; pub(crate) mod proto_gen { include!(concat!(env!("OUT_DIR"), "/datafusion_java.rs")); @@ -35,10 +36,12 @@ use datafusion::dataframe::DataFrame; use datafusion::dataframe::DataFrameWriteOptions; use datafusion::error::DataFusionError; use datafusion::execution::runtime_env::RuntimeEnvBuilder; +use datafusion::logical_expr::{ScalarUDF, Signature}; use datafusion::prelude::{ParquetReadOptions, SessionConfig, SessionContext}; -use jni::objects::{JByteArray, JClass, JObjectArray, JString}; +use jni::objects::{JByteArray, JClass, JObject, JObjectArray, JString}; use jni::sys::{jboolean, jint, jlong}; use jni::JNIEnv; +use jni::JavaVM; use prost::Message; use tokio::runtime::Runtime; @@ -47,6 +50,21 @@ use crate::proto_gen::ParquetReadOptionsProto; use crate::proto_gen::SessionOptions; use crate::schema::decode_optional_schema; +static JAVA_VM: OnceLock = OnceLock::new(); + +#[no_mangle] +pub extern "system" fn JNI_OnLoad(vm: JavaVM, _reserved: *mut std::ffi::c_void) -> jni::sys::jint { + let _ = JAVA_VM.set(vm); + jni::sys::JNI_VERSION_1_8 +} + +#[allow(dead_code)] +pub(crate) fn jvm() -> &'static JavaVM { + JAVA_VM + .get() + .expect("JNI_OnLoad has not been called; JavaVM unavailable") +} + pub(crate) fn runtime() -> &'static Runtime { static RT: OnceLock = OnceLock::new(); RT.get_or_init(|| Runtime::new().expect("failed to create Tokio runtime")) @@ -475,3 +493,61 @@ pub extern "system" fn Java_org_apache_datafusion_SessionContext_readParquetWith }) }) } + +#[no_mangle] +pub extern "system" fn Java_org_apache_datafusion_SessionContext_registerScalarUdf<'local>( + mut env: JNIEnv<'local>, + _class: JClass<'local>, + handle: jlong, + name: JString<'local>, + signature_schema_bytes: JByteArray<'local>, + volatility: jni::sys::jbyte, + udf: JObject<'local>, +) { + try_unwrap_or_throw(&mut env, (), |env| -> JniResult<()> { + if handle == 0 { + return Err("SessionContext handle is null".into()); + } + // SAFETY: handle is a valid Box allocated by createSessionContext. + let ctx = unsafe { &*(handle as *const SessionContext) }; + let name: String = env.get_string(&name)?.into(); + + // Decode the signature schema (field 0 = return type, fields 1..N = arg types). + let signature_schema = crate::schema::decode_optional_schema(env, signature_schema_bytes)? + .ok_or("signature schema bytes were null")?; + let fields = signature_schema.fields(); + if fields.is_empty() { + return Err("signature schema must have at least a return-type field".into()); + } + let return_type = fields[0].data_type().clone(); + let arg_types: Vec = fields + .iter() + .skip(1) + .map(|f| f.data_type().clone()) + .collect(); + + let volatility = crate::udf::volatility_from_byte(volatility as u8)?; + let signature = Signature::exact(arg_types, volatility); + + // Hold references that survive the JNI call. + let udf_global_ref = env.new_global_ref(&udf)?; + let bridge_class_local = env.find_class("org/apache/datafusion/internal/JniBridge")?; + let bridge_class = env.new_global_ref(&bridge_class_local)?; + let invoke_method = env.get_static_method_id( + &bridge_class_local, + "invokeScalarUdf", + "(Lorg/apache/datafusion/ScalarUdf;JJJJI)V", + )?; + + let java_udf = crate::udf::JavaScalarUdf { + name: name.clone(), + signature, + return_type, + udf_global_ref, + bridge_class, + invoke_method, + }; + ctx.register_udf(ScalarUDF::new_from_impl(java_udf)); + Ok(()) + }) +} diff --git a/native/src/udf.rs b/native/src/udf.rs new file mode 100644 index 0000000..cf19b0b --- /dev/null +++ b/native/src/udf.rs @@ -0,0 +1,293 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Java-backed scalar UDF support. + +use std::any::Any; +use std::fmt; +use std::sync::Arc; + +use datafusion::arrow::array::{make_array, Array, ArrayRef, StructArray}; +use datafusion::arrow::datatypes::{DataType, Field, Fields}; +use datafusion::arrow::ffi::{from_ffi, to_ffi, FFI_ArrowArray, FFI_ArrowSchema}; +use datafusion::error::DataFusionError; +use datafusion::logical_expr::{ + ColumnarValue, ScalarFunctionArgs, ScalarUDFImpl, Signature, TypeSignature, Volatility, +}; +use jni::objects::{GlobalRef, JStaticMethodID, JThrowable}; +use jni::signature::{Primitive, ReturnType}; +use jni::sys::{jlong, jvalue}; +use jni::JNIEnv; + +pub(crate) struct JavaScalarUdf { + pub(crate) name: String, + pub(crate) signature: Signature, + pub(crate) return_type: DataType, + /// Global ref to the user's `org.apache.datafusion.ScalarUdf` instance. + pub(crate) udf_global_ref: GlobalRef, + /// Global ref to the `org.apache.datafusion.internal.JniBridge` class. + pub(crate) bridge_class: GlobalRef, + /// Method ID for `JniBridge.invokeScalarUdf`. + pub(crate) invoke_method: JStaticMethodID, +} + +// SAFETY: JStaticMethodID is a JNI handle that's safe to share because the +// class it points to is held alive by `bridge_class`. We never mutate +// `invoke_method` after construction; DataFusion requires `Send + Sync` on +// `ScalarUDFImpl`. +unsafe impl Send for JavaScalarUdf {} +unsafe impl Sync for JavaScalarUdf {} + +impl fmt::Debug for JavaScalarUdf { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_struct("JavaScalarUdf") + .field("name", &self.name) + .field("return_type", &self.return_type) + .finish() + } +} + +impl PartialEq for JavaScalarUdf { + fn eq(&self, other: &Self) -> bool { + // Two Java UDFs are equal iff they wrap the same registered name. + self.name == other.name + } +} + +impl Eq for JavaScalarUdf {} + +impl std::hash::Hash for JavaScalarUdf { + fn hash(&self, state: &mut H) { + self.name.hash(state); + } +} + +impl ScalarUDFImpl for JavaScalarUdf { + fn as_any(&self) -> &dyn Any { + self + } + + fn name(&self) -> &str { + &self.name + } + + fn signature(&self) -> &Signature { + &self.signature + } + + fn return_type(&self, _arg_types: &[DataType]) -> datafusion::error::Result { + Ok(self.return_type.clone()) + } + + fn invoke_with_args( + &self, + args: ScalarFunctionArgs, + ) -> datafusion::error::Result { + let number_rows = args.number_rows; + + // 1. Materialise scalars to arrays so all columns are length-N. + let arrays: Vec = args + .args + .iter() + .map(|cv| cv.clone().into_array(number_rows)) + .collect::>>()?; + + // 2. Build a single struct array carrying all arg columns. Field names/types come + // from the signature's Exact type list (matches what the Java caller declared). + let signature_fields: Vec> = match &self.signature.type_signature { + TypeSignature::Exact(types) => types + .iter() + .enumerate() + .map(|(i, ty)| Arc::new(Field::new(format!("arg{}", i), ty.clone(), true))) + .collect(), + _ => { + return Err(DataFusionError::Internal( + "JavaScalarUdf signature is not Exact; only Signature::exact is supported" + .to_string(), + )) + } + }; + + let fields = Fields::from( + signature_fields + .iter() + .map(|f| f.as_ref().clone()) + .collect::>(), + ); + let struct_array = StructArray::try_new_with_length(fields, arrays, None, number_rows) + .map_err(|e| DataFusionError::ArrowError(Box::new(e), None))?; + let args_data = struct_array.into_data(); + let (args_ffi_array, args_ffi_schema) = + to_ffi(&args_data).map_err(|e| DataFusionError::ArrowError(Box::new(e), None))?; + + // 3. Pre-allocate empty FFI structs for the result. + let result_ffi_array = FFI_ArrowArray::empty(); + let result_ffi_schema = FFI_ArrowSchema::empty(); + + // 4. Box for stable addresses across the JNI call. + let mut args_array_box = Box::new(args_ffi_array); + let mut args_schema_box = Box::new(args_ffi_schema); + let mut result_array_box = Box::new(result_ffi_array); + let mut result_schema_box = Box::new(result_ffi_schema); + + let args_array_addr = args_array_box.as_mut() as *mut _ as jlong; + let args_schema_addr = args_schema_box.as_mut() as *mut _ as jlong; + let result_array_addr = result_array_box.as_mut() as *mut _ as jlong; + let result_schema_addr = result_schema_box.as_mut() as *mut _ as jlong; + + // 5. Attach JNI to current thread. + let mut env = crate::jvm() + .attach_current_thread() + .map_err(|e| DataFusionError::Execution(format!("JNI attach failed: {}", e)))?; + + // 6. Call JniBridge.invokeScalarUdf(udf, args*, result*, expectedRowCount). + // + // Build the jvalue argument array for call_static_method_unchecked. + // SAFETY: we build the args inline and pass them immediately; the JObject + // pointed to by udf_global_ref is alive for the duration of this call. + let expected_rows = i32::try_from(number_rows).map_err(|_| { + DataFusionError::Execution(format!( + "batch row count {} exceeds i32::MAX; UDFs cannot handle batches larger than 2^31 - 1 rows", + number_rows + )) + })?; + + let udf_jobject = self.udf_global_ref.as_obj(); + // SAFETY: udf_jobject is derived from a GlobalRef alive for the duration of this + // function. The raw pointer is only read by the JNI call below, which happens + // before any code that could drop udf_global_ref. + let call_args: [jvalue; 6] = [ + // ScalarUdf instance + jvalue { + l: udf_jobject.as_raw(), + }, + // argsArrayAddr + jvalue { j: args_array_addr }, + // argsSchemaAddr + jvalue { + j: args_schema_addr, + }, + // resultArrayAddr + jvalue { + j: result_array_addr, + }, + // resultSchemaAddr + jvalue { + j: result_schema_addr, + }, + // expectedRowCount + jvalue { i: expected_rows }, + ]; + + let call_result = unsafe { + env.call_static_method_unchecked( + &self.bridge_class, + self.invoke_method, + ReturnType::Primitive(Primitive::Void), + &call_args, + ) + }; + + // 7. If Java threw, translate to DataFusionError. Always check exception_check first. + if env.exception_check().unwrap_or(false) { + let throwable = env.exception_occurred().map_err(|e| { + DataFusionError::Execution(format!("exception_occurred failed: {}", e)) + })?; + env.exception_clear().ok(); + let message = jthrowable_to_string(&mut env, &throwable, &self.name); + return Err(DataFusionError::Execution(message)); + } + call_result.map_err(|e| DataFusionError::Execution(format!("JNI call failed: {}", e)))?; + + // 8. Import result. from_ffi consumes the FFI_ArrowArray. + let result_array = *result_array_box; + let result_schema = *result_schema_box; + // SAFETY: Java's `Data.exportVector` populated `result_array_box` and + // `result_schema_box` in place via the C Data Interface, and the + // exception check above guarantees the call succeeded without + // throwing — so the FFI structs are fully initialized. + let result_data = unsafe { from_ffi(result_array, &result_schema) } + .map_err(|e| DataFusionError::ArrowError(Box::new(e), None))?; + + // 9. Validate type. + if result_data.data_type() != &self.return_type { + return Err(DataFusionError::Execution(format!( + "Java UDF '{}' returned vector of type {:?}; declared return type was {:?}", + self.name, + result_data.data_type(), + self.return_type + ))); + } + + let array: ArrayRef = make_array(result_data); + Ok(ColumnarValue::Array(array)) + } +} + +pub(crate) fn volatility_from_byte(byte: u8) -> datafusion::error::Result { + match byte { + 0 => Ok(Volatility::Immutable), + 1 => Ok(Volatility::Stable), + 2 => Ok(Volatility::Volatile), + other => Err(DataFusionError::Execution(format!( + "unknown volatility byte: {}", + other + ))), + } +} + +/// Best-effort: extract class name and getMessage() from a Java throwable. +/// Anything that goes wrong collapses to a generic message so we don't +/// double-throw inside an error path. +fn jthrowable_to_string(env: &mut JNIEnv, throwable: &JThrowable, udf_name: &str) -> String { + let class_name_result = (|| -> jni::errors::Result { + let class = env.call_method(throwable, "getClass", "()Ljava/lang/Class;", &[])?; + let class_obj = class.l()?; + let name = env.call_method(&class_obj, "getName", "()Ljava/lang/String;", &[])?; + let name_obj = name.l()?; + let name_str: String = env.get_string(&name_obj.into())?.into(); + Ok(name_str) + })(); + let class_name = match class_name_result { + Ok(s) => s, + Err(_) => { + // A reflective call itself threw — clear that secondary exception so the + // thread is in a clean state when we return to the JVM. + env.exception_clear().ok(); + "".to_string() + } + }; + + let message_result = (|| -> jni::errors::Result { + let msg = env.call_method(throwable, "getMessage", "()Ljava/lang/String;", &[])?; + let msg_obj = msg.l()?; + if msg_obj.is_null() { + return Ok("".to_string()); + } + let s: String = env.get_string(&msg_obj.into())?.into(); + Ok(s) + })(); + let message = match message_result { + Ok(s) => s, + Err(_) => { + env.exception_clear().ok(); + "".to_string() + } + }; + + format!("Java UDF '{}' threw {}: {}", udf_name, class_name, message) +}