Skip to content

Commit

Permalink
Branch 2.2 merge (#232)
Browse files Browse the repository at this point in the history
* [SPARK-23816][CORE] Killed tasks should ignore FetchFailures.

SPARK-19276 ensured that FetchFailures do not get swallowed by other
layers of exception handling, but it also meant that a killed task could
look like a fetch failure.  This is particularly a problem with
speculative execution, where we expect to kill tasks as they are reading
shuffle data.  The fix is to ensure that we always check for killed
tasks first.

Added a new unit test which fails before the fix, ran it 1k times to
check for flakiness.  Full suite of tests on jenkins.

Author: Imran Rashid <irashid@cloudera.com>

Closes apache#20987 from squito/SPARK-23816.

(cherry picked from commit 10f45bb)
Signed-off-by: Marcelo Vanzin <vanzin@cloudera.com>

* [SPARK-24007][SQL] EqualNullSafe for FloatType and DoubleType might generate a wrong result by codegen.

`EqualNullSafe` for `FloatType` and `DoubleType` might generate a wrong result by codegen.

```scala
scala> val df = Seq((Some(-1.0d), None), (None, Some(-1.0d))).toDF()
df: org.apache.spark.sql.DataFrame = [_1: double, _2: double]

scala> df.show()
+----+----+
|  _1|  _2|
+----+----+
|-1.0|null|
|null|-1.0|
+----+----+

scala> df.filter("_1 <=> _2").show()
+----+----+
|  _1|  _2|
+----+----+
|-1.0|null|
|null|-1.0|
+----+----+
```

The result should be empty but the result remains two rows.

Added a test.

Author: Takuya UESHIN <ueshin@databricks.com>

Closes apache#21094 from ueshin/issues/SPARK-24007/equalnullsafe.

(cherry picked from commit f09a9e9)
Signed-off-by: gatorsmile <gatorsmile@gmail.com>

* [SPARK-23963][SQL] Properly handle large number of columns in query on text-based Hive table

## What changes were proposed in this pull request?

TableReader would get disproportionately slower as the number of columns in the query increased.

I fixed the way TableReader was looking up metadata for each column in the row. Previously, it had been looking up this data in linked lists, accessing each linked list by an index (column number). Now it looks up this data in arrays, where indexing by column number works better.

## How was this patch tested?

Manual testing
All sbt unit tests
python sql tests

Author: Bruce Robbins <bersprockets@gmail.com>

Closes apache#21043 from bersprockets/tabreadfix.

* [MINOR][DOCS] Fix comments of SQLExecution#withExecutionId

## What changes were proposed in this pull request?
Fix comment. Change `BroadcastHashJoin.broadcastFuture` to `BroadcastExchangeExec.relationFuture`: https://github.com/apache/spark/blob/d28d5732ae205771f1f443b15b10e64dcffb5ff0/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/BroadcastExchangeExec.scala#L66

## How was this patch tested?
N/A

Author: seancxmao <seancxmao@gmail.com>

Closes apache#21113 from seancxmao/SPARK-13136.

(cherry picked from commit c303b1b)
Signed-off-by: hyukjinkwon <gurwls223@apache.org>

* [SPARK-23941][MESOS] Mesos task failed on specific spark app name

## What changes were proposed in this pull request?
Shell escaped the name passed to spark-submit and change how conf attributes are shell escaped.

## How was this patch tested?
This test has been tested manually with Hive-on-spark with mesos or with the use case described in the issue with the sparkPi application with a custom name which contains illegal shell characters.

With this PR, hive-on-spark on mesos works like a charm with hive 3.0.0-SNAPSHOT.

I state that this contribution is my original work and that I license the work to the project under the project’s open source license

Author: Bounkong Khamphousone <bounkong.khamphousone@ebiznext.com>

Closes apache#21014 from tiboun/fix/SPARK-23941.

(cherry picked from commit 6782359)
Signed-off-by: Marcelo Vanzin <vanzin@cloudera.com>

* [SPARK-23433][CORE] Late zombie task completions update all tasksets

Fetch failure lead to multiple tasksets which are active for a given
stage.  While there is only one "active" version of the taskset, the
earlier attempts can still have running tasks, which can complete
successfully.  So a task completion needs to update every taskset
so that it knows the partition is completed.  That way the final active
taskset does not try to submit another task for the same partition,
and so that it knows when it is completed and when it should be
marked as a "zombie".

Added a regression test.

Author: Imran Rashid <irashid@cloudera.com>

Closes apache#21131 from squito/SPARK-23433.

(cherry picked from commit 94641fe)
Signed-off-by: Imran Rashid <irashid@cloudera.com>

* [SPARK-23489][SQL][TEST][BRANCH-2.2] HiveExternalCatalogVersionsSuite should verify the downloaded file

## What changes were proposed in this pull request?

This is a backport of apache#21210 because `branch-2.2` also faces the same failures.

Although [SPARK-22654](https://issues.apache.org/jira/browse/SPARK-22654) made `HiveExternalCatalogVersionsSuite` download from Apache mirrors three times, it has been flaky because it didn't verify the downloaded file. Some Apache mirrors terminate the downloading abnormally, the *corrupted* file shows the following errors.

```
gzip: stdin: not in gzip format
tar: Child returned status 1
tar: Error is not recoverable: exiting now
22:46:32.700 WARN org.apache.spark.sql.hive.HiveExternalCatalogVersionsSuite:

===== POSSIBLE THREAD LEAK IN SUITE o.a.s.sql.hive.HiveExternalCatalogVersionsSuite, thread names: Keep-Alive-Timer =====

*** RUN ABORTED ***
  java.io.IOException: Cannot run program "./bin/spark-submit" (in directory "/tmp/test-spark/spark-2.2.0"): error=2, No such file or directory
```

This has been reported weirdly in two ways. For example, the above case is reported as Case 2 `no failures`.

- Case 1. [Test Result (1 failure / +1)](https://amplab.cs.berkeley.edu/jenkins/view/Spark%20QA%20Test%20(Dashboard)/job/spark-master-test-sbt-hadoop-2.7/4389/)
- Case 2. [Test Result (no failures)](https://amplab.cs.berkeley.edu/jenkins/view/Spark%20QA%20Test%20(Dashboard)/job/spark-master-test-maven-hadoop-2.6/4811/)

This PR aims to make `HiveExternalCatalogVersionsSuite` more robust by verifying the downloaded `tgz` file by extracting and checking the existence of `bin/spark-submit`. If it turns out that the file is empty or corrupted, `HiveExternalCatalogVersionsSuite` will do retry logic like the download failure.

## How was this patch tested?

Pass the Jenkins.

Author: Dongjoon Hyun <dongjoon@apache.org>

Closes apache#21232 from dongjoon-hyun/SPARK-23489-2.

* [SPARK-23697][CORE] LegacyAccumulatorWrapper should define isZero correctly

## What changes were proposed in this pull request?

It's possible that Accumulators of Spark 1.x may no longer work with Spark 2.x. This is because `LegacyAccumulatorWrapper.isZero` may return wrong answer if `AccumulableParam` doesn't define equals/hashCode.

This PR fixes this by using reference equality check in `LegacyAccumulatorWrapper.isZero`.

## How was this patch tested?

a new test

Author: Wenchen Fan <wenchen@databricks.com>

Closes apache#21229 from cloud-fan/accumulator.

(cherry picked from commit 4d5de4d)
Signed-off-by: Wenchen Fan <wenchen@databricks.com>

* [SPARK-21278][PYSPARK] Upgrade to Py4J 0.10.6

This PR aims to bump Py4J in order to fix the following float/double bug.
Py4J 0.10.5 fixes this (py4j/py4j#272) and the latest Py4J is 0.10.6.

**BEFORE**
```
>>> df = spark.range(1)
>>> df.select(df['id'] + 17.133574204226083).show()
+--------------------+
|(id + 17.1335742042)|
+--------------------+
|       17.1335742042|
+--------------------+
```

**AFTER**
```
>>> df = spark.range(1)
>>> df.select(df['id'] + 17.133574204226083).show()
+-------------------------+
|(id + 17.133574204226083)|
+-------------------------+
|       17.133574204226083|
+-------------------------+
```

Manual.

Author: Dongjoon Hyun <dongjoon@apache.org>

Closes apache#18546 from dongjoon-hyun/SPARK-21278.

(cherry picked from commit c8d0aba)
Signed-off-by: Marcelo Vanzin <vanzin@cloudera.com>

* [SPARK-16406][SQL] Improve performance of LogicalPlan.resolve

`LogicalPlan.resolve(...)` uses linear searches to find an attribute matching a name. This is fine in normal cases, but gets problematic when you try to resolve a large number of columns on a plan with a large number of attributes.

This PR adds an indexing structure to `resolve(...)` in order to find potential matches quicker. This PR improves the reference resolution time for the following code by 4x (11.8s -> 2.4s):

``` scala
val n = 4000
val values = (1 to n).map(_.toString).mkString(", ")
val columns = (1 to n).map("column" + _).mkString(", ")
val query =
  s"""
     |SELECT $columns
     |FROM VALUES ($values) T($columns)
     |WHERE 1=2 AND 1 IN ($columns)
     |GROUP BY $columns
     |ORDER BY $columns
     |""".stripMargin

spark.time(sql(query))
```

Existing tests.

Author: Herman van Hovell <hvanhovell@databricks.com>

Closes apache#14083 from hvanhovell/SPARK-16406.

* [PYSPARK] Update py4j to version 0.10.7.

(cherry picked from commit cc613b5)
Signed-off-by: Marcelo Vanzin <vanzin@cloudera.com>
(cherry picked from commit 323dc3a)
Signed-off-by: Marcelo Vanzin <vanzin@cloudera.com>

* [SPARKR] Match pyspark features in SparkR communication protocol.

(cherry picked from commit 628c7b5)
Signed-off-by: Marcelo Vanzin <vanzin@cloudera.com>
(cherry picked from commit 16cd9ac)
Signed-off-by: Marcelo Vanzin <vanzin@cloudera.com>

* Keep old-style messages for AnalysisException with ambiguous references
  • Loading branch information
markhamstra authored and csd-jenkins committed May 15, 2018
1 parent cdcacda commit 4f7963c
Show file tree
Hide file tree
Showing 49 changed files with 804 additions and 164 deletions.
2 changes: 1 addition & 1 deletion LICENSE
Expand Up @@ -263,7 +263,7 @@ The text of each license is also included at licenses/LICENSE-[project].txt.
(New BSD license) Protocol Buffer Java API (org.spark-project.protobuf:protobuf-java:2.4.1-shaded - http://code.google.com/p/protobuf)
(The BSD License) Fortran to Java ARPACK (net.sourceforge.f2j:arpack_combined_all:0.1 - http://f2j.sourceforge.net)
(The BSD License) xmlenc Library (xmlenc:xmlenc:0.52 - http://xmlenc.sourceforge.net)
(The New BSD License) Py4J (net.sf.py4j:py4j:0.10.4 - http://py4j.sourceforge.net/)
(The New BSD License) Py4J (net.sf.py4j:py4j:0.10.7 - http://py4j.sourceforge.net/)
(Two-clause BSD-style license) JUnit-Interface (com.novocode:junit-interface:0.10 - http://github.com/szeiger/junit-interface/)
(BSD licence) sbt and sbt-launch-lib.bash
(BSD 3 Clause) d3.min.js (https://github.com/mbostock/d3/blob/master/LICENSE)
Expand Down
4 changes: 2 additions & 2 deletions R/pkg/R/client.R
Expand Up @@ -19,7 +19,7 @@

# Creates a SparkR client connection object
# if one doesn't already exist
connectBackend <- function(hostname, port, timeout) {
connectBackend <- function(hostname, port, timeout, authSecret) {
if (exists(".sparkRcon", envir = .sparkREnv)) {
if (isOpen(.sparkREnv[[".sparkRCon"]])) {
cat("SparkRBackend client connection already exists\n")
Expand All @@ -29,7 +29,7 @@ connectBackend <- function(hostname, port, timeout) {

con <- socketConnection(host = hostname, port = port, server = FALSE,
blocking = TRUE, open = "wb", timeout = timeout)

doServerAuth(con, authSecret)
assign(".sparkRCon", con, envir = .sparkREnv)
con
}
Expand Down
10 changes: 7 additions & 3 deletions R/pkg/R/deserialize.R
Expand Up @@ -60,14 +60,18 @@ readTypedObject <- function(con, type) {
stop(paste("Unsupported type for deserialization", type)))
}

readString <- function(con) {
stringLen <- readInt(con)
raw <- readBin(con, raw(), stringLen, endian = "big")
readStringData <- function(con, len) {
raw <- readBin(con, raw(), len, endian = "big")
string <- rawToChar(raw)
Encoding(string) <- "UTF-8"
string
}

readString <- function(con) {
stringLen <- readInt(con)
readStringData(con, stringLen)
}

readInt <- function(con) {
readBin(con, integer(), n = 1, endian = "big")
}
Expand Down
39 changes: 34 additions & 5 deletions R/pkg/R/sparkR.R
Expand Up @@ -161,6 +161,10 @@ sparkR.sparkContext <- function(
" please use the --packages commandline instead", sep = ","))
}
backendPort <- existingPort
authSecret <- Sys.getenv("SPARKR_BACKEND_AUTH_SECRET")
if (nchar(authSecret) == 0) {
stop("Auth secret not provided in environment.")
}
} else {
path <- tempfile(pattern = "backend_port")
submitOps <- getClientModeSparkSubmitOpts(
Expand Down Expand Up @@ -189,16 +193,27 @@ sparkR.sparkContext <- function(
monitorPort <- readInt(f)
rLibPath <- readString(f)
connectionTimeout <- readInt(f)

# Don't use readString() so that we can provide a useful
# error message if the R and Java versions are mismatched.
authSecretLen = readInt(f)
if (length(authSecretLen) == 0 || authSecretLen == 0) {
stop("Unexpected EOF in JVM connection data. Mismatched versions?")
}
authSecret <- readStringData(f, authSecretLen)
close(f)
file.remove(path)
if (length(backendPort) == 0 || backendPort == 0 ||
length(monitorPort) == 0 || monitorPort == 0 ||
length(rLibPath) != 1) {
length(rLibPath) != 1 || length(authSecret) == 0) {
stop("JVM failed to launch")
}
assign(".monitorConn",
socketConnection(port = monitorPort, timeout = connectionTimeout),
envir = .sparkREnv)

monitorConn <- socketConnection(port = monitorPort, blocking = TRUE,
timeout = connectionTimeout, open = "wb")
doServerAuth(monitorConn, authSecret)

assign(".monitorConn", monitorConn, envir = .sparkREnv)
assign(".backendLaunched", 1, envir = .sparkREnv)
if (rLibPath != "") {
assign(".libPath", rLibPath, envir = .sparkREnv)
Expand All @@ -208,7 +223,7 @@ sparkR.sparkContext <- function(

.sparkREnv$backendPort <- backendPort
tryCatch({
connectBackend("localhost", backendPort, timeout = connectionTimeout)
connectBackend("localhost", backendPort, timeout = connectionTimeout, authSecret = authSecret)
},
error = function(err) {
stop("Failed to connect JVM\n")
Expand Down Expand Up @@ -632,3 +647,17 @@ sparkCheckInstall <- function(sparkHome, master, deployMode) {
NULL
}
}

# Utility function for sending auth data over a socket and checking the server's reply.
doServerAuth <- function(con, authSecret) {
if (nchar(authSecret) == 0) {
stop("Auth secret not provided.")
}
writeString(con, authSecret)
flush(con)
reply <- readString(con)
if (reply != "ok") {
close(con)
stop("Unexpected reply from server.")
}
}
4 changes: 3 additions & 1 deletion R/pkg/inst/worker/daemon.R
Expand Up @@ -28,7 +28,9 @@ suppressPackageStartupMessages(library(SparkR))

port <- as.integer(Sys.getenv("SPARKR_WORKER_PORT"))
inputCon <- socketConnection(
port = port, open = "rb", blocking = TRUE, timeout = connectionTimeout)
port = port, open = "wb", blocking = TRUE, timeout = connectionTimeout)

SparkR:::doServerAuth(inputCon, Sys.getenv("SPARKR_WORKER_SECRET"))

while (TRUE) {
ready <- socketSelect(list(inputCon))
Expand Down
5 changes: 4 additions & 1 deletion R/pkg/inst/worker/worker.R
Expand Up @@ -100,9 +100,12 @@ suppressPackageStartupMessages(library(SparkR))

port <- as.integer(Sys.getenv("SPARKR_WORKER_PORT"))
inputCon <- socketConnection(
port = port, blocking = TRUE, open = "rb", timeout = connectionTimeout)
port = port, blocking = TRUE, open = "wb", timeout = connectionTimeout)
SparkR:::doServerAuth(inputCon, Sys.getenv("SPARKR_WORKER_SECRET"))

outputCon <- socketConnection(
port = port, blocking = TRUE, open = "wb", timeout = connectionTimeout)
SparkR:::doServerAuth(outputCon, Sys.getenv("SPARKR_WORKER_SECRET"))

# read the index of the current partition inside the RDD
partition <- SparkR:::readInt(inputCon)
Expand Down
6 changes: 3 additions & 3 deletions bin/pyspark
Expand Up @@ -25,14 +25,14 @@ source "${SPARK_HOME}"/bin/load-spark-env.sh
export _SPARK_CMD_USAGE="Usage: ./bin/pyspark [options]"

# In Spark 2.0, IPYTHON and IPYTHON_OPTS are removed and pyspark fails to launch if either option
# is set in the user's environment. Instead, users should set PYSPARK_DRIVER_PYTHON=ipython
# is set in the user's environment. Instead, users should set PYSPARK_DRIVER_PYTHON=ipython
# to use IPython and set PYSPARK_DRIVER_PYTHON_OPTS to pass options when starting the Python driver
# (e.g. PYSPARK_DRIVER_PYTHON_OPTS='notebook'). This supports full customization of the IPython
# and executor Python executables.

# Fail noisily if removed options are set
if [[ -n "$IPYTHON" || -n "$IPYTHON_OPTS" ]]; then
echo "Error in pyspark startup:"
echo "Error in pyspark startup:"
echo "IPYTHON and IPYTHON_OPTS are removed in Spark 2.0+. Remove these from the environment and set PYSPARK_DRIVER_PYTHON and PYSPARK_DRIVER_PYTHON_OPTS instead."
exit 1
fi
Expand All @@ -57,7 +57,7 @@ export PYSPARK_PYTHON

# Add the PySpark classes to the Python path:
export PYTHONPATH="${SPARK_HOME}/python/:$PYTHONPATH"
export PYTHONPATH="${SPARK_HOME}/python/lib/py4j-0.10.4-src.zip:$PYTHONPATH"
export PYTHONPATH="${SPARK_HOME}/python/lib/py4j-0.10.7-src.zip:$PYTHONPATH"

# Load the PySpark shell.py script when ./pyspark is used interactively:
export OLD_PYTHONSTARTUP="$PYTHONSTARTUP"
Expand Down
2 changes: 1 addition & 1 deletion bin/pyspark2.cmd
Expand Up @@ -30,7 +30,7 @@ if "x%PYSPARK_DRIVER_PYTHON%"=="x" (
)

set PYTHONPATH=%SPARK_HOME%\python;%PYTHONPATH%
set PYTHONPATH=%SPARK_HOME%\python\lib\py4j-0.10.4-src.zip;%PYTHONPATH%
set PYTHONPATH=%SPARK_HOME%\python\lib\py4j-0.10.7-src.zip;%PYTHONPATH%

set OLD_PYTHONSTARTUP=%PYTHONSTARTUP%
set PYTHONSTARTUP=%SPARK_HOME%\python\pyspark\shell.py
Expand Down
2 changes: 1 addition & 1 deletion core/pom.xml
Expand Up @@ -335,7 +335,7 @@
<dependency>
<groupId>net.sf.py4j</groupId>
<artifactId>py4j</artifactId>
<version>0.10.4</version>
<version>0.10.7</version>
</dependency>
<dependency>
<groupId>org.apache.spark</groupId>
Expand Down
11 changes: 2 additions & 9 deletions core/src/main/scala/org/apache/spark/SecurityManager.scala
Expand Up @@ -17,13 +17,11 @@

package org.apache.spark

import java.lang.{Byte => JByte}
import java.net.{Authenticator, PasswordAuthentication}
import java.security.{KeyStore, SecureRandom}
import java.security.KeyStore
import java.security.cert.X509Certificate
import javax.net.ssl._

import com.google.common.hash.HashCodes
import com.google.common.io.Files
import org.apache.hadoop.io.Text

Expand Down Expand Up @@ -435,12 +433,7 @@ private[spark] class SecurityManager(
val secretKey = SparkHadoopUtil.get.getSecretKeyFromUserCredentials(SECRET_LOOKUP_KEY)
if (secretKey == null || secretKey.length == 0) {
logDebug("generateSecretKey: yarn mode, secret key from credentials is null")
val rnd = new SecureRandom()
val length = sparkConf.getInt("spark.authenticate.secretBitLength", 256) / JByte.SIZE
val secret = new Array[Byte](length)
rnd.nextBytes(secret)

val cookie = HashCodes.fromBytes(secret).toString()
val cookie = Utils.createSecret(sparkConf)
SparkHadoopUtil.get.addSecretKeyToUserCredentials(SECRET_LOOKUP_KEY, cookie)
cookie
} else {
Expand Down
Expand Up @@ -17,26 +17,39 @@

package org.apache.spark.api.python

import java.io.DataOutputStream
import java.net.Socket
import java.io.{DataOutputStream, File, FileOutputStream}
import java.net.InetAddress
import java.nio.charset.StandardCharsets.UTF_8
import java.nio.file.Files

import py4j.GatewayServer

import org.apache.spark.SparkConf
import org.apache.spark.internal.Logging
import org.apache.spark.util.Utils

/**
* Process that starts a Py4J GatewayServer on an ephemeral port and communicates the bound port
* back to its caller via a callback port specified by the caller.
* Process that starts a Py4J GatewayServer on an ephemeral port.
*
* This process is launched (via SparkSubmit) by the PySpark driver (see java_gateway.py).
*/
private[spark] object PythonGatewayServer extends Logging {
initializeLogIfNecessary(true)

def main(args: Array[String]): Unit = Utils.tryOrExit {
// Start a GatewayServer on an ephemeral port
val gatewayServer: GatewayServer = new GatewayServer(null, 0)
def main(args: Array[String]): Unit = {
val secret = Utils.createSecret(new SparkConf())

// Start a GatewayServer on an ephemeral port. Make sure the callback client is configured
// with the same secret, in case the app needs callbacks from the JVM to the underlying
// python processes.
val localhost = InetAddress.getLoopbackAddress()
val gatewayServer: GatewayServer = new GatewayServer.GatewayServerBuilder()
.authToken(secret)
.javaPort(0)
.javaAddress(localhost)
.callbackClient(GatewayServer.DEFAULT_PYTHON_PORT, localhost, secret)
.build()

gatewayServer.start()
val boundPort: Int = gatewayServer.getListeningPort
if (boundPort == -1) {
Expand All @@ -46,15 +59,24 @@ private[spark] object PythonGatewayServer extends Logging {
logDebug(s"Started PythonGatewayServer on port $boundPort")
}

// Communicate the bound port back to the caller via the caller-specified callback port
val callbackHost = sys.env("_PYSPARK_DRIVER_CALLBACK_HOST")
val callbackPort = sys.env("_PYSPARK_DRIVER_CALLBACK_PORT").toInt
logDebug(s"Communicating GatewayServer port to Python driver at $callbackHost:$callbackPort")
val callbackSocket = new Socket(callbackHost, callbackPort)
val dos = new DataOutputStream(callbackSocket.getOutputStream)
// Communicate the connection information back to the python process by writing the
// information in the requested file. This needs to match the read side in java_gateway.py.
val connectionInfoPath = new File(sys.env("_PYSPARK_DRIVER_CONN_INFO_PATH"))
val tmpPath = Files.createTempFile(connectionInfoPath.getParentFile().toPath(),
"connection", ".info").toFile()

val dos = new DataOutputStream(new FileOutputStream(tmpPath))
dos.writeInt(boundPort)

val secretBytes = secret.getBytes(UTF_8)
dos.writeInt(secretBytes.length)
dos.write(secretBytes, 0, secretBytes.length)
dos.close()
callbackSocket.close()

if (!tmpPath.renameTo(connectionInfoPath)) {
logError(s"Unable to write connection information to $connectionInfoPath.")
System.exit(1)
}

// Exit on EOF or broken pipe to ensure that this process dies when the Python driver dies:
while (System.in.read() != -1) {
Expand Down
29 changes: 22 additions & 7 deletions core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala
Expand Up @@ -38,6 +38,7 @@ import org.apache.spark.broadcast.Broadcast
import org.apache.spark.input.PortableDataStream
import org.apache.spark.internal.Logging
import org.apache.spark.rdd.RDD
import org.apache.spark.security.SocketAuthHelper
import org.apache.spark.util._


Expand Down Expand Up @@ -421,6 +422,12 @@ private[spark] object PythonRDD extends Logging {
// remember the broadcasts sent to each worker
private val workerBroadcasts = new mutable.WeakHashMap[Socket, mutable.Set[Long]]()

// Authentication helper used when serving iterator data.
private lazy val authHelper = {
val conf = Option(SparkEnv.get).map(_.conf).getOrElse(new SparkConf())
new SocketAuthHelper(conf)
}

def getWorkerBroadcasts(worker: Socket): mutable.Set[Long] = {
synchronized {
workerBroadcasts.getOrElseUpdate(worker, new mutable.HashSet[Long]())
Expand All @@ -443,12 +450,13 @@ private[spark] object PythonRDD extends Logging {
* (effectively a collect()), but allows you to run on a certain subset of partitions,
* or to enable local execution.
*
* @return the port number of a local socket which serves the data collected from this job.
* @return 2-tuple (as a Java array) with the port number of a local socket which serves the
* data collected from this job, and the secret for authentication.
*/
def runJob(
sc: SparkContext,
rdd: JavaRDD[Array[Byte]],
partitions: JArrayList[Int]): Int = {
partitions: JArrayList[Int]): Array[Any] = {
type ByteArray = Array[Byte]
type UnrolledPartition = Array[ByteArray]
val allPartitions: Array[UnrolledPartition] =
Expand All @@ -461,13 +469,14 @@ private[spark] object PythonRDD extends Logging {
/**
* A helper function to collect an RDD as an iterator, then serve it via socket.
*
* @return the port number of a local socket which serves the data collected from this job.
* @return 2-tuple (as a Java array) with the port number of a local socket which serves the
* data collected from this job, and the secret for authentication.
*/
def collectAndServe[T](rdd: RDD[T]): Int = {
def collectAndServe[T](rdd: RDD[T]): Array[Any] = {
serveIterator(rdd.collect().iterator, s"serve RDD ${rdd.id}")
}

def toLocalIteratorAndServe[T](rdd: RDD[T]): Int = {
def toLocalIteratorAndServe[T](rdd: RDD[T]): Array[Any] = {
serveIterator(rdd.toLocalIterator, s"serve toLocalIterator")
}

Expand Down Expand Up @@ -698,8 +707,11 @@ private[spark] object PythonRDD extends Logging {
* and send them into this connection.
*
* The thread will terminate after all the data are sent or any exceptions happen.
*
* @return 2-tuple (as a Java array) with the port number of a local socket which serves the
* data collected from this job, and the secret for authentication.
*/
def serveIterator[T](items: Iterator[T], threadName: String): Int = {
def serveIterator(items: Iterator[_], threadName: String): Array[Any] = {
val serverSocket = new ServerSocket(0, 1, InetAddress.getByName("localhost"))
// Close the socket if no connection in 15 seconds
serverSocket.setSoTimeout(15000)
Expand All @@ -709,11 +721,14 @@ private[spark] object PythonRDD extends Logging {
override def run() {
try {
val sock = serverSocket.accept()
authHelper.authClient(sock)

val out = new DataOutputStream(new BufferedOutputStream(sock.getOutputStream))
Utils.tryWithSafeFinally {
writeIteratorToStream(items, out)
} {
out.close()
sock.close()
}
} catch {
case NonFatal(e) =>
Expand All @@ -724,7 +739,7 @@ private[spark] object PythonRDD extends Logging {
}
}.start()

serverSocket.getLocalPort
Array(serverSocket.getLocalPort, authHelper.secret)
}

private def getMergedConf(confAsMap: java.util.HashMap[String, String],
Expand Down
Expand Up @@ -32,7 +32,7 @@ private[spark] object PythonUtils {
val pythonPath = new ArrayBuffer[String]
for (sparkHome <- sys.env.get("SPARK_HOME")) {
pythonPath += Seq(sparkHome, "python", "lib", "pyspark.zip").mkString(File.separator)
pythonPath += Seq(sparkHome, "python", "lib", "py4j-0.10.4-src.zip").mkString(File.separator)
pythonPath += Seq(sparkHome, "python", "lib", "py4j-0.10.7-src.zip").mkString(File.separator)
}
pythonPath ++= SparkContext.jarOfObject(this)
pythonPath.mkString(File.pathSeparator)
Expand Down

0 comments on commit 4f7963c

Please sign in to comment.