Skip to content

Commit

Permalink
address comments
Browse files Browse the repository at this point in the history
  • Loading branch information
Davies Liu committed Feb 20, 2016
1 parent d0974cf commit 3a8f08d
Show file tree
Hide file tree
Showing 7 changed files with 33 additions and 64 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -25,17 +25,16 @@ import org.apache.spark.sql.types.DataType
/**
* An interface for subquery that is used in expressions.
*/
abstract class SubqueryExpression extends LeafExpression{
abstract class SubqueryExpression extends LeafExpression {

/**
* The logical plan of the query.
*/
def query: LogicalPlan

/**
* The underline plan for the query, could be logical plan or physical plan.
*
* This is used to generate tree string.
* Either a logical plan or a physical plan. The generated tree string (explain output) uses this
* field to explain the subquery.
*/
def plan: QueryPlan[_]

Expand All @@ -48,7 +47,9 @@ abstract class SubqueryExpression extends LeafExpression{
/**
* A subquery that will return only one row and one column.
*
* This is not evaluable, it should be converted into SparkScalaSubquery.
* This will be converted into [[execution.ScalarSubquery]] during physical planning.
*
* Note: `exprId` is used to have unique name in explain string output.
*/
case class ScalarSubquery(
query: LogicalPlan,
Expand All @@ -63,7 +64,7 @@ case class ScalarSubquery(

override def checkInputDataTypes(): TypeCheckResult = {
if (query.schema.length != 1) {
TypeCheckResult.TypeCheckFailure("Scalar subquery can only have 1 column, but got " +
TypeCheckResult.TypeCheckFailure("Scalar subquery must return only one column, but got " +
query.schema.length.toString)
} else {
TypeCheckResult.TypeCheckSuccess
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -90,16 +90,16 @@ abstract class Optimizer extends RuleExecutor[LogicalPlan] {
Batch("LocalRelation", FixedPoint(100),
ConvertToLocalRelation) ::
Batch("Subquery", Once,
Subquery) :: Nil
OptimizeSubqueries) :: Nil
}

/**
* Optimize all the subqueries inside expression.
*/
object Subquery extends Rule[LogicalPlan] {
object OptimizeSubqueries extends Rule[LogicalPlan] {
def apply(plan: LogicalPlan): LogicalPlan = plan transformAllExpressions {
case subquery: SubqueryExpression =>
subquery.withNewPlan(execute(subquery.query))
subquery.withNewPlan(Optimizer.this.execute(subquery.query))
}
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -204,47 +204,8 @@ class CatalystQlSuite extends PlanTest {
}

test("subquery") {
comparePlans(
parser.parsePlan("select (select max(b) from s) ss from t"),
Project(
UnresolvedAlias(
Alias(
ScalarSubquery(
Project(
UnresolvedAlias(
UnresolvedFunction("max", UnresolvedAttribute("b") :: Nil, false)) :: Nil,
UnresolvedRelation(TableIdentifier("s")))),
"ss")(ExprId(0))) :: Nil,
UnresolvedRelation(TableIdentifier("t"))))
comparePlans(
parser.parsePlan("select * from t where a = (select b from s)"),
Project(
UnresolvedAlias(
UnresolvedStar(None)) :: Nil,
Filter(
EqualTo(
UnresolvedAttribute("a"),
ScalarSubquery(
Project(
UnresolvedAlias(
UnresolvedAttribute("b")) :: Nil,
UnresolvedRelation(TableIdentifier("s"))))),
UnresolvedRelation(TableIdentifier("t")))))
comparePlans(
parser.parsePlan("select * from t group by g having a > (select b from s)"),
Filter(
Cast(
GreaterThan(
UnresolvedAttribute("a"),
ScalarSubquery(
Project(
UnresolvedAlias(
UnresolvedAttribute("b")) :: Nil,
UnresolvedRelation(TableIdentifier("s"))))),
BooleanType),
Aggregate(
UnresolvedAttribute("g") :: Nil,
UnresolvedAlias(UnresolvedStar(None)) :: Nil,
UnresolvedRelation(TableIdentifier("t")))))
parser.parsePlan("select (select max(b) from s) ss from t")
parser.parsePlan("select * from t where a = (select b from s)")
parser.parsePlan("select * from t group by g having a > (select b from s)")
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -114,10 +114,15 @@ class AnalysisErrorSuite extends AnalysisTest {
val dateLit = Literal.create(null, DateType)

errorTest(
"invalid scalar subquery",
"scalar subquery with 2 columns",
testRelation.select(
(ScalarSubquery(testRelation.select('a, dateLit.as('b))) + Literal(1)).as('a)),
"Scalar subquery can only have 1 column, but got 2" :: Nil)
"Scalar subquery must return only one column, but got 2" :: Nil)

errorTest(
"scalar subquery with no column",
testRelation.select(ScalarSubquery(LocalRelation()).as('a)),
"Scalar subquery must return only one column, but got 0" :: Nil)

errorTest(
"single invalid type, single arg",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -884,7 +884,7 @@ class SQLContext private[sql](
@transient
protected[sql] val prepareForExecution = new RuleExecutor[SparkPlan] {
val batches = Seq(
Batch("Subquery", Once, ConvertSubquery(self)),
Batch("Subquery", Once, PlanSubqueries(self)),
Batch("Add exchange", Once, EnsureRequirements(self)),
Batch("Whole stage codegen", Once, CollapseCodegenStages(self))
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -127,8 +127,8 @@ abstract class SparkPlan extends QueryPlan[SparkPlan] with Logging with Serializ
doPrepare()

// collect all the subqueries and submit jobs to execute them in background
val queryResults = ArrayBuffer[(SparkScalarSubquery, Future[Array[InternalRow]])]()
val allSubqueries = expressions.flatMap(_.collect {case e: SparkScalarSubquery => e})
val queryResults = ArrayBuffer[(ScalarSubquery, Future[Array[InternalRow]])]()
val allSubqueries = expressions.flatMap(_.collect {case e: ScalarSubquery => e})
allSubqueries.foreach { e =>
val futureResult = Future {
e.plan.executeCollect()
Expand All @@ -146,9 +146,12 @@ abstract class SparkPlan extends QueryPlan[SparkPlan] with Logging with Serializ
sys.error(s"Scalar subquery should return at most one row, but got ${rows.length}: " +
s"${e.query.treeString}")
}
// Analyzer will make sure that it only return on column
if (rows.length > 0) {
assert(rows(0).numFields == 1, "Analyzer should make sure this only returns one column")
e.updateResult(rows(0).get(0, e.dataType))
} else {
// the result should be null, since the expression already have null as default value,
// we don't need to update that.
}
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,8 @@
package org.apache.spark.sql.execution

import org.apache.spark.sql.SQLContext
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.{ExprId, ScalarSubquery, SubqueryExpression}
import org.apache.spark.sql.catalyst.{expressions, InternalRow}
import org.apache.spark.sql.catalyst.expressions.{ExprId, SubqueryExpression}
import org.apache.spark.sql.catalyst.expressions.codegen.CodegenFallback
import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, ReturnAnswer}
import org.apache.spark.sql.catalyst.rules.Rule
Expand All @@ -30,7 +30,7 @@ import org.apache.spark.sql.types.DataType
*
* This is the physical copy of ScalarSubquery to be used inside SparkPlan.
*/
case class SparkScalarSubquery(
case class ScalarSubquery(
@transient executedPlan: SparkPlan,
exprId: ExprId)
extends SubqueryExpression with CodegenFallback {
Expand Down Expand Up @@ -58,14 +58,13 @@ case class SparkScalarSubquery(
/**
* Convert the subquery from logical plan into executed plan.
*/
private[sql] case class ConvertSubquery(sqlContext: SQLContext) extends Rule[SparkPlan] {
private[sql] case class PlanSubqueries(sqlContext: SQLContext) extends Rule[SparkPlan] {
def apply(plan: SparkPlan): SparkPlan = {
plan.transformAllExpressions {
// Only scalar subquery will be executed separately, all others will be written as join.
case subquery: ScalarSubquery =>
case subquery: expressions.ScalarSubquery =>
val sparkPlan = sqlContext.planner.plan(ReturnAnswer(subquery.query)).next()
val executedPlan = sqlContext.prepareForExecution.execute(sparkPlan)
SparkScalarSubquery(executedPlan, subquery.exprId)
ScalarSubquery(executedPlan, subquery.exprId)
}
}
}

0 comments on commit 3a8f08d

Please sign in to comment.