Skip to content
Permalink
Browse files

[SPARK-20774][SPARK-27036][SQL] Cancel the running broadcast executio…

…n on BroadcastTimeout

## What changes were proposed in this pull request?

In the existing code, a broadcast execution timeout for the Future only causes a query failure, but the job running with the broadcast and the computation in the Future are not canceled. This wastes resources and slows down the other jobs. This PR tries to cancel both the running job and the running hashed relation construction thread.

## How was this patch tested?

Add new test suite `BroadcastExchangeExec`

Closes #24595 from jiangxb1987/SPARK-20774.

Authored-by: Xingbo Jiang <xingbo.jiang@databricks.com>
Signed-off-by: gatorsmile <gatorsmile@gmail.com>
  • Loading branch information...
jiangxb1987 authored and gatorsmile committed May 15, 2019
1 parent efa3035 commit 0bba5cf56832f0690a4ebd733d01a0416e4c7252
@@ -2026,7 +2026,10 @@ class SQLConf extends Serializable with Logging {

def columnNameOfCorruptRecord: String = getConf(COLUMN_NAME_OF_CORRUPT_RECORD)

def broadcastTimeout: Long = getConf(BROADCAST_TIMEOUT)
def broadcastTimeout: Long = {
val timeoutValue = getConf(BROADCAST_TIMEOUT)
if (timeoutValue < 0) Long.MaxValue else timeoutValue
}

def defaultDataSourceName: String = getConf(DEFAULT_DATA_SOURCE_NAME)

@@ -17,10 +17,11 @@

package org.apache.spark.sql.execution.exchange

import java.util.concurrent.TimeoutException
import java.util.UUID
import java.util.concurrent._

import scala.concurrent.{ExecutionContext, Future}
import scala.concurrent.duration._
import scala.concurrent.ExecutionContext
import scala.concurrent.duration.NANOSECONDS
import scala.util.control.NonFatal

import org.apache.spark.{broadcast, SparkException}
@@ -43,6 +44,8 @@ case class BroadcastExchangeExec(
mode: BroadcastMode,
child: SparkPlan) extends Exchange {

private val runId: UUID = UUID.randomUUID

override lazy val metrics = Map(
"dataSize" -> SQLMetrics.createSizeMetric(sparkContext, "data size"),
"collectTime" -> SQLMetrics.createTimingMetric(sparkContext, "time to collect"),
@@ -56,79 +59,79 @@ case class BroadcastExchangeExec(
}

@transient
private val timeout: Duration = {
val timeoutValue = sqlContext.conf.broadcastTimeout
if (timeoutValue < 0) {
Duration.Inf
} else {
timeoutValue.seconds
}
}
private val timeout: Long = SQLConf.get.broadcastTimeout

@transient
private lazy val relationFuture: Future[broadcast.Broadcast[Any]] = {
// broadcastFuture is used in "doExecute". Therefore we can get the execution id correctly here.
// relationFuture is used in "doExecute". Therefore we can get the execution id correctly here.
val executionId = sparkContext.getLocalProperty(SQLExecution.EXECUTION_ID_KEY)
Future {
// This will run in another thread. Set the execution id so that we can connect these jobs
// with the correct execution.
SQLExecution.withExecutionId(sqlContext.sparkSession, executionId) {
try {
val beforeCollect = System.nanoTime()
// Use executeCollect/executeCollectIterator to avoid conversion to Scala types
val (numRows, input) = child.executeCollectIterator()
if (numRows >= 512000000) {
throw new SparkException(
s"Cannot broadcast the table with 512 million or more rows: $numRows rows")
}

val beforeBuild = System.nanoTime()
longMetric("collectTime") += NANOSECONDS.toMillis(beforeBuild - beforeCollect)

// Construct the relation.
val relation = mode.transform(input, Some(numRows))

val dataSize = relation match {
case map: HashedRelation =>
map.estimatedSize
case arr: Array[InternalRow] =>
arr.map(_.asInstanceOf[UnsafeRow].getSizeInBytes.toLong).sum
case _ =>
throw new SparkException("[BUG] BroadcastMode.transform returned unexpected type: " +
relation.getClass.getName)
}

longMetric("dataSize") += dataSize
if (dataSize >= (8L << 30)) {
throw new SparkException(
s"Cannot broadcast the table that is larger than 8GB: ${dataSize >> 30} GB")
val task = new Callable[broadcast.Broadcast[Any]]() {
override def call(): broadcast.Broadcast[Any] = {
// This will run in another thread. Set the execution id so that we can connect these jobs
// with the correct execution.
SQLExecution.withExecutionId(sqlContext.sparkSession, executionId) {
try {
// Setup a job group here so later it may get cancelled by groupId if necessary.
sparkContext.setJobGroup(runId.toString, s"broadcast exchange (runId $runId)",
interruptOnCancel = true)
val beforeCollect = System.nanoTime()
// Use executeCollect/executeCollectIterator to avoid conversion to Scala types
val (numRows, input) = child.executeCollectIterator()
if (numRows >= 512000000) {
throw new SparkException(
s"Cannot broadcast the table with 512 million or more rows: $numRows rows")
}

val beforeBuild = System.nanoTime()
longMetric("collectTime") += NANOSECONDS.toMillis(beforeBuild - beforeCollect)

// Construct the relation.
val relation = mode.transform(input, Some(numRows))

val dataSize = relation match {
case map: HashedRelation =>
map.estimatedSize
case arr: Array[InternalRow] =>
arr.map(_.asInstanceOf[UnsafeRow].getSizeInBytes.toLong).sum
case _ =>
throw new SparkException("[BUG] BroadcastMode.transform returned unexpected " +
s"type: ${relation.getClass.getName}")
}

longMetric("dataSize") += dataSize
if (dataSize >= (8L << 30)) {
throw new SparkException(
s"Cannot broadcast the table that is larger than 8GB: ${dataSize >> 30} GB")
}

val beforeBroadcast = System.nanoTime()
longMetric("buildTime") += NANOSECONDS.toMillis(beforeBroadcast - beforeBuild)

// Broadcast the relation
val broadcasted = sparkContext.broadcast(relation)
longMetric("broadcastTime") += NANOSECONDS.toMillis(
System.nanoTime() - beforeBroadcast)

SQLMetrics.postDriverMetricUpdates(sparkContext, executionId, metrics.values.toSeq)
broadcasted
} catch {
// SPARK-24294: To bypass scala bug: https://github.com/scala/bug/issues/9554, we throw
// SparkFatalException, which is a subclass of Exception. ThreadUtils.awaitResult
// will catch this exception and re-throw the wrapped fatal throwable.
case oe: OutOfMemoryError =>
throw new SparkFatalException(
new OutOfMemoryError("Not enough memory to build and broadcast the table to all " +
"worker nodes. As a workaround, you can either disable broadcast by setting " +
s"${SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key} to -1 or increase the spark " +
s"driver memory by setting ${SparkLauncher.DRIVER_MEMORY} to a higher value.")
.initCause(oe.getCause))
case e if !NonFatal(e) =>
throw new SparkFatalException(e)
}

val beforeBroadcast = System.nanoTime()
longMetric("buildTime") += NANOSECONDS.toMillis(beforeBroadcast - beforeBuild)

// Broadcast the relation
val broadcasted = sparkContext.broadcast(relation)
longMetric("broadcastTime") += NANOSECONDS.toMillis(System.nanoTime() - beforeBroadcast)

SQLMetrics.postDriverMetricUpdates(sparkContext, executionId, metrics.values.toSeq)
broadcasted
} catch {
// SPARK-24294: To bypass scala bug: https://github.com/scala/bug/issues/9554, we throw
// SparkFatalException, which is a subclass of Exception. ThreadUtils.awaitResult
// will catch this exception and re-throw the wrapped fatal throwable.
case oe: OutOfMemoryError =>
throw new SparkFatalException(
new OutOfMemoryError(s"Not enough memory to build and broadcast the table to " +
s"all worker nodes. As a workaround, you can either disable broadcast by setting " +
s"${SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key} to -1 or increase the spark driver " +
s"memory by setting ${SparkLauncher.DRIVER_MEMORY} to a higher value")
.initCause(oe.getCause))
case e if !NonFatal(e) =>
throw new SparkFatalException(e)
}
}
}(BroadcastExchangeExec.executionContext)
}
BroadcastExchangeExec.executionContext.submit[broadcast.Broadcast[Any]](task)
}

override protected def doPrepare(): Unit = {
@@ -143,11 +146,15 @@ case class BroadcastExchangeExec(

override protected[sql] def doExecuteBroadcast[T](): broadcast.Broadcast[T] = {
try {
ThreadUtils.awaitResult(relationFuture, timeout).asInstanceOf[broadcast.Broadcast[T]]
relationFuture.get(timeout, TimeUnit.SECONDS).asInstanceOf[broadcast.Broadcast[T]]
} catch {
case ex: TimeoutException =>
logError(s"Could not execute broadcast in ${timeout.toSeconds} secs.", ex)
throw new SparkException(s"Could not execute broadcast in ${timeout.toSeconds} secs. " +
logError(s"Could not execute broadcast in $timeout secs.", ex)
if (!relationFuture.isDone) {
sparkContext.cancelJobGroup(runId.toString)
relationFuture.cancel(true)
}
throw new SparkException(s"Could not execute broadcast in $timeout secs. " +
s"You can increase the timeout for broadcasts via ${SQLConf.BROADCAST_TIMEOUT.key} or " +
s"disable broadcast join by setting ${SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key} to -1",
ex)
@@ -0,0 +1,93 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.sql.execution

import java.util.concurrent.{CountDownLatch, TimeUnit}

import org.apache.spark.SparkException
import org.apache.spark.scheduler._
import org.apache.spark.sql.execution.exchange.BroadcastExchangeExec
import org.apache.spark.sql.execution.joins.HashedRelation
import org.apache.spark.sql.functions.broadcast
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.test.SharedSQLContext

class BroadcastExchangeSuite extends SparkPlanTest with SharedSQLContext {

import testImplicits._

test("BroadcastExchange should cancel the job group if timeout") {
val startLatch = new CountDownLatch(1)
val endLatch = new CountDownLatch(1)
var jobEvents: Seq[SparkListenerEvent] = Seq.empty[SparkListenerEvent]
spark.sparkContext.addSparkListener(new SparkListener {
override def onJobEnd(jobEnd: SparkListenerJobEnd): Unit = {
jobEvents :+= jobEnd
endLatch.countDown()
}
override def onJobStart(jobStart: SparkListenerJobStart): Unit = {
jobEvents :+= jobStart
startLatch.countDown()
}
})

withSQLConf(SQLConf.BROADCAST_TIMEOUT.key -> "0") {
val df = spark.range(100).join(spark.range(15).as[Long].map { x =>
Thread.sleep(5000)
x
}).where("id = value")

// get the exchange physical plan
val hashExchange = df.queryExecution.executedPlan
.collect { case p: BroadcastExchangeExec => p }.head

// materialize the future and wait for the job being scheduled
hashExchange.prepare()
startLatch.await(5, TimeUnit.SECONDS)

// check timeout exception is captured by just executing the exchange
val hashEx = intercept[SparkException] {
hashExchange.executeBroadcast[HashedRelation]()
}
assert(hashEx.getMessage.contains("Could not execute broadcast"))

// wait for cancel is posted and then check the results.
endLatch.await(5, TimeUnit.SECONDS)
assert(jobCancelled())
}

def jobCancelled(): Boolean = {
val events = jobEvents.toArray
val hasStart = events(0).isInstanceOf[SparkListenerJobStart]
val hasCancelled = events(1).asInstanceOf[SparkListenerJobEnd].jobResult
.asInstanceOf[JobFailed].exception.getMessage.contains("cancelled job group")
events.length == 2 && hasStart && hasCancelled
}
}

test("set broadcastTimeout to -1") {
withSQLConf(SQLConf.BROADCAST_TIMEOUT.key -> "-1") {
val df = spark.range(1).toDF()
val joinDF = df.join(broadcast(df), "id")
val broadcastExchangeExec = joinDF.queryExecution.executedPlan
.collect { case p: BroadcastExchangeExec => p }
assert(broadcastExchangeExec.size == 1, "one and only BroadcastExchangeExec")
assert(joinDF.collect().length == 1)
}
}
}

0 comments on commit 0bba5cf

Please sign in to comment.
You can’t perform that action at this time.