Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SPARK-12796] [SQL] Whole stage codegen #10735

Closed
wants to merge 24 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
Expand Up @@ -61,7 +61,11 @@ case class BoundReference(ordinal: Int, dataType: DataType, nullable: Boolean)
override def genCode(ctx: CodegenContext, ev: ExprCode): String = {
val javaType = ctx.javaType(dataType)
val value = ctx.getValue(ctx.INPUT_ROW, dataType, ordinal.toString)
if (nullable) {
if (ctx.currentVars != null && ctx.currentVars(ordinal) != null) {
ev.isNull = ctx.currentVars(ordinal).isNull
ev.value = ctx.currentVars(ordinal).value
""
} else if (nullable) {
s"""
boolean ${ev.isNull} = ${ctx.INPUT_ROW}.isNullAt($ordinal);
$javaType ${ev.value} = ${ev.isNull} ? ${ctx.defaultValue(dataType)} : ($value);
Expand Down
Expand Up @@ -55,6 +55,12 @@ class CodegenContext {
*/
val references: mutable.ArrayBuffer[Any] = new mutable.ArrayBuffer[Any]()

/**
* Holding a list of generated columns as input of current operator, will be used by
* BoundReference to generate code.
*/
var currentVars: Seq[ExprCode] = null

/**
* Holding expressions' mutable states like `MonotonicallyIncreasingID.count` as a
* 3-tuple: java type, variable name, code to init it.
Expand All @@ -77,6 +83,16 @@ class CodegenContext {
mutableStates += ((javaType, variableName, initCode))
}

def declareMutableStates(): String = {
mutableStates.map { case (javaType, variableName, _) =>
s"private $javaType $variableName;"
}.mkString("\n")
}

def initMutableStates(): String = {
mutableStates.map(_._3).mkString("\n")
}

/**
* Holding all the functions those will be added into generated class.
*/
Expand Down Expand Up @@ -111,6 +127,10 @@ class CodegenContext {
// The collection of sub-exression result resetting methods that need to be called on each row.
val subExprResetVariables = mutable.ArrayBuffer.empty[String]

def declareAddedFunctions(): String = {
addedFunctions.map { case (funcName, funcCode) => funcCode }.mkString("\n")
}

final val JAVA_BOOLEAN = "boolean"
final val JAVA_BYTE = "byte"
final val JAVA_SHORT = "short"
Expand All @@ -120,7 +140,7 @@ class CodegenContext {
final val JAVA_DOUBLE = "double"

/** The variable name of the input row in generated code. */
final val INPUT_ROW = "i"
final var INPUT_ROW = "i"

private val curId = new java.util.concurrent.atomic.AtomicInteger()

Expand Down Expand Up @@ -476,20 +496,6 @@ abstract class CodeGenerator[InType <: AnyRef, OutType <: AnyRef] extends Loggin

protected val genericMutableRowType: String = classOf[GenericMutableRow].getName

protected def declareMutableStates(ctx: CodegenContext): String = {
ctx.mutableStates.map { case (javaType, variableName, _) =>
s"private $javaType $variableName;"
}.mkString("\n")
}

protected def initMutableStates(ctx: CodegenContext): String = {
ctx.mutableStates.map(_._3).mkString("\n")
}

protected def declareAddedFunctions(ctx: CodegenContext): String = {
ctx.addedFunctions.map { case (funcName, funcCode) => funcCode }.mkString("\n").trim
}

/**
* Generates a class for a given input expression. Called when there is not cached code
* already available.
Expand All @@ -505,16 +511,33 @@ abstract class CodeGenerator[InType <: AnyRef, OutType <: AnyRef] extends Loggin
/** Binds an input expression to a given input schema */
protected def bind(in: InType, inputSchema: Seq[Attribute]): InType

/** Generates the requested evaluator binding the given expression(s) to the inputSchema. */
def generate(expressions: InType, inputSchema: Seq[Attribute]): OutType =
generate(bind(expressions, inputSchema))

/** Generates the requested evaluator given already bound expression(s). */
def generate(expressions: InType): OutType = create(canonicalize(expressions))

/**
* Compile the Java source code into a Java class, using Janino.
* Create a new codegen context for expression evaluator, used to store those
* expressions that don't support codegen
*/
protected def compile(code: String): GeneratedClass = {
def newCodeGenContext(): CodegenContext = {
new CodegenContext
}
}

object CodeGenerator extends Logging {
/**
* Compile the Java source code into a Java class, using Janino.
*/
def compile(code: String): GeneratedClass = {
cache.get(code)
}

/**
* Compile the Java source code into a Java class, using Janino.
*/
* Compile the Java source code into a Java class, using Janino.
*/
private[this] def doCompile(code: String): GeneratedClass = {
val evaluator = new ClassBodyEvaluator()
evaluator.setParentClassLoader(Utils.getContextOrSparkClassLoader)
Expand Down Expand Up @@ -577,19 +600,4 @@ abstract class CodeGenerator[InType <: AnyRef, OutType <: AnyRef] extends Loggin
result
}
})

/** Generates the requested evaluator binding the given expression(s) to the inputSchema. */
def generate(expressions: InType, inputSchema: Seq[Attribute]): OutType =
generate(bind(expressions, inputSchema))

/** Generates the requested evaluator given already bound expression(s). */
def generate(expressions: InType): OutType = create(canonicalize(expressions))

/**
* Create a new codegen context for expression evaluator, used to store those
* expressions that don't support codegen
*/
def newCodeGenContext(): CodegenContext = {
new CodegenContext
}
}
Expand Up @@ -17,7 +17,7 @@

package org.apache.spark.sql.catalyst.expressions.codegen

import org.apache.spark.sql.catalyst.expressions.{Expression, Nondeterministic}
import org.apache.spark.sql.catalyst.expressions.{Expression, LeafExpression, Nondeterministic}

/**
* A trait that can be used to provide a fallback mode for expression code generation.
Expand All @@ -30,13 +30,15 @@ trait CodegenFallback extends Expression {
case _ =>
}

// LeafNode does not need `input`
val input = if (this.isInstanceOf[LeafExpression]) "null" else ctx.INPUT_ROW
val idx = ctx.references.length
ctx.references += this
val objectTerm = ctx.freshName("obj")
if (nullable) {
s"""
/* expression: ${this.toCommentSafeString} */
Object $objectTerm = ((Expression) references[$idx]).eval(${ctx.INPUT_ROW});
Object $objectTerm = ((Expression) references[$idx]).eval($input);
boolean ${ev.isNull} = $objectTerm == null;
${ctx.javaType(this.dataType)} ${ev.value} = ${ctx.defaultValue(this.dataType)};
if (!${ev.isNull}) {
Expand All @@ -47,7 +49,7 @@ trait CodegenFallback extends Expression {
ev.isNull = "false"
s"""
/* expression: ${this.toCommentSafeString} */
Object $objectTerm = ((Expression) references[$idx]).eval(${ctx.INPUT_ROW});
Object $objectTerm = ((Expression) references[$idx]).eval($input);
${ctx.javaType(this.dataType)} ${ev.value} = (${ctx.boxedType(this.dataType)}) $objectTerm;
"""
}
Expand Down
Expand Up @@ -107,13 +107,13 @@ object GenerateMutableProjection extends CodeGenerator[Seq[Expression], () => Mu

private Object[] references;
private MutableRow mutableRow;
${declareMutableStates(ctx)}
${declareAddedFunctions(ctx)}
${ctx.declareMutableStates()}
${ctx.declareAddedFunctions()}

public SpecificMutableProjection(Object[] references) {
this.references = references;
mutableRow = new $genericMutableRowType(${expressions.size});
${initMutableStates(ctx)}
${ctx.initMutableStates()}
}

public ${classOf[BaseMutableProjection].getName} target(MutableRow row) {
Expand All @@ -138,7 +138,7 @@ object GenerateMutableProjection extends CodeGenerator[Seq[Expression], () => Mu

logDebug(s"code for ${expressions.mkString(",")}:\n${CodeFormatter.format(code)}")

val c = compile(code)
val c = CodeGenerator.compile(code)
() => {
c.generate(ctx.references.toArray).asInstanceOf[MutableProjection]
}
Expand Down
Expand Up @@ -118,12 +118,12 @@ object GenerateOrdering extends CodeGenerator[Seq[SortOrder], Ordering[InternalR
class SpecificOrdering extends ${classOf[BaseOrdering].getName} {

private Object[] references;
${declareMutableStates(ctx)}
${declareAddedFunctions(ctx)}
${ctx.declareMutableStates()}
${ctx.declareAddedFunctions()}

public SpecificOrdering(Object[] references) {
this.references = references;
${initMutableStates(ctx)}
${ctx.initMutableStates()}
}

public int compare(InternalRow a, InternalRow b) {
Expand All @@ -135,6 +135,6 @@ object GenerateOrdering extends CodeGenerator[Seq[SortOrder], Ordering[InternalR

logDebug(s"Generated Ordering: ${CodeFormatter.format(code)}")

compile(code).generate(ctx.references.toArray).asInstanceOf[BaseOrdering]
CodeGenerator.compile(code).generate(ctx.references.toArray).asInstanceOf[BaseOrdering]
}
}
Expand Up @@ -47,12 +47,12 @@ object GeneratePredicate extends CodeGenerator[Expression, (InternalRow) => Bool

class SpecificPredicate extends ${classOf[Predicate].getName} {
private final Object[] references;
${declareMutableStates(ctx)}
${declareAddedFunctions(ctx)}
${ctx.declareMutableStates()}
${ctx.declareAddedFunctions()}

public SpecificPredicate(Object[] references) {
this.references = references;
${initMutableStates(ctx)}
${ctx.initMutableStates()}
}

public boolean eval(InternalRow ${ctx.INPUT_ROW}) {
Expand All @@ -63,7 +63,7 @@ object GeneratePredicate extends CodeGenerator[Expression, (InternalRow) => Bool

logDebug(s"Generated predicate '$predicate':\n${CodeFormatter.format(code)}")

val p = compile(code).generate(ctx.references.toArray).asInstanceOf[Predicate]
val p = CodeGenerator.compile(code).generate(ctx.references.toArray).asInstanceOf[Predicate]
(r: InternalRow) => p.eval(r)
}
}
Expand Up @@ -160,13 +160,13 @@ object GenerateSafeProjection extends CodeGenerator[Seq[Expression], Projection]

private Object[] references;
private MutableRow mutableRow;
${declareMutableStates(ctx)}
${declareAddedFunctions(ctx)}
${ctx.declareMutableStates()}
${ctx.declareAddedFunctions()}

public SpecificSafeProjection(Object[] references) {
this.references = references;
mutableRow = new $genericMutableRowType(${expressions.size});
${initMutableStates(ctx)}
${ctx.initMutableStates()}
}

public java.lang.Object apply(java.lang.Object _i) {
Expand All @@ -179,7 +179,7 @@ object GenerateSafeProjection extends CodeGenerator[Seq[Expression], Projection]

logDebug(s"code for ${expressions.mkString(",")}:\n${CodeFormatter.format(code)}")

val c = compile(code)
val c = CodeGenerator.compile(code)
c.generate(ctx.references.toArray).asInstanceOf[Projection]
}
}
Expand Up @@ -338,14 +338,12 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro
class SpecificUnsafeProjection extends ${classOf[UnsafeProjection].getName} {

private Object[] references;

${declareMutableStates(ctx)}

${declareAddedFunctions(ctx)}
${ctx.declareMutableStates()}
${ctx.declareAddedFunctions()}

public SpecificUnsafeProjection(Object[] references) {
this.references = references;
${initMutableStates(ctx)}
${ctx.initMutableStates()}
}

// Scala.Function1 need this
Expand All @@ -362,7 +360,7 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro

logDebug(s"code for ${expressions.mkString(",")}:\n${CodeFormatter.format(code)}")

val c = compile(code)
val c = CodeGenerator.compile(code)
c.generate(ctx.references.toArray).asInstanceOf[UnsafeProjection]
}
}
Expand Up @@ -196,7 +196,7 @@ object GenerateUnsafeRowJoiner extends CodeGenerator[(StructType, StructType), U

logDebug(s"SpecificUnsafeRowJoiner($schema1, $schema2):\n${CodeFormatter.format(code)}")

val c = compile(code)
val c = CodeGenerator.compile(code)
c.generate(Array.empty).asInstanceOf[UnsafeRowJoiner]
}
}
Expand Up @@ -224,6 +224,7 @@ object CaseWhen {
}
}


/**
* Case statements of the form "CASE a WHEN b THEN c [WHEN d THEN e]* [ELSE f] END".
* When a = b, returns c; when a = d, returns e; else returns f.
Expand Down
Expand Up @@ -351,8 +351,7 @@ case class Murmur3Hash(children: Seq[Expression], seed: Int) extends Expression
val hasher = classOf[Murmur3_x86_32].getName
def hashInt(i: String): ExprCode = inlineValue(s"$hasher.hashInt($i, $seed)")
def hashLong(l: String): ExprCode = inlineValue(s"$hasher.hashLong($l, $seed)")
def inlineValue(v: String): ExprCode =
ExprCode(code = "", isNull = "false", value = v)
def inlineValue(v: String): ExprCode = ExprCode(code = "", isNull = "false", value = v)

dataType match {
case NullType => inlineValue(seed)
Expand Down
Expand Up @@ -452,7 +452,7 @@ abstract class TreeNode[BaseType <: TreeNode[BaseType]] extends Product {
* depth `i + 1` is the last child of its own parent node. The depth of the root node is 0, and
* `lastChildren` for the root node should be empty.
*/
protected def generateTreeString(
def generateTreeString(
depth: Int, lastChildren: Seq[Boolean], builder: StringBuilder): StringBuilder = {
if (depth > 0) {
lastChildren.init.foreach { isLast =>
Expand Down
3 changes: 0 additions & 3 deletions sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala
Expand Up @@ -18,7 +18,6 @@
package org.apache.spark.sql

import java.io.CharArrayWriter
import java.util.Properties

import scala.language.implicitConversions
import scala.reflect.ClassTag
Expand All @@ -39,12 +38,10 @@ import org.apache.spark.sql.catalyst.plans.logical._
import org.apache.spark.sql.execution.{EvaluatePython, ExplainCommand, FileRelation, LogicalRDD, Queryable, QueryExecution, SQLExecution}
import org.apache.spark.sql.execution.datasources.{CreateTableUsingAsSelect, LogicalRelation}
import org.apache.spark.sql.execution.datasources.json.JacksonGenerator
import org.apache.spark.sql.sources.HadoopFsRelation
import org.apache.spark.sql.types._
import org.apache.spark.storage.StorageLevel
import org.apache.spark.util.Utils


private[sql] object DataFrame {
def apply(sqlContext: SQLContext, logicalPlan: LogicalPlan): DataFrame = {
new DataFrame(sqlContext, logicalPlan)
Expand Down
9 changes: 9 additions & 0 deletions sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala
Expand Up @@ -489,6 +489,13 @@ private[spark] object SQLConf {
isPublic = false,
doc = "This flag should be set to true to enable support for SQL2011 reserved keywords.")

val WHOLESTAGE_CODEGEN_ENABLED = booleanConf("spark.sql.codegen.wholeStage",
defaultValue = Some(true),
doc = "When true, the whole stage (of multiple operators) will be compiled into single java" +
" method",
isPublic = false)


object Deprecated {
val MAPRED_REDUCE_TASKS = "mapred.reduce.tasks"
val EXTERNAL_SORT = "spark.sql.planner.externalSort"
Expand Down Expand Up @@ -561,6 +568,8 @@ private[sql] class SQLConf extends Serializable with CatalystConf with ParserCon

private[spark] def nativeView: Boolean = getConf(NATIVE_VIEW)

private[spark] def wholeStageEnabled: Boolean = getConf(WHOLESTAGE_CODEGEN_ENABLED)

def caseSensitiveAnalysis: Boolean = getConf(SQLConf.CASE_SENSITIVE)

private[spark] def subexpressionEliminationEnabled: Boolean =
Expand Down
Expand Up @@ -904,7 +904,8 @@ class SQLContext private[sql](
@transient
protected[sql] val prepareForExecution = new RuleExecutor[SparkPlan] {
val batches = Seq(
Batch("Add exchange", Once, EnsureRequirements(self))
Batch("Add exchange", Once, EnsureRequirements(self)),
Batch("Whole stage codegen", Once, CollapseCodegenStages(self))
)
}

Expand Down