diff --git a/pom.xml b/pom.xml index d7dac399c2aed..0e961f3cb9b2b 100644 --- a/pom.xml +++ b/pom.xml @@ -113,6 +113,7 @@ connector/protobuf udf/worker/proto udf/worker/core + udf/worker/grpc diff --git a/udf/worker/README.md b/udf/worker/README.md index b843c430d0e04..199e51ee7d2ba 100644 --- a/udf/worker/README.md +++ b/udf/worker/README.md @@ -19,7 +19,7 @@ WorkerDispatcher -- manages workers, creates sessions | v WorkerSession -- one UDF execution - | 1. session.init(InitMessage(payload, inputSchema, outputSchema)) + | 1. session.init(Init proto: udf payload + data format + schemas) | 2. val results = session.process(inputBatches) | 3. session.close() ``` @@ -34,19 +34,22 @@ provisioning service or daemon). ``` udf/worker/ ├── proto/ -│ worker_spec.proto -- UDFWorkerSpecification protobuf (+ generated Java classes) -│ common.proto -- shared enums (UDFWorkerDataFormat, etc.) +│ worker_spec.proto -- UDFWorkerSpecification protobuf +│ udf_protocol.proto -- UDF execution protocol (Init, UdfPayload, ...) +│ common.proto -- shared enums (UdfWorkerDataFormat, etc.) │ └── core/ -- abstract interfaces WorkerDispatcher.scala -- creates sessions, manages worker lifecycle - WorkerSession.scala -- per-UDF init/process/cancel/close + InitMessage + WorkerSession.scala -- per-UDF init/process/cancel/close + WorkerSessionFactory.scala -- protocol-level connection + session factory WorkerConnection.scala -- transport channel abstraction WorkerSecurityScope.scala -- security boundary for worker pooling │ └── direct/ -- "direct" creation: local OS processes DirectWorkerDispatcher.scala -- spawns processes, env lifecycle DirectWorkerProcess.scala -- OS process + connection + UDS socket - DirectWorkerSession.scala -- session backed by a direct process + DirectWorkerSession.scala -- decorator: forwards to inner session, + releases worker ref-count on close ``` The `core/` package defines abstract interfaces that are independent of how @@ -76,10 +79,12 @@ Workers are terminated via SIGTERM/SIGKILL when the dispatcher is closed. ```scala import org.apache.spark.udf.worker.{ - DirectWorker, ProcessCallable, UDFProtoCommunicationPattern, - UDFWorkerDataFormat, UDFWorkerProperties, UDFWorkerSpecification, - UnixDomainSocket, WorkerCapabilities, WorkerConnectionSpec, WorkerEnvironment} + DirectWorker, Init, ProcessCallable, UdfPayload, + UdfProtoCommunicationPattern, UdfWorkerDataFormat, UDFWorkerProperties, + UDFWorkerSpecification, UnixDomainSocket, WorkerCapabilities, + WorkerConnectionSpec, WorkerEnvironment} import org.apache.spark.udf.worker.core._ +import com.google.protobuf.ByteString // 1. Define a worker spec (direct creation mode). val spec = UDFWorkerSpecification.newBuilder() @@ -90,9 +95,9 @@ val spec = UDFWorkerSpecification.newBuilder() .addCommand("pip").addCommand("install").addCommand("my_udf_worker").build()) .build()) .setCapabilities(WorkerCapabilities.newBuilder() - .addSupportedDataFormats(UDFWorkerDataFormat.ARROW) + .addSupportedDataFormats(UdfWorkerDataFormat.ARROW) .addSupportedCommunicationPatterns( - UDFProtoCommunicationPattern.BIDIRECTIONAL_STREAMING) + UdfProtoCommunicationPattern.BIDIRECTIONAL_STREAMING) .build()) .setDirect(DirectWorker.newBuilder() .setRunner(ProcessCallable.newBuilder() @@ -112,10 +117,14 @@ val dispatcher: WorkerDispatcher = ... val session = dispatcher.createSession(securityScope = None) try { // 4. Initialize with the serialized function and schemas. - session.init(InitMessage( - functionPayload = serializedFunction, - inputSchema = arrowInputSchema, - outputSchema = arrowOutputSchema)) + session.init(Init.newBuilder() + .setUdf(UdfPayload.newBuilder() + .setPayload(ByteString.copyFrom(serializedFunction)) + .setFormat("py-cloudpickle-v3")) + .setDataFormat(UdfWorkerDataFormat.ARROW) + .setInputSchema(ByteString.copyFrom(arrowInputSchema)) + .setOutputSchema(ByteString.copyFrom(arrowOutputSchema)) + .build()) // 5. Process data -- Iterator in, Iterator out. val results: Iterator[Array[Byte]] = diff --git a/udf/worker/core/src/main/scala/org/apache/spark/udf/worker/core/WorkerSession.scala b/udf/worker/core/src/main/scala/org/apache/spark/udf/worker/core/WorkerSession.scala index f4c4091688c94..4a7701981f223 100644 --- a/udf/worker/core/src/main/scala/org/apache/spark/udf/worker/core/WorkerSession.scala +++ b/udf/worker/core/src/main/scala/org/apache/spark/udf/worker/core/WorkerSession.scala @@ -19,31 +19,7 @@ package org.apache.spark.udf.worker.core import java.util.concurrent.atomic.AtomicBoolean import org.apache.spark.annotation.Experimental - -/** - * :: Experimental :: - * Carries all information needed to initialize a UDF execution on a worker. - * - * This message is passed to [[WorkerSession#init]] and contains the function - * definition, schemas, and any additional configuration. - * - * Placeholder: will be replaced by a generated proto message once the - * UDF wire protocol lands. Do not rely on case-class equality -- - * `Array[Byte]` fields compare by reference. - * - * @param functionPayload serialized function (e.g., pickled Python, JVM bytes) - * @param inputSchema serialized input schema (e.g., Arrow schema bytes) - * @param outputSchema serialized output schema (e.g., Arrow schema bytes) - * @param properties additional key-value configuration. Can carry - * protocol-specific or engine-specific metadata that - * does not yet have a dedicated field. - */ -@Experimental -case class InitMessage( - functionPayload: Array[Byte], - inputSchema: Array[Byte], - outputSchema: Array[Byte], - properties: Map[String, String] = Map.empty) +import org.apache.spark.udf.worker.Init /** * :: Experimental :: @@ -62,7 +38,10 @@ case class InitMessage( * {{{ * val session = dispatcher.createSession(securityScope = None) * try { - * session.init(InitMessage(functionPayload, inputSchema, outputSchema)) + * session.init(Init.newBuilder() + * .setUdf(UdfPayload.newBuilder().setPayload(callable).setFormat(fmt)) + * .setDataFormat(UdfWorkerDataFormat.ARROW) + * .build()) * val results = session.process(inputBatches) * results.foreach(handleBatch) * } finally { @@ -74,7 +53,18 @@ case class InitMessage( * - [[init]] must be called exactly once before [[process]]. * - [[process]] must be called at most once per session. * - [[close]] must always be called (use try-finally). - * - [[cancel]] may be called at any time to abort execution. + * - [[cancel]] may be called at any time, including before [[init]] + * or after [[process]]/[[close]] has returned. Implementations + * treat such calls as a no-op so that callers driven by a task + * interruption listener (which has no view into the session state) + * do not need to coordinate with the thread driving [[process]]. + * + * Cancel-vs-finish race: when the session driver has finished + * sending input (and therefore queued an implicit finish on the + * underlying transport) and a [[cancel]] arrives concurrently, both + * are valid stream-terminating actions; the response side carries + * either a `FinishResponse` or a `CancelResponse` depending on which + * the worker observes first, and either is acceptable to the caller. * * The lifecycle is enforced here: [[init]] and [[process]] are `final` * and delegate to [[doInit]] / [[doProcess]] after AtomicBoolean guards. @@ -93,10 +83,12 @@ abstract class WorkerSession extends AutoCloseable { * * Throws `IllegalStateException` if called more than once. * - * @param message the initialization parameters including the serialized - * function, input/output schemas, and configuration. + * @param message the [[Init]] proto carrying the UDF body, the wire + * data format, optional input/output schemas, and any + * engine-side session context the worker needs to start + * processing. */ - final def init(message: InitMessage): Unit = { + final def init(message: Init): Unit = { if (!initialized.compareAndSet(false, true)) { throw new IllegalStateException("init has already been called on this session") } @@ -128,7 +120,7 @@ abstract class WorkerSession extends AutoCloseable { } /** Subclass hook for [[init]]. Called once, after the guard. */ - protected def doInit(message: InitMessage): Unit + protected def doInit(message: Init): Unit /** Subclass hook for [[process]]. Called at most once, after the guard. */ protected def doProcess(input: Iterator[Array[Byte]]): Iterator[Array[Byte]] @@ -138,11 +130,29 @@ abstract class WorkerSession extends AutoCloseable { * * '''Thread-safety:''' implementations must allow [[cancel]] to be called * from a thread different from the one driving [[process]] (typically a - * task interruption thread). It may be invoked at any point after - * [[init]] and should be a no-op if execution has already finished. + * task interruption thread). + * + * '''Lifecycle:''' [[cancel]] is idempotent and safe at any point in + * the session's life: + * - before [[init]] -- nothing has been sent on the transport yet, + * so [[cancel]] is a no-op (the session may still be closed + * normally via [[close]]). + * - between [[init]] and [[process]] -- transitions the session + * into a cancelled state; subsequent [[process]] calls observe + * the cancellation. + * - during [[process]] -- aborts the active stream. + * - after [[process]] / [[close]] has returned -- a no-op. + * + * Implementations are responsible for the no-op behavior described + * above so that callers (e.g. task interruption listeners) do not + * need to coordinate with the thread driving [[process]]. */ def cancel(): Unit - /** Closes this session and releases resources. */ + /** + * Closes this session and releases resources. Idempotent; safe to + * call from a `finally` block regardless of whether [[init]], + * [[process]], or [[cancel]] have been invoked. + */ override def close(): Unit } diff --git a/udf/worker/core/src/main/scala/org/apache/spark/udf/worker/core/WorkerSessionFactory.scala b/udf/worker/core/src/main/scala/org/apache/spark/udf/worker/core/WorkerSessionFactory.scala new file mode 100644 index 0000000000000..005e902a4a69e --- /dev/null +++ b/udf/worker/core/src/main/scala/org/apache/spark/udf/worker/core/WorkerSessionFactory.scala @@ -0,0 +1,66 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.udf.worker.core + +import org.apache.spark.annotation.Experimental + +/** + * :: Experimental :: + * Protocol-level factory for [[WorkerConnection]]s and per-execution + * [[WorkerSession]]s. + * + * The factory is responsible for everything that depends on the protocol + * the worker speaks (e.g. gRPC) -- building the transport-level connection + * and producing per-execution sessions on top of it. It is '''not''' + * responsible for creating, terminating, or otherwise managing the worker + * itself; that is the [[WorkerDispatcher]]'s job. As long as the dispatcher + * gives the factory an address to dial, and later a connection to talk + * over, the factory does not need to know whether the worker is a + * locally-spawned process, a daemon, or a remote service. + * + * One factory instance is owned by one dispatcher and is closed by the + * dispatcher. Implementations that own protocol-level shared resources + * (e.g. a Netty event loop group) should release them in [[close]]. + */ +@Experimental +trait WorkerSessionFactory extends AutoCloseable { + + /** + * Builds a transport-level connection to a worker reachable at + * `address`. The address format is whatever the dispatcher has agreed + * with the worker spec's transport (e.g. a UDS path, a TCP host:port). + * + * The returned connection is owned by the dispatcher's worker handle; + * its lifecycle is tied to worker teardown. + */ + def createConnection(address: String): WorkerConnection + + /** + * Creates a per-execution [[WorkerSession]] over an already-established + * [[WorkerConnection]]. The factory does not own the connection -- the + * worker handle does -- and a connection may back many sessions over + * its lifetime. + */ + def createSession(connection: WorkerConnection): WorkerSession + + /** + * Releases factory-level resources (event loops, thread pools, ...). + * Called by the dispatcher on close. Default is a no-op so factories + * that have nothing to release can omit the override. + */ + override def close(): Unit = () +} diff --git a/udf/worker/core/src/main/scala/org/apache/spark/udf/worker/core/direct/DirectUnixSocketWorkerDispatcher.scala b/udf/worker/core/src/main/scala/org/apache/spark/udf/worker/core/direct/DirectUnixSocketWorkerDispatcher.scala index 8da0354187e4f..e608187687cb1 100644 --- a/udf/worker/core/src/main/scala/org/apache/spark/udf/worker/core/direct/DirectUnixSocketWorkerDispatcher.scala +++ b/udf/worker/core/src/main/scala/org/apache/spark/udf/worker/core/direct/DirectUnixSocketWorkerDispatcher.scala @@ -22,7 +22,7 @@ import java.nio.file.attribute.PosixFilePermissions import org.apache.spark.annotation.Experimental import org.apache.spark.udf.worker.UDFWorkerSpecification -import org.apache.spark.udf.worker.core.{UnixSocketWorkerConnection, WorkerLogger} +import org.apache.spark.udf.worker.core.{WorkerLogger, WorkerSessionFactory} import org.apache.spark.udf.worker.core.direct.DirectWorkerDispatcher.SOCKET_POLL_INTERVAL_MS /** @@ -31,14 +31,20 @@ import org.apache.spark.udf.worker.core.direct.DirectWorkerDispatcher.SOCKET_POL * transport. Allocates a private 0700 socket directory at construction; * each worker is given a UDS path inside it. * - * Concrete subclasses implement [[createConnection]] (with a UDS protocol - * of choice) and [[createSessionForWorker]]. + * Pair this dispatcher with any [[WorkerSessionFactory]] whose + * [[WorkerSessionFactory#createConnection]] knows how to dial a UDS path + * (e.g. a gRPC-over-UDS factory). + * + * @param sessionFactory protocol-specific connection + session factory. + * Owned by this dispatcher; closed by the base on + * dispatcher close. */ @Experimental -abstract class DirectUnixSocketWorkerDispatcher( +class DirectUnixSocketWorkerDispatcher( workerSpec: UDFWorkerSpecification, + sessionFactory: WorkerSessionFactory, logger: WorkerLogger = WorkerLogger.NoOp) - extends DirectWorkerDispatcher(workerSpec, logger) { + extends DirectWorkerDispatcher(workerSpec, sessionFactory, logger) { // Removed explicitly in closeTransport(). deleteOnExit is avoided because // the JDK retains the path for the JVM lifetime, which leaks in @@ -100,8 +106,6 @@ abstract class DirectUnixSocketWorkerDispatcher( s"got ${conn.getTransportCase}") } - override protected def createConnection(address: String): UnixSocketWorkerConnection - private def throwWorkerExitedBeforeSocket( process: Process, address: String, @@ -117,8 +121,12 @@ abstract class DirectUnixSocketWorkerDispatcher( * On non-POSIX filesystems falls back to best-effort `File.setXxx`, * which is TOCTOU-racy and weaker; a WARN surfaces if the platform * refuses the setters. + * + * Visible for testing: tests override this to escape Spark's parent-POM + * `java.io.tmpdir=target/tmp`, which on long checkout paths pushes UDS + * paths past Linux's 108-char cap. */ - private def createPrivateTempDirectory(): Path = { + protected def createPrivateTempDirectory(): Path = { val attr = PosixFilePermissions.asFileAttribute( PosixFilePermissions.fromString("rwx------")) try { diff --git a/udf/worker/core/src/main/scala/org/apache/spark/udf/worker/core/direct/DirectWorkerDispatcher.scala b/udf/worker/core/src/main/scala/org/apache/spark/udf/worker/core/direct/DirectWorkerDispatcher.scala index afaf23791d80f..1b14323714c4e 100644 --- a/udf/worker/core/src/main/scala/org/apache/spark/udf/worker/core/direct/DirectWorkerDispatcher.scala +++ b/udf/worker/core/src/main/scala/org/apache/spark/udf/worker/core/direct/DirectWorkerDispatcher.scala @@ -29,8 +29,8 @@ import scala.util.control.NonFatal import org.apache.spark.annotation.Experimental import org.apache.spark.udf.worker.{ProcessCallable, UDFWorkerSpecification} -import org.apache.spark.udf.worker.core.{WorkerConnection, WorkerDispatcher, - WorkerLogger, WorkerSecurityScope, WorkerSession} +import org.apache.spark.udf.worker.core.{WorkerDispatcher, WorkerLogger, + WorkerSecurityScope, WorkerSession, WorkerSessionFactory} import org.apache.spark.udf.worker.core.direct.DirectWorkerDispatcher.{CallableResult, DEFAULT_CALLABLE_TIMEOUT_MS, DEFAULT_GRACEFUL_TIMEOUT_MS, DEFAULT_INIT_TIMEOUT_MS, ENGINE_MAX_TIMEOUT_MS, EnvironmentState, MAX_OUTPUT_SCAN_BYTES, @@ -46,13 +46,17 @@ import org.apache.spark.udf.worker.core.direct.DirectWorkerDispatcher.{CallableR * currently gets a fresh worker that is terminated when the session closes * (the single-reference case of the future pooling policy). * - * Subclasses implement [[createConnection]] and [[createSessionForWorker]] - * to provide protocol-specific behavior (e.g., gRPC, raw sockets). + * Subclasses implement transport-level concerns (endpoint allocation, + * readiness wait, transport-state cleanup). The protocol the worker speaks + * (gRPC, raw sockets, ...) is delegated to a [[WorkerSessionFactory]] + * passed in by the caller. * * For workers obtained through a provisioning service or daemon (indirect * creation), see the `indirect` package (TODO). * * @param workerSpec worker specification (proto) + * @param sessionFactory protocol-specific connection + session factory. + * Owned by this dispatcher; closed from [[close]]. * @param logger [[WorkerLogger]] used for dispatcher-internal messages. * The framework does not depend on any concrete logging * backend; callers should pass an adapter that forwards @@ -62,6 +66,7 @@ import org.apache.spark.udf.worker.core.direct.DirectWorkerDispatcher.{CallableR @Experimental abstract class DirectWorkerDispatcher( override val workerSpec: UDFWorkerSpecification, + sessionFactory: WorkerSessionFactory, protected val logger: WorkerLogger = WorkerLogger.NoOp) extends WorkerDispatcher { @@ -156,12 +161,6 @@ abstract class DirectWorkerDispatcher( */ protected def validateTransportSupport(): Unit - /** Creates a protocol-specific connection to a worker at the given address. */ - protected def createConnection(address: String): WorkerConnection - - /** Creates a protocol-specific session for the given worker. */ - protected def createSessionForWorker(worker: DirectWorkerProcess): WorkerSession - override def createSession( securityScope: Option[WorkerSecurityScope]): WorkerSession = { require(securityScope.isEmpty, @@ -180,7 +179,8 @@ abstract class DirectWorkerDispatcher( throwClosed() } try { - createSessionForWorker(worker) + val inner = sessionFactory.createSession(worker.connection) + new DirectWorkerSession(worker, inner, logger) } catch { case e: InterruptedException => Thread.currentThread().interrupt() @@ -237,6 +237,10 @@ abstract class DirectWorkerDispatcher( case NonFatal(e) => logger.warn("Error cleaning up transport state", e) } + try sessionFactory.close() catch { + case NonFatal(e) => + logger.warn("Error closing session factory", e) + } deregisterEnvironmentCleanupHook() runEnvironmentCleanup() } @@ -381,7 +385,7 @@ abstract class DirectWorkerDispatcher( try { waitForReady(address, process, outputFile.toFile) - val connection = createConnection(address) + val connection = sessionFactory.createConnection(address) val artifacts = new WorkerArtifacts(process, connection, outputFile, logger) new DirectWorkerProcess( workerId, artifacts, gracefulTimeoutMs, logger, diff --git a/udf/worker/core/src/main/scala/org/apache/spark/udf/worker/core/direct/DirectWorkerSession.scala b/udf/worker/core/src/main/scala/org/apache/spark/udf/worker/core/direct/DirectWorkerSession.scala index 7cdc5329350e3..0a68440c6afdd 100644 --- a/udf/worker/core/src/main/scala/org/apache/spark/udf/worker/core/direct/DirectWorkerSession.scala +++ b/udf/worker/core/src/main/scala/org/apache/spark/udf/worker/core/direct/DirectWorkerSession.scala @@ -18,39 +18,53 @@ package org.apache.spark.udf.worker.core.direct import java.util.concurrent.atomic.AtomicBoolean +import scala.util.control.NonFatal + import org.apache.spark.annotation.Experimental -import org.apache.spark.udf.worker.core.{WorkerConnection, WorkerSession} +import org.apache.spark.udf.worker.Init +import org.apache.spark.udf.worker.core.{WorkerLogger, WorkerSession} /** * :: Experimental :: - * A [[WorkerSession]] backed by a locally-spawned [[DirectWorkerProcess]]. - * - * This is the session type returned by [[DirectWorkerDispatcher]]. It ties - * the session lifecycle to the worker's ref-count: the dispatcher increments - * the count before construction, and [[close]] decrements it, so the - * dispatcher knows when a worker process is idle and can be terminated or - * reused. + * Lifecycle decorator that ties a protocol-level [[WorkerSession]] to the + * direct worker that backs it. Forwards [[doInit]] / [[doProcess]] / + * [[cancel]] to the inner session, and on [[close]] additionally fires + * the underlying [[DirectWorkerProcess]]'s session ref-count so the + * dispatcher knows when a worker has gone idle. * - * Subclasses implement the protocol-specific data transmission - * ([[init]], [[process]], [[cancel]]). + * Constructed exclusively by [[DirectWorkerDispatcher]]; protocol code + * (gRPC and friends) does not extend or see this class. * - * @param workerProcess the direct worker process backing this session. - * Internal to the `core` package and test code -- the - * worker handle is a dispatcher implementation detail, - * not part of the public WorkerSession API. + * @param workerProcess the direct worker backing this session. + * @param inner the protocol-specific session this decorator forwards to. + * @param logger logger used when the inner close raises. */ @Experimental -abstract class DirectWorkerSession( - private[core] val workerProcess: DirectWorkerProcess) extends WorkerSession { +final class DirectWorkerSession private[direct] ( + private[core] val workerProcess: DirectWorkerProcess, + private val inner: WorkerSession, + private val logger: WorkerLogger = WorkerLogger.NoOp) extends WorkerSession { private val released = new AtomicBoolean(false) - /** The connection to the worker for this session. */ - def connection: WorkerConnection = workerProcess.connection + override protected def doInit(message: Init): Unit = inner.init(message) + + override protected def doProcess( + input: Iterator[Array[Byte]]): Iterator[Array[Byte]] = inner.process(input) + + override def cancel(): Unit = inner.cancel() + /** + * Closes the inner protocol session, then releases the worker's session + * ref-count. The ref-count release is run unconditionally so a faulty + * inner close cannot strand the worker as "still in use". + */ override def close(): Unit = { - if (released.compareAndSet(false, true)) { - workerProcess.releaseSession() + try inner.close() catch { + case NonFatal(e) => + logger.warn(s"Error closing inner session for worker ${workerProcess.id}", e) + } finally { + if (released.compareAndSet(false, true)) workerProcess.releaseSession() } } } diff --git a/udf/worker/core/src/test/scala/org/apache/spark/udf/worker/core/DirectWorkerDispatcherSuite.scala b/udf/worker/core/src/test/scala/org/apache/spark/udf/worker/core/DirectWorkerDispatcherSuite.scala index 60f5e2211b702..cf829880801c4 100644 --- a/udf/worker/core/src/test/scala/org/apache/spark/udf/worker/core/DirectWorkerDispatcherSuite.scala +++ b/udf/worker/core/src/test/scala/org/apache/spark/udf/worker/core/DirectWorkerDispatcherSuite.scala @@ -27,7 +27,7 @@ import org.scalatest.BeforeAndAfterEach import org.scalatest.funsuite.AnyFunSuite import org.apache.spark.udf.worker.{ - DirectWorker, LocalTcpConnection, ProcessCallable, UDFWorkerProperties, + DirectWorker, Init, LocalTcpConnection, ProcessCallable, UDFWorkerProperties, UDFWorkerSpecification, UnixDomainSocket, WorkerConnectionSpec, WorkerEnvironment} import org.apache.spark.udf.worker.core.direct.{DirectUnixSocketWorkerDispatcher, @@ -45,45 +45,50 @@ class SocketFileConnection(socketPath: String) } /** - * A stub [[DirectWorkerSession]] for process-lifecycle tests that don't - * need actual data transmission. + * A stub [[WorkerSession]] for process-lifecycle tests that don't need + * actual data transmission. The dispatcher wraps this in a + * [[DirectWorkerSession]] decorator that owns the worker ref-count. * - * TODO: [[cancel]] is a no-op here. Once a concrete [[DirectWorkerSession]] + * TODO: [[cancel]] is a no-op here. Once a concrete protocol session * with real data-plane wiring lands, add tests exercising cancel() in * particular: cancel from a different thread than process(), cancel * after process() has returned, and cancel before init (should be a * no-op). Tracking the thread-safety contract in the docstring on * [[org.apache.spark.udf.worker.core.WorkerSession.cancel]]. */ -class StubWorkerSession( - workerProcess: DirectWorkerProcess) extends DirectWorkerSession(workerProcess) { - - override protected def doInit(message: InitMessage): Unit = {} +class StubWorkerSession extends WorkerSession { + override protected def doInit(message: Init): Unit = {} override protected def doProcess( input: Iterator[Array[Byte]]): Iterator[Array[Byte]] = Iterator.empty override def cancel(): Unit = {} + + override def close(): Unit = {} } /** - * A [[DirectUnixSocketWorkerDispatcher]] subclass for testing that uses - * a socket-file connection and stub sessions instead of a real protocol - * implementation. + * A [[WorkerSessionFactory]] for tests: hands out + * [[SocketFileConnection]] connections and [[StubWorkerSession]] + * sessions instead of any real protocol implementation. */ -class TestDirectWorkerDispatcher(spec: UDFWorkerSpecification) - extends DirectUnixSocketWorkerDispatcher(spec) { - - override protected def createConnection( - socketPath: String): UnixSocketWorkerConnection = - new SocketFileConnection(socketPath) +class TestSessionFactory extends WorkerSessionFactory { + override def createConnection(address: String): WorkerConnection = + new SocketFileConnection(address) - override protected def createSessionForWorker( - worker: DirectWorkerProcess): WorkerSession = - new StubWorkerSession(worker) + override def createSession(connection: WorkerConnection): WorkerSession = + new StubWorkerSession } +/** + * Test convenience: a [[DirectUnixSocketWorkerDispatcher]] pre-wired + * with the stub factory, so tests can stay focused on dispatcher + * behaviour and not the factory plumbing. + */ +class TestDirectWorkerDispatcher(spec: UDFWorkerSpecification) + extends DirectUnixSocketWorkerDispatcher(spec, new TestSessionFactory) + /** * Tests for [[DirectWorkerDispatcher]] process lifecycle: spawning workers * and terminating them on close. @@ -145,14 +150,14 @@ class DirectWorkerDispatcherSuite } // Narrow the publicly-typed WorkerSession returned by `createSession` back - // down to StubWorkerSession in one place, with a descriptive failure if + // down to DirectWorkerSession in one place, with a descriptive failure if // the cast is ever wrong, so individual tests don't scatter `asInstanceOf` // (which would throw ClassCastException rather than a useful message). - private def createStubSession(): StubWorkerSession = + private def createStubSession(): DirectWorkerSession = dispatcher.createSession(None) match { - case stub: StubWorkerSession => stub + case dws: DirectWorkerSession => dws case other => fail( - s"Expected StubWorkerSession, got ${other.getClass.getSimpleName}") + s"Expected DirectWorkerSession, got ${other.getClass.getSimpleName}") } // The whole suite uses UDS as the only transport, so reaching past the @@ -181,7 +186,7 @@ class DirectWorkerDispatcherSuite dispatcher = new TestDirectWorkerDispatcher(specWithRunner(defaultRunner)) val threads = 8 - val sessions = new java.util.concurrent.ConcurrentLinkedQueue[StubWorkerSession]() + val sessions = new java.util.concurrent.ConcurrentLinkedQueue[DirectWorkerSession]() val startGate = new java.util.concurrent.CountDownLatch(1) val doneGate = new java.util.concurrent.CountDownLatch(threads) val errors = new java.util.concurrent.ConcurrentLinkedQueue[Throwable]() @@ -360,24 +365,24 @@ class DirectWorkerDispatcherSuite // should remain in either case. val readyLatch = new java.util.concurrent.CountDownLatch(1) val releaseLatch = new java.util.concurrent.CountDownLatch(1) - val capturedWorkers = - new java.util.concurrent.ConcurrentLinkedQueue[DirectWorkerProcess]() - val racing = new DirectUnixSocketWorkerDispatcher(specWithRunner(defaultRunner)) { - override protected def createConnection( - socketPath: String): UnixSocketWorkerConnection = - new SocketFileConnection(socketPath) - override protected def createSessionForWorker( - worker: DirectWorkerProcess): WorkerSession = { - capturedWorkers.add(worker) + val capturedConnections = + new java.util.concurrent.ConcurrentLinkedQueue[WorkerConnection]() + val blockingFactory = new WorkerSessionFactory { + override def createConnection(address: String): WorkerConnection = + new SocketFileConnection(address) + override def createSession(connection: WorkerConnection): WorkerSession = { + capturedConnections.add(connection) readyLatch.countDown() // Block here so dispatcher.close() runs while createSession is in // flight. Use a generous wait so a slow CI doesn't time out. if (!releaseLatch.await(30, java.util.concurrent.TimeUnit.SECONDS)) { fail("releaseLatch never fired -- test orchestration broken") } - new StubWorkerSession(worker) + new StubWorkerSession } } + val racing = new DirectUnixSocketWorkerDispatcher( + specWithRunner(defaultRunner), blockingFactory) try { val outcome = new java.util.concurrent.atomic.AtomicReference[Either[Throwable, WorkerSession]]() @@ -394,7 +399,7 @@ class DirectWorkerDispatcherSuite // Wait for thread A to have published the worker and entered the // blocking override. assert(readyLatch.await(10, java.util.concurrent.TimeUnit.SECONDS), - "createSession thread never reached createSessionForWorker") + "createSession thread never reached the factory's createSession") val closeThread = new Thread(() => racing.close(), "close-racer") closeThread.start() @@ -409,10 +414,10 @@ class DirectWorkerDispatcherSuite assert(!createThread.isAlive, "createSession thread did not finish") assert(!closeThread.isAlive, "close thread did not finish") - val captured = capturedWorkers.toArray(Array.empty[DirectWorkerProcess]) + val captured = capturedConnections.toArray(Array.empty[WorkerConnection]) assert(captured.length == 1, - s"expected exactly one worker spawned, got ${captured.length}") - val worker = captured(0) + s"expected exactly one connection spawned, got ${captured.length}") + val sockPath = captured(0).asInstanceOf[UnixSocketWorkerConnection].socketPath outcome.get() match { case Left(e: IllegalStateException) => @@ -425,20 +430,21 @@ class DirectWorkerDispatcherSuite s"expected dispatcher-closed error, got: ${e.getMessage}") case Left(other) => fail(s"unexpected exception from racing createSession: $other") - case Right(_) => + case Right(s: DirectWorkerSession) => // close() iterated the published worker and tore it down; the // returned session points at a worker that should now be dead. + val worker = s.workerProcess + val deadline = System.currentTimeMillis() + 5000 + while (worker.process.isAlive && System.currentTimeMillis() < deadline) { + Thread.sleep(50) + } + assert(!worker.process.isAlive, + s"worker process should be terminated after close, still alive at $sockPath") + case Right(other) => + fail(s"expected DirectWorkerSession, got ${other.getClass.getSimpleName}") } - // Whichever path won, the worker must not still be running and the - // socket file must be gone. - val deadline = System.currentTimeMillis() + 5000 - while (worker.process.isAlive && System.currentTimeMillis() < deadline) { - Thread.sleep(50) - } - val sockPath = udsPath(worker) - assert(!worker.process.isAlive, - s"worker process should be terminated after close, still alive at $sockPath") + // Whichever branch won, the socket file must be gone. assert(!new java.io.File(sockPath).exists(), s"socket file $sockPath should have been removed") } finally { @@ -533,32 +539,33 @@ class DirectWorkerDispatcherSuite // -- Error-path tests ------------------------------------------------------- - test("worker is cleaned up when createSessionForWorker throws") { - // A dispatcher whose createSessionForWorker always throws. The spawned - // worker must be terminated rather than leaked until dispatcher.close(). - var capturedWorker: DirectWorkerProcess = null - val failingDispatcher = - new DirectUnixSocketWorkerDispatcher(specWithRunner(defaultRunner)) { - override protected def createConnection( - socketPath: String): UnixSocketWorkerConnection = - new SocketFileConnection(socketPath) - override protected def createSessionForWorker( - worker: DirectWorkerProcess): WorkerSession = { - capturedWorker = worker - throw new RuntimeException("session creation failed") - } + test("worker is cleaned up when session factory createSession throws") { + // A factory whose createSession always throws. The spawned worker must + // be terminated rather than leaked until dispatcher.close(). + var capturedSocketPath: String = null + val failingFactory = new WorkerSessionFactory { + override def createConnection(address: String): WorkerConnection = + new SocketFileConnection(address) + override def createSession(connection: WorkerConnection): WorkerSession = { + capturedSocketPath = connection.asInstanceOf[UnixSocketWorkerConnection].socketPath + throw new RuntimeException("session creation failed") } + } + val failingDispatcher = new DirectUnixSocketWorkerDispatcher( + specWithRunner(defaultRunner), failingFactory) try { val ex = intercept[RuntimeException] { failingDispatcher.createSession(None) } assert(ex.getMessage.contains("session creation failed")) - assert(capturedWorker != null, "worker should have been spawned before the failure") - assert(!capturedWorker.process.isAlive, - "worker process should have been terminated after session creation failed") - assert(capturedWorker.activeSessions == 0, - "worker session count should be released after failure") + assert(capturedSocketPath != null, + "factory should have been called before the failure") + // Worker teardown removes the UDS socket file (via SocketFileConnection's + // base-class close), so the file going away is the observable that the + // worker was reaped rather than leaked. + assert(!new File(capturedSocketPath).exists(), + s"socket file $capturedSocketPath should be cleaned up after failure") } finally { failingDispatcher.close() } @@ -591,19 +598,18 @@ class DirectWorkerDispatcherSuite s"expected UDS-only error, got: ${ex.getMessage}") } - test("socket file is cleaned up when createConnection throws") { + test("socket file is cleaned up when factory createConnection throws") { val capturedSocketPaths = new java.util.concurrent.ConcurrentLinkedQueue[String]() - val failingDispatcher = - new DirectUnixSocketWorkerDispatcher(specWithRunner(defaultRunner)) { - override protected def createConnection( - socketPath: String): UnixSocketWorkerConnection = { - capturedSocketPaths.add(socketPath) - throw new RuntimeException("connection creation failed") - } - override protected def createSessionForWorker( - worker: DirectWorkerProcess): WorkerSession = - new StubWorkerSession(worker) + val failingFactory = new WorkerSessionFactory { + override def createConnection(address: String): WorkerConnection = { + capturedSocketPaths.add(address) + throw new RuntimeException("connection creation failed") } + override def createSession(connection: WorkerConnection): WorkerSession = + new StubWorkerSession + } + val failingDispatcher = new DirectUnixSocketWorkerDispatcher( + specWithRunner(defaultRunner), failingFactory) try { val ex = intercept[RuntimeException] { failingDispatcher.createSession(None) @@ -760,14 +766,8 @@ class DirectWorkerDispatcherSuite .addCommand("sleep 30").build() val env = WorkerEnvironment.newBuilder().setInstallation(slowInstall).build() val shortTimeoutDispatcher = - new DirectUnixSocketWorkerDispatcher(specWithEnv(env = env)) { + new DirectUnixSocketWorkerDispatcher(specWithEnv(env = env), new TestSessionFactory) { override protected def callableTimeoutMs: Long = 500L - override protected def createConnection( - socketPath: String): UnixSocketWorkerConnection = - new SocketFileConnection(socketPath) - override protected def createSessionForWorker( - worker: DirectWorkerProcess): WorkerSession = - new StubWorkerSession(worker) } try { val ex = intercept[DirectWorkerTimeoutException] { @@ -897,14 +897,8 @@ class DirectWorkerDispatcherSuite s"echo invoked >> ${counterFile.getAbsolutePath}; sleep 30").build()) .build() val timeoutDispatcher = - new DirectUnixSocketWorkerDispatcher(specWithEnv(env = env)) { + new DirectUnixSocketWorkerDispatcher(specWithEnv(env = env), new TestSessionFactory) { override protected def callableTimeoutMs: Long = 500L - override protected def createConnection( - socketPath: String): UnixSocketWorkerConnection = - new SocketFileConnection(socketPath) - override protected def createSessionForWorker( - worker: DirectWorkerProcess): WorkerSession = - new StubWorkerSession(worker) } try { val first = intercept[DirectWorkerTimeoutException] { diff --git a/udf/worker/grpc/pom.xml b/udf/worker/grpc/pom.xml new file mode 100644 index 0000000000000..677c26c8b106a --- /dev/null +++ b/udf/worker/grpc/pom.xml @@ -0,0 +1,86 @@ + + + + + 4.0.0 + + org.apache.spark + spark-parent_2.13 + 4.2.0-SNAPSHOT + ../../../pom.xml + + + spark-udf-worker-grpc_2.13 + jar + Spark Project UDF Worker gRPC + https://spark.apache.org/ + + + udf-worker-grpc + + + + + org.apache.spark + spark-tags_${scala.binary.version} + + + org.apache.spark + spark-udf-worker-core_${scala.binary.version} + ${project.version} + + + org.apache.spark + spark-udf-worker-proto_${scala.binary.version} + ${project.version} + + + org.scala-lang + scala-library + + + io.grpc + grpc-netty-shaded + ${io.grpc.version} + + + io.grpc + grpc-stub + ${io.grpc.version} + + + io.grpc + grpc-protobuf + ${io.grpc.version} + + + + + target/scala-${scala.binary.version}/classes + target/scala-${scala.binary.version}/test-classes + + + net.alchim31.maven + scala-maven-plugin + + + + diff --git a/udf/worker/grpc/src/main/scala/org/apache/spark/udf/worker/core/grpc/GrpcWorkerConnection.scala b/udf/worker/grpc/src/main/scala/org/apache/spark/udf/worker/core/grpc/GrpcWorkerConnection.scala new file mode 100644 index 0000000000000..efe45185862bf --- /dev/null +++ b/udf/worker/grpc/src/main/scala/org/apache/spark/udf/worker/core/grpc/GrpcWorkerConnection.scala @@ -0,0 +1,71 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.udf.worker.core.grpc + +import java.util.concurrent.TimeUnit + +import io.grpc.{ConnectivityState, ManagedChannel} + +import org.apache.spark.annotation.Experimental +import org.apache.spark.udf.worker.core.UnixSocketWorkerConnection + +/** + * :: Experimental :: + * A [[UnixSocketWorkerConnection]] backed by a gRPC [[ManagedChannel]]. + * + * The Netty event loop group used by the channel is owned by the + * [[GrpcWorkerSessionFactory]] that created this connection (shared + * across all connections it produces), so [[close]] only shuts down + * the channel here -- not the event loop. The base class handles + * socket-file removal. + * + * @param channel the gRPC managed channel connected to the worker + * @param socketPath the UDS path the channel is bound to; passed to the + * base class for ownership/cleanup of the socket file + */ +@Experimental +class GrpcWorkerConnection( + val channel: ManagedChannel, + socketPath: String) extends UnixSocketWorkerConnection(socketPath) { + + private val SHUTDOWN_TIMEOUT_MS = 5000L + + /** + * A channel is considered active unless explicitly shut down. + * Transient states (IDLE, CONNECTING, TRANSIENT_FAILURE) may still + * recover, so we only report inactive for SHUTDOWN. + */ + override def isActive: Boolean = { + val state = channel.getState(false) + state != ConnectivityState.SHUTDOWN + } + + override def close(): Unit = { + channel.shutdown() + try { + channel.awaitTermination(SHUTDOWN_TIMEOUT_MS, TimeUnit.MILLISECONDS) + } catch { + case _: InterruptedException => + Thread.currentThread().interrupt() + } finally { + if (!channel.isTerminated) { + channel.shutdownNow() + } + } + super.close() + } +} diff --git a/udf/worker/grpc/src/main/scala/org/apache/spark/udf/worker/core/grpc/GrpcWorkerSession.scala b/udf/worker/grpc/src/main/scala/org/apache/spark/udf/worker/core/grpc/GrpcWorkerSession.scala new file mode 100644 index 0000000000000..825aba8ebb8d7 --- /dev/null +++ b/udf/worker/grpc/src/main/scala/org/apache/spark/udf/worker/core/grpc/GrpcWorkerSession.scala @@ -0,0 +1,541 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.udf.worker.core.grpc + +import java.util.concurrent.{CancellationException, LinkedBlockingQueue, TimeUnit} +import java.util.concurrent.atomic.AtomicBoolean + +import scala.util.control.NonFatal + +import com.google.protobuf.ByteString +import io.grpc.stub.{CallStreamObserver, StreamObserver} + +import org.apache.spark.annotation.Experimental +import org.apache.spark.udf.worker._ +import org.apache.spark.udf.worker.core.WorkerSession + +/** + * :: Experimental :: + * A [[WorkerSession]] that implements the UDF protocol over a + * bidirectional gRPC stream + * ([[WorkerGrpc.WorkerStub#execute]]). + * + * Each instance maps to one UDF execution on a single gRPC stream. + * The number of output batches is independent of the number of input + * batches (depends on UDF shape: scalar is 1:1, aggregate is N:1, etc.). + * + * ==Threading model== + * + * This session uses a '''single caller thread''' for both sending input + * and receiving results. gRPC's internal I/O threads handle network + * transport and invoke the response [[StreamObserver]] callbacks, which + * enqueue responses into [[responseQueue]]. The caller thread + * interleaves sending and polling: + * + * {{{ + * Caller thread gRPC I/O thread Worker (gRPC server) + * | | | + * |-- init(): send Init ----------|----> Init ----------->| + * | poll responseQueue | | + * |<--- enqueue InitResponse <----|<---- InitResponse ----| + * | | | + * |-- process(input): | | + * | loop: | | + * | poll responseQueue | | + * | if result: yield batch | | + * | if empty & input & ready: | | + * | send DataRequest -------|----> DataRequest ----->| + * | |<---- DataResponse ----| + * |<--- enqueue DataResponse <----| | + * | yield batch | | + * | (input exhausted) | | + * | send Finish + complete ---|----> Finish --------->| + * | |<---- FinishResponse --| + * |<--- enqueue FinishResponse <--| | + * | |<---- onCompleted -----| + * |<--- enqueue StreamEnd <-------| | + * | hasNext -> false | | + * | | | + * |-- close() ------------------->| | + * }}} + * + * ==Backpressure== + * + * The caller thread uses [[CallStreamObserver.isReady]] to check + * whether the gRPC transport buffer can accept more data before sending + * each batch. When the buffer is full (worker is slow to consume), the + * caller instead polls for results, naturally throttling the send rate. + * This avoids unbounded memory growth that would occur if the sender + * pushed data faster than the worker can process it. + * + * The interleaved send/receive loop also provides implicit backpressure: + * results are always preferred over sending more input. A batch is only + * sent when the result queue is empty and the transport is ready, which + * limits the amount of in-flight data to what the worker can absorb. + * + * ==Error handling== + * + * The gRPC callback [[StreamObserver.onError]] enqueues the error into + * [[responseQueue]]. On the next poll, the caller thread re-throws it. + * Before sending each input batch, the caller checks for errors already + * received from the worker, allowing early termination without pushing + * data into a failed stream. + * + * @param connection the gRPC connection to the worker. The connection is + * owned by the dispatcher's worker handle; this session + * uses but does not close it. + */ +@Experimental +class GrpcWorkerSession( + connection: GrpcWorkerConnection, + controlResponseTimeoutMs: Long = GrpcWorkerSession.DefaultControlResponseTimeoutMs) + extends WorkerSession { + + // TODO: wire order ("Init -> Data* -> Finish|Cancel") is enforced by control + // flow rather than an explicit state machine, so a future change that e.g. + // sends Cancel before Init would compile cleanly with no defensive guard. + + // TODO: configurable timeouts -- POLL_TIMEOUT_MS, controlResponseTimeoutMs, + // and the graceful shutdown timeout in close() are all hardcoded today. + private val POLL_TIMEOUT_MS = 100L + + // Events enqueued by the gRPC callback thread and consumed by the caller. + // A small sealed ADT is clearer than Either + a reference-equality sentinel, + // and gives the pattern matches below exhaustiveness checking for free. + private sealed trait StreamEvent + private case class StreamResponse(udf: UdfResponse) extends StreamEvent + private case class StreamError(cause: Throwable) extends StreamEvent + private case object StreamEnd extends StreamEvent + + /** + * Events from the gRPC callback thread. Unbounded because every data + * message is drained by the iterator and the few control messages + * (Init/Finish/Cancel ack) never pile up; the queue is a hand-off, not + * a buffer that grows with traffic. + */ + private val responseQueue = new LinkedBlockingQueue[StreamEvent]() + + /** Set to true by the gRPC callback when the server stream ends. */ + @volatile private var completed = false + + /** Captures the first error from the worker so the caller can check it. */ + @volatile private var workerError: Throwable = _ + + private val closed = new AtomicBoolean(false) + + /** True once the client stream has been finished or cancelled. */ + private val streamFinished = new AtomicBoolean(false) + + /** + * True once the Init request has been written to the request stream. + * Gates [[cancel]] so it cannot send a Cancel as the first wire message + * (which would violate the proto's `Init -> ... -> Cancel|Finish` ordering). + */ + private val initSent = new AtomicBoolean(false) + + /** + * True once an explicit [[cancel]] has been requested by the caller. + * Distinct from [[streamFinished]] (which is also set by close/finish) so + * [[doProcess]] can surface a cancellation to a caller that called + * [[cancel]] between [[init]] and [[process]] rather than silently + * returning an empty iterator. + */ + private val cancelRequested = new AtomicBoolean(false) + + /** + * Guards all access to [[requestObserver]] so that cancel() and close() + * (which may be called from task-cancellation threads) cannot race with + * the caller thread's sends. gRPC StreamObserver is not thread-safe. + */ + private val streamLock = new Object + + private val stub = WorkerGrpc.newStub(connection.channel) + + // Open the bidirectional stream. The gRPC I/O thread writes to + // responseQueue; the caller thread reads from it. + private val requestObserver: StreamObserver[UdfRequest] = + stub.execute(new StreamObserver[UdfResponse] { + override def onNext(response: UdfResponse): Unit = { + enqueueOrInterrupt(StreamResponse(response)) + } + + override def onError(t: Throwable): Unit = { + workerError = t + completed = true + enqueueOrInterrupt(StreamError(t)) + } + + override def onCompleted(): Unit = { + completed = true + // Enqueue StreamEnd so the consumer unblocks immediately instead of + // spin-polling on the `completed` volatile. + enqueueOrInterrupt(StreamEnd) + } + + // gRPC callbacks must not throw checked exceptions; if `put` is + // ever interrupted (the unbounded queue makes this unrealistic in + // practice), restore the interrupt flag and capture the failure as + // a worker error so the caller surfaces it. + private def enqueueOrInterrupt(event: StreamEvent): Unit = { + try { + responseQueue.put(event) + } catch { + case ie: InterruptedException => + Thread.currentThread().interrupt() + workerError = ie + completed = true + } + } + }) + + // TODO: payload chunking. Large UDF payloads can exceed gRPC's per-message + // size limit. The proto defines PayloadChunk for streaming; this method + // forwards the payload inline and does not yet chunk. + // TODO: capability validation. Init.dataFormat is not checked against the + // spec's supported_data_formats, nor BIDIRECTIONAL_STREAMING against + // supported_communication_patterns. Belongs in the dispatcher/factory. + override protected def doInit(message: Init): Unit = { + val request = UdfRequest.newBuilder() + .setControl(UdfControlRequest.newBuilder().setInit(message)) + .build() + streamLock.synchronized { + requestObserver.onNext(request) + // Set after onNext returns so cancel() only sends Cancel when Init + // has actually been written. Inside the lock so cancel() observes a + // consistent ordering between "Init went on the wire" and the flag. + initSent.set(true) + } + awaitControlResponse(UdfControlResponse.ControlCase.INIT) + } + + /** + * Streams input to the worker and returns an iterator of results. + * + * Uses a '''single-thread interleaved''' model: the caller thread + * alternates between polling for results and sending input batches. + * Results are always preferred over sending -- a new batch is only + * sent when the result queue is empty and the gRPC transport is ready + * ([[CallStreamObserver.isReady]]). This naturally throttles the send + * rate to the worker's processing speed and limits memory usage. + * + * The number of result batches is independent of input batches; + * it depends on the UDF shape (scalar: 1:1, aggregate: N:1, etc.). + * Generator-style UDFs may produce output even with empty input. + */ + override protected def doProcess(input: Iterator[Array[Byte]]): Iterator[Array[Byte]] = { + // If cancel() ran between init and process, surface it explicitly so + // the caller doesn't get an empty iterator that's indistinguishable + // from a generator UDF that legitimately produced nothing. + if (cancelRequested.get()) { + throw new CancellationException("session was cancelled before process") + } + + // If input is already empty, send Finish immediately. + // Generator-style UDFs may still produce output without input. + var inputExhausted = false + if (!input.hasNext) { + inputExhausted = true + sendFinishAndComplete() + } + + // gRPC stub methods return a CallStreamObserver in practice; the public + // interface only promises StreamObserver. Use a typed match so a future + // change in grpc-java surfaces a clear error rather than ClassCastException. + val callObserver = requestObserver match { + case c: CallStreamObserver[UdfRequest @unchecked] => c + case other => throw new IllegalStateException( + s"expected CallStreamObserver from gRPC stub, got ${other.getClass.getName}") + } + + new Iterator[Array[Byte]] { + private var pendingBatch: Option[Array[Byte]] = None + private var exhausted = false + + override def hasNext: Boolean = { + if (exhausted) return false + if (pendingBatch.isDefined) return true + fillNextBatch() + pendingBatch.isDefined + } + + override def next(): Array[Byte] = { + if (!hasNext) throw new NoSuchElementException("No more result batches") + val result = pendingBatch.get + pendingBatch = None + result + } + + /** + * Interleaved send/receive loop. On each iteration: + * + * 1. Check for worker errors -- stop early if the stream has failed. + * 2. Poll the result queue -- prefer consuming results to reduce + * memory pressure from buffered responses. + * 3. If no result is available and there is more input to send and + * the gRPC transport is ready, send the next input batch. + * + * This ordering ensures the caller never sends faster than it + * consumes, providing natural backpressure. + */ + private def fillNextBatch(): Unit = { + while (pendingBatch.isEmpty && !exhausted) { + // 1. Check for worker errors before doing any work. + throwIfErrored() + + // 2. Poll the result queue. The timeout is required: an aggregate- + // like UDF only emits output after consuming enough input, so a + // blocking poll here would deadlock -- we cannot send the next + // input batch (step 3 below) until the poll returns, and the + // worker has nothing to send back yet. + // TODO: liveness during execution. The proto's Manage RPC + // (Heartbeat / ShutdownRequest) is declared but unused, and + // WorkerConnection.isActive only reflects the gRPC channel + // state, not the worker process. A hung worker leaves this + // loop polling every POLL_TIMEOUT_MS forever. + val item = responseQueue.poll(POLL_TIMEOUT_MS, TimeUnit.MILLISECONDS) + if (item != null) { + item match { + case StreamEnd => + exhausted = true + return + case StreamResponse(response) if response.hasData => + val dataResponse = response.getData + val bytes = dataResponse.getData.toByteArray + pendingBatch = Some(bytes) + return + case StreamResponse(_) => + // Control response (FinishResponse, etc.); not a data batch. + // TODO: FinishResponse.data is silently dropped. The proto + // carries an optional final payload here; we have no API + // to surface it outside the data stream. + case StreamError(t) => throw t + } + } else if (completed && responseQueue.isEmpty) { + exhausted = true + return + } + + // 3. No result available -- send more input if possible. + // Only send when the transport is ready (backpressure: the gRPC + // transport buffer has room). If not ready, loop back to poll + // -- the worker is slow and we should wait for results. + if (!inputExhausted && input.hasNext && callObserver.isReady) { + try { + val data = input.next() + sendDataBatch(data) + if (!input.hasNext) { + inputExhausted = true + sendFinishAndComplete() + } + } catch { + case NonFatal(e) => + // Error from upstream (e.g., input serialization failure). + // Notify the worker so it can clean up, then propagate. + finishStreamWithError(e) + throw e + } + } else if (!inputExhausted && !input.hasNext) { + // Input exhausted between the last check and now. + inputExhausted = true + sendFinishAndComplete() + } + } + } + } + } + + // TODO: cancel cannot quickly unblock an in-flight process() poll. cancel() + // from another thread relies on the worker (or transport teardown) to + // deliver StreamEnd before the poll loop notices, so task-interruption + // latency is bounded by POLL_TIMEOUT_MS plus a worker round-trip. The + // contract permits this -- the response side may carry FinishResponse or + // CancelResponse depending on which the worker observes first. + override def cancel(): Unit = { + cancelRequested.set(true) + // Per WorkerSession's lifecycle contract, cancel before init is a + // no-op: nothing has been sent on the transport yet, and sending + // Cancel as the first request would violate the proto's + // `Init -> ... -> Cancel|Finish` ordering. + if (!initSent.get()) return + if (streamFinished.compareAndSet(false, true)) { + streamLock.synchronized { + val request = UdfRequest.newBuilder() + .setControl(UdfControlRequest.newBuilder().setCancel(Cancel.getDefaultInstance)) + .build() + // Best-effort: cancel can race with the worker tearing down the + // stream (e.g. crash, transport reset). Swallow non-fatal errors + // so a task-interruption thread is never disrupted by cleanup. + try { + requestObserver.onNext(request) + requestObserver.onCompleted() + } catch { + case NonFatal(_) => // stream already terminated; cancel is best-effort + } + } + } + } + + /** + * Closes this session. If the stream was not properly finished (via + * process() or cancel()), sends a Cancel to notify the worker. + * + * If sending Cancel fails (e.g., the transport is already broken), + * falls back to [[StreamObserver.onError]] to ensure the server + * receives a termination signal and does not wait indefinitely. + */ + override def close(): Unit = { + if (!closed.compareAndSet(false, true)) return + // Mirror cancel(): if init never went on the wire, the proto has not + // started, so closing is purely local cleanup. + if (!initSent.get()) return + if (!streamFinished.compareAndSet(false, true)) return + streamLock.synchronized { + var sentCancel = false + try { + val request = UdfRequest.newBuilder() + .setControl(UdfControlRequest.newBuilder() + .setCancel(Cancel.getDefaultInstance)) + .build() + requestObserver.onNext(request) + sentCancel = true + requestObserver.onCompleted() + } catch { + case NonFatal(_) if !sentCancel => + // onNext threw before onCompleted ran; no terminal call has + // succeeded yet, so onError is a valid terminator that lets + // the server stop waiting. Kept under streamLock so the + // (non-thread-safe) StreamObserver is only ever touched by + // one thread at a time. + try { + requestObserver.onError( + new RuntimeException("session closed before normal termination")) + } catch { + case NonFatal(_) => // stream already terminated + } + case NonFatal(_) => + // onCompleted threw after onNext succeeded. The terminal call + // is ambiguous; per the gRPC StreamObserver contract we must + // not also call onError, which would be a second terminator. + } + } + } + + private def sendDataBatch(data: Array[Byte]): Unit = { + val request = UdfRequest.newBuilder() + .setData(DataRequest.newBuilder().setData(ByteString.copyFrom(data))) + .build() + streamLock.synchronized { + if (!streamFinished.get()) { + requestObserver.onNext(request) + } + } + } + + /** + * Sends Finish and completes the client stream. The `streamFinished` + * CAS ensures only one code path (normal completion, cancel, or close) + * wins and actually touches the observer. + */ + private def sendFinishAndComplete(): Unit = { + if (streamFinished.compareAndSet(false, true)) { + streamLock.synchronized { + val finish = UdfRequest.newBuilder() + .setControl(UdfControlRequest.newBuilder() + .setFinish(Finish.getDefaultInstance)) + .build() + requestObserver.onNext(finish) + requestObserver.onCompleted() + } + } + } + + /** + * Terminates the stream with a client-side error. Used when the input + * iterator throws, so the worker can clean up rather than wait for + * data that will never arrive. + */ + private def finishStreamWithError(cause: Throwable): Unit = { + if (streamFinished.compareAndSet(false, true)) { + streamLock.synchronized { + try { + requestObserver.onError(cause) + } catch { + case NonFatal(_) => // stream already terminated + } + } + } + } + + /** + * Throws the worker's reported error, if any. Called before each send + * so we don't push data into a stream the worker has already failed. + */ + private def throwIfErrored(): Unit = { + val err = workerError + if (err != null) throw err + } + + /** + * Blocks until a control response of `expected` type arrives from the + * worker, or the timeout expires. + * + * A data response arriving before the control ack is a protocol + * violation (the worker must ack the matching control before any data). + * A control response of the wrong type (e.g. `FinishResponse` arriving + * where `InitResponse` is expected) is also a protocol violation; both + * are surfaced as errors rather than silently accepted. + * + * Uses `nanoTime` for a monotonic deadline so an NTP step cannot make + * the wait spuriously time out (or never time out). + */ + private def awaitControlResponse(expected: UdfControlResponse.ControlCase): Unit = { + val timeoutNanos = TimeUnit.MILLISECONDS.toNanos(controlResponseTimeoutMs) + val deadlineNanos = System.nanoTime() + timeoutNanos + var remainingNanos = timeoutNanos + while (remainingNanos > 0L) { + val item = responseQueue.poll(remainingNanos, TimeUnit.NANOSECONDS) + if (item != null) { + item match { + case StreamResponse(response) if response.hasControl => + val actual = response.getControl.getControlCase + if (actual == expected) return + throw new RuntimeException( + s"Worker sent control response of type $actual where " + + s"$expected was expected (protocol violation)") + case StreamResponse(_) => + throw new RuntimeException( + "Worker sent a non-control response before the expected " + + "control response (protocol violation)") + case StreamError(t) => throw t + case StreamEnd => + throw new RuntimeException( + "Worker stream completed without sending expected control response") + } + } + remainingNanos = deadlineNanos - System.nanoTime() + } + val channelState = connection.channel.getState(false) + throw new RuntimeException( + s"Timed out after ${controlResponseTimeoutMs}ms waiting for control " + + s"response (channel state: $channelState)") + } +} + +private[grpc] object GrpcWorkerSession { + /** Default deadline for awaiting any single control response (Init/Finish/Cancel). */ + val DefaultControlResponseTimeoutMs: Long = 30000L +} diff --git a/udf/worker/grpc/src/main/scala/org/apache/spark/udf/worker/core/grpc/GrpcWorkerSessionFactory.scala b/udf/worker/grpc/src/main/scala/org/apache/spark/udf/worker/core/grpc/GrpcWorkerSessionFactory.scala new file mode 100644 index 0000000000000..04af0c8af2849 --- /dev/null +++ b/udf/worker/grpc/src/main/scala/org/apache/spark/udf/worker/core/grpc/GrpcWorkerSessionFactory.scala @@ -0,0 +1,85 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.udf.worker.core.grpc + +import java.util.concurrent.TimeUnit + +import io.grpc.netty.shaded.io.grpc.netty.NettyChannelBuilder +import io.grpc.netty.shaded.io.netty.channel.epoll.{EpollDomainSocketChannel, EpollEventLoopGroup} +import io.grpc.netty.shaded.io.netty.channel.unix.DomainSocketAddress + +import org.apache.spark.annotation.Experimental +import org.apache.spark.udf.worker.core.{WorkerConnection, WorkerSession, + WorkerSessionFactory} + +/** + * :: Experimental :: + * A [[WorkerSessionFactory]] that hands out gRPC connections (over UDS) + * and [[GrpcWorkerSession]] instances on top of them. + * + * Plug it into any direct dispatcher whose transport is UDS, e.g.: + * {{{ + * val dispatcher = new DirectUnixSocketWorkerDispatcher( + * spec, new GrpcWorkerSessionFactory) + * }}} + * + * '''Platform limitation:''' uses Linux epoll for UDS transport (via + * Netty's [[EpollDomainSocketChannel]]). It is '''Linux-only''' and + * will not work on macOS (which would require kqueue) or Windows. + * The [[EpollEventLoopGroup]] is created in the constructor, so + * instantiating this factory on a non-Linux host fails immediately + * with a `LinkageError` / native-library load failure -- not lazily + * at session-creation time. TCP/IP transport is also not currently + * supported. A single [[EpollEventLoopGroup]] is shared across all + * connections produced by this factory to avoid excessive thread + * creation. + * + * TODO: Support TCP/IP transport and macOS (kqueue). + */ +@Experimental +class GrpcWorkerSessionFactory extends WorkerSessionFactory { + + private val SHUTDOWN_TIMEOUT_MS = 5000L + + private val sharedEventLoopGroup = new EpollEventLoopGroup() + + override def createConnection(address: String): GrpcWorkerConnection = { + val channel = NettyChannelBuilder + .forAddress(new DomainSocketAddress(address)) + .eventLoopGroup(sharedEventLoopGroup) + .channelType(classOf[EpollDomainSocketChannel]) + .usePlaintext() + .build() + new GrpcWorkerConnection(channel, address) + } + + override def createSession(connection: WorkerConnection): WorkerSession = { + val grpc = connection match { + case g: GrpcWorkerConnection => g + case other => throw new IllegalArgumentException( + "GrpcWorkerSessionFactory requires a GrpcWorkerConnection, " + + s"got ${other.getClass.getName}") + } + new GrpcWorkerSession(grpc) + } + + override def close(): Unit = { + sharedEventLoopGroup + .shutdownGracefully(0, SHUTDOWN_TIMEOUT_MS, TimeUnit.MILLISECONDS) + .await(SHUTDOWN_TIMEOUT_MS, TimeUnit.MILLISECONDS) + } +} diff --git a/udf/worker/grpc/src/test/scala/org/apache/spark/udf/worker/core/grpc/EchoUDFWorkerService.scala b/udf/worker/grpc/src/test/scala/org/apache/spark/udf/worker/core/grpc/EchoUDFWorkerService.scala new file mode 100644 index 0000000000000..ad7fb3ab2ae76 --- /dev/null +++ b/udf/worker/grpc/src/test/scala/org/apache/spark/udf/worker/core/grpc/EchoUDFWorkerService.scala @@ -0,0 +1,272 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.udf.worker.core.grpc + +import java.util.concurrent.{ConcurrentLinkedQueue, TimeUnit} + +import scala.jdk.CollectionConverters._ + +import io.grpc.Server +import io.grpc.netty.shaded.io.grpc.netty.NettyServerBuilder +import io.grpc.netty.shaded.io.netty.channel.EventLoopGroup +import io.grpc.netty.shaded.io.netty.channel.epoll.{EpollEventLoopGroup, + EpollServerDomainSocketChannel} +import io.grpc.netty.shaded.io.netty.channel.unix.DomainSocketAddress +import io.grpc.stub.StreamObserver + +import org.apache.spark.udf.worker._ + +/** + * A dummy UDF worker gRPC service that echoes data back to the client. + * + * {{{ + * Caller (GrpcWorkerSession) gRPC I/O EchoUDFWorkerService + * | | | + * |--- onNext(Init) -------------|----> Init ------------>| + * | |<----- InitResponse ----| + * |--- onNext(Data) -------------|----> Data ------------>| + * | |<----- Data (echo) -----| + * | ... repeats per input batch ... | + * |--- onNext(Finish) -----------|----> Finish ---------->| + * | |<----- FinishResponse --| + * |--- onCompleted ------- | | + * | |<----- onCompleted -----| + * }}} + * + * Each inbound request produces exactly one outbound response synchronously + * on the gRPC I/O thread, in arrival order. There is no UDF logic and no + * buffering. + * + * Limits / non-features: + * - Permissive: accepts any wire order; does not enforce the proto's + * "Init -> Data* -> Finish|Cancel" sequence, and silently drops requests + * that match no branch (the proto says receivers MUST reject these). + * - On `onError`, the throwable is recorded for [[assertNoErrors()]] but + * is not propagated to the response stream; callers may need to forcibly + * shut the channel down for cleanup. + * - [[assertNoErrors()]] is the only assertion hook -- it does not verify + * that anything was received, only that nothing errored. + * + * Tests should call [[assertNoErrors()]] after exercising the stream to + * verify no gRPC errors were silently swallowed on the server side. + */ +class EchoUDFWorkerService extends WorkerGrpc.WorkerImplBase { + + private val errors = new ConcurrentLinkedQueue[Throwable]() + + // Recording surface for tests. Captures every UdfRequest the service + // received, in arrival order, so tests can assert wire-order and field + // round-trip without changing the echo behaviour. + private val received = new ConcurrentLinkedQueue[UdfRequest]() + + /** + * Asserts that no gRPC errors were received on the server side. + * Call this at the end of each test to catch errors that might + * otherwise go unnoticed (e.g., stream resets, serialization failures). + * Reports all captured errors so failures aren't masked when more than + * one stream fails in the same test. + */ + // scalastyle:off throwerror + def assertNoErrors(): Unit = { + val snapshot = errors.iterator().asScala.toList + if (snapshot.nonEmpty) { + val summary = snapshot.map(_.getMessage).mkString("; ") + throw new AssertionError( + s"Echo service received ${snapshot.size} error(s): $summary", snapshot.head) + } + } + // scalastyle:on throwerror + + /** Snapshot of every [[UdfRequest]] received, in arrival order. */ + def receivedRequests: List[UdfRequest] = received.iterator().asScala.toList + + /** Snapshot of the request-case sequence (Init / Data / Control{Finish,Cancel}). */ + def receivedSequence: List[UdfRequest.RequestCase] = + received.iterator().asScala.map(_.getRequestCase).toList + + override def execute( + responseObserver: StreamObserver[UdfResponse]): StreamObserver[UdfRequest] = { + new StreamObserver[UdfRequest] { + override def onNext(request: UdfRequest): Unit = { + received.add(request) + if (request.hasControl) { + val ctrl = request.getControl + if (ctrl.hasInit) { + responseObserver.onNext( + UdfResponse.newBuilder() + .setControl(UdfControlResponse.newBuilder() + .setInit(InitResponse.getDefaultInstance)) + .build()) + } else if (ctrl.hasFinish) { + responseObserver.onNext( + UdfResponse.newBuilder() + .setControl(UdfControlResponse.newBuilder() + .setFinish(FinishResponse.getDefaultInstance)) + .build()) + } else if (ctrl.hasCancel) { + responseObserver.onNext( + UdfResponse.newBuilder() + .setControl(UdfControlResponse.newBuilder() + .setCancel(CancelResponse.getDefaultInstance)) + .build()) + } + } else if (request.hasData) { + val dataReq = request.getData + responseObserver.onNext( + UdfResponse.newBuilder() + .setData(DataResponse.newBuilder().setData(dataReq.getData)) + .build()) + } + } + + override def onError(t: Throwable): Unit = { + errors.add(t) + } + + override def onCompleted(): Unit = { + responseObserver.onCompleted() + } + } + } +} + +/** + * A test fixture that swallows everything: never sends a response, never + * calls onCompleted. Used to exercise the [[GrpcWorkerSession]] init + * timeout path without waiting on real network failure modes. + */ +class IgnoringInitUDFWorkerService extends WorkerGrpc.WorkerImplBase { + override def execute( + responseObserver: StreamObserver[UdfResponse]): StreamObserver[UdfRequest] = { + new StreamObserver[UdfRequest] { + override def onNext(request: UdfRequest): Unit = () + override def onError(t: Throwable): Unit = () + override def onCompleted(): Unit = () + } + } +} + +/** + * A test fixture that acks Init but emits an `onError` on the first + * [[DataRequest]] it sees. Used to exercise the worker-error-mid-stream + * path that surfaces via [[GrpcWorkerSession]]'s `workerError` and the + * `StreamError` queue event. + */ +class FailingMidStreamUDFWorkerService(failureMessage: String) + extends WorkerGrpc.WorkerImplBase { + override def execute( + responseObserver: StreamObserver[UdfResponse]): StreamObserver[UdfRequest] = { + new StreamObserver[UdfRequest] { + override def onNext(request: UdfRequest): Unit = { + if (request.hasControl && request.getControl.hasInit) { + responseObserver.onNext( + UdfResponse.newBuilder() + .setControl(UdfControlResponse.newBuilder() + .setInit(InitResponse.getDefaultInstance)) + .build()) + } else if (request.hasData) { + responseObserver.onError(new RuntimeException(failureMessage)) + } + } + override def onError(t: Throwable): Unit = () + override def onCompleted(): Unit = () + } + } +} + +/** + * Result of starting an [[EchoUDFWorkerService]], providing access to + * both the gRPC server and the service instance for test assertions. + */ +case class EchoServerHandle( + server: Server, + service: EchoUDFWorkerService, + bossGroup: EventLoopGroup, + workerGroup: EventLoopGroup) + +/** + * Generic gRPC server handle used by misbehaving-service fixtures + * ([[IgnoringInitUDFWorkerService]], [[FailingMidStreamUDFWorkerService]]). + * No service-specific accessors -- tests interact via [[GrpcWorkerSession]] + * behaviour rather than the server. + */ +case class GrpcServerHandle( + server: Server, + bossGroup: EventLoopGroup, + workerGroup: EventLoopGroup) + +/** + * Helper to start/stop an [[EchoUDFWorkerService]] on a Unix domain socket. + */ +object EchoUDFWorkerServer { + + private val SHUTDOWN_TIMEOUT_MS = 5000L + + def start(socketPath: String): EchoServerHandle = { + val service = new EchoUDFWorkerService + val bossGroup = new EpollEventLoopGroup() + val workerGroup = new EpollEventLoopGroup() + val server = NettyServerBuilder + .forAddress(new DomainSocketAddress(socketPath)) + .channelType(classOf[EpollServerDomainSocketChannel]) + .bossEventLoopGroup(bossGroup) + .workerEventLoopGroup(workerGroup) + .addService(service) + .build() + .start() + EchoServerHandle(server, service, bossGroup, workerGroup) + } + + /** Start any [[WorkerGrpc.WorkerImplBase]] on a UDS path. */ + def startWith( + service: WorkerGrpc.WorkerImplBase, + socketPath: String): GrpcServerHandle = { + val bossGroup = new EpollEventLoopGroup() + val workerGroup = new EpollEventLoopGroup() + val server = NettyServerBuilder + .forAddress(new DomainSocketAddress(socketPath)) + .channelType(classOf[EpollServerDomainSocketChannel]) + .bossEventLoopGroup(bossGroup) + .workerEventLoopGroup(workerGroup) + .addService(service) + .build() + .start() + GrpcServerHandle(server, bossGroup, workerGroup) + } + + def stop(handle: GrpcServerHandle): Unit = { + if (handle != null && handle.server != null) { + handle.server.shutdownNow() + handle.server.awaitTermination(5, TimeUnit.SECONDS) + handle.bossGroup.shutdownGracefully(0, SHUTDOWN_TIMEOUT_MS, + TimeUnit.MILLISECONDS) + handle.workerGroup.shutdownGracefully(0, SHUTDOWN_TIMEOUT_MS, + TimeUnit.MILLISECONDS) + } + } + + def stop(handle: EchoServerHandle): Unit = { + if (handle != null && handle.server != null) { + handle.server.shutdownNow() + handle.server.awaitTermination(5, TimeUnit.SECONDS) + handle.bossGroup.shutdownGracefully(0, SHUTDOWN_TIMEOUT_MS, + TimeUnit.MILLISECONDS) + handle.workerGroup.shutdownGracefully(0, SHUTDOWN_TIMEOUT_MS, + TimeUnit.MILLISECONDS) + } + } +} diff --git a/udf/worker/grpc/src/test/scala/org/apache/spark/udf/worker/core/grpc/GrpcDirectWorkerDispatcherSuite.scala b/udf/worker/grpc/src/test/scala/org/apache/spark/udf/worker/core/grpc/GrpcDirectWorkerDispatcherSuite.scala new file mode 100644 index 0000000000000..c93942d469200 --- /dev/null +++ b/udf/worker/grpc/src/test/scala/org/apache/spark/udf/worker/core/grpc/GrpcDirectWorkerDispatcherSuite.scala @@ -0,0 +1,200 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.udf.worker.core.grpc + +import com.google.protobuf.ByteString + +// scalastyle:off funsuite +import org.scalatest.BeforeAndAfterEach +import org.scalatest.funsuite.AnyFunSuite + +import org.apache.spark.udf.worker.{ + DirectWorker, Init, ProcessCallable, UdfWorkerDataFormat, UDFWorkerProperties, + UDFWorkerSpecification, UdfPayload, UnixDomainSocket, + WorkerConnectionSpec => WorkerConnectionProto} +import org.apache.spark.udf.worker.core.direct.{DirectUnixSocketWorkerDispatcher, + DirectWorkerSession} + +/** + * A [[GrpcWorkerSessionFactory]] that, on each [[createConnection]], + * also spins up an in-process [[EchoUDFWorkerService]] bound to the + * same UDS path. This lets us test the full + * dispatcher -> session -> gRPC data round-trip without requiring an + * external worker binary. + * + * '''Design note:''' The bash worker script (`touch` + `sleep` loop) + * serves only as a lifecycle placeholder to satisfy the + * process-spawning logic in [[DirectUnixSocketWorkerDispatcher]]. It does NOT + * serve gRPC itself; the in-process echo server bound here does. + */ +class TestEchoBackedFactory extends GrpcWorkerSessionFactory { + + var lastHandle: EchoServerHandle = _ + + override def createConnection(address: String): GrpcWorkerConnection = { + new java.io.File(address).delete() + lastHandle = EchoUDFWorkerServer.start(address) + super.createConnection(address) + } +} + +/** + * End-to-end tests: UDFWorkerSpecification -> DirectUnixSocketWorkerDispatcher + * (with [[GrpcWorkerSessionFactory]]) -> [[GrpcWorkerSession]] -> data + * round-trip over UDS. + */ +/** + * Test-only [[DirectUnixSocketWorkerDispatcher]] that places its socket + * directory under `/tmp` rather than `java.io.tmpdir`. Linux caps UDS + * paths at 108 chars, and Spark's parent POM forces + * `java.io.tmpdir=target/tmp`, which on a normal checkout already pushes + * the path past the cap. `Files.createTempDirectory` reads + * `java.io.tmpdir` from a `static final` cache set at JVM start, so a + * test-time `System.setProperty` cannot recover; we route around it by + * overriding the protected hook the dispatcher exposes for testing. + */ +class TestDirectUnixSocketWorkerDispatcher( + spec: org.apache.spark.udf.worker.UDFWorkerSpecification, + factory: GrpcWorkerSessionFactory) + extends DirectUnixSocketWorkerDispatcher(spec, factory) { + + override protected def createPrivateTempDirectory(): java.nio.file.Path = { + val attr = java.nio.file.attribute.PosixFilePermissions.asFileAttribute( + java.nio.file.attribute.PosixFilePermissions.fromString("rwx------")) + java.nio.file.Files.createTempDirectory( + java.nio.file.Paths.get("/tmp"), "g", attr) + } +} + +class GrpcDirectWorkerDispatcherSuite + extends AnyFunSuite with BeforeAndAfterEach { +// scalastyle:on funsuite + + private val workerScript = + """ + |#!/bin/bash + |SOCKET_PATH="" + |while [[ $# -gt 0 ]]; do + | case "$1" in + | --connection) SOCKET_PATH="$2"; shift 2 ;; + | *) shift ;; + | esac + |done + |cleanup() { rm -f "$SOCKET_PATH"; exit 0; } + |trap cleanup SIGTERM + |touch "$SOCKET_PATH" + |while true; do sleep 1; done + """.stripMargin.trim + + private var dispatcher: DirectUnixSocketWorkerDispatcher = _ + private var factory: TestEchoBackedFactory = _ + + override def afterEach(): Unit = { + if (dispatcher != null) { + // Assert no silent gRPC errors before tearing down. + if (factory != null && factory.lastHandle != null) { + factory.lastHandle.service.assertNoErrors() + EchoUDFWorkerServer.stop(factory.lastHandle) + } + dispatcher.close() + dispatcher = null + factory = null + } + super.afterEach() + } + + private def buildSpec(): UDFWorkerSpecification = { + val runner = ProcessCallable.newBuilder() + .addCommand("bash").addCommand("-c").addCommand(workerScript).addCommand("--") + .build() + val properties = UDFWorkerProperties.newBuilder() + .setConnection(WorkerConnectionProto.newBuilder() + .setUnixDomainSocket(UnixDomainSocket.getDefaultInstance) + .build()) + .build() + UDFWorkerSpecification.newBuilder() + .setDirect(DirectWorker.newBuilder() + .setRunner(runner) + .setProperties(properties) + .build()) + .build() + } + + /** Minimal Init proto sufficient for the echo server. */ + private def buildInit( + payload: Array[Byte] = Array.emptyByteArray, + sessionConf: Map[String, String] = Map.empty): Init = { + val udf = UdfPayload.newBuilder() + .setPayload(ByteString.copyFrom(payload)) + .setFormat("test") + .build() + val builder = Init.newBuilder() + .setUdf(udf) + .setDataFormat(UdfWorkerDataFormat.ARROW) + sessionConf.foreach { case (k, v) => builder.putSessionConf(k, v) } + builder.build() + } + + private def setup(): DirectUnixSocketWorkerDispatcher = { + factory = new TestEchoBackedFactory + dispatcher = new TestDirectUnixSocketWorkerDispatcher(buildSpec(), factory) + dispatcher + } + + test("end-to-end: spec -> dispatcher -> session -> process over UDS") { + val disp = setup() + + val session = disp.createSession(None) + try { + session.init(buildInit( + payload = "my-udf".getBytes, + sessionConf = Map("mode" -> "test"))) + + val input = Iterator("batch-1".getBytes, "batch-2".getBytes) + val output = session.process(input).toList + + assert(output.size == 2) + assert(new String(output(0)) == "batch-1") + assert(new String(output(1)) == "batch-2") + } finally { + session.close() + } + } + + test("end-to-end: dispatcher close terminates worker") { + val disp = setup() + + val session = disp.createSession(None) + session.init(buildInit()) + + val worker = session.asInstanceOf[DirectWorkerSession].workerProcess + assert(worker.process.isAlive) + + val output = session.process(Iterator("data".getBytes)).toList + assert(output.size == 1) + session.close() + + factory.lastHandle.service.assertNoErrors() + EchoUDFWorkerServer.stop(factory.lastHandle) + disp.close() + dispatcher = null + factory = null + + assert(!worker.process.isAlive, + "worker should be terminated after dispatcher close") + } +} diff --git a/udf/worker/grpc/src/test/scala/org/apache/spark/udf/worker/core/grpc/GrpcWorkerSessionSuite.scala b/udf/worker/grpc/src/test/scala/org/apache/spark/udf/worker/core/grpc/GrpcWorkerSessionSuite.scala new file mode 100644 index 0000000000000..d7a2dc86c628e --- /dev/null +++ b/udf/worker/grpc/src/test/scala/org/apache/spark/udf/worker/core/grpc/GrpcWorkerSessionSuite.scala @@ -0,0 +1,425 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.udf.worker.core.grpc + +import java.nio.file.{Files, Paths} +import java.util.UUID +import java.util.concurrent.{CountDownLatch, FutureTask, TimeUnit} + +import com.google.protobuf.ByteString +import io.grpc.netty.shaded.io.grpc.netty.NettyChannelBuilder +import io.grpc.netty.shaded.io.netty.channel.epoll.{EpollDomainSocketChannel, + EpollEventLoopGroup} +import io.grpc.netty.shaded.io.netty.channel.unix.DomainSocketAddress + +// scalastyle:off funsuite +import org.scalatest.{BeforeAndAfterAll, BeforeAndAfterEach} +import org.scalatest.funsuite.AnyFunSuite + +import org.apache.spark.udf.worker.{Init, UdfWorkerDataFormat, UdfPayload} + +/** + * Unit tests for [[GrpcWorkerSession]] using an [[EchoUDFWorkerService]] + * bound to a Unix domain socket. + */ +class GrpcWorkerSessionSuite + extends AnyFunSuite with BeforeAndAfterEach with BeforeAndAfterAll { +// scalastyle:on funsuite + + private var handle: EchoServerHandle = _ + private var grpcHandle: GrpcServerHandle = _ + private var connection: GrpcWorkerConnection = _ + private var socketPath: String = _ + + private val testEventLoopGroup = new EpollEventLoopGroup() + + private def generateSocketPath(): String = { + // Linux UDS paths are capped at 108 chars. Spark's parent POM forces + // java.io.tmpdir to `target/tmp`, which on a normal checkout already + // pushes the path past the cap. Use the OS tmpdir directly, plus a + // short prefix and a truncated UUID, to stay well under the limit. + val tmpDir = Files.createTempDirectory(Paths.get("/tmp"), "g") + tmpDir.resolve(s"e-${UUID.randomUUID().toString.take(8)}.sock").toString + } + + /** Minimal Init proto sufficient for the echo server. */ + private def buildInit( + payload: Array[Byte] = Array.emptyByteArray, + inputSchema: Array[Byte] = Array.emptyByteArray, + outputSchema: Array[Byte] = Array.emptyByteArray, + sessionConf: Map[String, String] = Map.empty): Init = { + val udf = UdfPayload.newBuilder() + .setPayload(ByteString.copyFrom(payload)) + .setFormat("test") + .build() + val builder = Init.newBuilder() + .setUdf(udf) + .setDataFormat(UdfWorkerDataFormat.ARROW) + if (inputSchema.nonEmpty) builder.setInputSchema(ByteString.copyFrom(inputSchema)) + if (outputSchema.nonEmpty) builder.setOutputSchema(ByteString.copyFrom(outputSchema)) + sessionConf.foreach { case (k, v) => builder.putSessionConf(k, v) } + builder.build() + } + + private def createSessionOnEchoServer(): GrpcWorkerSession = { + socketPath = generateSocketPath() + handle = EchoUDFWorkerServer.start(socketPath) + connection = newConnection(socketPath) + new GrpcWorkerSession(connection) + } + + /** + * Bind a custom [[org.apache.spark.udf.worker.WorkerGrpc.WorkerImplBase]] + * to a fresh UDS path and return a session against it. Used by the + * misbehaving-fixture tests (init timeout, worker error mid-stream). + */ + private def createSessionWithService( + service: org.apache.spark.udf.worker.WorkerGrpc.WorkerImplBase, + controlResponseTimeoutMs: Long = + GrpcWorkerSession.DefaultControlResponseTimeoutMs): GrpcWorkerSession = { + socketPath = generateSocketPath() + grpcHandle = EchoUDFWorkerServer.startWith(service, socketPath) + connection = newConnection(socketPath) + new GrpcWorkerSession(connection, controlResponseTimeoutMs) + } + + private def newConnection(endpoint: String): GrpcWorkerConnection = { + val channel = NettyChannelBuilder + .forAddress(new DomainSocketAddress(endpoint)) + .eventLoopGroup(testEventLoopGroup) + .channelType(classOf[EpollDomainSocketChannel]) + .usePlaintext() + .build() + new GrpcWorkerConnection(channel, endpoint) + } + + override def afterEach(): Unit = { + // Assert no silent gRPC errors before tearing down. + if (handle != null) { + handle.service.assertNoErrors() + EchoUDFWorkerServer.stop(handle) + handle = null + } + if (grpcHandle != null) { + EchoUDFWorkerServer.stop(grpcHandle) + grpcHandle = null + } + if (connection != null) { + connection.close() + connection = null + } + super.afterEach() + } + + override def afterAll(): Unit = { + testEventLoopGroup.shutdownGracefully(0, 5000, TimeUnit.MILLISECONDS) + super.afterAll() + } + + test("init sends Init and receives InitResponse") { + val session = createSessionOnEchoServer() + try { + session.init(buildInit( + payload = "test-function".getBytes, + inputSchema = "input-schema".getBytes, + outputSchema = "output-schema".getBytes, + sessionConf = Map("key" -> "value"))) + } finally { + session.close() + } + } + + test("process echoes single batch") { + val session = createSessionOnEchoServer() + try { + session.init(buildInit()) + + val input = Iterator("hello-world".getBytes) + val output = session.process(input).toList + + assert(output.size == 1) + assert(new String(output.head) == "hello-world") + } finally { + session.close() + } + } + + test("process echoes multiple batches") { + val session = createSessionOnEchoServer() + try { + session.init(buildInit()) + + val batches = (1 to 5).map(i => s"batch-$i".getBytes) + val output = session.process(batches.iterator).toList + + assert(output.size == 5) + output.zip(batches).foreach { case (result, expected) => + assert(result sameElements expected) + } + } finally { + session.close() + } + } + + test("process with empty input returns empty output") { + val session = createSessionOnEchoServer() + try { + session.init(buildInit()) + + val output = session.process(Iterator.empty).toList + assert(output.isEmpty) + } finally { + session.close() + } + } + + test("cancel aborts the session") { + val session = createSessionOnEchoServer() + try { + session.init(buildInit()) + session.cancel() + } finally { + session.close() + } + } + + // -- Error-path tests ------------------------------------------------------- + + test("init called twice throws") { + val session = createSessionOnEchoServer() + try { + session.init(buildInit()) + val ex = intercept[IllegalStateException] { + session.init(buildInit()) + } + assert(ex.getMessage.contains("init has already been called")) + } finally { + session.close() + } + } + + test("process called twice throws") { + val session = createSessionOnEchoServer() + try { + session.init(buildInit()) + session.process(Iterator.empty).toList + val ex = intercept[IllegalStateException] { + session.process(Iterator.empty) + } + assert(ex.getMessage.contains("process has already been called")) + } finally { + session.close() + } + } + + test("cancel after init does not throw") { + val session = createSessionOnEchoServer() + try { + session.init(buildInit()) + session.cancel() + // close() after cancel should be safe + } finally { + session.close() + } + } + + test("cancel before init is a no-op") { + // Per WorkerSession's lifecycle contract, cancel before init must not + // touch the wire (sending Cancel as the first request would violate the + // proto's `Init -> ... -> Cancel|Finish` ordering). The session must + // still close cleanly afterwards. + val session = createSessionOnEchoServer() + try { + session.cancel() + // No init was sent; the echo server should not have observed any + // wire traffic, including Cancel. + handle.service.assertNoErrors() + } finally { + session.close() + } + } + + test("process after cancel throws CancellationException") { + // Per WorkerSession's lifecycle contract, a cancel between init and + // process must surface as an observable signal -- not a silently + // empty iterator that's indistinguishable from a generator UDF that + // produced nothing. + val session = createSessionOnEchoServer() + try { + session.init(buildInit()) + session.cancel() + intercept[java.util.concurrent.CancellationException] { + session.process(Iterator.empty) + } + } finally { + session.close() + } + } + + test("cancel during process from another thread does not hang") { + // Threading invariant: cancel() called from a task-interruption thread + // while process() is running on the caller thread must let process + // exit (cleanly, with empty/partial output, or by throwing). The + // contract permits either FinishResponse or CancelResponse; the only + // requirement is that the iterator does not block forever. + val session = createSessionOnEchoServer() + try { + session.init(buildInit()) + + // hasNext blocks on `release` after the first batch so cancel + // arrives while process() is in fillNextBatch. Bounded await so + // a regression cannot hang the test runner. + val inProcess = new CountDownLatch(1) + val release = new CountDownLatch(1) + val input = new Iterator[Array[Byte]] { + private var idx = 0 + override def hasNext: Boolean = { + if (idx == 1) { + inProcess.countDown() + release.await(2, TimeUnit.SECONDS) + } + idx < 5 + } + override def next(): Array[Byte] = { + val v = s"batch-$idx".getBytes + idx += 1 + v + } + } + + val task = new FutureTask[List[Array[Byte]]]( + () => session.process(input).toList) + val thread = new Thread(task, "process-thread") + thread.setDaemon(true) + thread.start() + + assert(inProcess.await(2, TimeUnit.SECONDS), + "process() did not reach the input iterator in time") + session.cancel() + release.countDown() + + // The future should resolve (return or throw) within the deadline. + // Either is acceptable per contract; the assertion here is that we + // did not hang. + try { + task.get(2, TimeUnit.SECONDS) + } catch { + case _: java.util.concurrent.ExecutionException => // also acceptable + case _: java.util.concurrent.TimeoutException => + fail("process() did not return within 2s after cancel") + } + assert(!thread.isAlive, "process thread should have exited after cancel") + } finally { + session.close() + } + } + + test("worker error mid-stream surfaces from process iterator") { + // gRPC translates a server-side `responseObserver.onError(...)` into a + // StatusRuntimeException on the client; the exact message wording is + // implementation-defined. We only assert the failure surfaces from the + // iterator (not silently dropped, not hung). + val session = createSessionWithService( + new FailingMidStreamUDFWorkerService("worker-failed-mid-stream")) + try { + session.init(buildInit()) + intercept[Throwable] { + session.process(Iterator("data".getBytes)).toList + } + } finally { + session.close() + } + } + + test("init throws when worker never sends InitResponse") { + // Use a fixture that swallows Init and a small timeout so the test + // exercises the timeout path without waiting on real network failure. + val session = createSessionWithService( + new IgnoringInitUDFWorkerService, + controlResponseTimeoutMs = 250L) + try { + val ex = intercept[RuntimeException] { + session.init(buildInit()) + } + assert(ex.getMessage.contains("Timed out"), + s"expected timeout error, got: ${ex.getMessage}") + assert(ex.getMessage.contains("channel state"), + "timeout error should report channel state for debuggability") + } finally { + session.close() + } + } + + test("Init fields round-trip to the worker intact") { + val session = createSessionOnEchoServer() + try { + val payload = "test-callable".getBytes + val inputSchema = "test-input-schema".getBytes + val outputSchema = "test-output-schema".getBytes + session.init(buildInit( + payload = payload, + inputSchema = inputSchema, + outputSchema = outputSchema, + sessionConf = Map("k1" -> "v1", "k2" -> "v2"))) + // Drain so Finish is on the wire and recorded. + session.process(Iterator.empty).toList + + val received = handle.service.receivedRequests + val initMsg = received.headOption + .flatMap(r => if (r.hasControl && r.getControl.hasInit) Some(r.getControl.getInit) + else None) + .getOrElse(fail("expected first received request to be Init")) + + assert(initMsg.getDataFormat == UdfWorkerDataFormat.ARROW) + assert(initMsg.getUdf.getPayload.toByteArray sameElements payload) + assert(initMsg.getUdf.getFormat == "test") + assert(initMsg.getInputSchema.toByteArray sameElements inputSchema) + assert(initMsg.getOutputSchema.toByteArray sameElements outputSchema) + assert(initMsg.getSessionConfMap.get("k1") == "v1") + assert(initMsg.getSessionConfMap.get("k2") == "v2") + } finally { + session.close() + } + } + + test("request wire order is Init -> Data* -> Finish") { + val session = createSessionOnEchoServer() + try { + session.init(buildInit()) + val batches = (1 to 3).map(i => s"b$i".getBytes) + session.process(batches.iterator).toList + + val received = handle.service.receivedRequests + assert(received.nonEmpty, "expected at least Init + Finish") + // First must be Init. + assert(received.head.hasControl && received.head.getControl.hasInit, + s"first request was not Init: ${received.head}") + // Last must be Finish (not Cancel) on the happy path. + assert(received.last.hasControl && received.last.getControl.hasFinish, + s"last request was not Finish: ${received.last}") + // Middle entries are all Data, in input order. + val middle = received.drop(1).dropRight(1) + assert(middle.size == 3, s"expected 3 Data messages, got ${middle.size}") + middle.zip(batches).foreach { case (req, expected) => + assert(req.hasData, s"middle request was not Data: $req") + assert(req.getData.getData.toByteArray sameElements expected) + } + } finally { + session.close() + } + } +} diff --git a/udf/worker/proto/pom.xml b/udf/worker/proto/pom.xml index 6850ae6938a39..843d843719bd9 100644 --- a/udf/worker/proto/pom.xml +++ b/udf/worker/proto/pom.xml @@ -51,32 +51,28 @@ org.scala-lang scala-library + + + io.grpc + grpc-api + ${io.grpc.version} + + + io.grpc + grpc-protobuf + ${io.grpc.version} + + + io.grpc + grpc-stub + ${io.grpc.version} + target/scala-${scala.binary.version}/classes target/scala-${scala.binary.version}/test-classes - - com.github.os72 - protoc-jar-maven-plugin - ${protoc-jar-maven-plugin.version} - - - generate-sources - - run - - - com.google.protobuf:protoc:${protobuf.version} - ${protobuf.version} - - src/main/protobuf - - - - - org.apache.maven.plugins maven-shade-plugin @@ -113,4 +109,86 @@ + + + + + default-protoc + + true + + + + + eu.maveniverse.maven.plugins + nisse-plugin3 + 0.7.0 + + + set-os-detector-properties + + inject-properties + + validate + + + + + org.xolstice.maven.plugins + protobuf-maven-plugin + 0.6.1 + + com.google.protobuf:protoc:${protobuf.version}:exe:${os.detected.classifier} + grpc-java + io.grpc:protoc-gen-grpc-java:${io.grpc.version}:exe:${os.detected.classifier} + src/main/protobuf + + + + + compile + compile-custom + + + + + + + + + user-defined-protoc + + ${env.SPARK_PROTOC_EXEC_PATH} + ${env.CONNECT_PLUGIN_EXEC_PATH} + + + + + org.xolstice.maven.plugins + protobuf-maven-plugin + 0.6.1 + + ${spark.protoc.executable.path} + grpc-java + ${connect.plugin.executable.path} + src/main/protobuf + + + + + compile + compile-custom + + + + + + + + diff --git a/udf/worker/proto/src/main/protobuf/common.proto b/udf/worker/proto/src/main/protobuf/common.proto index ee032def73efe..c08becb082d4e 100644 --- a/udf/worker/proto/src/main/protobuf/common.proto +++ b/udf/worker/proto/src/main/protobuf/common.proto @@ -24,9 +24,9 @@ option java_package = "org.apache.spark.udf.worker"; option java_multiple_files = true; // The UDF in & output data format. -enum UDFWorkerDataFormat { +enum UdfWorkerDataFormat { UDF_WORKER_DATA_FORMAT_UNSPECIFIED = 0; - + // The worker accepts and produces Apache arrow batches. ARROW = 1; } @@ -39,10 +39,10 @@ enum UDFWorkerDataFormat { // framing their phases as messages on the stream, but that is a design // question worth revisiting as additional UDF types are added -- for // example, aggregation may prefer a multi-round or specialized pattern. -enum UDFProtoCommunicationPattern { +enum UdfProtoCommunicationPattern { UDF_PROTO_COMMUNICATION_PATTERN_UNSPECIFIED = 0; - // Data exachanged as a bidrectional + // Data exchanged as a bidirectional // stream of bytes. BIDIRECTIONAL_STREAMING = 1; } diff --git a/udf/worker/proto/src/main/protobuf/udf_protocol.proto b/udf/worker/proto/src/main/protobuf/udf_protocol.proto new file mode 100644 index 0000000000000..dcd7746333dcf --- /dev/null +++ b/udf/worker/proto/src/main/protobuf/udf_protocol.proto @@ -0,0 +1,454 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +syntax = "proto3"; + +import "common.proto"; + +package org.apache.spark.udf.worker; + +option java_package = "org.apache.spark.udf.worker"; +option java_multiple_files = true; + +// ===================================================================== +// Language-agnostic UDF execution protocol. +// +// The Spark engine acts as the gRPC client; a UDF worker (in any +// language) acts as the gRPC server. +// ===================================================================== + +// The default UDF gRPC service. A worker that exposes this service +// MUST do so over the default connection of the worker specification. +// +// In future, additional connections (e.g. a separate channel) may be +// reserved by the worker spec for other purposes. +service Worker { + // Per-execution stream. Exactly one [[Init]] is sent first, followed + // by 0..N data batches in either direction, terminated by exactly + // one [[Finish]] or [[Cancel]] from the engine. The worker MUST + // respond with the matching Init / Finish / Cancel responses on the + // response stream. + // + // For stateful execution, the state is maintained per bi-directional + // stream, mapping to a `WorkerSession` on the engine side + // (`org.apache.spark.udf.worker.core.WorkerSession`). + rpc Execute(stream UdfRequest) returns (stream UdfResponse); + + // Worker-scoped management RPC, independent of any per-execution + // stream. Used for heartbeat, capability query, and graceful + // shutdown. Kept unary so it does not depend on the lifecycle of an + // active Execute stream. + rpc Manage(WorkerRequest) returns (WorkerResponse); +} + +// ===================================================================== +// Execute stream -- envelope +// ===================================================================== + +// Engine -> Worker. Either a control message ([[Init]] / [[PayloadChunk]] +// / [[Finish]] / [[Cancel]]) or a data message. +message UdfRequest { + // Exactly one branch MUST be set; receivers MUST reject messages + // with no branch set. + oneof request { + UdfControlRequest control = 1; + DataRequest data = 2; + } +} + +// Worker -> Engine. Either a control response ([[InitResponse]] / +// [[FinishResponse]] / [[CancelResponse]]) or a data response message. +message UdfResponse { + // Exactly one branch MUST be set; receivers MUST reject messages + // with no branch set. + oneof response { + UdfControlResponse control = 1; + DataResponse data = 2; + } +} + +// Engine -> Worker control messages. +// +// Wire order on an Execute stream is exactly: +// Init { ... } +// PayloadChunk { ... }* // optional; 0..N chunks, only used when +// // the single UDF payload on Init is too +// // large to fit inline. +// ( DataRequest | )* +// Finish { ... } OR Cancel { ... } // exactly one terminator +// +// The worker MUST emit [[InitResponse]] before sending any +// [[DataResponse]], and MUST emit exactly one [[FinishResponse]] or +// [[CancelResponse]] before closing the response stream. +// +// A worker that receives messages out of this order (e.g. a second Init, +// a PayloadChunk after the first DataRequest, a DataRequest before Init, +// or a Cancel before Init) MUST close the stream with an error. +message UdfControlRequest { + // Exactly one branch MUST be set; receivers MUST reject messages + // with no branch set. + oneof control { + Init init = 1; + PayloadChunk payload = 2; + Finish finish = 3; + Cancel cancel = 4; + } +} + +// Worker -> Engine control messages. +message UdfControlResponse { + // Exactly one branch MUST be set; receivers MUST reject messages + // with no branch set. + oneof control { + InitResponse init = 1; + FinishResponse finish = 2; + CancelResponse cancel = 3; + } +} + +// ===================================================================== +// Init phase +// ===================================================================== + +// Sent once, as the first message on an Execute stream. Describes +// the UDF body to run plus the minimum metadata the worker needs to +// start processing it. +// +// Today the protocol mandates exactly one Init per UDF execution +// (one Init -> data -> Finish). This is the simplest contract and +// covers all currently supported UDF kinds. In the future we may +// evolve to support multiple init phases on the same stream -- e.g. +// when worker setup requires an interactive handshake (negotiate a +// schema, exchange capabilities, fetch driver-side metadata, ...) +// before the data plane opens. Such an extension would be additive +// and would not change the single-Init semantics already in use. +// +// Engine vs. client split: +// * Most fields on Init are engine-side. They describe what +// flows on the wire for this session ([[data_format]] / +// [[input_schema]] / [[output_schema]] -- matching the worker +// spec, not the function's view) and what per-session +// context the worker needs ([[timezone]], [[session_conf]], +// [[task_context]], [[parameters]]). +// * [[UdfPayload]] carries everything the client side of Spark +// (where the UDF is defined and serialized) packs -- the +// serialized callable, an opaque format tag, and any encoder +// metadata bundled with the callable. The wire protocol does +// not enumerate encoder shapes; that is left to the client and +// worker to agree on per UDF type. +message Init { + // (Required) Wire format used for [[DataRequest.data]] and + // [[DataResponse.data]] for the life of this session. Must be + // one of the formats the worker declared in + // [[WorkerCapabilities.supported_data_formats]]; the client side + // of the protocol picks one at planning time and sticks with it. + // Receivers MUST reject an Init whose [[data_format]] is + // `UDF_WORKER_DATA_FORMAT_UNSPECIFIED`. + UdfWorkerDataFormat data_format = 1; + + // (Required) The UDF body to execute on the worker for this + // session. Exactly one payload per Execute stream. + UdfPayload udf = 2; + + // (Optional) Schema of the input data plane in the wire format + // declared by [[data_format]] -- e.g. an Arrow IPC schema when + // data_format = ARROW. This is an engine-side requirement: it + // describes the bytes the engine will actually put on + // [[DataRequest.data]] for this session, matching what the + // worker advertised in its spec. It is NOT necessarily the + // schema the function definer expressed; the UDF's own type + // information lives inside [[UdfPayload]], typically embedded + // alongside the callable in [[UdfPayload.payload]] (e.g. as + // input/output encoders chosen per UDF type). + // + // Left unset when the worker can derive the schema from the + // payload alone. + optional bytes input_schema = 3; + + // (Optional) Schema of the output data plane in the wire format + // declared by [[data_format]]. Same semantics as + // [[input_schema]] -- engine-side requirement describing the + // bytes the engine expects on [[DataResponse.data]]. + optional bytes output_schema = 4; + + // (Optional; defaults to an empty map.) Per-task context + // provided by the engine. Common keys identify the task instance + // for diagnostics, logging, and stateful workers -- e.g. + // partition id, task attempt id, stage id, micro-batch id. + // Engine and worker agree on the keys they share; the protocol + // does not enumerate them. + map task_context = 5; + + // (Optional; defaults to an empty map.) Worker-private knobs not + // already captured by typed fields above. Free-form; both sides + // agree on the keys they need. + // + // Any key that two languages converge on is a candidate for + // promotion to a structured proto field -- once promoted, it gets + // a typed field number from the reserved range right after this + // block and is removed from [[session_conf]]. [[timezone]] below + // is an example of a key that has already been promoted. + map session_conf = 6; + + // (Optional) Session timezone, promoted out of [[session_conf]] + // because every eval needs it for timestamp encoding/decoding. + // + // Format follows Spark's `spark.sql.session.timeZone` config -- + // typically an IANA TZ id (e.g. "America/Los_Angeles") or a + // fixed offset (e.g. "+08:00"). The engine MUST pass the value + // it would resolve from the session conf without further + // transformation, so the worker can interpret it the same way + // Spark does. + optional string timezone = 7; + + // Reserved for future typed Init fields, in particular keys + // graduated from [[session_conf]] (see the [[timezone]] precedent + // above). Numbers >= 100 are intentionally NOT reserved here; if + // a future revision needs an opaque escape-hatch field, give it a + // number >= 100 alongside [[parameters]] and add a field-level + // comment so the convention stays visible. + reserved 8 to 99; + + // (Optional) Engine-packed opaque parameters specific to a + // particular kind of UDF execution. The escape hatch for + // anything the engine needs the worker to see at init time + // that is not already captured by the typed fields above and + // does not fit naturally into [[task_context]]. The encoding + // is agreed between the engine and the worker; the protocol + // does not interpret it. The matching response, also opaque + // bytes, is returned via [[InitResponse.data]]. + // + // Numbers >= 100 are reserved by convention for opaque + // escape-hatch fields like this one; new typed fields use the + // reserved 8..99 range. + // + // Client-side init data (anything packed by the layer that + // defines and serializes the UDF) does NOT belong here -- it + // travels inside [[UdfPayload.payload]] instead. + optional bytes parameters = 100; +} + +// Acknowledgment for [[Init]]. The worker MUST send exactly one +// [[InitResponse]] before any [[DataResponse]]. +// +// The init phase allows the engine to interact with the UDF before +// data starts flowing -- the worker can return inline bytes here for +// the engine (or higher-level code on the engine side) to consume +// during setup. The semantics of those bytes are agreed between the +// client side of the protocol and the worker; this message itself is +// otherwise opaque. +message InitResponse { + // (Optional) Inline init result returned by the worker. Opaque + // to the protocol; the client side of the protocol and the + // worker agree on what (if anything) it carries. + optional bytes data = 1; +} + +// Optional. Used to stream the single UDF payload when it does not +// fit in a single gRPC message. The default is to send the payload +// inline on [[UdfPayload.payload]]; chunking is only needed when a +// payload exceeds the gRPC message size limit. +// +// When used, chunks are sent zero or more times after [[Init]] and +// before the first [[DataRequest]]. The worker concatenates the +// inline [[UdfPayload.payload]] (if any) followed by all chunks in +// arrival order to form the final payload. +// +// Chunks are part of the Init handshake, not standalone control +// messages: they extend [[Init.udf.payload]] and are not +// individually acknowledged. The single [[InitResponse]] covers +// Init plus all of its chunks together. +message PayloadChunk { + // (Required, non-empty.) Bytes appended to the [[Init.udf]] + // payload. + bytes data = 1; + + // (Optional) Set to true on the final chunk. Receivers MAY use + // this as an early signal that the payload is complete and + // decoding can begin; receivers that prefer to wait for the + // first [[DataRequest]] (which marks the end of the chunking + // phase) MAY ignore this. When unset, the receiver determines + // completeness by the arrival of the first [[DataRequest]]. + optional bool last = 2; +} + +// ===================================================================== +// Data phase +// +// `data` is intentionally a top-level `bytes` field on both request +// and response messages -- not nested inside a wrapper -- so that +// implementations can avoid an extra copy when reading or writing +// the payload. The wire format (Arrow IPC etc.) is declared once per +// session via [[Init.data_format]] and stays the same for the life +// of the stream. +// ===================================================================== + +// Engine -> Worker per-batch payload. +message DataRequest { + // (Required, non-empty.) Encoded data bytes for one batch in the + // session-declared format. + bytes data = 1; +} + +// Worker -> Engine per-batch payload. The worker emits zero or more +// [[DataResponse]]s between [[InitResponse]] and [[FinishResponse]] / +// [[CancelResponse]]. Sink-style UDFs (which consume input but +// produce no output rows on the data plane) emit exactly zero. +message DataResponse { + // (Required, non-empty.) Encoded data bytes for one batch in the + // session-declared format. + bytes data = 1; +} + +// ===================================================================== +// Finish / Cancel phase +// ===================================================================== + +// Sent by the engine when no more input batches will arrive. The +// worker MUST drain any remaining output, then emit +// [[FinishResponse]] and close the response stream. +// +// Exactly one of [[Finish]] or [[Cancel]] is sent per Execute stream; +// they are mutually exclusive. If the engine has already sent +// [[Finish]] it MUST NOT send [[Cancel]] afterwards (and vice versa). +message Finish {} + +// Worker -> Engine completion message. May carry summary metrics. +message FinishResponse { + // Final metrics aggregated over the whole session (e.g. rows + // in/out, time per phase). Free-form; names are worker-defined. + map metrics = 1; + + // (Optional) Inline finish result returned by the worker. + // Mirrors [[InitResponse.data]] -- the finish phase allows the + // engine to interact with the UDF after data has stopped + // flowing, with the worker returning opaque bytes the engine (or + // higher-level code) may consume during teardown. The semantics + // of those bytes are agreed between the client side of the + // protocol and the worker. + optional bytes data = 2; +} + +// Engine -> Worker explicit cancel. Distinct from a gRPC stream error +// so the worker can run cleanup deterministically (release file +// handles, drop temp state, etc.). After receiving [[Cancel]] the +// worker MUST stop emitting [[DataResponse]] messages, run cleanup, +// and emit [[CancelResponse]] before closing. +// +// Exactly one of [[Finish]] or [[Cancel]] is sent per Execute stream; +// see [[Finish]]. [[Cancel]] is the cooperative cancellation path; +// gRPC-level stream errors are the involuntary fallback. If the +// stream breaks before [[CancelResponse]] arrives, the engine +// considers the worker uncancellable for this session and relies on +// process-level cleanup. +message Cancel { + // (Optional) Free-form reason for diagnostics. + optional string reason = 1; +} + +// Worker -> Engine acknowledgment of [[Cancel]]. +message CancelResponse {} + +// The single UDF body delivered to the worker on [[Init]]. Opaque to +// the engine: the engine forwards [[payload]] and [[format]] +// unchanged, and the worker decodes them per the format the client +// and worker have agreed on. +message UdfPayload { + // (Required, may be empty when chunked.) Serialized UDF bundle, + // opaque to the engine. The encoding is declared in [[format]]. + // + // The bundle is not necessarily just the serialized callable; + // it is up to the client side of the protocol and the worker to + // agree on what is packed inside it -- e.g. custom encoders for + // user-defined types, type hints, or any other metadata the + // worker needs to invoke the UDF. + // + // For payloads too large to fit on a single gRPC message, this + // field MAY be left empty (zero-length bytes) and the bytes + // delivered via the [[PayloadChunk]] mechanism instead. See + // [[PayloadChunk]] for chunking semantics. + bytes payload = 1; + + // (Required, non-empty.) Format tag identifying the encoding of + // [[payload]] (e.g. "py-cloudpickle-v3", "wasm-v1"). Engine does + // not interpret this; the client side of the protocol and the + // worker agree on its meaning. + string format = 2; + + // (Optional) Total payload size in bytes. Useful when chunked + // streaming is used so the worker can pre-allocate buffers. + optional int64 payload_size = 3; + + // (Optional) Human-readable name for diagnostics and metrics. + optional string name = 4; + + // (Optional) Worker / language-specific dispatch hint. A + // free-form string the worker uses to pick the code path that + // handles this payload. The protocol does not enumerate eval + // types because they are language-specific; the client side of + // the protocol and the worker agree on the namespace and the + // values. + // + // When the worker can derive the eval type from the payload + // itself (embedded metadata, format tag, etc.), this field is + // left unset. Otherwise the client side of the protocol sets it + // explicitly. + optional string eval_type = 5; +} + +// ===================================================================== +// Manage RPC -- worker-scoped operations independent of Execute +// ===================================================================== + +// Engine -> Worker. Wraps the manage operations in a oneof so the RPC +// is a single typed call, leaving room for future operations +// (capability query, profiling, ...). +message WorkerRequest { + // Exactly one branch MUST be set; receivers MUST reject messages + // with no branch set. + oneof manage { + Heartbeat heartbeat = 1; + ShutdownRequest shutdown = 2; + } +} + +// Worker -> Engine. +message WorkerResponse { + // Exactly one branch MUST be set, mirroring the request oneof. + oneof manage { + HeartbeatResponse heartbeat = 1; + ShutdownResponse shutdown = 2; + } +} + +// Liveness probe. The engine may send this periodically to detect a +// hung worker. The worker SHOULD reply within a small bounded time. +message Heartbeat {} + +// Acknowledgment for [[Heartbeat]]. +message HeartbeatResponse {} + +// Engine-initiated graceful shutdown request. Independent of SIGTERM +// (which is the OS-level fallback) -- this lets the worker know the +// engine has finished with it and intends no further Execute streams. +message ShutdownRequest { + // (Optional) Free-form reason for diagnostics. + optional string reason = 1; +} + +// Worker -> Engine acknowledgment of [[ShutdownRequest]]. +message ShutdownResponse {} diff --git a/udf/worker/proto/src/main/protobuf/worker_spec.proto b/udf/worker/proto/src/main/protobuf/worker_spec.proto index 83dac4f962e5f..6433d6169f412 100644 --- a/udf/worker/proto/src/main/protobuf/worker_spec.proto +++ b/udf/worker/proto/src/main/protobuf/worker_spec.proto @@ -124,7 +124,7 @@ message WorkerCapabilities { // is reported by the engine as part of the UDF protocol's init message. // // (Required) - repeated UDFWorkerDataFormat supported_data_formats = 1; + repeated UdfWorkerDataFormat supported_data_formats = 1; // Which UDF protocol communication patterns the worker // supports. This should list all supported patterns. @@ -135,7 +135,7 @@ message WorkerCapabilities { // the query will fail during query planning. // // (Required) - repeated UDFProtoCommunicationPattern supported_communication_patterns = 2; + repeated UdfProtoCommunicationPattern supported_communication_patterns = 2; // Whether multiple, concurrent UDF // connections are supported by this worker