Skip to content

Commit

Permalink
[SPARK-12796] [SQL] Whole stage codegen
Browse files Browse the repository at this point in the history
This is the initial work for whole stage codegen, it support Projection/Filter/Range, we will continue work on this to support more physical operators.

A micro benchmark show that a query with range, filter and projection could be 3X faster then before.

It's turned on by default. For a tree that have at least two chained plans, a WholeStageCodegen will be inserted into it, for example, the following plan
```
Limit 10
+- Project [(id#5L + 1) AS (id + 1)#6L]
   +- Filter ((id#5L & 1) = 1)
      +- Range 0, 1, 4, 10, [id#5L]
```
will be translated into
```
Limit 10
+- WholeStageCodegen
      +- Project [(id#1L + 1) AS (id + 1)#2L]
         +- Filter ((id#1L & 1) = 1)
            +- Range 0, 1, 4, 10, [id#1L]
```

Here is the call graph to generate Java source for A and B (A  support codegen, but B does not):

```
  *   WholeStageCodegen       Plan A               FakeInput        Plan B
  * =========================================================================
  *
  * -> execute()
  *     |
  *  doExecute() -------->   produce()
  *                             |
  *                          doProduce()  -------> produce()
  *                                                   |
  *                                                doProduce() ---> execute()
  *                                                   |
  *                                                consume()
  *                          doConsume()  ------------|
  *                             |
  *  doConsume()  <-----    consume()
```

A SparkPlan that support codegen need to implement doProduce() and doConsume():

```
def doProduce(ctx: CodegenContext): (RDD[InternalRow], String)
def doConsume(ctx: CodegenContext, child: SparkPlan, input: Seq[ExprCode]): String
```

Author: Davies Liu <davies@databricks.com>

Closes #10735 from davies/whole2.
  • Loading branch information
