From 28c107532448e9681a92edf12825b009bc0404ab Mon Sep 17 00:00:00 2001 From: Chirag Singh Date: Fri, 14 Nov 2025 09:37:16 -0800 Subject: [PATCH 1/4] fixes --- .../exchange/EnsureRequirements.scala | 33 ++++ .../KeyGroupedPartitioningSuite.scala | 145 ++++++++++++++++++ .../exchange/EnsureRequirementsSuite.scala | 71 +++++---- 3 files changed, 222 insertions(+), 27 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/EnsureRequirements.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/EnsureRequirements.scala index b97d765afcf7..13a0f7b8a812 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/EnsureRequirements.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/EnsureRequirements.scala @@ -140,6 +140,13 @@ case class EnsureRequirements( // Choose all the specs that can be used to shuffle other children val candidateSpecs = specs .filter(_._2.canCreatePartitioning) + .filter { + // To choose a KeyGroupedShuffleSpec, we must be able to push down SPJ parameters into + // the scan (for join key positions). If these parameters can't be pushed down, this + // spec can't be used to shuffle other children. + case (idx, _: KeyGroupedShuffleSpec) => canPushDownSPJParamsToScan(children(idx)) + case _ => true + } .filter(p => !shouldConsiderMinParallelism || children(p._1).outputPartitioning.numPartitions >= conf.defaultNumShufflePartitions) val bestSpecOpt = if (candidateSpecs.isEmpty) { @@ -402,6 +409,26 @@ case class EnsureRequirements( } } + /** + * Whether SPJ params can be pushed down to the leaf nodes of a physical plan. For a plan to be + * eligible for SPJ parameter pushdown, all leaf nodes must be a KeyGroupedPartitioning-aware + * scan. + * + * Notably, if the leaf of `plan` is an [[RDDScanExec]] created by checkpointing a DSv2 scan, the + * reported partitioning will be a [[KeyGroupedPartitioning]], but this plan will _not_ be + * eligible for SPJ parameter pushdown (as the partitioning is static and can't be easily + * re-grouped or padded with empty partitions according to the partition values on the other side + * of the join). + */ + private def canPushDownSPJParamsToScan(plan: SparkPlan): Boolean = { + plan.collectLeaves().forall { + case _: KeyGroupedPartitionedScan[_] => true + case f: FileSourceScanExec => + f.relation.location.isInstanceOf[KeyGroupedPartitionedScan[_]] + case _ => false + } + } + /** * Checks whether two children, `left` and `right`, of a join operator have compatible * `KeyGroupedPartitioning`, and can benefit from storage-partitioned join. @@ -413,6 +440,12 @@ case class EnsureRequirements( left: SparkPlan, right: SparkPlan, requiredChildDistribution: Seq[Distribution]): Option[Seq[SparkPlan]] = { + // If SPJ params can't be pushed down to either the left or right side, it's unsafe to do an + // SPJ. + if (!canPushDownSPJParamsToScan(left) || !canPushDownSPJParamsToScan(right)) { + return None + } + parent match { case smj: SortMergeJoinExec => checkKeyGroupCompatible(left, right, smj.joinType, requiredChildDistribution) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/connector/KeyGroupedPartitioningSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/connector/KeyGroupedPartitioningSuite.scala index d4ae597811dd..d34664403429 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/connector/KeyGroupedPartitioningSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/connector/KeyGroupedPartitioningSuite.scala @@ -34,6 +34,7 @@ import org.apache.spark.sql.execution.datasources.v2.BatchScanExec import org.apache.spark.sql.execution.datasources.v2.DataSourceV2ScanRelation import org.apache.spark.sql.execution.exchange.{ShuffleExchangeExec, ShuffleExchangeLike} import org.apache.spark.sql.execution.joins.SortMergeJoinExec +import org.apache.spark.sql.functions.{col, max} import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.internal.SQLConf._ import org.apache.spark.sql.types._ @@ -2626,4 +2627,148 @@ class KeyGroupedPartitioningSuite extends DistributionAndOrderingSuiteBase { assert(scans.forall(_.inputRDD.partitions.length == 2)) } } + + test("SPARK-53322: checkpointed scans avoid shuffles for aggregates") { + withTempDir { dir => + spark.sparkContext.setCheckpointDir(dir.getPath) + val items_partitions = Array(identity("id")) + createTable(items, itemsColumns, items_partitions) + sql(s"INSERT INTO testcat.ns.$items VALUES " + + s"(1, 'aa', 40.0, cast('2020-01-01' as timestamp)), " + + s"(1, 'aa', 41.0, cast('2020-01-02' as timestamp)), " + + s"(2, 'bb', 10.0, cast('2020-01-01' as timestamp)), " + + s"(3, 'cc', 15.5, cast('2020-02-01' as timestamp))") + + val scanDF = spark.read.table(s"testcat.ns.$items").checkpoint() + val df = scanDF.groupBy("id").agg(max("price").as("res")).select("res") + checkAnswer(df.sort("res"), Seq(Row(10.0), Row(15.5), Row(41.0))) + + val shuffles = collectAllShuffles(df.queryExecution.executedPlan) + assert(shuffles.isEmpty, + "should not contain shuffle when not grouping by partition values") + } + } + + test("SPARK-53322: checkpointed scans aren't used for SPJ") { + withTempDir { dir => + spark.sparkContext.setCheckpointDir(dir.getPath) + val items_partitions = Array(identity("id")) + createTable(items, itemsColumns, items_partitions) + sql(s"INSERT INTO testcat.ns.$items VALUES " + + s"(1, 'aa', 41.0, cast('2020-01-01' as timestamp)), " + + s"(2, 'bb', 10.0, cast('2020-01-02' as timestamp)), " + + s"(3, 'cc', 15.5, cast('2020-01-03' as timestamp))") + + val purchase_partitions = Array(identity("item_id")) + createTable(purchases, purchasesColumns, purchase_partitions) + sql(s"INSERT INTO testcat.ns.$purchases VALUES " + + s"(1, 40.0, cast('2020-01-01' as timestamp)), " + + s"(3, 25.5, cast('2020-01-03' as timestamp)), " + + s"(4, 20.0, cast('2020-01-04' as timestamp))") + + for { + pushdownValues <- Seq(true, false) + checkpointBothScans <- Seq(true, false) + } { + withSQLConf( + SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "-1", + SQLConf.V2_BUCKETING_PUSH_PART_VALUES_ENABLED.key -> pushdownValues.toString) { + val scanDF1 = spark.read.table(s"testcat.ns.$items").checkpoint().as("i") + val scanDF2 = if (checkpointBothScans) { + spark.read.table(s"testcat.ns.$purchases").checkpoint().as("p") + } else { + spark.read.table(s"testcat.ns.$purchases").as("p") + } + + val df = scanDF1 + .join(scanDF2, col("id") === col("item_id")) + .selectExpr("id", "name", "i.price AS purchase_price", "p.price AS sale_price") + .orderBy("id", "purchase_price", "sale_price") + checkAnswer( + df, + Seq(Row(1, "aa", 41.0, 40.0), Row(3, "cc", 15.5, 25.5)) + ) + // 1 shuffle for SORT and 2 shuffles for JOIN are expected. + assert(collectAllShuffles(df.queryExecution.executedPlan).length === 3) + } + } + } + } + + test("SPARK-53322: checkpointed scans can't shuffle other children on SPJ") { + withTempDir { dir => + spark.sparkContext.setCheckpointDir(dir.getPath) + val items_partitions = Array(identity("id")) + createTable(items, itemsColumns, items_partitions) + sql(s"INSERT INTO testcat.ns.$items VALUES " + + s"(1, 'aa', 41.0, cast('2020-01-01' as timestamp)), " + + s"(2, 'bb', 10.0, cast('2020-01-02' as timestamp)), " + + s"(3, 'cc', 15.5, cast('2020-01-03' as timestamp))") + + createTable(purchases, purchasesColumns, Array.empty) + sql(s"INSERT INTO testcat.ns.$purchases VALUES " + + s"(1, 40.0, cast('2020-01-01' as timestamp)), " + + s"(3, 25.5, cast('2020-01-03' as timestamp)), " + + s"(4, 20.0, cast('2020-01-04' as timestamp))") + + Seq(true, false).foreach { pushdownValues => + withSQLConf( + SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "-1", + SQLConf.V2_BUCKETING_SHUFFLE_ENABLED.key -> "true", + SQLConf.V2_BUCKETING_PUSH_PART_VALUES_ENABLED.key -> pushdownValues.toString) { + val scanDF1 = spark.read.table(s"testcat.ns.$items").checkpoint().as("i") + val scanDF2 = spark.read.table(s"testcat.ns.$purchases").as("p") + + val df = scanDF1 + .join(scanDF2, col("id") === col("item_id")) + .selectExpr("id", "name", "i.price AS purchase_price", "p.price AS sale_price") + .orderBy("id", "purchase_price", "sale_price") + checkAnswer( + df, + Seq(Row(1, "aa", 41.0, 40.0), Row(3, "cc", 15.5, 25.5)) + ) + // 1 shuffle for SORT and 2 shuffles for JOIN are expected. + assert(collectAllShuffles(df.queryExecution.executedPlan).length === 3) + } + } + } + } + + test("SPARK-53322: checkpointed scans can be shuffled by children on SPJ") { + withTempDir { dir => + spark.sparkContext.setCheckpointDir(dir.getPath) + val items_partitions = Array(identity("id")) + createTable(items, itemsColumns, items_partitions) + sql(s"INSERT INTO testcat.ns.$items VALUES " + + s"(1, 'aa', 41.0, cast('2020-01-01' as timestamp)), " + + s"(2, 'bb', 10.0, cast('2020-01-02' as timestamp)), " + + s"(3, 'cc', 15.5, cast('2020-01-03' as timestamp))") + + createTable(purchases, purchasesColumns, Array(identity("item_id"))) + sql(s"INSERT INTO testcat.ns.$purchases VALUES " + + s"(1, 40.0, cast('2020-01-01' as timestamp)), " + + s"(3, 25.5, cast('2020-01-03' as timestamp)), " + + s"(4, 20.0, cast('2020-01-04' as timestamp))") + + withSQLConf( + SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "-1", + SQLConf.V2_BUCKETING_SHUFFLE_ENABLED.key -> "true", + SQLConf.V2_BUCKETING_PUSH_PART_VALUES_ENABLED.key -> "true") { + val scanDF1 = spark.read.table(s"testcat.ns.$items").checkpoint().as("i") + val scanDF2 = spark.read.table(s"testcat.ns.$purchases").as("p") + + val df = scanDF1 + .join(scanDF2, col("id") === col("item_id")) + .selectExpr("id", "name", "i.price AS purchase_price", "p.price AS sale_price") + .orderBy("id", "purchase_price", "sale_price") + checkAnswer( + df, + Seq(Row(1, "aa", 41.0, 40.0), Row(3, "cc", 15.5, 25.5)) + ) + + // One shuffle for the sort and one shuffle for one side of the JOIN are expected. + assert(collectAllShuffles(df.queryExecution.executedPlan).length === 2) + } + } + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/exchange/EnsureRequirementsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/exchange/EnsureRequirementsSuite.scala index b94ca4673641..d6bf84ac7cf3 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/exchange/EnsureRequirementsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/exchange/EnsureRequirementsSuite.scala @@ -29,6 +29,7 @@ import org.apache.spark.sql.catalyst.statsEstimation.StatsTestPlan import org.apache.spark.sql.connector.catalog.functions._ import org.apache.spark.sql.execution.{DummySparkPlan, SortExec} import org.apache.spark.sql.execution.SparkPlan +import org.apache.spark.sql.execution.datasources.v2.BatchScanExec import org.apache.spark.sql.execution.joins.{ShuffledHashJoinExec, SortMergeJoinExec} import org.apache.spark.sql.execution.python.FlatMapCoGroupsInPandasExec import org.apache.spark.sql.execution.window.WindowExec @@ -45,6 +46,17 @@ class EnsureRequirementsSuite extends SharedSparkSession { private val EnsureRequirements = new EnsureRequirements() + /** Helper to add dummy BatchScanExec child to a dummy plan (to ensure SPJ can kick in). */ + private implicit class DummySparkPlanExt(dummyPlan: DummySparkPlan) { + def withDummyBatchScanChild: DummySparkPlan = { + dummyPlan.copy( + children = Seq(BatchScanExec(Seq.empty, null, Seq.empty, table = null)), + requiredChildDistribution = Seq(UnspecifiedDistribution), + requiredChildOrdering = Seq(Seq.empty) + ) + } + } + test("reorder should handle PartitioningCollection") { val plan1 = DummySparkPlan( outputPartitioning = PartitioningCollection(Seq( @@ -95,11 +107,11 @@ class EnsureRequirementsSuite extends SharedSparkSession { val plan1 = DummySparkPlan( outputPartitioning = KeyGroupedPartitioning(Seq( years(exprA), bucket(4, exprB), days(exprC)), 4) - ) + ).withDummyBatchScanChild val plan2 = DummySparkPlan( outputPartitioning = KeyGroupedPartitioning(Seq( years(exprB), bucket(4, exprA), days(exprD)), 4) - ) + ).withDummyBatchScanChild val smjExec = SortMergeJoinExec( exprB :: exprC :: exprA :: Nil, exprA :: exprD :: exprB :: Nil, Inner, None, plan1, plan2 @@ -117,7 +129,7 @@ class EnsureRequirementsSuite extends SharedSparkSession { val plan3 = DummySparkPlan( outputPartitioning = KeyGroupedPartitioning(Seq( bucket(4, exprD), days(exprA), years(exprC)), 4) - ) + ).withDummyBatchScanChild val smjExec2 = SortMergeJoinExec( exprB :: exprD :: exprC :: Nil, exprA :: exprC :: exprD :: Nil, Inner, None, plan1, plan3 @@ -780,8 +792,10 @@ class EnsureRequirementsSuite extends SharedSparkSession { // simplest case: identity transforms var plan1 = DummySparkPlan( outputPartitioning = KeyGroupedPartitioning(exprA :: exprB :: Nil, 5)) + .withDummyBatchScanChild var plan2 = DummySparkPlan( outputPartitioning = KeyGroupedPartitioning(exprA :: exprC :: Nil, 5)) + .withDummyBatchScanChild var smjExec = SortMergeJoinExec( exprA :: exprB :: Nil, exprA :: exprC :: Nil, Inner, None, plan1, plan2) EnsureRequirements.apply(smjExec) match { @@ -797,11 +811,11 @@ class EnsureRequirementsSuite extends SharedSparkSession { plan1 = DummySparkPlan( outputPartitioning = KeyGroupedPartitioning( bucket(4, exprA) :: bucket(16, exprB) :: Nil, 4) - ) + ).withDummyBatchScanChild plan2 = DummySparkPlan( outputPartitioning = KeyGroupedPartitioning( bucket(4, exprA) :: bucket(16, exprC) :: Nil, 4) - ) + ).withDummyBatchScanChild smjExec = SortMergeJoinExec( exprA :: exprB :: Nil, exprA :: exprC :: Nil, Inner, None, plan1, plan2) EnsureRequirements.apply(smjExec) match { @@ -817,13 +831,13 @@ class EnsureRequirementsSuite extends SharedSparkSession { plan1 = DummySparkPlan( outputPartitioning = KeyGroupedPartitioning( bucket(4, exprA) :: bucket(16, exprB) :: Nil, 4) - ) + ).withDummyBatchScanChild plan2 = DummySparkPlan( outputPartitioning = PartitioningCollection(Seq( KeyGroupedPartitioning(bucket(4, exprA) :: bucket(16, exprC) :: Nil, 4), KeyGroupedPartitioning(bucket(4, exprA) :: bucket(16, exprC) :: Nil, 4)) ) - ) + ).withDummyBatchScanChild smjExec = SortMergeJoinExec( exprA :: exprB :: Nil, exprA :: exprC :: Nil, Inner, None, plan1, plan2) EnsureRequirements.apply(smjExec) match { @@ -846,10 +860,10 @@ class EnsureRequirementsSuite extends SharedSparkSession { // bucket + years transforms from both sides plan1 = DummySparkPlan( outputPartitioning = KeyGroupedPartitioning(bucket(4, exprA) :: years(exprB) :: Nil, 4) - ) + ).withDummyBatchScanChild plan2 = DummySparkPlan( outputPartitioning = KeyGroupedPartitioning(bucket(4, exprA) :: years(exprC) :: Nil, 4) - ) + ).withDummyBatchScanChild smjExec = SortMergeJoinExec( exprA :: exprB :: Nil, exprA :: exprC :: Nil, Inner, None, plan1, plan2) EnsureRequirements.apply(smjExec) match { @@ -866,11 +880,11 @@ class EnsureRequirementsSuite extends SharedSparkSession { plan1 = DummySparkPlan( outputPartitioning = KeyGroupedPartitioning( bucket(4, exprA) :: bucket(4, exprB) :: Nil, 4) - ) + ).withDummyBatchScanChild plan2 = DummySparkPlan( outputPartitioning = KeyGroupedPartitioning( bucket(4, exprA) :: bucket(4, exprC) :: Nil, 4) - ) + ).withDummyBatchScanChild smjExec = SortMergeJoinExec( exprA :: exprB :: exprB :: Nil, exprA :: exprC :: exprC :: Nil, Inner, None, plan1, plan2) EnsureRequirements.apply(smjExec) match { @@ -887,11 +901,11 @@ class EnsureRequirementsSuite extends SharedSparkSession { var plan1 = DummySparkPlan( outputPartitioning = KeyGroupedPartitioning( bucket(4, exprB) :: years(exprC) :: Nil, 4) - ) + ).withDummyBatchScanChild var plan2 = DummySparkPlan( outputPartitioning = KeyGroupedPartitioning( bucket(4, exprC) :: years(exprB) :: Nil, 4) - ) + ).withDummyBatchScanChild // simple case var smjExec = SortMergeJoinExec( @@ -909,11 +923,11 @@ class EnsureRequirementsSuite extends SharedSparkSession { plan1 = DummySparkPlan( outputPartitioning = KeyGroupedPartitioning( bucket(4, exprA) :: years(exprB) :: Nil, 4) - ) + ).withDummyBatchScanChild plan2 = DummySparkPlan( outputPartitioning = KeyGroupedPartitioning( bucket(4, exprA) :: years(exprC) :: Nil, 4) - ) + ).withDummyBatchScanChild smjExec = SortMergeJoinExec( exprA :: exprB :: exprB :: Nil, exprA :: exprC :: exprC :: Nil, Inner, None, plan1, plan2) applyEnsureRequirementsWithSubsetKeys(smjExec) match { @@ -929,9 +943,11 @@ class EnsureRequirementsSuite extends SharedSparkSession { plan1 = DummySparkPlan( outputPartitioning = KeyGroupedPartitioning( years(exprA) :: bucket(4, exprB) :: days(exprA) :: Nil, 5)) + .withDummyBatchScanChild plan2 = DummySparkPlan( outputPartitioning = KeyGroupedPartitioning( years(exprA) :: bucket(4, exprC) :: days(exprA) :: Nil, 5)) + .withDummyBatchScanChild smjExec = SortMergeJoinExec( exprA :: exprB :: exprB :: Nil, exprA :: exprC :: exprC :: Nil, Inner, None, plan1, plan2) applyEnsureRequirementsWithSubsetKeys(smjExec) match { @@ -947,11 +963,11 @@ class EnsureRequirementsSuite extends SharedSparkSession { plan1 = DummySparkPlan( outputPartitioning = KeyGroupedPartitioning( bucket(4, exprA) :: bucket(4, exprB) :: Nil, 4) - ) + ).withDummyBatchScanChild plan2 = DummySparkPlan( outputPartitioning = KeyGroupedPartitioning( bucket(4, exprB) :: bucket(4, exprC) :: Nil, 4) - ) + ).withDummyBatchScanChild smjExec = SortMergeJoinExec( exprA :: exprB :: exprC :: Nil, exprA :: exprB :: exprC :: Nil, Inner, None, plan1, plan2) @@ -968,11 +984,11 @@ class EnsureRequirementsSuite extends SharedSparkSession { plan1 = DummySparkPlan( outputPartitioning = KeyGroupedPartitioning( bucket(4, exprA) :: bucket(4, exprB) :: Nil, 4) - ) + ).withDummyBatchScanChild plan2 = DummySparkPlan( outputPartitioning = KeyGroupedPartitioning( bucket(4, exprA) :: bucket(8, exprC) :: Nil, 4) - ) + ).withDummyBatchScanChild smjExec = SortMergeJoinExec( exprA :: exprB :: exprB :: Nil, exprA :: exprC :: exprC :: Nil, Inner, None, plan1, plan2) applyEnsureRequirementsWithSubsetKeys(smjExec) match { @@ -987,10 +1003,10 @@ class EnsureRequirementsSuite extends SharedSparkSession { // invalid case: partition key positions match but with different transforms plan1 = DummySparkPlan( outputPartitioning = KeyGroupedPartitioning(years(exprA) :: bucket(4, exprB) :: Nil, 4) - ) + ).withDummyBatchScanChild plan2 = DummySparkPlan( outputPartitioning = KeyGroupedPartitioning(days(exprA) :: bucket(4, exprC) :: Nil, 4) - ) + ).withDummyBatchScanChild smjExec = SortMergeJoinExec( exprA :: exprB :: exprB :: Nil, exprA :: exprC :: exprC :: Nil, Inner, None, plan1, plan2) applyEnsureRequirementsWithSubsetKeys(smjExec) match { @@ -1007,11 +1023,11 @@ class EnsureRequirementsSuite extends SharedSparkSession { plan1 = DummySparkPlan( outputPartitioning = KeyGroupedPartitioning( years(exprA) :: buckets(4, Seq(exprB, exprC)) :: Nil, 4) - ) + ).withDummyBatchScanChild plan2 = DummySparkPlan( outputPartitioning = KeyGroupedPartitioning( years(exprA) :: buckets(4, Seq(exprB, exprC)) :: Nil, 4) - ) + ).withDummyBatchScanChild smjExec = SortMergeJoinExec( exprA :: exprB :: exprB :: Nil, exprA :: exprC :: exprC :: Nil, Inner, None, plan1, plan2) applyEnsureRequirementsWithSubsetKeys(smjExec) match { @@ -1033,11 +1049,11 @@ class EnsureRequirementsSuite extends SharedSparkSession { var plan1 = DummySparkPlan( outputPartitioning = KeyGroupedPartitioning(bucket(4, exprB) :: bucket(8, exprC) :: Nil, leftPartValues.length, leftPartValues) - ) + ).withDummyBatchScanChild var plan2 = DummySparkPlan( outputPartitioning = KeyGroupedPartitioning(bucket(4, exprC) :: bucket(8, exprB) :: Nil, rightPartValues.length, rightPartValues) - ) + ).withDummyBatchScanChild // simple case var smjExec = SortMergeJoinExec( @@ -1059,7 +1075,7 @@ class EnsureRequirementsSuite extends SharedSparkSession { KeyGroupedPartitioning(bucket(4, exprB) :: bucket(8, exprC) :: Nil, leftPartValues.length, leftPartValues)) ) - ) + ).withDummyBatchScanChild smjExec = SortMergeJoinExec( exprA :: exprB :: exprC :: Nil, exprA :: exprC :: exprB :: Nil, Inner, None, plan1, plan2) @@ -1093,7 +1109,7 @@ class EnsureRequirementsSuite extends SharedSparkSession { rightPartValues.length, rightPartValues))) ) ) - ) + ).withDummyBatchScanChild smjExec = SortMergeJoinExec( exprA :: exprB :: exprC :: Nil, exprA :: exprC :: exprB :: Nil, Inner, None, plan1, plan2) @@ -1121,6 +1137,7 @@ class EnsureRequirementsSuite extends SharedSparkSession { val partitionValue = Seq(50, 51, 52).map(v => InternalRow.fromSeq(Seq(v))) val plan1 = DummySparkPlan(outputPartitioning = KeyGroupedPartitioning( identity(a1) :: Nil, 4, partitionValue)) + .withDummyBatchScanChild val plan2 = DummySparkPlan(outputPartitioning = SinglePartition) val smjExec = ShuffledHashJoinExec( From 61920d8f394ff4edd414f5f5c3bc43acad664ee8 Mon Sep 17 00:00:00 2001 From: Chirag Singh Date: Mon, 17 Nov 2025 09:10:03 -0800 Subject: [PATCH 2/4] fix --- .../spark/sql/execution/exchange/EnsureRequirements.scala | 2 -- 1 file changed, 2 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/EnsureRequirements.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/EnsureRequirements.scala index 13a0f7b8a812..0790fe357e52 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/EnsureRequirements.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/EnsureRequirements.scala @@ -423,8 +423,6 @@ case class EnsureRequirements( private def canPushDownSPJParamsToScan(plan: SparkPlan): Boolean = { plan.collectLeaves().forall { case _: KeyGroupedPartitionedScan[_] => true - case f: FileSourceScanExec => - f.relation.location.isInstanceOf[KeyGroupedPartitionedScan[_]] case _ => false } } From 260d00899ec8ca33188dc76e70035fcbba0bd17a Mon Sep 17 00:00:00 2001 From: Chirag Singh Date: Mon, 17 Nov 2025 15:28:47 -0800 Subject: [PATCH 3/4] fix --- .../exchange/EnsureRequirements.scala | 2 +- .../exchange/EnsureRequirementsSuite.scala | 144 +++++++++--------- 2 files changed, 69 insertions(+), 77 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/EnsureRequirements.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/EnsureRequirements.scala index 0790fe357e52..088ece6554c9 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/EnsureRequirements.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/EnsureRequirements.scala @@ -144,7 +144,7 @@ case class EnsureRequirements( // To choose a KeyGroupedShuffleSpec, we must be able to push down SPJ parameters into // the scan (for join key positions). If these parameters can't be pushed down, this // spec can't be used to shuffle other children. - case (idx, _: KeyGroupedShuffleSpec) => canPushDownSPJParamsToScan(children(idx)) + case (idx, _: KeyGroupedShuffleSpec) => canPushDownSPJParamsToScan(children(idx)) case _ => true } .filter(p => !shouldConsiderMinParallelism || diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/exchange/EnsureRequirementsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/exchange/EnsureRequirementsSuite.scala index d6bf84ac7cf3..1cc0d795d74f 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/exchange/EnsureRequirementsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/exchange/EnsureRequirementsSuite.scala @@ -46,17 +46,6 @@ class EnsureRequirementsSuite extends SharedSparkSession { private val EnsureRequirements = new EnsureRequirements() - /** Helper to add dummy BatchScanExec child to a dummy plan (to ensure SPJ can kick in). */ - private implicit class DummySparkPlanExt(dummyPlan: DummySparkPlan) { - def withDummyBatchScanChild: DummySparkPlan = { - dummyPlan.copy( - children = Seq(BatchScanExec(Seq.empty, null, Seq.empty, table = null)), - requiredChildDistribution = Seq(UnspecifiedDistribution), - requiredChildOrdering = Seq(Seq.empty) - ) - } - } - test("reorder should handle PartitioningCollection") { val plan1 = DummySparkPlan( outputPartitioning = PartitioningCollection(Seq( @@ -104,14 +93,14 @@ class EnsureRequirementsSuite extends SharedSparkSession { test("reorder should handle KeyGroupedPartitioning") { // partitioning on the left - val plan1 = DummySparkPlan( + val plan1 = new DummySparkPlanWithBatchScanChild( outputPartitioning = KeyGroupedPartitioning(Seq( years(exprA), bucket(4, exprB), days(exprC)), 4) - ).withDummyBatchScanChild - val plan2 = DummySparkPlan( + ) + val plan2 = new DummySparkPlanWithBatchScanChild( outputPartitioning = KeyGroupedPartitioning(Seq( years(exprB), bucket(4, exprA), days(exprD)), 4) - ).withDummyBatchScanChild + ) val smjExec = SortMergeJoinExec( exprB :: exprC :: exprA :: Nil, exprA :: exprD :: exprB :: Nil, Inner, None, plan1, plan2 @@ -126,10 +115,10 @@ class EnsureRequirementsSuite extends SharedSparkSession { } // partitioning on the right - val plan3 = DummySparkPlan( + val plan3 = new DummySparkPlanWithBatchScanChild( outputPartitioning = KeyGroupedPartitioning(Seq( bucket(4, exprD), days(exprA), years(exprC)), 4) - ).withDummyBatchScanChild + ) val smjExec2 = SortMergeJoinExec( exprB :: exprD :: exprC :: Nil, exprA :: exprC :: exprD :: Nil, Inner, None, plan1, plan3 @@ -790,12 +779,10 @@ class EnsureRequirementsSuite extends SharedSparkSession { test("Check with KeyGroupedPartitioning") { // simplest case: identity transforms - var plan1 = DummySparkPlan( - outputPartitioning = KeyGroupedPartitioning(exprA :: exprB :: Nil, 5)) - .withDummyBatchScanChild - var plan2 = DummySparkPlan( + var plan1 = new DummySparkPlanWithBatchScanChild( + KeyGroupedPartitioning(exprA :: exprB :: Nil, 5)) + var plan2 = new DummySparkPlanWithBatchScanChild( outputPartitioning = KeyGroupedPartitioning(exprA :: exprC :: Nil, 5)) - .withDummyBatchScanChild var smjExec = SortMergeJoinExec( exprA :: exprB :: Nil, exprA :: exprC :: Nil, Inner, None, plan1, plan2) EnsureRequirements.apply(smjExec) match { @@ -808,14 +795,14 @@ class EnsureRequirementsSuite extends SharedSparkSession { } // matching bucket transforms from both sides - plan1 = DummySparkPlan( + plan1 = new DummySparkPlanWithBatchScanChild( outputPartitioning = KeyGroupedPartitioning( bucket(4, exprA) :: bucket(16, exprB) :: Nil, 4) - ).withDummyBatchScanChild - plan2 = DummySparkPlan( + ) + plan2 = new DummySparkPlanWithBatchScanChild( outputPartitioning = KeyGroupedPartitioning( bucket(4, exprA) :: bucket(16, exprC) :: Nil, 4) - ).withDummyBatchScanChild + ) smjExec = SortMergeJoinExec( exprA :: exprB :: Nil, exprA :: exprC :: Nil, Inner, None, plan1, plan2) EnsureRequirements.apply(smjExec) match { @@ -828,16 +815,16 @@ class EnsureRequirementsSuite extends SharedSparkSession { } // partition collections - plan1 = DummySparkPlan( + plan1 = new DummySparkPlanWithBatchScanChild( outputPartitioning = KeyGroupedPartitioning( bucket(4, exprA) :: bucket(16, exprB) :: Nil, 4) - ).withDummyBatchScanChild - plan2 = DummySparkPlan( + ) + plan2 = new DummySparkPlanWithBatchScanChild( outputPartitioning = PartitioningCollection(Seq( KeyGroupedPartitioning(bucket(4, exprA) :: bucket(16, exprC) :: Nil, 4), KeyGroupedPartitioning(bucket(4, exprA) :: bucket(16, exprC) :: Nil, 4)) ) - ).withDummyBatchScanChild + ) smjExec = SortMergeJoinExec( exprA :: exprB :: Nil, exprA :: exprC :: Nil, Inner, None, plan1, plan2) EnsureRequirements.apply(smjExec) match { @@ -858,12 +845,12 @@ class EnsureRequirementsSuite extends SharedSparkSession { } // bucket + years transforms from both sides - plan1 = DummySparkPlan( + plan1 = new DummySparkPlanWithBatchScanChild( outputPartitioning = KeyGroupedPartitioning(bucket(4, exprA) :: years(exprB) :: Nil, 4) - ).withDummyBatchScanChild - plan2 = DummySparkPlan( + ) + plan2 = new DummySparkPlanWithBatchScanChild( outputPartitioning = KeyGroupedPartitioning(bucket(4, exprA) :: years(exprC) :: Nil, 4) - ).withDummyBatchScanChild + ) smjExec = SortMergeJoinExec( exprA :: exprB :: Nil, exprA :: exprC :: Nil, Inner, None, plan1, plan2) EnsureRequirements.apply(smjExec) match { @@ -877,14 +864,14 @@ class EnsureRequirementsSuite extends SharedSparkSession { // by default spark.sql.requireAllClusterKeysForCoPartition is true, so when there isn't // exact match on all partition keys, Spark will fallback to shuffle. - plan1 = DummySparkPlan( + plan1 = new DummySparkPlanWithBatchScanChild( outputPartitioning = KeyGroupedPartitioning( bucket(4, exprA) :: bucket(4, exprB) :: Nil, 4) - ).withDummyBatchScanChild - plan2 = DummySparkPlan( + ) + plan2 = new DummySparkPlanWithBatchScanChild( outputPartitioning = KeyGroupedPartitioning( bucket(4, exprA) :: bucket(4, exprC) :: Nil, 4) - ).withDummyBatchScanChild + ) smjExec = SortMergeJoinExec( exprA :: exprB :: exprB :: Nil, exprA :: exprC :: exprC :: Nil, Inner, None, plan1, plan2) EnsureRequirements.apply(smjExec) match { @@ -898,14 +885,14 @@ class EnsureRequirementsSuite extends SharedSparkSession { } test(s"KeyGroupedPartitioning with ${REQUIRE_ALL_CLUSTER_KEYS_FOR_CO_PARTITION.key} = false") { - var plan1 = DummySparkPlan( + var plan1 = new DummySparkPlanWithBatchScanChild( outputPartitioning = KeyGroupedPartitioning( bucket(4, exprB) :: years(exprC) :: Nil, 4) - ).withDummyBatchScanChild - var plan2 = DummySparkPlan( + ) + var plan2 = new DummySparkPlanWithBatchScanChild( outputPartitioning = KeyGroupedPartitioning( bucket(4, exprC) :: years(exprB) :: Nil, 4) - ).withDummyBatchScanChild + ) // simple case var smjExec = SortMergeJoinExec( @@ -920,14 +907,14 @@ class EnsureRequirementsSuite extends SharedSparkSession { } // should also work with distributions with duplicated keys - plan1 = DummySparkPlan( + plan1 = new DummySparkPlanWithBatchScanChild( outputPartitioning = KeyGroupedPartitioning( bucket(4, exprA) :: years(exprB) :: Nil, 4) - ).withDummyBatchScanChild - plan2 = DummySparkPlan( + ) + plan2 = new DummySparkPlanWithBatchScanChild( outputPartitioning = KeyGroupedPartitioning( bucket(4, exprA) :: years(exprC) :: Nil, 4) - ).withDummyBatchScanChild + ) smjExec = SortMergeJoinExec( exprA :: exprB :: exprB :: Nil, exprA :: exprC :: exprC :: Nil, Inner, None, plan1, plan2) applyEnsureRequirementsWithSubsetKeys(smjExec) match { @@ -940,14 +927,12 @@ class EnsureRequirementsSuite extends SharedSparkSession { } // both partitioning and distribution have duplicated keys - plan1 = DummySparkPlan( + plan1 = new DummySparkPlanWithBatchScanChild( outputPartitioning = KeyGroupedPartitioning( years(exprA) :: bucket(4, exprB) :: days(exprA) :: Nil, 5)) - .withDummyBatchScanChild - plan2 = DummySparkPlan( + plan2 = new DummySparkPlanWithBatchScanChild( outputPartitioning = KeyGroupedPartitioning( years(exprA) :: bucket(4, exprC) :: days(exprA) :: Nil, 5)) - .withDummyBatchScanChild smjExec = SortMergeJoinExec( exprA :: exprB :: exprB :: Nil, exprA :: exprC :: exprC :: Nil, Inner, None, plan1, plan2) applyEnsureRequirementsWithSubsetKeys(smjExec) match { @@ -960,14 +945,14 @@ class EnsureRequirementsSuite extends SharedSparkSession { } // invalid case: partitioning key positions don't match - plan1 = DummySparkPlan( + plan1 = new DummySparkPlanWithBatchScanChild( outputPartitioning = KeyGroupedPartitioning( bucket(4, exprA) :: bucket(4, exprB) :: Nil, 4) - ).withDummyBatchScanChild - plan2 = DummySparkPlan( + ) + plan2 = new DummySparkPlanWithBatchScanChild( outputPartitioning = KeyGroupedPartitioning( bucket(4, exprB) :: bucket(4, exprC) :: Nil, 4) - ).withDummyBatchScanChild + ) smjExec = SortMergeJoinExec( exprA :: exprB :: exprC :: Nil, exprA :: exprB :: exprC :: Nil, Inner, None, plan1, plan2) @@ -981,14 +966,14 @@ class EnsureRequirementsSuite extends SharedSparkSession { } // invalid case: different number of buckets (we don't support coalescing/repartitioning yet) - plan1 = DummySparkPlan( + plan1 = new DummySparkPlanWithBatchScanChild( outputPartitioning = KeyGroupedPartitioning( bucket(4, exprA) :: bucket(4, exprB) :: Nil, 4) - ).withDummyBatchScanChild - plan2 = DummySparkPlan( + ) + plan2 = new DummySparkPlanWithBatchScanChild( outputPartitioning = KeyGroupedPartitioning( bucket(4, exprA) :: bucket(8, exprC) :: Nil, 4) - ).withDummyBatchScanChild + ) smjExec = SortMergeJoinExec( exprA :: exprB :: exprB :: Nil, exprA :: exprC :: exprC :: Nil, Inner, None, plan1, plan2) applyEnsureRequirementsWithSubsetKeys(smjExec) match { @@ -1001,12 +986,12 @@ class EnsureRequirementsSuite extends SharedSparkSession { } // invalid case: partition key positions match but with different transforms - plan1 = DummySparkPlan( + plan1 = new DummySparkPlanWithBatchScanChild( outputPartitioning = KeyGroupedPartitioning(years(exprA) :: bucket(4, exprB) :: Nil, 4) - ).withDummyBatchScanChild - plan2 = DummySparkPlan( + ) + plan2 = new DummySparkPlanWithBatchScanChild( outputPartitioning = KeyGroupedPartitioning(days(exprA) :: bucket(4, exprC) :: Nil, 4) - ).withDummyBatchScanChild + ) smjExec = SortMergeJoinExec( exprA :: exprB :: exprB :: Nil, exprA :: exprC :: exprC :: Nil, Inner, None, plan1, plan2) applyEnsureRequirementsWithSubsetKeys(smjExec) match { @@ -1020,14 +1005,14 @@ class EnsureRequirementsSuite extends SharedSparkSession { // invalid case: multiple references in transform - plan1 = DummySparkPlan( + plan1 = new DummySparkPlanWithBatchScanChild( outputPartitioning = KeyGroupedPartitioning( years(exprA) :: buckets(4, Seq(exprB, exprC)) :: Nil, 4) - ).withDummyBatchScanChild - plan2 = DummySparkPlan( + ) + plan2 = new DummySparkPlanWithBatchScanChild( outputPartitioning = KeyGroupedPartitioning( years(exprA) :: buckets(4, Seq(exprB, exprC)) :: Nil, 4) - ).withDummyBatchScanChild + ) smjExec = SortMergeJoinExec( exprA :: exprB :: exprB :: Nil, exprA :: exprC :: exprC :: Nil, Inner, None, plan1, plan2) applyEnsureRequirementsWithSubsetKeys(smjExec) match { @@ -1046,14 +1031,14 @@ class EnsureRequirementsSuite extends SharedSparkSession { val rightPartValues = Seq(Array[Any](1, 1), Array[Any](2, 2), Array[Any](3, 3)) .map(new GenericInternalRow(_)) - var plan1 = DummySparkPlan( + var plan1 = new DummySparkPlanWithBatchScanChild( outputPartitioning = KeyGroupedPartitioning(bucket(4, exprB) :: bucket(8, exprC) :: Nil, leftPartValues.length, leftPartValues) - ).withDummyBatchScanChild - var plan2 = DummySparkPlan( + ) + var plan2 = new DummySparkPlanWithBatchScanChild( outputPartitioning = KeyGroupedPartitioning(bucket(4, exprC) :: bucket(8, exprB) :: Nil, rightPartValues.length, rightPartValues) - ).withDummyBatchScanChild + ) // simple case var smjExec = SortMergeJoinExec( @@ -1068,14 +1053,14 @@ class EnsureRequirementsSuite extends SharedSparkSession { } // With partition collections - plan1 = DummySparkPlan(outputPartitioning = + plan1 = new DummySparkPlanWithBatchScanChild(outputPartitioning = PartitioningCollection( Seq(KeyGroupedPartitioning(bucket(4, exprB) :: bucket(8, exprC) :: Nil, leftPartValues.length, leftPartValues), KeyGroupedPartitioning(bucket(4, exprB) :: bucket(8, exprC) :: Nil, leftPartValues.length, leftPartValues)) ) - ).withDummyBatchScanChild + ) smjExec = SortMergeJoinExec( exprA :: exprB :: exprC :: Nil, exprA :: exprC :: exprB :: Nil, Inner, None, plan1, plan2) @@ -1092,7 +1077,7 @@ class EnsureRequirementsSuite extends SharedSparkSession { } // Nested partition collections - plan2 = DummySparkPlan(outputPartitioning = + plan2 = new DummySparkPlanWithBatchScanChild(outputPartitioning = PartitioningCollection( Seq( PartitioningCollection( @@ -1109,7 +1094,7 @@ class EnsureRequirementsSuite extends SharedSparkSession { rightPartValues.length, rightPartValues))) ) ) - ).withDummyBatchScanChild + ) smjExec = SortMergeJoinExec( exprA :: exprB :: exprC :: Nil, exprA :: exprC :: exprB :: Nil, Inner, None, plan1, plan2) @@ -1135,9 +1120,8 @@ class EnsureRequirementsSuite extends SharedSparkSession { val a1 = AttributeReference("a1", IntegerType)() val partitionValue = Seq(50, 51, 52).map(v => InternalRow.fromSeq(Seq(v))) - val plan1 = DummySparkPlan(outputPartitioning = KeyGroupedPartitioning( + val plan1 = new DummySparkPlanWithBatchScanChild(outputPartitioning = KeyGroupedPartitioning( identity(a1) :: Nil, 4, partitionValue)) - .withDummyBatchScanChild val plan2 = DummySparkPlan(outputPartitioning = SinglePartition) val smjExec = ShuffledHashJoinExec( @@ -1406,4 +1390,12 @@ class EnsureRequirementsSuite extends SharedSparkSession { def days(expr: Expression): TransformExpression = { TransformExpression(DaysFunction, Seq(expr)) } + + private class DummySparkPlanWithBatchScanChild(outputPartitioning: Partitioning) + extends DummySparkPlan( + children = Seq(BatchScanExec(Seq.empty, null, Seq.empty, table = null)), + outputPartitioning = outputPartitioning, + requiredChildDistribution = Seq(UnspecifiedDistribution), + requiredChildOrdering = Seq(Seq.empty) + ) } From 58a638e637331b369d7b1454729c0977b2f68906 Mon Sep 17 00:00:00 2001 From: Chirag Singh Date: Tue, 18 Nov 2025 07:52:06 -0800 Subject: [PATCH 4/4] fix --- .../connector/KeyGroupedPartitioningSuite.scala | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/connector/KeyGroupedPartitioningSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/connector/KeyGroupedPartitioningSuite.scala index d34664403429..7798397d96b3 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/connector/KeyGroupedPartitioningSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/connector/KeyGroupedPartitioningSuite.scala @@ -2631,8 +2631,8 @@ class KeyGroupedPartitioningSuite extends DistributionAndOrderingSuiteBase { test("SPARK-53322: checkpointed scans avoid shuffles for aggregates") { withTempDir { dir => spark.sparkContext.setCheckpointDir(dir.getPath) - val items_partitions = Array(identity("id")) - createTable(items, itemsColumns, items_partitions) + val itemsPartitions = Array(identity("id")) + createTable(items, itemsColumns, itemsPartitions) sql(s"INSERT INTO testcat.ns.$items VALUES " + s"(1, 'aa', 40.0, cast('2020-01-01' as timestamp)), " + s"(1, 'aa', 41.0, cast('2020-01-02' as timestamp)), " + @@ -2652,8 +2652,8 @@ class KeyGroupedPartitioningSuite extends DistributionAndOrderingSuiteBase { test("SPARK-53322: checkpointed scans aren't used for SPJ") { withTempDir { dir => spark.sparkContext.setCheckpointDir(dir.getPath) - val items_partitions = Array(identity("id")) - createTable(items, itemsColumns, items_partitions) + val itemsPartitions = Array(identity("id")) + createTable(items, itemsColumns, itemsPartitions) sql(s"INSERT INTO testcat.ns.$items VALUES " + s"(1, 'aa', 41.0, cast('2020-01-01' as timestamp)), " + s"(2, 'bb', 10.0, cast('2020-01-02' as timestamp)), " + @@ -2698,8 +2698,8 @@ class KeyGroupedPartitioningSuite extends DistributionAndOrderingSuiteBase { test("SPARK-53322: checkpointed scans can't shuffle other children on SPJ") { withTempDir { dir => spark.sparkContext.setCheckpointDir(dir.getPath) - val items_partitions = Array(identity("id")) - createTable(items, itemsColumns, items_partitions) + val itemsPartitions = Array(identity("id")) + createTable(items, itemsColumns, itemsPartitions) sql(s"INSERT INTO testcat.ns.$items VALUES " + s"(1, 'aa', 41.0, cast('2020-01-01' as timestamp)), " + s"(2, 'bb', 10.0, cast('2020-01-02' as timestamp)), " + @@ -2737,8 +2737,8 @@ class KeyGroupedPartitioningSuite extends DistributionAndOrderingSuiteBase { test("SPARK-53322: checkpointed scans can be shuffled by children on SPJ") { withTempDir { dir => spark.sparkContext.setCheckpointDir(dir.getPath) - val items_partitions = Array(identity("id")) - createTable(items, itemsColumns, items_partitions) + val itemsPartitions = Array(identity("id")) + createTable(items, itemsColumns, itemsPartitions) sql(s"INSERT INTO testcat.ns.$items VALUES " + s"(1, 'aa', 41.0, cast('2020-01-01' as timestamp)), " + s"(2, 'bb', 10.0, cast('2020-01-02' as timestamp)), " +