Skip to content

Conversation

@dwsmith1983
Copy link
Contributor

@dwsmith1983 dwsmith1983 commented Nov 29, 2025

What changes were proposed in this pull request?

This PR enables Dynamic Partition Pruning (DPP) optimization when joining with CommandResult nodes (e.g., results from SHOW PARTITIONS).

Changes made to sql/core/src/main/scala/org/apache/spark/sql/execution/dynamicpruning/PartitionPruning.scala:

  1. Modified hasSelectivePredicate() to recognize CommandResult as selective (line 212)
  2. Modified calculatePlanOverhead() to return 0.0 overhead for CommandResult since data is already materialized (lines 187, 199)

Added test coverage in sql/core/src/test/scala/org/apache/spark/sql/DynamicPartitionPruningSuite.scala to verify DPP works correctly with CommandResult.

Built and tested against tag v4.0.1 locally to verify the results and Spark plan as well.

https://issues.apache.org/jira/browse/SPARK-54554

Why are the changes needed?

Previously, when using SHOW PARTITIONS results in a broadcast join, Spark would perform full table scans instead of applying Dynamic Partition Pruning.

Example scenario where this matters:
val partitions = spark.sql("SHOW PARTITIONS fact_table")
.selectExpr("cast(split(partition, '=')[1] as int) as partition_id")
.agg(max("partition_id"))

spark.table("fact_table")
.join(partitions, col("partition_id") === col("max(partition_id)"))

Before this fix: Full table scan of all partitions
After this fix: DPP prunes to only the relevant partition(s)

Does this PR introduce any user-facing change?

Yes. Queries that join partitioned tables with SHOW PARTITIONS results (or other commands returning CommandResult) will now benefit from Dynamic Partition Pruning, potentially improving performance by scanning fewer partitions.

The behavior change is transparent to users - existing queries will simply run faster without any code changes required.

How was this patch tested?

Added new test case "DPP with CommandResult from SHOW PARTITIONS in broadcast join" in DynamicPartitionPruningSuite that verifies:
- DPP is applied when joining with CommandResult
- Correct query results are returned
- Plan contains DynamicPruningSubquery operator

Ran full DynamicPartitionPruning test suite (73 tests total) - all passed

Tested manually with local Spark build using various CommandResult scenarios

Was this patch authored or co-authored using generative AI tooling?

No

@github-actions github-actions bot added the SQL label Nov 29, 2025
@dwsmith1983 dwsmith1983 changed the title Spark 54554 dpp command result [SPARK-54554][SQL] Enable Dynamic Partition Pruning with CommandResult Nov 29, 2025
@dwsmith1983 dwsmith1983 force-pushed the SPARK-54554-dpp-command-result branch 3 times, most recently from e3d2c69 to f7ac51c Compare November 29, 2025 09:50
@dwsmith1983 dwsmith1983 marked this pull request as draft November 29, 2025 13:44
CommandResult (from commands like SHOW PARTITIONS) should qualify for
Dynamic Partition Pruning optimization in broadcast joins. Previously,
CommandResult was not recognized as a selective predicate, causing
full table scans even when the partition list was small.

This change:
- Treats CommandResult as selective in hasSelectivePredicate()
- Sets CommandResult overhead to 0.0 in calculatePlanOverhead()
  (data already materialized, no I/O cost)
- Adds test coverage in DynamicPartitionPruningSuite
- Removed unnecessary catch all
- Updated test name to reflect Jira ticket
- Corrected wrong df values in DynamicPartitionPruningSuite
- Updated Jira ticket in test name
- Removed test script

Co-authored-by: Tri Tam Hoang <tritam.hoang@gmail.com>
@dwsmith1983 dwsmith1983 force-pushed the SPARK-54554-dpp-command-result branch from d8857de to f8bd1e7 Compare November 29, 2025 16:23
@dwsmith1983 dwsmith1983 marked this pull request as ready for review November 29, 2025 16:24
@dwsmith1983
Copy link
Contributor Author

@dongjoon-hyun can you review this PR and provide feedback?

@disliketd
Copy link

I have concerns about this approach. The motivating use case relies on parsing string outputs from SHOW PARTITIONS to drive logic, which is an anti-pattern compared to standard scalar subqueries (WHERE col = (SELECT MAX(col)...)).

Furthermore, blindly treating all CommandResult nodes as 'selective' (hasSelectivePredicate = true) seems risky. If the command returns all partitions, we incur the DPP overhead without any pruning benefit. We shouldn't modify core optimizer heuristics to support a fragile query pattern.

