Skip to content

Commit

Permalink
Add flag to disable.
Browse files Browse the repository at this point in the history
  • Loading branch information
nongli committed Nov 9, 2015
1 parent a9bcdb0 commit 6cf0186
Show file tree
Hide file tree
Showing 6 changed files with 50 additions and 3 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -144,6 +144,22 @@ object UnsafeProjection {
def create(exprs: Seq[Expression], inputSchema: Seq[Attribute]): UnsafeProjection = {
create(exprs.map(BindReferences.bindReference(_, inputSchema)))
}

/**
* Same as other create()'s but allowing enabling/disabling subexpression elimination.
* TODO: refactor the plumbing and clean this up.
*/
def create(
exprs: Seq[Expression],
inputSchema: Seq[Attribute],
subexpressionEliminationEnabled: Boolean): UnsafeProjection = {
val e = exprs.map(BindReferences.bindReference(_, inputSchema))
.map(_ transform {
case CreateStruct(children) => CreateStructUnsafe(children)
case CreateNamedStruct(children) => CreateNamedStructUnsafe(children)
})
GenerateUnsafeProjection.generate(e, subexpressionEliminationEnabled)
}
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -308,10 +308,21 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro
protected def bind(in: Seq[Expression], inputSchema: Seq[Attribute]): Seq[Expression] =
in.map(BindReferences.bindReference(_, inputSchema))

def generate(
expressions: Seq[Expression],
subexpressionEliminationEnabled: Boolean): UnsafeProjection = {
create(canonicalize(expressions), subexpressionEliminationEnabled)
}

protected def create(expressions: Seq[Expression]): UnsafeProjection = {
val ctx = newCodeGenContext()
create(expressions, false)
}

val eval = createCode(ctx, expressions, true)
private def create(
expressions: Seq[Expression],
subexpressionEliminationEnabled: Boolean): UnsafeProjection = {
val ctx = newCodeGenContext()
val eval = createCode(ctx, expressions, subexpressionEliminationEnabled)

val code = s"""
public Object generate($exprType[] exprs) {
Expand Down
8 changes: 8 additions & 0 deletions sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala
Original file line number Diff line number Diff line change
Expand Up @@ -268,6 +268,11 @@ private[spark] object SQLConf {
doc = "When true, use the new optimized Tungsten physical execution backend.",
isPublic = false)

val SUBEXPRESSION_ELIMINATION_ENABLED = booleanConf("spark.sql.subexpressionElimination.enabled",
defaultValue = Some(true), // use CODEGEN_ENABLED as default
doc = "When true, common subexpressions will be eliminated.",
isPublic = false)

val DIALECT = stringConf(
"spark.sql.dialect",
defaultValue = Some("sql"),
Expand Down Expand Up @@ -532,6 +537,9 @@ private[sql] class SQLConf extends Serializable with CatalystConf {

private[spark] def unsafeEnabled: Boolean = getConf(UNSAFE_ENABLED, getConf(TUNGSTEN_ENABLED))

private[spark] def subexpressionEliminationEnabled: Boolean =
getConf(SUBEXPRESSION_ELIMINATION_ENABLED, codegenEnabled)

private[spark] def useSqlAggregate2: Boolean = getConf(USE_SQL_AGGREGATE2)

private[spark] def autoBroadcastJoinThreshold: Int = getConf(AUTO_BROADCASTJOIN_THRESHOLD)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,11 @@ abstract class SparkPlan extends QueryPlan[SparkPlan] with Logging with Serializ
} else {
false
}
val subexpressionEliminationEnabled: Boolean = if (sqlContext != null) {
sqlContext.conf.subexpressionEliminationEnabled
} else {
false
}

/**
* Whether the "prepare" method is called.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,8 @@ case class TungstenProject(projectList: Seq[NamedExpression], child: SparkPlan)
protected override def doExecute(): RDD[InternalRow] = {
val numRows = longMetric("numRows")
child.execute().mapPartitions { iter =>
val project = UnsafeProjection.create(projectList, child.output)
val project = UnsafeProjection.create(projectList, child.output,
subexpressionEliminationEnabled)
iter.map { row =>
numRows += 1
project(row)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2042,5 +2042,11 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext {
// Would be nice if semantic equals for `+` understood commutative
verifyCallCount(
df.selectExpr("testUdf(a + 1) + testUdf(1 + a)", "testUdf(a + 1)"), Row(4, 2), 2)

// Try disabling it via configuration.
sqlContext.setConf("spark.sql.subexpressionElimination.enabled", "false")
verifyCallCount(df.selectExpr("testUdf(a)", "testUdf(a)"), Row(1, 1), 2)
sqlContext.setConf("spark.sql.subexpressionElimination.enabled", "true")
verifyCallCount(df.selectExpr("testUdf(a)", "testUdf(a)"), Row(1, 1), 1)
}
}

0 comments on commit 6cf0186

Please sign in to comment.