diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/subquery.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/subquery.scala index 2482fdf2de5ea..a8f5e1f63d4c7 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/subquery.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/subquery.scala @@ -25,7 +25,7 @@ import org.apache.spark.sql.types.DataType /** * An interface for subquery that is used in expressions. */ -abstract class SubqueryExpression extends LeafExpression{ +abstract class SubqueryExpression extends LeafExpression { /** * The logical plan of the query. @@ -33,9 +33,8 @@ abstract class SubqueryExpression extends LeafExpression{ def query: LogicalPlan /** - * The underline plan for the query, could be logical plan or physical plan. - * - * This is used to generate tree string. + * Either a logical plan or a physical plan. The generated tree string (explain output) uses this + * field to explain the subquery. */ def plan: QueryPlan[_] @@ -48,7 +47,9 @@ abstract class SubqueryExpression extends LeafExpression{ /** * A subquery that will return only one row and one column. * - * This is not evaluable, it should be converted into SparkScalaSubquery. + * This will be converted into [[execution.ScalarSubquery]] during physical planning. + * + * Note: `exprId` is used to have unique name in explain string output. */ case class ScalarSubquery( query: LogicalPlan, @@ -63,7 +64,7 @@ case class ScalarSubquery( override def checkInputDataTypes(): TypeCheckResult = { if (query.schema.length != 1) { - TypeCheckResult.TypeCheckFailure("Scalar subquery can only have 1 column, but got " + + TypeCheckResult.TypeCheckFailure("Scalar subquery must return only one column, but got " + query.schema.length.toString) } else { TypeCheckResult.TypeCheckSuccess diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala index 1f61aac2b1381..f1f438075164e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala @@ -90,16 +90,16 @@ abstract class Optimizer extends RuleExecutor[LogicalPlan] { Batch("LocalRelation", FixedPoint(100), ConvertToLocalRelation) :: Batch("Subquery", Once, - Subquery) :: Nil + OptimizeSubqueries) :: Nil } /** * Optimize all the subqueries inside expression. */ - object Subquery extends Rule[LogicalPlan] { + object OptimizeSubqueries extends Rule[LogicalPlan] { def apply(plan: LogicalPlan): LogicalPlan = plan transformAllExpressions { case subquery: SubqueryExpression => - subquery.withNewPlan(execute(subquery.query)) + subquery.withNewPlan(Optimizer.this.execute(subquery.query)) } } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/CatalystQlSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/CatalystQlSuite.scala index 9d80d4b2ce5b3..ed7121831ac29 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/CatalystQlSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/CatalystQlSuite.scala @@ -204,47 +204,8 @@ class CatalystQlSuite extends PlanTest { } test("subquery") { - comparePlans( - parser.parsePlan("select (select max(b) from s) ss from t"), - Project( - UnresolvedAlias( - Alias( - ScalarSubquery( - Project( - UnresolvedAlias( - UnresolvedFunction("max", UnresolvedAttribute("b") :: Nil, false)) :: Nil, - UnresolvedRelation(TableIdentifier("s")))), - "ss")(ExprId(0))) :: Nil, - UnresolvedRelation(TableIdentifier("t")))) - comparePlans( - parser.parsePlan("select * from t where a = (select b from s)"), - Project( - UnresolvedAlias( - UnresolvedStar(None)) :: Nil, - Filter( - EqualTo( - UnresolvedAttribute("a"), - ScalarSubquery( - Project( - UnresolvedAlias( - UnresolvedAttribute("b")) :: Nil, - UnresolvedRelation(TableIdentifier("s"))))), - UnresolvedRelation(TableIdentifier("t"))))) - comparePlans( - parser.parsePlan("select * from t group by g having a > (select b from s)"), - Filter( - Cast( - GreaterThan( - UnresolvedAttribute("a"), - ScalarSubquery( - Project( - UnresolvedAlias( - UnresolvedAttribute("b")) :: Nil, - UnresolvedRelation(TableIdentifier("s"))))), - BooleanType), - Aggregate( - UnresolvedAttribute("g") :: Nil, - UnresolvedAlias(UnresolvedStar(None)) :: Nil, - UnresolvedRelation(TableIdentifier("t"))))) + parser.parsePlan("select (select max(b) from s) ss from t") + parser.parsePlan("select * from t where a = (select b from s)") + parser.parsePlan("select * from t group by g having a > (select b from s)") } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisErrorSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisErrorSuite.scala index de10ba9c91372..ca6dcd8bdfb84 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisErrorSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisErrorSuite.scala @@ -114,10 +114,15 @@ class AnalysisErrorSuite extends AnalysisTest { val dateLit = Literal.create(null, DateType) errorTest( - "invalid scalar subquery", + "scalar subquery with 2 columns", testRelation.select( (ScalarSubquery(testRelation.select('a, dateLit.as('b))) + Literal(1)).as('a)), - "Scalar subquery can only have 1 column, but got 2" :: Nil) + "Scalar subquery must return only one column, but got 2" :: Nil) + + errorTest( + "scalar subquery with no column", + testRelation.select(ScalarSubquery(LocalRelation()).as('a)), + "Scalar subquery must return only one column, but got 0" :: Nil) errorTest( "single invalid type, single arg", diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala index a4ce0eb592a18..55325c1662e2d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala @@ -884,7 +884,7 @@ class SQLContext private[sql]( @transient protected[sql] val prepareForExecution = new RuleExecutor[SparkPlan] { val batches = Seq( - Batch("Subquery", Once, ConvertSubquery(self)), + Batch("Subquery", Once, PlanSubqueries(self)), Batch("Add exchange", Once, EnsureRequirements(self)), Batch("Whole stage codegen", Once, CollapseCodegenStages(self)) ) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala index ef08365a583a8..ad29942420899 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala @@ -127,8 +127,8 @@ abstract class SparkPlan extends QueryPlan[SparkPlan] with Logging with Serializ doPrepare() // collect all the subqueries and submit jobs to execute them in background - val queryResults = ArrayBuffer[(SparkScalarSubquery, Future[Array[InternalRow]])]() - val allSubqueries = expressions.flatMap(_.collect {case e: SparkScalarSubquery => e}) + val queryResults = ArrayBuffer[(ScalarSubquery, Future[Array[InternalRow]])]() + val allSubqueries = expressions.flatMap(_.collect {case e: ScalarSubquery => e}) allSubqueries.foreach { e => val futureResult = Future { e.plan.executeCollect() @@ -146,9 +146,12 @@ abstract class SparkPlan extends QueryPlan[SparkPlan] with Logging with Serializ sys.error(s"Scalar subquery should return at most one row, but got ${rows.length}: " + s"${e.query.treeString}") } - // Analyzer will make sure that it only return on column if (rows.length > 0) { + assert(rows(0).numFields == 1, "Analyzer should make sure this only returns one column") e.updateResult(rows(0).get(0, e.dataType)) + } else { + // the result should be null, since the expression already have null as default value, + // we don't need to update that. } } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/subquery.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/subquery.scala index 78fab20e8b652..67426dde6b5cb 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/subquery.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/subquery.scala @@ -18,8 +18,8 @@ package org.apache.spark.sql.execution import org.apache.spark.sql.SQLContext -import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.expressions.{ExprId, ScalarSubquery, SubqueryExpression} +import org.apache.spark.sql.catalyst.{expressions, InternalRow} +import org.apache.spark.sql.catalyst.expressions.{ExprId, SubqueryExpression} import org.apache.spark.sql.catalyst.expressions.codegen.CodegenFallback import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, ReturnAnswer} import org.apache.spark.sql.catalyst.rules.Rule @@ -30,7 +30,7 @@ import org.apache.spark.sql.types.DataType * * This is the physical copy of ScalarSubquery to be used inside SparkPlan. */ -case class SparkScalarSubquery( +case class ScalarSubquery( @transient executedPlan: SparkPlan, exprId: ExprId) extends SubqueryExpression with CodegenFallback { @@ -58,14 +58,13 @@ case class SparkScalarSubquery( /** * Convert the subquery from logical plan into executed plan. */ -private[sql] case class ConvertSubquery(sqlContext: SQLContext) extends Rule[SparkPlan] { +private[sql] case class PlanSubqueries(sqlContext: SQLContext) extends Rule[SparkPlan] { def apply(plan: SparkPlan): SparkPlan = { plan.transformAllExpressions { - // Only scalar subquery will be executed separately, all others will be written as join. - case subquery: ScalarSubquery => + case subquery: expressions.ScalarSubquery => val sparkPlan = sqlContext.planner.plan(ReturnAnswer(subquery.query)).next() val executedPlan = sqlContext.prepareForExecution.execute(sparkPlan) - SparkScalarSubquery(executedPlan, subquery.exprId) + ScalarSubquery(executedPlan, subquery.exprId) } } }