Skip to content

Commit

Permalink
SPARK-47148 - Refactor AQE APIs
Browse files Browse the repository at this point in the history
  • Loading branch information
erenavsarogullari committed Apr 29, 2024
1 parent 7950208 commit 1105282
Show file tree
Hide file tree
Showing 5 changed files with 64 additions and 47 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -17,11 +17,11 @@

package org.apache.spark.sql.execution.adaptive

import java.util.concurrent.atomic.{AtomicBoolean, AtomicReference}
import java.util.concurrent.atomic.AtomicReference

import scala.concurrent.Future

import org.apache.spark.{FutureAction, MapOutputStatistics, SparkException}
import org.apache.spark.{MapOutputStatistics, SparkException}
import org.apache.spark.broadcast.Broadcast
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.catalyst.InternalRow
Expand Down Expand Up @@ -56,25 +56,13 @@ abstract class QueryStageExec extends LeafExecNode {
*/
val name: String = s"${this.getClass.getSimpleName}-$id"

/**
* This flag aims to detect if the stage materialization is started. This helps
* to avoid unnecessary stage materialization when the stage is canceled.
*/
private val materializationStarted = new AtomicBoolean()

/**
* Exposes status if the materialization is started
*/
def isMaterializationStarted(): Boolean = materializationStarted.get()

/**
* Materialize this query stage, to prepare for the execution, like submitting map stages,
* broadcasting data, etc. The caller side can use the returned [[Future]] to wait until this
* stage is ready.
*/
final def materialize(): Future[Any] = {
logDebug(s"Materialize query stage: $name")
materializationStarted.set(true)
doMaterialize()
}

Expand Down Expand Up @@ -168,7 +156,12 @@ abstract class ExchangeQueryStageExec extends QueryStageExec {
/**
* Cancel the stage materialization if in progress; otherwise do nothing.
*/
def cancel(): Unit
final def cancel(): Unit = {
logDebug(s"Cancel query stage: $name")
doCancel()
}

protected def doCancel(): Unit

/**
* The canonicalized plan before applying query stage optimizer rules.
Expand Down Expand Up @@ -201,9 +194,7 @@ case class ShuffleQueryStageExec(

def advisoryPartitionSize: Option[Long] = shuffle.advisoryPartitionSize

@transient private lazy val shuffleFuture = shuffle.submitShuffleJob

override protected def doMaterialize(): Future[Any] = shuffleFuture
override protected def doMaterialize(): Future[Any] = shuffle.submitShuffleJob

override def newReuseInstance(
newStageId: Int, newOutput: Seq[Attribute]): ExchangeQueryStageExec = {
Expand All @@ -215,16 +206,7 @@ case class ShuffleQueryStageExec(
reuse
}

override def cancel(): Unit = {
if (isMaterializationStarted()) {
shuffleFuture match {
case action: FutureAction[MapOutputStatistics] if !action.isCompleted =>
action.cancel()
logInfo(s"$name is cancelled.")
case _ =>
}
}
}
override protected def doCancel(): Unit = shuffle.cancelShuffleJob

/**
* Returns the Option[MapOutputStatistics]. If the shuffle map stage has no partition,
Expand Down Expand Up @@ -258,9 +240,7 @@ case class BroadcastQueryStageExec(
throw SparkException.internalError(s"wrong plan for broadcast stage:\n ${plan.treeString}")
}

override protected def doMaterialize(): Future[Any] = {
broadcast.submitBroadcastJob
}
override protected def doMaterialize(): Future[Any] = broadcast.submitBroadcastJob

override def newReuseInstance(
newStageId: Int, newOutput: Seq[Attribute]): ExchangeQueryStageExec = {
Expand All @@ -272,13 +252,7 @@ case class BroadcastQueryStageExec(
reuse
}

override def cancel(): Unit = {
if (isMaterializationStarted() && !broadcast.relationFuture.isDone) {
sparkContext.cancelJobsWithTag(broadcast.jobTag)
broadcast.relationFuture.cancel(true)
logInfo(s"$name is cancelled.")
}
}
override protected def doCancel(): Unit = broadcast.cancelBroadcastJob()

override def getRuntimeStatistics: Statistics = broadcast.runtimeStatistics
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -67,11 +67,22 @@ trait BroadcastExchangeLike extends Exchange {
* It also does the preparations work, such as waiting for the subqueries.
*/
final def submitBroadcastJob: scala.concurrent.Future[broadcast.Broadcast[Any]] = executeQuery {
materializationStarted.set(true)
completionFuture
}

protected def completionFuture: scala.concurrent.Future[broadcast.Broadcast[Any]]

/**
* Cancels broadcast job.
*/
final def cancelBroadcastJob(): Unit = {
if (isMaterializationStarted() && !this.relationFuture.isDone) {
sparkContext.cancelJobsWithTag(this.jobTag)
this.relationFuture.cancel(true)
}
}

/**
* Returns the runtime statistics after broadcast materialization.
*/
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@

package org.apache.spark.sql.execution.exchange

import java.util.concurrent.atomic.AtomicBoolean

import org.apache.spark.broadcast
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.catalyst.InternalRow
Expand All @@ -34,6 +36,17 @@ import org.apache.spark.sql.vectorized.ColumnarBatch
* "Volcano -- An Extensible and Parallel Query Evaluation System" by Goetz Graefe.
*/
abstract class Exchange extends UnaryExecNode {
/**
* This flag aims to detect if the stage materialization is started. This helps
* to avoid unnecessary AQE stage materialization when the stage is canceled.
*/
protected val materializationStarted = new AtomicBoolean()

/**
* Exposes status if the materialization is started
*/
def isMaterializationStarted(): Boolean = materializationStarted.get()

override def output: Seq[Attribute] = child.output
final override val nodePatterns: Seq[TreePattern] = Seq(EXCHANGE)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,15 @@ import org.apache.spark.util.random.XORShiftRandom
*/
trait ShuffleExchangeLike extends Exchange {

/**
* The asynchronous job that materializes the shuffle. It also does the preparations work,
* such as waiting for the subqueries.
*/
@transient private lazy val shuffleFuture: Future[MapOutputStatistics] = executeQuery {
materializationStarted.set(true)
mapOutputStatisticsFuture
}

/**
* Returns the number of mappers of this shuffle.
*/
Expand All @@ -68,15 +77,25 @@ trait ShuffleExchangeLike extends Exchange {
def shuffleOrigin: ShuffleOrigin

/**
* The asynchronous job that materializes the shuffle. It also does the preparations work,
* such as waiting for the subqueries.
* Submits the shuffle job.
*/
final def submitShuffleJob: Future[MapOutputStatistics] = executeQuery {
mapOutputStatisticsFuture
}
final def submitShuffleJob: Future[MapOutputStatistics] = shuffleFuture

protected def mapOutputStatisticsFuture: Future[MapOutputStatistics]

/**
* Cancels the shuffle job.
*/
final def cancelShuffleJob: Unit = {
if (isMaterializationStarted()) {
shuffleFuture match {
case action: FutureAction[MapOutputStatistics] if !action.isCompleted =>
action.cancel()
case _ =>
}
}
}

/**
* Returns the shuffle RDD with specified partition specs.
*/
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -930,15 +930,15 @@ class AdaptiveQueryExecSuite
s"3 ShuffleQueryStages. Physical Plan: $adaptivePlan")
shuffleQueryStageExecs.foreach(sqse => assert(sqse.name.contains("ShuffleQueryStageExec-")))
// First ShuffleQueryStage is materialized so it needs to be canceled.
assert(shuffleQueryStageExecs(0).isMaterializationStarted(),
assert(shuffleQueryStageExecs(0).shuffle.isMaterializationStarted(),
"Materialization should be started.")
// Second ShuffleQueryStage materialization is failed so
// it is excluded from the cancellation due to earlyFailedStage.
assert(shuffleQueryStageExecs(1).isMaterializationStarted(),
assert(shuffleQueryStageExecs(1).shuffle.isMaterializationStarted(),
"Materialization should be started but it is failed.")
// Last ShuffleQueryStage is not materialized yet so it does not require
// to be canceled and it is just skipped from the cancellation.
assert(!shuffleQueryStageExecs(2).isMaterializationStarted(),
assert(!shuffleQueryStageExecs(2).shuffle.isMaterializationStarted(),
"Materialization should not be started.")
}
} finally {
Expand Down Expand Up @@ -976,7 +976,7 @@ class AdaptiveQueryExecSuite
broadcastQueryStageExecs.foreach { bqse =>
assert(bqse.name.contains("BroadcastQueryStageExec-"))
// Both BroadcastQueryStages are materialized at the beginning.
assert(bqse.isMaterializationStarted(),
assert(bqse.broadcast.isMaterializationStarted(),
s"${bqse.name}' s materialization should be started.")
}
}
Expand Down

0 comments on commit 1105282

Please sign in to comment.