-
Notifications
You must be signed in to change notification settings - Fork 29k
[SPARK-54554][SQL] Enable Dynamic Partition Pruning with CommandResult #53263
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: master
Are you sure you want to change the base?
[SPARK-54554][SQL] Enable Dynamic Partition Pruning with CommandResult #53263
Conversation
e3d2c69 to
f7ac51c
Compare
CommandResult (from commands like SHOW PARTITIONS) should qualify for Dynamic Partition Pruning optimization in broadcast joins. Previously, CommandResult was not recognized as a selective predicate, causing full table scans even when the partition list was small. This change: - Treats CommandResult as selective in hasSelectivePredicate() - Sets CommandResult overhead to 0.0 in calculatePlanOverhead() (data already materialized, no I/O cost) - Adds test coverage in DynamicPartitionPruningSuite - Removed unnecessary catch all - Updated test name to reflect Jira ticket - Corrected wrong df values in DynamicPartitionPruningSuite - Updated Jira ticket in test name - Removed test script Co-authored-by: Tri Tam Hoang <tritam.hoang@gmail.com>
d8857de to
f8bd1e7
Compare
|
@dongjoon-hyun can you review this PR and provide feedback? |
|
I have concerns about this approach. The motivating use case relies on parsing string outputs from Furthermore, blindly treating all |
|
Hi @disliketd, Using SHOW PARTITIONS is not an anti-pattern, it's the only way to avoid scanning all partitions just to get metadata that's already available. In a production table with 1000 partitions where you want to process only the latest partition:
The difference becomes more dramatic as the number of partitions grows. Moreover, this is especially noticeable for cloud deployments (S3, GCS, Azure Blob Storage). Project [date_id#10125, product_id#10126, store_id#10128, units_sold#10127]
+- Join Inner, (cast(store_id#10128 as bigint) = cast(max_store_id#92129 as bigint))
:- Filter (isnotnull(store_id#10128) AND dynamicpruning#92130 [cast(store_id#10128 as bigint)])
: : +- Filter isnotnull(max_store_id#92129)
: : +- Aggregate [split(max(partition#92121), =, -1)[1] AS max_store_id#92129]
: : +- CommandResult [partition#92121], Execute ShowPartitionsCommand, [[store_id=0], [store_id=1], [store_id=10], [store_id=11], [store_id=12], [store_id=13], [store_id=14], [store_id=15], [store_id=16], [store_id=17], [store_id=18], [store_id=19], [store_id=2], [store_id=20], [store_id=21], [store_id=22], [store_id=23], [store_id=24], [store_id=25], [store_id=26], [store_id=27], [store_id=28], [store_id=29], [store_id=3], [store_id=30], ... 75 more fields]
: : +- ShowPartitionsCommand `spark_catalog`.`default`.`fact_stats_perf`, [partition#92121]
: +- Relation spark_catalog.default.fact_stats_perf[date_id#10125,product_id#10126,units_sold#10127,store_id#10128] parquet
+- Filter isnotnull(max_store_id#92129)
+- Aggregate [split(max(partition#92121), =, -1)[1] AS max_store_id#92129]
+- CommandResult [partition#92121], Execute ShowPartitionsCommand, [[store_id=0], [store_id=1], [store_id=10], [store_id=11], [store_id=12], [store_id=13], [store_id=14], [store_id=15], [store_id=16], [store_id=17], [store_id=18], [store_id=19], [store_id=2], [store_id=20], [store_id=21], [store_id=22], [store_id=23], [store_id=24], [store_id=25], [store_id=26], [store_id=27], [store_id=28], [store_id=29], [store_id=3], [store_id=30], ... 75 more fields]
+- ShowPartitionsCommand `spark_catalog`.`default`.`fact_stats_perf`, [partition#92121]Our local benchmark showed with Intel (Mac) scenario 1: 27% and scenario 2: 46% improvement. On Arm (Mac), we saw scenario 1 and 2: 41%. In cloud environments with network latency, the improvement is order of magnitudes:
The implementation is not blind as there are existing safeguards in the Your suggestion would still invoke a table scan. Project [date_id#20274, product_id#20275, store_id#20277, units_sold#20276]
+- Filter (isnotnull(store_id#20277) AND (store_id#20277 = scalar-subquery#102280 []))
: +- Aggregate [max(store_id#102286) AS max(store_id)#102282]
: +- Project [store_id#102286]
: +- Relation spark_catalog.default.fact_stats_perf[...] parquet
+- Relation spark_catalog.default.fact_stats_perf[...] parquetReal world experience with full table scan: in a financial table we have in AWS, the table scan for finding max(partition) takes around 8-10mins. By using the metadata driven approach, 2-4s. When you have 1000s of production pipelines and many users, this becomes a waste of time and money for an organization. You can locally benchmark this yourself. import org.apache.spark.sql.SparkSession
import org.apache.spark.sql.functions._
import scala.math._
object DPPPerformanceBenchmark {
case class Stats(min: Double, max: Double, median: Double, avg: Double, stdDev: Double) {
override def toString: String = {
f" Min: $min%8.2f ms\n" +
f" Max: $max%8.2f ms\n" +
f" Median: $median%8.2f ms\n" +
f" Average: $avg%8.2f ms\n" +
f" StdDev: $stdDev%8.2f ms"
}
}
def calculateStats(values: Seq[Double]): Stats = {
val sorted = values.sorted
val n = sorted.length
val avg = sorted.sum / n
val variance = sorted.map(x => scala.math.pow(x - avg, 2)).sum / n
val stdDev = scala.math.sqrt(variance)
val median = if (n % 2 == 0) {
(sorted(n/2 - 1) + sorted(n/2)) / 2.0
} else {
sorted(n/2)
}
Stats(sorted.head, sorted.last, median, avg, stdDev)
}
def main(args: Array[String]): Unit = {
val spark = SparkSession
.builder()
.appName("DPP Performance Benchmark")
.master("local[*]")
.config("spark.sql.dynamicPartitionPruning.enabled", "true")
.config("spark.sql.dynamicPartitionPruning.reuseBroadcastOnly", "true")
.getOrCreate()
import spark.implicits._
// Configuration
val NUM_PARTITIONS = 100
val ROWS_PER_PARTITION = 100
val NUM_ITERATIONS = 1000
println("=" * 80)
println("DYNAMIC PARTITION PRUNING PERFORMANCE BENCHMARK")
println("=" * 80)
println()
println(s"Configuration:")
println(s" Partitions: $NUM_PARTITIONS")
println(s" Rows per partition: $ROWS_PER_PARTITION")
println(s" Total rows: ${NUM_PARTITIONS * ROWS_PER_PARTITION}")
println(s" Iterations per test: $NUM_ITERATIONS")
println()
// Setup test data
spark.sql("DROP TABLE IF EXISTS fact_stats_perf")
val factData = (0 until NUM_PARTITIONS).flatMap { partitionId =>
(1 to ROWS_PER_PARTITION).map { i =>
(1000 * partitionId + i, partitionId, i % 50, 10 + (i % 50))
}
}.toDF("date_id", "store_id", "product_id", "units_sold")
factData
.write
.mode("overwrite")
.partitionBy("store_id")
.format("parquet")
.saveAsTable("fact_stats_perf")
spark.sql("ANALYZE TABLE fact_stats_perf COMPUTE STATISTICS FOR ALL COLUMNS")
println(s"Created fact_stats_perf with ${factData.count()} rows")
println(s"Partitions: ${spark.sql("SHOW PARTITIONS fact_stats_perf").count()}")
println()
// Warmup
println("JVM warm up (20 iterations)...")
for (_ <- 1 to 20) {
// DPP with SHOW PARTITIONS (metadata-only, fast)
val maxPartitionDF = spark.sql("SHOW PARTITIONS fact_stats_perf")
.agg(org.apache.spark.sql.functions.max("partition").alias("max_partition"))
.selectExpr("split(max_partition, '=')[1] as max_store_id")
maxPartitionDF.createOrReplaceTempView("max_partition_warmup")
spark.sql("""
SELECT f.date_id, f.product_id, f.store_id, f.units_sold
FROM fact_stats_perf f
JOIN max_partition_warmup m ON f.store_id = m.max_store_id
""").collect()
// Standard scalar subquery (requires table scan)
spark.sql("""
SELECT f.date_id, f.product_id, f.store_id, f.units_sold
FROM fact_stats_perf f
WHERE f.store_id = (SELECT MAX(store_id) FROM fact_stats_perf)
""").collect()
}
// ============================================================================
// SCENARIO 1: Base Case - Single Partition Selection (MAX)
// ============================================================================
println("=" * 80)
println("SCENARIO 1: Base Case - Single Partition Selection (MAX)")
println("=" * 80)
println()
// Approach 1A: DPP with CommandResult
println(s"Running Approach 1A: DPP with CommandResult ($NUM_ITERATIONS iterations)...")
val dppTimes1 = (1 to NUM_ITERATIONS).map { i =>
if (i % 200 == 0) print(s"$i...")
val start = System.nanoTime()
val maxPartitionDF = spark.sql("SHOW PARTITIONS fact_stats_perf")
.agg(org.apache.spark.sql.functions.max("partition").alias("max_partition"))
.selectExpr("split(max_partition, '=')[1] as max_store_id")
maxPartitionDF.createOrReplaceTempView("max_partition_dpp")
val df = spark.sql("""
SELECT f.date_id, f.product_id, f.store_id, f.units_sold
FROM fact_stats_perf f
JOIN max_partition_dpp m ON f.store_id = m.max_store_id
""")
val result = df.collect()
val elapsed = (System.nanoTime() - start) / 1e6
if (i == 1) {
val hasDpp = df.queryExecution.optimizedPlan.toString().contains("dynamicpruning")
println(s"\n DPP optimization active: $hasDpp")
println(s" Result count: ${result.length}")
}
elapsed
}
// Approach 2A: Standard Scalar Subquery (requires table scan for MAX)
println(s"\nRunning Approach 2A: Standard Scalar Subquery ($NUM_ITERATIONS iterations)...")
val scalarTimes1 = (1 to NUM_ITERATIONS).map { i =>
if (i % 200 == 0) print(s"$i...")
val start = System.nanoTime()
val df = spark.sql("""
SELECT f.date_id, f.product_id, f.store_id, f.units_sold
FROM fact_stats_perf f
WHERE f.store_id = (SELECT MAX(store_id) FROM fact_stats_perf)
""")
val result = df.collect()
val elapsed = (System.nanoTime() - start) / 1e6
if (i == 1) {
println(s"\n Result count: ${result.length}")
}
elapsed
}
println(" Done!\n")
val dppStats1 = calculateStats(dppTimes1)
val scalarStats1 = calculateStats(scalarTimes1)
println("Scenario 1 Results:")
println()
println("Approach 1A: DPP with CommandResult (JOIN)")
println(dppStats1)
println()
println("Approach 2A: Standard Scalar Subquery (WHERE with table scan)")
println(scalarStats1)
println()
val avgDiff1 = scalarStats1.avg - dppStats1.avg
val percentDiff1 = (avgDiff1 / scalarStats1.avg) * 100
println(f"Average difference: ${avgDiff1}%8.2f ms (${percentDiff1}%+6.2f%%)")
if (avgDiff1 > 0) {
println(f"DPP with CommandResult (metadata) is FASTER by ${avgDiff1.abs}%.2f ms (${percentDiff1.abs}%.2f%%)")
} else {
println(f"Standard Scalar Subquery (table scan) is FASTER by ${avgDiff1.abs}%.2f ms (${percentDiff1.abs}%.2f%%)")
}
println()
// ============================================================================
// SCENARIO 2: Large Selection - Many Partitions
// ============================================================================
println("=" * 80)
println("SCENARIO 2: Large Selection - Many Partitions (Top 50%)")
println("=" * 80)
println()
val topNPartitions = NUM_PARTITIONS / 2
println(s"Selecting top $topNPartitions partitions out of $NUM_PARTITIONS")
println()
// Approach 1B: DPP with CommandResult (many partitions)
println(s"Running Approach 1B: DPP with CommandResult ($NUM_ITERATIONS iterations)...")
val dppTimes2 = (1 to NUM_ITERATIONS).map { i =>
if (i % 200 == 0) print(s"$i...")
val start = System.nanoTime()
val topPartitionsDF = spark.sql("SHOW PARTITIONS fact_stats_perf")
.selectExpr("CAST(split(partition, '=')[1] AS INT) as store_id")
.orderBy(desc("store_id"))
.limit(topNPartitions)
topPartitionsDF.createOrReplaceTempView("top_partitions_dpp")
val df = spark.sql("""
SELECT f.date_id, f.product_id, f.store_id, f.units_sold
FROM fact_stats_perf f
JOIN top_partitions_dpp m ON f.store_id = m.store_id
""")
val result = df.collect()
val elapsed = (System.nanoTime() - start) / 1e6
if (i == 1) {
val hasDpp = df.queryExecution.optimizedPlan.toString().contains("dynamicpruning")
println(s"\n DPP optimization active: $hasDpp")
println(s" Result count: ${result.length}")
println(s" Selected partitions: $topNPartitions")
}
elapsed
}
println(" Done!")
// Approach 2B: Standard subquery with IN (using temp view for top partitions)
println(s"\nRunning Approach 2B: Standard IN Subquery ($NUM_ITERATIONS iterations)...")
val scalarTimes2 = (1 to NUM_ITERATIONS).map { i =>
if (i % 200 == 0) print(s"$i...")
val start = System.nanoTime()
// Get top N store_ids from the table itself (requires scan)
val topStoreIds = spark.sql(s"""
SELECT DISTINCT store_id
FROM fact_stats_perf
ORDER BY store_id DESC
LIMIT $topNPartitions
""")
topStoreIds.createOrReplaceTempView("top_stores_scalar")
val df = spark.sql("""
SELECT f.date_id, f.product_id, f.store_id, f.units_sold
FROM fact_stats_perf f
WHERE f.store_id IN (SELECT store_id FROM top_stores_scalar)
""")
val result = df.collect()
val elapsed = (System.nanoTime() - start) / 1e6
if (i == 1) {
println(s"\n Result count: ${result.length}")
}
elapsed
}
val dppStats2 = calculateStats(dppTimes2)
val scalarStats2 = calculateStats(scalarTimes2)
println("Scenario 2 Results:")
println()
println("Approach 1B: DPP with CommandResult (JOIN, many partitions)")
println(dppStats2)
println()
println("Approach 2B: Standard IN Subquery (with table scan)")
println(scalarStats2)
println()
val avgDiff2 = scalarStats2.avg - dppStats2.avg
val percentDiff2 = (avgDiff2 / scalarStats2.avg) * 100
println(f"Average difference: ${avgDiff2}%8.2f ms (${percentDiff2}%+6.2f%%)")
if (avgDiff2 > 0) {
println(f"DPP with CommandResult (metadata) is FASTER by ${avgDiff2.abs}%.2f ms (${percentDiff2.abs}%.2f%%)")
} else {
println(f"Standard IN Subquery (table scan) is FASTER by ${avgDiff2.abs}%.2f ms (${percentDiff2.abs}%.2f%%)")
}
println()
// ============================================================================
// SUMMARY
// ============================================================================
println("=" * 80)
println("SUMMARY")
println("=" * 80)
println()
println("Configuration:")
println(f" Total partitions: $NUM_PARTITIONS")
println(f" Rows per partition: $ROWS_PER_PARTITION")
println(f" Iterations: $NUM_ITERATIONS")
println()
println("Scenario 1: Single Partition (MAX) - The motivating use case")
println(f" DPP avg: ${dppStats1.avg}%8.2f ms")
println(f" Scalar avg: ${scalarStats1.avg}%8.2f ms")
println(f" Difference: ${avgDiff1}%8.2f ms (${percentDiff1}%+6.2f%%)")
println()
println(s"Scenario 2: Many Partitions (Top $topNPartitions) - Reviewer's concern")
println(f" DPP avg: ${dppStats2.avg}%8.2f ms")
println(f" Scalar avg: ${scalarStats2.avg}%8.2f ms")
println(f" Difference: ${avgDiff2}%8.2f ms (${percentDiff2}%+6.2f%%)")
println()
// Print plan comparison for Scenario 1
println("=" * 80)
println("PLAN COMPARISON - Scenario 1 (Single Partition)")
println("=" * 80)
println()
println("DPP Approach - Optimized Plan:")
val dppDF = {
val maxPartitionDF = spark.sql("SHOW PARTITIONS fact_stats_perf")
.agg(org.apache.spark.sql.functions.max("partition").alias("max_partition"))
.selectExpr("split(max_partition, '=')[1] as max_store_id")
maxPartitionDF.createOrReplaceTempView("max_partition_final")
spark.sql("""
SELECT f.date_id, f.product_id, f.store_id, f.units_sold
FROM fact_stats_perf f
JOIN max_partition_final m ON f.store_id = m.max_store_id
""")
}
println(dppDF.queryExecution.optimizedPlan)
println()
println("Standard Scalar Subquery Approach - Optimized Plan:")
val scalarDF = spark.sql("""
SELECT f.date_id, f.product_id, f.store_id, f.units_sold
FROM fact_stats_perf f
WHERE f.store_id = (SELECT MAX(store_id) FROM fact_stats_perf)
""")
println(scalarDF.queryExecution.optimizedPlan)
println()
// Cleanup
spark.sql("DROP TABLE IF EXISTS fact_stats_perf")
}
}
DPPPerformanceBenchmark.main(Array("")) |
What changes were proposed in this pull request?
This PR enables Dynamic Partition Pruning (DPP) optimization when joining with CommandResult nodes (e.g., results from SHOW PARTITIONS).
Changes made to sql/core/src/main/scala/org/apache/spark/sql/execution/dynamicpruning/PartitionPruning.scala:
Added test coverage in sql/core/src/test/scala/org/apache/spark/sql/DynamicPartitionPruningSuite.scala to verify DPP works correctly with CommandResult.
Built and tested against tag v4.0.1 locally to verify the results and Spark plan as well.
https://issues.apache.org/jira/browse/SPARK-54554
Why are the changes needed?
Previously, when using SHOW PARTITIONS results in a broadcast join, Spark would perform full table scans instead of applying Dynamic Partition Pruning.
Example scenario where this matters:
val partitions = spark.sql("SHOW PARTITIONS fact_table")
.selectExpr("cast(split(partition, '=')[1] as int) as partition_id")
.agg(max("partition_id"))
spark.table("fact_table")
.join(partitions, col("partition_id") === col("max(partition_id)"))
Before this fix: Full table scan of all partitions
After this fix: DPP prunes to only the relevant partition(s)
Does this PR introduce any user-facing change?
Yes. Queries that join partitioned tables with SHOW PARTITIONS results (or other commands returning CommandResult) will now benefit from Dynamic Partition Pruning, potentially improving performance by scanning fewer partitions.
The behavior change is transparent to users - existing queries will simply run faster without any code changes required.
How was this patch tested?
Added new test case "DPP with CommandResult from SHOW PARTITIONS in broadcast join" in DynamicPartitionPruningSuite that verifies:
- DPP is applied when joining with CommandResult
- Correct query results are returned
- Plan contains DynamicPruningSubquery operator
Ran full DynamicPartitionPruning test suite (73 tests total) - all passed
Tested manually with local Spark build using various CommandResult scenarios
Was this patch authored or co-authored using generative AI tooling?
No