diff --git a/pom.xml b/pom.xml
index d7dac399c2aed..0e961f3cb9b2b 100644
--- a/pom.xml
+++ b/pom.xml
@@ -113,6 +113,7 @@
connector/protobuf
udf/worker/proto
udf/worker/core
+ udf/worker/grpc
diff --git a/udf/worker/README.md b/udf/worker/README.md
index b843c430d0e04..199e51ee7d2ba 100644
--- a/udf/worker/README.md
+++ b/udf/worker/README.md
@@ -19,7 +19,7 @@ WorkerDispatcher -- manages workers, creates sessions
|
v
WorkerSession -- one UDF execution
- | 1. session.init(InitMessage(payload, inputSchema, outputSchema))
+ | 1. session.init(Init proto: udf payload + data format + schemas)
| 2. val results = session.process(inputBatches)
| 3. session.close()
```
@@ -34,19 +34,22 @@ provisioning service or daemon).
```
udf/worker/
├── proto/
-│ worker_spec.proto -- UDFWorkerSpecification protobuf (+ generated Java classes)
-│ common.proto -- shared enums (UDFWorkerDataFormat, etc.)
+│ worker_spec.proto -- UDFWorkerSpecification protobuf
+│ udf_protocol.proto -- UDF execution protocol (Init, UdfPayload, ...)
+│ common.proto -- shared enums (UdfWorkerDataFormat, etc.)
│
└── core/ -- abstract interfaces
WorkerDispatcher.scala -- creates sessions, manages worker lifecycle
- WorkerSession.scala -- per-UDF init/process/cancel/close + InitMessage
+ WorkerSession.scala -- per-UDF init/process/cancel/close
+ WorkerSessionFactory.scala -- protocol-level connection + session factory
WorkerConnection.scala -- transport channel abstraction
WorkerSecurityScope.scala -- security boundary for worker pooling
│
└── direct/ -- "direct" creation: local OS processes
DirectWorkerDispatcher.scala -- spawns processes, env lifecycle
DirectWorkerProcess.scala -- OS process + connection + UDS socket
- DirectWorkerSession.scala -- session backed by a direct process
+ DirectWorkerSession.scala -- decorator: forwards to inner session,
+ releases worker ref-count on close
```
The `core/` package defines abstract interfaces that are independent of how
@@ -76,10 +79,12 @@ Workers are terminated via SIGTERM/SIGKILL when the dispatcher is closed.
```scala
import org.apache.spark.udf.worker.{
- DirectWorker, ProcessCallable, UDFProtoCommunicationPattern,
- UDFWorkerDataFormat, UDFWorkerProperties, UDFWorkerSpecification,
- UnixDomainSocket, WorkerCapabilities, WorkerConnectionSpec, WorkerEnvironment}
+ DirectWorker, Init, ProcessCallable, UdfPayload,
+ UdfProtoCommunicationPattern, UdfWorkerDataFormat, UDFWorkerProperties,
+ UDFWorkerSpecification, UnixDomainSocket, WorkerCapabilities,
+ WorkerConnectionSpec, WorkerEnvironment}
import org.apache.spark.udf.worker.core._
+import com.google.protobuf.ByteString
// 1. Define a worker spec (direct creation mode).
val spec = UDFWorkerSpecification.newBuilder()
@@ -90,9 +95,9 @@ val spec = UDFWorkerSpecification.newBuilder()
.addCommand("pip").addCommand("install").addCommand("my_udf_worker").build())
.build())
.setCapabilities(WorkerCapabilities.newBuilder()
- .addSupportedDataFormats(UDFWorkerDataFormat.ARROW)
+ .addSupportedDataFormats(UdfWorkerDataFormat.ARROW)
.addSupportedCommunicationPatterns(
- UDFProtoCommunicationPattern.BIDIRECTIONAL_STREAMING)
+ UdfProtoCommunicationPattern.BIDIRECTIONAL_STREAMING)
.build())
.setDirect(DirectWorker.newBuilder()
.setRunner(ProcessCallable.newBuilder()
@@ -112,10 +117,14 @@ val dispatcher: WorkerDispatcher = ...
val session = dispatcher.createSession(securityScope = None)
try {
// 4. Initialize with the serialized function and schemas.
- session.init(InitMessage(
- functionPayload = serializedFunction,
- inputSchema = arrowInputSchema,
- outputSchema = arrowOutputSchema))
+ session.init(Init.newBuilder()
+ .setUdf(UdfPayload.newBuilder()
+ .setPayload(ByteString.copyFrom(serializedFunction))
+ .setFormat("py-cloudpickle-v3"))
+ .setDataFormat(UdfWorkerDataFormat.ARROW)
+ .setInputSchema(ByteString.copyFrom(arrowInputSchema))
+ .setOutputSchema(ByteString.copyFrom(arrowOutputSchema))
+ .build())
// 5. Process data -- Iterator in, Iterator out.
val results: Iterator[Array[Byte]] =
diff --git a/udf/worker/core/src/main/scala/org/apache/spark/udf/worker/core/WorkerSession.scala b/udf/worker/core/src/main/scala/org/apache/spark/udf/worker/core/WorkerSession.scala
index f4c4091688c94..4a7701981f223 100644
--- a/udf/worker/core/src/main/scala/org/apache/spark/udf/worker/core/WorkerSession.scala
+++ b/udf/worker/core/src/main/scala/org/apache/spark/udf/worker/core/WorkerSession.scala
@@ -19,31 +19,7 @@ package org.apache.spark.udf.worker.core
import java.util.concurrent.atomic.AtomicBoolean
import org.apache.spark.annotation.Experimental
-
-/**
- * :: Experimental ::
- * Carries all information needed to initialize a UDF execution on a worker.
- *
- * This message is passed to [[WorkerSession#init]] and contains the function
- * definition, schemas, and any additional configuration.
- *
- * Placeholder: will be replaced by a generated proto message once the
- * UDF wire protocol lands. Do not rely on case-class equality --
- * `Array[Byte]` fields compare by reference.
- *
- * @param functionPayload serialized function (e.g., pickled Python, JVM bytes)
- * @param inputSchema serialized input schema (e.g., Arrow schema bytes)
- * @param outputSchema serialized output schema (e.g., Arrow schema bytes)
- * @param properties additional key-value configuration. Can carry
- * protocol-specific or engine-specific metadata that
- * does not yet have a dedicated field.
- */
-@Experimental
-case class InitMessage(
- functionPayload: Array[Byte],
- inputSchema: Array[Byte],
- outputSchema: Array[Byte],
- properties: Map[String, String] = Map.empty)
+import org.apache.spark.udf.worker.Init
/**
* :: Experimental ::
@@ -62,7 +38,10 @@ case class InitMessage(
* {{{
* val session = dispatcher.createSession(securityScope = None)
* try {
- * session.init(InitMessage(functionPayload, inputSchema, outputSchema))
+ * session.init(Init.newBuilder()
+ * .setUdf(UdfPayload.newBuilder().setPayload(callable).setFormat(fmt))
+ * .setDataFormat(UdfWorkerDataFormat.ARROW)
+ * .build())
* val results = session.process(inputBatches)
* results.foreach(handleBatch)
* } finally {
@@ -74,7 +53,18 @@ case class InitMessage(
* - [[init]] must be called exactly once before [[process]].
* - [[process]] must be called at most once per session.
* - [[close]] must always be called (use try-finally).
- * - [[cancel]] may be called at any time to abort execution.
+ * - [[cancel]] may be called at any time, including before [[init]]
+ * or after [[process]]/[[close]] has returned. Implementations
+ * treat such calls as a no-op so that callers driven by a task
+ * interruption listener (which has no view into the session state)
+ * do not need to coordinate with the thread driving [[process]].
+ *
+ * Cancel-vs-finish race: when the session driver has finished
+ * sending input (and therefore queued an implicit finish on the
+ * underlying transport) and a [[cancel]] arrives concurrently, both
+ * are valid stream-terminating actions; the response side carries
+ * either a `FinishResponse` or a `CancelResponse` depending on which
+ * the worker observes first, and either is acceptable to the caller.
*
* The lifecycle is enforced here: [[init]] and [[process]] are `final`
* and delegate to [[doInit]] / [[doProcess]] after AtomicBoolean guards.
@@ -93,10 +83,12 @@ abstract class WorkerSession extends AutoCloseable {
*
* Throws `IllegalStateException` if called more than once.
*
- * @param message the initialization parameters including the serialized
- * function, input/output schemas, and configuration.
+ * @param message the [[Init]] proto carrying the UDF body, the wire
+ * data format, optional input/output schemas, and any
+ * engine-side session context the worker needs to start
+ * processing.
*/
- final def init(message: InitMessage): Unit = {
+ final def init(message: Init): Unit = {
if (!initialized.compareAndSet(false, true)) {
throw new IllegalStateException("init has already been called on this session")
}
@@ -128,7 +120,7 @@ abstract class WorkerSession extends AutoCloseable {
}
/** Subclass hook for [[init]]. Called once, after the guard. */
- protected def doInit(message: InitMessage): Unit
+ protected def doInit(message: Init): Unit
/** Subclass hook for [[process]]. Called at most once, after the guard. */
protected def doProcess(input: Iterator[Array[Byte]]): Iterator[Array[Byte]]
@@ -138,11 +130,29 @@ abstract class WorkerSession extends AutoCloseable {
*
* '''Thread-safety:''' implementations must allow [[cancel]] to be called
* from a thread different from the one driving [[process]] (typically a
- * task interruption thread). It may be invoked at any point after
- * [[init]] and should be a no-op if execution has already finished.
+ * task interruption thread).
+ *
+ * '''Lifecycle:''' [[cancel]] is idempotent and safe at any point in
+ * the session's life:
+ * - before [[init]] -- nothing has been sent on the transport yet,
+ * so [[cancel]] is a no-op (the session may still be closed
+ * normally via [[close]]).
+ * - between [[init]] and [[process]] -- transitions the session
+ * into a cancelled state; subsequent [[process]] calls observe
+ * the cancellation.
+ * - during [[process]] -- aborts the active stream.
+ * - after [[process]] / [[close]] has returned -- a no-op.
+ *
+ * Implementations are responsible for the no-op behavior described
+ * above so that callers (e.g. task interruption listeners) do not
+ * need to coordinate with the thread driving [[process]].
*/
def cancel(): Unit
- /** Closes this session and releases resources. */
+ /**
+ * Closes this session and releases resources. Idempotent; safe to
+ * call from a `finally` block regardless of whether [[init]],
+ * [[process]], or [[cancel]] have been invoked.
+ */
override def close(): Unit
}
diff --git a/udf/worker/core/src/main/scala/org/apache/spark/udf/worker/core/WorkerSessionFactory.scala b/udf/worker/core/src/main/scala/org/apache/spark/udf/worker/core/WorkerSessionFactory.scala
new file mode 100644
index 0000000000000..005e902a4a69e
--- /dev/null
+++ b/udf/worker/core/src/main/scala/org/apache/spark/udf/worker/core/WorkerSessionFactory.scala
@@ -0,0 +1,66 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.spark.udf.worker.core
+
+import org.apache.spark.annotation.Experimental
+
+/**
+ * :: Experimental ::
+ * Protocol-level factory for [[WorkerConnection]]s and per-execution
+ * [[WorkerSession]]s.
+ *
+ * The factory is responsible for everything that depends on the protocol
+ * the worker speaks (e.g. gRPC) -- building the transport-level connection
+ * and producing per-execution sessions on top of it. It is '''not'''
+ * responsible for creating, terminating, or otherwise managing the worker
+ * itself; that is the [[WorkerDispatcher]]'s job. As long as the dispatcher
+ * gives the factory an address to dial, and later a connection to talk
+ * over, the factory does not need to know whether the worker is a
+ * locally-spawned process, a daemon, or a remote service.
+ *
+ * One factory instance is owned by one dispatcher and is closed by the
+ * dispatcher. Implementations that own protocol-level shared resources
+ * (e.g. a Netty event loop group) should release them in [[close]].
+ */
+@Experimental
+trait WorkerSessionFactory extends AutoCloseable {
+
+ /**
+ * Builds a transport-level connection to a worker reachable at
+ * `address`. The address format is whatever the dispatcher has agreed
+ * with the worker spec's transport (e.g. a UDS path, a TCP host:port).
+ *
+ * The returned connection is owned by the dispatcher's worker handle;
+ * its lifecycle is tied to worker teardown.
+ */
+ def createConnection(address: String): WorkerConnection
+
+ /**
+ * Creates a per-execution [[WorkerSession]] over an already-established
+ * [[WorkerConnection]]. The factory does not own the connection -- the
+ * worker handle does -- and a connection may back many sessions over
+ * its lifetime.
+ */
+ def createSession(connection: WorkerConnection): WorkerSession
+
+ /**
+ * Releases factory-level resources (event loops, thread pools, ...).
+ * Called by the dispatcher on close. Default is a no-op so factories
+ * that have nothing to release can omit the override.
+ */
+ override def close(): Unit = ()
+}
diff --git a/udf/worker/core/src/main/scala/org/apache/spark/udf/worker/core/direct/DirectUnixSocketWorkerDispatcher.scala b/udf/worker/core/src/main/scala/org/apache/spark/udf/worker/core/direct/DirectUnixSocketWorkerDispatcher.scala
index 8da0354187e4f..e608187687cb1 100644
--- a/udf/worker/core/src/main/scala/org/apache/spark/udf/worker/core/direct/DirectUnixSocketWorkerDispatcher.scala
+++ b/udf/worker/core/src/main/scala/org/apache/spark/udf/worker/core/direct/DirectUnixSocketWorkerDispatcher.scala
@@ -22,7 +22,7 @@ import java.nio.file.attribute.PosixFilePermissions
import org.apache.spark.annotation.Experimental
import org.apache.spark.udf.worker.UDFWorkerSpecification
-import org.apache.spark.udf.worker.core.{UnixSocketWorkerConnection, WorkerLogger}
+import org.apache.spark.udf.worker.core.{WorkerLogger, WorkerSessionFactory}
import org.apache.spark.udf.worker.core.direct.DirectWorkerDispatcher.SOCKET_POLL_INTERVAL_MS
/**
@@ -31,14 +31,20 @@ import org.apache.spark.udf.worker.core.direct.DirectWorkerDispatcher.SOCKET_POL
* transport. Allocates a private 0700 socket directory at construction;
* each worker is given a UDS path inside it.
*
- * Concrete subclasses implement [[createConnection]] (with a UDS protocol
- * of choice) and [[createSessionForWorker]].
+ * Pair this dispatcher with any [[WorkerSessionFactory]] whose
+ * [[WorkerSessionFactory#createConnection]] knows how to dial a UDS path
+ * (e.g. a gRPC-over-UDS factory).
+ *
+ * @param sessionFactory protocol-specific connection + session factory.
+ * Owned by this dispatcher; closed by the base on
+ * dispatcher close.
*/
@Experimental
-abstract class DirectUnixSocketWorkerDispatcher(
+class DirectUnixSocketWorkerDispatcher(
workerSpec: UDFWorkerSpecification,
+ sessionFactory: WorkerSessionFactory,
logger: WorkerLogger = WorkerLogger.NoOp)
- extends DirectWorkerDispatcher(workerSpec, logger) {
+ extends DirectWorkerDispatcher(workerSpec, sessionFactory, logger) {
// Removed explicitly in closeTransport(). deleteOnExit is avoided because
// the JDK retains the path for the JVM lifetime, which leaks in
@@ -100,8 +106,6 @@ abstract class DirectUnixSocketWorkerDispatcher(
s"got ${conn.getTransportCase}")
}
- override protected def createConnection(address: String): UnixSocketWorkerConnection
-
private def throwWorkerExitedBeforeSocket(
process: Process,
address: String,
@@ -117,8 +121,12 @@ abstract class DirectUnixSocketWorkerDispatcher(
* On non-POSIX filesystems falls back to best-effort `File.setXxx`,
* which is TOCTOU-racy and weaker; a WARN surfaces if the platform
* refuses the setters.
+ *
+ * Visible for testing: tests override this to escape Spark's parent-POM
+ * `java.io.tmpdir=target/tmp`, which on long checkout paths pushes UDS
+ * paths past Linux's 108-char cap.
*/
- private def createPrivateTempDirectory(): Path = {
+ protected def createPrivateTempDirectory(): Path = {
val attr = PosixFilePermissions.asFileAttribute(
PosixFilePermissions.fromString("rwx------"))
try {
diff --git a/udf/worker/core/src/main/scala/org/apache/spark/udf/worker/core/direct/DirectWorkerDispatcher.scala b/udf/worker/core/src/main/scala/org/apache/spark/udf/worker/core/direct/DirectWorkerDispatcher.scala
index afaf23791d80f..1b14323714c4e 100644
--- a/udf/worker/core/src/main/scala/org/apache/spark/udf/worker/core/direct/DirectWorkerDispatcher.scala
+++ b/udf/worker/core/src/main/scala/org/apache/spark/udf/worker/core/direct/DirectWorkerDispatcher.scala
@@ -29,8 +29,8 @@ import scala.util.control.NonFatal
import org.apache.spark.annotation.Experimental
import org.apache.spark.udf.worker.{ProcessCallable, UDFWorkerSpecification}
-import org.apache.spark.udf.worker.core.{WorkerConnection, WorkerDispatcher,
- WorkerLogger, WorkerSecurityScope, WorkerSession}
+import org.apache.spark.udf.worker.core.{WorkerDispatcher, WorkerLogger,
+ WorkerSecurityScope, WorkerSession, WorkerSessionFactory}
import org.apache.spark.udf.worker.core.direct.DirectWorkerDispatcher.{CallableResult,
DEFAULT_CALLABLE_TIMEOUT_MS, DEFAULT_GRACEFUL_TIMEOUT_MS, DEFAULT_INIT_TIMEOUT_MS,
ENGINE_MAX_TIMEOUT_MS, EnvironmentState, MAX_OUTPUT_SCAN_BYTES,
@@ -46,13 +46,17 @@ import org.apache.spark.udf.worker.core.direct.DirectWorkerDispatcher.{CallableR
* currently gets a fresh worker that is terminated when the session closes
* (the single-reference case of the future pooling policy).
*
- * Subclasses implement [[createConnection]] and [[createSessionForWorker]]
- * to provide protocol-specific behavior (e.g., gRPC, raw sockets).
+ * Subclasses implement transport-level concerns (endpoint allocation,
+ * readiness wait, transport-state cleanup). The protocol the worker speaks
+ * (gRPC, raw sockets, ...) is delegated to a [[WorkerSessionFactory]]
+ * passed in by the caller.
*
* For workers obtained through a provisioning service or daemon (indirect
* creation), see the `indirect` package (TODO).
*
* @param workerSpec worker specification (proto)
+ * @param sessionFactory protocol-specific connection + session factory.
+ * Owned by this dispatcher; closed from [[close]].
* @param logger [[WorkerLogger]] used for dispatcher-internal messages.
* The framework does not depend on any concrete logging
* backend; callers should pass an adapter that forwards
@@ -62,6 +66,7 @@ import org.apache.spark.udf.worker.core.direct.DirectWorkerDispatcher.{CallableR
@Experimental
abstract class DirectWorkerDispatcher(
override val workerSpec: UDFWorkerSpecification,
+ sessionFactory: WorkerSessionFactory,
protected val logger: WorkerLogger = WorkerLogger.NoOp)
extends WorkerDispatcher {
@@ -156,12 +161,6 @@ abstract class DirectWorkerDispatcher(
*/
protected def validateTransportSupport(): Unit
- /** Creates a protocol-specific connection to a worker at the given address. */
- protected def createConnection(address: String): WorkerConnection
-
- /** Creates a protocol-specific session for the given worker. */
- protected def createSessionForWorker(worker: DirectWorkerProcess): WorkerSession
-
override def createSession(
securityScope: Option[WorkerSecurityScope]): WorkerSession = {
require(securityScope.isEmpty,
@@ -180,7 +179,8 @@ abstract class DirectWorkerDispatcher(
throwClosed()
}
try {
- createSessionForWorker(worker)
+ val inner = sessionFactory.createSession(worker.connection)
+ new DirectWorkerSession(worker, inner, logger)
} catch {
case e: InterruptedException =>
Thread.currentThread().interrupt()
@@ -237,6 +237,10 @@ abstract class DirectWorkerDispatcher(
case NonFatal(e) =>
logger.warn("Error cleaning up transport state", e)
}
+ try sessionFactory.close() catch {
+ case NonFatal(e) =>
+ logger.warn("Error closing session factory", e)
+ }
deregisterEnvironmentCleanupHook()
runEnvironmentCleanup()
}
@@ -381,7 +385,7 @@ abstract class DirectWorkerDispatcher(
try {
waitForReady(address, process, outputFile.toFile)
- val connection = createConnection(address)
+ val connection = sessionFactory.createConnection(address)
val artifacts = new WorkerArtifacts(process, connection, outputFile, logger)
new DirectWorkerProcess(
workerId, artifacts, gracefulTimeoutMs, logger,
diff --git a/udf/worker/core/src/main/scala/org/apache/spark/udf/worker/core/direct/DirectWorkerSession.scala b/udf/worker/core/src/main/scala/org/apache/spark/udf/worker/core/direct/DirectWorkerSession.scala
index 7cdc5329350e3..0a68440c6afdd 100644
--- a/udf/worker/core/src/main/scala/org/apache/spark/udf/worker/core/direct/DirectWorkerSession.scala
+++ b/udf/worker/core/src/main/scala/org/apache/spark/udf/worker/core/direct/DirectWorkerSession.scala
@@ -18,39 +18,53 @@ package org.apache.spark.udf.worker.core.direct
import java.util.concurrent.atomic.AtomicBoolean
+import scala.util.control.NonFatal
+
import org.apache.spark.annotation.Experimental
-import org.apache.spark.udf.worker.core.{WorkerConnection, WorkerSession}
+import org.apache.spark.udf.worker.Init
+import org.apache.spark.udf.worker.core.{WorkerLogger, WorkerSession}
/**
* :: Experimental ::
- * A [[WorkerSession]] backed by a locally-spawned [[DirectWorkerProcess]].
- *
- * This is the session type returned by [[DirectWorkerDispatcher]]. It ties
- * the session lifecycle to the worker's ref-count: the dispatcher increments
- * the count before construction, and [[close]] decrements it, so the
- * dispatcher knows when a worker process is idle and can be terminated or
- * reused.
+ * Lifecycle decorator that ties a protocol-level [[WorkerSession]] to the
+ * direct worker that backs it. Forwards [[doInit]] / [[doProcess]] /
+ * [[cancel]] to the inner session, and on [[close]] additionally fires
+ * the underlying [[DirectWorkerProcess]]'s session ref-count so the
+ * dispatcher knows when a worker has gone idle.
*
- * Subclasses implement the protocol-specific data transmission
- * ([[init]], [[process]], [[cancel]]).
+ * Constructed exclusively by [[DirectWorkerDispatcher]]; protocol code
+ * (gRPC and friends) does not extend or see this class.
*
- * @param workerProcess the direct worker process backing this session.
- * Internal to the `core` package and test code -- the
- * worker handle is a dispatcher implementation detail,
- * not part of the public WorkerSession API.
+ * @param workerProcess the direct worker backing this session.
+ * @param inner the protocol-specific session this decorator forwards to.
+ * @param logger logger used when the inner close raises.
*/
@Experimental
-abstract class DirectWorkerSession(
- private[core] val workerProcess: DirectWorkerProcess) extends WorkerSession {
+final class DirectWorkerSession private[direct] (
+ private[core] val workerProcess: DirectWorkerProcess,
+ private val inner: WorkerSession,
+ private val logger: WorkerLogger = WorkerLogger.NoOp) extends WorkerSession {
private val released = new AtomicBoolean(false)
- /** The connection to the worker for this session. */
- def connection: WorkerConnection = workerProcess.connection
+ override protected def doInit(message: Init): Unit = inner.init(message)
+
+ override protected def doProcess(
+ input: Iterator[Array[Byte]]): Iterator[Array[Byte]] = inner.process(input)
+
+ override def cancel(): Unit = inner.cancel()
+ /**
+ * Closes the inner protocol session, then releases the worker's session
+ * ref-count. The ref-count release is run unconditionally so a faulty
+ * inner close cannot strand the worker as "still in use".
+ */
override def close(): Unit = {
- if (released.compareAndSet(false, true)) {
- workerProcess.releaseSession()
+ try inner.close() catch {
+ case NonFatal(e) =>
+ logger.warn(s"Error closing inner session for worker ${workerProcess.id}", e)
+ } finally {
+ if (released.compareAndSet(false, true)) workerProcess.releaseSession()
}
}
}
diff --git a/udf/worker/core/src/test/scala/org/apache/spark/udf/worker/core/DirectWorkerDispatcherSuite.scala b/udf/worker/core/src/test/scala/org/apache/spark/udf/worker/core/DirectWorkerDispatcherSuite.scala
index 60f5e2211b702..cf829880801c4 100644
--- a/udf/worker/core/src/test/scala/org/apache/spark/udf/worker/core/DirectWorkerDispatcherSuite.scala
+++ b/udf/worker/core/src/test/scala/org/apache/spark/udf/worker/core/DirectWorkerDispatcherSuite.scala
@@ -27,7 +27,7 @@ import org.scalatest.BeforeAndAfterEach
import org.scalatest.funsuite.AnyFunSuite
import org.apache.spark.udf.worker.{
- DirectWorker, LocalTcpConnection, ProcessCallable, UDFWorkerProperties,
+ DirectWorker, Init, LocalTcpConnection, ProcessCallable, UDFWorkerProperties,
UDFWorkerSpecification, UnixDomainSocket, WorkerConnectionSpec,
WorkerEnvironment}
import org.apache.spark.udf.worker.core.direct.{DirectUnixSocketWorkerDispatcher,
@@ -45,45 +45,50 @@ class SocketFileConnection(socketPath: String)
}
/**
- * A stub [[DirectWorkerSession]] for process-lifecycle tests that don't
- * need actual data transmission.
+ * A stub [[WorkerSession]] for process-lifecycle tests that don't need
+ * actual data transmission. The dispatcher wraps this in a
+ * [[DirectWorkerSession]] decorator that owns the worker ref-count.
*
- * TODO: [[cancel]] is a no-op here. Once a concrete [[DirectWorkerSession]]
+ * TODO: [[cancel]] is a no-op here. Once a concrete protocol session
* with real data-plane wiring lands, add tests exercising cancel() in
* particular: cancel from a different thread than process(), cancel
* after process() has returned, and cancel before init (should be a
* no-op). Tracking the thread-safety contract in the docstring on
* [[org.apache.spark.udf.worker.core.WorkerSession.cancel]].
*/
-class StubWorkerSession(
- workerProcess: DirectWorkerProcess) extends DirectWorkerSession(workerProcess) {
-
- override protected def doInit(message: InitMessage): Unit = {}
+class StubWorkerSession extends WorkerSession {
+ override protected def doInit(message: Init): Unit = {}
override protected def doProcess(
input: Iterator[Array[Byte]]): Iterator[Array[Byte]] =
Iterator.empty
override def cancel(): Unit = {}
+
+ override def close(): Unit = {}
}
/**
- * A [[DirectUnixSocketWorkerDispatcher]] subclass for testing that uses
- * a socket-file connection and stub sessions instead of a real protocol
- * implementation.
+ * A [[WorkerSessionFactory]] for tests: hands out
+ * [[SocketFileConnection]] connections and [[StubWorkerSession]]
+ * sessions instead of any real protocol implementation.
*/
-class TestDirectWorkerDispatcher(spec: UDFWorkerSpecification)
- extends DirectUnixSocketWorkerDispatcher(spec) {
-
- override protected def createConnection(
- socketPath: String): UnixSocketWorkerConnection =
- new SocketFileConnection(socketPath)
+class TestSessionFactory extends WorkerSessionFactory {
+ override def createConnection(address: String): WorkerConnection =
+ new SocketFileConnection(address)
- override protected def createSessionForWorker(
- worker: DirectWorkerProcess): WorkerSession =
- new StubWorkerSession(worker)
+ override def createSession(connection: WorkerConnection): WorkerSession =
+ new StubWorkerSession
}
+/**
+ * Test convenience: a [[DirectUnixSocketWorkerDispatcher]] pre-wired
+ * with the stub factory, so tests can stay focused on dispatcher
+ * behaviour and not the factory plumbing.
+ */
+class TestDirectWorkerDispatcher(spec: UDFWorkerSpecification)
+ extends DirectUnixSocketWorkerDispatcher(spec, new TestSessionFactory)
+
/**
* Tests for [[DirectWorkerDispatcher]] process lifecycle: spawning workers
* and terminating them on close.
@@ -145,14 +150,14 @@ class DirectWorkerDispatcherSuite
}
// Narrow the publicly-typed WorkerSession returned by `createSession` back
- // down to StubWorkerSession in one place, with a descriptive failure if
+ // down to DirectWorkerSession in one place, with a descriptive failure if
// the cast is ever wrong, so individual tests don't scatter `asInstanceOf`
// (which would throw ClassCastException rather than a useful message).
- private def createStubSession(): StubWorkerSession =
+ private def createStubSession(): DirectWorkerSession =
dispatcher.createSession(None) match {
- case stub: StubWorkerSession => stub
+ case dws: DirectWorkerSession => dws
case other => fail(
- s"Expected StubWorkerSession, got ${other.getClass.getSimpleName}")
+ s"Expected DirectWorkerSession, got ${other.getClass.getSimpleName}")
}
// The whole suite uses UDS as the only transport, so reaching past the
@@ -181,7 +186,7 @@ class DirectWorkerDispatcherSuite
dispatcher = new TestDirectWorkerDispatcher(specWithRunner(defaultRunner))
val threads = 8
- val sessions = new java.util.concurrent.ConcurrentLinkedQueue[StubWorkerSession]()
+ val sessions = new java.util.concurrent.ConcurrentLinkedQueue[DirectWorkerSession]()
val startGate = new java.util.concurrent.CountDownLatch(1)
val doneGate = new java.util.concurrent.CountDownLatch(threads)
val errors = new java.util.concurrent.ConcurrentLinkedQueue[Throwable]()
@@ -360,24 +365,24 @@ class DirectWorkerDispatcherSuite
// should remain in either case.
val readyLatch = new java.util.concurrent.CountDownLatch(1)
val releaseLatch = new java.util.concurrent.CountDownLatch(1)
- val capturedWorkers =
- new java.util.concurrent.ConcurrentLinkedQueue[DirectWorkerProcess]()
- val racing = new DirectUnixSocketWorkerDispatcher(specWithRunner(defaultRunner)) {
- override protected def createConnection(
- socketPath: String): UnixSocketWorkerConnection =
- new SocketFileConnection(socketPath)
- override protected def createSessionForWorker(
- worker: DirectWorkerProcess): WorkerSession = {
- capturedWorkers.add(worker)
+ val capturedConnections =
+ new java.util.concurrent.ConcurrentLinkedQueue[WorkerConnection]()
+ val blockingFactory = new WorkerSessionFactory {
+ override def createConnection(address: String): WorkerConnection =
+ new SocketFileConnection(address)
+ override def createSession(connection: WorkerConnection): WorkerSession = {
+ capturedConnections.add(connection)
readyLatch.countDown()
// Block here so dispatcher.close() runs while createSession is in
// flight. Use a generous wait so a slow CI doesn't time out.
if (!releaseLatch.await(30, java.util.concurrent.TimeUnit.SECONDS)) {
fail("releaseLatch never fired -- test orchestration broken")
}
- new StubWorkerSession(worker)
+ new StubWorkerSession
}
}
+ val racing = new DirectUnixSocketWorkerDispatcher(
+ specWithRunner(defaultRunner), blockingFactory)
try {
val outcome =
new java.util.concurrent.atomic.AtomicReference[Either[Throwable, WorkerSession]]()
@@ -394,7 +399,7 @@ class DirectWorkerDispatcherSuite
// Wait for thread A to have published the worker and entered the
// blocking override.
assert(readyLatch.await(10, java.util.concurrent.TimeUnit.SECONDS),
- "createSession thread never reached createSessionForWorker")
+ "createSession thread never reached the factory's createSession")
val closeThread = new Thread(() => racing.close(), "close-racer")
closeThread.start()
@@ -409,10 +414,10 @@ class DirectWorkerDispatcherSuite
assert(!createThread.isAlive, "createSession thread did not finish")
assert(!closeThread.isAlive, "close thread did not finish")
- val captured = capturedWorkers.toArray(Array.empty[DirectWorkerProcess])
+ val captured = capturedConnections.toArray(Array.empty[WorkerConnection])
assert(captured.length == 1,
- s"expected exactly one worker spawned, got ${captured.length}")
- val worker = captured(0)
+ s"expected exactly one connection spawned, got ${captured.length}")
+ val sockPath = captured(0).asInstanceOf[UnixSocketWorkerConnection].socketPath
outcome.get() match {
case Left(e: IllegalStateException) =>
@@ -425,20 +430,21 @@ class DirectWorkerDispatcherSuite
s"expected dispatcher-closed error, got: ${e.getMessage}")
case Left(other) =>
fail(s"unexpected exception from racing createSession: $other")
- case Right(_) =>
+ case Right(s: DirectWorkerSession) =>
// close() iterated the published worker and tore it down; the
// returned session points at a worker that should now be dead.
+ val worker = s.workerProcess
+ val deadline = System.currentTimeMillis() + 5000
+ while (worker.process.isAlive && System.currentTimeMillis() < deadline) {
+ Thread.sleep(50)
+ }
+ assert(!worker.process.isAlive,
+ s"worker process should be terminated after close, still alive at $sockPath")
+ case Right(other) =>
+ fail(s"expected DirectWorkerSession, got ${other.getClass.getSimpleName}")
}
- // Whichever path won, the worker must not still be running and the
- // socket file must be gone.
- val deadline = System.currentTimeMillis() + 5000
- while (worker.process.isAlive && System.currentTimeMillis() < deadline) {
- Thread.sleep(50)
- }
- val sockPath = udsPath(worker)
- assert(!worker.process.isAlive,
- s"worker process should be terminated after close, still alive at $sockPath")
+ // Whichever branch won, the socket file must be gone.
assert(!new java.io.File(sockPath).exists(),
s"socket file $sockPath should have been removed")
} finally {
@@ -533,32 +539,33 @@ class DirectWorkerDispatcherSuite
// -- Error-path tests -------------------------------------------------------
- test("worker is cleaned up when createSessionForWorker throws") {
- // A dispatcher whose createSessionForWorker always throws. The spawned
- // worker must be terminated rather than leaked until dispatcher.close().
- var capturedWorker: DirectWorkerProcess = null
- val failingDispatcher =
- new DirectUnixSocketWorkerDispatcher(specWithRunner(defaultRunner)) {
- override protected def createConnection(
- socketPath: String): UnixSocketWorkerConnection =
- new SocketFileConnection(socketPath)
- override protected def createSessionForWorker(
- worker: DirectWorkerProcess): WorkerSession = {
- capturedWorker = worker
- throw new RuntimeException("session creation failed")
- }
+ test("worker is cleaned up when session factory createSession throws") {
+ // A factory whose createSession always throws. The spawned worker must
+ // be terminated rather than leaked until dispatcher.close().
+ var capturedSocketPath: String = null
+ val failingFactory = new WorkerSessionFactory {
+ override def createConnection(address: String): WorkerConnection =
+ new SocketFileConnection(address)
+ override def createSession(connection: WorkerConnection): WorkerSession = {
+ capturedSocketPath = connection.asInstanceOf[UnixSocketWorkerConnection].socketPath
+ throw new RuntimeException("session creation failed")
}
+ }
+ val failingDispatcher = new DirectUnixSocketWorkerDispatcher(
+ specWithRunner(defaultRunner), failingFactory)
try {
val ex = intercept[RuntimeException] {
failingDispatcher.createSession(None)
}
assert(ex.getMessage.contains("session creation failed"))
- assert(capturedWorker != null, "worker should have been spawned before the failure")
- assert(!capturedWorker.process.isAlive,
- "worker process should have been terminated after session creation failed")
- assert(capturedWorker.activeSessions == 0,
- "worker session count should be released after failure")
+ assert(capturedSocketPath != null,
+ "factory should have been called before the failure")
+ // Worker teardown removes the UDS socket file (via SocketFileConnection's
+ // base-class close), so the file going away is the observable that the
+ // worker was reaped rather than leaked.
+ assert(!new File(capturedSocketPath).exists(),
+ s"socket file $capturedSocketPath should be cleaned up after failure")
} finally {
failingDispatcher.close()
}
@@ -591,19 +598,18 @@ class DirectWorkerDispatcherSuite
s"expected UDS-only error, got: ${ex.getMessage}")
}
- test("socket file is cleaned up when createConnection throws") {
+ test("socket file is cleaned up when factory createConnection throws") {
val capturedSocketPaths = new java.util.concurrent.ConcurrentLinkedQueue[String]()
- val failingDispatcher =
- new DirectUnixSocketWorkerDispatcher(specWithRunner(defaultRunner)) {
- override protected def createConnection(
- socketPath: String): UnixSocketWorkerConnection = {
- capturedSocketPaths.add(socketPath)
- throw new RuntimeException("connection creation failed")
- }
- override protected def createSessionForWorker(
- worker: DirectWorkerProcess): WorkerSession =
- new StubWorkerSession(worker)
+ val failingFactory = new WorkerSessionFactory {
+ override def createConnection(address: String): WorkerConnection = {
+ capturedSocketPaths.add(address)
+ throw new RuntimeException("connection creation failed")
}
+ override def createSession(connection: WorkerConnection): WorkerSession =
+ new StubWorkerSession
+ }
+ val failingDispatcher = new DirectUnixSocketWorkerDispatcher(
+ specWithRunner(defaultRunner), failingFactory)
try {
val ex = intercept[RuntimeException] {
failingDispatcher.createSession(None)
@@ -760,14 +766,8 @@ class DirectWorkerDispatcherSuite
.addCommand("sleep 30").build()
val env = WorkerEnvironment.newBuilder().setInstallation(slowInstall).build()
val shortTimeoutDispatcher =
- new DirectUnixSocketWorkerDispatcher(specWithEnv(env = env)) {
+ new DirectUnixSocketWorkerDispatcher(specWithEnv(env = env), new TestSessionFactory) {
override protected def callableTimeoutMs: Long = 500L
- override protected def createConnection(
- socketPath: String): UnixSocketWorkerConnection =
- new SocketFileConnection(socketPath)
- override protected def createSessionForWorker(
- worker: DirectWorkerProcess): WorkerSession =
- new StubWorkerSession(worker)
}
try {
val ex = intercept[DirectWorkerTimeoutException] {
@@ -897,14 +897,8 @@ class DirectWorkerDispatcherSuite
s"echo invoked >> ${counterFile.getAbsolutePath}; sleep 30").build())
.build()
val timeoutDispatcher =
- new DirectUnixSocketWorkerDispatcher(specWithEnv(env = env)) {
+ new DirectUnixSocketWorkerDispatcher(specWithEnv(env = env), new TestSessionFactory) {
override protected def callableTimeoutMs: Long = 500L
- override protected def createConnection(
- socketPath: String): UnixSocketWorkerConnection =
- new SocketFileConnection(socketPath)
- override protected def createSessionForWorker(
- worker: DirectWorkerProcess): WorkerSession =
- new StubWorkerSession(worker)
}
try {
val first = intercept[DirectWorkerTimeoutException] {
diff --git a/udf/worker/grpc/pom.xml b/udf/worker/grpc/pom.xml
new file mode 100644
index 0000000000000..677c26c8b106a
--- /dev/null
+++ b/udf/worker/grpc/pom.xml
@@ -0,0 +1,86 @@
+
+
+
+
+ 4.0.0
+
+ org.apache.spark
+ spark-parent_2.13
+ 4.2.0-SNAPSHOT
+ ../../../pom.xml
+
+
+ spark-udf-worker-grpc_2.13
+ jar
+ Spark Project UDF Worker gRPC
+ https://spark.apache.org/
+
+
+ udf-worker-grpc
+
+
+
+
+ org.apache.spark
+ spark-tags_${scala.binary.version}
+
+
+ org.apache.spark
+ spark-udf-worker-core_${scala.binary.version}
+ ${project.version}
+
+
+ org.apache.spark
+ spark-udf-worker-proto_${scala.binary.version}
+ ${project.version}
+
+
+ org.scala-lang
+ scala-library
+
+
+ io.grpc
+ grpc-netty-shaded
+ ${io.grpc.version}
+
+
+ io.grpc
+ grpc-stub
+ ${io.grpc.version}
+
+
+ io.grpc
+ grpc-protobuf
+ ${io.grpc.version}
+
+
+
+
+ target/scala-${scala.binary.version}/classes
+ target/scala-${scala.binary.version}/test-classes
+
+
+ net.alchim31.maven
+ scala-maven-plugin
+
+
+
+
diff --git a/udf/worker/grpc/src/main/scala/org/apache/spark/udf/worker/core/grpc/GrpcWorkerConnection.scala b/udf/worker/grpc/src/main/scala/org/apache/spark/udf/worker/core/grpc/GrpcWorkerConnection.scala
new file mode 100644
index 0000000000000..efe45185862bf
--- /dev/null
+++ b/udf/worker/grpc/src/main/scala/org/apache/spark/udf/worker/core/grpc/GrpcWorkerConnection.scala
@@ -0,0 +1,71 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.spark.udf.worker.core.grpc
+
+import java.util.concurrent.TimeUnit
+
+import io.grpc.{ConnectivityState, ManagedChannel}
+
+import org.apache.spark.annotation.Experimental
+import org.apache.spark.udf.worker.core.UnixSocketWorkerConnection
+
+/**
+ * :: Experimental ::
+ * A [[UnixSocketWorkerConnection]] backed by a gRPC [[ManagedChannel]].
+ *
+ * The Netty event loop group used by the channel is owned by the
+ * [[GrpcWorkerSessionFactory]] that created this connection (shared
+ * across all connections it produces), so [[close]] only shuts down
+ * the channel here -- not the event loop. The base class handles
+ * socket-file removal.
+ *
+ * @param channel the gRPC managed channel connected to the worker
+ * @param socketPath the UDS path the channel is bound to; passed to the
+ * base class for ownership/cleanup of the socket file
+ */
+@Experimental
+class GrpcWorkerConnection(
+ val channel: ManagedChannel,
+ socketPath: String) extends UnixSocketWorkerConnection(socketPath) {
+
+ private val SHUTDOWN_TIMEOUT_MS = 5000L
+
+ /**
+ * A channel is considered active unless explicitly shut down.
+ * Transient states (IDLE, CONNECTING, TRANSIENT_FAILURE) may still
+ * recover, so we only report inactive for SHUTDOWN.
+ */
+ override def isActive: Boolean = {
+ val state = channel.getState(false)
+ state != ConnectivityState.SHUTDOWN
+ }
+
+ override def close(): Unit = {
+ channel.shutdown()
+ try {
+ channel.awaitTermination(SHUTDOWN_TIMEOUT_MS, TimeUnit.MILLISECONDS)
+ } catch {
+ case _: InterruptedException =>
+ Thread.currentThread().interrupt()
+ } finally {
+ if (!channel.isTerminated) {
+ channel.shutdownNow()
+ }
+ }
+ super.close()
+ }
+}
diff --git a/udf/worker/grpc/src/main/scala/org/apache/spark/udf/worker/core/grpc/GrpcWorkerSession.scala b/udf/worker/grpc/src/main/scala/org/apache/spark/udf/worker/core/grpc/GrpcWorkerSession.scala
new file mode 100644
index 0000000000000..825aba8ebb8d7
--- /dev/null
+++ b/udf/worker/grpc/src/main/scala/org/apache/spark/udf/worker/core/grpc/GrpcWorkerSession.scala
@@ -0,0 +1,541 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.spark.udf.worker.core.grpc
+
+import java.util.concurrent.{CancellationException, LinkedBlockingQueue, TimeUnit}
+import java.util.concurrent.atomic.AtomicBoolean
+
+import scala.util.control.NonFatal
+
+import com.google.protobuf.ByteString
+import io.grpc.stub.{CallStreamObserver, StreamObserver}
+
+import org.apache.spark.annotation.Experimental
+import org.apache.spark.udf.worker._
+import org.apache.spark.udf.worker.core.WorkerSession
+
+/**
+ * :: Experimental ::
+ * A [[WorkerSession]] that implements the UDF protocol over a
+ * bidirectional gRPC stream
+ * ([[WorkerGrpc.WorkerStub#execute]]).
+ *
+ * Each instance maps to one UDF execution on a single gRPC stream.
+ * The number of output batches is independent of the number of input
+ * batches (depends on UDF shape: scalar is 1:1, aggregate is N:1, etc.).
+ *
+ * ==Threading model==
+ *
+ * This session uses a '''single caller thread''' for both sending input
+ * and receiving results. gRPC's internal I/O threads handle network
+ * transport and invoke the response [[StreamObserver]] callbacks, which
+ * enqueue responses into [[responseQueue]]. The caller thread
+ * interleaves sending and polling:
+ *
+ * {{{
+ * Caller thread gRPC I/O thread Worker (gRPC server)
+ * | | |
+ * |-- init(): send Init ----------|----> Init ----------->|
+ * | poll responseQueue | |
+ * |<--- enqueue InitResponse <----|<---- InitResponse ----|
+ * | | |
+ * |-- process(input): | |
+ * | loop: | |
+ * | poll responseQueue | |
+ * | if result: yield batch | |
+ * | if empty & input & ready: | |
+ * | send DataRequest -------|----> DataRequest ----->|
+ * | |<---- DataResponse ----|
+ * |<--- enqueue DataResponse <----| |
+ * | yield batch | |
+ * | (input exhausted) | |
+ * | send Finish + complete ---|----> Finish --------->|
+ * | |<---- FinishResponse --|
+ * |<--- enqueue FinishResponse <--| |
+ * | |<---- onCompleted -----|
+ * |<--- enqueue StreamEnd <-------| |
+ * | hasNext -> false | |
+ * | | |
+ * |-- close() ------------------->| |
+ * }}}
+ *
+ * ==Backpressure==
+ *
+ * The caller thread uses [[CallStreamObserver.isReady]] to check
+ * whether the gRPC transport buffer can accept more data before sending
+ * each batch. When the buffer is full (worker is slow to consume), the
+ * caller instead polls for results, naturally throttling the send rate.
+ * This avoids unbounded memory growth that would occur if the sender
+ * pushed data faster than the worker can process it.
+ *
+ * The interleaved send/receive loop also provides implicit backpressure:
+ * results are always preferred over sending more input. A batch is only
+ * sent when the result queue is empty and the transport is ready, which
+ * limits the amount of in-flight data to what the worker can absorb.
+ *
+ * ==Error handling==
+ *
+ * The gRPC callback [[StreamObserver.onError]] enqueues the error into
+ * [[responseQueue]]. On the next poll, the caller thread re-throws it.
+ * Before sending each input batch, the caller checks for errors already
+ * received from the worker, allowing early termination without pushing
+ * data into a failed stream.
+ *
+ * @param connection the gRPC connection to the worker. The connection is
+ * owned by the dispatcher's worker handle; this session
+ * uses but does not close it.
+ */
+@Experimental
+class GrpcWorkerSession(
+ connection: GrpcWorkerConnection,
+ controlResponseTimeoutMs: Long = GrpcWorkerSession.DefaultControlResponseTimeoutMs)
+ extends WorkerSession {
+
+ // TODO: wire order ("Init -> Data* -> Finish|Cancel") is enforced by control
+ // flow rather than an explicit state machine, so a future change that e.g.
+ // sends Cancel before Init would compile cleanly with no defensive guard.
+
+ // TODO: configurable timeouts -- POLL_TIMEOUT_MS, controlResponseTimeoutMs,
+ // and the graceful shutdown timeout in close() are all hardcoded today.
+ private val POLL_TIMEOUT_MS = 100L
+
+ // Events enqueued by the gRPC callback thread and consumed by the caller.
+ // A small sealed ADT is clearer than Either + a reference-equality sentinel,
+ // and gives the pattern matches below exhaustiveness checking for free.
+ private sealed trait StreamEvent
+ private case class StreamResponse(udf: UdfResponse) extends StreamEvent
+ private case class StreamError(cause: Throwable) extends StreamEvent
+ private case object StreamEnd extends StreamEvent
+
+ /**
+ * Events from the gRPC callback thread. Unbounded because every data
+ * message is drained by the iterator and the few control messages
+ * (Init/Finish/Cancel ack) never pile up; the queue is a hand-off, not
+ * a buffer that grows with traffic.
+ */
+ private val responseQueue = new LinkedBlockingQueue[StreamEvent]()
+
+ /** Set to true by the gRPC callback when the server stream ends. */
+ @volatile private var completed = false
+
+ /** Captures the first error from the worker so the caller can check it. */
+ @volatile private var workerError: Throwable = _
+
+ private val closed = new AtomicBoolean(false)
+
+ /** True once the client stream has been finished or cancelled. */
+ private val streamFinished = new AtomicBoolean(false)
+
+ /**
+ * True once the Init request has been written to the request stream.
+ * Gates [[cancel]] so it cannot send a Cancel as the first wire message
+ * (which would violate the proto's `Init -> ... -> Cancel|Finish` ordering).
+ */
+ private val initSent = new AtomicBoolean(false)
+
+ /**
+ * True once an explicit [[cancel]] has been requested by the caller.
+ * Distinct from [[streamFinished]] (which is also set by close/finish) so
+ * [[doProcess]] can surface a cancellation to a caller that called
+ * [[cancel]] between [[init]] and [[process]] rather than silently
+ * returning an empty iterator.
+ */
+ private val cancelRequested = new AtomicBoolean(false)
+
+ /**
+ * Guards all access to [[requestObserver]] so that cancel() and close()
+ * (which may be called from task-cancellation threads) cannot race with
+ * the caller thread's sends. gRPC StreamObserver is not thread-safe.
+ */
+ private val streamLock = new Object
+
+ private val stub = WorkerGrpc.newStub(connection.channel)
+
+ // Open the bidirectional stream. The gRPC I/O thread writes to
+ // responseQueue; the caller thread reads from it.
+ private val requestObserver: StreamObserver[UdfRequest] =
+ stub.execute(new StreamObserver[UdfResponse] {
+ override def onNext(response: UdfResponse): Unit = {
+ enqueueOrInterrupt(StreamResponse(response))
+ }
+
+ override def onError(t: Throwable): Unit = {
+ workerError = t
+ completed = true
+ enqueueOrInterrupt(StreamError(t))
+ }
+
+ override def onCompleted(): Unit = {
+ completed = true
+ // Enqueue StreamEnd so the consumer unblocks immediately instead of
+ // spin-polling on the `completed` volatile.
+ enqueueOrInterrupt(StreamEnd)
+ }
+
+ // gRPC callbacks must not throw checked exceptions; if `put` is
+ // ever interrupted (the unbounded queue makes this unrealistic in
+ // practice), restore the interrupt flag and capture the failure as
+ // a worker error so the caller surfaces it.
+ private def enqueueOrInterrupt(event: StreamEvent): Unit = {
+ try {
+ responseQueue.put(event)
+ } catch {
+ case ie: InterruptedException =>
+ Thread.currentThread().interrupt()
+ workerError = ie
+ completed = true
+ }
+ }
+ })
+
+ // TODO: payload chunking. Large UDF payloads can exceed gRPC's per-message
+ // size limit. The proto defines PayloadChunk for streaming; this method
+ // forwards the payload inline and does not yet chunk.
+ // TODO: capability validation. Init.dataFormat is not checked against the
+ // spec's supported_data_formats, nor BIDIRECTIONAL_STREAMING against
+ // supported_communication_patterns. Belongs in the dispatcher/factory.
+ override protected def doInit(message: Init): Unit = {
+ val request = UdfRequest.newBuilder()
+ .setControl(UdfControlRequest.newBuilder().setInit(message))
+ .build()
+ streamLock.synchronized {
+ requestObserver.onNext(request)
+ // Set after onNext returns so cancel() only sends Cancel when Init
+ // has actually been written. Inside the lock so cancel() observes a
+ // consistent ordering between "Init went on the wire" and the flag.
+ initSent.set(true)
+ }
+ awaitControlResponse(UdfControlResponse.ControlCase.INIT)
+ }
+
+ /**
+ * Streams input to the worker and returns an iterator of results.
+ *
+ * Uses a '''single-thread interleaved''' model: the caller thread
+ * alternates between polling for results and sending input batches.
+ * Results are always preferred over sending -- a new batch is only
+ * sent when the result queue is empty and the gRPC transport is ready
+ * ([[CallStreamObserver.isReady]]). This naturally throttles the send
+ * rate to the worker's processing speed and limits memory usage.
+ *
+ * The number of result batches is independent of input batches;
+ * it depends on the UDF shape (scalar: 1:1, aggregate: N:1, etc.).
+ * Generator-style UDFs may produce output even with empty input.
+ */
+ override protected def doProcess(input: Iterator[Array[Byte]]): Iterator[Array[Byte]] = {
+ // If cancel() ran between init and process, surface it explicitly so
+ // the caller doesn't get an empty iterator that's indistinguishable
+ // from a generator UDF that legitimately produced nothing.
+ if (cancelRequested.get()) {
+ throw new CancellationException("session was cancelled before process")
+ }
+
+ // If input is already empty, send Finish immediately.
+ // Generator-style UDFs may still produce output without input.
+ var inputExhausted = false
+ if (!input.hasNext) {
+ inputExhausted = true
+ sendFinishAndComplete()
+ }
+
+ // gRPC stub methods return a CallStreamObserver in practice; the public
+ // interface only promises StreamObserver. Use a typed match so a future
+ // change in grpc-java surfaces a clear error rather than ClassCastException.
+ val callObserver = requestObserver match {
+ case c: CallStreamObserver[UdfRequest @unchecked] => c
+ case other => throw new IllegalStateException(
+ s"expected CallStreamObserver from gRPC stub, got ${other.getClass.getName}")
+ }
+
+ new Iterator[Array[Byte]] {
+ private var pendingBatch: Option[Array[Byte]] = None
+ private var exhausted = false
+
+ override def hasNext: Boolean = {
+ if (exhausted) return false
+ if (pendingBatch.isDefined) return true
+ fillNextBatch()
+ pendingBatch.isDefined
+ }
+
+ override def next(): Array[Byte] = {
+ if (!hasNext) throw new NoSuchElementException("No more result batches")
+ val result = pendingBatch.get
+ pendingBatch = None
+ result
+ }
+
+ /**
+ * Interleaved send/receive loop. On each iteration:
+ *
+ * 1. Check for worker errors -- stop early if the stream has failed.
+ * 2. Poll the result queue -- prefer consuming results to reduce
+ * memory pressure from buffered responses.
+ * 3. If no result is available and there is more input to send and
+ * the gRPC transport is ready, send the next input batch.
+ *
+ * This ordering ensures the caller never sends faster than it
+ * consumes, providing natural backpressure.
+ */
+ private def fillNextBatch(): Unit = {
+ while (pendingBatch.isEmpty && !exhausted) {
+ // 1. Check for worker errors before doing any work.
+ throwIfErrored()
+
+ // 2. Poll the result queue. The timeout is required: an aggregate-
+ // like UDF only emits output after consuming enough input, so a
+ // blocking poll here would deadlock -- we cannot send the next
+ // input batch (step 3 below) until the poll returns, and the
+ // worker has nothing to send back yet.
+ // TODO: liveness during execution. The proto's Manage RPC
+ // (Heartbeat / ShutdownRequest) is declared but unused, and
+ // WorkerConnection.isActive only reflects the gRPC channel
+ // state, not the worker process. A hung worker leaves this
+ // loop polling every POLL_TIMEOUT_MS forever.
+ val item = responseQueue.poll(POLL_TIMEOUT_MS, TimeUnit.MILLISECONDS)
+ if (item != null) {
+ item match {
+ case StreamEnd =>
+ exhausted = true
+ return
+ case StreamResponse(response) if response.hasData =>
+ val dataResponse = response.getData
+ val bytes = dataResponse.getData.toByteArray
+ pendingBatch = Some(bytes)
+ return
+ case StreamResponse(_) =>
+ // Control response (FinishResponse, etc.); not a data batch.
+ // TODO: FinishResponse.data is silently dropped. The proto
+ // carries an optional final payload here; we have no API
+ // to surface it outside the data stream.
+ case StreamError(t) => throw t
+ }
+ } else if (completed && responseQueue.isEmpty) {
+ exhausted = true
+ return
+ }
+
+ // 3. No result available -- send more input if possible.
+ // Only send when the transport is ready (backpressure: the gRPC
+ // transport buffer has room). If not ready, loop back to poll
+ // -- the worker is slow and we should wait for results.
+ if (!inputExhausted && input.hasNext && callObserver.isReady) {
+ try {
+ val data = input.next()
+ sendDataBatch(data)
+ if (!input.hasNext) {
+ inputExhausted = true
+ sendFinishAndComplete()
+ }
+ } catch {
+ case NonFatal(e) =>
+ // Error from upstream (e.g., input serialization failure).
+ // Notify the worker so it can clean up, then propagate.
+ finishStreamWithError(e)
+ throw e
+ }
+ } else if (!inputExhausted && !input.hasNext) {
+ // Input exhausted between the last check and now.
+ inputExhausted = true
+ sendFinishAndComplete()
+ }
+ }
+ }
+ }
+ }
+
+ // TODO: cancel cannot quickly unblock an in-flight process() poll. cancel()
+ // from another thread relies on the worker (or transport teardown) to
+ // deliver StreamEnd before the poll loop notices, so task-interruption
+ // latency is bounded by POLL_TIMEOUT_MS plus a worker round-trip. The
+ // contract permits this -- the response side may carry FinishResponse or
+ // CancelResponse depending on which the worker observes first.
+ override def cancel(): Unit = {
+ cancelRequested.set(true)
+ // Per WorkerSession's lifecycle contract, cancel before init is a
+ // no-op: nothing has been sent on the transport yet, and sending
+ // Cancel as the first request would violate the proto's
+ // `Init -> ... -> Cancel|Finish` ordering.
+ if (!initSent.get()) return
+ if (streamFinished.compareAndSet(false, true)) {
+ streamLock.synchronized {
+ val request = UdfRequest.newBuilder()
+ .setControl(UdfControlRequest.newBuilder().setCancel(Cancel.getDefaultInstance))
+ .build()
+ // Best-effort: cancel can race with the worker tearing down the
+ // stream (e.g. crash, transport reset). Swallow non-fatal errors
+ // so a task-interruption thread is never disrupted by cleanup.
+ try {
+ requestObserver.onNext(request)
+ requestObserver.onCompleted()
+ } catch {
+ case NonFatal(_) => // stream already terminated; cancel is best-effort
+ }
+ }
+ }
+ }
+
+ /**
+ * Closes this session. If the stream was not properly finished (via
+ * process() or cancel()), sends a Cancel to notify the worker.
+ *
+ * If sending Cancel fails (e.g., the transport is already broken),
+ * falls back to [[StreamObserver.onError]] to ensure the server
+ * receives a termination signal and does not wait indefinitely.
+ */
+ override def close(): Unit = {
+ if (!closed.compareAndSet(false, true)) return
+ // Mirror cancel(): if init never went on the wire, the proto has not
+ // started, so closing is purely local cleanup.
+ if (!initSent.get()) return
+ if (!streamFinished.compareAndSet(false, true)) return
+ streamLock.synchronized {
+ var sentCancel = false
+ try {
+ val request = UdfRequest.newBuilder()
+ .setControl(UdfControlRequest.newBuilder()
+ .setCancel(Cancel.getDefaultInstance))
+ .build()
+ requestObserver.onNext(request)
+ sentCancel = true
+ requestObserver.onCompleted()
+ } catch {
+ case NonFatal(_) if !sentCancel =>
+ // onNext threw before onCompleted ran; no terminal call has
+ // succeeded yet, so onError is a valid terminator that lets
+ // the server stop waiting. Kept under streamLock so the
+ // (non-thread-safe) StreamObserver is only ever touched by
+ // one thread at a time.
+ try {
+ requestObserver.onError(
+ new RuntimeException("session closed before normal termination"))
+ } catch {
+ case NonFatal(_) => // stream already terminated
+ }
+ case NonFatal(_) =>
+ // onCompleted threw after onNext succeeded. The terminal call
+ // is ambiguous; per the gRPC StreamObserver contract we must
+ // not also call onError, which would be a second terminator.
+ }
+ }
+ }
+
+ private def sendDataBatch(data: Array[Byte]): Unit = {
+ val request = UdfRequest.newBuilder()
+ .setData(DataRequest.newBuilder().setData(ByteString.copyFrom(data)))
+ .build()
+ streamLock.synchronized {
+ if (!streamFinished.get()) {
+ requestObserver.onNext(request)
+ }
+ }
+ }
+
+ /**
+ * Sends Finish and completes the client stream. The `streamFinished`
+ * CAS ensures only one code path (normal completion, cancel, or close)
+ * wins and actually touches the observer.
+ */
+ private def sendFinishAndComplete(): Unit = {
+ if (streamFinished.compareAndSet(false, true)) {
+ streamLock.synchronized {
+ val finish = UdfRequest.newBuilder()
+ .setControl(UdfControlRequest.newBuilder()
+ .setFinish(Finish.getDefaultInstance))
+ .build()
+ requestObserver.onNext(finish)
+ requestObserver.onCompleted()
+ }
+ }
+ }
+
+ /**
+ * Terminates the stream with a client-side error. Used when the input
+ * iterator throws, so the worker can clean up rather than wait for
+ * data that will never arrive.
+ */
+ private def finishStreamWithError(cause: Throwable): Unit = {
+ if (streamFinished.compareAndSet(false, true)) {
+ streamLock.synchronized {
+ try {
+ requestObserver.onError(cause)
+ } catch {
+ case NonFatal(_) => // stream already terminated
+ }
+ }
+ }
+ }
+
+ /**
+ * Throws the worker's reported error, if any. Called before each send
+ * so we don't push data into a stream the worker has already failed.
+ */
+ private def throwIfErrored(): Unit = {
+ val err = workerError
+ if (err != null) throw err
+ }
+
+ /**
+ * Blocks until a control response of `expected` type arrives from the
+ * worker, or the timeout expires.
+ *
+ * A data response arriving before the control ack is a protocol
+ * violation (the worker must ack the matching control before any data).
+ * A control response of the wrong type (e.g. `FinishResponse` arriving
+ * where `InitResponse` is expected) is also a protocol violation; both
+ * are surfaced as errors rather than silently accepted.
+ *
+ * Uses `nanoTime` for a monotonic deadline so an NTP step cannot make
+ * the wait spuriously time out (or never time out).
+ */
+ private def awaitControlResponse(expected: UdfControlResponse.ControlCase): Unit = {
+ val timeoutNanos = TimeUnit.MILLISECONDS.toNanos(controlResponseTimeoutMs)
+ val deadlineNanos = System.nanoTime() + timeoutNanos
+ var remainingNanos = timeoutNanos
+ while (remainingNanos > 0L) {
+ val item = responseQueue.poll(remainingNanos, TimeUnit.NANOSECONDS)
+ if (item != null) {
+ item match {
+ case StreamResponse(response) if response.hasControl =>
+ val actual = response.getControl.getControlCase
+ if (actual == expected) return
+ throw new RuntimeException(
+ s"Worker sent control response of type $actual where " +
+ s"$expected was expected (protocol violation)")
+ case StreamResponse(_) =>
+ throw new RuntimeException(
+ "Worker sent a non-control response before the expected " +
+ "control response (protocol violation)")
+ case StreamError(t) => throw t
+ case StreamEnd =>
+ throw new RuntimeException(
+ "Worker stream completed without sending expected control response")
+ }
+ }
+ remainingNanos = deadlineNanos - System.nanoTime()
+ }
+ val channelState = connection.channel.getState(false)
+ throw new RuntimeException(
+ s"Timed out after ${controlResponseTimeoutMs}ms waiting for control " +
+ s"response (channel state: $channelState)")
+ }
+}
+
+private[grpc] object GrpcWorkerSession {
+ /** Default deadline for awaiting any single control response (Init/Finish/Cancel). */
+ val DefaultControlResponseTimeoutMs: Long = 30000L
+}
diff --git a/udf/worker/grpc/src/main/scala/org/apache/spark/udf/worker/core/grpc/GrpcWorkerSessionFactory.scala b/udf/worker/grpc/src/main/scala/org/apache/spark/udf/worker/core/grpc/GrpcWorkerSessionFactory.scala
new file mode 100644
index 0000000000000..04af0c8af2849
--- /dev/null
+++ b/udf/worker/grpc/src/main/scala/org/apache/spark/udf/worker/core/grpc/GrpcWorkerSessionFactory.scala
@@ -0,0 +1,85 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.spark.udf.worker.core.grpc
+
+import java.util.concurrent.TimeUnit
+
+import io.grpc.netty.shaded.io.grpc.netty.NettyChannelBuilder
+import io.grpc.netty.shaded.io.netty.channel.epoll.{EpollDomainSocketChannel, EpollEventLoopGroup}
+import io.grpc.netty.shaded.io.netty.channel.unix.DomainSocketAddress
+
+import org.apache.spark.annotation.Experimental
+import org.apache.spark.udf.worker.core.{WorkerConnection, WorkerSession,
+ WorkerSessionFactory}
+
+/**
+ * :: Experimental ::
+ * A [[WorkerSessionFactory]] that hands out gRPC connections (over UDS)
+ * and [[GrpcWorkerSession]] instances on top of them.
+ *
+ * Plug it into any direct dispatcher whose transport is UDS, e.g.:
+ * {{{
+ * val dispatcher = new DirectUnixSocketWorkerDispatcher(
+ * spec, new GrpcWorkerSessionFactory)
+ * }}}
+ *
+ * '''Platform limitation:''' uses Linux epoll for UDS transport (via
+ * Netty's [[EpollDomainSocketChannel]]). It is '''Linux-only''' and
+ * will not work on macOS (which would require kqueue) or Windows.
+ * The [[EpollEventLoopGroup]] is created in the constructor, so
+ * instantiating this factory on a non-Linux host fails immediately
+ * with a `LinkageError` / native-library load failure -- not lazily
+ * at session-creation time. TCP/IP transport is also not currently
+ * supported. A single [[EpollEventLoopGroup]] is shared across all
+ * connections produced by this factory to avoid excessive thread
+ * creation.
+ *
+ * TODO: Support TCP/IP transport and macOS (kqueue).
+ */
+@Experimental
+class GrpcWorkerSessionFactory extends WorkerSessionFactory {
+
+ private val SHUTDOWN_TIMEOUT_MS = 5000L
+
+ private val sharedEventLoopGroup = new EpollEventLoopGroup()
+
+ override def createConnection(address: String): GrpcWorkerConnection = {
+ val channel = NettyChannelBuilder
+ .forAddress(new DomainSocketAddress(address))
+ .eventLoopGroup(sharedEventLoopGroup)
+ .channelType(classOf[EpollDomainSocketChannel])
+ .usePlaintext()
+ .build()
+ new GrpcWorkerConnection(channel, address)
+ }
+
+ override def createSession(connection: WorkerConnection): WorkerSession = {
+ val grpc = connection match {
+ case g: GrpcWorkerConnection => g
+ case other => throw new IllegalArgumentException(
+ "GrpcWorkerSessionFactory requires a GrpcWorkerConnection, " +
+ s"got ${other.getClass.getName}")
+ }
+ new GrpcWorkerSession(grpc)
+ }
+
+ override def close(): Unit = {
+ sharedEventLoopGroup
+ .shutdownGracefully(0, SHUTDOWN_TIMEOUT_MS, TimeUnit.MILLISECONDS)
+ .await(SHUTDOWN_TIMEOUT_MS, TimeUnit.MILLISECONDS)
+ }
+}
diff --git a/udf/worker/grpc/src/test/scala/org/apache/spark/udf/worker/core/grpc/EchoUDFWorkerService.scala b/udf/worker/grpc/src/test/scala/org/apache/spark/udf/worker/core/grpc/EchoUDFWorkerService.scala
new file mode 100644
index 0000000000000..ad7fb3ab2ae76
--- /dev/null
+++ b/udf/worker/grpc/src/test/scala/org/apache/spark/udf/worker/core/grpc/EchoUDFWorkerService.scala
@@ -0,0 +1,272 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.spark.udf.worker.core.grpc
+
+import java.util.concurrent.{ConcurrentLinkedQueue, TimeUnit}
+
+import scala.jdk.CollectionConverters._
+
+import io.grpc.Server
+import io.grpc.netty.shaded.io.grpc.netty.NettyServerBuilder
+import io.grpc.netty.shaded.io.netty.channel.EventLoopGroup
+import io.grpc.netty.shaded.io.netty.channel.epoll.{EpollEventLoopGroup,
+ EpollServerDomainSocketChannel}
+import io.grpc.netty.shaded.io.netty.channel.unix.DomainSocketAddress
+import io.grpc.stub.StreamObserver
+
+import org.apache.spark.udf.worker._
+
+/**
+ * A dummy UDF worker gRPC service that echoes data back to the client.
+ *
+ * {{{
+ * Caller (GrpcWorkerSession) gRPC I/O EchoUDFWorkerService
+ * | | |
+ * |--- onNext(Init) -------------|----> Init ------------>|
+ * | |<----- InitResponse ----|
+ * |--- onNext(Data) -------------|----> Data ------------>|
+ * | |<----- Data (echo) -----|
+ * | ... repeats per input batch ... |
+ * |--- onNext(Finish) -----------|----> Finish ---------->|
+ * | |<----- FinishResponse --|
+ * |--- onCompleted ------- | |
+ * | |<----- onCompleted -----|
+ * }}}
+ *
+ * Each inbound request produces exactly one outbound response synchronously
+ * on the gRPC I/O thread, in arrival order. There is no UDF logic and no
+ * buffering.
+ *
+ * Limits / non-features:
+ * - Permissive: accepts any wire order; does not enforce the proto's
+ * "Init -> Data* -> Finish|Cancel" sequence, and silently drops requests
+ * that match no branch (the proto says receivers MUST reject these).
+ * - On `onError`, the throwable is recorded for [[assertNoErrors()]] but
+ * is not propagated to the response stream; callers may need to forcibly
+ * shut the channel down for cleanup.
+ * - [[assertNoErrors()]] is the only assertion hook -- it does not verify
+ * that anything was received, only that nothing errored.
+ *
+ * Tests should call [[assertNoErrors()]] after exercising the stream to
+ * verify no gRPC errors were silently swallowed on the server side.
+ */
+class EchoUDFWorkerService extends WorkerGrpc.WorkerImplBase {
+
+ private val errors = new ConcurrentLinkedQueue[Throwable]()
+
+ // Recording surface for tests. Captures every UdfRequest the service
+ // received, in arrival order, so tests can assert wire-order and field
+ // round-trip without changing the echo behaviour.
+ private val received = new ConcurrentLinkedQueue[UdfRequest]()
+
+ /**
+ * Asserts that no gRPC errors were received on the server side.
+ * Call this at the end of each test to catch errors that might
+ * otherwise go unnoticed (e.g., stream resets, serialization failures).
+ * Reports all captured errors so failures aren't masked when more than
+ * one stream fails in the same test.
+ */
+ // scalastyle:off throwerror
+ def assertNoErrors(): Unit = {
+ val snapshot = errors.iterator().asScala.toList
+ if (snapshot.nonEmpty) {
+ val summary = snapshot.map(_.getMessage).mkString("; ")
+ throw new AssertionError(
+ s"Echo service received ${snapshot.size} error(s): $summary", snapshot.head)
+ }
+ }
+ // scalastyle:on throwerror
+
+ /** Snapshot of every [[UdfRequest]] received, in arrival order. */
+ def receivedRequests: List[UdfRequest] = received.iterator().asScala.toList
+
+ /** Snapshot of the request-case sequence (Init / Data / Control{Finish,Cancel}). */
+ def receivedSequence: List[UdfRequest.RequestCase] =
+ received.iterator().asScala.map(_.getRequestCase).toList
+
+ override def execute(
+ responseObserver: StreamObserver[UdfResponse]): StreamObserver[UdfRequest] = {
+ new StreamObserver[UdfRequest] {
+ override def onNext(request: UdfRequest): Unit = {
+ received.add(request)
+ if (request.hasControl) {
+ val ctrl = request.getControl
+ if (ctrl.hasInit) {
+ responseObserver.onNext(
+ UdfResponse.newBuilder()
+ .setControl(UdfControlResponse.newBuilder()
+ .setInit(InitResponse.getDefaultInstance))
+ .build())
+ } else if (ctrl.hasFinish) {
+ responseObserver.onNext(
+ UdfResponse.newBuilder()
+ .setControl(UdfControlResponse.newBuilder()
+ .setFinish(FinishResponse.getDefaultInstance))
+ .build())
+ } else if (ctrl.hasCancel) {
+ responseObserver.onNext(
+ UdfResponse.newBuilder()
+ .setControl(UdfControlResponse.newBuilder()
+ .setCancel(CancelResponse.getDefaultInstance))
+ .build())
+ }
+ } else if (request.hasData) {
+ val dataReq = request.getData
+ responseObserver.onNext(
+ UdfResponse.newBuilder()
+ .setData(DataResponse.newBuilder().setData(dataReq.getData))
+ .build())
+ }
+ }
+
+ override def onError(t: Throwable): Unit = {
+ errors.add(t)
+ }
+
+ override def onCompleted(): Unit = {
+ responseObserver.onCompleted()
+ }
+ }
+ }
+}
+
+/**
+ * A test fixture that swallows everything: never sends a response, never
+ * calls onCompleted. Used to exercise the [[GrpcWorkerSession]] init
+ * timeout path without waiting on real network failure modes.
+ */
+class IgnoringInitUDFWorkerService extends WorkerGrpc.WorkerImplBase {
+ override def execute(
+ responseObserver: StreamObserver[UdfResponse]): StreamObserver[UdfRequest] = {
+ new StreamObserver[UdfRequest] {
+ override def onNext(request: UdfRequest): Unit = ()
+ override def onError(t: Throwable): Unit = ()
+ override def onCompleted(): Unit = ()
+ }
+ }
+}
+
+/**
+ * A test fixture that acks Init but emits an `onError` on the first
+ * [[DataRequest]] it sees. Used to exercise the worker-error-mid-stream
+ * path that surfaces via [[GrpcWorkerSession]]'s `workerError` and the
+ * `StreamError` queue event.
+ */
+class FailingMidStreamUDFWorkerService(failureMessage: String)
+ extends WorkerGrpc.WorkerImplBase {
+ override def execute(
+ responseObserver: StreamObserver[UdfResponse]): StreamObserver[UdfRequest] = {
+ new StreamObserver[UdfRequest] {
+ override def onNext(request: UdfRequest): Unit = {
+ if (request.hasControl && request.getControl.hasInit) {
+ responseObserver.onNext(
+ UdfResponse.newBuilder()
+ .setControl(UdfControlResponse.newBuilder()
+ .setInit(InitResponse.getDefaultInstance))
+ .build())
+ } else if (request.hasData) {
+ responseObserver.onError(new RuntimeException(failureMessage))
+ }
+ }
+ override def onError(t: Throwable): Unit = ()
+ override def onCompleted(): Unit = ()
+ }
+ }
+}
+
+/**
+ * Result of starting an [[EchoUDFWorkerService]], providing access to
+ * both the gRPC server and the service instance for test assertions.
+ */
+case class EchoServerHandle(
+ server: Server,
+ service: EchoUDFWorkerService,
+ bossGroup: EventLoopGroup,
+ workerGroup: EventLoopGroup)
+
+/**
+ * Generic gRPC server handle used by misbehaving-service fixtures
+ * ([[IgnoringInitUDFWorkerService]], [[FailingMidStreamUDFWorkerService]]).
+ * No service-specific accessors -- tests interact via [[GrpcWorkerSession]]
+ * behaviour rather than the server.
+ */
+case class GrpcServerHandle(
+ server: Server,
+ bossGroup: EventLoopGroup,
+ workerGroup: EventLoopGroup)
+
+/**
+ * Helper to start/stop an [[EchoUDFWorkerService]] on a Unix domain socket.
+ */
+object EchoUDFWorkerServer {
+
+ private val SHUTDOWN_TIMEOUT_MS = 5000L
+
+ def start(socketPath: String): EchoServerHandle = {
+ val service = new EchoUDFWorkerService
+ val bossGroup = new EpollEventLoopGroup()
+ val workerGroup = new EpollEventLoopGroup()
+ val server = NettyServerBuilder
+ .forAddress(new DomainSocketAddress(socketPath))
+ .channelType(classOf[EpollServerDomainSocketChannel])
+ .bossEventLoopGroup(bossGroup)
+ .workerEventLoopGroup(workerGroup)
+ .addService(service)
+ .build()
+ .start()
+ EchoServerHandle(server, service, bossGroup, workerGroup)
+ }
+
+ /** Start any [[WorkerGrpc.WorkerImplBase]] on a UDS path. */
+ def startWith(
+ service: WorkerGrpc.WorkerImplBase,
+ socketPath: String): GrpcServerHandle = {
+ val bossGroup = new EpollEventLoopGroup()
+ val workerGroup = new EpollEventLoopGroup()
+ val server = NettyServerBuilder
+ .forAddress(new DomainSocketAddress(socketPath))
+ .channelType(classOf[EpollServerDomainSocketChannel])
+ .bossEventLoopGroup(bossGroup)
+ .workerEventLoopGroup(workerGroup)
+ .addService(service)
+ .build()
+ .start()
+ GrpcServerHandle(server, bossGroup, workerGroup)
+ }
+
+ def stop(handle: GrpcServerHandle): Unit = {
+ if (handle != null && handle.server != null) {
+ handle.server.shutdownNow()
+ handle.server.awaitTermination(5, TimeUnit.SECONDS)
+ handle.bossGroup.shutdownGracefully(0, SHUTDOWN_TIMEOUT_MS,
+ TimeUnit.MILLISECONDS)
+ handle.workerGroup.shutdownGracefully(0, SHUTDOWN_TIMEOUT_MS,
+ TimeUnit.MILLISECONDS)
+ }
+ }
+
+ def stop(handle: EchoServerHandle): Unit = {
+ if (handle != null && handle.server != null) {
+ handle.server.shutdownNow()
+ handle.server.awaitTermination(5, TimeUnit.SECONDS)
+ handle.bossGroup.shutdownGracefully(0, SHUTDOWN_TIMEOUT_MS,
+ TimeUnit.MILLISECONDS)
+ handle.workerGroup.shutdownGracefully(0, SHUTDOWN_TIMEOUT_MS,
+ TimeUnit.MILLISECONDS)
+ }
+ }
+}
diff --git a/udf/worker/grpc/src/test/scala/org/apache/spark/udf/worker/core/grpc/GrpcDirectWorkerDispatcherSuite.scala b/udf/worker/grpc/src/test/scala/org/apache/spark/udf/worker/core/grpc/GrpcDirectWorkerDispatcherSuite.scala
new file mode 100644
index 0000000000000..c93942d469200
--- /dev/null
+++ b/udf/worker/grpc/src/test/scala/org/apache/spark/udf/worker/core/grpc/GrpcDirectWorkerDispatcherSuite.scala
@@ -0,0 +1,200 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.spark.udf.worker.core.grpc
+
+import com.google.protobuf.ByteString
+
+// scalastyle:off funsuite
+import org.scalatest.BeforeAndAfterEach
+import org.scalatest.funsuite.AnyFunSuite
+
+import org.apache.spark.udf.worker.{
+ DirectWorker, Init, ProcessCallable, UdfWorkerDataFormat, UDFWorkerProperties,
+ UDFWorkerSpecification, UdfPayload, UnixDomainSocket,
+ WorkerConnectionSpec => WorkerConnectionProto}
+import org.apache.spark.udf.worker.core.direct.{DirectUnixSocketWorkerDispatcher,
+ DirectWorkerSession}
+
+/**
+ * A [[GrpcWorkerSessionFactory]] that, on each [[createConnection]],
+ * also spins up an in-process [[EchoUDFWorkerService]] bound to the
+ * same UDS path. This lets us test the full
+ * dispatcher -> session -> gRPC data round-trip without requiring an
+ * external worker binary.
+ *
+ * '''Design note:''' The bash worker script (`touch` + `sleep` loop)
+ * serves only as a lifecycle placeholder to satisfy the
+ * process-spawning logic in [[DirectUnixSocketWorkerDispatcher]]. It does NOT
+ * serve gRPC itself; the in-process echo server bound here does.
+ */
+class TestEchoBackedFactory extends GrpcWorkerSessionFactory {
+
+ var lastHandle: EchoServerHandle = _
+
+ override def createConnection(address: String): GrpcWorkerConnection = {
+ new java.io.File(address).delete()
+ lastHandle = EchoUDFWorkerServer.start(address)
+ super.createConnection(address)
+ }
+}
+
+/**
+ * End-to-end tests: UDFWorkerSpecification -> DirectUnixSocketWorkerDispatcher
+ * (with [[GrpcWorkerSessionFactory]]) -> [[GrpcWorkerSession]] -> data
+ * round-trip over UDS.
+ */
+/**
+ * Test-only [[DirectUnixSocketWorkerDispatcher]] that places its socket
+ * directory under `/tmp` rather than `java.io.tmpdir`. Linux caps UDS
+ * paths at 108 chars, and Spark's parent POM forces
+ * `java.io.tmpdir=target/tmp`, which on a normal checkout already pushes
+ * the path past the cap. `Files.createTempDirectory` reads
+ * `java.io.tmpdir` from a `static final` cache set at JVM start, so a
+ * test-time `System.setProperty` cannot recover; we route around it by
+ * overriding the protected hook the dispatcher exposes for testing.
+ */
+class TestDirectUnixSocketWorkerDispatcher(
+ spec: org.apache.spark.udf.worker.UDFWorkerSpecification,
+ factory: GrpcWorkerSessionFactory)
+ extends DirectUnixSocketWorkerDispatcher(spec, factory) {
+
+ override protected def createPrivateTempDirectory(): java.nio.file.Path = {
+ val attr = java.nio.file.attribute.PosixFilePermissions.asFileAttribute(
+ java.nio.file.attribute.PosixFilePermissions.fromString("rwx------"))
+ java.nio.file.Files.createTempDirectory(
+ java.nio.file.Paths.get("/tmp"), "g", attr)
+ }
+}
+
+class GrpcDirectWorkerDispatcherSuite
+ extends AnyFunSuite with BeforeAndAfterEach {
+// scalastyle:on funsuite
+
+ private val workerScript =
+ """
+ |#!/bin/bash
+ |SOCKET_PATH=""
+ |while [[ $# -gt 0 ]]; do
+ | case "$1" in
+ | --connection) SOCKET_PATH="$2"; shift 2 ;;
+ | *) shift ;;
+ | esac
+ |done
+ |cleanup() { rm -f "$SOCKET_PATH"; exit 0; }
+ |trap cleanup SIGTERM
+ |touch "$SOCKET_PATH"
+ |while true; do sleep 1; done
+ """.stripMargin.trim
+
+ private var dispatcher: DirectUnixSocketWorkerDispatcher = _
+ private var factory: TestEchoBackedFactory = _
+
+ override def afterEach(): Unit = {
+ if (dispatcher != null) {
+ // Assert no silent gRPC errors before tearing down.
+ if (factory != null && factory.lastHandle != null) {
+ factory.lastHandle.service.assertNoErrors()
+ EchoUDFWorkerServer.stop(factory.lastHandle)
+ }
+ dispatcher.close()
+ dispatcher = null
+ factory = null
+ }
+ super.afterEach()
+ }
+
+ private def buildSpec(): UDFWorkerSpecification = {
+ val runner = ProcessCallable.newBuilder()
+ .addCommand("bash").addCommand("-c").addCommand(workerScript).addCommand("--")
+ .build()
+ val properties = UDFWorkerProperties.newBuilder()
+ .setConnection(WorkerConnectionProto.newBuilder()
+ .setUnixDomainSocket(UnixDomainSocket.getDefaultInstance)
+ .build())
+ .build()
+ UDFWorkerSpecification.newBuilder()
+ .setDirect(DirectWorker.newBuilder()
+ .setRunner(runner)
+ .setProperties(properties)
+ .build())
+ .build()
+ }
+
+ /** Minimal Init proto sufficient for the echo server. */
+ private def buildInit(
+ payload: Array[Byte] = Array.emptyByteArray,
+ sessionConf: Map[String, String] = Map.empty): Init = {
+ val udf = UdfPayload.newBuilder()
+ .setPayload(ByteString.copyFrom(payload))
+ .setFormat("test")
+ .build()
+ val builder = Init.newBuilder()
+ .setUdf(udf)
+ .setDataFormat(UdfWorkerDataFormat.ARROW)
+ sessionConf.foreach { case (k, v) => builder.putSessionConf(k, v) }
+ builder.build()
+ }
+
+ private def setup(): DirectUnixSocketWorkerDispatcher = {
+ factory = new TestEchoBackedFactory
+ dispatcher = new TestDirectUnixSocketWorkerDispatcher(buildSpec(), factory)
+ dispatcher
+ }
+
+ test("end-to-end: spec -> dispatcher -> session -> process over UDS") {
+ val disp = setup()
+
+ val session = disp.createSession(None)
+ try {
+ session.init(buildInit(
+ payload = "my-udf".getBytes,
+ sessionConf = Map("mode" -> "test")))
+
+ val input = Iterator("batch-1".getBytes, "batch-2".getBytes)
+ val output = session.process(input).toList
+
+ assert(output.size == 2)
+ assert(new String(output(0)) == "batch-1")
+ assert(new String(output(1)) == "batch-2")
+ } finally {
+ session.close()
+ }
+ }
+
+ test("end-to-end: dispatcher close terminates worker") {
+ val disp = setup()
+
+ val session = disp.createSession(None)
+ session.init(buildInit())
+
+ val worker = session.asInstanceOf[DirectWorkerSession].workerProcess
+ assert(worker.process.isAlive)
+
+ val output = session.process(Iterator("data".getBytes)).toList
+ assert(output.size == 1)
+ session.close()
+
+ factory.lastHandle.service.assertNoErrors()
+ EchoUDFWorkerServer.stop(factory.lastHandle)
+ disp.close()
+ dispatcher = null
+ factory = null
+
+ assert(!worker.process.isAlive,
+ "worker should be terminated after dispatcher close")
+ }
+}
diff --git a/udf/worker/grpc/src/test/scala/org/apache/spark/udf/worker/core/grpc/GrpcWorkerSessionSuite.scala b/udf/worker/grpc/src/test/scala/org/apache/spark/udf/worker/core/grpc/GrpcWorkerSessionSuite.scala
new file mode 100644
index 0000000000000..d7a2dc86c628e
--- /dev/null
+++ b/udf/worker/grpc/src/test/scala/org/apache/spark/udf/worker/core/grpc/GrpcWorkerSessionSuite.scala
@@ -0,0 +1,425 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.spark.udf.worker.core.grpc
+
+import java.nio.file.{Files, Paths}
+import java.util.UUID
+import java.util.concurrent.{CountDownLatch, FutureTask, TimeUnit}
+
+import com.google.protobuf.ByteString
+import io.grpc.netty.shaded.io.grpc.netty.NettyChannelBuilder
+import io.grpc.netty.shaded.io.netty.channel.epoll.{EpollDomainSocketChannel,
+ EpollEventLoopGroup}
+import io.grpc.netty.shaded.io.netty.channel.unix.DomainSocketAddress
+
+// scalastyle:off funsuite
+import org.scalatest.{BeforeAndAfterAll, BeforeAndAfterEach}
+import org.scalatest.funsuite.AnyFunSuite
+
+import org.apache.spark.udf.worker.{Init, UdfWorkerDataFormat, UdfPayload}
+
+/**
+ * Unit tests for [[GrpcWorkerSession]] using an [[EchoUDFWorkerService]]
+ * bound to a Unix domain socket.
+ */
+class GrpcWorkerSessionSuite
+ extends AnyFunSuite with BeforeAndAfterEach with BeforeAndAfterAll {
+// scalastyle:on funsuite
+
+ private var handle: EchoServerHandle = _
+ private var grpcHandle: GrpcServerHandle = _
+ private var connection: GrpcWorkerConnection = _
+ private var socketPath: String = _
+
+ private val testEventLoopGroup = new EpollEventLoopGroup()
+
+ private def generateSocketPath(): String = {
+ // Linux UDS paths are capped at 108 chars. Spark's parent POM forces
+ // java.io.tmpdir to `target/tmp`, which on a normal checkout already
+ // pushes the path past the cap. Use the OS tmpdir directly, plus a
+ // short prefix and a truncated UUID, to stay well under the limit.
+ val tmpDir = Files.createTempDirectory(Paths.get("/tmp"), "g")
+ tmpDir.resolve(s"e-${UUID.randomUUID().toString.take(8)}.sock").toString
+ }
+
+ /** Minimal Init proto sufficient for the echo server. */
+ private def buildInit(
+ payload: Array[Byte] = Array.emptyByteArray,
+ inputSchema: Array[Byte] = Array.emptyByteArray,
+ outputSchema: Array[Byte] = Array.emptyByteArray,
+ sessionConf: Map[String, String] = Map.empty): Init = {
+ val udf = UdfPayload.newBuilder()
+ .setPayload(ByteString.copyFrom(payload))
+ .setFormat("test")
+ .build()
+ val builder = Init.newBuilder()
+ .setUdf(udf)
+ .setDataFormat(UdfWorkerDataFormat.ARROW)
+ if (inputSchema.nonEmpty) builder.setInputSchema(ByteString.copyFrom(inputSchema))
+ if (outputSchema.nonEmpty) builder.setOutputSchema(ByteString.copyFrom(outputSchema))
+ sessionConf.foreach { case (k, v) => builder.putSessionConf(k, v) }
+ builder.build()
+ }
+
+ private def createSessionOnEchoServer(): GrpcWorkerSession = {
+ socketPath = generateSocketPath()
+ handle = EchoUDFWorkerServer.start(socketPath)
+ connection = newConnection(socketPath)
+ new GrpcWorkerSession(connection)
+ }
+
+ /**
+ * Bind a custom [[org.apache.spark.udf.worker.WorkerGrpc.WorkerImplBase]]
+ * to a fresh UDS path and return a session against it. Used by the
+ * misbehaving-fixture tests (init timeout, worker error mid-stream).
+ */
+ private def createSessionWithService(
+ service: org.apache.spark.udf.worker.WorkerGrpc.WorkerImplBase,
+ controlResponseTimeoutMs: Long =
+ GrpcWorkerSession.DefaultControlResponseTimeoutMs): GrpcWorkerSession = {
+ socketPath = generateSocketPath()
+ grpcHandle = EchoUDFWorkerServer.startWith(service, socketPath)
+ connection = newConnection(socketPath)
+ new GrpcWorkerSession(connection, controlResponseTimeoutMs)
+ }
+
+ private def newConnection(endpoint: String): GrpcWorkerConnection = {
+ val channel = NettyChannelBuilder
+ .forAddress(new DomainSocketAddress(endpoint))
+ .eventLoopGroup(testEventLoopGroup)
+ .channelType(classOf[EpollDomainSocketChannel])
+ .usePlaintext()
+ .build()
+ new GrpcWorkerConnection(channel, endpoint)
+ }
+
+ override def afterEach(): Unit = {
+ // Assert no silent gRPC errors before tearing down.
+ if (handle != null) {
+ handle.service.assertNoErrors()
+ EchoUDFWorkerServer.stop(handle)
+ handle = null
+ }
+ if (grpcHandle != null) {
+ EchoUDFWorkerServer.stop(grpcHandle)
+ grpcHandle = null
+ }
+ if (connection != null) {
+ connection.close()
+ connection = null
+ }
+ super.afterEach()
+ }
+
+ override def afterAll(): Unit = {
+ testEventLoopGroup.shutdownGracefully(0, 5000, TimeUnit.MILLISECONDS)
+ super.afterAll()
+ }
+
+ test("init sends Init and receives InitResponse") {
+ val session = createSessionOnEchoServer()
+ try {
+ session.init(buildInit(
+ payload = "test-function".getBytes,
+ inputSchema = "input-schema".getBytes,
+ outputSchema = "output-schema".getBytes,
+ sessionConf = Map("key" -> "value")))
+ } finally {
+ session.close()
+ }
+ }
+
+ test("process echoes single batch") {
+ val session = createSessionOnEchoServer()
+ try {
+ session.init(buildInit())
+
+ val input = Iterator("hello-world".getBytes)
+ val output = session.process(input).toList
+
+ assert(output.size == 1)
+ assert(new String(output.head) == "hello-world")
+ } finally {
+ session.close()
+ }
+ }
+
+ test("process echoes multiple batches") {
+ val session = createSessionOnEchoServer()
+ try {
+ session.init(buildInit())
+
+ val batches = (1 to 5).map(i => s"batch-$i".getBytes)
+ val output = session.process(batches.iterator).toList
+
+ assert(output.size == 5)
+ output.zip(batches).foreach { case (result, expected) =>
+ assert(result sameElements expected)
+ }
+ } finally {
+ session.close()
+ }
+ }
+
+ test("process with empty input returns empty output") {
+ val session = createSessionOnEchoServer()
+ try {
+ session.init(buildInit())
+
+ val output = session.process(Iterator.empty).toList
+ assert(output.isEmpty)
+ } finally {
+ session.close()
+ }
+ }
+
+ test("cancel aborts the session") {
+ val session = createSessionOnEchoServer()
+ try {
+ session.init(buildInit())
+ session.cancel()
+ } finally {
+ session.close()
+ }
+ }
+
+ // -- Error-path tests -------------------------------------------------------
+
+ test("init called twice throws") {
+ val session = createSessionOnEchoServer()
+ try {
+ session.init(buildInit())
+ val ex = intercept[IllegalStateException] {
+ session.init(buildInit())
+ }
+ assert(ex.getMessage.contains("init has already been called"))
+ } finally {
+ session.close()
+ }
+ }
+
+ test("process called twice throws") {
+ val session = createSessionOnEchoServer()
+ try {
+ session.init(buildInit())
+ session.process(Iterator.empty).toList
+ val ex = intercept[IllegalStateException] {
+ session.process(Iterator.empty)
+ }
+ assert(ex.getMessage.contains("process has already been called"))
+ } finally {
+ session.close()
+ }
+ }
+
+ test("cancel after init does not throw") {
+ val session = createSessionOnEchoServer()
+ try {
+ session.init(buildInit())
+ session.cancel()
+ // close() after cancel should be safe
+ } finally {
+ session.close()
+ }
+ }
+
+ test("cancel before init is a no-op") {
+ // Per WorkerSession's lifecycle contract, cancel before init must not
+ // touch the wire (sending Cancel as the first request would violate the
+ // proto's `Init -> ... -> Cancel|Finish` ordering). The session must
+ // still close cleanly afterwards.
+ val session = createSessionOnEchoServer()
+ try {
+ session.cancel()
+ // No init was sent; the echo server should not have observed any
+ // wire traffic, including Cancel.
+ handle.service.assertNoErrors()
+ } finally {
+ session.close()
+ }
+ }
+
+ test("process after cancel throws CancellationException") {
+ // Per WorkerSession's lifecycle contract, a cancel between init and
+ // process must surface as an observable signal -- not a silently
+ // empty iterator that's indistinguishable from a generator UDF that
+ // produced nothing.
+ val session = createSessionOnEchoServer()
+ try {
+ session.init(buildInit())
+ session.cancel()
+ intercept[java.util.concurrent.CancellationException] {
+ session.process(Iterator.empty)
+ }
+ } finally {
+ session.close()
+ }
+ }
+
+ test("cancel during process from another thread does not hang") {
+ // Threading invariant: cancel() called from a task-interruption thread
+ // while process() is running on the caller thread must let process
+ // exit (cleanly, with empty/partial output, or by throwing). The
+ // contract permits either FinishResponse or CancelResponse; the only
+ // requirement is that the iterator does not block forever.
+ val session = createSessionOnEchoServer()
+ try {
+ session.init(buildInit())
+
+ // hasNext blocks on `release` after the first batch so cancel
+ // arrives while process() is in fillNextBatch. Bounded await so
+ // a regression cannot hang the test runner.
+ val inProcess = new CountDownLatch(1)
+ val release = new CountDownLatch(1)
+ val input = new Iterator[Array[Byte]] {
+ private var idx = 0
+ override def hasNext: Boolean = {
+ if (idx == 1) {
+ inProcess.countDown()
+ release.await(2, TimeUnit.SECONDS)
+ }
+ idx < 5
+ }
+ override def next(): Array[Byte] = {
+ val v = s"batch-$idx".getBytes
+ idx += 1
+ v
+ }
+ }
+
+ val task = new FutureTask[List[Array[Byte]]](
+ () => session.process(input).toList)
+ val thread = new Thread(task, "process-thread")
+ thread.setDaemon(true)
+ thread.start()
+
+ assert(inProcess.await(2, TimeUnit.SECONDS),
+ "process() did not reach the input iterator in time")
+ session.cancel()
+ release.countDown()
+
+ // The future should resolve (return or throw) within the deadline.
+ // Either is acceptable per contract; the assertion here is that we
+ // did not hang.
+ try {
+ task.get(2, TimeUnit.SECONDS)
+ } catch {
+ case _: java.util.concurrent.ExecutionException => // also acceptable
+ case _: java.util.concurrent.TimeoutException =>
+ fail("process() did not return within 2s after cancel")
+ }
+ assert(!thread.isAlive, "process thread should have exited after cancel")
+ } finally {
+ session.close()
+ }
+ }
+
+ test("worker error mid-stream surfaces from process iterator") {
+ // gRPC translates a server-side `responseObserver.onError(...)` into a
+ // StatusRuntimeException on the client; the exact message wording is
+ // implementation-defined. We only assert the failure surfaces from the
+ // iterator (not silently dropped, not hung).
+ val session = createSessionWithService(
+ new FailingMidStreamUDFWorkerService("worker-failed-mid-stream"))
+ try {
+ session.init(buildInit())
+ intercept[Throwable] {
+ session.process(Iterator("data".getBytes)).toList
+ }
+ } finally {
+ session.close()
+ }
+ }
+
+ test("init throws when worker never sends InitResponse") {
+ // Use a fixture that swallows Init and a small timeout so the test
+ // exercises the timeout path without waiting on real network failure.
+ val session = createSessionWithService(
+ new IgnoringInitUDFWorkerService,
+ controlResponseTimeoutMs = 250L)
+ try {
+ val ex = intercept[RuntimeException] {
+ session.init(buildInit())
+ }
+ assert(ex.getMessage.contains("Timed out"),
+ s"expected timeout error, got: ${ex.getMessage}")
+ assert(ex.getMessage.contains("channel state"),
+ "timeout error should report channel state for debuggability")
+ } finally {
+ session.close()
+ }
+ }
+
+ test("Init fields round-trip to the worker intact") {
+ val session = createSessionOnEchoServer()
+ try {
+ val payload = "test-callable".getBytes
+ val inputSchema = "test-input-schema".getBytes
+ val outputSchema = "test-output-schema".getBytes
+ session.init(buildInit(
+ payload = payload,
+ inputSchema = inputSchema,
+ outputSchema = outputSchema,
+ sessionConf = Map("k1" -> "v1", "k2" -> "v2")))
+ // Drain so Finish is on the wire and recorded.
+ session.process(Iterator.empty).toList
+
+ val received = handle.service.receivedRequests
+ val initMsg = received.headOption
+ .flatMap(r => if (r.hasControl && r.getControl.hasInit) Some(r.getControl.getInit)
+ else None)
+ .getOrElse(fail("expected first received request to be Init"))
+
+ assert(initMsg.getDataFormat == UdfWorkerDataFormat.ARROW)
+ assert(initMsg.getUdf.getPayload.toByteArray sameElements payload)
+ assert(initMsg.getUdf.getFormat == "test")
+ assert(initMsg.getInputSchema.toByteArray sameElements inputSchema)
+ assert(initMsg.getOutputSchema.toByteArray sameElements outputSchema)
+ assert(initMsg.getSessionConfMap.get("k1") == "v1")
+ assert(initMsg.getSessionConfMap.get("k2") == "v2")
+ } finally {
+ session.close()
+ }
+ }
+
+ test("request wire order is Init -> Data* -> Finish") {
+ val session = createSessionOnEchoServer()
+ try {
+ session.init(buildInit())
+ val batches = (1 to 3).map(i => s"b$i".getBytes)
+ session.process(batches.iterator).toList
+
+ val received = handle.service.receivedRequests
+ assert(received.nonEmpty, "expected at least Init + Finish")
+ // First must be Init.
+ assert(received.head.hasControl && received.head.getControl.hasInit,
+ s"first request was not Init: ${received.head}")
+ // Last must be Finish (not Cancel) on the happy path.
+ assert(received.last.hasControl && received.last.getControl.hasFinish,
+ s"last request was not Finish: ${received.last}")
+ // Middle entries are all Data, in input order.
+ val middle = received.drop(1).dropRight(1)
+ assert(middle.size == 3, s"expected 3 Data messages, got ${middle.size}")
+ middle.zip(batches).foreach { case (req, expected) =>
+ assert(req.hasData, s"middle request was not Data: $req")
+ assert(req.getData.getData.toByteArray sameElements expected)
+ }
+ } finally {
+ session.close()
+ }
+ }
+}
diff --git a/udf/worker/proto/pom.xml b/udf/worker/proto/pom.xml
index 6850ae6938a39..843d843719bd9 100644
--- a/udf/worker/proto/pom.xml
+++ b/udf/worker/proto/pom.xml
@@ -51,32 +51,28 @@
org.scala-lang
scala-library
+
+
+ io.grpc
+ grpc-api
+ ${io.grpc.version}
+
+
+ io.grpc
+ grpc-protobuf
+ ${io.grpc.version}
+
+
+ io.grpc
+ grpc-stub
+ ${io.grpc.version}
+
target/scala-${scala.binary.version}/classes
target/scala-${scala.binary.version}/test-classes
-
- com.github.os72
- protoc-jar-maven-plugin
- ${protoc-jar-maven-plugin.version}
-
-
- generate-sources
-
- run
-
-
- com.google.protobuf:protoc:${protobuf.version}
- ${protobuf.version}
-
- src/main/protobuf
-
-
-
-
-
org.apache.maven.plugins
maven-shade-plugin
@@ -113,4 +109,86 @@
+
+
+
+
+ default-protoc
+
+ true
+
+
+
+
+ eu.maveniverse.maven.plugins
+ nisse-plugin3
+ 0.7.0
+
+
+ set-os-detector-properties
+
+ inject-properties
+
+ validate
+
+
+
+
+ org.xolstice.maven.plugins
+ protobuf-maven-plugin
+ 0.6.1
+
+ com.google.protobuf:protoc:${protobuf.version}:exe:${os.detected.classifier}
+ grpc-java
+ io.grpc:protoc-gen-grpc-java:${io.grpc.version}:exe:${os.detected.classifier}
+ src/main/protobuf
+
+
+
+
+ compile
+ compile-custom
+
+
+
+
+
+
+
+
+ user-defined-protoc
+
+ ${env.SPARK_PROTOC_EXEC_PATH}
+ ${env.CONNECT_PLUGIN_EXEC_PATH}
+
+
+
+
+ org.xolstice.maven.plugins
+ protobuf-maven-plugin
+ 0.6.1
+
+ ${spark.protoc.executable.path}
+ grpc-java
+ ${connect.plugin.executable.path}
+ src/main/protobuf
+
+
+
+
+ compile
+ compile-custom
+
+
+
+
+
+
+
+
diff --git a/udf/worker/proto/src/main/protobuf/common.proto b/udf/worker/proto/src/main/protobuf/common.proto
index ee032def73efe..c08becb082d4e 100644
--- a/udf/worker/proto/src/main/protobuf/common.proto
+++ b/udf/worker/proto/src/main/protobuf/common.proto
@@ -24,9 +24,9 @@ option java_package = "org.apache.spark.udf.worker";
option java_multiple_files = true;
// The UDF in & output data format.
-enum UDFWorkerDataFormat {
+enum UdfWorkerDataFormat {
UDF_WORKER_DATA_FORMAT_UNSPECIFIED = 0;
-
+
// The worker accepts and produces Apache arrow batches.
ARROW = 1;
}
@@ -39,10 +39,10 @@ enum UDFWorkerDataFormat {
// framing their phases as messages on the stream, but that is a design
// question worth revisiting as additional UDF types are added -- for
// example, aggregation may prefer a multi-round or specialized pattern.
-enum UDFProtoCommunicationPattern {
+enum UdfProtoCommunicationPattern {
UDF_PROTO_COMMUNICATION_PATTERN_UNSPECIFIED = 0;
- // Data exachanged as a bidrectional
+ // Data exchanged as a bidirectional
// stream of bytes.
BIDIRECTIONAL_STREAMING = 1;
}
diff --git a/udf/worker/proto/src/main/protobuf/udf_protocol.proto b/udf/worker/proto/src/main/protobuf/udf_protocol.proto
new file mode 100644
index 0000000000000..dcd7746333dcf
--- /dev/null
+++ b/udf/worker/proto/src/main/protobuf/udf_protocol.proto
@@ -0,0 +1,454 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+syntax = "proto3";
+
+import "common.proto";
+
+package org.apache.spark.udf.worker;
+
+option java_package = "org.apache.spark.udf.worker";
+option java_multiple_files = true;
+
+// =====================================================================
+// Language-agnostic UDF execution protocol.
+//
+// The Spark engine acts as the gRPC client; a UDF worker (in any
+// language) acts as the gRPC server.
+// =====================================================================
+
+// The default UDF gRPC service. A worker that exposes this service
+// MUST do so over the default connection of the worker specification.
+//
+// In future, additional connections (e.g. a separate channel) may be
+// reserved by the worker spec for other purposes.
+service Worker {
+ // Per-execution stream. Exactly one [[Init]] is sent first, followed
+ // by 0..N data batches in either direction, terminated by exactly
+ // one [[Finish]] or [[Cancel]] from the engine. The worker MUST
+ // respond with the matching Init / Finish / Cancel responses on the
+ // response stream.
+ //
+ // For stateful execution, the state is maintained per bi-directional
+ // stream, mapping to a `WorkerSession` on the engine side
+ // (`org.apache.spark.udf.worker.core.WorkerSession`).
+ rpc Execute(stream UdfRequest) returns (stream UdfResponse);
+
+ // Worker-scoped management RPC, independent of any per-execution
+ // stream. Used for heartbeat, capability query, and graceful
+ // shutdown. Kept unary so it does not depend on the lifecycle of an
+ // active Execute stream.
+ rpc Manage(WorkerRequest) returns (WorkerResponse);
+}
+
+// =====================================================================
+// Execute stream -- envelope
+// =====================================================================
+
+// Engine -> Worker. Either a control message ([[Init]] / [[PayloadChunk]]
+// / [[Finish]] / [[Cancel]]) or a data message.
+message UdfRequest {
+ // Exactly one branch MUST be set; receivers MUST reject messages
+ // with no branch set.
+ oneof request {
+ UdfControlRequest control = 1;
+ DataRequest data = 2;
+ }
+}
+
+// Worker -> Engine. Either a control response ([[InitResponse]] /
+// [[FinishResponse]] / [[CancelResponse]]) or a data response message.
+message UdfResponse {
+ // Exactly one branch MUST be set; receivers MUST reject messages
+ // with no branch set.
+ oneof response {
+ UdfControlResponse control = 1;
+ DataResponse data = 2;
+ }
+}
+
+// Engine -> Worker control messages.
+//
+// Wire order on an Execute stream is exactly:
+// Init { ... }
+// PayloadChunk { ... }* // optional; 0..N chunks, only used when
+// // the single UDF payload on Init is too
+// // large to fit inline.
+// ( DataRequest | )*
+// Finish { ... } OR Cancel { ... } // exactly one terminator
+//
+// The worker MUST emit [[InitResponse]] before sending any
+// [[DataResponse]], and MUST emit exactly one [[FinishResponse]] or
+// [[CancelResponse]] before closing the response stream.
+//
+// A worker that receives messages out of this order (e.g. a second Init,
+// a PayloadChunk after the first DataRequest, a DataRequest before Init,
+// or a Cancel before Init) MUST close the stream with an error.
+message UdfControlRequest {
+ // Exactly one branch MUST be set; receivers MUST reject messages
+ // with no branch set.
+ oneof control {
+ Init init = 1;
+ PayloadChunk payload = 2;
+ Finish finish = 3;
+ Cancel cancel = 4;
+ }
+}
+
+// Worker -> Engine control messages.
+message UdfControlResponse {
+ // Exactly one branch MUST be set; receivers MUST reject messages
+ // with no branch set.
+ oneof control {
+ InitResponse init = 1;
+ FinishResponse finish = 2;
+ CancelResponse cancel = 3;
+ }
+}
+
+// =====================================================================
+// Init phase
+// =====================================================================
+
+// Sent once, as the first message on an Execute stream. Describes
+// the UDF body to run plus the minimum metadata the worker needs to
+// start processing it.
+//
+// Today the protocol mandates exactly one Init per UDF execution
+// (one Init -> data -> Finish). This is the simplest contract and
+// covers all currently supported UDF kinds. In the future we may
+// evolve to support multiple init phases on the same stream -- e.g.
+// when worker setup requires an interactive handshake (negotiate a
+// schema, exchange capabilities, fetch driver-side metadata, ...)
+// before the data plane opens. Such an extension would be additive
+// and would not change the single-Init semantics already in use.
+//
+// Engine vs. client split:
+// * Most fields on Init are engine-side. They describe what
+// flows on the wire for this session ([[data_format]] /
+// [[input_schema]] / [[output_schema]] -- matching the worker
+// spec, not the function's view) and what per-session
+// context the worker needs ([[timezone]], [[session_conf]],
+// [[task_context]], [[parameters]]).
+// * [[UdfPayload]] carries everything the client side of Spark
+// (where the UDF is defined and serialized) packs -- the
+// serialized callable, an opaque format tag, and any encoder
+// metadata bundled with the callable. The wire protocol does
+// not enumerate encoder shapes; that is left to the client and
+// worker to agree on per UDF type.
+message Init {
+ // (Required) Wire format used for [[DataRequest.data]] and
+ // [[DataResponse.data]] for the life of this session. Must be
+ // one of the formats the worker declared in
+ // [[WorkerCapabilities.supported_data_formats]]; the client side
+ // of the protocol picks one at planning time and sticks with it.
+ // Receivers MUST reject an Init whose [[data_format]] is
+ // `UDF_WORKER_DATA_FORMAT_UNSPECIFIED`.
+ UdfWorkerDataFormat data_format = 1;
+
+ // (Required) The UDF body to execute on the worker for this
+ // session. Exactly one payload per Execute stream.
+ UdfPayload udf = 2;
+
+ // (Optional) Schema of the input data plane in the wire format
+ // declared by [[data_format]] -- e.g. an Arrow IPC schema when
+ // data_format = ARROW. This is an engine-side requirement: it
+ // describes the bytes the engine will actually put on
+ // [[DataRequest.data]] for this session, matching what the
+ // worker advertised in its spec. It is NOT necessarily the
+ // schema the function definer expressed; the UDF's own type
+ // information lives inside [[UdfPayload]], typically embedded
+ // alongside the callable in [[UdfPayload.payload]] (e.g. as
+ // input/output encoders chosen per UDF type).
+ //
+ // Left unset when the worker can derive the schema from the
+ // payload alone.
+ optional bytes input_schema = 3;
+
+ // (Optional) Schema of the output data plane in the wire format
+ // declared by [[data_format]]. Same semantics as
+ // [[input_schema]] -- engine-side requirement describing the
+ // bytes the engine expects on [[DataResponse.data]].
+ optional bytes output_schema = 4;
+
+ // (Optional; defaults to an empty map.) Per-task context
+ // provided by the engine. Common keys identify the task instance
+ // for diagnostics, logging, and stateful workers -- e.g.
+ // partition id, task attempt id, stage id, micro-batch id.
+ // Engine and worker agree on the keys they share; the protocol
+ // does not enumerate them.
+ map task_context = 5;
+
+ // (Optional; defaults to an empty map.) Worker-private knobs not
+ // already captured by typed fields above. Free-form; both sides
+ // agree on the keys they need.
+ //
+ // Any key that two languages converge on is a candidate for
+ // promotion to a structured proto field -- once promoted, it gets
+ // a typed field number from the reserved range right after this
+ // block and is removed from [[session_conf]]. [[timezone]] below
+ // is an example of a key that has already been promoted.
+ map session_conf = 6;
+
+ // (Optional) Session timezone, promoted out of [[session_conf]]
+ // because every eval needs it for timestamp encoding/decoding.
+ //
+ // Format follows Spark's `spark.sql.session.timeZone` config --
+ // typically an IANA TZ id (e.g. "America/Los_Angeles") or a
+ // fixed offset (e.g. "+08:00"). The engine MUST pass the value
+ // it would resolve from the session conf without further
+ // transformation, so the worker can interpret it the same way
+ // Spark does.
+ optional string timezone = 7;
+
+ // Reserved for future typed Init fields, in particular keys
+ // graduated from [[session_conf]] (see the [[timezone]] precedent
+ // above). Numbers >= 100 are intentionally NOT reserved here; if
+ // a future revision needs an opaque escape-hatch field, give it a
+ // number >= 100 alongside [[parameters]] and add a field-level
+ // comment so the convention stays visible.
+ reserved 8 to 99;
+
+ // (Optional) Engine-packed opaque parameters specific to a
+ // particular kind of UDF execution. The escape hatch for
+ // anything the engine needs the worker to see at init time
+ // that is not already captured by the typed fields above and
+ // does not fit naturally into [[task_context]]. The encoding
+ // is agreed between the engine and the worker; the protocol
+ // does not interpret it. The matching response, also opaque
+ // bytes, is returned via [[InitResponse.data]].
+ //
+ // Numbers >= 100 are reserved by convention for opaque
+ // escape-hatch fields like this one; new typed fields use the
+ // reserved 8..99 range.
+ //
+ // Client-side init data (anything packed by the layer that
+ // defines and serializes the UDF) does NOT belong here -- it
+ // travels inside [[UdfPayload.payload]] instead.
+ optional bytes parameters = 100;
+}
+
+// Acknowledgment for [[Init]]. The worker MUST send exactly one
+// [[InitResponse]] before any [[DataResponse]].
+//
+// The init phase allows the engine to interact with the UDF before
+// data starts flowing -- the worker can return inline bytes here for
+// the engine (or higher-level code on the engine side) to consume
+// during setup. The semantics of those bytes are agreed between the
+// client side of the protocol and the worker; this message itself is
+// otherwise opaque.
+message InitResponse {
+ // (Optional) Inline init result returned by the worker. Opaque
+ // to the protocol; the client side of the protocol and the
+ // worker agree on what (if anything) it carries.
+ optional bytes data = 1;
+}
+
+// Optional. Used to stream the single UDF payload when it does not
+// fit in a single gRPC message. The default is to send the payload
+// inline on [[UdfPayload.payload]]; chunking is only needed when a
+// payload exceeds the gRPC message size limit.
+//
+// When used, chunks are sent zero or more times after [[Init]] and
+// before the first [[DataRequest]]. The worker concatenates the
+// inline [[UdfPayload.payload]] (if any) followed by all chunks in
+// arrival order to form the final payload.
+//
+// Chunks are part of the Init handshake, not standalone control
+// messages: they extend [[Init.udf.payload]] and are not
+// individually acknowledged. The single [[InitResponse]] covers
+// Init plus all of its chunks together.
+message PayloadChunk {
+ // (Required, non-empty.) Bytes appended to the [[Init.udf]]
+ // payload.
+ bytes data = 1;
+
+ // (Optional) Set to true on the final chunk. Receivers MAY use
+ // this as an early signal that the payload is complete and
+ // decoding can begin; receivers that prefer to wait for the
+ // first [[DataRequest]] (which marks the end of the chunking
+ // phase) MAY ignore this. When unset, the receiver determines
+ // completeness by the arrival of the first [[DataRequest]].
+ optional bool last = 2;
+}
+
+// =====================================================================
+// Data phase
+//
+// `data` is intentionally a top-level `bytes` field on both request
+// and response messages -- not nested inside a wrapper -- so that
+// implementations can avoid an extra copy when reading or writing
+// the payload. The wire format (Arrow IPC etc.) is declared once per
+// session via [[Init.data_format]] and stays the same for the life
+// of the stream.
+// =====================================================================
+
+// Engine -> Worker per-batch payload.
+message DataRequest {
+ // (Required, non-empty.) Encoded data bytes for one batch in the
+ // session-declared format.
+ bytes data = 1;
+}
+
+// Worker -> Engine per-batch payload. The worker emits zero or more
+// [[DataResponse]]s between [[InitResponse]] and [[FinishResponse]] /
+// [[CancelResponse]]. Sink-style UDFs (which consume input but
+// produce no output rows on the data plane) emit exactly zero.
+message DataResponse {
+ // (Required, non-empty.) Encoded data bytes for one batch in the
+ // session-declared format.
+ bytes data = 1;
+}
+
+// =====================================================================
+// Finish / Cancel phase
+// =====================================================================
+
+// Sent by the engine when no more input batches will arrive. The
+// worker MUST drain any remaining output, then emit
+// [[FinishResponse]] and close the response stream.
+//
+// Exactly one of [[Finish]] or [[Cancel]] is sent per Execute stream;
+// they are mutually exclusive. If the engine has already sent
+// [[Finish]] it MUST NOT send [[Cancel]] afterwards (and vice versa).
+message Finish {}
+
+// Worker -> Engine completion message. May carry summary metrics.
+message FinishResponse {
+ // Final metrics aggregated over the whole session (e.g. rows
+ // in/out, time per phase). Free-form; names are worker-defined.
+ map metrics = 1;
+
+ // (Optional) Inline finish result returned by the worker.
+ // Mirrors [[InitResponse.data]] -- the finish phase allows the
+ // engine to interact with the UDF after data has stopped
+ // flowing, with the worker returning opaque bytes the engine (or
+ // higher-level code) may consume during teardown. The semantics
+ // of those bytes are agreed between the client side of the
+ // protocol and the worker.
+ optional bytes data = 2;
+}
+
+// Engine -> Worker explicit cancel. Distinct from a gRPC stream error
+// so the worker can run cleanup deterministically (release file
+// handles, drop temp state, etc.). After receiving [[Cancel]] the
+// worker MUST stop emitting [[DataResponse]] messages, run cleanup,
+// and emit [[CancelResponse]] before closing.
+//
+// Exactly one of [[Finish]] or [[Cancel]] is sent per Execute stream;
+// see [[Finish]]. [[Cancel]] is the cooperative cancellation path;
+// gRPC-level stream errors are the involuntary fallback. If the
+// stream breaks before [[CancelResponse]] arrives, the engine
+// considers the worker uncancellable for this session and relies on
+// process-level cleanup.
+message Cancel {
+ // (Optional) Free-form reason for diagnostics.
+ optional string reason = 1;
+}
+
+// Worker -> Engine acknowledgment of [[Cancel]].
+message CancelResponse {}
+
+// The single UDF body delivered to the worker on [[Init]]. Opaque to
+// the engine: the engine forwards [[payload]] and [[format]]
+// unchanged, and the worker decodes them per the format the client
+// and worker have agreed on.
+message UdfPayload {
+ // (Required, may be empty when chunked.) Serialized UDF bundle,
+ // opaque to the engine. The encoding is declared in [[format]].
+ //
+ // The bundle is not necessarily just the serialized callable;
+ // it is up to the client side of the protocol and the worker to
+ // agree on what is packed inside it -- e.g. custom encoders for
+ // user-defined types, type hints, or any other metadata the
+ // worker needs to invoke the UDF.
+ //
+ // For payloads too large to fit on a single gRPC message, this
+ // field MAY be left empty (zero-length bytes) and the bytes
+ // delivered via the [[PayloadChunk]] mechanism instead. See
+ // [[PayloadChunk]] for chunking semantics.
+ bytes payload = 1;
+
+ // (Required, non-empty.) Format tag identifying the encoding of
+ // [[payload]] (e.g. "py-cloudpickle-v3", "wasm-v1"). Engine does
+ // not interpret this; the client side of the protocol and the
+ // worker agree on its meaning.
+ string format = 2;
+
+ // (Optional) Total payload size in bytes. Useful when chunked
+ // streaming is used so the worker can pre-allocate buffers.
+ optional int64 payload_size = 3;
+
+ // (Optional) Human-readable name for diagnostics and metrics.
+ optional string name = 4;
+
+ // (Optional) Worker / language-specific dispatch hint. A
+ // free-form string the worker uses to pick the code path that
+ // handles this payload. The protocol does not enumerate eval
+ // types because they are language-specific; the client side of
+ // the protocol and the worker agree on the namespace and the
+ // values.
+ //
+ // When the worker can derive the eval type from the payload
+ // itself (embedded metadata, format tag, etc.), this field is
+ // left unset. Otherwise the client side of the protocol sets it
+ // explicitly.
+ optional string eval_type = 5;
+}
+
+// =====================================================================
+// Manage RPC -- worker-scoped operations independent of Execute
+// =====================================================================
+
+// Engine -> Worker. Wraps the manage operations in a oneof so the RPC
+// is a single typed call, leaving room for future operations
+// (capability query, profiling, ...).
+message WorkerRequest {
+ // Exactly one branch MUST be set; receivers MUST reject messages
+ // with no branch set.
+ oneof manage {
+ Heartbeat heartbeat = 1;
+ ShutdownRequest shutdown = 2;
+ }
+}
+
+// Worker -> Engine.
+message WorkerResponse {
+ // Exactly one branch MUST be set, mirroring the request oneof.
+ oneof manage {
+ HeartbeatResponse heartbeat = 1;
+ ShutdownResponse shutdown = 2;
+ }
+}
+
+// Liveness probe. The engine may send this periodically to detect a
+// hung worker. The worker SHOULD reply within a small bounded time.
+message Heartbeat {}
+
+// Acknowledgment for [[Heartbeat]].
+message HeartbeatResponse {}
+
+// Engine-initiated graceful shutdown request. Independent of SIGTERM
+// (which is the OS-level fallback) -- this lets the worker know the
+// engine has finished with it and intends no further Execute streams.
+message ShutdownRequest {
+ // (Optional) Free-form reason for diagnostics.
+ optional string reason = 1;
+}
+
+// Worker -> Engine acknowledgment of [[ShutdownRequest]].
+message ShutdownResponse {}
diff --git a/udf/worker/proto/src/main/protobuf/worker_spec.proto b/udf/worker/proto/src/main/protobuf/worker_spec.proto
index 83dac4f962e5f..6433d6169f412 100644
--- a/udf/worker/proto/src/main/protobuf/worker_spec.proto
+++ b/udf/worker/proto/src/main/protobuf/worker_spec.proto
@@ -124,7 +124,7 @@ message WorkerCapabilities {
// is reported by the engine as part of the UDF protocol's init message.
//
// (Required)
- repeated UDFWorkerDataFormat supported_data_formats = 1;
+ repeated UdfWorkerDataFormat supported_data_formats = 1;
// Which UDF protocol communication patterns the worker
// supports. This should list all supported patterns.
@@ -135,7 +135,7 @@ message WorkerCapabilities {
// the query will fail during query planning.
//
// (Required)
- repeated UDFProtoCommunicationPattern supported_communication_patterns = 2;
+ repeated UdfProtoCommunicationPattern supported_communication_patterns = 2;
// Whether multiple, concurrent UDF
// connections are supported by this worker