Skip to content

Commit

Permalink
move wait subqueries into execute()/produce()
Browse files Browse the repository at this point in the history
  • Loading branch information
Davies Liu committed Feb 20, 2016
1 parent 7596173 commit 0034172
Show file tree
Hide file tree
Showing 4 changed files with 49 additions and 41 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -115,44 +115,59 @@ abstract class SparkPlan extends QueryPlan[SparkPlan] with Logging with Serializ
final def execute(): RDD[InternalRow] = {
RDDOperationScope.withScope(sparkContext, nodeName, false, true) {
prepare()
waitForSubqueries()
doExecute()
}
}

// All the subquries and their Future of results.
@transient private val queryResults = ArrayBuffer[(ScalarSubquery, Future[Array[InternalRow]])]()

/**
* Collects all the subqueries and create a Future to take the first two rows of them.
*/
protected def prepareSubqueries(): Unit = {
val allSubqueries = expressions.flatMap(_.collect {case e: ScalarSubquery => e})
allSubqueries.foreach { e =>
val futureResult = Future {
// We only need the first row, try to take two rows so we can throw an exception if there
// are more than one rows returned.
e.executedPlan.executeTake(2)
}(SparkPlan.subqueryExecutionContext)
queryResults += e -> futureResult
}
}

/**
* Waits for all the subquires to finish and updates the results.
*/
protected def waitForSubqueries(): Unit = {
// fill in the result of subqueries
queryResults.foreach {
case (e, futureResult) =>
val rows = Await.result(futureResult, Duration.Inf)
if (rows.length > 1) {
sys.error(s"more than one row returned by a subquery used as an expression:\n${e.plan}")
}
if (rows.length == 1) {
assert(rows(0).numFields == 1, "Analyzer should make sure this only returns one column")
e.updateResult(rows(0).get(0, e.dataType))
} else {
// There is no rows returned, the result should be null.
e.updateResult(null)
}
}
queryResults.clear()
}

/**
* Prepare a SparkPlan for execution. It's idempotent.
*/
final def prepare(): Unit = {
if (prepareCalled.compareAndSet(false, true)) {
doPrepare()

// collect all the subqueries and submit jobs to execute them in background
val queryResults = ArrayBuffer[(ScalarSubquery, Future[Array[InternalRow]])]()
val allSubqueries = expressions.flatMap(_.collect {case e: ScalarSubquery => e})
allSubqueries.foreach { e =>
val futureResult = Future {
e.plan.executeTake(2)
}(SparkPlan.subqueryExecutionContext)
queryResults += e -> futureResult
}

prepareSubqueries()
children.foreach(_.prepare())

// fill in the result of subqueries
queryResults.foreach {
case (e, futureResult) =>
val rows = Await.result(futureResult, Duration.Inf)
if (rows.length > 1) {
sys.error(s"more than one row returned by a subquery used as an expression:\n${e.plan}")
}
if (rows.length == 1) {
assert(rows(0).numFields == 1, "Analyzer should make sure this only returns one column")
e.updateResult(rows(0).get(0, e.dataType))
} else {
// There is no rows returned, the result should be null.
e.updateResult(null)
}
}
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -73,9 +73,10 @@ trait CodegenSupport extends SparkPlan {
/**
* Returns Java source code to process the rows from upstream.
*/
def produce(ctx: CodegenContext, parent: CodegenSupport): String = {
final def produce(ctx: CodegenContext, parent: CodegenSupport): String = {
this.parent = parent
ctx.freshNamePrefix = variablePrefix
waitForSubqueries()
doProduce(ctx)
}

Expand All @@ -101,7 +102,7 @@ trait CodegenSupport extends SparkPlan {
/**
* Consume the columns generated from current SparkPlan, call it's parent.
*/
def consume(ctx: CodegenContext, input: Seq[ExprCode], row: String = null): String = {
final def consume(ctx: CodegenContext, input: Seq[ExprCode], row: String = null): String = {
if (input != null) {
assert(input.length == output.length)
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -355,6 +355,6 @@ case class Subquery(name: String, child: SparkPlan) extends UnaryNode {
override def outputOrdering: Seq[SortOrder] = child.outputOrdering

protected override def doExecute(): RDD[InternalRow] = {
child.execute()
throw new UnsupportedOperationException
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -18,11 +18,11 @@
package org.apache.spark.sql.execution

import org.apache.spark.sql.SQLContext
import org.apache.spark.sql.catalyst.{expressions, InternalRow}
import org.apache.spark.sql.catalyst.expressions.{ExprId, Literal, SubqueryExpression}
import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode}
import org.apache.spark.sql.catalyst.expressions.{ExprId, SubqueryExpression}
import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, ReturnAnswer}
import org.apache.spark.sql.catalyst.rules.Rule
import org.apache.spark.sql.catalyst.{InternalRow, expressions}
import org.apache.spark.sql.types.DataType

/**
Expand Down Expand Up @@ -55,15 +55,7 @@ case class ScalarSubquery(
override def eval(input: InternalRow): Any = result

override def genCode(ctx: CodegenContext, ev: ExprCode): String = {
val thisTerm = ctx.addReferenceObj("subquery", this)
val isNull = ctx.freshName("isNull")
ctx.addMutableState("boolean", isNull, s"$isNull = $thisTerm.eval(null) == null;")
val value = ctx.freshName("value")
ctx.addMutableState(ctx.javaType(dataType), value,
s"$value = (${ctx.boxedType(dataType)}) $thisTerm.eval(null);")
ev.isNull = isNull
ev.value = value
""
Literal.create(result, dataType).genCode(ctx, ev)
}
}

Expand Down

0 comments on commit 0034172

Please sign in to comment.