Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SPARK-13404] [SQL] Create variables for input row when it's actually used #11274

Closed
wants to merge 14 commits into from
Original file line number Diff line number Diff line change
Expand Up @@ -62,9 +62,10 @@ case class BoundReference(ordinal: Int, dataType: DataType, nullable: Boolean)
val javaType = ctx.javaType(dataType)
val value = ctx.getValue(ctx.INPUT_ROW, dataType, ordinal.toString)
if (ctx.currentVars != null && ctx.currentVars(ordinal) != null) {
ev.isNull = ctx.currentVars(ordinal).isNull
ev.value = ctx.currentVars(ordinal).value
""
val oev = ctx.currentVars(ordinal)
ev.isNull = oev.isNull
ev.value = oev.value
oev.code
} else if (nullable) {
s"""
boolean ${ev.isNull} = ${ctx.INPUT_ROW}.isNullAt($ordinal);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,8 @@ import org.apache.spark.util.Utils
* Java source for evaluating an [[Expression]] given a [[InternalRow]] of input.
*
* @param code The sequence of statements required to evaluate the expression.
* It should be empty string, if `isNull` and `value` are already existed, or no code
* needed to evaluate them (literals).
* @param isNull A term that holds a boolean value representing whether the expression evaluated
* to null.
* @param value A term for a (possibly primitive) value of the result of the evaluation. Not
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -151,17 +151,16 @@ private[sql] case class PhysicalRDD(
val exprs = output.zipWithIndex.map(x => new BoundReference(x._2, x._1.dataType, true))
val row = ctx.freshName("row")
val numOutputRows = metricTerm(ctx, "numOutputRows")
ctx.INPUT_ROW = row
ctx.currentVars = null
val columns = exprs.map(_.gen(ctx))

// The input RDD can either return (all) ColumnarBatches or InternalRows. We determine this
// by looking at the first value of the RDD and then calling the function which will process
// the remaining. It is faster to return batches.
// TODO: The abstractions between this class and SqlNewHadoopRDD makes it difficult to know
// here which path to use. Fix this.


ctx.INPUT_ROW = row
ctx.currentVars = null
val columns1 = exprs.map(_.gen(ctx))
val scanBatches = ctx.freshName("processBatches")
ctx.addNewFunction(scanBatches,
s"""
Expand All @@ -170,12 +169,11 @@ private[sql] case class PhysicalRDD(
| int numRows = $batch.numRows();
| if ($idx == 0) $numOutputRows.add(numRows);
|
| while ($idx < numRows) {
| while (!shouldStop() && $idx < numRows) {
| InternalRow $row = $batch.getRow($idx++);
| ${columns.map(_.code).mkString("\n").trim}
| ${consume(ctx, columns).trim}
| if (shouldStop()) return;
| ${consume(ctx, columns1).trim}
| }
| if (shouldStop()) return;
|
| if (!$input.hasNext()) {
| $batch = null;
Expand All @@ -186,30 +184,37 @@ private[sql] case class PhysicalRDD(
| }
| }""".stripMargin)

ctx.INPUT_ROW = row
ctx.currentVars = null
val columns2 = exprs.map(_.gen(ctx))
val inputRow = if (isUnsafeRow) row else null
val scanRows = ctx.freshName("processRows")
ctx.addNewFunction(scanRows,
s"""
| private void $scanRows(InternalRow $row) throws java.io.IOException {
| while (true) {
| boolean firstRow = true;
| while (!shouldStop() && (firstRow || $input.hasNext())) {
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@nongli Since we changed to use continue for predicates, it's tricky to get this right.

| if (firstRow) {
| firstRow = false;
| } else {
| $row = (InternalRow) $input.next();
| }
| $numOutputRows.add(1);
| ${columns.map(_.code).mkString("\n").trim}
| ${consume(ctx, columns).trim}
| if (shouldStop()) return;
| if (!$input.hasNext()) break;
| $row = (InternalRow)$input.next();
| ${consume(ctx, columns2, inputRow).trim}
| }
| }""".stripMargin)

val value = ctx.freshName("value")
s"""
| if ($batch != null) {
| $scanBatches();
| } else if ($input.hasNext()) {
| Object value = $input.next();
| if (value instanceof $columnarBatchClz) {
| $batch = ($columnarBatchClz)value;
| Object $value = $input.next();
| if ($value instanceof $columnarBatchClz) {
| $batch = ($columnarBatchClz)$value;
| $scanBatches();
| } else {
| $scanRows((InternalRow)value);
| $scanRows((InternalRow) $value);
| }
| }
""".stripMargin
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -185,8 +185,10 @@ case class Expand(

val numOutput = metricTerm(ctx, "numOutputRows")
val i = ctx.freshName("i")
// these column have to declared before the loop.
val evaluate = evaluateVariables(outputColumns)
s"""
|${outputColumns.map(_.code).mkString("\n").trim}
|$evaluate
|for (int $i = 0; $i < ${projections.length}; $i ++) {
| switch ($i) {
| ${cases.mkString("\n").trim}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -81,11 +81,14 @@ trait CodegenSupport extends SparkPlan {
this.parent = parent
ctx.freshNamePrefix = variablePrefix
waitForSubqueries()
doProduce(ctx)
s"""
|/*** PRODUCE: ${toCommentSafeString(this.simpleString)} */
|${doProduce(ctx)}
""".stripMargin
}

/**
* Generate the Java source code to process, should be overrided by subclass to support codegen.
* Generate the Java source code to process, should be overridden by subclass to support codegen.
*
* doProduce() usually generate the framework, for example, aggregation could generate this:
*
Expand All @@ -94,11 +97,11 @@ trait CodegenSupport extends SparkPlan {
* # call child.produce()
* initialized = true;
* }
* while (hashmap.hasNext()) {
* while (!shouldStop() && hashmap.hasNext()) {
* row = hashmap.next();
* # build the aggregation results
* # create varialbles for results
* # call consume(), wich will call parent.doConsume()
* # create variables for results
* # call consume(), which will call parent.doConsume()
* }
*/
protected def doProduce(ctx: CodegenContext): String
Expand All @@ -114,27 +117,71 @@ trait CodegenSupport extends SparkPlan {
}

/**
* Consume the columns generated from it's child, call doConsume() or emit the rows.
* Returns source code to evaluate all the variables, and clear the code of them, to prevent
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you add a high level comment that describes the overall framework? I think the important things to include are:

  • how it works in general?
  • how should an operator that does not short circuit (e.g. project/sort) use this?
  • how should an operator that does short circuit use this (if different)?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I was imagining something like:

evaluateAttributes(Seq[Expression]) which evaluates all the attribute refernces in the tree that haven't been. This is kind of similar to what you have below.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Some variables could be generated in the middle of the plan, for example, aggregate, and join, so we can't always use the references of current plan to determine which expression is used or not. So I have two different functions here, we could pass in the used references to the function below.

* them to be evaluated twice.
*/
protected def evaluateVariables(variables: Seq[ExprCode]): String = {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you update the comment for ExprCode.code to specify what it means when it is empty.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

val evaluate = variables.filter(_.code != "").map(_.code.trim).mkString("\n")
variables.foreach(_.code = "")
evaluate
}

/**
* Returns source code to evaluate the variables for required attributes, and clear the code
* of evaluated variables, to prevent them to be evaluated twice..
*/
protected def evaluateRequiredVariables(
attributes: Seq[Attribute],
variables: Seq[ExprCode],
required: AttributeSet): String = {
var evaluateVars = ""
variables.zipWithIndex.foreach { case (ev, i) =>
if (ev.code != "" && required.contains(attributes(i))) {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@davies I was just reviewing build warnings, and it flags this line. ev.code is a Block rather than String. Should it be ev.code.nonEmpty && ... instead?

evaluateVars += ev.code.trim + "\n"
ev.code = ""
}
}
evaluateVars
}

/**
* The subset of inputSet those should be evaluated before this plan.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this is a good place to document how this whole thing works in a couple of sentences. Something describing that we defer attribute access in the generated function. We access all the attributes needed by the operator at the beginning if it was not already referenced earlier in the pipeline.

Might also update the commit message with this since this is what most of the patch is about.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

*
* We will use this to insert some code to access those columns that are actually used by current
* plan before calling doConsume().
*/
def usedInputs: AttributeSet = references

/**
* Consume the columns generated from its child, call doConsume() or emit the rows.
*
* An operator could generate variables for the output, or a row, either one could be null.
*
* If the row is not null, we create variables to access the columns that are actually used by
* current plan before calling doConsume().
*/
def consumeChild(
ctx: CodegenContext,
child: SparkPlan,
input: Seq[ExprCode],
row: String = null): String = {
ctx.freshNamePrefix = variablePrefix
if (row != null) {
ctx.currentVars = null
ctx.INPUT_ROW = row
val evals = child.output.zipWithIndex.map { case (attr, i) =>
BoundReference(i, attr.dataType, attr.nullable).gen(ctx)
val inputVars =
if (row != null) {
ctx.currentVars = null
ctx.INPUT_ROW = row
child.output.zipWithIndex.map { case (attr, i) =>
BoundReference(i, attr.dataType, attr.nullable).gen(ctx)
}
} else {
input
}
s"""
| ${evals.map(_.code).mkString("\n")}
| ${doConsume(ctx, evals)}
""".stripMargin
} else {
doConsume(ctx, input)
}
s"""
|
|/*** CONSUME: ${toCommentSafeString(this.simpleString)} */
|${evaluateRequiredVariables(child.output, inputVars, usedInputs)}
|${doConsume(ctx, inputVars)}
""".stripMargin
}

/**
Expand All @@ -145,9 +192,8 @@ trait CodegenSupport extends SparkPlan {
* For example, Filter will generate the code like this:
*
* # code to evaluate the predicate expression, result is isNull1 and value2
* if (isNull1 || value2) {
* # call consume(), which will call parent.doConsume()
* }
* if (isNull1 || !value2) continue;
* # call consume(), which will call parent.doConsume()
*/
protected def doConsume(ctx: CodegenContext, input: Seq[ExprCode]): String = {
throw new UnsupportedOperationException
Expand Down Expand Up @@ -190,13 +236,9 @@ case class InputAdapter(child: SparkPlan) extends UnaryNode with CodegenSupport
ctx.currentVars = null
val columns = exprs.map(_.gen(ctx))
s"""
| while ($input.hasNext()) {
| while (!shouldStop() && $input.hasNext()) {
| InternalRow $row = (InternalRow) $input.next();
| ${columns.map(_.code).mkString("\n").trim}
| ${consume(ctx, columns).trim}
| if (shouldStop()) {
| return;
| }
| }
""".stripMargin
}
Expand Down Expand Up @@ -332,10 +374,12 @@ case class WholeStageCodegen(child: SparkPlan) extends UnaryNode with CodegenSup
val colExprs = output.zipWithIndex.map { case (attr, i) =>
BoundReference(i, attr.dataType, attr.nullable)
}
val evaluateInputs = evaluateVariables(input)
// generate the code to create a UnsafeRow
ctx.currentVars = input
val code = GenerateUnsafeProjection.createCode(ctx, colExprs, false)
s"""
|$evaluateInputs
|${code.code.trim}
|append(${code.value}.copy());
""".stripMargin.trim
Expand Down
Loading