Skip to content

Commit

Permalink
[SPARK-13306][SQL] Addendum to uncorrelated scalar subquery
Browse files Browse the repository at this point in the history
## What changes were proposed in this pull request?
This pull request fixes some minor issues (documentation, test flakiness, test organization) with #11190, which was merged earlier tonight.

## How was the this patch tested?
unit tests.

Author: Reynold Xin <rxin@databricks.com>

Closes #11285 from rxin/subquery.
  • Loading branch information
rxin committed Feb 21, 2016
1 parent 0947f09 commit af441dd
Show file tree
Hide file tree
Showing 6 changed files with 61 additions and 70 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -123,13 +123,12 @@ class Analyzer(
}
substituted.getOrElse(u)
case other =>
// This can't be done in ResolveSubquery because that does not know the CTE.
// This cannot be done in ResolveSubquery because ResolveSubquery does not know the CTE.
other transformExpressions {
case e: SubqueryExpression =>
e.withNewPlan(substituteCTE(e.query, cteRelations))
}
}

}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -255,14 +255,3 @@ case class Literal protected (value: Any, dataType: DataType)
case _ => value.toString
}
}

// TODO: Specialize
case class MutableLiteral(var value: Any, dataType: DataType, nullable: Boolean = true)
extends LeafExpression with CodegenFallback {

def update(expression: Expression, input: InternalRow): Unit = {
value = expression.eval(input)
}

override def eval(input: InternalRow): Any = value
}
Original file line number Diff line number Diff line change
Expand Up @@ -45,9 +45,8 @@ abstract class SubqueryExpression extends LeafExpression {
}

/**
* A subquery that will return only one row and one column.
*
* This will be converted into [[execution.ScalarSubquery]] during physical planning.
* A subquery that will return only one row and one column. This will be converted into a physical
* scalar subquery during planning.
*
* Note: `exprId` is used to have unique name in explain string output.
*/
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ abstract class SparkPlan extends QueryPlan[SparkPlan] with Logging with Serializ
* populated by the query planning infrastructure.
*/
@transient
protected[spark] final val sqlContext = SQLContext.getActive().getOrElse(null)
protected[spark] final val sqlContext = SQLContext.getActive().orNull

protected def sparkContext = sqlContext.sparkContext

Expand Down Expand Up @@ -120,44 +120,49 @@ abstract class SparkPlan extends QueryPlan[SparkPlan] with Logging with Serializ
}
}

// All the subqueries and their Future of results.
@transient private val queryResults = ArrayBuffer[(ScalarSubquery, Future[Array[InternalRow]])]()
/**
* List of (uncorrelated scalar subquery, future holding the subquery result) for this plan node.
* This list is populated by [[prepareSubqueries]], which is called in [[prepare]].
*/
@transient
private val subqueryResults = new ArrayBuffer[(ScalarSubquery, Future[Array[InternalRow]])]

/**
* Collects all the subqueries and create a Future to take the first two rows of them.
* Finds scalar subquery expressions in this plan node and starts evaluating them.
* The list of subqueries are added to [[subqueryResults]].
*/
protected def prepareSubqueries(): Unit = {
val allSubqueries = expressions.flatMap(_.collect {case e: ScalarSubquery => e})
allSubqueries.asInstanceOf[Seq[ScalarSubquery]].foreach { e =>
val futureResult = Future {
// We only need the first row, try to take two rows so we can throw an exception if there
// are more than one rows returned.
// Each subquery should return only one row (and one column). We take two here and throws
// an exception later if the number of rows is greater than one.
e.executedPlan.executeTake(2)
}(SparkPlan.subqueryExecutionContext)
queryResults += e -> futureResult
subqueryResults += e -> futureResult
}
}

