Skip to content

Commit

Permalink
[SPARK-10371][SQL] Implement subexpr elimination for UnsafeProjections
Browse files Browse the repository at this point in the history
This patch adds the building blocks for codegening subexpr elimination and implements
it end to end for UnsafeProjection. The building blocks can be used to do the same thing
for other operators.

It introduces some utilities to compute common sub expressions. Expressions can be added to
this data structure. The expr and its children will be recursively matched against existing
expressions (ones previously added) and grouped into common groups. This is built using
the existing `semanticEquals`. It does not understand things like commutative or associative
expressions. This can be done as future work.

After building this data structure, the codegen process takes advantage of it by:
  1. Generating a helper function in the generated class that computes the common
     subexpression. This is done for all common subexpressions that have at least
     two occurrences and the expression tree is sufficiently complex.
  2. When generating the apply() function, if the helper function exists, call that
     instead of regenerating the expression tree. Repeated calls to the helper function
     shortcircuit the evaluation logic.

Author: Nong Li <nong@databricks.com>
Author: Nong Li <nongli@gmail.com>

This patch had conflicts when merged, resolved by
Committer: Michael Armbrust <michael@databricks.com>

Closes #9480 from nongli/spark-10371.

(cherry picked from commit 87aedc4)
Signed-off-by: Michael Armbrust <michael@databricks.com>
  • Loading branch information
nongli authored and marmbrus committed Nov 10, 2015
1 parent 5ccc1eb commit f38509a
Show file tree
Hide file tree
Showing 11 changed files with 523 additions and 16 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,106 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.sql.catalyst.expressions

import scala.collection.mutable

