diff --git a/core/src/main/java/org/apache/datafusion/ScalarFunction.java b/core/src/main/java/org/apache/datafusion/ScalarFunction.java
index b83c636..d11c30a 100644
--- a/core/src/main/java/org/apache/datafusion/ScalarFunction.java
+++ b/core/src/main/java/org/apache/datafusion/ScalarFunction.java
@@ -22,7 +22,7 @@
import java.util.List;
import org.apache.arrow.memory.BufferAllocator;
-import org.apache.arrow.vector.types.pojo.ArrowType;
+import org.apache.arrow.vector.types.pojo.Field;
/**
* A Java-implemented scalar SQL function. Implementations declare their own name, signature, and
@@ -40,15 +40,24 @@ public interface ScalarFunction {
String name();
/**
- * Declared argument types, in positional order. The function is registered with an exact
+ * Declared argument fields, in positional order. The function is registered with an exact
* signature; calls whose argument types do not match exactly are rejected.
+ *
+ *
Each entry is an Arrow {@link Field} -- a name plus a {@code FieldType} plus an optional
+ * list of child fields. Use {@link Field#nullable(String,
+ * org.apache.arrow.vector.types.pojo.ArrowType)} for primitive types (e.g. {@code
+ * Field.nullable("arg0", new ArrowType.Int(32, true))}). Nested types like {@code List}, {@code
+ * Struct}, and {@code Map} require the children list to carry element / member / key / value type
+ * information; constructing a {@code Field} via {@code new Field(name, FieldType, children)} is
+ * the canonical Arrow way to do that.
*/
- List argTypes();
+ List argFields();
/**
- * Declared return type. The returned {@link ColumnarValue}'s vector must have this exact type.
+ * Declared return field. The returned {@link ColumnarValue}'s vector must have this exact type,
+ * including any nested children. Same construction rules as {@link #argFields()}.
*/
- ArrowType returnType();
+ Field returnField();
/**
* Volatility classification. Use {@link Volatility#IMMUTABLE} for pure functions, {@link
diff --git a/core/src/main/java/org/apache/datafusion/ScalarUdf.java b/core/src/main/java/org/apache/datafusion/ScalarUdf.java
index 59cbd07..4fda894 100644
--- a/core/src/main/java/org/apache/datafusion/ScalarUdf.java
+++ b/core/src/main/java/org/apache/datafusion/ScalarUdf.java
@@ -22,7 +22,7 @@
import java.util.List;
import java.util.Objects;
-import org.apache.arrow.vector.types.pojo.ArrowType;
+import org.apache.arrow.vector.types.pojo.Field;
/**
* A scalar UDF registration handle: pairs a {@link ScalarFunction} implementation with the metadata
@@ -35,8 +35,8 @@
public final class ScalarUdf {
private final ScalarFunction impl;
private final String name;
- private final List argTypes;
- private final ArrowType returnType;
+ private final List argFields;
+ private final Field returnField;
private final Volatility volatility;
/**
@@ -48,8 +48,8 @@ public final class ScalarUdf {
public ScalarUdf(ScalarFunction impl) {
this.impl = Objects.requireNonNull(impl, "impl");
this.name = Objects.requireNonNull(impl.name(), "impl.name()");
- this.argTypes = Objects.requireNonNull(impl.argTypes(), "impl.argTypes()");
- this.returnType = Objects.requireNonNull(impl.returnType(), "impl.returnType()");
+ this.argFields = Objects.requireNonNull(impl.argFields(), "impl.argFields()");
+ this.returnField = Objects.requireNonNull(impl.returnField(), "impl.returnField()");
this.volatility = Objects.requireNonNull(impl.volatility(), "impl.volatility()");
}
@@ -68,14 +68,14 @@ public String name() {
return name;
}
- /** Declared argument types; cached from {@link ScalarFunction#argTypes()}. */
- public List argTypes() {
- return argTypes;
+ /** Declared argument fields; cached from {@link ScalarFunction#argFields()}. */
+ public List argFields() {
+ return argFields;
}
- /** Declared return type; cached from {@link ScalarFunction#returnType()}. */
- public ArrowType returnType() {
- return returnType;
+ /** Declared return field; cached from {@link ScalarFunction#returnField()}. */
+ public Field returnField() {
+ return returnField;
}
/** Volatility classification; cached from {@link ScalarFunction#volatility()}. */
diff --git a/core/src/main/java/org/apache/datafusion/SessionContext.java b/core/src/main/java/org/apache/datafusion/SessionContext.java
index 328eb6d..a2fc77a 100644
--- a/core/src/main/java/org/apache/datafusion/SessionContext.java
+++ b/core/src/main/java/org/apache/datafusion/SessionContext.java
@@ -32,9 +32,7 @@
import org.apache.arrow.vector.ipc.ArrowStreamWriter;
import org.apache.arrow.vector.ipc.ReadChannel;
import org.apache.arrow.vector.ipc.message.MessageSerializer;
-import org.apache.arrow.vector.types.pojo.ArrowType;
import org.apache.arrow.vector.types.pojo.Field;
-import org.apache.arrow.vector.types.pojo.FieldType;
import org.apache.arrow.vector.types.pojo.Schema;
/**
@@ -367,7 +365,7 @@ public DataFrame readArrow(String path, ArrowReadOptions options) {
* via the UDF's name or referenced in DataFusion plans deserialised with {@link #fromProto}.
*
* The UDF is registered with an exact signature: the runtime will reject calls whose argument
- * types do not match the declared {@link ScalarFunction#argTypes()} exactly.
+ * types do not match the declared {@link ScalarFunction#argFields()} exactly.
*
* @throws RuntimeException if registration fails (e.g., name already registered with an
* incompatible signature, schema serialisation failure).
@@ -379,14 +377,10 @@ public void registerUdf(ScalarUdf udf) {
java.util.Objects.requireNonNull(udf, "udf");
ScalarFunction impl = udf.impl();
String name = udf.name();
- ArrowType returnType = udf.returnType();
- List argTypes = udf.argTypes();
Volatility volatility = udf.volatility();
- List fields = new ArrayList<>(argTypes.size() + 1);
- fields.add(new Field("return", FieldType.nullable(returnType), null));
- for (int i = 0; i < argTypes.size(); i++) {
- fields.add(new Field("arg" + i, FieldType.nullable(argTypes.get(i)), null));
- }
+ List fields = new ArrayList<>(udf.argFields().size() + 1);
+ fields.add(udf.returnField());
+ fields.addAll(udf.argFields());
Schema signatureSchema = new Schema(fields);
byte[] signatureBytes = serializeSchemaIpc(signatureSchema);
registerScalarUdf(nativeHandle, name, signatureBytes, volatility.code(), impl);
diff --git a/core/src/test/java/org/apache/datafusion/ScalarUdfTest.java b/core/src/test/java/org/apache/datafusion/ScalarUdfTest.java
index 3e14580..c70bcd9 100644
--- a/core/src/test/java/org/apache/datafusion/ScalarUdfTest.java
+++ b/core/src/test/java/org/apache/datafusion/ScalarUdfTest.java
@@ -20,6 +20,8 @@
package org.apache.datafusion;
import static org.junit.jupiter.api.Assertions.assertEquals;
+import static org.junit.jupiter.api.Assertions.assertFalse;
+import static org.junit.jupiter.api.Assertions.assertTrue;
import java.util.List;
@@ -28,13 +30,18 @@
import org.apache.arrow.vector.FieldVector;
import org.apache.arrow.vector.IntVector;
import org.apache.arrow.vector.VectorSchemaRoot;
+import org.apache.arrow.vector.complex.ListVector;
+import org.apache.arrow.vector.complex.StructVector;
import org.apache.arrow.vector.ipc.ArrowReader;
import org.apache.arrow.vector.types.pojo.ArrowType;
+import org.apache.arrow.vector.types.pojo.Field;
+import org.apache.arrow.vector.types.pojo.FieldType;
import org.junit.jupiter.api.Test;
class ScalarUdfTest {
private static final ArrowType INT32 = new ArrowType.Int(32, true);
+ private static final ArrowType INT64 = new ArrowType.Int(64, true);
private static final ArrowType FLOAT64 =
new ArrowType.FloatingPoint(org.apache.arrow.vector.types.FloatingPointPrecision.DOUBLE);
private static final ArrowType UTF8 = new ArrowType.Utf8();
@@ -42,15 +49,15 @@ class ScalarUdfTest {
/** Test-only base that supplies the four metadata getters from constructor args. */
abstract static class AbstractScalarFunction implements ScalarFunction {
private final String name;
- private final List argTypes;
- private final ArrowType returnType;
+ private final List argFields;
+ private final Field returnField;
private final Volatility volatility;
AbstractScalarFunction(
- String name, List argTypes, ArrowType returnType, Volatility volatility) {
+ String name, List argFields, Field returnField, Volatility volatility) {
this.name = name;
- this.argTypes = argTypes;
- this.returnType = returnType;
+ this.argFields = argFields;
+ this.returnField = returnField;
this.volatility = volatility;
}
@@ -60,13 +67,13 @@ public final String name() {
}
@Override
- public final List argTypes() {
- return argTypes;
+ public final List argFields() {
+ return argFields;
}
@Override
- public final ArrowType returnType() {
- return returnType;
+ public final Field returnField() {
+ return returnField;
}
@Override
@@ -82,7 +89,7 @@ static final class AddOne extends AbstractScalarFunction {
}
AddOne(String name, Volatility volatility) {
- super(name, List.of(INT32), INT32, volatility);
+ super(name, List.of(Field.nullable("x", INT32)), Field.nullable("y", INT32), volatility);
}
@Override
@@ -129,7 +136,11 @@ void addOne_overConstantTable_returnsIncrementedValues() throws Exception {
/** Concatenates two Utf8 columns. */
static final class Concat extends AbstractScalarFunction {
Concat() {
- super("java_concat", List.of(UTF8, UTF8), UTF8, Volatility.IMMUTABLE);
+ super(
+ "java_concat",
+ List.of(Field.nullable("a", UTF8), Field.nullable("b", UTF8)),
+ Field.nullable("c", UTF8),
+ Volatility.IMMUTABLE);
}
@Override
@@ -184,7 +195,11 @@ void concat_overVarCharColumns_concatenatesValues() throws Exception {
/** Squares a Float64 column. */
static final class Square extends AbstractScalarFunction {
Square() {
- super("java_square", List.of(FLOAT64), FLOAT64, Volatility.IMMUTABLE);
+ super(
+ "java_square",
+ List.of(Field.nullable("x", FLOAT64)),
+ Field.nullable("y", FLOAT64),
+ Volatility.IMMUTABLE);
}
@Override
@@ -249,7 +264,11 @@ void addOne_invokedTwiceInOneSession_executesIndependently() throws Exception {
static final class ReturnsNull extends AbstractScalarFunction {
ReturnsNull() {
- super("bad_null", List.of(INT32), INT32, Volatility.IMMUTABLE);
+ super(
+ "bad_null",
+ List.of(Field.nullable("x", INT32)),
+ Field.nullable("y", INT32),
+ Volatility.IMMUTABLE);
}
@Override
@@ -280,7 +299,11 @@ void udfReturningNull_surfacesIllegalStateException() {
static final class WrongRowCount extends AbstractScalarFunction {
WrongRowCount() {
- super("bad_rows", List.of(INT32), INT32, Volatility.IMMUTABLE);
+ super(
+ "bad_rows",
+ List.of(Field.nullable("x", INT32)),
+ Field.nullable("y", INT32),
+ Volatility.IMMUTABLE);
}
@Override
@@ -316,7 +339,11 @@ void udfReturningWrongRowCount_surfacesIllegalStateException() {
static final class WrongType extends AbstractScalarFunction {
WrongType() {
- super("bad_type", List.of(INT32), INT32, Volatility.IMMUTABLE);
+ super(
+ "bad_type",
+ List.of(Field.nullable("x", INT32)),
+ Field.nullable("y", INT32),
+ Volatility.IMMUTABLE);
}
@Override
@@ -354,7 +381,11 @@ void udfReturningWrongType_surfacesTypeMismatch() {
static final class ThrowsIAE extends AbstractScalarFunction {
ThrowsIAE() {
- super("boom", List.of(INT32), INT32, Volatility.IMMUTABLE);
+ super(
+ "boom",
+ List.of(Field.nullable("x", INT32)),
+ Field.nullable("y", INT32),
+ Volatility.IMMUTABLE);
}
@Override
@@ -458,7 +489,7 @@ void udfAppliedToMultiRowQuery_processesAllRows() throws Exception {
*/
static final class JavaPi extends AbstractScalarFunction {
JavaPi() {
- super("java_pi", List.of(), FLOAT64, Volatility.VOLATILE);
+ super("java_pi", List.of(), Field.nullable("p", FLOAT64), Volatility.VOLATILE);
}
@Override
@@ -499,7 +530,11 @@ void nullaryScalarReturnUdf_overMultiRowQuery_broadcasts() throws Exception {
*/
static final class AssertSecondArgIsScalar extends AbstractScalarFunction {
AssertSecondArgIsScalar() {
- super("assert_scalar_arg", List.of(INT32, INT32), INT32, Volatility.IMMUTABLE);
+ super(
+ "assert_scalar_arg",
+ List.of(Field.nullable("a", INT32), Field.nullable("b", INT32)),
+ Field.nullable("y", INT32),
+ Volatility.IMMUTABLE);
}
@Override
@@ -560,7 +595,11 @@ void scalarLiteralArg_arrivesAsScalarColumnarValue() throws Exception {
/** UDF that ignores its input and returns a constant Scalar. */
static final class IgnoreInputReturnFortyTwo extends AbstractScalarFunction {
IgnoreInputReturnFortyTwo() {
- super("forty_two", List.of(INT32), INT32, Volatility.IMMUTABLE);
+ super(
+ "forty_two",
+ List.of(Field.nullable("x", INT32)),
+ Field.nullable("y", INT32),
+ Volatility.IMMUTABLE);
}
@Override
@@ -612,4 +651,189 @@ void volatilityBytesRoundTrip_forAllThreeKinds() throws Exception {
}
}
}
+
+ // ---------------------------------------------------------------------
+ // Nested-type UDF tests. These pinpoint the regression #58 fixed.
+ // ---------------------------------------------------------------------
+
+ /**
+ * UDF taking a {@code List} argument and returning its length as Int32. Exercises that
+ * nested Arrow types -- whose element / member types live on the parent {@link Field}'s {@code
+ * children} list, not inside {@link ArrowType} -- can be declared and registered.
+ */
+ static final class ListLength extends AbstractScalarFunction {
+ ListLength() {
+ super(
+ "java_list_length",
+ List.of(
+ new Field(
+ "vals",
+ FieldType.nullable(new ArrowType.List()),
+ List.of(Field.nullable("item", INT32)))),
+ Field.nullable("len", INT32),
+ Volatility.IMMUTABLE);
+ }
+
+ @Override
+ public ColumnarValue evaluate(BufferAllocator allocator, ScalarFunctionArgs args) {
+ ListVector in = (ListVector) args.args().get(0).vector();
+ IntVector out = new IntVector("len_out", allocator);
+ int n = in.getValueCount();
+ out.allocateNew(n);
+ for (int i = 0; i < n; i++) {
+ if (in.isNull(i)) {
+ out.setNull(i);
+ } else {
+ out.set(i, in.getElementEndIndex(i) - in.getElementStartIndex(i));
+ }
+ }
+ out.setValueCount(n);
+ return ColumnarValue.array(out);
+ }
+ }
+
+ @Test
+ void udfWithListArg_canBeRegistered() {
+ // Smoke test: registration alone must succeed for nested-type UDFs. Without #58's fix the
+ // schema-IPC writer rejects List with no children.
+ try (SessionContext ctx = new SessionContext()) {
+ ctx.registerUdf(new ScalarUdf(new ListLength()));
+ }
+ }
+
+ @Test
+ void udfWithListArg_canBeInvokedFromSql() throws Exception {
+ // End-to-end: the registered UDF is callable from SQL with literal list arguments and the
+ // body sees the right element type. SELECT java_list_length([10, 20, 30]) -> 3.
+ try (SessionContext ctx = new SessionContext();
+ BufferAllocator allocator = new RootAllocator()) {
+ ctx.registerUdf(new ScalarUdf(new ListLength()));
+
+ try (DataFrame df =
+ ctx.sql(
+ "SELECT java_list_length(make_array(CAST(10 AS INT), CAST(20 AS INT),"
+ + " CAST(30 AS INT))) AS n");
+ ArrowReader r = df.collect(allocator)) {
+ assertEquals(true, r.loadNextBatch());
+ IntVector n = (IntVector) r.getVectorSchemaRoot().getVector("n");
+ assertEquals(1, n.getValueCount());
+ assertEquals(3, n.get(0));
+ }
+ }
+ }
+
+ /**
+ * UDF taking a {@code Struct} and returning the sum {@code a + b} as Int64.
+ * Confirms the fix is structural rather than List-specific.
+ */
+ static final class SumStructFields extends AbstractScalarFunction {
+ SumStructFields() {
+ super(
+ "java_sum_struct",
+ List.of(
+ new Field(
+ "ab",
+ FieldType.nullable(new ArrowType.Struct()),
+ List.of(Field.nullable("a", INT32), Field.nullable("b", INT32)))),
+ Field.nullable("s", INT64),
+ Volatility.IMMUTABLE);
+ }
+
+ @Override
+ public ColumnarValue evaluate(BufferAllocator allocator, ScalarFunctionArgs args) {
+ StructVector in = (StructVector) args.args().get(0).vector();
+ IntVector a = (IntVector) in.getChild("a");
+ IntVector b = (IntVector) in.getChild("b");
+ org.apache.arrow.vector.BigIntVector out =
+ new org.apache.arrow.vector.BigIntVector("sum_out", allocator);
+ int n = in.getValueCount();
+ out.allocateNew(n);
+ for (int i = 0; i < n; i++) {
+ if (in.isNull(i) || a.isNull(i) || b.isNull(i)) {
+ out.setNull(i);
+ } else {
+ out.set(i, (long) a.get(i) + (long) b.get(i));
+ }
+ }
+ out.setValueCount(n);
+ return ColumnarValue.array(out);
+ }
+ }
+
+ @Test
+ void udfWithStructArg_canBeRegistered() {
+ // Struct child fields ride through the same Field children list as List elements; if the fix
+ // works for List it should work for Struct. This pins that.
+ try (SessionContext ctx = new SessionContext()) {
+ ctx.registerUdf(new ScalarUdf(new SumStructFields()));
+ }
+ }
+
+ @Test
+ void udfWithStructArg_canBeInvokedFromSql() throws Exception {
+ try (SessionContext ctx = new SessionContext();
+ BufferAllocator allocator = new RootAllocator()) {
+ ctx.registerUdf(new ScalarUdf(new SumStructFields()));
+
+ try (DataFrame df =
+ ctx.sql(
+ "SELECT java_sum_struct(named_struct('a', CAST(3 AS INT), 'b', CAST(4 AS INT)))"
+ + " AS s");
+ ArrowReader r = df.collect(allocator)) {
+ assertEquals(true, r.loadNextBatch());
+ org.apache.arrow.vector.BigIntVector s =
+ (org.apache.arrow.vector.BigIntVector) r.getVectorSchemaRoot().getVector("s");
+ assertEquals(1, s.getValueCount());
+ assertEquals(7L, s.get(0));
+ }
+ }
+ }
+
+ /**
+ * UDF declaring a non-nullable return Field. DataFusion's default {@code return_field_from_args}
+ * wraps the return type in a fresh always-nullable Field, so without the {@code
+ * JavaScalarUdf::return_field_from_args} override the planner sees this UDF's output as nullable
+ * even though the Java caller said otherwise.
+ */
+ static final class NonNullableConstOne extends AbstractScalarFunction {
+ NonNullableConstOne() {
+ super(
+ "java_const_one_nn",
+ List.of(),
+ new Field("v", new FieldType(false, INT32, null), null),
+ Volatility.IMMUTABLE);
+ }
+
+ @Override
+ public ColumnarValue evaluate(BufferAllocator allocator, ScalarFunctionArgs args) {
+ // Nullary -- broadcast a single value as a Scalar. The framework expands it to
+ // args.rowCount() rows downstream, so we only need a length-1 vector here.
+ IntVector out = new IntVector("out", allocator);
+ out.allocateNew(1);
+ out.set(0, 1);
+ out.setValueCount(1);
+ return ColumnarValue.scalar(out);
+ }
+ }
+
+ @Test
+ void udfWithNonNullableReturnField_preservesNullabilityInResultSchema() throws Exception {
+ // The result column's schema must reflect the declared non-nullable return Field. Before the
+ // fix the JavaScalarUdf only stored a DataType; DataFusion's default return_field_from_args
+ // synthesised a fresh always-nullable Field, so the column came back as nullable.
+ try (SessionContext ctx = new SessionContext();
+ BufferAllocator allocator = new RootAllocator()) {
+ ctx.registerUdf(new ScalarUdf(new NonNullableConstOne()));
+
+ try (DataFrame df = ctx.sql("SELECT java_const_one_nn() AS v");
+ ArrowReader r = df.collect(allocator)) {
+ assertTrue(r.loadNextBatch());
+ Field resultField = r.getVectorSchemaRoot().getSchema().findField("v");
+ assertFalse(
+ resultField.isNullable(),
+ "expected declared non-nullable return Field to round-trip through registration,"
+ + " got nullable=true");
+ }
+ }
+ }
}
diff --git a/examples/src/main/java/org/apache/datafusion/examples/AddOneExample.java b/examples/src/main/java/org/apache/datafusion/examples/AddOneExample.java
index d9416b1..a6ecdfd 100644
--- a/examples/src/main/java/org/apache/datafusion/examples/AddOneExample.java
+++ b/examples/src/main/java/org/apache/datafusion/examples/AddOneExample.java
@@ -27,6 +27,7 @@
import org.apache.arrow.vector.VectorSchemaRoot;
import org.apache.arrow.vector.ipc.ArrowReader;
import org.apache.arrow.vector.types.pojo.ArrowType;
+import org.apache.arrow.vector.types.pojo.Field;
import org.apache.datafusion.ColumnarValue;
import org.apache.datafusion.DataFrame;
import org.apache.datafusion.ScalarFunction;
@@ -48,13 +49,13 @@ public String name() {
}
@Override
- public List argTypes() {
- return List.of(INT32);
+ public List argFields() {
+ return List.of(Field.nullable("x", INT32));
}
@Override
- public ArrowType returnType() {
- return INT32;
+ public Field returnField() {
+ return Field.nullable("y", INT32);
}
@Override
diff --git a/native/src/lib.rs b/native/src/lib.rs
index 1472628..114e070 100644
--- a/native/src/lib.rs
+++ b/native/src/lib.rs
@@ -732,7 +732,7 @@ pub extern "system" fn Java_org_apache_datafusion_SessionContext_registerScalarU
if fields.is_empty() {
return Err("signature schema must have at least a return-type field".into());
}
- let return_type = fields[0].data_type().clone();
+ let return_field = fields[0].clone();
let arg_types: Vec = fields
.iter()
.skip(1)
@@ -755,7 +755,7 @@ pub extern "system" fn Java_org_apache_datafusion_SessionContext_registerScalarU
let java_udf = crate::udf::JavaScalarUdf {
name: name.clone(),
signature,
- return_type,
+ return_field,
udf_global_ref,
bridge_class,
invoke_method,
diff --git a/native/src/udf.rs b/native/src/udf.rs
index d2b18b4..daa9ea6 100644
--- a/native/src/udf.rs
+++ b/native/src/udf.rs
@@ -21,12 +21,13 @@ use std::any::Any;
use std::fmt;
use datafusion::arrow::array::{make_array, Array, ArrayRef, StructArray};
-use datafusion::arrow::datatypes::{DataType, Field, Fields};
+use datafusion::arrow::datatypes::{DataType, Field, FieldRef, Fields};
use datafusion::arrow::ffi::{from_ffi, to_ffi, FFI_ArrowArray, FFI_ArrowSchema};
use datafusion::common::ScalarValue;
use datafusion::error::DataFusionError;
use datafusion::logical_expr::{
- ColumnarValue, ScalarFunctionArgs, ScalarUDFImpl, Signature, TypeSignature, Volatility,
+ ColumnarValue, ReturnFieldArgs, ScalarFunctionArgs, ScalarUDFImpl, Signature, TypeSignature,
+ Volatility,
};
use jni::objects::{GlobalRef, JStaticMethodID, JThrowable};
use jni::signature::{Primitive, ReturnType};
@@ -36,7 +37,10 @@ use jni::JNIEnv;
pub(crate) struct JavaScalarUdf {
pub(crate) name: String,
pub(crate) signature: Signature,
- pub(crate) return_type: DataType,
+ /// The full return Field as the Java caller declared it. Carries the data type plus
+ /// nullability and any metadata; reused as both `return_type()` and the result of
+ /// `return_field_from_args()` so callers see the user's declaration verbatim.
+ pub(crate) return_field: FieldRef,
/// Global ref to the user's `org.apache.datafusion.ScalarFunction` instance.
pub(crate) udf_global_ref: GlobalRef,
/// Global ref to the `org.apache.datafusion.internal.JniBridge` class.
@@ -56,7 +60,7 @@ impl fmt::Debug for JavaScalarUdf {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("JavaScalarUdf")
.field("name", &self.name)
- .field("return_type", &self.return_type)
+ .field("return_field", &self.return_field)
.finish()
}
}
@@ -90,7 +94,17 @@ impl ScalarUDFImpl for JavaScalarUdf {
}
fn return_type(&self, _arg_types: &[DataType]) -> datafusion::error::Result {
- Ok(self.return_type.clone())
+ Ok(self.return_field.data_type().clone())
+ }
+
+ fn return_field_from_args(
+ &self,
+ _args: ReturnFieldArgs,
+ ) -> datafusion::error::Result {
+ // The default impl wraps return_type() in a fresh Field that's always nullable and
+ // carries no metadata. We hold the user's declared Field verbatim, so return it -- this
+ // preserves the declared nullability and any metadata they attached.
+ Ok(self.return_field.clone())
}
fn invoke_with_args(
@@ -249,12 +263,13 @@ impl ScalarUDFImpl for JavaScalarUdf {
let result_data = unsafe { from_ffi(result_array_ffi, &result_schema_ffi) }
.map_err(|e| DataFusionError::ArrowError(Box::new(e), None))?;
- if result_data.data_type() != &self.return_type {
+ // 9. Validate type.
+ if result_data.data_type() != self.return_field.data_type() {
return Err(DataFusionError::Execution(format!(
"Java UDF '{}' returned vector of type {:?}; declared return type was {:?}",
self.name,
result_data.data_type(),
- self.return_type
+ self.return_field.data_type()
)));
}