diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/CoalesceShufflePartitions.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/CoalesceShufflePartitions.scala index db4a6b7dcf2eb..1bbc26f3e52ed 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/CoalesceShufflePartitions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/CoalesceShufflePartitions.scala @@ -75,8 +75,8 @@ case class CoalesceShufflePartitions(session: SparkSession) extends AQEShuffleRe val minNumPartitionsByGroup = if (coalesceGroups.length == 1) { Seq(math.max(minNumPartitions, 1)) } else { - val sizes = - coalesceGroups.map(_.flatMap(_.shuffleStage.mapStats.map(_.bytesByPartitionId.sum)).sum) + val sizes = coalesceGroups.map( + _.shuffleStages.flatMap(_.shuffleStage.mapStats.map(_.bytesByPartitionId.sum)).sum) val totalSize = sizes.sum sizes.map { size => val num = if (totalSize > 0) { @@ -90,8 +90,8 @@ case class CoalesceShufflePartitions(session: SparkSession) extends AQEShuffleRe val specsMap = mutable.HashMap.empty[Int, Seq[ShufflePartitionSpec]] // Coalesce partitions for each coalesce group independently. - coalesceGroups.zip(minNumPartitionsByGroup).foreach { case (shuffleStages, minNumPartitions) => - val advisoryTargetSize = advisoryPartitionSize(shuffleStages) + coalesceGroups.zip(minNumPartitionsByGroup).foreach { case (coalesceGroup, minNumPartitions) => + val advisoryTargetSize = advisoryPartitionSize(coalesceGroup) val minPartitionSize = if (Utils.isTesting) { // In the tests, we usually set the target size to a very small value that is even smaller // than the default value of the min partition size. Here we also adjust the min partition @@ -103,14 +103,14 @@ case class CoalesceShufflePartitions(session: SparkSession) extends AQEShuffleRe } val newPartitionSpecs = ShufflePartitionsUtil.coalescePartitions( - shuffleStages.map(_.shuffleStage.mapStats), - shuffleStages.map(_.partitionSpecs), + coalesceGroup.shuffleStages.map(_.shuffleStage.mapStats), + coalesceGroup.shuffleStages.map(_.partitionSpecs), advisoryTargetSize = advisoryTargetSize, minNumPartitions = minNumPartitions, minPartitionSize = minPartitionSize) if (newPartitionSpecs.nonEmpty) { - shuffleStages.zip(newPartitionSpecs).map { case (stageInfo, partSpecs) => + coalesceGroup.shuffleStages.zip(newPartitionSpecs).map { case (stageInfo, partSpecs) => specsMap.put(stageInfo.shuffleStage.id, partSpecs) } } @@ -126,9 +126,12 @@ case class CoalesceShufflePartitions(session: SparkSession) extends AQEShuffleRe // data sources may request a particular advisory partition size for the final write stage // if it happens, the advisory partition size will be set in ShuffleQueryStageExec // only one shuffle stage is expected in such cases - private def advisoryPartitionSize(shuffleStages: Seq[ShuffleStageInfo]): Long = { + private def advisoryPartitionSize(coalesceGroup: CoalesceGroup): Long = { + if (coalesceGroup.hasExplodingJoin) { + return conf.getConf(SQLConf.COALESCE_PARTITIONS_MIN_PARTITION_SIZE) + } val defaultAdvisorySize = conf.getConf(SQLConf.ADVISORY_PARTITION_SIZE_IN_BYTES) - shuffleStages match { + coalesceGroup.shuffleStages match { case Seq(stage) => stage.shuffleStage.advisoryPartitionSize.getOrElse(defaultAdvisorySize) case _ => @@ -143,18 +146,18 @@ case class CoalesceShufflePartitions(session: SparkSession) extends AQEShuffleRe * 1) all leaf nodes of this child are exchange stages; and * 2) all these shuffle stages support coalescing. */ - private def collectCoalesceGroups(plan: SparkPlan): Seq[Seq[ShuffleStageInfo]] = plan match { + private def collectCoalesceGroups( + plan: SparkPlan, + hasExplodingJoin: Boolean = false): Seq[CoalesceGroup] = plan match { case r @ AQEShuffleReadExec(q: ShuffleQueryStageExec, _) if isSupported(q.shuffle) => - Seq(collectShuffleStageInfos(r)) - case unary: UnaryExecNode => collectCoalesceGroups(unary.child) - case union: UnionExec => union.children.flatMap(collectCoalesceGroups) - case join: CartesianProductExec => join.children.flatMap(collectCoalesceGroups) - // Note that, `BroadcastQueryStageExec` is a valid case: - // If a join has been optimized from shuffled join to broadcast join, then the one side is - // `BroadcastQueryStageExec` and other side is `ShuffleQueryStageExec`. It can coalesce the - // shuffle side as we do not expect broadcast exchange has same partition number. - case join: BroadcastHashJoinExec => join.children.flatMap(collectCoalesceGroups) - case join: BroadcastNestedLoopJoinExec => join.children.flatMap(collectCoalesceGroups) + Seq(CoalesceGroup(collectShuffleStageInfos(r), hasExplodingJoin)) + case unary: UnaryExecNode => collectCoalesceGroups(unary.child, hasExplodingJoin) + // If a plan node does not need compatible data partitioning for its children, then each of its + // child can be an individual coalesce group and Spark will apply shuffle partitions coalescing + // for them independently, + case p if !childrenNeedCompatiblePartitioning(p) => + val hasExplodingJoinSoFar = hasExplodingJoin || isExplodingJoin(p) + p.children.flatMap(collectCoalesceGroups(_, hasExplodingJoinSoFar)) // If not all leaf nodes are exchange query stages, it's not safe to reduce the number of // shuffle partitions, because we may break the assumption that all children of a spark plan // have same number of output partitions. @@ -163,13 +166,30 @@ case class CoalesceShufflePartitions(session: SparkSession) extends AQEShuffleRe // ShuffleExchanges introduced by repartition do not support partition number change. // We change the number of partitions only if all the ShuffleExchanges support it. if (shuffleStages.forall(s => isSupported(s.shuffleStage.shuffle))) { - Seq(shuffleStages) + // The recursion stops here, we need to call `p.exists(isExplodingJoin)` and find out if + // there is any exploding join in this sub-plan-tree. + Seq(CoalesceGroup(shuffleStages, hasExplodingJoin || p.exists(isExplodingJoin))) } else { Seq.empty } case _ => Seq.empty } + private def childrenNeedCompatiblePartitioning(p: SparkPlan): Boolean = p match { + // TODO: match more plan nodes here. + case _: UnionExec => false + case _: CartesianProductExec => false + case _: BroadcastHashJoinExec => false + case _: BroadcastNestedLoopJoinExec => false + case _ => true + } + + private def isExplodingJoin(p: SparkPlan): Boolean = p match { + case _: BroadcastNestedLoopJoinExec => true + case _: CartesianProductExec => true + case _ => false + } + private def collectShuffleStageInfos(plan: SparkPlan): Seq[ShuffleStageInfo] = plan match { case ShuffleStageInfo(stage, specs) => Seq(new ShuffleStageInfo(stage, specs)) case _ => plan.children.flatMap(collectShuffleStageInfos) @@ -202,3 +222,7 @@ private object ShuffleStageInfo { case _ => None } } + +private case class CoalesceGroup( + shuffleStages: Seq[ShuffleStageInfo], + hasExplodingJoin: Boolean) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/adaptive/AdaptiveQueryExecSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/adaptive/AdaptiveQueryExecSuite.scala index d1d83f96c6702..f528c5584fee9 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/adaptive/AdaptiveQueryExecSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/adaptive/AdaptiveQueryExecSuite.scala @@ -2861,6 +2861,26 @@ class AdaptiveQueryExecSuite val unionDF = aggDf1.union(aggDf2) checkAnswer(unionDF.select("id").distinct(), Seq(Row(null))) } + + test("SPARK-47247: coalesce differently for BNLJ") { + Seq(true, false).foreach { expectCoalesce => + val minPartitionSize = if (expectCoalesce) "64MB" else "1B" + withSQLConf( + SQLConf.ADVISORY_PARTITION_SIZE_IN_BYTES.key -> "64MB", + SQLConf.COALESCE_PARTITIONS_MIN_PARTITION_SIZE.key -> minPartitionSize) { + val (_, adaptivePlan) = runAdaptiveAndVerifyResult( + "SELECT /*+ broadcast(testData2) */ * " + + "FROM (SELECT value v, max(key) k from testData group by value) " + + "JOIN testData2 ON k + a > 0") + val bnlj = findTopLevelBroadcastNestedLoopJoin(adaptivePlan) + assert(bnlj.size == 1) + val coalescedReads = collect(adaptivePlan) { + case read: AQEShuffleReadExec if read.isCoalescedRead => read + } + assert(coalescedReads.nonEmpty == expectCoalesce) + } + } + } } /**