Skip to content

Commit

Permalink
SPARK-47148 - Minor refactoring
Browse files Browse the repository at this point in the history
  • Loading branch information
erenavsarogullari committed Mar 1, 2024
1 parent 4639938 commit 53dd089
Show file tree
Hide file tree
Showing 3 changed files with 30 additions and 44 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -791,8 +791,7 @@ case class AdaptiveSparkPlanExec(
// so we should avoid calling cancel on it to re-trigger the failure again.
case s: ExchangeQueryStageExec if !earlyFailedStage.contains(s.id) =>
try {
val status = s.cancel()
logInfo(s"$s cancellation has ended with status: $status")
s.cancel()
} catch {
case NonFatal(t) =>
logError(s"Exception in cancelling query stage: ${s.treeString}", t)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,6 @@ import org.apache.spark.sql.catalyst.plans.logical.Statistics
import org.apache.spark.sql.catalyst.plans.physical.Partitioning
import org.apache.spark.sql.columnar.CachedBatch
import org.apache.spark.sql.execution._
import org.apache.spark.sql.execution.adaptive.StageCancellationStatus.{StageCancellationStatus, UNPERFORMED}
import org.apache.spark.sql.execution.columnar.InMemoryTableScanExec
import org.apache.spark.sql.execution.exchange._
import org.apache.spark.sql.vectorized.ColumnarBatch
Expand Down Expand Up @@ -154,19 +153,17 @@ abstract class ExchangeQueryStageExec extends QueryStageExec {
* unnecessary stage materialization when the stage is canceled. This flag aims to detect
* this kind of cases and skip unnecessary stage materialization.
*/
@transient protected val isMaterializationStarted = new AtomicBoolean()

@transient protected var stageCancellationStatus: StageCancellationStatus = UNPERFORMED
@transient protected val materializationStarted = new AtomicBoolean()

/**
* Exposes the stage cancellation status.
* Exposes status if the materialization is started
*/
def getCancellationStatus(): StageCancellationStatus = stageCancellationStatus
def isMaterializationStarted() = materializationStarted.get()

/**
* Cancel the stage materialization if in progress; otherwise do nothing.
*/
def cancel(): StageCancellationStatus
def cancel(): Unit

/**
* The canonicalized plan before applying query stage optimizer rules.
Expand All @@ -178,11 +175,6 @@ abstract class ExchangeQueryStageExec extends QueryStageExec {
def newReuseInstance(newStageId: Int, newOutput: Seq[Attribute]): ExchangeQueryStageExec
}

object StageCancellationStatus extends Enumeration {
type StageCancellationStatus = Value
val UNPERFORMED, COMPLETED, SKIPPED = Value
}

/**
* A shuffle query stage whose child is a [[ShuffleExchangeLike]] or [[ReusedExchangeExec]].
*
Expand All @@ -205,7 +197,7 @@ case class ShuffleQueryStageExec(
def advisoryPartitionSize: Option[Long] = shuffle.advisoryPartitionSize

@transient private lazy val shuffleFuture = {
isMaterializationStarted.set(true)
materializationStarted.set(true)
shuffle.submitShuffleJob
}

Expand All @@ -221,18 +213,16 @@ case class ShuffleQueryStageExec(
reuse
}

override def cancel(): StageCancellationStatus = {
if (isMaterializationStarted.get()) {
override def cancel(): Unit = {
if (isMaterializationStarted()) {
shuffleFuture match {
case action: FutureAction[MapOutputStatistics] if !action.isCompleted =>
case action: FutureAction[MapOutputStatistics] if !action.isCompleted => {
action.cancel()
stageCancellationStatus = StageCancellationStatus.COMPLETED
logInfo(s"${this.getClass.getSimpleName}-$id is cancelled.")
}
case _ =>
}
} else {
stageCancellationStatus = StageCancellationStatus.SKIPPED
}
stageCancellationStatus
}

/**
Expand Down Expand Up @@ -268,7 +258,7 @@ case class BroadcastQueryStageExec(
}

@transient private lazy val broadcastFuture = {
isMaterializationStarted.set(true)
materializationStarted.set(true)
broadcast.submitBroadcastJob
}

Expand All @@ -286,15 +276,12 @@ case class BroadcastQueryStageExec(
reuse
}

override def cancel(): StageCancellationStatus = {
if (isMaterializationStarted.get() && !broadcast.relationFuture.isDone) {
override def cancel(): Unit = {
if (isMaterializationStarted() && !broadcast.relationFuture.isDone) {
sparkContext.cancelJobsWithTag(broadcast.jobTag)
broadcast.relationFuture.cancel(true)
stageCancellationStatus = StageCancellationStatus.COMPLETED
} else {
stageCancellationStatus = StageCancellationStatus.SKIPPED
logInfo(s"${this.getClass.getSimpleName}-$id is cancelled.")
}
stageCancellationStatus
}

override def getRuntimeStatistics: Statistics = broadcast.runtimeStatistics
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -939,20 +939,21 @@ class AdaptiveQueryExecSuite
}
assert(shuffleQueryStageExecs.length == 3, s"$adaptivePlan")
// First ShuffleQueryStage is materialized so it needs to be canceled.
assert(shuffleQueryStageExecs(0).getCancellationStatus() ==
StageCancellationStatus.COMPLETED)
assert(shuffleQueryStageExecs(0).isMaterializationStarted(),
"Materialization should be started.")
// Second ShuffleQueryStage materialization is failed so
// it is earlyFailedStage and it is excluded from the cancellation.
assert(shuffleQueryStageExecs(1).getCancellationStatus() ==
StageCancellationStatus.UNPERFORMED)
// Last ShuffleQueryStage is not materialized so it does not require to be canceled.
assert(shuffleQueryStageExecs(2).getCancellationStatus() ==
StageCancellationStatus.SKIPPED)
// it is excluded from the cancellation due to earlyFailedStage.
assert(shuffleQueryStageExecs(1).isMaterializationStarted(),
"Materialization should be started but it is failed.")
// Last ShuffleQueryStage is not materialized yet so it does not require
// to be canceled and it is just skipped from the cancellation.
assert(!shuffleQueryStageExecs(2).isMaterializationStarted(),
"Materialization should not be started.")
}
}
}

test("SPARK-47148: Check AQE BroadcastQueryStages cancellation status") {
test("SPARK-47148: Check if BroadcastQueryStage materialization is started") {
withSQLConf(
SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> "true",
SQLConf.CAN_CHANGE_CACHED_PLAN_OUTPUT_PARTITIONING.key -> "true") {
Expand Down Expand Up @@ -992,12 +993,11 @@ class AdaptiveQueryExecSuite
case r: BroadcastQueryStageExec => r
}
assert(broadcastQueryStageExecs.length == 3, s"$adaptivePlan")
assert(broadcastQueryStageExecs(0).getCancellationStatus() ==
StageCancellationStatus.SKIPPED)
assert(broadcastQueryStageExecs(1).getCancellationStatus() ==
StageCancellationStatus.COMPLETED)
assert(broadcastQueryStageExecs(2).getCancellationStatus() ==
StageCancellationStatus.COMPLETED)
broadcastQueryStageExecs.foreach { bqse =>
assert(bqse.isMaterializationStarted(),
s"${bqse.getClass.getName}-${bqse.id}' s materialization should be started before " +
"BroadcastQueryStage-1' s materialization is failed.")
}
}
}
}
Expand Down

0 comments on commit 53dd089

Please sign in to comment.