From 3b4eb1fbd8a351c29a12bfd94ec4cdbee803f416 Mon Sep 17 00:00:00 2001 From: Josh Rosen Date: Fri, 19 Nov 2021 15:24:52 -0800 Subject: [PATCH] [SPARK-37379][SQL] Add tree pattern pruning to CTESubstitution rule ### What changes were proposed in this pull request? This PR adds tree pattern pruning to the `CTESubstitution` analyzer rule. The rule will now exit early if the tree does not contain an `UnresolvedWith` node. ### Why are the changes needed? Analysis is eagerly performed after every DataFrame transformation. If a user's program performs a long chain of _n_ transformations to construct a large query plan then this can lead to _O(n^2)_ performance costs from `CTESubstitution` because it is applied _n_ times and each application traverses the entire logical plan tree (which contains _O(n)_ nodes). In the case of chained `withColumn` calls (leading to stacked `Project` nodes) it's possible to see _O(n^3)_ slowdowns where _n_ is the number of projects: this is because there are _n_ separate analysis phases, each of which calls `CTESubstitution.traverseAndSubstituteCTE`, where each call visits each of the _n_ `Project` nodes and each of their _O(n)_ expressions. Very large DataFrame plans typically do not use CTEs because there is not a DataFrame syntax for them (although they might appear in the plan if `sql(someQueryWithCTE)` is used). As a result, this PR's proposed optimization to skip `CTESubstitution` can greatly reduce the analysis cost for such plans. ### Does this PR introduce _any_ user-facing change? No. ### How was this patch tested? I believe that optimizer correctness is covered by existing tests. As a toy benchmark, I ran ``` import org.apache.spark.sql.DataFrame org.apache.spark.sql.catalyst.rules.RuleExecutor.resetMetrics() (1 to 600).foldLeft(spark.range(100).toDF)((df: DataFrame, i: Int) => df.withColumn(s"col$i", $"id" % i)) println(org.apache.spark.sql.catalyst.rules.RuleExecutor.dumpTimeSpent()) ``` on my laptop before and after this PR's changes (simulating a _O(n^3)_ case). Skipping `CTESubstitution` cut the running time from ~28.4 seconds to ~15.5 seconds. The bulk of the remaining time comes from `DeduplicateRelations`, for which I plan to submit a separate optimization PR. Closes #34658 from JoshRosen/CTESubstitution-tree-pattern-pruning. Authored-by: Josh Rosen Signed-off-by: Josh Rosen --- .../apache/spark/sql/catalyst/analysis/CTESubstitution.scala | 3 +++ .../sql/catalyst/plans/logical/basicLogicalOperators.scala | 2 ++ .../org/apache/spark/sql/catalyst/trees/TreePatterns.scala | 1 + 3 files changed, 6 insertions(+) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CTESubstitution.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CTESubstitution.scala index ec3d957f92ee3..2e2d415954695 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CTESubstitution.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CTESubstitution.scala @@ -48,6 +48,9 @@ import org.apache.spark.sql.internal.SQLConf.{LEGACY_CTE_PRECEDENCE_POLICY, Lega */ object CTESubstitution extends Rule[LogicalPlan] { def apply(plan: LogicalPlan): LogicalPlan = { + if (!plan.containsPattern(UNRESOLVED_WITH)) { + return plan + } val isCommand = plan.find { case _: Command | _: ParsedStatement | _: InsertIntoDir => true case _ => false diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala index f1b954d6c7eae..e8a632d01598f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala @@ -626,6 +626,8 @@ object View { case class UnresolvedWith( child: LogicalPlan, cteRelations: Seq[(String, SubqueryAlias)]) extends UnaryNode { + final override val nodePatterns: Seq[TreePattern] = Seq(UNRESOLVED_WITH) + override def output: Seq[Attribute] = child.output override def simpleString(maxFields: Int): String = { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreePatterns.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreePatterns.scala index 6c1b64dd0af6e..aad90ff695e10 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreePatterns.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreePatterns.scala @@ -111,6 +111,7 @@ object TreePattern extends Enumeration { val REPARTITION_OPERATION: Value = Value val UNION: Value = Value val UNRESOLVED_RELATION: Value = Value + val UNRESOLVED_WITH: Value = Value val TYPED_FILTER: Value = Value val WINDOW: Value = Value val WITH_WINDOW_DEFINITION: Value = Value