Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
49 changes: 49 additions & 0 deletions core/src/main/java/org/apache/datafusion/ScalarUdf.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/

package org.apache.datafusion;

import java.util.List;

import org.apache.arrow.memory.BufferAllocator;
import org.apache.arrow.vector.FieldVector;

/**
* A Java-implemented scalar SQL function. Register an instance with {@link
* SessionContext#registerUdf} to make it callable from SQL or DataFrame plans.
*
* <p>Implementations may be invoked concurrently by DataFusion on multiple worker threads. If the
* implementation carries mutable state, the implementation must synchronize it.
*/
@FunctionalInterface
public interface ScalarUdf {
/**
* Compute the function result for one input batch.
*
* @param allocator the {@link BufferAllocator} that MUST be used for any new {@link FieldVector}
* allocation, including the result. Buffers allocated from other allocators will not survive
* the JNI handoff.
* @param args one {@link FieldVector} per declared argument, all of the same length. These are
* read-only views; the implementation must NOT close them.
* @return a {@link FieldVector} of the declared return type and the same length as the inputs.
* Ownership transfers to the framework on return; the implementation must NOT close the
* returned vector.
*/
FieldVector evaluate(BufferAllocator allocator, List<FieldVector> args);
}
43 changes: 43 additions & 0 deletions core/src/main/java/org/apache/datafusion/SessionContext.java
Original file line number Diff line number Diff line change
Expand Up @@ -23,13 +23,18 @@
import java.io.ByteArrayOutputStream;
import java.io.IOException;
import java.nio.channels.Channels;
import java.util.ArrayList;
import java.util.List;

import org.apache.arrow.memory.BufferAllocator;
import org.apache.arrow.memory.RootAllocator;
import org.apache.arrow.vector.VectorSchemaRoot;
import org.apache.arrow.vector.ipc.ArrowStreamWriter;
import org.apache.arrow.vector.ipc.ReadChannel;
import org.apache.arrow.vector.ipc.message.MessageSerializer;
import org.apache.arrow.vector.types.pojo.ArrowType;
import org.apache.arrow.vector.types.pojo.Field;
import org.apache.arrow.vector.types.pojo.FieldType;
import org.apache.arrow.vector.types.pojo.Schema;

/**
Expand Down Expand Up @@ -204,6 +209,41 @@ public DataFrame readParquet(String path, ParquetReadOptions options) {
return new DataFrame(dfHandle);
}

/**
* Register a Java-implemented scalar UDF. After registration, the function can be invoked by SQL
* via its {@code name} or referenced in DataFusion plans deserialised with {@link #fromProto}.
*
* <p>Argument and return types are declared at registration time. The UDF is registered with an
* exact signature: the runtime will reject calls whose argument types do not match {@code
* argTypes} exactly.
*
* @throws RuntimeException if registration fails (e.g., name already registered with an
* incompatible signature, schema serialisation failure).
*/
public void registerUdf(
String name,
ScalarUdf udf,
ArrowType returnType,
List<ArrowType> argTypes,
Volatility volatility) {
if (nativeHandle == 0) {
throw new IllegalStateException("SessionContext is closed");
}
java.util.Objects.requireNonNull(name, "name");
java.util.Objects.requireNonNull(udf, "udf");
java.util.Objects.requireNonNull(returnType, "returnType");
java.util.Objects.requireNonNull(argTypes, "argTypes");
java.util.Objects.requireNonNull(volatility, "volatility");
List<Field> fields = new ArrayList<>(argTypes.size() + 1);
fields.add(new Field("return", FieldType.nullable(returnType), null));
for (int i = 0; i < argTypes.size(); i++) {
fields.add(new Field("arg" + i, FieldType.nullable(argTypes.get(i)), null));
}
Schema signatureSchema = new Schema(fields);
byte[] signatureBytes = serializeSchemaIpc(signatureSchema);
registerScalarUdf(nativeHandle, name, signatureBytes, volatility.code(), udf);
}

