Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 14 additions & 5 deletions core/src/main/java/org/apache/datafusion/ScalarFunction.java
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
import java.util.List;

import org.apache.arrow.memory.BufferAllocator;
import org.apache.arrow.vector.types.pojo.ArrowType;
import org.apache.arrow.vector.types.pojo.Field;

/**
* A Java-implemented scalar SQL function. Implementations declare their own name, signature, and
Expand All @@ -40,15 +40,24 @@ public interface ScalarFunction {
String name();

/**
* Declared argument types, in positional order. The function is registered with an exact
* Declared argument fields, in positional order. The function is registered with an exact
* signature; calls whose argument types do not match exactly are rejected.
*
* <p>Each entry is an Arrow {@link Field} -- a name plus a {@code FieldType} plus an optional
* list of child fields. Use {@link Field#nullable(String,
* org.apache.arrow.vector.types.pojo.ArrowType)} for primitive types (e.g. {@code
* Field.nullable("arg0", new ArrowType.Int(32, true))}). Nested types like {@code List}, {@code
* Struct}, and {@code Map} require the children list to carry element / member / key / value type
* information; constructing a {@code Field} via {@code new Field(name, FieldType, children)} is
* the canonical Arrow way to do that.
*/
List<ArrowType> argTypes();
List<Field> argFields();

/**
* Declared return type. The returned {@link ColumnarValue}'s vector must have this exact type.
* Declared return field. The returned {@link ColumnarValue}'s vector must have this exact type,
* including any nested children. Same construction rules as {@link #argFields()}.
*/
ArrowType returnType();
Field returnField();

/**
* Volatility classification. Use {@link Volatility#IMMUTABLE} for pure functions, {@link
Expand Down
22 changes: 11 additions & 11 deletions core/src/main/java/org/apache/datafusion/ScalarUdf.java
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
import java.util.List;
import java.util.Objects;

import org.apache.arrow.vector.types.pojo.ArrowType;
import org.apache.arrow.vector.types.pojo.Field;

/**
* A scalar UDF registration handle: pairs a {@link ScalarFunction} implementation with the metadata
Expand All @@ -35,8 +35,8 @@
public final class ScalarUdf {
private final ScalarFunction impl;
private final String name;
private final List<ArrowType> argTypes;
private final ArrowType returnType;
private final List<Field> argFields;
private final Field returnField;
private final Volatility volatility;

/**
Expand All @@ -48,8 +48,8 @@ public final class ScalarUdf {
public ScalarUdf(ScalarFunction impl) {
this.impl = Objects.requireNonNull(impl, "impl");
this.name = Objects.requireNonNull(impl.name(), "impl.name()");
this.argTypes = Objects.requireNonNull(impl.argTypes(), "impl.argTypes()");
this.returnType = Objects.requireNonNull(impl.returnType(), "impl.returnType()");
this.argFields = Objects.requireNonNull(impl.argFields(), "impl.argFields()");
this.returnField = Objects.requireNonNull(impl.returnField(), "impl.returnField()");
this.volatility = Objects.requireNonNull(impl.volatility(), "impl.volatility()");
}

Expand All @@ -68,14 +68,14 @@ public String name() {
return name;
}

/** Declared argument types; cached from {@link ScalarFunction#argTypes()}. */
public List<ArrowType> argTypes() {
return argTypes;
/** Declared argument fields; cached from {@link ScalarFunction#argFields()}. */
public List<Field> argFields() {
return argFields;
}

/** Declared return type; cached from {@link ScalarFunction#returnType()}. */
public ArrowType returnType() {
return returnType;
/** Declared return field; cached from {@link ScalarFunction#returnField()}. */
public Field returnField() {
return returnField;
}

/** Volatility classification; cached from {@link ScalarFunction#volatility()}. */
Expand Down
14 changes: 4 additions & 10 deletions core/src/main/java/org/apache/datafusion/SessionContext.java
Original file line number Diff line number Diff line change
Expand Up @@ -32,9 +32,7 @@
import org.apache.arrow.vector.ipc.ArrowStreamWriter;
import org.apache.arrow.vector.ipc.ReadChannel;
import org.apache.arrow.vector.ipc.message.MessageSerializer;
import org.apache.arrow.vector.types.pojo.ArrowType;
import org.apache.arrow.vector.types.pojo.Field;
import org.apache.arrow.vector.types.pojo.FieldType;
import org.apache.arrow.vector.types.pojo.Schema;

/**
Expand Down Expand Up @@ -367,7 +365,7 @@ public DataFrame readArrow(String path, ArrowReadOptions options) {
* via the UDF's name or referenced in DataFusion plans deserialised with {@link #fromProto}.
*
* <p>The UDF is registered with an exact signature: the runtime will reject calls whose argument
* types do not match the declared {@link ScalarFunction#argTypes()} exactly.
* types do not match the declared {@link ScalarFunction#argFields()} exactly.
*
* @throws RuntimeException if registration fails (e.g., name already registered with an
* incompatible signature, schema serialisation failure).
Expand All @@ -379,14 +377,10 @@ public void registerUdf(ScalarUdf udf) {
java.util.Objects.requireNonNull(udf, "udf");
ScalarFunction impl = udf.impl();
String name = udf.name();
ArrowType returnType = udf.returnType();
List<ArrowType> argTypes = udf.argTypes();
Volatility volatility = udf.volatility();
List<Field> fields = new ArrayList<>(argTypes.size() + 1);
fields.add(new Field("return", FieldType.nullable(returnType), null));
for (int i = 0; i < argTypes.size(); i++) {
fields.add(new Field("arg" + i, FieldType.nullable(argTypes.get(i)), null));
}
List<Field> fields = new ArrayList<>(udf.argFields().size() + 1);
fields.add(udf.returnField());
fields.addAll(udf.argFields());
Schema signatureSchema = new Schema(fields);
byte[] signatureBytes = serializeSchemaIpc(signatureSchema);
registerScalarUdf(nativeHandle, name, signatureBytes, volatility.code(), impl);
Expand Down
Loading