Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
47 changes: 32 additions & 15 deletions core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala
Original file line number Diff line number Diff line change
Expand Up @@ -39,8 +39,9 @@ import org.apache.spark.api.python.PythonFunction.PythonAccumulator
import org.apache.spark.broadcast.Broadcast
import org.apache.spark.input.PortableDataStream
import org.apache.spark.internal.{Logging, MDC}
import org.apache.spark.internal.LogKeys.{HOST, PORT}
import org.apache.spark.internal.LogKeys.{HOST, PORT, SOCKET_ADDRESS}
import org.apache.spark.internal.config.BUFFER_SIZE
import org.apache.spark.internal.config.Python.PYTHON_UNIX_DOMAIN_SOCKET_ENABLED
import org.apache.spark.network.util.JavaUtils
import org.apache.spark.rdd.RDD
import org.apache.spark.security.{SocketAuthHelper, SocketAuthServer, SocketFuncServer}
Expand Down Expand Up @@ -717,35 +718,50 @@ class BytesToString extends org.apache.spark.api.java.function.Function[Array[By
* collects a list of pickled strings that we pass to Python through a socket.
*/
private[spark] class PythonAccumulatorV2(
@transient private val serverHost: String,
private val serverPort: Int,
private val secretToken: String)
@transient private val serverHost: Option[String],
private val serverPort: Option[Int],
private val secretToken: Option[String],
@transient private val socketPath: Option[String])
extends CollectionAccumulator[Array[Byte]] with Logging {

Utils.checkHost(serverHost)
// Unix domain socket
def this(socketPath: String) = this(None, None, None, Some(socketPath))
// TPC socket
def this(serverHost: String, serverPort: Int, secretToken: String) = this(
Some(serverHost), Some(serverPort), Some(secretToken), None)

serverHost.foreach(Utils.checkHost)

val bufferSize = SparkEnv.get.conf.get(BUFFER_SIZE)
val isUnixDomainSock = SparkEnv.get.conf.get(PYTHON_UNIX_DOMAIN_SOCKET_ENABLED)

/**
* We try to reuse a single Socket to transfer accumulator updates, as they are all added
* by the DAGScheduler's single-threaded RpcEndpoint anyway.
*/
@transient private var socket: Socket = _
@transient private var socket: SocketChannel = _

private def openSocket(): Socket = synchronized {
if (socket == null || socket.isClosed) {
socket = new Socket(serverHost, serverPort)
logInfo(log"Connected to AccumulatorServer at host: ${MDC(HOST, serverHost)}" +
log" port: ${MDC(PORT, serverPort)}")
private def openSocket(): SocketChannel = synchronized {
if (socket == null || !socket.isOpen) {
if (isUnixDomainSock) {
socket = SocketChannel.open(UnixDomainSocketAddress.of(socketPath.get))
logInfo(log"Connected to AccumulatorServer at socket: ${MDC(SOCKET_ADDRESS, serverHost)}")
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

serverHost.get

} else {
socket = SocketChannel.open(new InetSocketAddress(serverHost.get, serverPort.get))
logInfo(log"Connected to AccumulatorServer at host: ${MDC(HOST, serverHost)}" +
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

serverHost.get

log" port: ${MDC(PORT, serverPort)}")
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

serverPort.get

}
// send the secret just for the initial authentication when opening a new connection
socket.getOutputStream.write(secretToken.getBytes(StandardCharsets.UTF_8))
secretToken.foreach { token =>
Channels.newOutputStream(socket).write(token.getBytes(StandardCharsets.UTF_8))
}
}
socket
}

// Need to override so the types match with PythonFunction
override def copyAndReset(): PythonAccumulatorV2 = {
new PythonAccumulatorV2(serverHost, serverPort, secretToken)
new PythonAccumulatorV2(serverHost, serverPort, secretToken, socketPath)
}

override def merge(other: AccumulatorV2[Array[Byte], JList[Array[Byte]]]): Unit = synchronized {
Expand All @@ -758,8 +774,9 @@ private[spark] class PythonAccumulatorV2(
} else {
// This happens on the master, where we pass the updates to Python through a socket
val socket = openSocket()
val in = socket.getInputStream
val out = new DataOutputStream(new BufferedOutputStream(socket.getOutputStream, bufferSize))
val in = Channels.newInputStream(socket)
val out = new DataOutputStream(
new BufferedOutputStream(Channels.newOutputStream(socket), bufferSize))
val values = other.value
out.writeInt(values.size)
for (array <- values.asScala) {
Expand Down
60 changes: 44 additions & 16 deletions python/pyspark/accumulators.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,12 +15,13 @@
# limitations under the License.
#

import os
import sys
import select
import struct
import socketserver as SocketServer
import socketserver
import threading
from typing import Callable, Dict, Generic, Tuple, Type, TYPE_CHECKING, TypeVar, Union
from typing import Callable, Dict, Generic, Tuple, Type, TYPE_CHECKING, TypeVar, Union, Optional

from pyspark.serializers import read_int, CPickleSerializer
from pyspark.errors import PySparkRuntimeError
Expand Down Expand Up @@ -252,7 +253,7 @@ def addInPlace(self, value1: U, value2: U) -> U:
COMPLEX_ACCUMULATOR_PARAM = AddingAccumulatorParam(0.0j) # type: ignore[type-var]


class _UpdateRequestHandler(SocketServer.StreamRequestHandler):
class UpdateRequestHandler(socketserver.StreamRequestHandler):

"""
This handler will keep polling updates from the same socket until the
Expand Down Expand Up @@ -293,37 +294,64 @@ def authenticate_and_accum_updates() -> bool:
"The value of the provided token to the AccumulatorServer is not correct."
)

# first we keep polling till we've received the authentication token
poll(authenticate_and_accum_updates)
# Unix Domain Socket does not need the auth.
if auth_token is not None:
# first we keep polling till we've received the authentication token
poll(authenticate_and_accum_updates)

# now we've authenticated, don't need to check for the token anymore
poll(accum_updates)


class AccumulatorServer(SocketServer.TCPServer):
class AccumulatorTCPServer(socketserver.TCPServer):
server_shutdown = False

def __init__(
self,
server_address: Tuple[str, int],
RequestHandlerClass: Type["socketserver.BaseRequestHandler"],
auth_token: str,
):
SocketServer.TCPServer.__init__(self, server_address, RequestHandlerClass)
super().__init__(server_address, RequestHandlerClass)
self.auth_token = auth_token

"""
A simple TCP server that intercepts shutdown() in order to interrupt
our continuous polling on the handler.
"""
def shutdown(self) -> None:
self.server_shutdown = True
super().shutdown()
self.server_close()


class AccumulatorUnixServer(socketserver.UnixStreamServer):
server_shutdown = False

def __init__(
self, socket_path: str, RequestHandlerClass: Type[socketserver.BaseRequestHandler]
):
super().__init__(socket_path, RequestHandlerClass)
self.auth_token = None

def shutdown(self) -> None:
self.server_shutdown = True
SocketServer.TCPServer.shutdown(self)
super().shutdown()
self.server_close()
if os.path.exists(self.server_address): # type: ignore[arg-type]
os.remove(self.server_address) # type: ignore[arg-type]


def _start_update_server(
auth_token: str, is_unix_domain_sock: bool, socket_path: Optional[str] = None
) -> Union[AccumulatorTCPServer, AccumulatorUnixServer]:
"""Start a TCP or Unix Domain Socket server for accumulator updates."""
if is_unix_domain_sock:
assert socket_path is not None
if os.path.exists(socket_path):
os.remove(socket_path)
server = AccumulatorUnixServer(socket_path, UpdateRequestHandler)
else:
server = AccumulatorTCPServer(
("localhost", 0), UpdateRequestHandler, auth_token
) # type: ignore[assignment]


def _start_update_server(auth_token: str) -> AccumulatorServer:
"""Start a TCP server to receive accumulator updates in a daemon thread, and returns it"""
server = AccumulatorServer(("localhost", 0), _UpdateRequestHandler, auth_token)
thread = threading.Thread(target=server.serve_forever)
thread.daemon = True
thread.start()
Expand Down
21 changes: 18 additions & 3 deletions python/pyspark/core/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
# limitations under the License.
#

import uuid
import os
import shutil
import signal
Expand Down Expand Up @@ -305,11 +306,25 @@ def _do_init(
# they will be passed back to us through a TCP server
assert self._gateway is not None
auth_token = self._gateway.gateway_parameters.auth_token
is_unix_domain_sock = (
self._conf.get("spark.python.unix.domain.socket.enabled", "false").lower() == "true"
)
socket_path = None
if is_unix_domain_sock:
socket_dir = self._conf.get("spark.python.unix.domain.socket.dir")
if socket_dir is None:
socket_dir = getattr(self._jvm, "java.lang.System").getProperty("java.io.tmpdir")
socket_path = os.path.join(socket_dir, f".{uuid.uuid4()}.sock")
start_update_server = accumulators._start_update_server
self._accumulatorServer = start_update_server(auth_token)
(host, port) = self._accumulatorServer.server_address
self._accumulatorServer = start_update_server(auth_token, is_unix_domain_sock, socket_path)
assert self._jvm is not None
self._javaAccumulator = self._jvm.PythonAccumulatorV2(host, port, auth_token)
if is_unix_domain_sock:
self._javaAccumulator = self._jvm.PythonAccumulatorV2(
self._accumulatorServer.server_address
)
else:
(host, port) = self._accumulatorServer.server_address # type: ignore[misc]
self._javaAccumulator = self._jvm.PythonAccumulatorV2(host, port, auth_token)
self._jsc.sc().register(self._javaAccumulator)

# If encryption is enabled, we need to setup a server in the jvm to read broadcast
Expand Down