Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -113,6 +113,7 @@
<module>connector/protobuf</module>
<module>udf/worker/proto</module>
<module>udf/worker/core</module>
<module>udf/worker/grpc</module>
<!-- See additional modules enabled by profiles below -->
</modules>

Expand Down
37 changes: 23 additions & 14 deletions udf/worker/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ WorkerDispatcher -- manages workers, creates sessions
|
v
WorkerSession -- one UDF execution
| 1. session.init(InitMessage(payload, inputSchema, outputSchema))
| 1. session.init(Init proto: udf payload + data format + schemas)
| 2. val results = session.process(inputBatches)
| 3. session.close()
```
Expand All @@ -34,19 +34,22 @@ provisioning service or daemon).
```
udf/worker/
├── proto/
│ worker_spec.proto -- UDFWorkerSpecification protobuf (+ generated Java classes)
│ common.proto -- shared enums (UDFWorkerDataFormat, etc.)
│ worker_spec.proto -- UDFWorkerSpecification protobuf
│ udf_protocol.proto -- UDF execution protocol (Init, UdfPayload, ...)
│ common.proto -- shared enums (UdfWorkerDataFormat, etc.)
└── core/ -- abstract interfaces
WorkerDispatcher.scala -- creates sessions, manages worker lifecycle
WorkerSession.scala -- per-UDF init/process/cancel/close + InitMessage
WorkerSession.scala -- per-UDF init/process/cancel/close
WorkerSessionFactory.scala -- protocol-level connection + session factory
WorkerConnection.scala -- transport channel abstraction
WorkerSecurityScope.scala -- security boundary for worker pooling
└── direct/ -- "direct" creation: local OS processes
DirectWorkerDispatcher.scala -- spawns processes, env lifecycle
DirectWorkerProcess.scala -- OS process + connection + UDS socket
DirectWorkerSession.scala -- session backed by a direct process
DirectWorkerSession.scala -- decorator: forwards to inner session,
releases worker ref-count on close
```

The `core/` package defines abstract interfaces that are independent of how
Expand Down Expand Up @@ -76,10 +79,12 @@ Workers are terminated via SIGTERM/SIGKILL when the dispatcher is closed.

```scala
import org.apache.spark.udf.worker.{
DirectWorker, ProcessCallable, UDFProtoCommunicationPattern,
UDFWorkerDataFormat, UDFWorkerProperties, UDFWorkerSpecification,
UnixDomainSocket, WorkerCapabilities, WorkerConnectionSpec, WorkerEnvironment}
DirectWorker, Init, ProcessCallable, UdfPayload,
UdfProtoCommunicationPattern, UdfWorkerDataFormat, UDFWorkerProperties,
UDFWorkerSpecification, UnixDomainSocket, WorkerCapabilities,
WorkerConnectionSpec, WorkerEnvironment}
import org.apache.spark.udf.worker.core._
import com.google.protobuf.ByteString

