diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/EnsureRequirements.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/EnsureRequirements.scala index b97d765afcf7..088ece6554c9 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/EnsureRequirements.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/EnsureRequirements.scala @@ -140,6 +140,13 @@ case class EnsureRequirements( // Choose all the specs that can be used to shuffle other children val candidateSpecs = specs .filter(_._2.canCreatePartitioning) + .filter { + // To choose a KeyGroupedShuffleSpec, we must be able to push down SPJ parameters into + // the scan (for join key positions). If these parameters can't be pushed down, this + // spec can't be used to shuffle other children. + case (idx, _: KeyGroupedShuffleSpec) => canPushDownSPJParamsToScan(children(idx)) + case _ => true + } .filter(p => !shouldConsiderMinParallelism || children(p._1).outputPartitioning.numPartitions >= conf.defaultNumShufflePartitions) val bestSpecOpt = if (candidateSpecs.isEmpty) { @@ -402,6 +409,24 @@ case class EnsureRequirements( } } + /** + * Whether SPJ params can be pushed down to the leaf nodes of a physical plan. For a plan to be + * eligible for SPJ parameter pushdown, all leaf nodes must be a KeyGroupedPartitioning-aware + * scan. + * + * Notably, if the leaf of `plan` is an [[RDDScanExec]] created by checkpointing a DSv2 scan, the + * reported partitioning will be a [[KeyGroupedPartitioning]], but this plan will _not_ be + * eligible for SPJ parameter pushdown (as the partitioning is static and can't be easily + * re-grouped or padded with empty partitions according to the partition values on the other side + * of the join). + */ + private def canPushDownSPJParamsToScan(plan: SparkPlan): Boolean = { + plan.collectLeaves().forall { + case _: KeyGroupedPartitionedScan[_] => true + case _ => false + } + } + /** * Checks whether two children, `left` and `right`, of a join operator have compatible * `KeyGroupedPartitioning`, and can benefit from storage-partitioned join. @@ -413,6 +438,12 @@ case class EnsureRequirements( left: SparkPlan, right: SparkPlan, requiredChildDistribution: Seq[Distribution]): Option[Seq[SparkPlan]] = { + // If SPJ params can't be pushed down to either the left or right side, it's unsafe to do an + // SPJ. + if (!canPushDownSPJParamsToScan(left) || !canPushDownSPJParamsToScan(right)) { + return None + } + parent match { case smj: SortMergeJoinExec => checkKeyGroupCompatible(left, right, smj.joinType, requiredChildDistribution) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/connector/KeyGroupedPartitioningSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/connector/KeyGroupedPartitioningSuite.scala index c73e8e16fbbb..aba866e96b5c 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/connector/KeyGroupedPartitioningSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/connector/KeyGroupedPartitioningSuite.scala @@ -34,6 +34,7 @@ import org.apache.spark.sql.execution.datasources.v2.BatchScanExec import org.apache.spark.sql.execution.datasources.v2.DataSourceV2ScanRelation import org.apache.spark.sql.execution.exchange.ShuffleExchangeExec import org.apache.spark.sql.execution.joins.SortMergeJoinExec +import org.apache.spark.sql.functions.{col, max} import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.internal.SQLConf._ import org.apache.spark.sql.types._ @@ -2626,4 +2627,148 @@ class KeyGroupedPartitioningSuite extends DistributionAndOrderingSuiteBase { assert(scans.forall(_.inputRDD.partitions.length == 2)) } } + + test("SPARK-53322: checkpointed scans avoid shuffles for aggregates") { + withTempDir { dir => + spark.sparkContext.setCheckpointDir(dir.getPath) + val itemsPartitions = Array(identity("id")) + createTable(items, itemsColumns, itemsPartitions) + sql(s"INSERT INTO testcat.ns.$items VALUES " + + s"(1, 'aa', 40.0, cast('2020-01-01' as timestamp)), " + + s"(1, 'aa', 41.0, cast('2020-01-02' as timestamp)), " + + s"(2, 'bb', 10.0, cast('2020-01-01' as timestamp)), " + + s"(3, 'cc', 15.5, cast('2020-02-01' as timestamp))") + + val scanDF = spark.read.table(s"testcat.ns.$items").checkpoint() + val df = scanDF.groupBy("id").agg(max("price").as("res")).select("res") + checkAnswer(df.sort("res"), Seq(Row(10.0), Row(15.5), Row(41.0))) + + val shuffles = collectAllShuffles(df.queryExecution.executedPlan) + assert(shuffles.isEmpty, + "should not contain shuffle when not grouping by partition values") + } + } + + test("SPARK-53322: checkpointed scans aren't used for SPJ") { + withTempDir { dir => + spark.sparkContext.setCheckpointDir(dir.getPath) + val itemsPartitions = Array(identity("id")) + createTable(items, itemsColumns, itemsPartitions) + sql(s"INSERT INTO testcat.ns.$items VALUES " + + s"(1, 'aa', 41.0, cast('2020-01-01' as timestamp)), " + + s"(2, 'bb', 10.0, cast('2020-01-02' as timestamp)), " + + s"(3, 'cc', 15.5, cast('2020-01-03' as timestamp))") + + val purchase_partitions = Array(identity("item_id")) + createTable(purchases, purchasesColumns, purchase_partitions) + sql(s"INSERT INTO testcat.ns.$purchases VALUES " + + s"(1, 40.0, cast('2020-01-01' as timestamp)), " + + s"(3, 25.5, cast('2020-01-03' as timestamp)), " + + s"(4, 20.0, cast('2020-01-04' as timestamp))") + + for { + pushdownValues <- Seq(true, false) + checkpointBothScans <- Seq(true, false) + } { + withSQLConf( + SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "-1", + SQLConf.V2_BUCKETING_PUSH_PART_VALUES_ENABLED.key -> pushdownValues.toString) { + val scanDF1 = spark.read.table(s"testcat.ns.$items").checkpoint().as("i") + val scanDF2 = if (checkpointBothScans) { + spark.read.table(s"testcat.ns.$purchases").checkpoint().as("p") + } else { + spark.read.table(s"testcat.ns.$purchases").as("p") + } + + val df = scanDF1 + .join(scanDF2, col("id") === col("item_id")) + .selectExpr("id", "name", "i.price AS purchase_price", "p.price AS sale_price") + .orderBy("id", "purchase_price", "sale_price") + checkAnswer( + df, + Seq(Row(1, "aa", 41.0, 40.0), Row(3, "cc", 15.5, 25.5)) + ) + // 1 shuffle for SORT and 2 shuffles for JOIN are expected. + assert(collectAllShuffles(df.queryExecution.executedPlan).length === 3) + } + } + } + } + + test("SPARK-53322: checkpointed scans can't shuffle other children on SPJ") { + withTempDir { dir => + spark.sparkContext.setCheckpointDir(dir.getPath) + val itemsPartitions = Array(identity("id")) + createTable(items, itemsColumns, itemsPartitions) + sql(s"INSERT INTO testcat.ns.$items VALUES " + + s"(1, 'aa', 41.0, cast('2020-01-01' as timestamp)), " + + s"(2, 'bb', 10.0, cast('2020-01-02' as timestamp)), " + + s"(3, 'cc', 15.5, cast('2020-01-03' as timestamp))") + + createTable(purchases, purchasesColumns, Array.empty) + sql(s"INSERT INTO testcat.ns.$purchases VALUES " + + s"(1, 40.0, cast('2020-01-01' as timestamp)), " + + s"(3, 25.5, cast('2020-01-03' as timestamp)), " + + s"(4, 20.0, cast('2020-01-04' as timestamp))") + + Seq(true, false).foreach { pushdownValues => + withSQLConf( + SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "-1", + SQLConf.V2_BUCKETING_SHUFFLE_ENABLED.key -> "true", + SQLConf.V2_BUCKETING_PUSH_PART_VALUES_ENABLED.key -> pushdownValues.toString) { + val scanDF1 = spark.read.table(s"testcat.ns.$items").checkpoint().as("i") + val scanDF2 = spark.read.table(s"testcat.ns.$purchases").as("p") + + val df = scanDF1 + .join(scanDF2, col("id") === col("item_id")) + .selectExpr("id", "name", "i.price AS purchase_price", "p.price AS sale_price") + .orderBy("id", "purchase_price", "sale_price") + checkAnswer( + df, + Seq(Row(1, "aa", 41.0, 40.0), Row(3, "cc", 15.5, 25.5)) + ) + // 1 shuffle for SORT and 2 shuffles for JOIN are expected. + assert(collectAllShuffles(df.queryExecution.executedPlan).length === 3) + } + } + } + } + + test("SPARK-53322: checkpointed scans can be shuffled by children on SPJ") { + withTempDir { dir => + spark.sparkContext.setCheckpointDir(dir.getPath) + val itemsPartitions = Array(identity("id")) + createTable(items, itemsColumns, itemsPartitions) + sql(s"INSERT INTO testcat.ns.$items VALUES " + + s"(1, 'aa', 41.0, cast('2020-01-01' as timestamp)), " + + s"(2, 'bb', 10.0, cast('2020-01-02' as timestamp)), " + + s"(3, 'cc', 15.5, cast('2020-01-03' as timestamp))") + + createTable(purchases, purchasesColumns, Array(identity("item_id"))) + sql(s"INSERT INTO testcat.ns.$purchases VALUES " + + s"(1, 40.0, cast('2020-01-01' as timestamp)), " + + s"(3, 25.5, cast('2020-01-03' as timestamp)), " + + s"(4, 20.0, cast('2020-01-04' as timestamp))") + + withSQLConf( + SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "-1", + SQLConf.V2_BUCKETING_SHUFFLE_ENABLED.key -> "true", + SQLConf.V2_BUCKETING_PUSH_PART_VALUES_ENABLED.key -> "true") { + val scanDF1 = spark.read.table(s"testcat.ns.$items").checkpoint().as("i") + val scanDF2 = spark.read.table(s"testcat.ns.$purchases").as("p") + + val df = scanDF1 + .join(scanDF2, col("id") === col("item_id")) + .selectExpr("id", "name", "i.price AS purchase_price", "p.price AS sale_price") + .orderBy("id", "purchase_price", "sale_price") + checkAnswer( + df, + Seq(Row(1, "aa", 41.0, 40.0), Row(3, "cc", 15.5, 25.5)) + ) + + // One shuffle for the sort and one shuffle for one side of the JOIN are expected. + assert(collectAllShuffles(df.queryExecution.executedPlan).length === 2) + } + } + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/exchange/EnsureRequirementsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/exchange/EnsureRequirementsSuite.scala index b94ca4673641..1cc0d795d74f 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/exchange/EnsureRequirementsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/exchange/EnsureRequirementsSuite.scala @@ -29,6 +29,7 @@ import org.apache.spark.sql.catalyst.statsEstimation.StatsTestPlan import org.apache.spark.sql.connector.catalog.functions._ import org.apache.spark.sql.execution.{DummySparkPlan, SortExec} import org.apache.spark.sql.execution.SparkPlan +import org.apache.spark.sql.execution.datasources.v2.BatchScanExec import org.apache.spark.sql.execution.joins.{ShuffledHashJoinExec, SortMergeJoinExec} import org.apache.spark.sql.execution.python.FlatMapCoGroupsInPandasExec import org.apache.spark.sql.execution.window.WindowExec @@ -92,11 +93,11 @@ class EnsureRequirementsSuite extends SharedSparkSession { test("reorder should handle KeyGroupedPartitioning") { // partitioning on the left - val plan1 = DummySparkPlan( + val plan1 = new DummySparkPlanWithBatchScanChild( outputPartitioning = KeyGroupedPartitioning(Seq( years(exprA), bucket(4, exprB), days(exprC)), 4) ) - val plan2 = DummySparkPlan( + val plan2 = new DummySparkPlanWithBatchScanChild( outputPartitioning = KeyGroupedPartitioning(Seq( years(exprB), bucket(4, exprA), days(exprD)), 4) ) @@ -114,7 +115,7 @@ class EnsureRequirementsSuite extends SharedSparkSession { } // partitioning on the right - val plan3 = DummySparkPlan( + val plan3 = new DummySparkPlanWithBatchScanChild( outputPartitioning = KeyGroupedPartitioning(Seq( bucket(4, exprD), days(exprA), years(exprC)), 4) ) @@ -778,9 +779,9 @@ class EnsureRequirementsSuite extends SharedSparkSession { test("Check with KeyGroupedPartitioning") { // simplest case: identity transforms - var plan1 = DummySparkPlan( - outputPartitioning = KeyGroupedPartitioning(exprA :: exprB :: Nil, 5)) - var plan2 = DummySparkPlan( + var plan1 = new DummySparkPlanWithBatchScanChild( + KeyGroupedPartitioning(exprA :: exprB :: Nil, 5)) + var plan2 = new DummySparkPlanWithBatchScanChild( outputPartitioning = KeyGroupedPartitioning(exprA :: exprC :: Nil, 5)) var smjExec = SortMergeJoinExec( exprA :: exprB :: Nil, exprA :: exprC :: Nil, Inner, None, plan1, plan2) @@ -794,11 +795,11 @@ class EnsureRequirementsSuite extends SharedSparkSession { } // matching bucket transforms from both sides - plan1 = DummySparkPlan( + plan1 = new DummySparkPlanWithBatchScanChild( outputPartitioning = KeyGroupedPartitioning( bucket(4, exprA) :: bucket(16, exprB) :: Nil, 4) ) - plan2 = DummySparkPlan( + plan2 = new DummySparkPlanWithBatchScanChild( outputPartitioning = KeyGroupedPartitioning( bucket(4, exprA) :: bucket(16, exprC) :: Nil, 4) ) @@ -814,11 +815,11 @@ class EnsureRequirementsSuite extends SharedSparkSession { } // partition collections - plan1 = DummySparkPlan( + plan1 = new DummySparkPlanWithBatchScanChild( outputPartitioning = KeyGroupedPartitioning( bucket(4, exprA) :: bucket(16, exprB) :: Nil, 4) ) - plan2 = DummySparkPlan( + plan2 = new DummySparkPlanWithBatchScanChild( outputPartitioning = PartitioningCollection(Seq( KeyGroupedPartitioning(bucket(4, exprA) :: bucket(16, exprC) :: Nil, 4), KeyGroupedPartitioning(bucket(4, exprA) :: bucket(16, exprC) :: Nil, 4)) @@ -844,10 +845,10 @@ class EnsureRequirementsSuite extends SharedSparkSession { } // bucket + years transforms from both sides - plan1 = DummySparkPlan( + plan1 = new DummySparkPlanWithBatchScanChild( outputPartitioning = KeyGroupedPartitioning(bucket(4, exprA) :: years(exprB) :: Nil, 4) ) - plan2 = DummySparkPlan( + plan2 = new DummySparkPlanWithBatchScanChild( outputPartitioning = KeyGroupedPartitioning(bucket(4, exprA) :: years(exprC) :: Nil, 4) ) smjExec = SortMergeJoinExec( @@ -863,11 +864,11 @@ class EnsureRequirementsSuite extends SharedSparkSession { // by default spark.sql.requireAllClusterKeysForCoPartition is true, so when there isn't // exact match on all partition keys, Spark will fallback to shuffle. - plan1 = DummySparkPlan( + plan1 = new DummySparkPlanWithBatchScanChild( outputPartitioning = KeyGroupedPartitioning( bucket(4, exprA) :: bucket(4, exprB) :: Nil, 4) ) - plan2 = DummySparkPlan( + plan2 = new DummySparkPlanWithBatchScanChild( outputPartitioning = KeyGroupedPartitioning( bucket(4, exprA) :: bucket(4, exprC) :: Nil, 4) ) @@ -884,11 +885,11 @@ class EnsureRequirementsSuite extends SharedSparkSession { } test(s"KeyGroupedPartitioning with ${REQUIRE_ALL_CLUSTER_KEYS_FOR_CO_PARTITION.key} = false") { - var plan1 = DummySparkPlan( + var plan1 = new DummySparkPlanWithBatchScanChild( outputPartitioning = KeyGroupedPartitioning( bucket(4, exprB) :: years(exprC) :: Nil, 4) ) - var plan2 = DummySparkPlan( + var plan2 = new DummySparkPlanWithBatchScanChild( outputPartitioning = KeyGroupedPartitioning( bucket(4, exprC) :: years(exprB) :: Nil, 4) ) @@ -906,11 +907,11 @@ class EnsureRequirementsSuite extends SharedSparkSession { } // should also work with distributions with duplicated keys - plan1 = DummySparkPlan( + plan1 = new DummySparkPlanWithBatchScanChild( outputPartitioning = KeyGroupedPartitioning( bucket(4, exprA) :: years(exprB) :: Nil, 4) ) - plan2 = DummySparkPlan( + plan2 = new DummySparkPlanWithBatchScanChild( outputPartitioning = KeyGroupedPartitioning( bucket(4, exprA) :: years(exprC) :: Nil, 4) ) @@ -926,10 +927,10 @@ class EnsureRequirementsSuite extends SharedSparkSession { } // both partitioning and distribution have duplicated keys - plan1 = DummySparkPlan( + plan1 = new DummySparkPlanWithBatchScanChild( outputPartitioning = KeyGroupedPartitioning( years(exprA) :: bucket(4, exprB) :: days(exprA) :: Nil, 5)) - plan2 = DummySparkPlan( + plan2 = new DummySparkPlanWithBatchScanChild( outputPartitioning = KeyGroupedPartitioning( years(exprA) :: bucket(4, exprC) :: days(exprA) :: Nil, 5)) smjExec = SortMergeJoinExec( @@ -944,11 +945,11 @@ class EnsureRequirementsSuite extends SharedSparkSession { } // invalid case: partitioning key positions don't match - plan1 = DummySparkPlan( + plan1 = new DummySparkPlanWithBatchScanChild( outputPartitioning = KeyGroupedPartitioning( bucket(4, exprA) :: bucket(4, exprB) :: Nil, 4) ) - plan2 = DummySparkPlan( + plan2 = new DummySparkPlanWithBatchScanChild( outputPartitioning = KeyGroupedPartitioning( bucket(4, exprB) :: bucket(4, exprC) :: Nil, 4) ) @@ -965,11 +966,11 @@ class EnsureRequirementsSuite extends SharedSparkSession { } // invalid case: different number of buckets (we don't support coalescing/repartitioning yet) - plan1 = DummySparkPlan( + plan1 = new DummySparkPlanWithBatchScanChild( outputPartitioning = KeyGroupedPartitioning( bucket(4, exprA) :: bucket(4, exprB) :: Nil, 4) ) - plan2 = DummySparkPlan( + plan2 = new DummySparkPlanWithBatchScanChild( outputPartitioning = KeyGroupedPartitioning( bucket(4, exprA) :: bucket(8, exprC) :: Nil, 4) ) @@ -985,10 +986,10 @@ class EnsureRequirementsSuite extends SharedSparkSession { } // invalid case: partition key positions match but with different transforms - plan1 = DummySparkPlan( + plan1 = new DummySparkPlanWithBatchScanChild( outputPartitioning = KeyGroupedPartitioning(years(exprA) :: bucket(4, exprB) :: Nil, 4) ) - plan2 = DummySparkPlan( + plan2 = new DummySparkPlanWithBatchScanChild( outputPartitioning = KeyGroupedPartitioning(days(exprA) :: bucket(4, exprC) :: Nil, 4) ) smjExec = SortMergeJoinExec( @@ -1004,11 +1005,11 @@ class EnsureRequirementsSuite extends SharedSparkSession { // invalid case: multiple references in transform - plan1 = DummySparkPlan( + plan1 = new DummySparkPlanWithBatchScanChild( outputPartitioning = KeyGroupedPartitioning( years(exprA) :: buckets(4, Seq(exprB, exprC)) :: Nil, 4) ) - plan2 = DummySparkPlan( + plan2 = new DummySparkPlanWithBatchScanChild( outputPartitioning = KeyGroupedPartitioning( years(exprA) :: buckets(4, Seq(exprB, exprC)) :: Nil, 4) ) @@ -1030,11 +1031,11 @@ class EnsureRequirementsSuite extends SharedSparkSession { val rightPartValues = Seq(Array[Any](1, 1), Array[Any](2, 2), Array[Any](3, 3)) .map(new GenericInternalRow(_)) - var plan1 = DummySparkPlan( + var plan1 = new DummySparkPlanWithBatchScanChild( outputPartitioning = KeyGroupedPartitioning(bucket(4, exprB) :: bucket(8, exprC) :: Nil, leftPartValues.length, leftPartValues) ) - var plan2 = DummySparkPlan( + var plan2 = new DummySparkPlanWithBatchScanChild( outputPartitioning = KeyGroupedPartitioning(bucket(4, exprC) :: bucket(8, exprB) :: Nil, rightPartValues.length, rightPartValues) ) @@ -1052,7 +1053,7 @@ class EnsureRequirementsSuite extends SharedSparkSession { } // With partition collections - plan1 = DummySparkPlan(outputPartitioning = + plan1 = new DummySparkPlanWithBatchScanChild(outputPartitioning = PartitioningCollection( Seq(KeyGroupedPartitioning(bucket(4, exprB) :: bucket(8, exprC) :: Nil, leftPartValues.length, leftPartValues), @@ -1076,7 +1077,7 @@ class EnsureRequirementsSuite extends SharedSparkSession { } // Nested partition collections - plan2 = DummySparkPlan(outputPartitioning = + plan2 = new DummySparkPlanWithBatchScanChild(outputPartitioning = PartitioningCollection( Seq( PartitioningCollection( @@ -1119,7 +1120,7 @@ class EnsureRequirementsSuite extends SharedSparkSession { val a1 = AttributeReference("a1", IntegerType)() val partitionValue = Seq(50, 51, 52).map(v => InternalRow.fromSeq(Seq(v))) - val plan1 = DummySparkPlan(outputPartitioning = KeyGroupedPartitioning( + val plan1 = new DummySparkPlanWithBatchScanChild(outputPartitioning = KeyGroupedPartitioning( identity(a1) :: Nil, 4, partitionValue)) val plan2 = DummySparkPlan(outputPartitioning = SinglePartition) @@ -1389,4 +1390,12 @@ class EnsureRequirementsSuite extends SharedSparkSession { def days(expr: Expression): TransformExpression = { TransformExpression(DaysFunction, Seq(expr)) } + + private class DummySparkPlanWithBatchScanChild(outputPartitioning: Partitioning) + extends DummySparkPlan( + children = Seq(BatchScanExec(Seq.empty, null, Seq.empty, table = null)), + outputPartitioning = outputPartitioning, + requiredChildDistribution = Seq(UnspecifiedDistribution), + requiredChildOrdering = Seq(Seq.empty) + ) }