Davies Liu authored and davies committed Jan 16, 2016
1 parent 86972fa commit 3c0d236
Show file tree
Hide file tree
Showing 37 changed files with 694 additions and 107 deletions.
Expand Up @@ -61,7 +61,11 @@ case class BoundReference(ordinal: Int, dataType: DataType, nullable: Boolean)
override def genCode(ctx: CodegenContext, ev: ExprCode): String = {
val javaType = ctx.javaType(dataType)
val value = ctx.getValue(ctx.INPUT_ROW, dataType, ordinal.toString)
if (nullable) {
if (ctx.currentVars != null && ctx.currentVars(ordinal) != null) {
ev.isNull = ctx.currentVars(ordinal).isNull
ev.value = ctx.currentVars(ordinal).value
""
} else if (nullable) {
s"""
boolean ${ev.isNull} = ${ctx.INPUT_ROW}.isNullAt($ordinal);
$javaType ${ev.value} = ${ev.isNull} ? ${ctx.defaultValue(dataType)} : ($value);
Expand Down
Expand Up @@ -55,6 +55,12 @@ class CodegenContext {
*/
val references: mutable.ArrayBuffer[Any] = new mutable.ArrayBuffer[Any]()

/**
* Holding a list of generated columns as input of current operator, will be used by
* BoundReference to generate code.
*/
var currentVars: Seq[ExprCode] = null

/**
* Holding expressions' mutable states like `MonotonicallyIncreasingID.count` as a
* 3-tuple: java type, variable name, code to init it.
Expand All @@ -77,6 +83,16 @@ class CodegenContext {
mutableStates += ((javaType, variableName, initCode))
}

def declareMutableStates(): String = {
mutableStates.map { case (javaType, variableName, _) =>
s"private $javaType $variableName;"
}.mkString("\n")
}

def initMutableStates(): String = {
mutableStates.map(_._3).mkString("\n")
}

/**
* Holding all the functions those will be added into generated class.
*/
Expand Down Expand Up @@ -111,6 +127,10 @@ class CodegenContext {
// The collection of sub-exression result resetting methods that need to be called on each row.
val subExprResetVariables = mutable.ArrayBuffer.empty[String]

def declareAddedFunctions(): String = {
addedFunctions.map { case (funcName, funcCode) => funcCode }.mkString("\n")
}

final val JAVA_BOOLEAN = "boolean"
final val JAVA_BYTE = "byte"
final val JAVA_SHORT = "short"
Expand All @@ -120,7 +140,7 @@ class CodegenContext {
final val JAVA_DOUBLE = "double"

/** The variable name of the input row in generated code. */
final val INPUT_ROW = "i"
final var INPUT_ROW = "i"

private val curId = new java.util.concurrent.atomic.AtomicInteger()

Expand Down Expand Up @@ -476,20 +496,6 @@ abstract class CodeGenerator[InType <: AnyRef, OutType <: AnyRef] extends Loggin

protected val genericMutableRowType: String = classOf[GenericMutableRow].getName

protected def declareMutableStates(ctx: CodegenContext): String = {
ctx.mutableStates.map { case (javaType, variableName, _) =>
s"private $javaType $variableName;"
}.mkString("\n")
}

protected def initMutableStates(ctx: CodegenContext): String = {
ctx.mutableStates.map(_._3).mkString("\n")
}

protected def declareAddedFunctions(ctx: CodegenContext): String = {
ctx.addedFunctions.map { case (funcName, funcCode) => funcCode }.mkString("\n").trim
}

/**
* Generates a class for a given input expression. Called when there is not cached code
* already available.
Expand All @@ -505,16 +511,33 @@ abstract class CodeGenerator[InType <: AnyRef, OutType <: AnyRef] extends Loggin
/** Binds an input expression to a given input schema */
protected def bind(in: InType, inputSchema: Seq[Attribute]): InType

/** Generates the requested evaluator binding the given expression(s) to the inputSchema. */
def generate(expressions: InType, inputSchema: Seq[Attribute]): OutType =
generate(bind(expressions, inputSchema))

/** Generates the requested evaluator given already bound expression(s). */
def generate(expressions: InType): OutType = create(canonicalize(expressions))

/**
* Compile the Java source code into a Java class, using Janino.
* Create a new codegen context for expression evaluator, used to store those
* expressions that don't support codegen
*/
protected def compile(code: String): GeneratedClass = {
def newCodeGenContext(): CodegenContext = {
new CodegenContext
}
}

object CodeGenerator extends Logging {
/**
* Compile the Java source code into a Java class, using Janino.
*/
def compile(code: String): GeneratedClass = {
cache.get(code)
}

/**
* Compile the Java source code into a Java class, using Janino.
*/
* Compile the Java source code into a Java class, using Janino.
*/
private[this] def doCompile(code: String): GeneratedClass = {
val evaluator = new ClassBodyEvaluator()
evaluator.setParentClassLoader(Utils.getContextOrSparkClassLoader)
Expand Down Expand Up @@ -577,19 +600,4 @@ abstract class CodeGenerator[InType <: AnyRef, OutType <: AnyRef] extends Loggin
result
}
})

/** Generates the requested evaluator binding the given expression(s) to the inputSchema. */
def generate(expressions: InType, inputSchema: Seq[Attribute]): OutType =
generate(bind(expressions, inputSchema))

/** Generates the requested evaluator given already bound expression(s). */
def generate(expressions: InType): OutType = create(canonicalize(expressions))

/**
* Create a new codegen context for expression evaluator, used to store those
* expressions that don't support codegen
*/
def newCodeGenContext(): CodegenContext = {
new CodegenContext
}
}
Expand Up @@ -17,7 +17,7 @@

package org.apache.spark.sql.catalyst.expressions.codegen

import org.apache.spark.sql.catalyst.expressions.{Expression, Nondeterministic}
import org.apache.spark.sql.catalyst.expressions.{Expression, LeafExpression, Nondeterministic}

/**
* A trait that can be used to provide a fallback mode for expression code generation.
Expand All @@ -30,13 +30,15 @@ trait CodegenFallback extends Expression {
case _ =>
}

// LeafNode does not need `input`
val input = if (this.isInstanceOf[LeafExpression]) "null" else ctx.INPUT_ROW
val idx = ctx.references.length
ctx.references += this
val objectTerm = ctx.freshName("obj")
if (nullable) {
s"""
/* expression: ${this.toCommentSafeString} */
Object $objectTerm = ((Expression) references[$idx]).eval(${ctx.INPUT_ROW});
Object $objectTerm = ((Expression) references[$idx]).eval($input);
boolean ${ev.isNull} = $objectTerm == null;
${ctx.javaType(this.dataType)} ${ev.value} = ${ctx.defaultValue(this.dataType)};
if (!${ev.isNull}) {
Expand All @@ -47,7 +49,7 @@ trait CodegenFallback extends Expression {
ev.isNull = "false"
s"""
/* expression: ${this.toCommentSafeString} */
Object $objectTerm = ((Expression) references[$idx]).eval(${ctx.INPUT_ROW});
Object $objectTerm = ((Expression) references[$idx]).eval($input);
${ctx.javaType(this.dataType)} ${ev.value} = (${ctx.boxedType(this.dataType)}) $objectTerm;
"""
}
Expand Down
Expand Up @@ -107,13 +107,13 @@ object GenerateMutableProjection extends CodeGenerator[Seq[Expression], () => Mu

private Object[] references;
private MutableRow mutableRow;
${declareMutableStates(ctx)}
${declareAddedFunctions(ctx)}
${ctx.declareMutableStates()}
${ctx.declareAddedFunctions()}

public SpecificMutableProjection(Object[] references) {
this.references = references;
mutableRow = new $genericMutableRowType(${expressions.size});
${initMutableStates(ctx)}
${ctx.initMutableStates()}
}

public ${classOf[BaseMutableProjection].getName} target(MutableRow row) {
Expand All @@ -138,7 +138,7 @@ object GenerateMutableProjection extends CodeGenerator[Seq[Expression], () => Mu

logDebug(s"code for ${expressions.mkString(",")}:\n${CodeFormatter.format(code)}")

val c = compile(code)
val c = CodeGenerator.compile(code)
() => {
c.generate(ctx.references.toArray).asInstanceOf[MutableProjection]
}
Expand Down
Expand Up @@ -118,12 +118,12 @@ object GenerateOrdering extends CodeGenerator[Seq[SortOrder], Ordering[InternalR
class SpecificOrdering extends ${classOf[BaseOrdering].getName} {

private Object[] references;
${declareMutableStates(ctx)}
${declareAddedFunctions(ctx)}
${ctx.declareMutableStates()}
${ctx.declareAddedFunctions()}

public SpecificOrdering(Object[] references) {
this.references = references;
${initMutableStates(ctx)}
${ctx.initMutableStates()}
}

public int compare(InternalRow a, InternalRow b) {
Expand All @@ -135,6 +135,6 @@ object GenerateOrdering extends CodeGenerator[Seq[SortOrder], Ordering[InternalR

logDebug(s"Generated Ordering: ${CodeFormatter.format(code)}")

compile(code).generate(ctx.references.toArray).asInstanceOf[BaseOrdering]
CodeGenerator.compile(code).generate(ctx.references.toArray).asInstanceOf[BaseOrdering]
}
}
Expand Up @@ -47,12 +47,12 @@ object GeneratePredicate extends CodeGenerator[Expression, (InternalRow) => Bool

class SpecificPredicate extends ${classOf[Predicate].getName} {
private final Object[] references;
${declareMutableStates(ctx)}
${declareAddedFunctions(ctx)}
${ctx.declareMutableStates()}
${ctx.declareAddedFunctions()}

public SpecificPredicate(Object[] references) {
this.references = references;
${initMutableStates(ctx)}
${ctx.initMutableStates()}
}

public boolean eval(InternalRow ${ctx.INPUT_ROW}) {
Expand All @@ -63,7 +63,7 @@ object GeneratePredicate extends CodeGenerator[Expression, (InternalRow) => Bool

logDebug(s"Generated predicate '$predicate':\n${CodeFormatter.format(code)}")

val p = compile(code).generate(ctx.references.toArray).asInstanceOf[Predicate]
val p = CodeGenerator.compile(code).generate(ctx.references.toArray).asInstanceOf[Predicate]
(r: InternalRow) => p.eval(r)
}
}
Expand Up @@ -160,13 +160,13 @@ object GenerateSafeProjection extends CodeGenerator[Seq[Expression], Projection]

private Object[] references;
private MutableRow mutableRow;
${declareMutableStates(ctx)}
${declareAddedFunctions(ctx)}
${ctx.declareMutableStates()}
${ctx.declareAddedFunctions()}

public SpecificSafeProjection(Object[] references) {
this.references = references;
mutableRow = new $genericMutableRowType(${expressions.size});
${initMutableStates(ctx)}
${ctx.initMutableStates()}
}

public java.lang.Object apply(java.lang.Object _i) {
Expand All @@ -179,7 +179,7 @@ object GenerateSafeProjection extends CodeGenerator[Seq[Expression], Projection]

logDebug(s"code for ${expressions.mkString(",")}:\n${CodeFormatter.format(code)}")

val c = compile(code)
val c = CodeGenerator.compile(code)
c.generate(ctx.references.toArray).asInstanceOf[Projection]
}
}
Expand Up @@ -338,14 +338,12 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro
class SpecificUnsafeProjection extends ${classOf[UnsafeProjection].getName} {

private Object[] references;

${declareMutableStates(ctx)}

${declareAddedFunctions(ctx)}
${ctx.declareMutableStates()}
${ctx.declareAddedFunctions()}

public SpecificUnsafeProjection(Object[] references) {
this.references = references;
${initMutableStates(ctx)}
${ctx.initMutableStates()}
}

// Scala.Function1 need this
Expand All @@ -362,7 +360,7 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro

logDebug(s"code for ${expressions.mkString(",")}:\n${CodeFormatter.format(code)}")

val c = compile(code)
val c = CodeGenerator.compile(code)
c.generate(ctx.references.toArray).asInstanceOf[UnsafeProjection]
}
}
Expand Up @@ -196,7 +196,7 @@ object GenerateUnsafeRowJoiner extends CodeGenerator[(StructType, StructType), U

logDebug(s"SpecificUnsafeRowJoiner($schema1, $schema2):\n${CodeFormatter.format(code)}")

val c = compile(code)
val c = CodeGenerator.compile(code)
c.generate(Array.empty).asInstanceOf[UnsafeRowJoiner]
}
}
Expand Up @@ -224,6 +224,7 @@ object CaseWhen {
}
}


/**
* Case statements of the form "CASE a WHEN b THEN c [WHEN d THEN e]* [ELSE f] END".
* When a = b, returns c; when a = d, returns e; else returns f.
Expand Down
Expand Up @@ -351,8 +351,7 @@ case class Murmur3Hash(children: Seq[Expression], seed: Int) extends Expression
val hasher = classOf[Murmur3_x86_32].getName
def hashInt(i: String): ExprCode = inlineValue(s"$hasher.hashInt($i, $seed)")
def hashLong(l: String): ExprCode = inlineValue(s"$hasher.hashLong($l, $seed)")
def inlineValue(v: String): ExprCode =
ExprCode(code = "", isNull = "false", value = v)
def inlineValue(v: String): ExprCode = ExprCode(code = "", isNull = "false", value = v)

dataType match {
case NullType => inlineValue(seed)
Expand Down
Expand Up @@ -452,7 +452,7 @@ abstract class TreeNode[BaseType <: TreeNode[BaseType]] extends Product {
* depth `i + 1` is the last child of its own parent node. The depth of the root node is 0, and
* `lastChildren` for the root node should be empty.
*/
protected def generateTreeString(
def generateTreeString(
depth: Int, lastChildren: Seq[Boolean], builder: StringBuilder): StringBuilder = {
if (depth > 0) {
lastChildren.init.foreach { isLast =>
Expand Down
3 changes: 0 additions & 3 deletions sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala
Expand Up @@ -18,7 +18,6 @@
package org.apache.spark.sql

import java.io.CharArrayWriter
import java.util.Properties

import scala.language.implicitConversions
import scala.reflect.ClassTag
Expand All @@ -39,12 +38,10 @@ import org.apache.spark.sql.catalyst.plans.logical._
import org.apache.spark.sql.execution.{EvaluatePython, ExplainCommand, FileRelation, LogicalRDD, Queryable, QueryExecution, SQLExecution}
import org.apache.spark.sql.execution.datasources.{CreateTableUsingAsSelect, LogicalRelation}
import org.apache.spark.sql.execution.datasources.json.JacksonGenerator
import org.apache.spark.sql.sources.HadoopFsRelation
import org.apache.spark.sql.types._
import org.apache.spark.storage.StorageLevel
import org.apache.spark.util.Utils


private[sql] object DataFrame {
def apply(sqlContext: SQLContext, logicalPlan: LogicalPlan): DataFrame = {
new DataFrame(sqlContext, logicalPlan)
Expand Down
9 changes: 9 additions & 0 deletions sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala
Expand Up @@ -489,6 +489,13 @@ private[spark] object SQLConf {
isPublic = false,
doc = "This flag should be set to true to enable support for SQL2011 reserved keywords.")

val WHOLESTAGE_CODEGEN_ENABLED = booleanConf("spark.sql.codegen.wholeStage",
defaultValue = Some(true),
doc = "When true, the whole stage (of multiple operators) will be compiled into single java" +
" method",
isPublic = false)


object Deprecated {
val MAPRED_REDUCE_TASKS = "mapred.reduce.tasks"
val EXTERNAL_SORT = "spark.sql.planner.externalSort"
Expand Down Expand Up @@ -561,6 +568,8 @@ private[sql] class SQLConf extends Serializable with CatalystConf with ParserCon

private[spark] def nativeView: Boolean = getConf(NATIVE_VIEW)

private[spark] def wholeStageEnabled: Boolean = getConf(WHOLESTAGE_CODEGEN_ENABLED)

def caseSensitiveAnalysis: Boolean = getConf(SQLConf.CASE_SENSITIVE)

private[spark] def subexpressionEliminationEnabled: Boolean =
Expand Down
Expand Up @@ -904,7 +904,8 @@ class SQLContext private[sql](
@transient
protected[sql] val prepareForExecution = new RuleExecutor[SparkPlan] {
val batches = Seq(
Batch("Add exchange", Once, EnsureRequirements(self))
Batch("Add exchange", Once, EnsureRequirements(self)),
Batch("Whole stage codegen", Once, CollapseCodegenStages(self))
)
}

Expand Down

0 comments on commit 3c0d236

Please sign in to comment.