Skip to content

Commit

Permalink
[SPARK-13031][SQL] cleanup codegen and improve test coverage
Browse files Browse the repository at this point in the history
1. enable whole stage codegen during tests even there is only one operator supports that.
2. split doProduce() into two APIs: upstream() and doProduce()
3. generate prefix for fresh names of each operator
4. pass UnsafeRow to parent directly (avoid getters and create UnsafeRow again)
5. fix bugs and tests.

This PR re-open #10944 and fix the bug.

Author: Davies Liu <davies@databricks.com>

Closes #10977 from davies/gen_refactor.
  • Loading branch information
Davies Liu authored and rxin committed Jan 29, 2016
1 parent 8d3cc3d commit 55561e7
Show file tree
Hide file tree
Showing 11 changed files with 350 additions and 205 deletions.
Expand Up @@ -144,14 +144,23 @@ class CodegenContext {

private val curId = new java.util.concurrent.atomic.AtomicInteger()

/**
* A prefix used to generate fresh name.
*/
var freshNamePrefix = ""

/**
* Returns a term name that is unique within this instance of a `CodeGenerator`.
*
* (Since we aren't in a macro context we do not seem to have access to the built in `freshName`
* function.)
*/
def freshName(prefix: String): String = {
s"$prefix${curId.getAndIncrement}"
def freshName(name: String): String = {
if (freshNamePrefix == "") {
s"$name${curId.getAndIncrement}"
} else {
s"${freshNamePrefix}_$name${curId.getAndIncrement}"
}
}

/**
Expand Down
Expand Up @@ -93,7 +93,7 @@ object GenerateMutableProjection extends CodeGenerator[Seq[Expression], () => Mu
// Can't call setNullAt on DecimalType, because we need to keep the offset
s"""
if (this.isNull_$i) {
${ctx.setColumn("mutableRow", e.dataType, i, null)};
${ctx.setColumn("mutableRow", e.dataType, i, "null")};
} else {
${ctx.setColumn("mutableRow", e.dataType, i, s"this.value_$i")};
}
Expand Down
Expand Up @@ -22,9 +22,11 @@ import scala.collection.mutable.ArrayBuffer
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.SQLContext
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.{Attribute, BoundReference, Expression, LeafExpression}
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.expressions.codegen._
import org.apache.spark.sql.catalyst.plans.physical.Partitioning
import org.apache.spark.sql.catalyst.rules.Rule
import org.apache.spark.util.Utils

/**
* An interface for those physical operators that support codegen.
Expand All @@ -42,10 +44,16 @@ trait CodegenSupport extends SparkPlan {
private var parent: CodegenSupport = null

/**
* Returns an input RDD of InternalRow and Java source code to process them.
* Returns the RDD of InternalRow which generates the input rows.
*/
def produce(ctx: CodegenContext, parent: CodegenSupport): (RDD[InternalRow], String) = {
def upstream(): RDD[InternalRow]

/**
* Returns Java source code to process the rows from upstream.
*/
def produce(ctx: CodegenContext, parent: CodegenSupport): String = {
this.parent = parent
ctx.freshNamePrefix = nodeName
doProduce(ctx)
}

Expand All @@ -66,16 +74,41 @@ trait CodegenSupport extends SparkPlan {
* # call consume(), wich will call parent.doConsume()
* }
*/
protected def doProduce(ctx: CodegenContext): (RDD[InternalRow], String)
protected def doProduce(ctx: CodegenContext): String

/**
* Consume the columns generated from current SparkPlan, call it's parent or create an iterator.
* Consume the columns generated from current SparkPlan, call it's parent.
*/
protected def consume(ctx: CodegenContext, columns: Seq[ExprCode]): String = {
assert(columns.length == output.length)
parent.doConsume(ctx, this, columns)
def consume(ctx: CodegenContext, input: Seq[ExprCode], row: String = null): String = {
if (input != null) {
assert(input.length == output.length)
}
parent.consumeChild(ctx, this, input, row)
}

/**
* Consume the columns generated from it's child, call doConsume() or emit the rows.
*/
def consumeChild(
ctx: CodegenContext,
child: SparkPlan,
input: Seq[ExprCode],
row: String = null): String = {
ctx.freshNamePrefix = nodeName
if (row != null) {
ctx.currentVars = null
ctx.INPUT_ROW = row
val evals = child.output.zipWithIndex.map { case (attr, i) =>
BoundReference(i, attr.dataType, attr.nullable).gen(ctx)
}
s"""
| ${evals.map(_.code).mkString("\n")}
| ${doConsume(ctx, evals)}
""".stripMargin
} else {
doConsume(ctx, input)
}
}

/**
* Generate the Java source code to process the rows from child SparkPlan.
Expand All @@ -89,7 +122,9 @@ trait CodegenSupport extends SparkPlan {
* # call consume(), which will call parent.doConsume()
* }
*/
def doConsume(ctx: CodegenContext, child: SparkPlan, input: Seq[ExprCode]): String
protected def doConsume(ctx: CodegenContext, input: Seq[ExprCode]): String = {
throw new UnsupportedOperationException
}
}


Expand All @@ -102,31 +137,36 @@ trait CodegenSupport extends SparkPlan {
case class InputAdapter(child: SparkPlan) extends LeafNode with CodegenSupport {

override def output: Seq[Attribute] = child.output
override def outputPartitioning: Partitioning = child.outputPartitioning
override def outputOrdering: Seq[SortOrder] = child.outputOrdering

override def doPrepare(): Unit = {
child.prepare()
}

override def supportCodegen: Boolean = true
override def doExecute(): RDD[InternalRow] = {
child.execute()
}

override def doProduce(ctx: CodegenContext): (RDD[InternalRow], String) = {
override def supportCodegen: Boolean = false

override def upstream(): RDD[InternalRow] = {
child.execute()
}

override def doProduce(ctx: CodegenContext): String = {
val exprs = output.zipWithIndex.map(x => new BoundReference(x._2, x._1.dataType, true))
val row = ctx.freshName("row")
ctx.INPUT_ROW = row
ctx.currentVars = null
val columns = exprs.map(_.gen(ctx))
val code = s"""
| while (input.hasNext()) {
s"""
| while (input.hasNext()) {
| InternalRow $row = (InternalRow) input.next();
| ${columns.map(_.code).mkString("\n")}
| ${consume(ctx, columns)}
| }
""".stripMargin
(child.execute(), code)
}

def doConsume(ctx: CodegenContext, child: SparkPlan, input: Seq[ExprCode]): String = {
throw new UnsupportedOperationException
}

override def doExecute(): RDD[InternalRow] = {
throw new UnsupportedOperationException
}

override def simpleString: String = "INPUT"
Expand All @@ -143,16 +183,20 @@ case class InputAdapter(child: SparkPlan) extends LeafNode with CodegenSupport {
*
* -> execute()
* |
* doExecute() --------> produce()
* doExecute() ---------> upstream() -------> upstream() ------> execute()
* |
* -----------------> produce()
* |
* doProduce() -------> produce()
* |
* doProduce() ---> execute()
* doProduce()
* |
* consume()
* doConsume() ------------|
* consumeChild() <-----------|
* |
* doConsume() <----- consume()
* doConsume()
* |
* consumeChild() <----- consume()
*
* SparkPlan A should override doProduce() and doConsume().
*
Expand All @@ -162,37 +206,48 @@ case class InputAdapter(child: SparkPlan) extends LeafNode with CodegenSupport {
case class WholeStageCodegen(plan: CodegenSupport, children: Seq[SparkPlan])
extends SparkPlan with CodegenSupport {

override def supportCodegen: Boolean = false

override def output: Seq[Attribute] = plan.output
override def outputPartitioning: Partitioning = plan.outputPartitioning
override def outputOrdering: Seq[SortOrder] = plan.outputOrdering

override def doPrepare(): Unit = {
plan.prepare()
}

override def doExecute(): RDD[InternalRow] = {
val ctx = new CodegenContext
val (rdd, code) = plan.produce(ctx, this)
val code = plan.produce(ctx, this)
val references = ctx.references.toArray
val source = s"""
public Object generate(Object[] references) {
return new GeneratedIterator(references);
return new GeneratedIterator(references);
}

class GeneratedIterator extends org.apache.spark.sql.execution.BufferedRowIterator {

private Object[] references;
${ctx.declareMutableStates()}
private Object[] references;
${ctx.declareMutableStates()}
${ctx.declareAddedFunctions()}

public GeneratedIterator(Object[] references) {
public GeneratedIterator(Object[] references) {
this.references = references;
${ctx.initMutableStates()}
}
}

protected void processNext() {
protected void processNext() throws java.io.IOException {
$code
}
}
}
"""
"""

// try to compile, helpful for debug
// println(s"${CodeFormatter.format(source)}")
CodeGenerator.compile(source)

rdd.mapPartitions { iter =>
plan.upstream().mapPartitions { iter =>

val clazz = CodeGenerator.compile(source)
val buffer = clazz.generate(references).asInstanceOf[BufferedRowIterator]
buffer.setInput(iter)
Expand All @@ -203,29 +258,47 @@ case class WholeStageCodegen(plan: CodegenSupport, children: Seq[SparkPlan])
}
}

override def doProduce(ctx: CodegenContext): (RDD[InternalRow], String) = {
override def upstream(): RDD[InternalRow] = {
throw new UnsupportedOperationException
}

override def doConsume(ctx: CodegenContext, child: SparkPlan, input: Seq[ExprCode]): String = {
if (input.nonEmpty) {
val colExprs = output.zipWithIndex.map { case (attr, i) =>
BoundReference(i, attr.dataType, attr.nullable)
}
// generate the code to create a UnsafeRow
ctx.currentVars = input
val code = GenerateUnsafeProjection.createCode(ctx, colExprs, false)
s"""
| ${code.code.trim}
| currentRow = ${code.value};
| return;
""".stripMargin
} else {
// There is no columns
override def doProduce(ctx: CodegenContext): String = {
throw new UnsupportedOperationException
}

override def consumeChild(
ctx: CodegenContext,
child: SparkPlan,
input: Seq[ExprCode],
row: String = null): String = {

if (row != null) {
// There is an UnsafeRow already
s"""
| currentRow = unsafeRow;
| currentRow = $row;
| return;
""".stripMargin
} else {
assert(input != null)
if (input.nonEmpty) {
val colExprs = output.zipWithIndex.map { case (attr, i) =>
BoundReference(i, attr.dataType, attr.nullable)
}
// generate the code to create a UnsafeRow
ctx.currentVars = input
val code = GenerateUnsafeProjection.createCode(ctx, colExprs, false)
s"""
| ${code.code.trim}
| currentRow = ${code.value};
| return;
""".stripMargin
} else {
// There is no columns
s"""
| currentRow = unsafeRow;
| return;
""".stripMargin
}
}
}

Expand All @@ -246,7 +319,7 @@ case class WholeStageCodegen(plan: CodegenSupport, children: Seq[SparkPlan])
builder.append(simpleString)
builder.append("\n")

plan.generateTreeString(depth + 1, lastChildren :+children.isEmpty :+ true, builder)
plan.generateTreeString(depth + 2, lastChildren :+ false :+ true, builder)
if (children.nonEmpty) {
children.init.foreach(_.generateTreeString(depth + 1, lastChildren :+ false, builder))
children.last.generateTreeString(depth + 1, lastChildren :+ true, builder)
Expand Down Expand Up @@ -286,13 +359,14 @@ private[sql] case class CollapseCodegenStages(sqlContext: SQLContext) extends Ru
case plan: CodegenSupport if supportCodegen(plan) &&
// Whole stage codegen is only useful when there are at least two levels of operators that
// support it (save at least one projection/iterator).
plan.children.exists(supportCodegen) =>
(Utils.isTesting || plan.children.exists(supportCodegen)) =>

var inputs = ArrayBuffer[SparkPlan]()
val combined = plan.transform {
case p if !supportCodegen(p) =>
inputs += p
InputAdapter(p)
val input = apply(p) // collapse them recursively
inputs += input
InputAdapter(input)
}.asInstanceOf[CodegenSupport]
WholeStageCodegen(combined, inputs)
}
Expand Down
Expand Up @@ -238,7 +238,7 @@ abstract class AggregationIterator(
resultProjection(joinedRow(currentGroupingKey, currentBuffer))
}
} else {
// Grouping-only: we only output values of grouping expressions.
// Grouping-only: we only output values based on grouping expressions.
val resultProjection = UnsafeProjection.create(resultExpressions, groupingAttributes)
(currentGroupingKey: UnsafeRow, currentBuffer: MutableRow) => {
resultProjection(currentGroupingKey)
Expand Down

0 comments on commit 55561e7

Please sign in to comment.