From d1de3002f0fad391cdea44011ce1458ca0348956 Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Sat, 4 Nov 2017 01:14:22 +0100 Subject: [PATCH 1/2] move CodegenContext.copyResult to CodegenSupport --- .../expressions/codegen/CodeGenerator.scala | 10 ------ .../sql/execution/ColumnarBatchScan.scala | 2 +- .../spark/sql/execution/ExpandExec.scala | 3 +- .../spark/sql/execution/GenerateExec.scala | 3 +- .../apache/spark/sql/execution/SortExec.scala | 14 ++++---- .../sql/execution/WholeStageCodegenExec.scala | 33 ++++++++++++++----- .../aggregate/HashAggregateExec.scala | 14 ++++---- .../execution/basicPhysicalOperators.scala | 5 +-- .../joins/BroadcastHashJoinExec.scala | 16 +++++++-- .../execution/joins/SortMergeJoinExec.scala | 3 +- 10 files changed, 64 insertions(+), 39 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala index 58738b52b299f..98eda2a1ba92c 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala @@ -139,16 +139,6 @@ class CodegenContext { */ var currentVars: Seq[ExprCode] = null - /** - * Whether should we copy the result rows or not. - * - * If any operator inside WholeStageCodegen generate multiple rows from a single row (for - * example, Join), this should be true. - * - * If an operator starts a new pipeline, this should be reset to false before calling `consume()`. - */ - var copyResult: Boolean = false - /** * Holding expressions' mutable states like `MonotonicallyIncreasingID.count` as a * 3-tuple: java type, variable name, code to init it. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/ColumnarBatchScan.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/ColumnarBatchScan.scala index eb01e126bcbef..1925bad8c3545 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/ColumnarBatchScan.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/ColumnarBatchScan.scala @@ -115,7 +115,7 @@ private[sql] trait ColumnarBatchScan extends CodegenSupport { val localIdx = ctx.freshName("localIdx") val localEnd = ctx.freshName("localEnd") val numRows = ctx.freshName("numRows") - val shouldStop = if (isShouldStopRequired) { + val shouldStop = if (parent.needStopCheck) { s"if (shouldStop()) { $idx = $rowidx + 1; return; }" } else { "// shouldStop check is eliminated" diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/ExpandExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/ExpandExec.scala index d5603b3b00914..33849f4389b92 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/ExpandExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/ExpandExec.scala @@ -93,6 +93,8 @@ case class ExpandExec( child.asInstanceOf[CodegenSupport].produce(ctx, this) } + override def needCopyResult: Boolean = true + override def doConsume(ctx: CodegenContext, input: Seq[ExprCode], row: ExprCode): String = { /* * When the projections list looks like: @@ -187,7 +189,6 @@ case class ExpandExec( val i = ctx.freshName("i") // these column have to declared before the loop. val evaluate = evaluateVariables(outputColumns) - ctx.copyResult = true s""" |$evaluate |for (int $i = 0; $i < ${projections.length}; $i ++) { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/GenerateExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/GenerateExec.scala index 65ca37491b6a1..c142d3b5ed4f2 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/GenerateExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/GenerateExec.scala @@ -132,9 +132,10 @@ case class GenerateExec( child.asInstanceOf[CodegenSupport].produce(ctx, this) } + override def needCopyResult: Boolean = true + override def doConsume(ctx: CodegenContext, input: Seq[ExprCode], row: ExprCode): String = { ctx.currentVars = input - ctx.copyResult = true // Add input rows to the values when we are joining val values = if (join) { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SortExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SortExec.scala index ff71fd4dc7bb7..21765cdbd94cf 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SortExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SortExec.scala @@ -124,6 +124,14 @@ case class SortExec( // Name of sorter variable used in codegen. private var sorterVariable: String = _ + // The result rows come from the sort buffer, so this operator doesn't need to copy its result + // even if its child does. + override def needCopyResult: Boolean = false + + // Sort operator always consumes all the input rows before outputting any result, so we don't need + // a stop check before sorting. + override def needStopCheck: Boolean = false + override protected def doProduce(ctx: CodegenContext): String = { val needToSort = ctx.freshName("needToSort") ctx.addMutableState("boolean", needToSort, s"$needToSort = true;") @@ -148,10 +156,6 @@ case class SortExec( | } """.stripMargin.trim) - // The child could change `copyResult` to true, but we had already consumed all the rows, - // so `copyResult` should be reset to `false`. - ctx.copyResult = false - val outputRow = ctx.freshName("outputRow") val peakMemory = metricTerm(ctx, "peakMemory") val spillSize = metricTerm(ctx, "spillSize") @@ -177,8 +181,6 @@ case class SortExec( """.stripMargin.trim } - protected override val shouldStopRequired = false - override def doConsume(ctx: CodegenContext, input: Seq[ExprCode], row: ExprCode): String = { s""" |${row.code} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegenExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegenExec.scala index 286cb3bb0767c..367cefce18a11 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegenExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegenExec.scala @@ -213,19 +213,32 @@ trait CodegenSupport extends SparkPlan { } /** - * For optimization to suppress shouldStop() in a loop of WholeStageCodegen. - * Returning true means we need to insert shouldStop() into the loop producing rows, if any. + * Whether or not the result rows of this operator should be copied before putting into a buffer. + * + * If any operator inside WholeStageCodegen generate multiple rows from a single row (for + * example, Join), this should be true. + * + * If an operator starts a new pipeline, this should be false. */ - def isShouldStopRequired: Boolean = { - return shouldStopRequired && (this.parent == null || this.parent.isShouldStopRequired) + def needCopyResult: Boolean = { + if (children.isEmpty) { + false + } else if (children.length == 1) { + children.head.asInstanceOf[CodegenSupport].needStopCheck + } else { + throw new UnsupportedOperationException + } } /** - * Set to false if this plan consumes all rows produced by children but doesn't output row - * to buffer by calling append(), so the children don't require shouldStop() - * in the loop of producing rows. + * Whether or not the children of this operator should generate a stop check when consuming input + * rows. This is used to suppress shouldStop() in a loop of WholeStageCodegen. + * + * This should be false if an operator starts a new pipeline, which means it consumes all rows + * produced by children but doesn't output row to buffer by calling append(), so the children + * don't require shouldStop() in the loop of producing rows. */ - protected def shouldStopRequired: Boolean = true + def needStopCheck: Boolean = parent.needStopCheck } @@ -467,7 +480,7 @@ case class WholeStageCodegenExec(child: SparkPlan) extends UnaryExecNode with Co } override def doConsume(ctx: CodegenContext, input: Seq[ExprCode], row: ExprCode): String = { - val doCopy = if (ctx.copyResult) { + val doCopy = if (needCopyResult) { ".copy()" } else { "" @@ -487,6 +500,8 @@ case class WholeStageCodegenExec(child: SparkPlan) extends UnaryExecNode with Co addSuffix: Boolean = false): StringBuilder = { child.generateTreeString(depth, lastChildren, builder, verbose, "*") } + + override def needStopCheck: Boolean = true } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala index 43e5ff89afee6..2a208a2722550 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala @@ -149,6 +149,14 @@ case class HashAggregateExec( child.asInstanceOf[CodegenSupport].inputRDDs() } + // The result rows come from the aggregate buffer, or a single row(no grouping keys), so this + // operator doesn't need to copy its result even if its child does. + override def needCopyResult: Boolean = false + + // Aggregate operator always consumes all the input rows before outputting any result, so we + // don't need a stop check before aggregating. + override def needStopCheck: Boolean = false + protected override def doProduce(ctx: CodegenContext): String = { if (groupingExpressions.isEmpty) { doProduceWithoutKeys(ctx) @@ -246,8 +254,6 @@ case class HashAggregateExec( """.stripMargin } - protected override val shouldStopRequired = false - private def doConsumeWithoutKeys(ctx: CodegenContext, input: Seq[ExprCode]): String = { // only have DeclarativeAggregate val functions = aggregateExpressions.map(_.aggregateFunction.asInstanceOf[DeclarativeAggregate]) @@ -651,10 +657,6 @@ case class HashAggregateExec( val outputFunc = generateResultFunction(ctx) val numOutput = metricTerm(ctx, "numOutputRows") - // The child could change `copyResult` to true, but we had already consumed all the rows, - // so `copyResult` should be reset to `false`. - ctx.copyResult = false - def outputFromGeneratedMap: String = { if (isFastHashMapEnabled) { if (isVectorizedHashMapEnabled) { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala index e58c3cec2df15..3c7daa0a45844 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala @@ -279,6 +279,8 @@ case class SampleExec( child.asInstanceOf[CodegenSupport].produce(ctx, this) } + override def needCopyResult: Boolean = withReplacement + override def doConsume(ctx: CodegenContext, input: Seq[ExprCode], row: ExprCode): String = { val numOutput = metricTerm(ctx, "numOutputRows") val sampler = ctx.freshName("sampler") @@ -286,7 +288,6 @@ case class SampleExec( if (withReplacement) { val samplerClass = classOf[PoissonSampler[UnsafeRow]].getName val initSampler = ctx.freshName("initSampler") - ctx.copyResult = true val initSamplerFuncName = ctx.addNewFunction(initSampler, s""" @@ -450,7 +451,7 @@ case class RangeExec(range: org.apache.spark.sql.catalyst.plans.logical.Range) val localIdx = ctx.freshName("localIdx") val localEnd = ctx.freshName("localEnd") val range = ctx.freshName("range") - val shouldStop = if (isShouldStopRequired) { + val shouldStop = if (parent.needStopCheck) { s"if (shouldStop()) { $number = $value + ${step}L; return; }" } else { "// shouldStop check is eliminated" diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoinExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoinExec.scala index b09da9bdacb99..41664d1fca32d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoinExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoinExec.scala @@ -76,6 +76,20 @@ case class BroadcastHashJoinExec( streamedPlan.asInstanceOf[CodegenSupport].inputRDDs() } + override def needCopyResult: Boolean = joinType match { + case _: InnerLike | LeftOuter | RightOuter => + // For inner and outer joins, one row from the streamed side may produce multiple result rows, + // if the build side has duplicated keys. Then we need to copy the result rows before putting + // them in a buffer, because these result rows share one UnsafeRow instance. Note that here + // we wait for the broadcast to be finished, which is a no-op because it's already finished + // when we wait it in `doProduce`. + buildPlan.executeBroadcast[HashedRelation]().value.keyIsUnique + + // Other joins types(semi, anti, existence) can at most produce one result row for one input + // row from the streamed side, so no need to copy the result rows. + case _ => false + } + override def doProduce(ctx: CodegenContext): String = { streamedPlan.asInstanceOf[CodegenSupport].produce(ctx, this) } @@ -237,7 +251,6 @@ case class BroadcastHashJoinExec( """.stripMargin } else { - ctx.copyResult = true val matches = ctx.freshName("matches") val iteratorCls = classOf[Iterator[UnsafeRow]].getName s""" @@ -310,7 +323,6 @@ case class BroadcastHashJoinExec( """.stripMargin } else { - ctx.copyResult = true val matches = ctx.freshName("matches") val iteratorCls = classOf[Iterator[UnsafeRow]].getName val found = ctx.freshName("found") diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoinExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoinExec.scala index 4e02803552e82..cf7885f80d9fe 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoinExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoinExec.scala @@ -569,8 +569,9 @@ case class SortMergeJoinExec( } } + override def needCopyResult: Boolean = true + override def doProduce(ctx: CodegenContext): String = { - ctx.copyResult = true val leftInput = ctx.freshName("leftInput") ctx.addMutableState("scala.collection.Iterator", leftInput, s"$leftInput = inputs[0];") val rightInput = ctx.freshName("rightInput") From 35c38d04a1ed7fbc0f637a54c797fc3b26e871a4 Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Sat, 4 Nov 2017 17:06:13 +0100 Subject: [PATCH 2/2] fix mistakes --- .../apache/spark/sql/execution/WholeStageCodegenExec.scala | 4 +++- .../spark/sql/execution/joins/BroadcastHashJoinExec.scala | 2 +- 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegenExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegenExec.scala index 367cefce18a11..16b5706c03bf9 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegenExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegenExec.scala @@ -224,7 +224,7 @@ trait CodegenSupport extends SparkPlan { if (children.isEmpty) { false } else if (children.length == 1) { - children.head.asInstanceOf[CodegenSupport].needStopCheck + children.head.asInstanceOf[CodegenSupport].needCopyResult } else { throw new UnsupportedOperationException } @@ -291,6 +291,8 @@ case class InputAdapter(child: SparkPlan) extends UnaryExecNode with CodegenSupp addSuffix: Boolean = false): StringBuilder = { child.generateTreeString(depth, lastChildren, builder, verbose, "") } + + override def needCopyResult: Boolean = false } object WholeStageCodegenExec { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoinExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoinExec.scala index 41664d1fca32d..837b8525fed55 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoinExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoinExec.scala @@ -83,7 +83,7 @@ case class BroadcastHashJoinExec( // them in a buffer, because these result rows share one UnsafeRow instance. Note that here // we wait for the broadcast to be finished, which is a no-op because it's already finished // when we wait it in `doProduce`. - buildPlan.executeBroadcast[HashedRelation]().value.keyIsUnique + !buildPlan.executeBroadcast[HashedRelation]().value.keyIsUnique // Other joins types(semi, anti, existence) can at most produce one result row for one input // row from the streamed side, so no need to copy the result rows.