Skip to content

Commit

Permalink
[SPARK-14939][SQL] Add FoldablePropagation optimizer
Browse files Browse the repository at this point in the history
## What changes were proposed in this pull request?

This PR aims to add new **FoldablePropagation** optimizer that propagates foldable expressions by replacing all attributes with the aliases of original foldable expression. Other optimizations will take advantage of the propagated foldable expressions: e.g. `EliminateSorts` optimizer now can handle the following Case 2 and 3. (Case 1 is the previous implementation.)

1. Literals and foldable expression, e.g. "ORDER BY 1.0, 'abc', Now()"
2. Foldable ordinals, e.g. "SELECT 1.0, 'abc', Now() ORDER BY 1, 2, 3"
3. Foldable aliases, e.g. "SELECT 1.0 x, 'abc' y, Now() z ORDER BY x, y, z"

This PR has been generalized based on cloud-fan 's key ideas many times; he should be credited for the work he did.

**Before**
```
scala> sql("SELECT 1.0, Now() x ORDER BY 1, x").explain
== Physical Plan ==
WholeStageCodegen
:  +- Sort [1.0#5 ASC,x#0 ASC], true, 0
:     +- INPUT
+- Exchange rangepartitioning(1.0#5 ASC, x#0 ASC, 200), None
   +- WholeStageCodegen
      :  +- Project [1.0 AS 1.0#5,1461873043577000 AS x#0]
      :     +- INPUT
      +- Scan OneRowRelation[]
```

**After**
```
scala> sql("SELECT 1.0, Now() x ORDER BY 1, x").explain
== Physical Plan ==
WholeStageCodegen
:  +- Project [1.0 AS 1.0#5,1461873079484000 AS x#0]
:     +- INPUT
+- Scan OneRowRelation[]
```

## How was this patch tested?

Pass the Jenkins tests including a new test case.

Author: Dongjoon Hyun <dongjoon@apache.org>

Closes #12719 from dongjoon-hyun/SPARK-14939.
  • Loading branch information
dongjoon-hyun authored and cloud-fan committed May 19, 2016
1 parent e2ec32d commit 5907ebf
Show file tree
Hide file tree
Showing 6 changed files with 208 additions and 11 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -656,7 +656,7 @@ case class AssertNotNull(child: Expression, walkedTypePath: Seq[String])
extends UnaryExpression with NonSQLExpression {

override def dataType: DataType = child.dataType

override def foldable: Boolean = false
override def nullable: Boolean = false

override def eval(input: InternalRow): Any =
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ import scala.collection.immutable.HashSet
import scala.collection.mutable.ArrayBuffer

import org.apache.spark.sql.catalyst.{CatalystConf, SimpleCatalystConf}
import org.apache.spark.sql.catalyst.analysis.{CleanupAliases, DistinctAggregationRewriter, EliminateSubqueryAliases, EmptyFunctionRegistry}
import org.apache.spark.sql.catalyst.analysis._
import org.apache.spark.sql.catalyst.catalog.{InMemoryCatalog, SessionCatalog}
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.expressions.aggregate._
Expand Down Expand Up @@ -91,6 +91,7 @@ abstract class Optimizer(sessionCatalog: SessionCatalog, conf: CatalystConf)
CombineUnions,
// Constant folding and strength reduction
NullPropagation,
FoldablePropagation,
OptimizeIn(conf),
ConstantFolding,
LikeSimplification,
Expand Down Expand Up @@ -657,6 +658,45 @@ object NullPropagation extends Rule[LogicalPlan] {
}
}

/**
* Propagate foldable expressions:
* Replace attributes with aliases of the original foldable expressions if possible.
* Other optimizations will take advantage of the propagated foldable expressions.
*
* {{{
* SELECT 1.0 x, 'abc' y, Now() z ORDER BY x, y, 3
* ==> SELECT 1.0 x, 'abc' y, Now() z ORDER BY 1.0, 'abc', Now()
* }}}
*/
object FoldablePropagation extends Rule[LogicalPlan] {
def apply(plan: LogicalPlan): LogicalPlan = {
val foldableMap = AttributeMap(plan.flatMap {
case Project(projectList, _) => projectList.collect {
case a: Alias if a.resolved && a.child.foldable => (a.toAttribute, a)
}
case _ => Nil
})

if (foldableMap.isEmpty) {
plan
} else {
var stop = false
CleanupAliases(plan.transformUp {
case u: Union =>
stop = true
u
case c: Command =>
stop = true
c
case p: LogicalPlan if !stop => p.transformExpressions {
case a: AttributeReference if foldableMap.contains(a) =>
foldableMap(a)
}
})
}
}
}

