Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -140,6 +140,13 @@ case class EnsureRequirements(
// Choose all the specs that can be used to shuffle other children
val candidateSpecs = specs
.filter(_._2.canCreatePartitioning)
.filter {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

just for reference , were both checks needed? ie this and the other check in 'checkKeyGroupCompatible'

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yeah, both checks are needed. The reason We need the check in checkKeyGroupCompatible for the case that both children are key-grouped partitionings, and this check handles the case where only 1 child is a key-grouped partitioning and is shuffling a non-KGP plan

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Trying to understand this too. In checkKeyGroupCompatible we already makes sure that both children are of KeyGroupedPartitioning. This new check additionally checks that leaf nodes from both are all KeyGroupedPartitionedScan?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

checkKeyGroupCompatible applies to the case where we have 2 KeyGroupedPartitioned scans that are being joined against each other. For example, something like:

SortMergeJoinExec ...
  +- BatchScanExec tbl1 ... -> reporting KeyGroupedPartitioning
  +- BatchScanExec tbl2 ... -> reporting KeyGroupedPartitioning

If one child is not KeyGroupedPartitioned, we can still avoid the shuffle for one child (in general):

SortMergeJoinExec ...
  +- BatchScanExec tbl1 ... -> reporting KeyGroupedPartitioning
  +- ShuffleExchangeExec KeyGroupedPartitioning
    +- BatchScanExec tbl2 ... -> reporting UnknownPartitioning

However, if the child reporting the KeyGroupedPartitioning is not a BatchScanExec, then we can't safely push down the JOIN keys, making it unsafe to do this. This may arise if we call .checkpoint() on a BatchScanExec:

SortMergeJoinExec ...
  +- RDDScanExec ... -> reporting KeyGroupedPartitioning (coming from ckpt of tbl1 scan)
  +- ShuffleExchangeExec KeyGroupedPartitioning
    +- BatchScanExec tbl2 ... -> reporting UnknownPartitioning

This extra check is for this second case, where we want to make sure that we're not using a KeyGroupedPartitioning to shuffle another child of a JOIN without being able to push down JOIN keys. The test "SPARK-53322: checkpointed scans can't shuffle other children on SPJ" is for this case, and will fail without this change.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see, thanks for the explanation!

// To choose a KeyGroupedShuffleSpec, we must be able to push down SPJ parameters into
// the scan (for join key positions). If these parameters can't be pushed down, this
// spec can't be used to shuffle other children.
case (idx, _: KeyGroupedShuffleSpec) => canPushDownSPJParamsToScan(children(idx))
case _ => true
}
.filter(p => !shouldConsiderMinParallelism ||
children(p._1).outputPartitioning.numPartitions >= conf.defaultNumShufflePartitions)
val bestSpecOpt = if (candidateSpecs.isEmpty) {
Expand Down Expand Up @@ -402,6 +409,24 @@ case class EnsureRequirements(
}
}

/**
* Whether SPJ params can be pushed down to the leaf nodes of a physical plan. For a plan to be
* eligible for SPJ parameter pushdown, all leaf nodes must be a KeyGroupedPartitioning-aware
* scan.
*
* Notably, if the leaf of `plan` is an [[RDDScanExec]] created by checkpointing a DSv2 scan, the
* reported partitioning will be a [[KeyGroupedPartitioning]], but this plan will _not_ be
* eligible for SPJ parameter pushdown (as the partitioning is static and can't be easily
* re-grouped or padded with empty partitions according to the partition values on the other side
* of the join).
*/
private def canPushDownSPJParamsToScan(plan: SparkPlan): Boolean = {
plan.collectLeaves().forall {
case _: KeyGroupedPartitionedScan[_] => true
case _ => false
}
}

/**
* Checks whether two children, `left` and `right`, of a join operator have compatible
* `KeyGroupedPartitioning`, and can benefit from storage-partitioned join.
Expand All @@ -413,6 +438,12 @@ case class EnsureRequirements(
left: SparkPlan,
right: SparkPlan,
requiredChildDistribution: Seq[Distribution]): Option[Seq[SparkPlan]] = {
// If SPJ params can't be pushed down to either the left or right side, it's unsafe to do an
// SPJ.
if (!canPushDownSPJParamsToScan(left) || !canPushDownSPJParamsToScan(right)) {
return None
}

parent match {
case smj: SortMergeJoinExec =>
checkKeyGroupCompatible(left, right, smj.joinType, requiredChildDistribution)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ import org.apache.spark.sql.execution.datasources.v2.BatchScanExec
import org.apache.spark.sql.execution.datasources.v2.DataSourceV2ScanRelation
import org.apache.spark.sql.execution.exchange.{ShuffleExchangeExec, ShuffleExchangeLike}
import org.apache.spark.sql.execution.joins.SortMergeJoinExec
import org.apache.spark.sql.functions.{col, max}
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.internal.SQLConf._
import org.apache.spark.sql.types._
Expand Down Expand Up @@ -2626,4 +2627,148 @@ class KeyGroupedPartitioningSuite extends DistributionAndOrderingSuiteBase {
assert(scans.forall(_.inputRDD.partitions.length == 2))
}
}

test("SPARK-53322: checkpointed scans avoid shuffles for aggregates") {
withTempDir { dir =>
spark.sparkContext.setCheckpointDir(dir.getPath)
val itemsPartitions = Array(identity("id"))
createTable(items, itemsColumns, itemsPartitions)
sql(s"INSERT INTO testcat.ns.$items VALUES " +
s"(1, 'aa', 40.0, cast('2020-01-01' as timestamp)), " +
s"(1, 'aa', 41.0, cast('2020-01-02' as timestamp)), " +
s"(2, 'bb', 10.0, cast('2020-01-01' as timestamp)), " +
s"(3, 'cc', 15.5, cast('2020-02-01' as timestamp))")

val scanDF = spark.read.table(s"testcat.ns.$items").checkpoint()
val df = scanDF.groupBy("id").agg(max("price").as("res")).select("res")
checkAnswer(df.sort("res"), Seq(Row(10.0), Row(15.5), Row(41.0)))

val shuffles = collectAllShuffles(df.queryExecution.executedPlan)
assert(shuffles.isEmpty,
"should not contain shuffle when not grouping by partition values")
}
}

test("SPARK-53322: checkpointed scans aren't used for SPJ") {
withTempDir { dir =>
spark.sparkContext.setCheckpointDir(dir.getPath)
val itemsPartitions = Array(identity("id"))
createTable(items, itemsColumns, itemsPartitions)
sql(s"INSERT INTO testcat.ns.$items VALUES " +
s"(1, 'aa', 41.0, cast('2020-01-01' as timestamp)), " +
s"(2, 'bb', 10.0, cast('2020-01-02' as timestamp)), " +
s"(3, 'cc', 15.5, cast('2020-01-03' as timestamp))")

val purchase_partitions = Array(identity("item_id"))
createTable(purchases, purchasesColumns, purchase_partitions)
sql(s"INSERT INTO testcat.ns.$purchases VALUES " +
s"(1, 40.0, cast('2020-01-01' as timestamp)), " +
s"(3, 25.5, cast('2020-01-03' as timestamp)), " +
s"(4, 20.0, cast('2020-01-04' as timestamp))")

for {
pushdownValues <- Seq(true, false)
checkpointBothScans <- Seq(true, false)
} {
withSQLConf(
SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "-1",
SQLConf.V2_BUCKETING_PUSH_PART_VALUES_ENABLED.key -> pushdownValues.toString) {
val scanDF1 = spark.read.table(s"testcat.ns.$items").checkpoint().as("i")
val scanDF2 = if (checkpointBothScans) {
spark.read.table(s"testcat.ns.$purchases").checkpoint().as("p")
} else {
spark.read.table(s"testcat.ns.$purchases").as("p")
}

val df = scanDF1
.join(scanDF2, col("id") === col("item_id"))
.selectExpr("id", "name", "i.price AS purchase_price", "p.price AS sale_price")
.orderBy("id", "purchase_price", "sale_price")
checkAnswer(
df,
Seq(Row(1, "aa", 41.0, 40.0), Row(3, "cc", 15.5, 25.5))
)
// 1 shuffle for SORT and 2 shuffles for JOIN are expected.
assert(collectAllShuffles(df.queryExecution.executedPlan).length === 3)
}
}
}
}

test("SPARK-53322: checkpointed scans can't shuffle other children on SPJ") {
withTempDir { dir =>
spark.sparkContext.setCheckpointDir(dir.getPath)
val itemsPartitions = Array(identity("id"))
createTable(items, itemsColumns, itemsPartitions)
sql(s"INSERT INTO testcat.ns.$items VALUES " +
s"(1, 'aa', 41.0, cast('2020-01-01' as timestamp)), " +
s"(2, 'bb', 10.0, cast('2020-01-02' as timestamp)), " +
s"(3, 'cc', 15.5, cast('2020-01-03' as timestamp))")

createTable(purchases, purchasesColumns, Array.empty)
sql(s"INSERT INTO testcat.ns.$purchases VALUES " +
s"(1, 40.0, cast('2020-01-01' as timestamp)), " +
s"(3, 25.5, cast('2020-01-03' as timestamp)), " +
s"(4, 20.0, cast('2020-01-04' as timestamp))")

Seq(true, false).foreach { pushdownValues =>
withSQLConf(
SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "-1",
SQLConf.V2_BUCKETING_SHUFFLE_ENABLED.key -> "true",
SQLConf.V2_BUCKETING_PUSH_PART_VALUES_ENABLED.key -> pushdownValues.toString) {
val scanDF1 = spark.read.table(s"testcat.ns.$items").checkpoint().as("i")
val scanDF2 = spark.read.table(s"testcat.ns.$purchases").as("p")

val df = scanDF1
.join(scanDF2, col("id") === col("item_id"))
.selectExpr("id", "name", "i.price AS purchase_price", "p.price AS sale_price")
.orderBy("id", "purchase_price", "sale_price")
checkAnswer(
df,
Seq(Row(1, "aa", 41.0, 40.0), Row(3, "cc", 15.5, 25.5))
)
// 1 shuffle for SORT and 2 shuffles for JOIN are expected.
assert(collectAllShuffles(df.queryExecution.executedPlan).length === 3)
}
}
}
}

test("SPARK-53322: checkpointed scans can be shuffled by children on SPJ") {
withTempDir { dir =>
spark.sparkContext.setCheckpointDir(dir.getPath)
val itemsPartitions = Array(identity("id"))
createTable(items, itemsColumns, itemsPartitions)
sql(s"INSERT INTO testcat.ns.$items VALUES " +
s"(1, 'aa', 41.0, cast('2020-01-01' as timestamp)), " +
s"(2, 'bb', 10.0, cast('2020-01-02' as timestamp)), " +
s"(3, 'cc', 15.5, cast('2020-01-03' as timestamp))")

createTable(purchases, purchasesColumns, Array(identity("item_id")))
sql(s"INSERT INTO testcat.ns.$purchases VALUES " +
s"(1, 40.0, cast('2020-01-01' as timestamp)), " +
s"(3, 25.5, cast('2020-01-03' as timestamp)), " +
s"(4, 20.0, cast('2020-01-04' as timestamp))")

withSQLConf(
SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "-1",
SQLConf.V2_BUCKETING_SHUFFLE_ENABLED.key -> "true",
SQLConf.V2_BUCKETING_PUSH_PART_VALUES_ENABLED.key -> "true") {
val scanDF1 = spark.read.table(s"testcat.ns.$items").checkpoint().as("i")
val scanDF2 = spark.read.table(s"testcat.ns.$purchases").as("p")

val df = scanDF1
.join(scanDF2, col("id") === col("item_id"))
.selectExpr("id", "name", "i.price AS purchase_price", "p.price AS sale_price")
.orderBy("id", "purchase_price", "sale_price")
checkAnswer(
df,
Seq(Row(1, "aa", 41.0, 40.0), Row(3, "cc", 15.5, 25.5))
)

// One shuffle for the sort and one shuffle for one side of the JOIN are expected.
assert(collectAllShuffles(df.queryExecution.executedPlan).length === 2)
}
}
}
}
Loading