Skip to content

Commit

Permalink
address comments
Browse files Browse the repository at this point in the history
  • Loading branch information
Davies Liu committed Feb 21, 2016
1 parent 0034172 commit e082845
Show file tree
Hide file tree
Showing 4 changed files with 17 additions and 32 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -228,28 +228,8 @@ abstract class QueryPlan[PlanType <: TreeNode[PlanType]] extends TreeNode[PlanTy

override def simpleString: String = statePrefix + super.simpleString

override def generateTreeString(
depth: Int, lastChildren: Seq[Boolean], builder: StringBuilder): StringBuilder = {
if (depth > 0) {
lastChildren.init.foreach { isLast =>
val prefixFragment = if (isLast) " " else ": "
builder.append(prefixFragment)
}

val branch = if (lastChildren.last) "+- " else ":- "
builder.append(branch)
}

builder.append(simpleString)
builder.append("\n")

val allSubqueries = expressions.flatMap(_.collect {case e: SubqueryExpression => e})
val allChildren = children ++ allSubqueries.map(e => e.plan)
if (allChildren.nonEmpty) {
allChildren.init.foreach(_.generateTreeString(depth + 1, lastChildren :+ false, builder))
allChildren.last.generateTreeString(depth + 1, lastChildren :+ true, builder)
}

builder
override def treeChildren: Seq[PlanType] = {
val subqueries = expressions.flatMap(_.collect {case e: SubqueryExpression => e})
children ++ subqueries.map(e => e.plan.asInstanceOf[PlanType])
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -448,6 +448,11 @@ abstract class TreeNode[BaseType <: TreeNode[BaseType]] extends Product {
}
}

/**
* All the nodes that will be used to generate tree string.
*/
protected def treeChildren: Seq[BaseType] = children

/**
* Appends the string represent of this node and its children to the given StringBuilder.
*
Expand All @@ -470,9 +475,9 @@ abstract class TreeNode[BaseType <: TreeNode[BaseType]] extends Product {
builder.append(simpleString)
builder.append("\n")

if (children.nonEmpty) {
children.init.foreach(_.generateTreeString(depth + 1, lastChildren :+ false, builder))
children.last.generateTreeString(depth + 1, lastChildren :+ true, builder)
if (treeChildren.nonEmpty) {
treeChildren.init.foreach(_.generateTreeString(depth + 1, lastChildren :+ false, builder))
treeChildren.last.generateTreeString(depth + 1, lastChildren :+ true, builder)
}

builder
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -120,15 +120,15 @@ abstract class SparkPlan extends QueryPlan[SparkPlan] with Logging with Serializ
}
}

// All the subquries and their Future of results.
// All the subqueries and their Future of results.
@transient private val queryResults = ArrayBuffer[(ScalarSubquery, Future[Array[InternalRow]])]()

/**
* Collects all the subqueries and create a Future to take the first two rows of them.
*/
protected def prepareSubqueries(): Unit = {
val allSubqueries = expressions.flatMap(_.collect {case e: ScalarSubquery => e})
allSubqueries.foreach { e =>
allSubqueries.asInstanceOf[Seq[ScalarSubquery]].foreach { e =>
val futureResult = Future {
// We only need the first row, try to take two rows so we can throw an exception if there
// are more than one rows returned.
Expand All @@ -139,7 +139,7 @@ abstract class SparkPlan extends QueryPlan[SparkPlan] with Logging with Serializ
}

/**
* Waits for all the subquires to finish and updates the results.
* Waits for all the subqueries to finish and updates the results.
*/
protected def waitForSubqueries(): Unit = {
// fill in the result of subqueries
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,9 +20,6 @@ package org.apache.spark.sql
import org.apache.spark.sql.test.SharedSQLContext

class SubquerySuite extends QueryTest with SharedSQLContext {
import testImplicits._

setupTestData()

test("simple uncorrelated scalar subquery") {
assertResult(Array(Row(1))) {
Expand Down Expand Up @@ -64,6 +61,9 @@ class SubquerySuite extends QueryTest with SharedSQLContext {
}

test("uncorrelated scalar subquery on testData") {
// initialize test Data
testData

assertResult(Array(Row(5))) {
sql("select (select key from testData where key > 3 limit 1) + 1").collect()
}
Expand Down

0 comments on commit e082845

Please sign in to comment.