/**
* Waits for all the subqueries to finish and updates the results.
* Blocks the thread until all subqueries finish evaluation and update the results.
*/
protected def waitForSubqueries(): Unit = {
// fill in the result of subqueries
queryResults.foreach {
case (e, futureResult) =>
val rows = Await.result(futureResult, Duration.Inf)
if (rows.length > 1) {
sys.error(s"more than one row returned by a subquery used as an expression:\n${e.plan}")
}
if (rows.length == 1) {
assert(rows(0).numFields == 1, "Analyzer should make sure this only returns one column")
e.updateResult(rows(0).get(0, e.dataType))
} else {
// There is no rows returned, the result should be null.
e.updateResult(null)
}
subqueryResults.foreach { case (e, futureResult) =>
val rows = Await.result(futureResult, Duration.Inf)
if (rows.length > 1) {
sys.error(s"more than one row returned by a subquery used as an expression:\n${e.plan}")
}
if (rows.length == 1) {
assert(rows(0).numFields == 1,
s"Expects 1 field, but got ${rows(0).numFields}; something went wrong in analysis")
e.updateResult(rows(0).get(0, e.dataType))
} else {
// If there is no rows returned, the result should be null.
e.updateResult(null)
}
}
queryResults.clear()
subqueryResults.clear()
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ case class ScalarSubquery(
/**
* Convert the subquery from logical plan into executed plan.
*/
private[sql] case class PlanSubqueries(sqlContext: SQLContext) extends Rule[SparkPlan] {
case class PlanSubqueries(sqlContext: SQLContext) extends Rule[SparkPlan] {
def apply(plan: SparkPlan): SparkPlan = {
plan.transformAllExpressions {
case subquery: expressions.ScalarSubquery =>
Expand Down
61 changes: 30 additions & 31 deletions sql/core/src/test/scala/org/apache/spark/sql/SubquerySuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -20,65 +20,64 @@ package org.apache.spark.sql
import org.apache.spark.sql.test.SharedSQLContext

class SubquerySuite extends QueryTest with SharedSQLContext {
import testImplicits._

test("simple uncorrelated scalar subquery") {
assertResult(Array(Row(1))) {
sql("select (select 1 as b) as b").collect()
}

assertResult(Array(Row(1))) {
sql("with t2 as (select 1 as b, 2 as c) " +
"select a from (select 1 as a union all select 2 as a) t " +
"where a = (select max(b) from t2) ").collect()
}

assertResult(Array(Row(3))) {
sql("select (select (select 1) + 1) + 1").collect()
}

// more than one columns
val error = intercept[AnalysisException] {
sql("select (select 1, 2) as b").collect()
}
assert(error.message contains "Scalar subquery must return only one column, but got 2")

// more than one rows
val error2 = intercept[RuntimeException] {
sql("select (select a from (select 1 as a union all select 2 as a) t) as b").collect()
}
assert(error2.getMessage contains
"more than one row returned by a subquery used as an expression")

// string type
assertResult(Array(Row("s"))) {
sql("select (select 's' as s) as b").collect()
}
}

// zero rows
test("uncorrelated scalar subquery in CTE") {
assertResult(Array(Row(1))) {
sql("with t2 as (select 1 as b, 2 as c) " +
"select a from (select 1 as a union all select 2 as a) t " +
"where a = (select max(b) from t2) ").collect()
}
}

test("uncorrelated scalar subquery should return null if there is 0 rows") {
assertResult(Array(Row(null))) {
sql("select (select 's' as s limit 0) as b").collect()
}
}

test("uncorrelated scalar subquery on testData") {
// initialize test Data
testData
test("runtime error when the number of rows is greater than 1") {
val error2 = intercept[RuntimeException] {
sql("select (select a from (select 1 as a union all select 2 as a) t) as b").collect()
}
assert(error2.getMessage.contains(
"more than one row returned by a subquery used as an expression"))
}

test("uncorrelated scalar subquery on a DataFrame generated query") {
val df = Seq((1, "one"), (2, "two"), (3, "three")).toDF("key", "value")
df.registerTempTable("subqueryData")

assertResult(Array(Row(5))) {
sql("select (select key from testData where key > 3 limit 1) + 1").collect()
assertResult(Array(Row(4))) {
sql("select (select key from subqueryData where key > 2 order by key limit 1) + 1").collect()
}

assertResult(Array(Row(-100))) {
sql("select -(select max(key) from testData)").collect()
assertResult(Array(Row(-3))) {
sql("select -(select max(key) from subqueryData)").collect()
}

assertResult(Array(Row(null))) {
sql("select (select value from testData limit 0)").collect()
sql("select (select value from subqueryData limit 0)").collect()
}

assertResult(Array(Row("99"))) {
sql("select (select min(value) from testData" +
" where key = (select max(key) from testData) - 1)").collect()
assertResult(Array(Row("two"))) {
sql("select (select min(value) from subqueryData" +
" where key = (select max(key) from subqueryData) - 1)").collect()
}
}
}

0 comments on commit af441dd

Please sign in to comment.