Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
Aggregate [count(if ((_common_expr_0#0 = false)) null else _common_expr_0#0) AS count_if((a > 0))#0L]
+- Project [id#0L, a#0, b#0, d#0, e#0, f#0, g#0, (a#0 > 0) AS _common_expr_0#0]
+- LocalRelation <empty>, [id#0L, a#0, b#0, d#0, e#0, f#0, g#0]
Project [_aggregateexpression#0L AS count_if((a > 0))#0L]
+- Aggregate [count(if ((_common_expr_0#0 = false)) null else _common_expr_0#0) AS _aggregateexpression#0L]
+- Project [id#0L, a#0, b#0, d#0, e#0, f#0, g#0, (a#0 > 0) AS _common_expr_0#0]
+- LocalRelation <empty>, [id#0L, a#0, b#0, d#0, e#0, f#0, g#0]
Original file line number Diff line number Diff line change
Expand Up @@ -126,6 +126,7 @@ class ProtoToParsedPlanTestSuite
Connect.CONNECT_EXTENSIONS_EXPRESSION_CLASSES.key,
"org.apache.spark.sql.connect.plugin.ExampleExpressionPlugin")
.set(org.apache.spark.sql.internal.SQLConf.ANSI_ENABLED.key, false.toString)
.set(org.apache.spark.sql.internal.SQLConf.USE_COMMON_EXPR_ID_FOR_ALIAS.key, false.toString)
}

protected val suiteBaseResourcePath = commonResourcePath.resolve("query-tests")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@

package org.apache.spark.sql.catalyst.expressions

import org.apache.spark.sql.catalyst.trees.TreePattern.{COMMON_EXPR_REF, TreePattern, WITH_EXPRESSION}
import org.apache.spark.sql.catalyst.trees.TreePattern.{AGGREGATE_EXPRESSION, COMMON_EXPR_REF, TreePattern, WITH_EXPRESSION}
import org.apache.spark.sql.types.DataType

/**
Expand All @@ -27,6 +27,10 @@ import org.apache.spark.sql.types.DataType
*/
case class With(child: Expression, defs: Seq[CommonExpressionDef])
extends Expression with Unevaluable {
// We do not allow With to be created with an AggregateExpression in the child, as this would
// create a dangling CommonExpressionRef after rewriting it in RewriteWithExpression.
assert(!child.containsPattern(AGGREGATE_EXPRESSION))

override val nodePatterns: Seq[TreePattern] = Seq(WITH_EXPRESSION)
override def dataType: DataType = child.dataType
override def nullable: Boolean = child.nullable
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,36 +21,65 @@ import scala.collection.mutable

import org.apache.spark.SparkException
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, PlanHelper, Project}
import org.apache.spark.sql.catalyst.planning.PhysicalAggregation
import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, LogicalPlan, PlanHelper, Project}
import org.apache.spark.sql.catalyst.rules.Rule
import org.apache.spark.sql.catalyst.trees.TreePattern.{COMMON_EXPR_REF, WITH_EXPRESSION}
import org.apache.spark.sql.internal.SQLConf

/**
* Rewrites the `With` expressions by adding a `Project` to pre-evaluate the common expressions, or
* just inline them if they are cheap.
*
* Since this rule can introduce new `Project` operators, it is advised to run [[CollapseProject]]
* after this rule.
*
* Note: For now we only use `With` in a few `RuntimeReplaceable` expressions. If we expand its
* usage, we should support aggregate/window functions as well.
*/
object RewriteWithExpression extends Rule[LogicalPlan] {
override def apply(plan: LogicalPlan): LogicalPlan = {
plan.transformDownWithSubqueriesAndPruning(_.containsPattern(WITH_EXPRESSION)) {
plan.transformUpWithSubqueriesAndPruning(_.containsPattern(WITH_EXPRESSION)) {
// For aggregates, separate the computation of the aggregations themselves from the final
// result by moving the final result computation into a projection above it. This prevents
// this rule from producing an invalid Aggregate operator.
case p @ PhysicalAggregation(
groupingExpressions, aggregateExpressions, resultExpressions, child)
if p.expressions.exists(_.containsPattern(WITH_EXPRESSION)) =>
// PhysicalAggregation returns aggregateExpressions as attribute references, which we change
// to aliases so that they can be referred to by resultExpressions.
val aggExprs = aggregateExpressions.map(
ae => Alias(ae, "_aggregateexpression")(ae.resultId))
val aggExprIds = aggExprs.map(_.exprId).toSet
val resExprs = resultExpressions.map(_.transform {
case a: AttributeReference if aggExprIds.contains(a.exprId) =>
a.withName("_aggregateexpression")
}.asInstanceOf[NamedExpression])
// Rewrite the projection and the aggregate separately and then piece them together.
val agg = Aggregate(groupingExpressions, groupingExpressions ++ aggExprs, child)
val rewrittenAgg = applyInternal(agg)
val proj = Project(resExprs, rewrittenAgg)
applyInternal(proj)
case p if p.expressions.exists(_.containsPattern(WITH_EXPRESSION)) =>
val inputPlans = p.children.toArray
var newPlan: LogicalPlan = p.mapExpressions { expr =>
rewriteWithExprAndInputPlans(expr, inputPlans)
}
newPlan = newPlan.withNewChildren(inputPlans.toIndexedSeq)
// Since we add extra Projects with extra columns to pre-evaluate the common expressions,
// the current operator may have extra columns if it inherits the output columns from its
// child, and we need to project away the extra columns to keep the plan schema unchanged.
assert(p.output.length <= newPlan.output.length)
if (p.output.length < newPlan.output.length) {
assert(p.outputSet.subsetOf(newPlan.outputSet))
Project(p.output, newPlan)
} else {
newPlan
}
applyInternal(p)
}
}

private def applyInternal(p: LogicalPlan): LogicalPlan = {
val inputPlans = p.children.toArray
var newPlan: LogicalPlan = p.mapExpressions { expr =>
rewriteWithExprAndInputPlans(expr, inputPlans)
}
newPlan = newPlan.withNewChildren(inputPlans.toIndexedSeq)
// Since we add extra Projects with extra columns to pre-evaluate the common expressions,
// the current operator may have extra columns if it inherits the output columns from its
// child, and we need to project away the extra columns to keep the plan schema unchanged.
assert(p.output.length <= newPlan.output.length)
if (p.output.length < newPlan.output.length) {
assert(p.outputSet.subsetOf(newPlan.outputSet))
Project(p.output, newPlan)
} else {
newPlan
}
}

Expand Down Expand Up @@ -93,7 +122,12 @@ object RewriteWithExpression extends Rule[LogicalPlan] {
// if it's ref count is 1.
refToExpr(id) = child
} else {
val alias = Alias(child, s"_common_expr_$index")()
val aliasName = if (SQLConf.get.getConf(SQLConf.USE_COMMON_EXPR_ID_FOR_ALIAS)) {
s"_common_expr_${id.id}"
} else {
s"_common_expr_$index"
}
val alias = Alias(child, aliasName)()
val fakeProj = Project(Seq(alias), inputPlans(childProjectionIndex))
if (PlanHelper.specialExpressionsInUnsupportedOperator(fakeProj).nonEmpty) {
// We have to inline the common expression if it cannot be put in a Project.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -517,6 +517,30 @@ abstract class QueryPlan[PlanType <: QueryPlan[PlanType]]
transformDownWithSubqueriesAndPruning(AlwaysProcess.fn, UnknownRuleId)(f)
}

/**
* Same as `transformUpWithSubqueries` except allows for pruning opportunities.
*/
def transformUpWithSubqueriesAndPruning(
cond: TreePatternBits => Boolean,
ruleId: RuleId = UnknownRuleId)
(f: PartialFunction[PlanType, PlanType]): PlanType = {
val g: PartialFunction[PlanType, PlanType] = new PartialFunction[PlanType, PlanType] {
override def isDefinedAt(x: PlanType): Boolean = true

override def apply(plan: PlanType): PlanType = {
val transformed = plan.transformExpressionsUpWithPruning(t =>
t.containsPattern(PLAN_EXPRESSION) && cond(t)) {
case planExpression: PlanExpression[PlanType@unchecked] =>
val newPlan = planExpression.plan.transformUpWithSubqueriesAndPruning(cond, ruleId)(f)
planExpression.withNewPlan(newPlan)
}
f.applyOrElse[PlanType, PlanType](transformed, identity)
}
}

transformUpWithPruning(cond, ruleId)(g)
}

/**
* This method is the top-down (pre-order) counterpart of transformUpWithSubqueries.
* Returns a copy of this node where the given partial function has been recursively applied
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3429,6 +3429,17 @@ object SQLConf {
.booleanConf
.createWithDefault(false)

val USE_COMMON_EXPR_ID_FOR_ALIAS =
buildConf("spark.sql.useCommonExprIdForAlias")
.internal()
.doc("When true, use the common expression ID for the alias when rewriting With " +
"expressions. Otherwise, use the index of the common expression definition. When true " +
"this avoids duplicate alias names, but is helpful to set to false for testing to ensure" +
"that alias names are consistent.")
.version("4.0.0")
.booleanConf
.createWithDefault(true)

val USE_NULLS_FOR_MISSING_DEFAULT_COLUMN_VALUES =
buildConf("spark.sql.defaultColumn.useNullsForMissingDefaultValues")
.internal()
Expand Down
Loading