Skip to content

Commit

Permalink
[SPARK-49057][SQL] Do not block the AQE loop when submitting query st…
Browse files Browse the repository at this point in the history
…ages

### What changes were proposed in this pull request?

We missed the fact that submitting a shuffle or broadcast query stage can be heavy, as it needs to submit subqueries and wait for the results. This blocks the AQE loop and hurts the parallelism of AQE.

This PR fixes the problem by using shuffle/broadcast's own thread pool to wait for subqueries and other preparations.

This PR also re-implements #45234 to avoid submitting the shuffle job if the query is failed and all query stages need to be cancelled.

### Why are the changes needed?

better parallelism for AQE

### Does this PR introduce _any_ user-facing change?

no

### How was this patch tested?

new test case

### Was this patch authored or co-authored using generative AI tooling?

no

Closes #47533 from cloud-fan/aqe.

Authored-by: Wenchen Fan <wenchen@databricks.com>
Signed-off-by: Wenchen Fan <wenchen@databricks.com>
  • Loading branch information
cloud-fan committed Aug 5, 2024
1 parent 94f8872 commit f01eafd
Show file tree
Hide file tree
Showing 7 changed files with 149 additions and 118 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -170,6 +170,16 @@ object StaticSQLConf {
.intConf
.createWithDefault(1000)

val SHUFFLE_EXCHANGE_MAX_THREAD_THRESHOLD =
buildStaticConf("spark.sql.shuffleExchange.maxThreadThreshold")
.internal()
.doc("The maximum degree of parallelism for doing preparation of shuffle exchange, " +
"which includes subquery execution, file listing, etc.")
.version("4.0.0")
.intConf
.checkValue(thres => thres > 0 && thres <= 1024, "The threshold must be in (0,1024].")
.createWithDefault(1024)

val BROADCAST_EXCHANGE_MAX_THREAD_THRESHOLD =
buildStaticConf("spark.sql.broadcastExchange.maxThreadThreshold")
.internal()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -194,7 +194,7 @@ case class ShuffleQueryStageExec(

def advisoryPartitionSize: Option[Long] = shuffle.advisoryPartitionSize

override protected def doMaterialize(): Future[Any] = shuffle.submitShuffleJob
override protected def doMaterialize(): Future[Any] = shuffle.submitShuffleJob()

override def newReuseInstance(
newStageId: Int, newOutput: Seq[Attribute]): ExchangeQueryStageExec = {
Expand Down Expand Up @@ -240,7 +240,7 @@ case class BroadcastQueryStageExec(
throw SparkException.internalError(s"wrong plan for broadcast stage:\n ${plan.treeString}")
}

override protected def doMaterialize(): Future[Any] = broadcast.submitBroadcastJob
override protected def doMaterialize(): Future[Any] = broadcast.submitBroadcastJob()

override def newReuseInstance(
newStageId: Int, newOutput: Seq[Attribute]): ExchangeQueryStageExec = {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ import scala.util.control.NonFatal
import org.apache.spark.{broadcast, SparkException}
import org.apache.spark.internal.LogKeys._
import org.apache.spark.internal.MDC
import org.apache.spark.rdd.RDD
import org.apache.spark.rdd.{RDD, RDDOperationScope}
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.UnsafeRow
import org.apache.spark.sql.catalyst.plans.logical.Statistics
Expand Down Expand Up @@ -61,23 +61,49 @@ trait BroadcastExchangeLike extends Exchange {
*/
def relationFuture: Future[broadcast.Broadcast[Any]]

@transient
private lazy val promise = Promise[Unit]()

@transient
private lazy val scalaFuture: scala.concurrent.Future[Unit] = promise.future

@transient
private lazy val triggerFuture: Future[Any] = {
SQLExecution.withThreadLocalCaptured(session, BroadcastExchangeExec.executionContext) {
try {
// Trigger broadcast preparation which can involve expensive operations like waiting on
// subqueries and file listing.
executeQuery(null)
promise.trySuccess(())
} catch {
case e: Throwable =>
promise.tryFailure(e)
throw e
}
}
}

protected def completionFuture: scala.concurrent.Future[broadcast.Broadcast[Any]]

/**
* The asynchronous job that materializes the broadcast. It's used for registering callbacks on
* `relationFuture`. Note that calling this method may not start the execution of broadcast job.
* It also does the preparations work, such as waiting for the subqueries.
*/
final def submitBroadcastJob: scala.concurrent.Future[broadcast.Broadcast[Any]] = executeQuery {
materializationStarted.set(true)
completionFuture
final def submitBroadcastJob(): scala.concurrent.Future[broadcast.Broadcast[Any]] = {
triggerFuture
scalaFuture.flatMap { _ =>
RDDOperationScope.withScope(sparkContext, nodeName, false, true) {
completionFuture
}
}(BroadcastExchangeExec.executionContext)
}

protected def completionFuture: scala.concurrent.Future[broadcast.Broadcast[Any]]

/**
* Cancels broadcast job with an optional reason.
*/
final def cancelBroadcastJob(reason: Option[String]): Unit = {
if (isMaterializationStarted() && !this.relationFuture.isDone) {
if (!this.relationFuture.isDone) {
reason match {
case Some(r) => sparkContext.cancelJobsWithTag(this.jobTag, r)
case None => sparkContext.cancelJobsWithTag(this.jobTag)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,6 @@

package org.apache.spark.sql.execution.exchange

import java.util.concurrent.atomic.AtomicBoolean

import org.apache.spark.broadcast
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.catalyst.InternalRow
Expand All @@ -36,17 +34,6 @@ import org.apache.spark.sql.vectorized.ColumnarBatch
* "Volcano -- An Extensible and Parallel Query Evaluation System" by Goetz Graefe.
*/
abstract class Exchange extends UnaryExecNode {
/**
* This flag aims to detect if the stage materialization is started. This helps
* to avoid unnecessary AQE stage materialization when the stage is canceled.
*/
protected val materializationStarted = new AtomicBoolean()

/**
* Exposes status if the materialization is started
*/
def isMaterializationStarted(): Boolean = materializationStarted.get()

override def output: Seq[Attribute] = child.output
final override val nodePatterns: Seq[TreePattern] = Seq(EXCHANGE)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,14 +17,15 @@

package org.apache.spark.sql.execution.exchange

import java.util.concurrent.atomic.AtomicReference
import java.util.function.Supplier

import scala.collection.mutable
import scala.concurrent.Future
import scala.concurrent.{ExecutionContext, Future, Promise}

import org.apache.spark._
import org.apache.spark.internal.config
import org.apache.spark.rdd.RDD
import org.apache.spark.rdd.{RDD, RDDOperationScope}
import org.apache.spark.serializer.Serializer
import org.apache.spark.shuffle.{ShuffleWriteMetricsReporter, ShuffleWriteProcessor}
import org.apache.spark.shuffle.sort.SortShuffleManager
Expand All @@ -37,25 +38,15 @@ import org.apache.spark.sql.catalyst.plans.physical._
import org.apache.spark.sql.catalyst.types.DataTypeUtils
import org.apache.spark.sql.execution._
import org.apache.spark.sql.execution.metric.{SQLMetric, SQLMetrics, SQLShuffleReadMetricsReporter, SQLShuffleWriteMetricsReporter}
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.util.MutablePair
import org.apache.spark.sql.internal.{SQLConf, StaticSQLConf}
import org.apache.spark.util.{MutablePair, ThreadUtils}
import org.apache.spark.util.collection.unsafe.sort.{PrefixComparators, RecordComparator}
import org.apache.spark.util.random.XORShiftRandom

/**
* Common trait for all shuffle exchange implementations to facilitate pattern matching.
*/
trait ShuffleExchangeLike extends Exchange {

/**
* The asynchronous job that materializes the shuffle. It also does the preparations work,
* such as waiting for the subqueries.
*/
@transient private lazy val shuffleFuture: Future[MapOutputStatistics] = executeQuery {
materializationStarted.set(true)
mapOutputStatisticsFuture
}

/**
* Returns the number of mappers of this shuffle.
*/
Expand All @@ -76,26 +67,72 @@ trait ShuffleExchangeLike extends Exchange {
*/
def shuffleOrigin: ShuffleOrigin

@transient
private lazy val promise = Promise[MapOutputStatistics]()

@transient
private lazy val completionFuture
: scala.concurrent.Future[MapOutputStatistics] = promise.future

@transient
private[sql] // Exposed for testing
val futureAction = new AtomicReference[Option[FutureAction[MapOutputStatistics]]](None)

@transient
private var isCancelled: Boolean = false

@transient
private lazy val triggerFuture: java.util.concurrent.Future[Any] = {
SQLExecution.withThreadLocalCaptured(session, ShuffleExchangeExec.executionContext) {
try {
// Trigger shuffle preparation which can involve expensive operations like waiting on
// subqueries and file listing.
executeQuery(null)
// Submit shuffle job if not cancelled.
this.synchronized {
if (isCancelled) {
promise.tryFailure(new SparkException("Shuffle cancelled."))
} else {
val shuffleJob = RDDOperationScope.withScope(sparkContext, nodeName, false, true) {
mapOutputStatisticsFuture
}
shuffleJob match {
case action: FutureAction[MapOutputStatistics] => futureAction.set(Some(action))
case _ =>
}
promise.completeWith(shuffleJob)
}
}
null
} catch {
case e: Throwable =>
promise.tryFailure(e)
throw e
}
}
}

/**
* Submits the shuffle job.
* The asynchronous job that materializes the shuffle. It also does the preparations work,
* such as waiting for the subqueries.
*/
final def submitShuffleJob: Future[MapOutputStatistics] = shuffleFuture

protected def mapOutputStatisticsFuture: Future[MapOutputStatistics]
final def submitShuffleJob(): Future[MapOutputStatistics] = {
triggerFuture
completionFuture
}

/**
* Cancels the shuffle job with an optional reason.
*/
final def cancelShuffleJob(reason: Option[String]): Unit = {
if (isMaterializationStarted()) {
shuffleFuture match {
case action: FutureAction[MapOutputStatistics] if !action.isCompleted =>
action.cancel(reason)
case _ =>
}
final def cancelShuffleJob(reason: Option[String]): Unit = this.synchronized {
if (!isCancelled) {
isCancelled = true
futureAction.get().foreach(_.cancel(reason))
}
}

protected def mapOutputStatisticsFuture: Future[MapOutputStatistics]

/**
* Returns the shuffle RDD with specified partition specs.
*/
Expand Down Expand Up @@ -231,6 +268,10 @@ case class ShuffleExchangeExec(

object ShuffleExchangeExec {

private[execution] val executionContext = ExecutionContext.fromExecutorService(
ThreadUtils.newDaemonCachedThreadPool("shuffle-exchange",
SQLConf.get.getConf(StaticSQLConf.SHUFFLE_EXCHANGE_MAX_THREAD_THRESHOLD)))

/**
* Determines whether records must be defensively copied before being sent to the shuffle.
* Several of Spark's shuffle components will buffer deserialized Java objects in memory. The
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1006,7 +1006,7 @@ case class MyShuffleExchangeExec(delegate: ShuffleExchangeExec) extends ShuffleE
delegate.shuffleOrigin
}
override def mapOutputStatisticsFuture: Future[MapOutputStatistics] =
delegate.submitShuffleJob
delegate.submitShuffleJob()
override def getShuffleRDD(partitionSpecs: Array[ShufflePartitionSpec]): RDD[_] =
delegate.getShuffleRDD(partitionSpecs)
override def runtimeStatistics: Statistics = {
Expand All @@ -1032,7 +1032,7 @@ case class MyBroadcastExchangeExec(delegate: BroadcastExchangeExec) extends Broa
override val runId: UUID = delegate.runId
override def relationFuture: java.util.concurrent.Future[Broadcast[Any]] =
delegate.relationFuture
override def completionFuture: Future[Broadcast[Any]] = delegate.submitBroadcastJob
override def completionFuture: Future[Broadcast[Any]] = delegate.submitBroadcastJob()
override def runtimeStatistics: Statistics = delegate.runtimeStatistics
override def child: SparkPlan = delegate.child
override protected def doPrepare(): Unit = delegate.prepare()
Expand Down
Loading

0 comments on commit f01eafd

Please sign in to comment.