Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SPARK-10371] [SQL] Implement subexpr elimination for UnsafeProjections #9480

Closed
wants to merge 5 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,106 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.sql.catalyst.expressions

import scala.collection.mutable

/**
* This class is used to compute equality of (sub)expression trees. Expressions can be added
* to this class and they subsequently query for expression equality. Expression trees are
* considered equal if for the same input(s), the same result is produced.
*/
class EquivalentExpressions {
/**
* Wrapper around an Expression that provides semantic equality.
*/
case class Expr(e: Expression) {
val hash = e.semanticHash()
override def equals(o: Any): Boolean = o match {
case other: Expr => e.semanticEquals(other.e)
case _ => false
}
override def hashCode: Int = hash
}

// For each expression, the set of equivalent expressions.
private val equivalenceMap: mutable.HashMap[Expr, mutable.MutableList[Expression]] =
new mutable.HashMap[Expr, mutable.MutableList[Expression]]

/**
* Adds each expression to this data structure, grouping them with existing equivalent
* expressions. Non-recursive.
* Returns if there was already a matching expression.
*/
def addExpr(expr: Expression): Boolean = {
if (expr.deterministic) {
val e: Expr = Expr(expr)
val f = equivalenceMap.get(e)
if (f.isDefined) {
f.get.+= (expr)
true
} else {
equivalenceMap.put(e, mutable.MutableList(expr))
false
}
} else {
false
}
}

/**
* Adds the expression to this datastructure recursively. Stops if a matching expression
* is found. That is, if `expr` has already been added, its children are not added.
* If ignoreLeaf is true, leaf nodes are ignored.
*/
def addExprTree(root: Expression, ignoreLeaf: Boolean = true): Unit = {
val skip = root.isInstanceOf[LeafExpression] && ignoreLeaf
if (!skip && root.deterministic && !addExpr(root)) {
root.children.foreach(addExprTree(_, ignoreLeaf))
}
}

/**
* Returns all fo the expression trees that are equivalent to `e`. Returns
* an empty collection if there are none.
*/
def getEquivalentExprs(e: Expression): Seq[Expression] = {
equivalenceMap.get(Expr(e)).getOrElse(mutable.MutableList())
}

/**
* Returns all the equivalent sets of expressions.
*/
def getAllEquivalentExprs: Seq[Seq[Expression]] = {
equivalenceMap.values.map(_.toSeq).toSeq
}

/**
* Returns the state of the datastructure as a string. If all is false, skips sets of equivalent
* expressions with cardinality 1.
*/
def debugString(all: Boolean = false): String = {
val sb: mutable.StringBuilder = new StringBuilder()
sb.append("Equivalent expressions:\n")
equivalenceMap.foreach { case (k, v) => {
if (all || v.length > 1) {
sb.append(" " + v.mkString(", ")).append("\n")
}
}}
sb.toString()
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -92,12 +92,24 @@ abstract class Expression extends TreeNode[Expression] {
* @return [[GeneratedExpressionCode]]
*/
def gen(ctx: CodeGenContext): GeneratedExpressionCode = {
val isNull = ctx.freshName("isNull")
val primitive = ctx.freshName("primitive")
val ve = GeneratedExpressionCode("", isNull, primitive)
ve.code = genCode(ctx, ve)
// Add `this` in the comment.
ve.copy(s"/* $this */\n" + ve.code)
val subExprState = ctx.subExprEliminationExprs.get(this)
if (subExprState.isDefined) {
// This expression is repeated meaning the code to evaluated has already been added
// as a function, `subExprState.fnName`. Just call that.
val code =
s"""
|/* $this */
|${subExprState.get.fnName}(${ctx.INPUT_ROW});
|""".stripMargin.trim
GeneratedExpressionCode(code, subExprState.get.code.isNull, subExprState.get.code.value)
} else {
val isNull = ctx.freshName("isNull")
val primitive = ctx.freshName("primitive")
val ve = GeneratedExpressionCode("", isNull, primitive)
ve.code = genCode(ctx, ve)
// Add `this` in the comment.
ve.copy(s"/* $this */\n" + ve.code.trim)
}
}

/**
Expand Down Expand Up @@ -145,11 +157,37 @@ abstract class Expression extends TreeNode[Expression] {
case (i1, i2) => i1 == i2
}
}
// Non-determinstic expressions cannot be equal
if (!deterministic || !other.deterministic) return false
val elements1 = this.productIterator.toSeq
val elements2 = other.asInstanceOf[Product].productIterator.toSeq
checkSemantic(elements1, elements2)
}

/**
* Returns the hash for this expression. Expressions that compute the same result, even if
* they differ cosmetically should return the same hash.
*/
def semanticHash() : Int = {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since all of the expressions are the case class, probably we don't need to our own way to computeHash.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think so. Looking at thecomments on semanticEquals, we want to ignore cosmetic differences.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Identical hash values doesn't mean the identical values, I am suggesting to use the hashCode plus semanticEquals to identity the common expression, that's also the motivation of semanticEquals. See AttributeReference for details.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think that's what I do. I put the exprs in a hash set using semantic hash/semantic equals.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry, I am still confusing, why we can't use the hashCode instead.
The motivation of semanticEquals is to ignore the AttributeReference.name in comparison for AttributeReference, and the AttributeReference.hashCode also does ignore the AttributeReference.name, I don't think the cosmetic differences really exists for hashCode. Isn't it?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

And there is more discussion on the semanticEquals can be found at #6587, even, I don't think we need the semanticEquals if we changed the implementation of AttributeReference.equals, as it does make lots of code complicated like this one, by using the semanticEquals, other than equals or ==.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I took a quick look at 6587 and I agree with michael about semantic equals. I think when equivalence classes are added, semantic equals is even more different than equals.

That being said, this patch is not the motivation to do this. If we decide to remove semanticEquals, this patch can be updated trivially to use equals.

Regarding hashCode vs semanticHash code, I think it does no? It looks to me like the hash everything, including the cosmetic stuff but please correct me if I'm wrong. In general, i think it makes sense to implement hash if you implement equals.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

When you look at the methods equals, semanticEquals and hashCode of the AttributeReference, you will see that they are not matched, as equals will take consideration of the name, but the other 2 are not, that's why I am thinking we can also use the hashCode, instead of adding the new method semanticHash.

Anyway, it's not an external API, we can change it back anytime, as 1.6 is almost code freeze, and this is critical for people now.

def computeHash(e: Seq[Any]): Int = {
// See http://stackoverflow.com/questions/113511/hash-code-implementation
var hash: Int = 17
e.foreach(i => {
val h: Int = i match {
case (e: Expression) => e.semanticHash()
case (Some(e: Expression)) => e.semanticHash()
case (t: Traversable[_]) => computeHash(t.toSeq)
case null => 0
case (o) => o.hashCode()
}
hash = hash * 37 + h
})
hash
}

computeHash(this.productIterator.toSeq)
}

/**
* Checks the input data types, returns `TypeCheckResult.success` if it's valid,
* or returns a `TypeCheckResult` with an error message if invalid.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -144,6 +144,22 @@ object UnsafeProjection {
def create(exprs: Seq[Expression], inputSchema: Seq[Attribute]): UnsafeProjection = {
create(exprs.map(BindReferences.bindReference(_, inputSchema)))
}

/**
* Same as other create()'s but allowing enabling/disabling subexpression elimination.
* TODO: refactor the plumbing and clean this up.
*/
def create(
exprs: Seq[Expression],
inputSchema: Seq[Attribute],
subexpressionEliminationEnabled: Boolean): UnsafeProjection = {
val e = exprs.map(BindReferences.bindReference(_, inputSchema))
.map(_ transform {
case CreateStruct(children) => CreateStructUnsafe(children)
case CreateNamedStruct(children) => CreateNamedStructUnsafe(children)
})
GenerateUnsafeProjection.generate(e, subexpressionEliminationEnabled)
}
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,33 @@ class CodeGenContext {
addedFunctions += ((funcName, funcCode))
}

/**
* Holds expressions that are equivalent. Used to perform subexpression elimination
* during codegen.
*
* For expressions that appear more than once, generate additional code to prevent
* recomputing the value.
*
* For example, consider two exprsesion generated from this SQL statement:
* SELECT (col1 + col2), (col1 + col2) / col3.
*
* equivalentExpressions will match the tree containing `col1 + col2` and it will only
* be evaluated once.
*/
val equivalentExpressions: EquivalentExpressions = new EquivalentExpressions

// State used for subexpression elimination.
case class SubExprEliminationState(
val isLoaded: String, code: GeneratedExpressionCode, val fnName: String)

// Foreach expression that is participating in subexpression elimination, the state to use.
val subExprEliminationExprs: mutable.HashMap[Expression, SubExprEliminationState] =
mutable.HashMap[Expression, SubExprEliminationState]()

// The collection of isLoaded variables that need to be reset on each row.
val subExprIsLoadedVariables: mutable.ArrayBuffer[String] =
mutable.ArrayBuffer.empty[String]

final val JAVA_BOOLEAN = "boolean"
final val JAVA_BYTE = "byte"
final val JAVA_SHORT = "short"
Expand Down Expand Up @@ -317,6 +344,87 @@ class CodeGenContext {
functions.map(name => s"$name($row);").mkString("\n")
}
}

/**
* Checks and sets up the state and codegen for subexpression elimination. This finds the
* common subexpresses, generates the functions that evaluate those expressions and populates
* the mapping of common subexpressions to the generated functions.
*/
private def subexpressionElimination(expressions: Seq[Expression]) = {
// Add each expression tree and compute the common subexpressions.
expressions.foreach(equivalentExpressions.addExprTree(_))

// Get all the exprs that appear at least twice and set up the state for subexpression
// elimination.
val commonExprs = equivalentExpressions.getAllEquivalentExprs.filter(_.size > 1)
commonExprs.foreach(e => {
val expr = e.head
val isLoaded = freshName("isLoaded")
val isNull = freshName("isNull")
val primitive = freshName("primitive")
val fnName = freshName("evalExpr")

// Generate the code for this expression tree and wrap it in a function.
val code = expr.gen(this)
val fn =
s"""
|private void $fnName(InternalRow ${INPUT_ROW}) {
| if (!$isLoaded) {
| ${code.code.trim}
| $isLoaded = true;
| $isNull = ${code.isNull};
| $primitive = ${code.value};
| }
|}
""".stripMargin
code.code = fn
code.isNull = isNull
code.value = primitive

addNewFunction(fnName, fn)

// Add a state and a mapping of the common subexpressions that are associate with this
// state. Adding this expression to subExprEliminationExprMap means it will call `fn`
// when it is code generated. This decision should be a cost based one.
//
// The cost of doing subexpression elimination is:
// 1. Extra function call, although this is probably *good* as the JIT can decide to
// inline or not.
// 2. Extra branch to check isLoaded. This branch is likely to be predicted correctly
// very often. The reason it is not loaded is because of a prior branch.
// 3. Extra store into isLoaded.
// The benefit doing subexpression elimination is:
// 1. Running the expression logic. Even for a simple expression, it is likely more than 3
// above.
// 2. Less code.
// Currently, we will do this for all non-leaf only expression trees (i.e. expr trees with
// at least two nodes) as the cost of doing it is expected to be low.

// Maintain the loaded value and isNull as member variables. This is necessary if the codegen
// function is split across multiple functions.
// TODO: maintaining this as a local variable probably allows the compiler to do better
// optimizations.
addMutableState("boolean", isLoaded, s"$isLoaded = false;")
addMutableState("boolean", isNull, s"$isNull = false;")
addMutableState(javaType(expr.dataType), primitive,
s"$primitive = ${defaultValue(expr.dataType)};")
subExprIsLoadedVariables += isLoaded

val state = SubExprEliminationState(isLoaded, code, fnName)
e.foreach(subExprEliminationExprs.put(_, state))
})
}

/**
* Generates code for expressions. If doSubexpressionElimination is true, subexpression
* elimination will be performed. Subexpression elimination assumes that the code will for each
* expression will be combined in the `expressions` order.
*/
def generateExpressions(expressions: Seq[Expression],
doSubexpressionElimination: Boolean = false): Seq[GeneratedExpressionCode] = {
if (doSubexpressionElimination) subexpressionElimination(expressions)
expressions.map(e => e.gen(this))
}
}

/**
Expand Down Expand Up @@ -349,7 +457,7 @@ abstract class CodeGenerator[InType <: AnyRef, OutType <: AnyRef] extends Loggin
}

protected def declareAddedFunctions(ctx: CodeGenContext): String = {
ctx.addedFunctions.map { case (funcName, funcCode) => funcCode }.mkString("\n")
ctx.addedFunctions.map { case (funcName, funcCode) => funcCode }.mkString("\n").trim
}

/**
Expand Down