Skip to content

Commit

Permalink
[SPARK-10371] [SQL] Implement subexpr elimination for UnsafeProjections
Browse files Browse the repository at this point in the history
This patch adds the building blocks for codegening subexpr elimination and implements
it end to end for UnsafeProjection. The building blocks can be used to do the same thing
for other operators.

It introduces some utilities to compute common sub expressions. Expressions can be added to
this data structure. The expr and its children will be recursively matched against existing
expressions (ones previously added) and grouped into common groups. This is built using
the existing `semanticEquals`. It does not understand things like commutative or associative
expressions. This can be done as future work.

After building this data structure, the codegen process takes advantage of it by:
  1. Generating a helper function in the generated class that computes the common
     subexpression. This is done for all common subexpressions that have at least
     two occurrences and the expression tree is sufficiently complex.
  2. When generating the apply() function, if the helper function exists, call that
     instead of regenerating the expression tree. Repeated calls to the helper function
     shortcircuit the evaluation logic.
  • Loading branch information
nongli committed Nov 5, 2015
1 parent ce5e6a2 commit 2feafbc
Show file tree
Hide file tree
Showing 7 changed files with 459 additions and 15 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,108 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.sql.catalyst.expressions

import scala.collection.mutable

/**
* This class is used to compute equality of (sub)expression trees. Expressions can be added
* to this class and they subsequently query for expression equality. Expression trees are
* considered equal if for the same input(s), the same result is produced.
*/
class EquivalentExpressions {
/**
* Wrapper around an Expression that provides semantic equality.
*/
case class Expr(e: Expression) {
val hash = e.semanticHash()
override def equals(o: Any): Boolean = o match {
case other: Expr => e.semanticEquals(other.e)
case _ => false
}
override def hashCode: Int = hash
}

// For each expression, the set of equivalent expressions.
private val equivalenceMap: mutable.HashMap[Expr, mutable.MutableList[Expression]] =
new mutable.HashMap[Expr, mutable.MutableList[Expression]]

/**
* Adds each expression to this data structure, grouping them with existing equivalent
* expressions. Non-recursive.
* Returns if there was already a matching expression.
*/
def addExpr(expr: Expression): Boolean = {
if (expr.deterministic) {
val e: Expr = Expr(expr)
val f = equivalenceMap.get(e)
if (f.isDefined) {
f.get.+= (expr)
true
} else {
equivalenceMap.put(e, mutable.MutableList(expr))
false
}
} else {
false
}
}

/**
* Adds the expression to this datastructure recursively. Stops if a matching expression
* is found. That is, if `expr` has already been added, its children are not added.
* If ignoreLeaf is true, leaf nodes are ignored.
*/
def addExprTree(root: Expression, ignoreLeaf: Boolean): Unit = {
val skip = root.isInstanceOf[LeafExpression] && ignoreLeaf
if (!skip && root.deterministic && !addExpr(root)) {
root.children.foreach(addExprTree(_, ignoreLeaf))
}
}

/**
* Returns all fo the expression trees that are equivalent to `e`. Returns
* an empty collection if there are none.
*/
def getEquivalentExprs(e: Expression): Seq[Expression] = {
equivalenceMap.get(Expr(e)).getOrElse(mutable.MutableList())
}

/**
* Returns all the equivalent sets of expressions.
*/
def getAllEquivalentExprs: Seq[Seq[Expression]] = {
equivalenceMap.map { case(k, v) => {
v.toList
} }.toList
}

/**
* Returns the state of the datastructure as a string. If all is false, skips sets of equivalent
* expressions with cardinality 1.
*/
def debugString(all: Boolean = false): String = {
val sb: mutable.StringBuilder = new StringBuilder()
sb.append("Equivalent expressions:\n")
equivalenceMap.foreach { case (k, v) => {
if (all || v.length > 1) {
sb.append(" " + v.mkString(", ")).append("\n")
}
}}
sb.toString()
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -92,12 +92,24 @@ abstract class Expression extends TreeNode[Expression] {
* @return [[GeneratedExpressionCode]]
*/
def gen(ctx: CodeGenContext): GeneratedExpressionCode = {
val isNull = ctx.freshName("isNull")
val primitive = ctx.freshName("primitive")
val ve = GeneratedExpressionCode("", isNull, primitive)
ve.code = genCode(ctx, ve)
// Add `this` in the comment.
ve.copy(s"/* $this */\n" + ve.code)
val subExprState = ctx.subExprEliminationExprs.get(this)
if (subExprState.isDefined) {
// This expression is repeated meaning the code to evaluated has already been added
// as a function, `subExprState.fnName`. Just call that.
val code =
s"""
|/* $this */
|${subExprState.get.fnName}(${ctx.INPUT_ROW});
|""".stripMargin.trim
GeneratedExpressionCode(code, subExprState.get.code.isNull, subExprState.get.code.value)
} else {
val isNull = ctx.freshName("isNull")
val primitive = ctx.freshName("primitive")
val ve = GeneratedExpressionCode("", isNull, primitive)
ve.code = genCode(ctx, ve)
// Add `this` in the comment.
ve.copy(s"/* $this */\n" + ve.code.trim)
}
}

/**
Expand Down Expand Up @@ -135,6 +147,7 @@ abstract class Expression extends TreeNode[Expression] {
/**
* Returns true when two expressions will always compute the same result, even if they differ
* cosmetically (i.e. capitalization of names in attributes may be different).
* TODO: how should this deal with nonDeterministic
*/
def semanticEquals(other: Expression): Boolean = this.getClass == other.getClass && {
def checkSemantic(elements1: Seq[Any], elements2: Seq[Any]): Boolean = {
Expand All @@ -150,6 +163,30 @@ abstract class Expression extends TreeNode[Expression] {
checkSemantic(elements1, elements2)
}

/**
* Returns the hash for this expression. Expressions that compute the same result, even if
* they differ cosmetically should return the same hash.
*/
def semanticHash() : Int = {
def computeHash(e: Seq[Any]): Int = {
// See http://stackoverflow.com/questions/113511/hash-code-implementation
var hash: Int = 17
e.foreach(i => {
val h: Int = i match {
case (e: Expression) => e.semanticHash()
case (Some(e: Expression)) => e.semanticHash()
case (t: Traversable[_]) => computeHash(t.toSeq)
case null => 0
case (o) => o.hashCode()
}
hash = hash * 37 + h
})
hash
}

computeHash(this.productIterator.toSeq)
}

/**
* Checks the input data types, returns `TypeCheckResult.success` if it's valid,
* or returns a `TypeCheckResult` with an error message if invalid.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,34 @@ class CodeGenContext {
addedFunctions += ((funcName, funcCode))
}

/**
* Holds expressions that are equivalent. Used to perform subexpression elimination
* during codegen.
*
* For expressions that appear more than once, generate additional code to prevent
* recomputing the value.
*
* For example, consider two exprsesion generated from this SQL statement:
* SELECT (col1 + col2), (col1 + col2) / col3.
*
* equivalentExpressions will match the tree containing `col1 + col2` and it will only
* be evaluated once.
*/
val equivalentExpressions: EquivalentExpressions = new EquivalentExpressions

// State used for subexpression elimination.
case class SubExprEliminationState(
val isLoaded: String, code: GeneratedExpressionCode, val fnName: String, val dt: DataType)

// All the subexpr elimination states. There is one of these states for each group of common
// subexpressions.
val subExprEliminationStates: mutable.ArrayBuffer[SubExprEliminationState] =
mutable.ArrayBuffer.empty[SubExprEliminationState]

// Foreach expression that is participating in subexpression elimination, the state to use.
val subExprEliminationExprs: mutable.HashMap[Expression, SubExprEliminationState] =
mutable.HashMap[Expression, SubExprEliminationState]()

final val JAVA_BOOLEAN = "boolean"
final val JAVA_BYTE = "byte"
final val JAVA_SHORT = "short"
Expand Down Expand Up @@ -317,6 +345,77 @@ class CodeGenContext {
functions.map(name => s"$name($row);").mkString("\n")
}
}

/**
* Checks and sets up the state and codegen for subexpression elimination. This finds the
* common subexpresses, generates the functions that evaluate those expressions and populates
* the mapping of common subexpressions to the generated functions.
*/
private def subexpressionElimination(expressions: Seq[Expression]) = {
// Add each expression tree and compute the common subexpressions.
expressions.foreach(equivalentExpressions.addExprTree(_, true))

// Get all the exprs that appear at least twice and set up the state for subexpression
// elimination.
val commonExprs = equivalentExpressions.getAllEquivalentExprs.filter(_.size > 1)
commonExprs.foreach(e => {
val expr = e.head
val isLoaded = freshName("isLoaded")
val isNull = freshName("isNull")
val primitive = freshName("primitive")
val fnName = freshName("evalExpr")

// Generate the code for this expression tree and wrap it in a function.
val code = expr.gen(this)
val fn =
s"""
|private void $fnName(InternalRow ${INPUT_ROW}) {
| if (!$isLoaded) {
| ${code.code.trim}
| $isLoaded = true;
| $isNull = ${code.isNull};
| $primitive = ${code.value};
| }
|}
""".stripMargin
code.code = fn
code.isNull = isNull
code.value = primitive

addNewFunction(fnName, fn)

// Add a state and a mapping of the common subexpressions that are associate with this
// state. Adding this expression to subExprEliminationExprMap means it will call `fn`
// when it is code generated. This decision should be a cost based one.
//
// The cost of doing subexpression elimination is:
// 1. Extra function call, although this is probably *good* as the JIT can decide to
// inline or not.
// 2. Extra branch to check isLoaded. This branch is likely to be predicted correctly
// very often. The reason it is not loaded is because of a prior branch.
// 3. Extra store into isLoaded.
// The benefit doing subexpression elimination is:
// 1. Running the expression logic. Even for a simple expression, it is likely more than 3
// above.
// 2. Less code.
// Currently, we will do this for all non-leaf only expression trees (i.e. expr trees with
// at least two nodes) as the cost of doing it is expected to be low.
val state = SubExprEliminationState(isLoaded, code, fnName, expr.dataType)
subExprEliminationStates += state
e.foreach(subExprEliminationExprs.put(_, state))
})
}

/**
* Generates code for expressions. If doSubexpressionElimination is true, subexpression
* elimination will be performed. Subexpression elimination assumes that the code will for each
* expression will be combined in the `expressions` order.
*/
def generateExpressions(expressions: Seq[Expression],
doSubexpressionElimination: Boolean = false): Seq[GeneratedExpressionCode] = {
if (doSubexpressionElimination) subexpressionElimination(expressions)
expressions.map(e => e.gen(this))
}
}

/**
Expand All @@ -341,15 +440,26 @@ abstract class CodeGenerator[InType <: AnyRef, OutType <: AnyRef] extends Loggin
protected def declareMutableStates(ctx: CodeGenContext): String = {
ctx.mutableStates.map { case (javaType, variableName, _) =>
s"private $javaType $variableName;"
}.mkString("\n")
}.mkString("\n") + "\n" +
// Maintain the loaded value and isNull as member variables. This is necessary if the codegen
// function is split across multiple functions.
// TODO: maintaining this as a local variable probably allows the compiler to do better
// optimizations.
ctx.subExprEliminationStates.map { s => {
s"""
| private boolean ${s.isLoaded} = false;
| private boolean ${s.code.isNull};
| private ${ctx.javaType(s.dt)} ${s.code.value} = ${ctx.defaultValue(s.dt)};
""".stripMargin
}}.mkString("\n").trim
}

protected def initMutableStates(ctx: CodeGenContext): String = {
ctx.mutableStates.map(_._3).mkString("\n")
}

protected def declareAddedFunctions(ctx: CodeGenContext): String = {
ctx.addedFunctions.map { case (funcName, funcCode) => funcCode }.mkString("\n")
ctx.addedFunctions.map { case (funcName, funcCode) => funcCode }.mkString("\n").trim
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -139,17 +139,17 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro
s"""
${input.code}
if (${input.isNull}) {
$setNull
${setNull.trim}
} else {
$writeField
${writeField.trim}
}
"""
}

s"""
$rowWriter.initialize($bufferHolder, ${inputs.length});
${ctx.splitExpressions(row, writeFields)}
"""
""".trim
}

// TODO: if the nullability of array element is correct, we can use it to save null check.
Expand Down Expand Up @@ -275,8 +275,9 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro
"""
}

def createCode(ctx: CodeGenContext, expressions: Seq[Expression]): GeneratedExpressionCode = {
val exprEvals = expressions.map(e => e.gen(ctx))
def createCode(ctx: CodeGenContext, expressions: Seq[Expression],
useSubexprElimination: Boolean = false): GeneratedExpressionCode = {
val exprEvals = ctx.generateExpressions(expressions, useSubexprElimination)
val exprTypes = expressions.map(_.dataType)

val result = ctx.freshName("result")
Expand All @@ -285,10 +286,17 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro
val holderClass = classOf[BufferHolder].getName
ctx.addMutableState(holderClass, bufferHolder, s"this.$bufferHolder = new $holderClass();")

// Reset the isLoaded flag for each row.
val subexprReset = ctx.subExprEliminationStates.map(s => {
s"${s.isLoaded} = false;"
}).mkString("\n")

val code =
s"""
$bufferHolder.reset();
$subexprReset
${writeExpressionsToBuffer(ctx, ctx.INPUT_ROW, exprEvals, exprTypes, bufferHolder)}

$result.pointTo($bufferHolder.buffer, ${expressions.length}, $bufferHolder.totalSize());
"""
GeneratedExpressionCode(code, "false", result)
Expand All @@ -303,7 +311,7 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro
protected def create(expressions: Seq[Expression]): UnsafeProjection = {
val ctx = newCodeGenContext()

val eval = createCode(ctx, expressions)
val eval = createCode(ctx, expressions, true)

val code = s"""
public Object generate($exprType[] exprs) {
Expand All @@ -315,6 +323,7 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro
private $exprType[] expressions;

${declareMutableStates(ctx)}

${declareAddedFunctions(ctx)}

public SpecificUnsafeProjection($exprType[] expressions) {
Expand All @@ -328,7 +337,7 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro
}

public UnsafeRow apply(InternalRow ${ctx.INPUT_ROW}) {
${eval.code}
${eval.code.trim}
return ${eval.value};
}
}
Expand Down

0 comments on commit 2feafbc

Please sign in to comment.