Skip to content

Commit

Permalink
[CARMEL-6113] Support BHJ fallback to SMJ in AQE (#1042)
Browse files Browse the repository at this point in the history
* [CARMEL-6113] Support BHJ fallback to SMJ in AQE

* fix code style

* Add log if FallbackBroadcastStage takes effect
  • Loading branch information
xingchaozh authored and GitHub Enterprise committed Aug 19, 2022
1 parent 20fb439 commit e12e8d3
Show file tree
Hide file tree
Showing 9 changed files with 371 additions and 22 deletions.
6 changes: 6 additions & 0 deletions core/src/main/scala/org/apache/spark/SparkException.scala
Original file line number Diff line number Diff line change
Expand Up @@ -122,3 +122,9 @@ private[spark] case class CarmelUnsupportedOperationException(
s"CMR-${errorCode.id} $message"
}
}

/**
* Exception thrown when the table cannot be broadcast.
*/
private[spark] case class CannotBroadcastException(message: String)
extends SparkException(message)
Original file line number Diff line number Diff line change
Expand Up @@ -1281,6 +1281,14 @@ object SQLConf {
.booleanConf
.createWithDefault(false)

val AUTO_APPLY_STAGE_FALLBACK_PLAN_ENABLED =
buildConf("spark.sql.applyStageFallbackPlan.enabled")
.internal()
.doc("When true, we will try to apply fallback plan for the failed stage in AQE.")
.version("3.0.0")
.booleanConf
.createWithDefault(false)

val CROSS_JOINS_ENABLED = buildConf("spark.sql.crossJoin.enabled")
.internal()
.doc("When false, we will throw an error if a query contains a cartesian product without " +
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.sql.execution

import org.apache.spark.SparkException

/**
* Exception thrown when the table cannot be broadcast.
*/
private[spark] case class CannotBroadcastTableException(message: String)
extends SparkException(message)
Original file line number Diff line number Diff line change
Expand Up @@ -32,9 +32,9 @@ import org.apache.spark.scheduler.RepeatableIterator
import org.apache.spark.sql.SparkSession
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.QueryPlanningContext
import org.apache.spark.sql.catalyst.expressions.{Attribute, DynamicPruningExpression, Literal}
import org.apache.spark.sql.catalyst.expressions.Attribute
import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, ReturnAnswer}
import org.apache.spark.sql.catalyst.rules.{Rule, RuleExecutor}
import org.apache.spark.sql.catalyst.rules.Rule
import org.apache.spark.sql.catalyst.trees.TreeNodeTag
import org.apache.spark.sql.execution._
import org.apache.spark.sql.execution.adaptive.AdaptiveSparkPlanExec._
Expand Down Expand Up @@ -228,11 +228,19 @@ case class AdaptiveSparkPlanExec(
val nextMsg = events.take()
val rem = new util.ArrayList[StageMaterializationEvent]()
events.drainTo(rem)

var stagesToFallback = Seq.empty[(QueryStageExec, Throwable, FallbackStageHandler)]
(Seq(nextMsg) ++ rem.asScala).foreach {
case StageSuccess(stage, res) =>
stage.resultOption.set(Some(res))
case StageFailure(stage, ex) =>
errors.append(ex)
ex match {
case CannotBroadcastTableException(_)
if conf.getConf(SQLConf.AUTO_APPLY_STAGE_FALLBACK_PLAN_ENABLED) =>
stagesToFallback =
Seq((stage, ex, FallbackBroadcastStage)) ++ stagesToFallback
case _ => errors.append(ex)
}
}

// In case of errors, we cancel all running stages and throw exception.
Expand All @@ -251,17 +259,30 @@ case class AdaptiveSparkPlanExec(
// the current physical plan. Once a new plan is adopted and both logical and physical
// plans are updated, we can clear the query stage list because at this point the two plans
// are semantically and physically in sync again.
val logicalPlan = replaceWithQueryStagesInLogicalPlan(currentLogicalPlan, stagesToReplace)
val (newPhysicalPlan, newLogicalPlan) = reOptimize(logicalPlan)
val origCost = costEvaluator.evaluateCost(currentPhysicalPlan)
val newCost = costEvaluator.evaluateCost(newPhysicalPlan)
if (deserveNewCost(origCost, newCost) ||

def optimize(logicalPlan: LogicalPlan, consideringCost: Boolean) = {
val (newPhysicalPlan, newLogicalPlan) = reOptimize(logicalPlan)
val origCost = costEvaluator.evaluateCost(currentPhysicalPlan)
val newCost = costEvaluator.evaluateCost(newPhysicalPlan)
if ((!consideringCost && currentPhysicalPlan != newPhysicalPlan) ||
deserveNewCost(origCost, newCost) ||
(newCost == origCost && currentPhysicalPlan != newPhysicalPlan)) {
logOnLevel(s"Plan changed from $currentPhysicalPlan to $newPhysicalPlan")
cleanUpTempTags(newPhysicalPlan)
currentPhysicalPlan = newPhysicalPlan
currentLogicalPlan = newLogicalPlan
stagesToReplace = Seq.empty[QueryStageExec]
logOnLevel(s"Plan changed from $currentPhysicalPlan to $newPhysicalPlan")
cleanUpTempTags(newPhysicalPlan)
currentPhysicalPlan = newPhysicalPlan
currentLogicalPlan = newLogicalPlan
stagesToReplace = Seq.empty[QueryStageExec]
}
}

val logicalPlan = replaceWithQueryStagesInLogicalPlan(currentLogicalPlan,
stagesToReplace)
optimize(logicalPlan, true)

if (stagesToFallback.nonEmpty) {
val logicalPlanWithoutFallback = replaceFallbackStages(logicalPlan,
stagesToFallback)
optimize(logicalPlanWithoutFallback, false)
}

// Now that some stages have finished, we can try creating new stages.
Expand Down Expand Up @@ -576,6 +597,54 @@ case class AdaptiveSparkPlanExec(
logicalPlan
}

private def replaceFallbackStages(
plan: LogicalPlan,
stagesToFallback: Seq[(QueryStageExec, Throwable, FallbackStageHandler)]): LogicalPlan = {
var logicalPlan = plan
stagesToFallback.foreach {
case (stage, ex, fallbackHandler) if currentPhysicalPlan.find(_.eq(stage)).isDefined =>
val logicalNodeOpt = stage.getTagValue(TEMP_LOGICAL_PLAN_TAG).orElse(stage.logicalLink)
assert(logicalNodeOpt.isDefined)
val logicalNode = logicalNodeOpt.get
val physicalNode = currentPhysicalPlan.collectFirst {
case p if p.eq(stage) ||
p.getTagValue(TEMP_LOGICAL_PLAN_TAG).exists(logicalNode.eq) ||
p.logicalLink.exists(logicalNode.eq) => p
}
assert(physicalNode.isDefined)

// Replace the corresponding logical node via fallback handler
val (newLogicalPlan, effectiveCount) =
fallbackHandler(logicalPlan, LogicalQueryStage(logicalNode, physicalNode.get))

if (effectiveCount == 0) {
cleanUpAndThrowException(stagesToFallback.map(_._2), None)
}

assert(newLogicalPlan != logicalPlan,
s"logicalNode: $stage.logicalPlan; " +
s"logicalPlan: $logicalPlan " +
s"physicalPlan: $currentPhysicalPlan" +
s"stage: $stage")
logicalPlan = newLogicalPlan
case _ =>
// Invalid stages
cleanUpAndThrowException(stagesToFallback.map(_._2), None)
}

// Cancel stages if all replaced success
stagesToFallback.foreach(s => {
try {
s._1.cancel()
} catch {
case NonFatal(t) =>
logError(s"Exception in cancelling query stage: ${s._1.treeString}", t)
}
})

logicalPlan
}

/**
* Re-optimize and run physical planning on the current logical plan based on the latest stats.
*/
Expand Down Expand Up @@ -718,3 +787,7 @@ case class StageSuccess(stage: QueryStageExec, result: Any) extends StageMateria
* The materialization of a query stage hit an error and failed.
*/
case class StageFailure(stage: QueryStageExec, error: Throwable) extends StageMaterializationEvent

trait FallbackStageHandler {
def apply(logicalPlan: LogicalPlan, stage: LogicalQueryStage): (LogicalPlan, Int)
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,83 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.sql.execution.adaptive

import org.apache.spark.internal.Logging
import org.apache.spark.sql.catalyst.planning.ExtractEquiJoinKeys
import org.apache.spark.sql.catalyst.plans.logical.{HintInfo, Join, LogicalPlan, NO_BROADCAST_HASH}

/**
* This fallback handler detects stage that has been proved cannot be broadcast and
* adds a no-broadcast-hash-join hint to avoid it being broadcast.
*/

object FallbackBroadcastStage extends FallbackStageHandler with Logging {
override def apply(logicalPlan: LogicalPlan, stage: LogicalQueryStage): (LogicalPlan, Int) = {
var effectiveCount = 0
val originLogicalQueryStage = stage match {
case LogicalQueryStage(l@LogicalQueryStage(_, _), _) => l
case _ => stage
}
val newLogicalPlan = logicalPlan.transformDown {
// case j@Join(left, _, _, _, hint)
case j @ ExtractEquiJoinKeys(_, _, _, _, left, _, hint) if left.sameResult(stage) ||
left.sameResult(originLogicalQueryStage) =>
var newHint = hint
if (!hint.leftHint.exists(_.strategy.isDefined)) {
newHint = newHint.copy(leftHint =
Some(hint.leftHint.getOrElse(HintInfo()).copy(strategy = Some(NO_BROADCAST_HASH))))
}

if (newHint.ne(hint)) {
effectiveCount = effectiveCount + 1

logInfo(s"FallbackBroadcastStage takes effect for " +
s"logicalPlan: $logicalPlan " +
s"stage: $stage")
j.copy(hint = newHint, left = if (left.sameResult(stage)) {
stage.logicalPlan
} else originLogicalQueryStage.logicalPlan)
} else {
j
}

// case j@Join(_, right, _, _, hint)
case j @ ExtractEquiJoinKeys(_, _, _, _, _, right, hint) if right.sameResult(stage) ||
right.sameResult(originLogicalQueryStage) =>
var newHint = hint
if (!hint.rightHint.exists(_.strategy.isDefined)) {
newHint = newHint.copy(rightHint =
Some(hint.rightHint.getOrElse(HintInfo()).copy(strategy = Some(NO_BROADCAST_HASH))))
}

if (newHint.ne(hint)) {
effectiveCount = effectiveCount + 1

logInfo(s"FallbackBroadcastStage takes effect for " +
s"logicalPlan: $logicalPlan " +
s"stage: $stage")
j.copy(hint = newHint, right = if (right.sameResult(stage)) {
stage.logicalPlan
} else originLogicalQueryStage.logicalPlan)
} else {
j
}
}
(newLogicalPlan, effectiveCount)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.UnsafeRow
import org.apache.spark.sql.catalyst.plans.logical.Statistics
import org.apache.spark.sql.catalyst.plans.physical.{BroadcastMode, BroadcastPartitioning, IdentityBroadcastMode, Partitioning}
import org.apache.spark.sql.execution.{BroadcastRelationManager, FileSourceScanExec, SparkPlan, SQLExecution}
import org.apache.spark.sql.execution.{BroadcastRelationManager, CannotBroadcastTableException, FileSourceScanExec, SparkPlan, SQLExecution}
import org.apache.spark.sql.execution.BroadcastRelationManager.BroadcastRelationInfo
import org.apache.spark.sql.execution.joins.{HashedRelation, TokenTreeBroadcastMode, TreeRelation}
import org.apache.spark.sql.execution.metric.SQLMetrics
Expand Down Expand Up @@ -140,13 +140,13 @@ case class BroadcastExchangeExec(
// Use executeCollect/executeCollectIterator to avoid conversion to Scala types
val (numRows, totalBytes, input) = child.executeCollectIterator()
if (numRows >= broadcastMaxRowNum) {
throw new SparkException(
throw CannotBroadcastTableException(
s"Cannot broadcast the table with more than" +
s" $broadcastMaxRowNum rows: $numRows rows. " +
s"Please analyze these tables through: $analyzeTblMsg")
}
if (totalBytes >= broadcastMaxDataSize) {
throw new SparkException(
throw CannotBroadcastTableException(
s"Cannot broadcast the table that is larger than " +
s"${broadcastMaxDataSize >> 30}GB: ${totalBytes >> 30}GB")
}
Expand Down Expand Up @@ -193,7 +193,7 @@ case class BroadcastExchangeExec(

longMetric("dataSize") += dataSize
if (dataSize >= MAX_BROADCAST_TABLE_BYTES) {
throw new SparkException(
throw CannotBroadcastTableException(
s"Cannot broadcast the table that is larger than 8GB: ${dataSize >> 30} GB")
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,10 +30,9 @@ import org.apache.spark.rdd.RDD
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.{Attribute, Expression, InterpretedProjection, Projection, UnsafeRow}
import org.apache.spark.sql.catalyst.optimizer.{BuildLeft, BuildRight, BuildSide}
import org.apache.spark.sql.catalyst.planning.RangeJoin
import org.apache.spark.sql.catalyst.plans.logical.Statistics
import org.apache.spark.sql.catalyst.plans.physical.{BroadcastMode, BroadcastPartitioning, Partitioning}
import org.apache.spark.sql.execution.{BroadcastRelationManager, SparkPlan, SQLExecution}
import org.apache.spark.sql.execution.{BroadcastRelationManager, CannotBroadcastTableException, SparkPlan, SQLExecution}
import org.apache.spark.sql.execution.BroadcastRelationManager.BroadcastRelationInfo
import org.apache.spark.sql.execution.joins.{RangeIndex, RangeInfo}
import org.apache.spark.sql.execution.joins.RangeIndex.RangeEvent
Expand Down Expand Up @@ -106,12 +105,12 @@ case class BroadcastRangeExchangeExec(mode: BroadcastMode, child: SparkPlan, ran
val (buildRows, totalBytes, input) = child.executeCollectIterator()

if (buildRows >= broadcastMaxRowNum) {
throw new SparkException(
throw CannotBroadcastTableException(
s"Cannot broadcast the table with more than" +
s" $broadcastMaxRowNum rows: $buildRows rows")
}
if (totalBytes >= broadcastMaxDataSize) {
throw new SparkException(
throw CannotBroadcastTableException(
s"Cannot broadcast the table that is larger than " +
s"${broadcastMaxDataSize >> 30}GB: ${totalBytes >> 30}GB")
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ import java.sql.{Date, Timestamp}

import org.scalatest.GivenWhenThen

import org.apache.spark.SparkException
import org.apache.spark.sql.catalyst.TableIdentifier
import org.apache.spark.sql.catalyst.expressions.{CodegenObjectFactoryMode, DynamicPruningExpression, Expression, Literal}
import org.apache.spark.sql.execution.{FileSourceScanExec, FilterExec, InSubqueryExec, ReusedSubqueryExec, SparkPlan, SubqueryBroadcastExec, SubqueryExec}
Expand All @@ -31,7 +32,6 @@ import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.test.SharedSparkSession
import org.apache.spark.sql.types.{Decimal, StringType}


abstract class DynamicDataPruningSuiteBase
extends QueryTest
with SharedSparkSession
Expand Down Expand Up @@ -252,6 +252,33 @@ abstract class DynamicDataPruningSuiteBase
}
}

test("simple aggregate triggers shuffle pruning failure and " +
"should not handled by stage fallback") {
withSQLConf(SQLConf.DYNAMIC_DATA_PRUNING_ENABLED.key -> "true",
SQLConf.DYNAMIC_DATA_PRUNING_SIDE_THRESHOLD.key -> "10K",
SQLConf.BROADCAST_MAX_ROW_NUM.key -> "1",
SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> "true",
SQLConf.AUTO_APPLY_STAGE_FALLBACK_PLAN_ENABLED.key -> "true") {
val df = sql(
"""
|SELECT t11.a,
| t11.cnt
|FROM (SELECT a,
| Count(b) AS cnt
| FROM t1
| GROUP BY a) t11
| JOIN t3
| ON t11.a = t3.a AND t3.b < 2
|""".stripMargin)

val e = intercept[SparkException] {
checkDataPruningPredicate(df, false, true)
checkAnswer(df, Row(0, 1) :: Row(1, 1) :: Nil)
}
assert(e.getCause.getMessage.contains("Cannot broadcast the table"))
}
}

test("dynamic filter push down to datasource") {
withSQLConf(
SQLConf.DYNAMIC_DATA_PRUNING_SIDE_THRESHOLD.key -> "10K",
Expand Down
Loading

0 comments on commit e12e8d3

Please sign in to comment.