@dwsmith1983
Copy link
Contributor Author

dwsmith1983 commented Dec 5, 2025

Hi @disliketd,

Using SHOW PARTITIONS is not an anti-pattern, it's the only way to avoid scanning all partitions just to get metadata that's already available.

In a production table with 1000 partitions where you want to process only the latest partition:

  • Without this PR: you would scan all 1000 partitions to compute MAX, then scan the latest partition = 1001 partition scans
  • With this PR: you have a metadata lookup + scan 1 partition = 1 partition scan

The difference becomes more dramatic as the number of partitions grows. Moreover, this is especially noticeable for cloud deployments (S3, GCS, Azure Blob Storage).

Project [date_id#10125, product_id#10126, store_id#10128, units_sold#10127]
+- Join Inner, (cast(store_id#10128 as bigint) = cast(max_store_id#92129 as bigint))
   :- Filter (isnotnull(store_id#10128) AND dynamicpruning#92130 [cast(store_id#10128 as bigint)])
   :  :  +- Filter isnotnull(max_store_id#92129)
   :  :     +- Aggregate [split(max(partition#92121), =, -1)[1] AS max_store_id#92129]
   :  :        +- CommandResult [partition#92121], Execute ShowPartitionsCommand, [[store_id=0], [store_id=1], [store_id=10], [store_id=11], [store_id=12], [store_id=13], [store_id=14], [store_id=15], [store_id=16], [store_id=17], [store_id=18], [store_id=19], [store_id=2], [store_id=20], [store_id=21], [store_id=22], [store_id=23], [store_id=24], [store_id=25], [store_id=26], [store_id=27], [store_id=28], [store_id=29], [store_id=3], [store_id=30], ... 75 more fields]
   :  :              +- ShowPartitionsCommand `spark_catalog`.`default`.`fact_stats_perf`, [partition#92121]
   :  +- Relation spark_catalog.default.fact_stats_perf[date_id#10125,product_id#10126,units_sold#10127,store_id#10128] parquet
   +- Filter isnotnull(max_store_id#92129)
      +- Aggregate [split(max(partition#92121), =, -1)[1] AS max_store_id#92129]
         +- CommandResult [partition#92121], Execute ShowPartitionsCommand, [[store_id=0], [store_id=1], [store_id=10], [store_id=11], [store_id=12], [store_id=13], [store_id=14], [store_id=15], [store_id=16], [store_id=17], [store_id=18], [store_id=19], [store_id=2], [store_id=20], [store_id=21], [store_id=22], [store_id=23], [store_id=24], [store_id=25], [store_id=26], [store_id=27], [store_id=28], [store_id=29], [store_id=3], [store_id=30], ... 75 more fields]
               +- ShowPartitionsCommand `spark_catalog`.`default`.`fact_stats_perf`, [partition#92121]

Our local benchmark showed with Intel (Mac) scenario 1: 27% and scenario 2: 46% improvement. On Arm (Mac), we saw scenario 1 and 2: 41%. In cloud environments with network latency, the improvement is order of magnitudes:

  • Eliminating hundreds/thousands of object storage API calls
  • No network round-trips for each partition scan
  • No rate limiting delays

The implementation is not blind as there are existing safeguards in the pruningHasBenefit function.

Your suggestion would still invoke a table scan.

Project [date_id#20274, product_id#20275, store_id#20277, units_sold#20276]
+- Filter (isnotnull(store_id#20277) AND (store_id#20277 = scalar-subquery#102280 []))
   :  +- Aggregate [max(store_id#102286) AS max(store_id)#102282]
   :     +- Project [store_id#102286]
   :        +- Relation spark_catalog.default.fact_stats_perf[...] parquet
   +- Relation spark_catalog.default.fact_stats_perf[...] parquet

Real world experience with full table scan: in a financial table we have in AWS, the table scan for finding max(partition) takes around 8-10mins. By using the metadata driven approach, 2-4s. When you have 1000s of production pipelines and many users, this becomes a waste of time and money for an organization.

You can locally benchmark this yourself.

import org.apache.spark.sql.SparkSession
import org.apache.spark.sql.functions._
import scala.math._

object DPPPerformanceBenchmark {
  case class Stats(min: Double, max: Double, median: Double, avg: Double, stdDev: Double) {
    override def toString: String = {
      f"  Min:     $min%8.2f ms\n" +
      f"  Max:     $max%8.2f ms\n" +
      f"  Median:  $median%8.2f ms\n" +
      f"  Average: $avg%8.2f ms\n" +
      f"  StdDev:  $stdDev%8.2f ms"
    }
  }

  def calculateStats(values: Seq[Double]): Stats = {
    val sorted = values.sorted
    val n = sorted.length
    val avg = sorted.sum / n
    val variance = sorted.map(x => scala.math.pow(x - avg, 2)).sum / n
    val stdDev = scala.math.sqrt(variance)

    val median = if (n % 2 == 0) {
      (sorted(n/2 - 1) + sorted(n/2)) / 2.0
    } else {
      sorted(n/2)
    }

    Stats(sorted.head, sorted.last, median, avg, stdDev)
  }

  def main(args: Array[String]): Unit = {
    val spark = SparkSession
      .builder()
      .appName("DPP Performance Benchmark")
      .master("local[*]")
      .config("spark.sql.dynamicPartitionPruning.enabled", "true")
      .config("spark.sql.dynamicPartitionPruning.reuseBroadcastOnly", "true")
      .getOrCreate()

    import spark.implicits._

    // Configuration
    val NUM_PARTITIONS = 100
    val ROWS_PER_PARTITION = 100
    val NUM_ITERATIONS = 1000

    println("=" * 80)
    println("DYNAMIC PARTITION PRUNING PERFORMANCE BENCHMARK")
    println("=" * 80)
    println()
    println(s"Configuration:")
    println(s"  Partitions: $NUM_PARTITIONS")
    println(s"  Rows per partition: $ROWS_PER_PARTITION")
    println(s"  Total rows: ${NUM_PARTITIONS * ROWS_PER_PARTITION}")
    println(s"  Iterations per test: $NUM_ITERATIONS")
    println()

    // Setup test data
    spark.sql("DROP TABLE IF EXISTS fact_stats_perf")

    val factData = (0 until NUM_PARTITIONS).flatMap { partitionId =>
      (1 to ROWS_PER_PARTITION).map { i =>
        (1000 * partitionId + i, partitionId, i % 50, 10 + (i % 50))
      }
    }.toDF("date_id", "store_id", "product_id", "units_sold")

    factData
      .write
      .mode("overwrite")
      .partitionBy("store_id")
      .format("parquet")
      .saveAsTable("fact_stats_perf")

    spark.sql("ANALYZE TABLE fact_stats_perf COMPUTE STATISTICS FOR ALL COLUMNS")

    println(s"Created fact_stats_perf with ${factData.count()} rows")
    println(s"Partitions: ${spark.sql("SHOW PARTITIONS fact_stats_perf").count()}")
    println()

    // Warmup
    println("JVM warm up (20 iterations)...")
    for (_ <- 1 to 20) {
      // DPP with SHOW PARTITIONS (metadata-only, fast)
      val maxPartitionDF = spark.sql("SHOW PARTITIONS fact_stats_perf")
        .agg(org.apache.spark.sql.functions.max("partition").alias("max_partition"))
        .selectExpr("split(max_partition, '=')[1] as max_store_id")
      maxPartitionDF.createOrReplaceTempView("max_partition_warmup")
      spark.sql("""
        SELECT f.date_id, f.product_id, f.store_id, f.units_sold
        FROM fact_stats_perf f
        JOIN max_partition_warmup m ON f.store_id = m.max_store_id
      """).collect()

      // Standard scalar subquery (requires table scan)
      spark.sql("""
        SELECT f.date_id, f.product_id, f.store_id, f.units_sold
        FROM fact_stats_perf f
        WHERE f.store_id = (SELECT MAX(store_id) FROM fact_stats_perf)
      """).collect()
    }

    // ============================================================================
    // SCENARIO 1: Base Case - Single Partition Selection (MAX)
    // ============================================================================
    println("=" * 80)
    println("SCENARIO 1: Base Case - Single Partition Selection (MAX)")
    println("=" * 80)
    println()

    // Approach 1A: DPP with CommandResult
    println(s"Running Approach 1A: DPP with CommandResult ($NUM_ITERATIONS iterations)...")
    val dppTimes1 = (1 to NUM_ITERATIONS).map { i =>
      if (i % 200 == 0) print(s"$i...")

      val start = System.nanoTime()
      val maxPartitionDF = spark.sql("SHOW PARTITIONS fact_stats_perf")
        .agg(org.apache.spark.sql.functions.max("partition").alias("max_partition"))
        .selectExpr("split(max_partition, '=')[1] as max_store_id")
      maxPartitionDF.createOrReplaceTempView("max_partition_dpp")

      val df = spark.sql("""
        SELECT f.date_id, f.product_id, f.store_id, f.units_sold
        FROM fact_stats_perf f
        JOIN max_partition_dpp m ON f.store_id = m.max_store_id
      """)

      val result = df.collect()
      val elapsed = (System.nanoTime() - start) / 1e6

      if (i == 1) {
        val hasDpp = df.queryExecution.optimizedPlan.toString().contains("dynamicpruning")
        println(s"\n  DPP optimization active: $hasDpp")
        println(s"  Result count: ${result.length}")
      }

      elapsed
    }

    // Approach 2A: Standard Scalar Subquery (requires table scan for MAX)
    println(s"\nRunning Approach 2A: Standard Scalar Subquery ($NUM_ITERATIONS iterations)...")
    val scalarTimes1 = (1 to NUM_ITERATIONS).map { i =>
      if (i % 200 == 0) print(s"$i...")

      val start = System.nanoTime()
      val df = spark.sql("""
        SELECT f.date_id, f.product_id, f.store_id, f.units_sold
        FROM fact_stats_perf f
        WHERE f.store_id = (SELECT MAX(store_id) FROM fact_stats_perf)
      """)

      val result = df.collect()
      val elapsed = (System.nanoTime() - start) / 1e6

      if (i == 1) {
        println(s"\n  Result count: ${result.length}")
      }

      elapsed
    }
    println(" Done!\n")

    val dppStats1 = calculateStats(dppTimes1)
    val scalarStats1 = calculateStats(scalarTimes1)

    println("Scenario 1 Results:")
    println()
    println("Approach 1A: DPP with CommandResult (JOIN)")
    println(dppStats1)
    println()
    println("Approach 2A: Standard Scalar Subquery (WHERE with table scan)")
    println(scalarStats1)
    println()

    val avgDiff1 = scalarStats1.avg - dppStats1.avg
    val percentDiff1 = (avgDiff1 / scalarStats1.avg) * 100
    println(f"Average difference: ${avgDiff1}%8.2f ms (${percentDiff1}%+6.2f%%)")
    if (avgDiff1 > 0) {
      println(f"DPP with CommandResult (metadata) is FASTER by ${avgDiff1.abs}%.2f ms (${percentDiff1.abs}%.2f%%)")
    } else {
      println(f"Standard Scalar Subquery (table scan) is FASTER by ${avgDiff1.abs}%.2f ms (${percentDiff1.abs}%.2f%%)")
    }
    println()

    // ============================================================================
    // SCENARIO 2: Large Selection - Many Partitions
    // ============================================================================
    println("=" * 80)
    println("SCENARIO 2: Large Selection - Many Partitions (Top 50%)")
    println("=" * 80)
    println()

    val topNPartitions = NUM_PARTITIONS / 2
    println(s"Selecting top $topNPartitions partitions out of $NUM_PARTITIONS")
    println()

    // Approach 1B: DPP with CommandResult (many partitions)
    println(s"Running Approach 1B: DPP with CommandResult ($NUM_ITERATIONS iterations)...")
    val dppTimes2 = (1 to NUM_ITERATIONS).map { i =>
      if (i % 200 == 0) print(s"$i...")

      val start = System.nanoTime()
      val topPartitionsDF = spark.sql("SHOW PARTITIONS fact_stats_perf")
        .selectExpr("CAST(split(partition, '=')[1] AS INT) as store_id")
        .orderBy(desc("store_id"))
        .limit(topNPartitions)
      topPartitionsDF.createOrReplaceTempView("top_partitions_dpp")

      val df = spark.sql("""
        SELECT f.date_id, f.product_id, f.store_id, f.units_sold
        FROM fact_stats_perf f
        JOIN top_partitions_dpp m ON f.store_id = m.store_id
      """)

      val result = df.collect()
      val elapsed = (System.nanoTime() - start) / 1e6

      if (i == 1) {
        val hasDpp = df.queryExecution.optimizedPlan.toString().contains("dynamicpruning")
        println(s"\n  DPP optimization active: $hasDpp")
        println(s"  Result count: ${result.length}")
        println(s"  Selected partitions: $topNPartitions")
      }

      elapsed
    }
    println(" Done!")

    // Approach 2B: Standard subquery with IN (using temp view for top partitions)
    println(s"\nRunning Approach 2B: Standard IN Subquery ($NUM_ITERATIONS iterations)...")
    val scalarTimes2 = (1 to NUM_ITERATIONS).map { i =>
      if (i % 200 == 0) print(s"$i...")

      val start = System.nanoTime()

      // Get top N store_ids from the table itself (requires scan)
      val topStoreIds = spark.sql(s"""
        SELECT DISTINCT store_id
        FROM fact_stats_perf
        ORDER BY store_id DESC
        LIMIT $topNPartitions
      """)
      topStoreIds.createOrReplaceTempView("top_stores_scalar")

      val df = spark.sql("""
        SELECT f.date_id, f.product_id, f.store_id, f.units_sold
        FROM fact_stats_perf f
        WHERE f.store_id IN (SELECT store_id FROM top_stores_scalar)
      """)

      val result = df.collect()
      val elapsed = (System.nanoTime() - start) / 1e6

      if (i == 1) {
        println(s"\n  Result count: ${result.length}")
      }

      elapsed
    }

    val dppStats2 = calculateStats(dppTimes2)
    val scalarStats2 = calculateStats(scalarTimes2)

    println("Scenario 2 Results:")
    println()
    println("Approach 1B: DPP with CommandResult (JOIN, many partitions)")
    println(dppStats2)
    println()
    println("Approach 2B: Standard IN Subquery (with table scan)")
    println(scalarStats2)
    println()

    val avgDiff2 = scalarStats2.avg - dppStats2.avg
    val percentDiff2 = (avgDiff2 / scalarStats2.avg) * 100
    println(f"Average difference: ${avgDiff2}%8.2f ms (${percentDiff2}%+6.2f%%)")
    if (avgDiff2 > 0) {
      println(f"DPP with CommandResult (metadata) is FASTER by ${avgDiff2.abs}%.2f ms (${percentDiff2.abs}%.2f%%)")
    } else {
      println(f"Standard IN Subquery (table scan) is FASTER by ${avgDiff2.abs}%.2f ms (${percentDiff2.abs}%.2f%%)")
    }
    println()

    // ============================================================================
    // SUMMARY
    // ============================================================================
    println("=" * 80)
    println("SUMMARY")
    println("=" * 80)
    println()
    println("Configuration:")
    println(f"  Total partitions: $NUM_PARTITIONS")
    println(f"  Rows per partition: $ROWS_PER_PARTITION")
    println(f"  Iterations: $NUM_ITERATIONS")
    println()
    println("Scenario 1: Single Partition (MAX) - The motivating use case")
    println(f"  DPP avg:    ${dppStats1.avg}%8.2f ms")
    println(f"  Scalar avg: ${scalarStats1.avg}%8.2f ms")
    println(f"  Difference: ${avgDiff1}%8.2f ms (${percentDiff1}%+6.2f%%)")
    println()
    println(s"Scenario 2: Many Partitions (Top $topNPartitions) - Reviewer's concern")
    println(f"  DPP avg:    ${dppStats2.avg}%8.2f ms")
    println(f"  Scalar avg: ${scalarStats2.avg}%8.2f ms")
    println(f"  Difference: ${avgDiff2}%8.2f ms (${percentDiff2}%+6.2f%%)")
    println()

    // Print plan comparison for Scenario 1
    println("=" * 80)
    println("PLAN COMPARISON - Scenario 1 (Single Partition)")
    println("=" * 80)
    println()
    println("DPP Approach - Optimized Plan:")
    val dppDF = {
      val maxPartitionDF = spark.sql("SHOW PARTITIONS fact_stats_perf")
        .agg(org.apache.spark.sql.functions.max("partition").alias("max_partition"))
        .selectExpr("split(max_partition, '=')[1] as max_store_id")
      maxPartitionDF.createOrReplaceTempView("max_partition_final")
      spark.sql("""
        SELECT f.date_id, f.product_id, f.store_id, f.units_sold
        FROM fact_stats_perf f
        JOIN max_partition_final m ON f.store_id = m.max_store_id
      """)
    }
    println(dppDF.queryExecution.optimizedPlan)
    println()

    println("Standard Scalar Subquery Approach - Optimized Plan:")
    val scalarDF = spark.sql("""
      SELECT f.date_id, f.product_id, f.store_id, f.units_sold
      FROM fact_stats_perf f
      WHERE f.store_id = (SELECT MAX(store_id) FROM fact_stats_perf)
    """)
    println(scalarDF.queryExecution.optimizedPlan)
    println()

    // Cleanup
    spark.sql("DROP TABLE IF EXISTS fact_stats_perf")
  }
}

DPPPerformanceBenchmark.main(Array(""))

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants