Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion R/pkg/R/context.R
Original file line number Diff line number Diff line change
Expand Up @@ -181,7 +181,7 @@ parallelize <- function(sc, coll, numSlices = 1) {
parallelism <- as.integer(numSlices)
jserver <- newJObject("org.apache.spark.api.r.RParallelizeServer", sc, parallelism)
authSecret <- callJMethod(jserver, "secret")
port <- callJMethod(jserver, "port")
port <- callJMethod(jserver, "connInfo")
conn <- socketConnection(
port = port, blocking = TRUE, open = "wb", timeout = connectionTimeout)
doServerAuth(conn, authSecret)
Expand Down
31 changes: 16 additions & 15 deletions core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ package org.apache.spark.api.python

import java.io._
import java.net._
import java.nio.channels.{Channels, SocketChannel}
import java.nio.charset.StandardCharsets
import java.util.{ArrayList => JArrayList, List => JList, Map => JMap}

Expand Down Expand Up @@ -231,9 +232,9 @@ private[spark] object PythonRDD extends Logging {
* server object that can be used to join the JVM serving thread in Python.
*/
def toLocalIteratorAndServe[T](rdd: RDD[T], prefetchPartitions: Boolean = false): Array[Any] = {
val handleFunc = (sock: Socket) => {
val out = new DataOutputStream(sock.getOutputStream)
val in = new DataInputStream(sock.getInputStream)
val handleFunc = (sock: SocketChannel) => {
val out = new DataOutputStream(Channels.newOutputStream(sock))
val in = new DataInputStream(Channels.newInputStream(sock))
Utils.tryWithSafeFinallyAndFailureCallbacks(block = {
// Collects a partition on each iteration
val collectPartitionIter = rdd.partitions.indices.iterator.map { i =>
Expand Down Expand Up @@ -287,7 +288,7 @@ private[spark] object PythonRDD extends Logging {
}

val server = new SocketFuncServer(authHelper, "serve toLocalIterator", handleFunc)
Array(server.port, server.secret, server)
Array(server.connInfo, server.secret, server)
}

def readRDDFromFile(
Expand Down Expand Up @@ -831,21 +832,21 @@ private[spark] class PythonBroadcast(@transient var path: String) extends Serial

def setupEncryptionServer(): Array[Any] = {
encryptionServer = new SocketAuthServer[Unit]("broadcast-encrypt-server") {
override def handleConnection(sock: Socket): Unit = {
override def handleConnection(sock: SocketChannel): Unit = {
val env = SparkEnv.get
val in = sock.getInputStream()
val in = Channels.newInputStream(sock)
val abspath = new File(path).getAbsolutePath
val out = env.serializerManager.wrapForEncryption(new FileOutputStream(abspath))
DechunkedInputStream.dechunkAndCopyToOutput(in, out)
}
}
Array(encryptionServer.port, encryptionServer.secret)
Array(encryptionServer.connInfo, encryptionServer.secret)
}

def setupDecryptionServer(): Array[Any] = {
decryptionServer = new SocketAuthServer[Unit]("broadcast-decrypt-server-for-driver") {
override def handleConnection(sock: Socket): Unit = {
val out = new DataOutputStream(new BufferedOutputStream(sock.getOutputStream()))
override def handleConnection(sock: SocketChannel): Unit = {
val out = new DataOutputStream(new BufferedOutputStream(Channels.newOutputStream(sock)))
Utils.tryWithSafeFinally {
val in = SparkEnv.get.serializerManager.wrapForEncryption(new FileInputStream(path))
Utils.tryWithSafeFinally {
Expand All @@ -859,7 +860,7 @@ private[spark] class PythonBroadcast(@transient var path: String) extends Serial
}
}
}
Array(decryptionServer.port, decryptionServer.secret)
Array(decryptionServer.connInfo, decryptionServer.secret)
}

def waitTillBroadcastDataSent(): Unit = decryptionServer.getResult()
Expand Down Expand Up @@ -945,8 +946,8 @@ private[spark] class EncryptedPythonBroadcastServer(
val idsAndFiles: Seq[(Long, String)])
extends SocketAuthServer[Unit]("broadcast-decrypt-server") with Logging {

override def handleConnection(socket: Socket): Unit = {
val out = new DataOutputStream(new BufferedOutputStream(socket.getOutputStream()))
override def handleConnection(socket: SocketChannel): Unit = {
val out = new DataOutputStream(new BufferedOutputStream(Channels.newOutputStream(socket)))
var socketIn: InputStream = null
// send the broadcast id, then the decrypted data. We don't need to send the length, the
// the python pickle module just needs a stream.
Expand All @@ -962,7 +963,7 @@ private[spark] class EncryptedPythonBroadcastServer(
}
logTrace("waiting for python to accept broadcast data over socket")
out.flush()
socketIn = socket.getInputStream()
socketIn = Channels.newInputStream(socket)
socketIn.read()
logTrace("done serving broadcast data")
} {
Expand All @@ -983,8 +984,8 @@ private[spark] class EncryptedPythonBroadcastServer(
private[spark] abstract class PythonRDDServer
extends SocketAuthServer[JavaRDD[Array[Byte]]]("pyspark-parallelize-server") {

def handleConnection(sock: Socket): JavaRDD[Array[Byte]] = {
val in = sock.getInputStream()
def handleConnection(sock: SocketChannel): JavaRDD[Array[Byte]] = {
val in = Channels.newInputStream(sock)
val dechunkedInput: InputStream = new DechunkedInputStream(in)
streamToRDD(dechunkedInput)
}
Expand Down
100 changes: 58 additions & 42 deletions core/src/main/scala/org/apache/spark/api/python/PythonRunner.scala
Original file line number Diff line number Diff line change
Expand Up @@ -20,9 +20,9 @@ package org.apache.spark.api.python
import java.io._
import java.net._
import java.nio.ByteBuffer
import java.nio.channels.SelectionKey
import java.nio.charset.StandardCharsets.UTF_8
import java.nio.channels.{AsynchronousCloseException, Channels, SelectionKey, ServerSocketChannel, SocketChannel}
import java.nio.file.{Files => JavaFiles, Path}
import java.util.UUID
import java.util.concurrent.{ConcurrentHashMap, TimeUnit}
import java.util.concurrent.atomic.AtomicBoolean

Expand Down Expand Up @@ -201,9 +201,10 @@ private[spark] abstract class BasePythonRunner[IN, OUT](
// Python accumulator is always set in production except in tests. See SPARK-27893
private val maybeAccumulator: Option[PythonAccumulator] = Option(accumulator)

// Expose a ServerSocket to support method calls via socket from Python side. Only relevant for
// for tasks that are a part of barrier stage, refer [[BarrierTaskContext]] for details.
private[spark] var serverSocket: Option[ServerSocket] = None
// Expose a ServerSocketChannel to support method calls via socket from Python side.
// Only relevant for tasks that are a part of barrier stage, refer
// `BarrierTaskContext` for details.
private[spark] var serverSocketChannel: Option[ServerSocketChannel] = None

// Authentication helper used when serving method calls via socket from Python side.
private lazy val authHelper = new SocketAuthHelper(conf)
Expand Down Expand Up @@ -347,6 +348,11 @@ private[spark] abstract class BasePythonRunner[IN, OUT](
def writeNextInputToStream(dataOut: DataOutputStream): Boolean

def open(dataOut: DataOutputStream): Unit = Utils.logUncaughtExceptions {
val isUnixDomainSock = authHelper.conf.get(PYTHON_UNIX_DOMAIN_SOCKET_ENABLED)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Instead of if/else's everywhere you could consider using an interface with two implementations. That improves readability of the code by a lot...

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yup I will do it in a separate PR that factors those out. I think we can take the common code out to utils.

lazy val sockPath = new File(
authHelper.conf.get(PYTHON_UNIX_DOMAIN_SOCKET_DIR)
.getOrElse(System.getProperty("java.io.tmpdir")),
s".${UUID.randomUUID()}.sock")
try {
// Partition index
dataOut.writeInt(partitionIndex)
Expand All @@ -356,27 +362,34 @@ private[spark] abstract class BasePythonRunner[IN, OUT](
// Init a ServerSocket to accept method calls from Python side.
val isBarrier = context.isInstanceOf[BarrierTaskContext]
if (isBarrier) {
serverSocket = Some(new ServerSocket(/* port */ 0,
/* backlog */ 1,
InetAddress.getByName("localhost")))
// A call to accept() for ServerSocket shall block infinitely.
serverSocket.foreach(_.setSoTimeout(0))
if (isUnixDomainSock) {
serverSocketChannel = Some(ServerSocketChannel.open(StandardProtocolFamily.UNIX))
sockPath.deleteOnExit()
serverSocketChannel.get.bind(UnixDomainSocketAddress.of(sockPath.getPath))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What are the file permissions on the UDS?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's srwxrwxr-x. So only the owner and the same group can connect/read. Others cannot (because you need the write access to connect to the socket) since the read and execute bits on the socket file itself are mostly irrelevant to how socket communication works.

} else {
serverSocketChannel = Some(ServerSocketChannel.open())
serverSocketChannel.foreach(_.bind(
new InetSocketAddress(InetAddress.getLoopbackAddress(), 0), 1))
// A call to accept() for ServerSocket shall block infinitely.
serverSocketChannel.foreach(_.socket().setSoTimeout(0))
}

new Thread("accept-connections") {
setDaemon(true)

override def run(): Unit = {
while (!serverSocket.get.isClosed()) {
var sock: Socket = null
while (serverSocketChannel.get.isOpen()) {
var sock: SocketChannel = null
try {
sock = serverSocket.get.accept()
sock = serverSocketChannel.get.accept()
// Wait for function call from python side.
sock.setSoTimeout(10000)
if (!isUnixDomainSock) sock.socket().setSoTimeout(10000)
authHelper.authClient(sock)
val input = new DataInputStream(sock.getInputStream())
val input = new DataInputStream(Channels.newInputStream(sock))
val requestMethod = input.readInt()
// The BarrierTaskContext function may wait infinitely, socket shall not timeout
// before the function finishes.
sock.setSoTimeout(0)
if (!isUnixDomainSock) sock.socket().setSoTimeout(0)
requestMethod match {
case BarrierTaskContextMessageProtocol.BARRIER_FUNCTION =>
barrierAndServe(requestMethod, sock)
Expand All @@ -385,13 +398,14 @@ private[spark] abstract class BasePythonRunner[IN, OUT](
barrierAndServe(requestMethod, sock, message)
case _ =>
val out = new DataOutputStream(new BufferedOutputStream(
sock.getOutputStream))
Channels.newOutputStream(sock)))
writeUTF(BarrierTaskContextMessageProtocol.ERROR_UNRECOGNIZED_FUNCTION, out)
}
} catch {
case e: SocketException if e.getMessage.contains("Socket closed") =>
// It is possible that the ServerSocket is not closed, but the native socket
// has already been closed, we shall catch and silently ignore this case.
case _: AsynchronousCloseException =>
// Ignore to make less noisy. These will be closed when tasks
// are finished by listeners.
if (isUnixDomainSock) sockPath.delete()
} finally {
if (sock != null) {
sock.close()
Expand All @@ -401,33 +415,35 @@ private[spark] abstract class BasePythonRunner[IN, OUT](
}
}.start()
}
val secret = if (isBarrier) {
authHelper.secret
} else {
""
}
if (isBarrier) {
// Close ServerSocket on task completion.
serverSocket.foreach { server =>
context.addTaskCompletionListener[Unit](_ => server.close())
serverSocketChannel.foreach { server =>
context.addTaskCompletionListener[Unit] { _ =>
server.close()
if (isUnixDomainSock) sockPath.delete()
}
}
val boundPort: Int = serverSocket.map(_.getLocalPort).getOrElse(0)
if (boundPort == -1) {
val message = "ServerSocket failed to bind to Java side."
logError(message)
throw new SparkException(message)
if (isUnixDomainSock) {
logDebug(s"Started ServerSocket on with Unix Domain Socket $sockPath.")
dataOut.writeBoolean(/* isBarrier = */true)
dataOut.writeInt(-1)
PythonRDD.writeUTF(sockPath.getPath, dataOut)
} else {
val boundPort: Int = serverSocketChannel.map(_.socket().getLocalPort).getOrElse(-1)
if (boundPort == -1) {
val message = "ServerSocket failed to bind to Java side."
logError(message)
throw new SparkException(message)
}
logDebug(s"Started ServerSocket on port $boundPort.")
dataOut.writeBoolean(/* isBarrier = */true)
dataOut.writeInt(boundPort)
PythonRDD.writeUTF(authHelper.secret, dataOut)
}
logDebug(s"Started ServerSocket on port $boundPort.")
dataOut.writeBoolean(/* isBarrier = */true)
dataOut.writeInt(boundPort)
} else {
dataOut.writeBoolean(/* isBarrier = */false)
dataOut.writeInt(0)
}
// Write out the TaskContextInfo
val secretBytes = secret.getBytes(UTF_8)
dataOut.writeInt(secretBytes.length)
dataOut.write(secretBytes, 0, secretBytes.length)
dataOut.writeInt(context.stageId())
dataOut.writeInt(context.partitionId())
dataOut.writeInt(context.attemptNumber())
Expand Down Expand Up @@ -485,12 +501,12 @@ private[spark] abstract class BasePythonRunner[IN, OUT](
/**
* Gateway to call BarrierTaskContext methods.
*/
def barrierAndServe(requestMethod: Int, sock: Socket, message: String = ""): Unit = {
def barrierAndServe(requestMethod: Int, sock: SocketChannel, message: String = ""): Unit = {
require(
serverSocket.isDefined,
serverSocketChannel.isDefined,
"No available ServerSocket to redirect the BarrierTaskContext method call."
)
val out = new DataOutputStream(new BufferedOutputStream(sock.getOutputStream))
val out = new DataOutputStream(new BufferedOutputStream(Channels.newOutputStream(sock)))
try {
val messages = requestMethod match {
case BarrierTaskContextMessageProtocol.BARRIER_FUNCTION =>
Expand Down
Loading