// 1. Define a worker spec (direct creation mode).
val spec = UDFWorkerSpecification.newBuilder()
Expand All @@ -90,9 +95,9 @@ val spec = UDFWorkerSpecification.newBuilder()
.addCommand("pip").addCommand("install").addCommand("my_udf_worker").build())
.build())
.setCapabilities(WorkerCapabilities.newBuilder()
.addSupportedDataFormats(UDFWorkerDataFormat.ARROW)
.addSupportedDataFormats(UdfWorkerDataFormat.ARROW)
.addSupportedCommunicationPatterns(
UDFProtoCommunicationPattern.BIDIRECTIONAL_STREAMING)
UdfProtoCommunicationPattern.BIDIRECTIONAL_STREAMING)
.build())
.setDirect(DirectWorker.newBuilder()
.setRunner(ProcessCallable.newBuilder()
Expand All @@ -112,10 +117,14 @@ val dispatcher: WorkerDispatcher = ...
val session = dispatcher.createSession(securityScope = None)
try {
// 4. Initialize with the serialized function and schemas.
session.init(InitMessage(
functionPayload = serializedFunction,
inputSchema = arrowInputSchema,
outputSchema = arrowOutputSchema))
session.init(Init.newBuilder()
.setUdf(UdfPayload.newBuilder()
.setPayload(ByteString.copyFrom(serializedFunction))
.setFormat("py-cloudpickle-v3"))
.setDataFormat(UdfWorkerDataFormat.ARROW)
.setInputSchema(ByteString.copyFrom(arrowInputSchema))
.setOutputSchema(ByteString.copyFrom(arrowOutputSchema))
.build())

// 5. Process data -- Iterator in, Iterator out.
val results: Iterator[Array[Byte]] =
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,31 +19,7 @@ package org.apache.spark.udf.worker.core
import java.util.concurrent.atomic.AtomicBoolean

import org.apache.spark.annotation.Experimental

/**
* :: Experimental ::
* Carries all information needed to initialize a UDF execution on a worker.
*
* This message is passed to [[WorkerSession#init]] and contains the function
* definition, schemas, and any additional configuration.
*
* Placeholder: will be replaced by a generated proto message once the
* UDF wire protocol lands. Do not rely on case-class equality --
* `Array[Byte]` fields compare by reference.
*
* @param functionPayload serialized function (e.g., pickled Python, JVM bytes)
* @param inputSchema serialized input schema (e.g., Arrow schema bytes)
* @param outputSchema serialized output schema (e.g., Arrow schema bytes)
* @param properties additional key-value configuration. Can carry
* protocol-specific or engine-specific metadata that
* does not yet have a dedicated field.
*/
@Experimental
case class InitMessage(
functionPayload: Array[Byte],
inputSchema: Array[Byte],
outputSchema: Array[Byte],
properties: Map[String, String] = Map.empty)
import org.apache.spark.udf.worker.Init

/**
* :: Experimental ::
Expand All @@ -62,7 +38,10 @@ case class InitMessage(
* {{{
* val session = dispatcher.createSession(securityScope = None)
* try {
* session.init(InitMessage(functionPayload, inputSchema, outputSchema))
* session.init(Init.newBuilder()
* .setUdf(UdfPayload.newBuilder().setPayload(callable).setFormat(fmt))
* .setDataFormat(UdfWorkerDataFormat.ARROW)
* .build())
* val results = session.process(inputBatches)
* results.foreach(handleBatch)
* } finally {
Expand All @@ -74,7 +53,18 @@ case class InitMessage(
* - [[init]] must be called exactly once before [[process]].
* - [[process]] must be called at most once per session.
* - [[close]] must always be called (use try-finally).
* - [[cancel]] may be called at any time to abort execution.
* - [[cancel]] may be called at any time, including before [[init]]
* or after [[process]]/[[close]] has returned. Implementations
* treat such calls as a no-op so that callers driven by a task
* interruption listener (which has no view into the session state)
* do not need to coordinate with the thread driving [[process]].
*
* Cancel-vs-finish race: when the session driver has finished
* sending input (and therefore queued an implicit finish on the
* underlying transport) and a [[cancel]] arrives concurrently, both
* are valid stream-terminating actions; the response side carries
* either a `FinishResponse` or a `CancelResponse` depending on which
* the worker observes first, and either is acceptable to the caller.
*
* The lifecycle is enforced here: [[init]] and [[process]] are `final`
* and delegate to [[doInit]] / [[doProcess]] after AtomicBoolean guards.
Expand All @@ -93,10 +83,12 @@ abstract class WorkerSession extends AutoCloseable {
*
* Throws `IllegalStateException` if called more than once.
*
* @param message the initialization parameters including the serialized
* function, input/output schemas, and configuration.
* @param message the [[Init]] proto carrying the UDF body, the wire
* data format, optional input/output schemas, and any
* engine-side session context the worker needs to start
* processing.
*/
final def init(message: InitMessage): Unit = {
final def init(message: Init): Unit = {
if (!initialized.compareAndSet(false, true)) {
throw new IllegalStateException("init has already been called on this session")
}
Expand Down Expand Up @@ -128,7 +120,7 @@ abstract class WorkerSession extends AutoCloseable {
}

/** Subclass hook for [[init]]. Called once, after the guard. */
protected def doInit(message: InitMessage): Unit
protected def doInit(message: Init): Unit

/** Subclass hook for [[process]]. Called at most once, after the guard. */
protected def doProcess(input: Iterator[Array[Byte]]): Iterator[Array[Byte]]
Expand All @@ -138,11 +130,29 @@ abstract class WorkerSession extends AutoCloseable {
*
* '''Thread-safety:''' implementations must allow [[cancel]] to be called
* from a thread different from the one driving [[process]] (typically a
* task interruption thread). It may be invoked at any point after
* [[init]] and should be a no-op if execution has already finished.
* task interruption thread).
*
* '''Lifecycle:''' [[cancel]] is idempotent and safe at any point in
* the session's life:
* - before [[init]] -- nothing has been sent on the transport yet,
* so [[cancel]] is a no-op (the session may still be closed
* normally via [[close]]).
* - between [[init]] and [[process]] -- transitions the session
* into a cancelled state; subsequent [[process]] calls observe
* the cancellation.
* - during [[process]] -- aborts the active stream.
* - after [[process]] / [[close]] has returned -- a no-op.
*
* Implementations are responsible for the no-op behavior described
* above so that callers (e.g. task interruption listeners) do not
* need to coordinate with the thread driving [[process]].
*/
def cancel(): Unit

/** Closes this session and releases resources. */
/**
* Closes this session and releases resources. Idempotent; safe to
* call from a `finally` block regardless of whether [[init]],
* [[process]], or [[cancel]] have been invoked.
*/
override def close(): Unit
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.spark.udf.worker.core

import org.apache.spark.annotation.Experimental

/**
* :: Experimental ::
* Protocol-level factory for [[WorkerConnection]]s and per-execution
* [[WorkerSession]]s.
*
* The factory is responsible for everything that depends on the protocol
* the worker speaks (e.g. gRPC) -- building the transport-level connection
* and producing per-execution sessions on top of it. It is '''not'''
* responsible for creating, terminating, or otherwise managing the worker
* itself; that is the [[WorkerDispatcher]]'s job. As long as the dispatcher
* gives the factory an address to dial, and later a connection to talk
* over, the factory does not need to know whether the worker is a
* locally-spawned process, a daemon, or a remote service.
*
* One factory instance is owned by one dispatcher and is closed by the
* dispatcher. Implementations that own protocol-level shared resources
* (e.g. a Netty event loop group) should release them in [[close]].
*/
@Experimental
trait WorkerSessionFactory extends AutoCloseable {

/**
* Builds a transport-level connection to a worker reachable at
* `address`. The address format is whatever the dispatcher has agreed
* with the worker spec's transport (e.g. a UDS path, a TCP host:port).
*
* The returned connection is owned by the dispatcher's worker handle;
* its lifecycle is tied to worker teardown.
*/
def createConnection(address: String): WorkerConnection

/**
* Creates a per-execution [[WorkerSession]] over an already-established
* [[WorkerConnection]]. The factory does not own the connection -- the
* worker handle does -- and a connection may back many sessions over
* its lifetime.
*/
def createSession(connection: WorkerConnection): WorkerSession

/**
* Releases factory-level resources (event loops, thread pools, ...).
* Called by the dispatcher on close. Default is a no-op so factories
* that have nothing to release can omit the override.
*/
override def close(): Unit = ()
}
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ import java.nio.file.attribute.PosixFilePermissions

import org.apache.spark.annotation.Experimental
import org.apache.spark.udf.worker.UDFWorkerSpecification
import org.apache.spark.udf.worker.core.{UnixSocketWorkerConnection, WorkerLogger}
import org.apache.spark.udf.worker.core.{WorkerLogger, WorkerSessionFactory}
import org.apache.spark.udf.worker.core.direct.DirectWorkerDispatcher.SOCKET_POLL_INTERVAL_MS

/**
Expand All @@ -31,14 +31,20 @@ import org.apache.spark.udf.worker.core.direct.DirectWorkerDispatcher.SOCKET_POL
* transport. Allocates a private 0700 socket directory at construction;
* each worker is given a UDS path inside it.
*
* Concrete subclasses implement [[createConnection]] (with a UDS protocol
* of choice) and [[createSessionForWorker]].
* Pair this dispatcher with any [[WorkerSessionFactory]] whose
* [[WorkerSessionFactory#createConnection]] knows how to dial a UDS path
* (e.g. a gRPC-over-UDS factory).
*
* @param sessionFactory protocol-specific connection + session factory.
* Owned by this dispatcher; closed by the base on
* dispatcher close.
*/
@Experimental
abstract class DirectUnixSocketWorkerDispatcher(
class DirectUnixSocketWorkerDispatcher(
workerSpec: UDFWorkerSpecification,
sessionFactory: WorkerSessionFactory,
logger: WorkerLogger = WorkerLogger.NoOp)
extends DirectWorkerDispatcher(workerSpec, logger) {
extends DirectWorkerDispatcher(workerSpec, sessionFactory, logger) {

// Removed explicitly in closeTransport(). deleteOnExit is avoided because
// the JDK retains the path for the JVM lifetime, which leaks in
Expand Down Expand Up @@ -100,8 +106,6 @@ abstract class DirectUnixSocketWorkerDispatcher(
s"got ${conn.getTransportCase}")
}

override protected def createConnection(address: String): UnixSocketWorkerConnection

private def throwWorkerExitedBeforeSocket(
process: Process,
address: String,
Expand All @@ -117,8 +121,12 @@ abstract class DirectUnixSocketWorkerDispatcher(
* On non-POSIX filesystems falls back to best-effort `File.setXxx`,
* which is TOCTOU-racy and weaker; a WARN surfaces if the platform
* refuses the setters.
*
* Visible for testing: tests override this to escape Spark's parent-POM
* `java.io.tmpdir=target/tmp`, which on long checkout paths pushes UDS
* paths past Linux's 108-char cap.
*/
private def createPrivateTempDirectory(): Path = {
protected def createPrivateTempDirectory(): Path = {
val attr = PosixFilePermissions.asFileAttribute(
PosixFilePermissions.fromString("rwx------"))
try {
Expand Down
Loading