Skip to content

Commit

Permalink
[SPARKR] Match pyspark features in SparkR communication protocol.
Browse files Browse the repository at this point in the history
  • Loading branch information
Marcelo Vanzin committed May 9, 2018
1 parent cc613b5 commit 628c7b5
Show file tree
Hide file tree
Showing 10 changed files with 210 additions and 29 deletions.
4 changes: 2 additions & 2 deletions R/pkg/R/client.R
Expand Up @@ -19,7 +19,7 @@

# Creates a SparkR client connection object
# if one doesn't already exist
connectBackend <- function(hostname, port, timeout) {
connectBackend <- function(hostname, port, timeout, authSecret) {
if (exists(".sparkRcon", envir = .sparkREnv)) {
if (isOpen(.sparkREnv[[".sparkRCon"]])) {
cat("SparkRBackend client connection already exists\n")
Expand All @@ -29,7 +29,7 @@ connectBackend <- function(hostname, port, timeout) {

con <- socketConnection(host = hostname, port = port, server = FALSE,
blocking = TRUE, open = "wb", timeout = timeout)

doServerAuth(con, authSecret)
assign(".sparkRCon", con, envir = .sparkREnv)
con
}
Expand Down
10 changes: 7 additions & 3 deletions R/pkg/R/deserialize.R
Expand Up @@ -60,14 +60,18 @@ readTypedObject <- function(con, type) {
stop(paste("Unsupported type for deserialization", type)))
}

readString <- function(con) {
stringLen <- readInt(con)
raw <- readBin(con, raw(), stringLen, endian = "big")
readStringData <- function(con, len) {
raw <- readBin(con, raw(), len, endian = "big")
string <- rawToChar(raw)
Encoding(string) <- "UTF-8"
string
}

readString <- function(con) {
stringLen <- readInt(con)
readStringData(con, stringLen)
}

readInt <- function(con) {
readBin(con, integer(), n = 1, endian = "big")
}
Expand Down
39 changes: 34 additions & 5 deletions R/pkg/R/sparkR.R
Expand Up @@ -158,6 +158,10 @@ sparkR.sparkContext <- function(
" please use the --packages commandline instead", sep = ","))
}
backendPort <- existingPort
authSecret <- Sys.getenv("SPARKR_BACKEND_AUTH_SECRET")
if (nchar(authSecret) == 0) {
stop("Auth secret not provided in environment.")
}
} else {
path <- tempfile(pattern = "backend_port")
submitOps <- getClientModeSparkSubmitOpts(
Expand Down Expand Up @@ -186,16 +190,27 @@ sparkR.sparkContext <- function(
monitorPort <- readInt(f)
rLibPath <- readString(f)
connectionTimeout <- readInt(f)

# Don't use readString() so that we can provide a useful
# error message if the R and Java versions are mismatched.
authSecretLen = readInt(f)
if (length(authSecretLen) == 0 || authSecretLen == 0) {
stop("Unexpected EOF in JVM connection data. Mismatched versions?")
}
authSecret <- readStringData(f, authSecretLen)
close(f)
file.remove(path)
if (length(backendPort) == 0 || backendPort == 0 ||
length(monitorPort) == 0 || monitorPort == 0 ||
length(rLibPath) != 1) {
length(rLibPath) != 1 || length(authSecret) == 0) {
stop("JVM failed to launch")
}
assign(".monitorConn",
socketConnection(port = monitorPort, timeout = connectionTimeout),
envir = .sparkREnv)

monitorConn <- socketConnection(port = monitorPort, blocking = TRUE,
timeout = connectionTimeout, open = "wb")
doServerAuth(monitorConn, authSecret)

assign(".monitorConn", monitorConn, envir = .sparkREnv)
assign(".backendLaunched", 1, envir = .sparkREnv)
if (rLibPath != "") {
assign(".libPath", rLibPath, envir = .sparkREnv)
Expand All @@ -205,7 +220,7 @@ sparkR.sparkContext <- function(

.sparkREnv$backendPort <- backendPort
tryCatch({
connectBackend("localhost", backendPort, timeout = connectionTimeout)
connectBackend("localhost", backendPort, timeout = connectionTimeout, authSecret = authSecret)
},
error = function(err) {
stop("Failed to connect JVM\n")
Expand Down Expand Up @@ -687,3 +702,17 @@ sparkCheckInstall <- function(sparkHome, master, deployMode) {
NULL
}
}

# Utility function for sending auth data over a socket and checking the server's reply.
doServerAuth <- function(con, authSecret) {
if (nchar(authSecret) == 0) {
stop("Auth secret not provided.")
}
writeString(con, authSecret)
flush(con)
reply <- readString(con)
if (reply != "ok") {
close(con)
stop("Unexpected reply from server.")
}
}
4 changes: 3 additions & 1 deletion R/pkg/inst/worker/daemon.R
Expand Up @@ -28,7 +28,9 @@ suppressPackageStartupMessages(library(SparkR))

port <- as.integer(Sys.getenv("SPARKR_WORKER_PORT"))
inputCon <- socketConnection(
port = port, open = "rb", blocking = TRUE, timeout = connectionTimeout)
port = port, open = "wb", blocking = TRUE, timeout = connectionTimeout)

SparkR:::doServerAuth(inputCon, Sys.getenv("SPARKR_WORKER_SECRET"))

# Waits indefinitely for a socket connecion by default.
selectTimeout <- NULL
Expand Down
5 changes: 4 additions & 1 deletion R/pkg/inst/worker/worker.R
Expand Up @@ -100,9 +100,12 @@ suppressPackageStartupMessages(library(SparkR))

port <- as.integer(Sys.getenv("SPARKR_WORKER_PORT"))
inputCon <- socketConnection(
port = port, blocking = TRUE, open = "rb", timeout = connectionTimeout)
port = port, blocking = TRUE, open = "wb", timeout = connectionTimeout)
SparkR:::doServerAuth(inputCon, Sys.getenv("SPARKR_WORKER_SECRET"))

outputCon <- socketConnection(
port = port, blocking = TRUE, open = "wb", timeout = connectionTimeout)
SparkR:::doServerAuth(outputCon, Sys.getenv("SPARKR_WORKER_SECRET"))

# read the index of the current partition inside the RDD
partition <- SparkR:::readInt(inputCon)
Expand Down
38 changes: 38 additions & 0 deletions core/src/main/scala/org/apache/spark/api/r/RAuthHelper.scala
@@ -0,0 +1,38 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.api.r

import java.io.{DataInputStream, DataOutputStream}
import java.net.Socket

import org.apache.spark.SparkConf
import org.apache.spark.security.SocketAuthHelper

private[spark] class RAuthHelper(conf: SparkConf) extends SocketAuthHelper(conf) {

override protected def readUtf8(s: Socket): String = {
SerDe.readString(new DataInputStream(s.getInputStream()))
}

override protected def writeUtf8(str: String, s: Socket): Unit = {
val out = s.getOutputStream()
SerDe.writeString(new DataOutputStream(out), str)
out.flush()
}

}
43 changes: 37 additions & 6 deletions core/src/main/scala/org/apache/spark/api/r/RBackend.scala
Expand Up @@ -17,8 +17,8 @@

package org.apache.spark.api.r

import java.io.{DataOutputStream, File, FileOutputStream, IOException}
import java.net.{InetAddress, InetSocketAddress, ServerSocket}
import java.io.{DataInputStream, DataOutputStream, File, FileOutputStream, IOException}
import java.net.{InetAddress, InetSocketAddress, ServerSocket, Socket}
import java.util.concurrent.TimeUnit

import io.netty.bootstrap.ServerBootstrap
Expand All @@ -32,6 +32,8 @@ import io.netty.handler.timeout.ReadTimeoutHandler

import org.apache.spark.SparkConf
import org.apache.spark.internal.Logging
import org.apache.spark.network.util.JavaUtils
import org.apache.spark.util.Utils

/**
* Netty-based backend server that is used to communicate between R and Java.
Expand All @@ -45,14 +47,15 @@ private[spark] class RBackend {
/** Tracks JVM objects returned to R for this RBackend instance. */
private[r] val jvmObjectTracker = new JVMObjectTracker

def init(): Int = {
def init(): (Int, RAuthHelper) = {
val conf = new SparkConf()
val backendConnectionTimeout = conf.getInt(
"spark.r.backendConnectionTimeout", SparkRDefaults.DEFAULT_CONNECTION_TIMEOUT)
bossGroup = new NioEventLoopGroup(
conf.getInt("spark.r.numRBackendThreads", SparkRDefaults.DEFAULT_NUM_RBACKEND_THREADS))
val workerGroup = bossGroup
val handler = new RBackendHandler(this)
val authHelper = new RAuthHelper(conf)

bootstrap = new ServerBootstrap()
.group(bossGroup, workerGroup)
Expand All @@ -71,13 +74,16 @@ private[spark] class RBackend {
new LengthFieldBasedFrameDecoder(Integer.MAX_VALUE, 0, 4, 0, 4))
.addLast("decoder", new ByteArrayDecoder())
.addLast("readTimeoutHandler", new ReadTimeoutHandler(backendConnectionTimeout))
.addLast(new RBackendAuthHandler(authHelper.secret))
.addLast("handler", handler)
}
})

channelFuture = bootstrap.bind(new InetSocketAddress("localhost", 0))
channelFuture.syncUninterruptibly()
channelFuture.channel().localAddress().asInstanceOf[InetSocketAddress].getPort()

val port = channelFuture.channel().localAddress().asInstanceOf[InetSocketAddress].getPort()
(port, authHelper)
}

def run(): Unit = {
Expand Down Expand Up @@ -116,7 +122,7 @@ private[spark] object RBackend extends Logging {
val sparkRBackend = new RBackend()
try {
// bind to random port
val boundPort = sparkRBackend.init()
val (boundPort, authHelper) = sparkRBackend.init()
val serverSocket = new ServerSocket(0, 1, InetAddress.getByName("localhost"))
val listenPort = serverSocket.getLocalPort()
// Connection timeout is set by socket client. To make it configurable we will pass the
Expand All @@ -133,6 +139,7 @@ private[spark] object RBackend extends Logging {
dos.writeInt(listenPort)
SerDe.writeString(dos, RUtils.rPackages.getOrElse(""))
dos.writeInt(backendConnectionTimeout)
SerDe.writeString(dos, authHelper.secret)
dos.close()
f.renameTo(new File(path))

Expand All @@ -144,12 +151,35 @@ private[spark] object RBackend extends Logging {
val buf = new Array[Byte](1024)
// shutdown JVM if R does not connect back in 10 seconds
serverSocket.setSoTimeout(10000)

// Wait for the R process to connect back, ignoring any failed auth attempts. Allow
// a max number of connection attempts to avoid looping forever.
try {
val inSocket = serverSocket.accept()
var remainingAttempts = 10
var inSocket: Socket = null
while (inSocket == null) {
inSocket = serverSocket.accept()
try {
authHelper.authClient(inSocket)
} catch {
case e: Exception =>
remainingAttempts -= 1
if (remainingAttempts == 0) {
val msg = "Too many failed authentication attempts."
logError(msg)
throw new IllegalStateException(msg)
}
logInfo("Client connection failed authentication.")
inSocket = null
}
}

serverSocket.close()

// wait for the end of socket, closed if R process die
inSocket.getInputStream().read(buf)
} finally {
serverSocket.close()
sparkRBackend.close()
System.exit(0)
}
Expand All @@ -165,4 +195,5 @@ private[spark] object RBackend extends Logging {
}
System.exit(0)
}

}
@@ -0,0 +1,55 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.api.r

import java.io.{ByteArrayOutputStream, DataOutputStream}
import java.nio.charset.StandardCharsets.UTF_8

import io.netty.channel.{Channel, ChannelHandlerContext, SimpleChannelInboundHandler}

import org.apache.spark.internal.Logging
import org.apache.spark.util.Utils

/**
* Authentication handler for connections from the R process.
*/
private class RBackendAuthHandler(secret: String)
extends SimpleChannelInboundHandler[Array[Byte]] with Logging {

override def channelRead0(ctx: ChannelHandlerContext, msg: Array[Byte]): Unit = {
// The R code adds a null terminator to serialized strings, so ignore it here.
val clientSecret = new String(msg, 0, msg.length - 1, UTF_8)
try {
require(secret == clientSecret, "Auth secret mismatch.")
ctx.pipeline().remove(this)
writeReply("ok", ctx.channel())
} catch {
case e: Exception =>
logInfo("Authentication failure.", e)
writeReply("err", ctx.channel())
ctx.close()
}
}

private def writeReply(reply: String, chan: Channel): Unit = {
val out = new ByteArrayOutputStream()
SerDe.writeString(new DataOutputStream(out), reply)
chan.writeAndFlush(out.toByteArray())
}

}

0 comments on commit 628c7b5

Please sign in to comment.