From fd660761739c8346dbc29d2f0d9de6f38386c659 Mon Sep 17 00:00:00 2001 From: Lantao Jin Date: Mon, 18 May 2026 07:54:17 +0000 Subject: [PATCH] fix(udf): declare scalar UDF signatures with Field, not ArrowType Change ScalarFunction.argTypes() / returnType() (List / ArrowType) to argFields() / returnField() (List / Field). SessionContext.registerUdf forwards the Fields straight through. JavaScalarUdf stores the full return FieldRef and overrides ScalarUDFImpl::return_field_from_args, so declared nullability and metadata round-trip into the result schema. ArrowType is a leaf marker in Java Arrow: ArrowType.List has no fields, and child element / member / key / value types live on the parent Field's children list. The previous registration code reconstructed the schema with `new Field(..., FieldType.nullable( type), null)`, dropping nested-type metadata; the previous Rust impl only stored a DataType, so the default return_field_from_args wrapped results in a fresh always-nullable Field. Both are fixed by storing and forwarding the user's Fields verbatim. --- .../org/apache/datafusion/ScalarFunction.java | 21 +- .../java/org/apache/datafusion/ScalarUdf.java | 22 +- .../org/apache/datafusion/SessionContext.java | 14 +- .../org/apache/datafusion/ScalarUdfTest.java | 248 ++++++++++++++++-- .../datafusion/examples/AddOneExample.java | 9 +- native/src/lib.rs | 4 +- native/src/udf.rs | 28 +- 7 files changed, 291 insertions(+), 55 deletions(-) diff --git a/core/src/main/java/org/apache/datafusion/ScalarFunction.java b/core/src/main/java/org/apache/datafusion/ScalarFunction.java index 676154e..db4b5f9 100644 --- a/core/src/main/java/org/apache/datafusion/ScalarFunction.java +++ b/core/src/main/java/org/apache/datafusion/ScalarFunction.java @@ -23,7 +23,7 @@ import org.apache.arrow.memory.BufferAllocator; import org.apache.arrow.vector.FieldVector; -import org.apache.arrow.vector.types.pojo.ArrowType; +import org.apache.arrow.vector.types.pojo.Field; /** * A Java-implemented scalar SQL function. Implementations declare their own name, signature, and @@ -41,13 +41,24 @@ public interface ScalarFunction { String name(); /** - * Declared argument types, in positional order. The function is registered with an exact + * Declared argument fields, in positional order. The function is registered with an exact * signature; calls whose argument types do not match exactly are rejected. + * + *

Each entry is an Arrow {@link Field} -- a name plus a {@code FieldType} plus an optional + * list of child fields. Use {@link Field#nullable(String, + * org.apache.arrow.vector.types.pojo.ArrowType)} for primitive types (e.g. {@code + * Field.nullable("arg0", new ArrowType.Int(32, true))}). Nested types like {@code List}, {@code + * Struct}, and {@code Map} require the children list to carry element / member / key / value type + * information; constructing a {@code Field} via {@code new Field(name, FieldType, children)} is + * the canonical Arrow way to do that. */ - List argTypes(); + List argFields(); - /** Declared return type. The returned {@link FieldVector} must have this exact type. */ - ArrowType returnType(); + /** + * Declared return field. The returned {@link FieldVector} must have this exact type, including + * any nested children. Same construction rules as {@link #argFields()}. + */ + Field returnField(); /** * Volatility classification. Use {@link Volatility#IMMUTABLE} for pure functions, {@link diff --git a/core/src/main/java/org/apache/datafusion/ScalarUdf.java b/core/src/main/java/org/apache/datafusion/ScalarUdf.java index 59cbd07..4fda894 100644 --- a/core/src/main/java/org/apache/datafusion/ScalarUdf.java +++ b/core/src/main/java/org/apache/datafusion/ScalarUdf.java @@ -22,7 +22,7 @@ import java.util.List; import java.util.Objects; -import org.apache.arrow.vector.types.pojo.ArrowType; +import org.apache.arrow.vector.types.pojo.Field; /** * A scalar UDF registration handle: pairs a {@link ScalarFunction} implementation with the metadata @@ -35,8 +35,8 @@ public final class ScalarUdf { private final ScalarFunction impl; private final String name; - private final List argTypes; - private final ArrowType returnType; + private final List argFields; + private final Field returnField; private final Volatility volatility; /** @@ -48,8 +48,8 @@ public final class ScalarUdf { public ScalarUdf(ScalarFunction impl) { this.impl = Objects.requireNonNull(impl, "impl"); this.name = Objects.requireNonNull(impl.name(), "impl.name()"); - this.argTypes = Objects.requireNonNull(impl.argTypes(), "impl.argTypes()"); - this.returnType = Objects.requireNonNull(impl.returnType(), "impl.returnType()"); + this.argFields = Objects.requireNonNull(impl.argFields(), "impl.argFields()"); + this.returnField = Objects.requireNonNull(impl.returnField(), "impl.returnField()"); this.volatility = Objects.requireNonNull(impl.volatility(), "impl.volatility()"); } @@ -68,14 +68,14 @@ public String name() { return name; } - /** Declared argument types; cached from {@link ScalarFunction#argTypes()}. */ - public List argTypes() { - return argTypes; + /** Declared argument fields; cached from {@link ScalarFunction#argFields()}. */ + public List argFields() { + return argFields; } - /** Declared return type; cached from {@link ScalarFunction#returnType()}. */ - public ArrowType returnType() { - return returnType; + /** Declared return field; cached from {@link ScalarFunction#returnField()}. */ + public Field returnField() { + return returnField; } /** Volatility classification; cached from {@link ScalarFunction#volatility()}. */ diff --git a/core/src/main/java/org/apache/datafusion/SessionContext.java b/core/src/main/java/org/apache/datafusion/SessionContext.java index 328eb6d..a2fc77a 100644 --- a/core/src/main/java/org/apache/datafusion/SessionContext.java +++ b/core/src/main/java/org/apache/datafusion/SessionContext.java @@ -32,9 +32,7 @@ import org.apache.arrow.vector.ipc.ArrowStreamWriter; import org.apache.arrow.vector.ipc.ReadChannel; import org.apache.arrow.vector.ipc.message.MessageSerializer; -import org.apache.arrow.vector.types.pojo.ArrowType; import org.apache.arrow.vector.types.pojo.Field; -import org.apache.arrow.vector.types.pojo.FieldType; import org.apache.arrow.vector.types.pojo.Schema; /** @@ -367,7 +365,7 @@ public DataFrame readArrow(String path, ArrowReadOptions options) { * via the UDF's name or referenced in DataFusion plans deserialised with {@link #fromProto}. * *

The UDF is registered with an exact signature: the runtime will reject calls whose argument - * types do not match the declared {@link ScalarFunction#argTypes()} exactly. + * types do not match the declared {@link ScalarFunction#argFields()} exactly. * * @throws RuntimeException if registration fails (e.g., name already registered with an * incompatible signature, schema serialisation failure). @@ -379,14 +377,10 @@ public void registerUdf(ScalarUdf udf) { java.util.Objects.requireNonNull(udf, "udf"); ScalarFunction impl = udf.impl(); String name = udf.name(); - ArrowType returnType = udf.returnType(); - List argTypes = udf.argTypes(); Volatility volatility = udf.volatility(); - List fields = new ArrayList<>(argTypes.size() + 1); - fields.add(new Field("return", FieldType.nullable(returnType), null)); - for (int i = 0; i < argTypes.size(); i++) { - fields.add(new Field("arg" + i, FieldType.nullable(argTypes.get(i)), null)); - } + List fields = new ArrayList<>(udf.argFields().size() + 1); + fields.add(udf.returnField()); + fields.addAll(udf.argFields()); Schema signatureSchema = new Schema(fields); byte[] signatureBytes = serializeSchemaIpc(signatureSchema); registerScalarUdf(nativeHandle, name, signatureBytes, volatility.code(), impl); diff --git a/core/src/test/java/org/apache/datafusion/ScalarUdfTest.java b/core/src/test/java/org/apache/datafusion/ScalarUdfTest.java index 97f5f52..7723317 100644 --- a/core/src/test/java/org/apache/datafusion/ScalarUdfTest.java +++ b/core/src/test/java/org/apache/datafusion/ScalarUdfTest.java @@ -20,6 +20,8 @@ package org.apache.datafusion; import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertFalse; +import static org.junit.jupiter.api.Assertions.assertTrue; import java.util.List; @@ -28,13 +30,18 @@ import org.apache.arrow.vector.FieldVector; import org.apache.arrow.vector.IntVector; import org.apache.arrow.vector.VectorSchemaRoot; +import org.apache.arrow.vector.complex.ListVector; +import org.apache.arrow.vector.complex.StructVector; import org.apache.arrow.vector.ipc.ArrowReader; import org.apache.arrow.vector.types.pojo.ArrowType; +import org.apache.arrow.vector.types.pojo.Field; +import org.apache.arrow.vector.types.pojo.FieldType; import org.junit.jupiter.api.Test; class ScalarUdfTest { private static final ArrowType INT32 = new ArrowType.Int(32, true); + private static final ArrowType INT64 = new ArrowType.Int(64, true); private static final ArrowType FLOAT64 = new ArrowType.FloatingPoint(org.apache.arrow.vector.types.FloatingPointPrecision.DOUBLE); private static final ArrowType UTF8 = new ArrowType.Utf8(); @@ -42,15 +49,15 @@ class ScalarUdfTest { /** Test-only base that supplies the four metadata getters from constructor args. */ abstract static class AbstractScalarFunction implements ScalarFunction { private final String name; - private final List argTypes; - private final ArrowType returnType; + private final List argFields; + private final Field returnField; private final Volatility volatility; AbstractScalarFunction( - String name, List argTypes, ArrowType returnType, Volatility volatility) { + String name, List argFields, Field returnField, Volatility volatility) { this.name = name; - this.argTypes = argTypes; - this.returnType = returnType; + this.argFields = argFields; + this.returnField = returnField; this.volatility = volatility; } @@ -60,13 +67,13 @@ public final String name() { } @Override - public final List argTypes() { - return argTypes; + public final List argFields() { + return argFields; } @Override - public final ArrowType returnType() { - return returnType; + public final Field returnField() { + return returnField; } @Override @@ -82,7 +89,7 @@ static final class AddOne extends AbstractScalarFunction { } AddOne(String name, Volatility volatility) { - super(name, List.of(INT32), INT32, volatility); + super(name, List.of(Field.nullable("x", INT32)), Field.nullable("y", INT32), volatility); } @Override @@ -129,7 +136,11 @@ void addOne_overConstantTable_returnsIncrementedValues() throws Exception { /** Concatenates two Utf8 columns. */ static final class Concat extends AbstractScalarFunction { Concat() { - super("java_concat", List.of(UTF8, UTF8), UTF8, Volatility.IMMUTABLE); + super( + "java_concat", + List.of(Field.nullable("a", UTF8), Field.nullable("b", UTF8)), + Field.nullable("c", UTF8), + Volatility.IMMUTABLE); } @Override @@ -184,7 +195,11 @@ void concat_overVarCharColumns_concatenatesValues() throws Exception { /** Squares a Float64 column. */ static final class Square extends AbstractScalarFunction { Square() { - super("java_square", List.of(FLOAT64), FLOAT64, Volatility.IMMUTABLE); + super( + "java_square", + List.of(Field.nullable("x", FLOAT64)), + Field.nullable("y", FLOAT64), + Volatility.IMMUTABLE); } @Override @@ -248,7 +263,11 @@ void addOne_invokedTwiceInOneSession_executesIndependently() throws Exception { static final class ReturnsNull extends AbstractScalarFunction { ReturnsNull() { - super("bad_null", List.of(INT32), INT32, Volatility.IMMUTABLE); + super( + "bad_null", + List.of(Field.nullable("x", INT32)), + Field.nullable("y", INT32), + Volatility.IMMUTABLE); } @Override @@ -279,7 +298,11 @@ void udfReturningNull_surfacesIllegalStateException() { static final class WrongRowCount extends AbstractScalarFunction { WrongRowCount() { - super("bad_rows", List.of(INT32), INT32, Volatility.IMMUTABLE); + super( + "bad_rows", + List.of(Field.nullable("x", INT32)), + Field.nullable("y", INT32), + Volatility.IMMUTABLE); } @Override @@ -315,7 +338,11 @@ void udfReturningWrongRowCount_surfacesIllegalStateException() { static final class WrongType extends AbstractScalarFunction { WrongType() { - super("bad_type", List.of(INT32), INT32, Volatility.IMMUTABLE); + super( + "bad_type", + List.of(Field.nullable("x", INT32)), + Field.nullable("y", INT32), + Volatility.IMMUTABLE); } @Override @@ -352,7 +379,11 @@ void udfReturningWrongType_surfacesTypeMismatch() { static final class ThrowsIAE extends AbstractScalarFunction { ThrowsIAE() { - super("boom", List.of(INT32), INT32, Volatility.IMMUTABLE); + super( + "boom", + List.of(Field.nullable("x", INT32)), + Field.nullable("y", INT32), + Volatility.IMMUTABLE); } @Override @@ -464,4 +495,189 @@ void volatilityBytesRoundTrip_forAllThreeKinds() throws Exception { } } } + + // --------------------------------------------------------------------- + // Nested-type UDF tests. These pinpoint the regression #58 fixed. + // --------------------------------------------------------------------- + + /** + * UDF taking a {@code List} argument and returning its length as Int32. Exercises that + * nested Arrow types -- whose element / member types live on the parent {@link Field}'s {@code + * children} list, not inside {@link ArrowType} -- can be declared and registered. + */ + static final class ListLength extends AbstractScalarFunction { + ListLength() { + super( + "java_list_length", + List.of( + new Field( + "vals", + FieldType.nullable(new ArrowType.List()), + List.of(Field.nullable("item", INT32)))), + Field.nullable("len", INT32), + Volatility.IMMUTABLE); + } + + @Override + public FieldVector evaluate(BufferAllocator allocator, List args) { + ListVector in = (ListVector) args.get(0); + IntVector out = new IntVector("len_out", allocator); + int n = in.getValueCount(); + out.allocateNew(n); + for (int i = 0; i < n; i++) { + if (in.isNull(i)) { + out.setNull(i); + } else { + out.set(i, in.getElementEndIndex(i) - in.getElementStartIndex(i)); + } + } + out.setValueCount(n); + return out; + } + } + + @Test + void udfWithListArg_canBeRegistered() { + // Smoke test: registration alone must succeed for nested-type UDFs. Without #58's fix the + // schema-IPC writer rejects List with no children. + try (SessionContext ctx = new SessionContext()) { + ctx.registerUdf(new ScalarUdf(new ListLength())); + } + } + + @Test + void udfWithListArg_canBeInvokedFromSql() throws Exception { + // End-to-end: the registered UDF is callable from SQL with literal list arguments and the + // body sees the right element type. SELECT java_list_length([10, 20, 30]) -> 3. + try (SessionContext ctx = new SessionContext(); + BufferAllocator allocator = new RootAllocator()) { + ctx.registerUdf(new ScalarUdf(new ListLength())); + + try (DataFrame df = + ctx.sql( + "SELECT java_list_length(make_array(CAST(10 AS INT), CAST(20 AS INT)," + + " CAST(30 AS INT))) AS n"); + ArrowReader r = df.collect(allocator)) { + assertEquals(true, r.loadNextBatch()); + IntVector n = (IntVector) r.getVectorSchemaRoot().getVector("n"); + assertEquals(1, n.getValueCount()); + assertEquals(3, n.get(0)); + } + } + } + + /** + * UDF taking a {@code Struct} and returning the sum {@code a + b} as Int64. + * Confirms the fix is structural rather than List-specific. + */ + static final class SumStructFields extends AbstractScalarFunction { + SumStructFields() { + super( + "java_sum_struct", + List.of( + new Field( + "ab", + FieldType.nullable(new ArrowType.Struct()), + List.of(Field.nullable("a", INT32), Field.nullable("b", INT32)))), + Field.nullable("s", INT64), + Volatility.IMMUTABLE); + } + + @Override + public FieldVector evaluate(BufferAllocator allocator, List args) { + StructVector in = (StructVector) args.get(0); + IntVector a = (IntVector) in.getChild("a"); + IntVector b = (IntVector) in.getChild("b"); + org.apache.arrow.vector.BigIntVector out = + new org.apache.arrow.vector.BigIntVector("sum_out", allocator); + int n = in.getValueCount(); + out.allocateNew(n); + for (int i = 0; i < n; i++) { + if (in.isNull(i) || a.isNull(i) || b.isNull(i)) { + out.setNull(i); + } else { + out.set(i, (long) a.get(i) + (long) b.get(i)); + } + } + out.setValueCount(n); + return out; + } + } + + @Test + void udfWithStructArg_canBeRegistered() { + // Struct child fields ride through the same Field children list as List elements; if the fix + // works for List it should work for Struct. This pins that. + try (SessionContext ctx = new SessionContext()) { + ctx.registerUdf(new ScalarUdf(new SumStructFields())); + } + } + + @Test + void udfWithStructArg_canBeInvokedFromSql() throws Exception { + try (SessionContext ctx = new SessionContext(); + BufferAllocator allocator = new RootAllocator()) { + ctx.registerUdf(new ScalarUdf(new SumStructFields())); + + try (DataFrame df = + ctx.sql( + "SELECT java_sum_struct(named_struct('a', CAST(3 AS INT), 'b', CAST(4 AS INT)))" + + " AS s"); + ArrowReader r = df.collect(allocator)) { + assertEquals(true, r.loadNextBatch()); + org.apache.arrow.vector.BigIntVector s = + (org.apache.arrow.vector.BigIntVector) r.getVectorSchemaRoot().getVector("s"); + assertEquals(1, s.getValueCount()); + assertEquals(7L, s.get(0)); + } + } + } + + /** + * UDF declaring a non-nullable return Field. DataFusion's default {@code return_field_from_args} + * wraps the return type in a fresh always-nullable Field, so without the {@code + * JavaScalarUdf::return_field_from_args} override the planner sees this UDF's output as nullable + * even though the Java caller said otherwise. + */ + static final class NonNullableConstOne extends AbstractScalarFunction { + NonNullableConstOne() { + super( + "java_const_one_nn", + List.of(), + new Field("v", new FieldType(false, INT32, null), null), + Volatility.IMMUTABLE); + } + + @Override + public FieldVector evaluate(BufferAllocator allocator, List args) { + // 'args' is empty; we'll be invoked once per batch with rowCount=1 here since the call + // sites use a 1-row table. Sized to match. + IntVector out = new IntVector("out", allocator); + out.allocateNew(1); + out.set(0, 1); + out.setValueCount(1); + return out; + } + } + + @Test + void udfWithNonNullableReturnField_preservesNullabilityInResultSchema() throws Exception { + // The result column's schema must reflect the declared non-nullable return Field. Before the + // fix the JavaScalarUdf only stored a DataType; DataFusion's default return_field_from_args + // synthesised a fresh always-nullable Field, so the column came back as nullable. + try (SessionContext ctx = new SessionContext(); + BufferAllocator allocator = new RootAllocator()) { + ctx.registerUdf(new ScalarUdf(new NonNullableConstOne())); + + try (DataFrame df = ctx.sql("SELECT java_const_one_nn() AS v"); + ArrowReader r = df.collect(allocator)) { + assertTrue(r.loadNextBatch()); + Field resultField = r.getVectorSchemaRoot().getSchema().findField("v"); + assertFalse( + resultField.isNullable(), + "expected declared non-nullable return Field to round-trip through registration," + + " got nullable=true"); + } + } + } } diff --git a/examples/src/main/java/org/apache/datafusion/examples/AddOneExample.java b/examples/src/main/java/org/apache/datafusion/examples/AddOneExample.java index c27bff0..2fd47ff 100644 --- a/examples/src/main/java/org/apache/datafusion/examples/AddOneExample.java +++ b/examples/src/main/java/org/apache/datafusion/examples/AddOneExample.java @@ -28,6 +28,7 @@ import org.apache.arrow.vector.VectorSchemaRoot; import org.apache.arrow.vector.ipc.ArrowReader; import org.apache.arrow.vector.types.pojo.ArrowType; +import org.apache.arrow.vector.types.pojo.Field; import org.apache.datafusion.DataFrame; import org.apache.datafusion.ScalarFunction; import org.apache.datafusion.ScalarUdf; @@ -47,13 +48,13 @@ public String name() { } @Override - public List argTypes() { - return List.of(INT32); + public List argFields() { + return List.of(Field.nullable("x", INT32)); } @Override - public ArrowType returnType() { - return INT32; + public Field returnField() { + return Field.nullable("y", INT32); } @Override diff --git a/native/src/lib.rs b/native/src/lib.rs index f6f16d3..7e6a850 100644 --- a/native/src/lib.rs +++ b/native/src/lib.rs @@ -646,7 +646,7 @@ pub extern "system" fn Java_org_apache_datafusion_SessionContext_registerScalarU if fields.is_empty() { return Err("signature schema must have at least a return-type field".into()); } - let return_type = fields[0].data_type().clone(); + let return_field = fields[0].clone(); let arg_types: Vec = fields .iter() .skip(1) @@ -669,7 +669,7 @@ pub extern "system" fn Java_org_apache_datafusion_SessionContext_registerScalarU let java_udf = crate::udf::JavaScalarUdf { name: name.clone(), signature, - return_type, + return_field, udf_global_ref, bridge_class, invoke_method, diff --git a/native/src/udf.rs b/native/src/udf.rs index 62d0e24..1e18751 100644 --- a/native/src/udf.rs +++ b/native/src/udf.rs @@ -22,11 +22,12 @@ use std::fmt; use std::sync::Arc; use datafusion::arrow::array::{make_array, Array, ArrayRef, StructArray}; -use datafusion::arrow::datatypes::{DataType, Field, Fields}; +use datafusion::arrow::datatypes::{DataType, Field, FieldRef, Fields}; use datafusion::arrow::ffi::{from_ffi, to_ffi, FFI_ArrowArray, FFI_ArrowSchema}; use datafusion::error::DataFusionError; use datafusion::logical_expr::{ - ColumnarValue, ScalarFunctionArgs, ScalarUDFImpl, Signature, TypeSignature, Volatility, + ColumnarValue, ReturnFieldArgs, ScalarFunctionArgs, ScalarUDFImpl, Signature, TypeSignature, + Volatility, }; use jni::objects::{GlobalRef, JStaticMethodID, JThrowable}; use jni::signature::{Primitive, ReturnType}; @@ -36,7 +37,10 @@ use jni::JNIEnv; pub(crate) struct JavaScalarUdf { pub(crate) name: String, pub(crate) signature: Signature, - pub(crate) return_type: DataType, + /// The full return Field as the Java caller declared it. Carries the data type plus + /// nullability and any metadata; reused as both `return_type()` and the result of + /// `return_field_from_args()` so callers see the user's declaration verbatim. + pub(crate) return_field: FieldRef, /// Global ref to the user's `org.apache.datafusion.ScalarFunction` instance. pub(crate) udf_global_ref: GlobalRef, /// Global ref to the `org.apache.datafusion.internal.JniBridge` class. @@ -56,7 +60,7 @@ impl fmt::Debug for JavaScalarUdf { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { f.debug_struct("JavaScalarUdf") .field("name", &self.name) - .field("return_type", &self.return_type) + .field("return_field", &self.return_field) .finish() } } @@ -90,7 +94,17 @@ impl ScalarUDFImpl for JavaScalarUdf { } fn return_type(&self, _arg_types: &[DataType]) -> datafusion::error::Result { - Ok(self.return_type.clone()) + Ok(self.return_field.data_type().clone()) + } + + fn return_field_from_args( + &self, + _args: ReturnFieldArgs, + ) -> datafusion::error::Result { + // The default impl wraps return_type() in a fresh Field that's always nullable and + // carries no metadata. We hold the user's declared Field verbatim, so return it -- this + // preserves the declared nullability and any metadata they attached. + Ok(self.return_field.clone()) } fn invoke_with_args( @@ -224,12 +238,12 @@ impl ScalarUDFImpl for JavaScalarUdf { .map_err(|e| DataFusionError::ArrowError(Box::new(e), None))?; // 9. Validate type. - if result_data.data_type() != &self.return_type { + if result_data.data_type() != self.return_field.data_type() { return Err(DataFusionError::Execution(format!( "Java UDF '{}' returned vector of type {:?}; declared return type was {:?}", self.name, result_data.data_type(), - self.return_type + self.return_field.data_type() ))); }