Skip to content
Permalink
Browse files

[SPARK-23760][SQL] CodegenContext.withSubExprEliminationExprs should …

…save/restore CSE state correctly

## What changes were proposed in this pull request?

Fixed `CodegenContext.withSubExprEliminationExprs()` so that it saves/restores CSE state correctly.

## How was this patch tested?

Added new unit test to verify that the old CSE state is indeed saved and restored around the `withSubExprEliminationExprs()` call. Manually verified that this test fails without this patch.

Author: Kris Mok <kris.mok@databricks.com>

Closes #20870 from rednaxelafx/codegen-subexpr-fix.

(cherry picked from commit 95e51ff)
Signed-off-by: Wenchen Fan <wenchen@databricks.com>
  • Loading branch information...
rednaxelafx authored and cloud-fan committed Mar 22, 2018
1 parent c9acd46 commit 4da8c22f77475d1b328375e97e2825e1dea78fdd
@@ -389,7 +389,7 @@ class CodegenContext {
val equivalentExpressions: EquivalentExpressions = new EquivalentExpressions

// Foreach expression that is participating in subexpression elimination, the state to use.
val subExprEliminationExprs = mutable.HashMap.empty[Expression, SubExprEliminationState]
var subExprEliminationExprs = Map.empty[Expression, SubExprEliminationState]

// The collection of sub-expression result resetting methods that need to be called on each row.
val subexprFunctions = mutable.ArrayBuffer.empty[String]
@@ -1118,14 +1118,12 @@ class CodegenContext {
newSubExprEliminationExprs: Map[Expression, SubExprEliminationState])(
f: => Seq[ExprCode]): Seq[ExprCode] = {
val oldsubExprEliminationExprs = subExprEliminationExprs
subExprEliminationExprs.clear
newSubExprEliminationExprs.foreach(subExprEliminationExprs += _)
subExprEliminationExprs = newSubExprEliminationExprs

val genCodes = f

// Restore previous subExprEliminationExprs
subExprEliminationExprs.clear
oldsubExprEliminationExprs.foreach(subExprEliminationExprs += _)
subExprEliminationExprs = oldsubExprEliminationExprs
genCodes
}

@@ -1139,7 +1137,7 @@ class CodegenContext {
def subexpressionEliminationForWholeStageCodegen(expressions: Seq[Expression]): SubExprCodes = {
// Create a clear EquivalentExpressions and SubExprEliminationState mapping
val equivalentExpressions: EquivalentExpressions = new EquivalentExpressions
val subExprEliminationExprs = mutable.HashMap.empty[Expression, SubExprEliminationState]
val localSubExprEliminationExprs = mutable.HashMap.empty[Expression, SubExprEliminationState]

// Add each expression tree and compute the common subexpressions.
expressions.foreach(equivalentExpressions.addExprTree)
@@ -1152,10 +1150,10 @@ class CodegenContext {
// Generate the code for this expression tree.
val eval = expr.genCode(this)
val state = SubExprEliminationState(eval.isNull, eval.value)
e.foreach(subExprEliminationExprs.put(_, state))
e.foreach(localSubExprEliminationExprs.put(_, state))
eval.code.trim
}
SubExprCodes(codes, subExprEliminationExprs.toMap)
SubExprCodes(codes, localSubExprEliminationExprs.toMap)
}

/**
@@ -1203,7 +1201,7 @@ class CodegenContext {

subexprFunctions += s"${addNewFunction(fnName, fn)}($INPUT_ROW);"
val state = SubExprEliminationState(isNull, value)
e.foreach(subExprEliminationExprs.put(_, state))
subExprEliminationExprs ++= e.map(_ -> state).toMap
}
}

@@ -442,4 +442,48 @@ class CodeGenerationSuite extends SparkFunSuite with ExpressionEvalHelper {
assert(ctx.calculateParamLength(Seq.range(0, 100).map(Literal(_))) == 101)
assert(ctx.calculateParamLength(Seq.range(0, 100).map(x => Literal(x.toLong))) == 201)
}

test("SPARK-23760: CodegenContext.withSubExprEliminationExprs should save/restore correctly") {

val ref = BoundReference(0, IntegerType, true)
val add1 = Add(ref, ref)
val add2 = Add(add1, add1)

// raw testing of basic functionality
{
val ctx = new CodegenContext
val e = ref.genCode(ctx)
// before
ctx.subExprEliminationExprs += ref -> SubExprEliminationState(e.isNull, e.value)
assert(ctx.subExprEliminationExprs.contains(ref))
// call withSubExprEliminationExprs
ctx.withSubExprEliminationExprs(Map(add1 -> SubExprEliminationState("dummy", "dummy"))) {
assert(ctx.subExprEliminationExprs.contains(add1))
assert(!ctx.subExprEliminationExprs.contains(ref))
Seq.empty
}
// after
assert(ctx.subExprEliminationExprs.nonEmpty)
assert(ctx.subExprEliminationExprs.contains(ref))
assert(!ctx.subExprEliminationExprs.contains(add1))
}

// emulate an actual codegen workload
{
val ctx = new CodegenContext
// before
ctx.generateExpressions(Seq(add2, add1), doSubexpressionElimination = true) // trigger CSE
assert(ctx.subExprEliminationExprs.contains(add1))
// call withSubExprEliminationExprs
ctx.withSubExprEliminationExprs(Map(ref -> SubExprEliminationState("dummy", "dummy"))) {
assert(ctx.subExprEliminationExprs.contains(ref))
assert(!ctx.subExprEliminationExprs.contains(add1))
Seq.empty
}
// after
assert(ctx.subExprEliminationExprs.nonEmpty)
assert(ctx.subExprEliminationExprs.contains(add1))
assert(!ctx.subExprEliminationExprs.contains(ref))
}
}
}

0 comments on commit 4da8c22

Please sign in to comment.
You can’t perform that action at this time.