private static byte[] serializeSchemaIpc(Schema schema) {
ByteArrayOutputStream baos = new ByteArrayOutputStream();
try (BufferAllocator allocator = new RootAllocator();
Expand Down Expand Up @@ -248,4 +288,7 @@ private static native long readCsvWithOptions(
long handle, String path, byte[] optionsBytes, byte[] schemaIpcBytes);

private static native void closeSessionContext(long handle);

private static native void registerScalarUdf(
long handle, String name, byte[] signatureSchemaBytes, byte volatility, ScalarUdf udf);
}
48 changes: 48 additions & 0 deletions core/src/main/java/org/apache/datafusion/Volatility.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/

package org.apache.datafusion;

/**
* Volatility classification for a UDF. Mirrors DataFusion's {@code Volatility} enum.
*
* <ul>
* <li>{@link #IMMUTABLE} — pure function: same inputs always produce the same output; safe to
* constant-fold and common-subexpression-eliminate.
* <li>{@link #STABLE} — deterministic within a single query but not across queries (e.g., {@code
* now()}).
* <li>{@link #VOLATILE} — may return a different value on every call (e.g., {@code random()}).
* </ul>
*/
public enum Volatility {
IMMUTABLE((byte) 0),
STABLE((byte) 1),
VOLATILE((byte) 2);

private final byte code;

Volatility(byte code) {
this.code = code;
}

/** Stable byte code for FFI. */
public byte code() {
return code;
}
}
93 changes: 93 additions & 0 deletions core/src/main/java/org/apache/datafusion/internal/JniBridge.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,93 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/

package org.apache.datafusion.internal;

import java.util.List;

import org.apache.arrow.c.ArrowArray;
import org.apache.arrow.c.ArrowSchema;
import org.apache.arrow.c.Data;
import org.apache.arrow.memory.RootAllocator;
import org.apache.arrow.vector.FieldVector;
import org.apache.arrow.vector.VectorSchemaRoot;
import org.apache.datafusion.ScalarUdf;

/** Internal trampoline invoked from native code on every UDF call. Not part of the public API. */
public final class JniBridge {
/**
* Shared allocator for UDF inputs/outputs. Created once at class-load time; never closed.
* Outstanding allocations are released by the FFI structs' release callbacks when the native side
* drops them after {@code from_ffi}.
*/
private static final RootAllocator ALLOCATOR = new RootAllocator();

private JniBridge() {}

/**
* Invoke a scalar UDF for one batch. Called from native code; not for application use.
*
* @param udf the registered {@link ScalarUdf} instance
* @param argsArrayAddr address of a populated {@code FFI_ArrowArray} struct holding the input
* batch as a struct array (one field per UDF argument)
* @param argsSchemaAddr address of the matching {@code FFI_ArrowSchema}
* @param resultArrayAddr address of an empty {@code FFI_ArrowArray} the bridge writes into
* @param resultSchemaAddr address of an empty {@code FFI_ArrowSchema} the bridge writes into
* @param expectedRowCount the row count the result vector must have
*/
public static void invokeScalarUdf(
ScalarUdf udf,
long argsArrayAddr,
long argsSchemaAddr,
long resultArrayAddr,
long resultSchemaAddr,
int expectedRowCount) {
ArrowArray argsArr = ArrowArray.wrap(argsArrayAddr);
ArrowSchema argsSch = ArrowSchema.wrap(argsSchemaAddr);
ArrowArray resultArr = ArrowArray.wrap(resultArrayAddr);
ArrowSchema resultSch = ArrowSchema.wrap(resultSchemaAddr);

try (VectorSchemaRoot root = Data.importVectorSchemaRoot(ALLOCATOR, argsArr, argsSch, null)) {
List<FieldVector> argVectors = root.getFieldVectors();

FieldVector result = udf.evaluate(ALLOCATOR, argVectors);

if (result == null) {
throw new IllegalStateException("ScalarUdf.evaluate returned null");
}
if (result.getValueCount() != expectedRowCount) {
try {
throw new IllegalStateException(
"ScalarUdf.evaluate returned vector with "
+ result.getValueCount()
+ " rows; expected "
+ expectedRowCount);
} finally {
result.close();
}
}

try {
Data.exportVector(ALLOCATOR, result, null, resultArr, resultSch);
} finally {
result.close();
}
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -246,8 +246,7 @@ void limitRejectsNegativeArgs() {
@Test
void distinctRemovesDuplicates() {
try (SessionContext ctx = new SessionContext();
DataFrame source =
ctx.sql("SELECT * FROM (VALUES (1), (1), (2), (2), (3)) AS t(x)");
DataFrame source = ctx.sql("SELECT * FROM (VALUES (1), (1), (2), (2), (3)) AS t(x)");
DataFrame deduped = source.distinct()) {
assertEquals(3L, deduped.count());
}
Expand Down
Loading
Loading