/**
* This class is used to compute equality of (sub)expression trees. Expressions can be added
* to this class and they subsequently query for expression equality. Expression trees are
* considered equal if for the same input(s), the same result is produced.
*/
class EquivalentExpressions {
/**
* Wrapper around an Expression that provides semantic equality.
*/
case class Expr(e: Expression) {
val hash = e.semanticHash()
override def equals(o: Any): Boolean = o match {
case other: Expr => e.semanticEquals(other.e)
case _ => false
}
override def hashCode: Int = hash
}

// For each expression, the set of equivalent expressions.
private val equivalenceMap: mutable.HashMap[Expr, mutable.MutableList[Expression]] =
new mutable.HashMap[Expr, mutable.MutableList[Expression]]

/**
* Adds each expression to this data structure, grouping them with existing equivalent
* expressions. Non-recursive.
* Returns if there was already a matching expression.
*/
def addExpr(expr: Expression): Boolean = {
if (expr.deterministic) {
val e: Expr = Expr(expr)
val f = equivalenceMap.get(e)
if (f.isDefined) {
f.get.+= (expr)
true
} else {
equivalenceMap.put(e, mutable.MutableList(expr))
false
}
} else {
false
}
}

/**
* Adds the expression to this datastructure recursively. Stops if a matching expression
* is found. That is, if `expr` has already been added, its children are not added.
* If ignoreLeaf is true, leaf nodes are ignored.
*/
def addExprTree(root: Expression, ignoreLeaf: Boolean = true): Unit = {
val skip = root.isInstanceOf[LeafExpression] && ignoreLeaf
if (!skip && root.deterministic && !addExpr(root)) {
root.children.foreach(addExprTree(_, ignoreLeaf))
}
}

/**
* Returns all fo the expression trees that are equivalent to `e`. Returns
* an empty collection if there are none.
*/
def getEquivalentExprs(e: Expression): Seq[Expression] = {
equivalenceMap.get(Expr(e)).getOrElse(mutable.MutableList())
}

/**
* Returns all the equivalent sets of expressions.
*/
def getAllEquivalentExprs: Seq[Seq[Expression]] = {
equivalenceMap.values.map(_.toSeq).toSeq
}

/**
* Returns the state of the datastructure as a string. If all is false, skips sets of equivalent
* expressions with cardinality 1.
*/
def debugString(all: Boolean = false): String = {
val sb: mutable.StringBuilder = new StringBuilder()
sb.append("Equivalent expressions:\n")
equivalenceMap.foreach { case (k, v) => {
if (all || v.length > 1) {
sb.append(" " + v.mkString(", ")).append("\n")
}
}}
sb.toString()
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -92,12 +92,24 @@ abstract class Expression extends TreeNode[Expression] {
* @return [[GeneratedExpressionCode]]
*/
def gen(ctx: CodeGenContext): GeneratedExpressionCode = {
val isNull = ctx.freshName("isNull")
val primitive = ctx.freshName("primitive")
val ve = GeneratedExpressionCode("", isNull, primitive)
ve.code = genCode(ctx, ve)
// Add `this` in the comment.
ve.copy(s"/* $this */\n" + ve.code)
val subExprState = ctx.subExprEliminationExprs.get(this)
if (subExprState.isDefined) {
// This expression is repeated meaning the code to evaluated has already been added
// as a function, `subExprState.fnName`. Just call that.
val code =
s"""
|/* $this */
|${subExprState.get.fnName}(${ctx.INPUT_ROW});
|""".stripMargin.trim
GeneratedExpressionCode(code, subExprState.get.code.isNull, subExprState.get.code.value)
} else {
val isNull = ctx.freshName("isNull")
val primitive = ctx.freshName("primitive")
val ve = GeneratedExpressionCode("", isNull, primitive)
ve.code = genCode(ctx, ve)
// Add `this` in the comment.
ve.copy(s"/* $this */\n" + ve.code.trim)
}
}

/**
Expand Down Expand Up @@ -145,11 +157,37 @@ abstract class Expression extends TreeNode[Expression] {
case (i1, i2) => i1 == i2
}
}
// Non-determinstic expressions cannot be equal
if (!deterministic || !other.deterministic) return false
val elements1 = this.productIterator.toSeq
val elements2 = other.asInstanceOf[Product].productIterator.toSeq
checkSemantic(elements1, elements2)
}

/**
* Returns the hash for this expression. Expressions that compute the same result, even if
* they differ cosmetically should return the same hash.
*/
def semanticHash() : Int = {
def computeHash(e: Seq[Any]): Int = {
// See http://stackoverflow.com/questions/113511/hash-code-implementation
var hash: Int = 17
e.foreach(i => {
val h: Int = i match {
case (e: Expression) => e.semanticHash()
case (Some(e: Expression)) => e.semanticHash()
case (t: Traversable[_]) => computeHash(t.toSeq)
case null => 0
case (o) => o.hashCode()
}
hash = hash * 37 + h
})
hash
}

computeHash(this.productIterator.toSeq)
}

/**
* Checks the input data types, returns `TypeCheckResult.success` if it's valid,
* or returns a `TypeCheckResult` with an error message if invalid.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -144,6 +144,22 @@ object UnsafeProjection {
def create(exprs: Seq[Expression], inputSchema: Seq[Attribute]): UnsafeProjection = {
create(exprs.map(BindReferences.bindReference(_, inputSchema)))
}

/**
* Same as other create()'s but allowing enabling/disabling subexpression elimination.
* TODO: refactor the plumbing and clean this up.
*/
def create(
exprs: Seq[Expression],
inputSchema: Seq[Attribute],
subexpressionEliminationEnabled: Boolean): UnsafeProjection = {
val e = exprs.map(BindReferences.bindReference(_, inputSchema))
.map(_ transform {
case CreateStruct(children) => CreateStructUnsafe(children)
case CreateNamedStruct(children) => CreateNamedStructUnsafe(children)
})
GenerateUnsafeProjection.generate(e, subexpressionEliminationEnabled)
}
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,33 @@ class CodeGenContext {
addedFunctions += ((funcName, funcCode))
}

/**
* Holds expressions that are equivalent. Used to perform subexpression elimination
* during codegen.
*
* For expressions that appear more than once, generate additional code to prevent
* recomputing the value.
*
* For example, consider two exprsesion generated from this SQL statement:
* SELECT (col1 + col2), (col1 + col2) / col3.
*
* equivalentExpressions will match the tree containing `col1 + col2` and it will only
* be evaluated once.
*/
val equivalentExpressions: EquivalentExpressions = new EquivalentExpressions

// State used for subexpression elimination.
case class SubExprEliminationState(
val isLoaded: String, code: GeneratedExpressionCode, val fnName: String)

// Foreach expression that is participating in subexpression elimination, the state to use.
val subExprEliminationExprs: mutable.HashMap[Expression, SubExprEliminationState] =
mutable.HashMap[Expression, SubExprEliminationState]()

// The collection of isLoaded variables that need to be reset on each row.
val subExprIsLoadedVariables: mutable.ArrayBuffer[String] =
mutable.ArrayBuffer.empty[String]

final val JAVA_BOOLEAN = "boolean"
final val JAVA_BYTE = "byte"
final val JAVA_SHORT = "short"
Expand Down Expand Up @@ -317,6 +344,87 @@ class CodeGenContext {
functions.map(name => s"$name($row);").mkString("\n")
}
}

/**
* Checks and sets up the state and codegen for subexpression elimination. This finds the
* common subexpresses, generates the functions that evaluate those expressions and populates
* the mapping of common subexpressions to the generated functions.
*/
private def subexpressionElimination(expressions: Seq[Expression]) = {
// Add each expression tree and compute the common subexpressions.
expressions.foreach(equivalentExpressions.addExprTree(_))

// Get all the exprs that appear at least twice and set up the state for subexpression
// elimination.
val commonExprs = equivalentExpressions.getAllEquivalentExprs.filter(_.size > 1)
commonExprs.foreach(e => {
val expr = e.head
val isLoaded = freshName("isLoaded")
val isNull = freshName("isNull")
val primitive = freshName("primitive")
val fnName = freshName("evalExpr")

// Generate the code for this expression tree and wrap it in a function.
val code = expr.gen(this)
val fn =
s"""
|private void $fnName(InternalRow ${INPUT_ROW}) {
| if (!$isLoaded) {
| ${code.code.trim}
| $isLoaded = true;
| $isNull = ${code.isNull};
| $primitive = ${code.value};
| }
|}
""".stripMargin
code.code = fn
code.isNull = isNull
code.value = primitive

addNewFunction(fnName, fn)

// Add a state and a mapping of the common subexpressions that are associate with this
// state. Adding this expression to subExprEliminationExprMap means it will call `fn`
// when it is code generated. This decision should be a cost based one.
//
// The cost of doing subexpression elimination is:
// 1. Extra function call, although this is probably *good* as the JIT can decide to
// inline or not.
// 2. Extra branch to check isLoaded. This branch is likely to be predicted correctly
// very often. The reason it is not loaded is because of a prior branch.
// 3. Extra store into isLoaded.
// The benefit doing subexpression elimination is:
// 1. Running the expression logic. Even for a simple expression, it is likely more than 3
// above.
// 2. Less code.
// Currently, we will do this for all non-leaf only expression trees (i.e. expr trees with
// at least two nodes) as the cost of doing it is expected to be low.

// Maintain the loaded value and isNull as member variables. This is necessary if the codegen
// function is split across multiple functions.
// TODO: maintaining this as a local variable probably allows the compiler to do better
// optimizations.
addMutableState("boolean", isLoaded, s"$isLoaded = false;")
addMutableState("boolean", isNull, s"$isNull = false;")
addMutableState(javaType(expr.dataType), primitive,
s"$primitive = ${defaultValue(expr.dataType)};")
subExprIsLoadedVariables += isLoaded

val state = SubExprEliminationState(isLoaded, code, fnName)
e.foreach(subExprEliminationExprs.put(_, state))
})
}

/**
* Generates code for expressions. If doSubexpressionElimination is true, subexpression
* elimination will be performed. Subexpression elimination assumes that the code will for each
* expression will be combined in the `expressions` order.
*/
def generateExpressions(expressions: Seq[Expression],
doSubexpressionElimination: Boolean = false): Seq[GeneratedExpressionCode] = {
if (doSubexpressionElimination) subexpressionElimination(expressions)
expressions.map(e => e.gen(this))
}
}

/**
Expand Down Expand Up @@ -349,7 +457,7 @@ abstract class CodeGenerator[InType <: AnyRef, OutType <: AnyRef] extends Loggin
}

protected def declareAddedFunctions(ctx: CodeGenContext): String = {
ctx.addedFunctions.map { case (funcName, funcCode) => funcCode }.mkString("\n")
ctx.addedFunctions.map { case (funcName, funcCode) => funcCode }.mkString("\n").trim
}

/**
Expand Down

0 comments on commit f38509a

Please sign in to comment.