/**
* Generate a list of additional filters from an operator's existing constraint but remove those
* that are either already part of the operator's condition or are part of the operator's child
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,29 +34,34 @@ class AggregateOptimizeSuite extends PlanTest {

object Optimize extends RuleExecutor[LogicalPlan] {
val batches = Batch("Aggregate", FixedPoint(100),
FoldablePropagation,
RemoveLiteralFromGroupExpressions,
RemoveRepetitionFromGroupExpressions) :: Nil
}

val testRelation = LocalRelation('a.int, 'b.int, 'c.int)

test("remove literals in grouping expression") {
val input = LocalRelation('a.int, 'b.int)
val query = testRelation.groupBy('a, Literal("1"), Literal(1) + Literal(2))(sum('b))
val optimized = Optimize.execute(analyzer.execute(query))
val correctAnswer = testRelation.groupBy('a)(sum('b)).analyze

val query =
input.groupBy('a, Literal(1), Literal(1) + Literal(2))(sum('b))
val optimized = Optimize.execute(query)
comparePlans(optimized, correctAnswer)
}

val correctAnswer = input.groupBy('a)(sum('b))
test("Remove aliased literals") {
val query = testRelation.select('a, Literal(1).as('y)).groupBy('a, 'y)(sum('b))
val optimized = Optimize.execute(analyzer.execute(query))
val correctAnswer = testRelation.select('a, Literal(1).as('y)).groupBy('a)(sum('b)).analyze

comparePlans(optimized, correctAnswer)
}

test("remove repetition in grouping expression") {
val input = LocalRelation('a.int, 'b.int, 'c.int)

val query = input.groupBy('a + 1, 'b + 2, Literal(1) + 'A, Literal(2) + 'B)(sum('c))
val optimized = Optimize.execute(analyzer.execute(query))

val correctAnswer = analyzer.execute(input.groupBy('a + 1, 'b + 2)(sum('c)))
val correctAnswer = input.groupBy('a + 1, 'b + 2)(sum('c)).analyze

comparePlans(optimized, correctAnswer)
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,8 @@ class EliminateSortsSuite extends PlanTest {

object Optimize extends RuleExecutor[LogicalPlan] {
val batches =
Batch("Eliminate Sorts", Once,
Batch("Eliminate Sorts", FixedPoint(10),
FoldablePropagation,
EliminateSorts) :: Nil
}

Expand Down Expand Up @@ -69,4 +70,16 @@ class EliminateSortsSuite extends PlanTest {

comparePlans(optimized, correctAnswer)
}

test("Remove no-op alias") {
val x = testRelation

val query = x.select('a.as('x), Year(CurrentDate()).as('y), 'b)
.orderBy('x.asc, 'y.asc, 'b.desc)
val optimized = Optimize.execute(analyzer.execute(query))
val correctAnswer = analyzer.execute(
x.select('a.as('x), Year(CurrentDate()).as('y), 'b).orderBy('x.asc, 'b.desc))

comparePlans(optimized, correctAnswer)
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,131 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.sql.catalyst.optimizer

import org.apache.spark.sql.catalyst.dsl.expressions._
import org.apache.spark.sql.catalyst.dsl.plans._
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.plans._
import org.apache.spark.sql.catalyst.plans.logical._
import org.apache.spark.sql.catalyst.rules._

class FoldablePropagationSuite extends PlanTest {
object Optimize extends RuleExecutor[LogicalPlan] {
val batches =
Batch("Foldable Propagation", FixedPoint(20),
FoldablePropagation) :: Nil
}

val testRelation = LocalRelation('a.int, 'b.int)

test("Propagate from subquery") {
val query = OneRowRelation
.select(Literal(1).as('a), Literal(2).as('b))
.subquery('T)
.select('a, 'b)
val optimized = Optimize.execute(query.analyze)
val correctAnswer = OneRowRelation
.select(Literal(1).as('a), Literal(2).as('b))
.subquery('T)
.select(Literal(1).as('a), Literal(2).as('b)).analyze

comparePlans(optimized, correctAnswer)
}

test("Propagate to select clause") {
val query = testRelation
.select('a.as('x), "str".as('y), 'b.as('z))
.select('x, 'y, 'z)
val optimized = Optimize.execute(query.analyze)
val correctAnswer = testRelation
.select('a.as('x), "str".as('y), 'b.as('z))
.select('x, "str".as('y), 'z).analyze

comparePlans(optimized, correctAnswer)
}

test("Propagate to where clause") {
val query = testRelation
.select("str".as('y))
.where('y === "str" && "str" === 'y)
val optimized = Optimize.execute(query.analyze)
val correctAnswer = testRelation
.select("str".as('y))
.where("str".as('y) === "str" && "str" === "str".as('y)).analyze

comparePlans(optimized, correctAnswer)
}

test("Propagate to orderBy clause") {
val query = testRelation
.select('a.as('x), Year(CurrentDate()).as('y), 'b)
.orderBy('x.asc, 'y.asc, 'b.desc)
val optimized = Optimize.execute(query.analyze)
val correctAnswer = testRelation
.select('a.as('x), Year(CurrentDate()).as('y), 'b)
.orderBy('x.asc, SortOrder(Year(CurrentDate()), Ascending), 'b.desc).analyze

comparePlans(optimized, correctAnswer)
}

test("Propagate to groupBy clause") {
val query = testRelation
.select('a.as('x), Year(CurrentDate()).as('y), 'b)
.groupBy('x, 'y, 'b)(sum('x), avg('y).as('AVG), count('b))
val optimized = Optimize.execute(query.analyze)
val correctAnswer = testRelation
.select('a.as('x), Year(CurrentDate()).as('y), 'b)
.groupBy('x, Year(CurrentDate()).as('y), 'b)(sum('x), avg(Year(CurrentDate())).as('AVG),
count('b)).analyze

comparePlans(optimized, correctAnswer)
}

test("Propagate in a complex query") {
val query = testRelation
.select('a.as('x), Year(CurrentDate()).as('y), 'b)
.where('x > 1 && 'y === 2016 && 'b > 1)
.groupBy('x, 'y, 'b)(sum('x), avg('y).as('AVG), count('b))
.orderBy('x.asc, 'AVG.asc)
val optimized = Optimize.execute(query.analyze)
val correctAnswer = testRelation
.select('a.as('x), Year(CurrentDate()).as('y), 'b)
.where('x > 1 && Year(CurrentDate()).as('y) === 2016 && 'b > 1)
.groupBy('x, Year(CurrentDate()).as("y"), 'b)(sum('x), avg(Year(CurrentDate())).as('AVG),
count('b))
.orderBy('x.asc, 'AVG.asc).analyze

comparePlans(optimized, correctAnswer)
}

test("Propagate in subqueries of Union queries") {
val query = Union(
Seq(
testRelation.select(Literal(1).as('x), 'a).select('x + 'a),
testRelation.select(Literal(2).as('x), 'a).select('x + 'a)))
.select('x)
val optimized = Optimize.execute(query.analyze)
val correctAnswer = Union(
Seq(
testRelation.select(Literal(1).as('x), 'a).select((Literal(1).as('x) + 'a).as("(x + a)")),
testRelation.select(Literal(2).as('x), 'a).select((Literal(2).as('x) + 'a).as("(x + a)"))))
.select('x).analyze

comparePlans(optimized, correctAnswer)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -2499,6 +2499,14 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext {
}
}

test("Eliminate noop ordinal ORDER BY") {
withSQLConf(SQLConf.ORDER_BY_ORDINAL.key -> "true") {
val plan1 = sql("SELECT 1.0, 'abc', year(current_date()) ORDER BY 1, 2, 3")
val plan2 = sql("SELECT 1.0, 'abc', year(current_date())")
comparePlans(plan1.queryExecution.optimizedPlan, plan2.queryExecution.optimizedPlan)
}
}

test("check code injection is prevented") {
// The end of comment (*/) should be escaped.
var literal =
Expand Down

0 comments on commit 5907ebf

Please sign in to comment.