Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SPARK-45760][SQL] Add With expression to avoid duplicating expressions #43623

Closed
wants to merge 5 commits into from
Closed
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -1,2 +1,3 @@
Aggregate [count(if (((a#0 > 0) = false)) null else (a#0 > 0)) AS count_if((a > 0))#0L]
+- LocalRelation <empty>, [id#0L, a#0, b#0, d#0, e#0, f#0, g#0]
Aggregate [count(if ((_common_expr_0#0 = false)) null else _common_expr_0#0) AS count_if((a > 0))#0L]
+- Project [id#0L, a#0, b#0, d#0, e#0, f#0, g#0, (a#0 > 0) AS _common_expr_0#0]
+- LocalRelation <empty>, [id#0L, a#0, b#0, d#0, e#0, f#0, g#0]
Original file line number Diff line number Diff line change
@@ -1,2 +1,3 @@
Project [if ((regexp_extract(g#0, \d{2}(a|b|m), 0) = )) null else regexp_extract(g#0, \d{2}(a|b|m), 0) AS regexp_substr(g, \d{2}(a|b|m))#0]
+- LocalRelation <empty>, [id#0L, a#0, b#0, d#0, e#0, f#0, g#0]
Project [if ((_common_expr_0#0 = )) null else _common_expr_0#0 AS regexp_substr(g, \d{2}(a|b|m))#0]
+- Project [id#0L, a#0, b#0, d#0, e#0, f#0, g#0, regexp_extract(g#0, \d{2}(a|b|m), 0) AS _common_expr_0#0]
+- LocalRelation <empty>, [id#0L, a#0, b#0, d#0, e#0, f#0, g#0]
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,9 @@ import org.apache.spark.connect.proto
import org.apache.spark.sql.catalyst.{catalog, QueryPlanningTracker}
import org.apache.spark.sql.catalyst.analysis.{caseSensitiveResolution, Analyzer, FunctionRegistry, Resolver, TableFunctionRegistry}
import org.apache.spark.sql.catalyst.catalog.SessionCatalog
import org.apache.spark.sql.catalyst.optimizer.ReplaceExpressions
import org.apache.spark.sql.catalyst.optimizer.{ReplaceExpressions, RewriteWithExpression}
import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
import org.apache.spark.sql.catalyst.rules.RuleExecutor
import org.apache.spark.sql.connect.config.Connect
import org.apache.spark.sql.connect.planner.SparkConnectPlanner
import org.apache.spark.sql.connect.service.SessionHolder
Expand Down Expand Up @@ -181,8 +183,15 @@ class ProtoToParsedPlanTestSuite
val planner = new SparkConnectPlanner(SessionHolder.forTesting(spark))
val catalystPlan =
analyzer.executeAndCheck(planner.transformRelation(relation), new QueryPlanningTracker)
val actual =
removeMemoryAddress(normalizeExprIds(ReplaceExpressions(catalystPlan)).treeString)
val finalAnalyzedPlan = {
object Helper extends RuleExecutor[LogicalPlan] {
val batches =
Batch("Finish Analysis", Once, ReplaceExpressions) ::
Batch("Rewrite With expression", FixedPoint(10), RewriteWithExpression) :: Nil
}
Helper.execute(catalystPlan)
}
val actual = removeMemoryAddress(normalizeExprIds(finalAnalyzedPlan).treeString)
val goldenFile = goldenFilePath.resolve(relativePath).getParent.resolve(name + ".explain")
Try(readGoldenFile(goldenFile)) match {
case Success(expected) if expected == actual => // Test passes.
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.sql.catalyst.expressions

import org.apache.spark.sql.catalyst.trees.TreePattern.{COMMON_EXPR_REF, TreePattern, WITH_EXPRESSION}
import org.apache.spark.sql.types.DataType

/**
* An expression holder that keeps a list of common expressions and allow the actual expression to
* reference these common expressions. The common expressions are guaranteed to be evaluated only
* once even if it's referenced more than once. This is similar to CTE but is expression-level.
*/
case class With(child: Expression, defs: Seq[CommonExpressionDef])
extends Expression with Unevaluable {
override val nodePatterns: Seq[TreePattern] = Seq(WITH_EXPRESSION)
override def dataType: DataType = child.dataType
override def nullable: Boolean = child.nullable
override def children: Seq[Expression] = child +: defs
override protected def withNewChildrenInternal(
newChildren: IndexedSeq[Expression]): Expression = {
copy(child = newChildren.head, defs = newChildren.tail.map(_.asInstanceOf[CommonExpressionDef]))
}
}

/**
* A wrapper of common expression to carry the id.
*/
case class CommonExpressionDef(child: Expression, id: Long = CommonExpressionDef.newId)
extends UnaryExpression with Unevaluable {
override def dataType: DataType = child.dataType
override protected def withNewChildInternal(newChild: Expression): Expression =
copy(child = newChild)
}

/**
* A reference to the common expression by its id. Only resolved common expressions can be
* referenced, so that we can determine the data type and nullable of the reference node.
*/
case class CommonExpressionRef(id: Long, dataType: DataType, nullable: Boolean)
extends LeafExpression with Unevaluable {
def this(exprDef: CommonExpressionDef) = this(exprDef.id, exprDef.dataType, exprDef.nullable)
override val nodePatterns: Seq[TreePattern] = Seq(COMMON_EXPR_REF)
}

object CommonExpressionDef {
private[sql] val curId = new java.util.concurrent.atomic.AtomicLong()
def newId: Long = curId.getAndIncrement()
}
Original file line number Diff line number Diff line change
Expand Up @@ -154,7 +154,11 @@ case class NullIf(left: Expression, right: Expression, replacement: Expression)
extends RuntimeReplaceable with InheritAnalysisRules {

def this(left: Expression, right: Expression) = {
this(left, right, If(EqualTo(left, right), Literal.create(null, left.dataType), left))
this(left, right, {
val commonExpr = CommonExpressionDef(left)
val ref = new CommonExpressionRef(commonExpr)
With(If(EqualTo(ref, right), Literal.create(null, left.dataType), ref), Seq(commonExpr))
})
}

override def parameters: Seq[Expression] = Seq(left, right)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -147,6 +147,9 @@ abstract class Optimizer(catalogManager: CatalogManager)

val batches = (
Batch("Finish Analysis", Once, FinishAnalysis) ::
// We must run this batch after `ReplaceExpressions`, as `RuntimeReplaceable` expression
// may produce `With` expressions that need to be rewritten.
Batch("Rewrite With expression", fixedPoint, RewriteWithExpression) ::
//////////////////////////////////////////////////////////////////////////////////////////
// Optimizer rules start here
//////////////////////////////////////////////////////////////////////////////////////////
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,99 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.sql.catalyst.optimizer

import scala.collection.mutable

import org.apache.spark.sql.catalyst.expressions.{Alias, Attribute, AttributeMap, CommonExpressionRef, Expression, With}
import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, Project}
import org.apache.spark.sql.catalyst.rules.Rule
import org.apache.spark.sql.catalyst.trees.TreePattern.{COMMON_EXPR_REF, WITH_EXPRESSION}

/**
* Rewrites the `With` expressions by adding a `Project` to pre-evaluate the common expressions, or
* just inline them if they are cheap.
*
* Note: For now we only use `With` in a few `RuntimeReplaceable` expressions. If we expand its
* usage, we should support aggregate/window functions as well.
*/
object RewriteWithExpression extends Rule[LogicalPlan] {
override def apply(plan: LogicalPlan): LogicalPlan = {
plan.transformWithPruning(_.containsPattern(WITH_EXPRESSION)) {
case p if p.expressions.exists(_.containsPattern(WITH_EXPRESSION)) =>
val commonExprs = mutable.ArrayBuffer.empty[Alias]
// `With` can be nested, we should only rewrite the leaf `With` expression, as the outer
// `With` needs to add its own Project, in the next iteration when it becomes leaf.
// This is done via "transform down" and check if the common expression definitions does not
// contain nested `With`.
var newPlan: LogicalPlan = p.transformExpressionsDown {
case With(child, defs) if defs.forall(!_.containsPattern(WITH_EXPRESSION)) =>
val idToCheapExpr = mutable.HashMap.empty[Long, Expression]
val idToNonCheapExpr = mutable.HashMap.empty[Long, Alias]
defs.zipWithIndex.foreach { case (commonExprDef, index) =>
if (CollapseProject.isCheap(commonExprDef.child)) {
idToCheapExpr(commonExprDef.id) = commonExprDef.child
} else {
// TODO: we should calculate the ref count and also inline the common expression
// if it's ref count is 1.
val alias = Alias(commonExprDef.child, s"_common_expr_$index")()
commonExprs += alias
idToNonCheapExpr(commonExprDef.id) = alias
}
}

child.transformWithPruning(_.containsPattern(COMMON_EXPR_REF)) {
case ref: CommonExpressionRef =>
idToCheapExpr.getOrElse(ref.id, idToNonCheapExpr(ref.id).toAttribute)
}
}

var exprsToAdd = commonExprs.toSeq
val newChildren = newPlan.children.map { child =>
val (newExprs, others) = exprsToAdd.partition(_.references.subsetOf(child.outputSet))
exprsToAdd = others
if (newExprs.nonEmpty) {
Project(child.output ++ newExprs, child)
} else {
child
}
}

if (exprsToAdd.nonEmpty) {
// When we cannot rewrite the common expressions, force to inline them so that the query
// can still run. This can happen if the join condition contains `With` and the common
// expression references columns from both join sides.
// TODO: things can go wrong if the common expression is nondeterministic. We don't fix
// it for now to match the old buggy behavior when certain `RuntimeReplaceable`
// did not use the `With` expression.
val attrToExpr = AttributeMap(exprsToAdd.map { alias =>
alias.toAttribute -> alias.child
})
newPlan = newPlan.transformExpressionsUp {
case a: Attribute => attrToExpr.getOrElse(a, a)
}
}

newPlan = newPlan.withNewChildren(newChildren)
if (p.output == newPlan.output) {
newPlan
} else {
Project(p.output, newPlan)
}
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ object TreePattern extends Enumeration {
val CASE_WHEN: Value = Value
val CAST: Value = Value
val COALESCE: Value = Value
val COMMON_EXPR_REF: Value = Value
val CONCAT: Value = Value
val COUNT: Value = Value
val CREATE_NAMED_STRUCT: Value = Value
Expand Down Expand Up @@ -132,6 +133,7 @@ object TreePattern extends Enumeration {
val TYPED_FILTER: Value = Value
val WINDOW: Value = Value
val WINDOW_GROUP_LIMIT: Value = Value
val WITH_EXPRESSION: Value = Value
val WITH_WINDOW_DEFINITION: Value = Value

// Unresolved expression patterns (Alphabetically ordered)
Expand Down
Loading