Skip to content

Commit

Permalink
SPARK-47148 - Minor refactoring and define QueryStage name
Browse files Browse the repository at this point in the history
  • Loading branch information
erenavsarogullari committed Mar 6, 2024
1 parent ef8c50e commit 355aeb0
Show file tree
Hide file tree
Showing 2 changed files with 55 additions and 30 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -51,13 +51,30 @@ abstract class QueryStageExec extends LeafExecNode {
*/
val plan: SparkPlan

/**
* Name of this query stage which is unique in the entire query plan.
*/
val name: String = s"${this.getClass.getSimpleName}-$id"

/**
* This flag aims to detect if the stage materialization is started. This helps
* to avoid unnecessary stage materialization when the stage is canceled.
*/
private val materializationStarted = new AtomicBoolean()

/**
* Exposes status if the materialization is started
*/
def isMaterializationStarted(): Boolean = materializationStarted.get()

/**
* Materialize this query stage, to prepare for the execution, like submitting map stages,
* broadcasting data, etc. The caller side can use the returned [[Future]] to wait until this
* stage is ready.
*/
final def materialize(): Future[Any] = {
logDebug(s"Materialize query stage ${this.getClass.getSimpleName}: $id")
logDebug(s"Materialize query stage: $name")
materializationStarted.set(true)
doMaterialize()
}

Expand Down Expand Up @@ -148,17 +165,6 @@ abstract class QueryStageExec extends LeafExecNode {
*/
abstract class ExchangeQueryStageExec extends QueryStageExec {

/**
* This flag aims to detect if the stage materialization is started. This helps
* to avoid unnecessary stage materialization when the stage is canceled.
*/
@transient protected val materializationStarted = new AtomicBoolean()

/**
* Exposes status if the materialization is started
*/
def isMaterializationStarted(): Boolean = materializationStarted.get()

/**
* Cancel the stage materialization if in progress; otherwise do nothing.
*/
Expand Down Expand Up @@ -195,10 +201,7 @@ case class ShuffleQueryStageExec(

def advisoryPartitionSize: Option[Long] = shuffle.advisoryPartitionSize

@transient private lazy val shuffleFuture = {
materializationStarted.set(true)
shuffle.submitShuffleJob
}
@transient private lazy val shuffleFuture = shuffle.submitShuffleJob

override protected def doMaterialize(): Future[Any] = shuffleFuture

Expand All @@ -217,7 +220,7 @@ case class ShuffleQueryStageExec(
shuffleFuture match {
case action: FutureAction[MapOutputStatistics] if !action.isCompleted =>
action.cancel()
logInfo(s"${this.getClass.getSimpleName}-$id is cancelled.")
logInfo(s"$name is cancelled.")
case _ =>
}
}
Expand All @@ -228,7 +231,7 @@ case class ShuffleQueryStageExec(
* this method returns None, as there is no map statistics.
*/
def mapStats: Option[MapOutputStatistics] = {
assert(resultOption.get().isDefined, s"${getClass.getSimpleName} should already be ready")
assert(resultOption.get().isDefined, s"$name should already be ready")
val stats = resultOption.get().get.asInstanceOf[MapOutputStatistics]
Option(stats)
}
Expand All @@ -255,13 +258,8 @@ case class BroadcastQueryStageExec(
throw SparkException.internalError(s"wrong plan for broadcast stage:\n ${plan.treeString}")
}

@transient private lazy val broadcastFuture = {
materializationStarted.set(true)
broadcast.submitBroadcastJob
}

override protected def doMaterialize(): Future[Any] = {
broadcastFuture
broadcast.submitBroadcastJob
}

override def newReuseInstance(
Expand All @@ -278,7 +276,7 @@ case class BroadcastQueryStageExec(
if (isMaterializationStarted() && !broadcast.relationFuture.isDone) {
sparkContext.cancelJobsWithTag(broadcast.jobTag)
broadcast.relationFuture.cancel(true)
logInfo(s"${this.getClass.getSimpleName}-$id is cancelled.")
logInfo(s"$name is cancelled.")
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -903,8 +903,7 @@ class AdaptiveQueryExecSuite
SQLConf.CAN_CHANGE_CACHED_PLAN_OUTPUT_PARTITIONING.key -> "true",
SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "-1") {
withTable("bucketed_table1", "bucketed_table2", "bucketed_table3") {
val df =
(0 until 50).map(i => (i % 5, i % 13, i.toString)).toDF("i", "j", "k").as("df1")
val df = (0 until 50).map(i => (i % 5, i % 13, i.toString)).toDF("i", "j", "k")
df.write.format("parquet").bucketBy(8, "i").saveAsTable("bucketed_table1")
df.write.format("parquet").bucketBy(8, "i").saveAsTable("bucketed_table2")
df.write.format("parquet").bucketBy(8, "i").saveAsTable("bucketed_table3")
Expand Down Expand Up @@ -958,8 +957,7 @@ class AdaptiveQueryExecSuite
SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> "true",
SQLConf.CAN_CHANGE_CACHED_PLAN_OUTPUT_PARTITIONING.key -> "true") {
withTable("bucketed_table1", "bucketed_table2", "bucketed_table3") {
val df =
(0 until 50).map(i => (i % 5, i % 13, i.toString)).toDF("i", "j", "k").as("df1")
val df = (0 until 50).map(i => (i % 5, i % 13, i.toString)).toDF("i", "j", "k")
df.write.format("parquet").bucketBy(8, "i").saveAsTable("bucketed_table1")
df.write.format("parquet").bucketBy(8, "i").saveAsTable("bucketed_table2")
df.write.format("parquet").bucketBy(8, "i").saveAsTable("bucketed_table3")
Expand Down Expand Up @@ -995,13 +993,42 @@ class AdaptiveQueryExecSuite
assert(broadcastQueryStageExecs.length == 3, s"$adaptivePlan")
broadcastQueryStageExecs.foreach { bqse =>
assert(bqse.isMaterializationStarted(),
s"${bqse.getClass.getName}-${bqse.id}' s materialization should be started before " +
s"${bqse.name}' s materialization should be started before " +
"BroadcastQueryStage-1' s materialization is failed.")
}
}
}
}

test("SPARK-47148: Check AQE QueryStages names") {
withSQLConf(
SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> "true",
SQLConf.CAN_CHANGE_CACHED_PLAN_OUTPUT_PARTITIONING.key -> "true") {
withTable("bucketed_table1", "bucketed_table2") {
val df = (0 until 50).map(i => (i % 5, i % 13, i.toString)).toDF("i", "j", "k")
df.write.format("parquet").bucketBy(8, "i").saveAsTable("bucketed_table1")
df.write.format("parquet").bucketBy(8, "i").saveAsTable("bucketed_table2")

val df1 = spark.table("bucketed_table1").persist()
val df2 = spark.table("bucketed_table2").persist()
val joinedDF = df1.join(df2, Seq("i", "j", "k")).join(df1, Seq("i"))
.repartition(5).sort("i")
joinedDF.collect()

// Verify QueryStageExecs names
val adaptivePlanOfJoinedDF =
joinedDF.queryExecution.executedPlan.asInstanceOf[AdaptiveSparkPlanExec]
val queryStageExecs = collect(adaptivePlanOfJoinedDF) {
case qse: QueryStageExec => qse
}
assert(queryStageExecs.size == 7, s"$adaptivePlanOfJoinedDF")
assert(queryStageExecs.filter(_.name.contains("TableCacheQueryStageExec-")).size == 3)
assert(queryStageExecs.filter(_.name.contains("BroadcastQueryStageExec-")).size == 2)
assert(queryStageExecs.filter(_.name.contains("ShuffleQueryStageExec-")).size == 2)
}
}
}

test("SPARK-30403: AQE should handle InSubquery") {
withSQLConf(
SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> "true",
Expand Down

0 comments on commit 355aeb0

Please sign in to comment.