From 3b1649105869c72ccb16f86732e04829aaae0e93 Mon Sep 17 00:00:00 2001 From: frreiss Date: Mon, 16 May 2016 10:58:00 -0700 Subject: [PATCH 1/5] Commit before merge. --- .../sql/catalyst/optimizer/Optimizer.scala | 44 ++++++++++++- .../org/apache/spark/sql/SubquerySuite.scala | 61 +++++++++++++++++++ 2 files changed, 103 insertions(+), 2 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala index 350b60134e3e0..eb04e02f515db 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala @@ -1595,6 +1595,23 @@ object RewriteCorrelatedScalarSubquery extends Rule[LogicalPlan] { newExpression.asInstanceOf[E] } + /** + * Statically evaluate an expression containing one or more aggregates on an empty input. + */ + private def evalOnZeroTups(expr : Expression) : Option[Any] = { + // AggregateExpressions are Unevaluable, so we need to replace all aggregates + // in the expression with the value they would return for zero input tuples. + val rewrittenExpr = expr transform { + case a @ AggregateExpression(aggFunc, _, _, resultId) => + val resultLit = aggFunc.defaultResult match { + case Some(lit) => lit + case None => Literal.default(NullType) + } + Alias(resultLit, "aggVal") (exprId = resultId) + } + Option(rewrittenExpr.eval()) + } + /** * Construct a new child plan by left joining the given subqueries to a base plan. */ @@ -1603,9 +1620,32 @@ object RewriteCorrelatedScalarSubquery extends Rule[LogicalPlan] { subqueries: ArrayBuffer[ScalarSubquery]): LogicalPlan = { subqueries.foldLeft(child) { case (currentChild, ScalarSubquery(query, conditions, _)) => + val aggOutputExpr = query.asInstanceOf[Aggregate].aggregateExpressions.head + val origOutput = query.output.head + + // Ensure the rewritten subquery returns the same result when a tuple from the + // outer query block does not join with the subquery block. + // val (outputExpr, rewrittenQuery) = aggFunc.defaultResult match { + val (outputExpr, rewrittenQuery) = evalOnZeroTups(aggOutputExpr) match { + case Some(value) => + val origExprId = origOutput.exprId + val newExprId = NamedExpression.newExprId + + // Renumber the original output, because the outer query refers to its ID. + val newQuery = query transformExpressions { + case Alias(c, n) => Alias(c, n)(exprId = newExprId) + } + val coalesceExpr = Alias( + Coalesce(Seq(newQuery.output.head, Literal(value))), + origOutput.name) (exprId = origExprId) + (coalesceExpr, newQuery) + + case None => (origOutput, query) + } + Project( - currentChild.output :+ query.output.head, - Join(currentChild, query, LeftOuter, conditions.reduceOption(And))) + currentChild.output :+ outputExpr, + Join(currentChild, rewrittenQuery, LeftOuter, conditions.reduceOption(And))) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SubquerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SubquerySuite.scala index 17ac0c8c6e496..e47733aa20e0f 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SubquerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SubquerySuite.scala @@ -293,4 +293,65 @@ class SubquerySuite extends QueryTest with SharedSQLContext { """.stripMargin), Row(3) :: Nil) } + + test("COUNT bug in WHERE clause (Filter)") { + + // Case 1: Canonical example of the COUNT bug + checkAnswer( + sql("select l.a from l where (select count(*) from r where l.a = r.c) < l.a"), + Row(1) :: Row(1) :: Row(3) :: Row(6) :: Nil) + + // Case 2: count(*) = 0; could be rewritten to NOT EXISTS but currently uses + // a rewrite that is vulnerable to the COUNT bug + checkAnswer( + sql("select l.a from l where (select count(*) from r where l.a = r.c) = 0"), + Row(1) :: Row(1) :: Row(null) :: Row(null) :: Nil) + + // Case 3: COUNT bug without a COUNT aggregate + checkAnswer( + sql("select l.a from l where (select sum(r.d) is null from r where l.a = r.c)"), + Row(1) :: Row(1) ::Row(null) :: Row(null) :: Row(6) :: Nil) + + } + + test("COUNT bug in SELECT clause (Project)") { + checkAnswer( + sql("select a, (select count(*) from r where l.a = r.c) as cnt from l"), + Row(1, 0) :: Row(1, 0) :: Row(2, 2) :: Row(2, 2) :: Row(3, 1) :: Row(null, 0) + :: Row(null, 0) :: Row(6, 1) :: Nil) + } + + test("COUNT bug in HAVING clause (Filter)") { + checkAnswer( + sql("select l.a as grp_a from l group by l.a " + + "having (select count(*) from r where grp_a = r.c) = 0 " + + "order by grp_a"), + Row(null) :: Row(1) :: Nil) + } + + test("COUNT bug in Aggregate") { + checkAnswer( + sql("select l.a as aval, sum((select count(*) from r where l.a = r.c)) as cnt " + + "from l group by l.a order by aval"), + Row(null, 0) :: Row(1, 0) :: Row(2, 4) :: Row(3, 1) :: Row(6, 1) :: Nil) + } + + test("COUNT bug negative examples") { + // Case 1: Potential COUNT bug case that was working correctly prior to the fix + checkAnswer( + sql("select l.a from l where (select sum(r.d) from r where l.a = r.c) is null"), + Row(1) :: Row(1) :: Row(null) :: Row(null) :: Row(6) :: Nil) + + // Case 2: COUNT aggregate but no COUNT bug due to > 0 test. + checkAnswer( + sql("select l.a from l where (select count(*) from r where l.a = r.c) > 0"), + Row(2) :: Row(2) :: Row(3) :: Row(6) :: Nil) + + // Case 3: COUNT inside aggregate expression but no COUNT bug. + checkAnswer( + sql("select l.a from l where (select count(*) + sum(r.d) from r where l.a = r.c) = 0"), + Nil) + + + } } From 1615d560310a59b08a4c03677dd53eb3b9b49e06 Mon Sep 17 00:00:00 2001 From: frreiss Date: Thu, 19 May 2016 19:01:33 -0700 Subject: [PATCH 2/5] Second version of the updated rewrite --- .../sql/catalyst/optimizer/Optimizer.scala | 178 +++++++++++++++--- .../org/apache/spark/sql/SubquerySuite.scala | 35 ++-- 2 files changed, 173 insertions(+), 40 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala index 3829b48cee60c..80dc936075307 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala @@ -1647,23 +1647,136 @@ object RewriteCorrelatedScalarSubquery extends Rule[LogicalPlan] { newExpression.asInstanceOf[E] } + /** + * Statically evaluate an expression containing zero or more placeholders, given a set + * of bindings for placeholder values. + */ + private def evalExpr(expr : Expression, bindings : Map[Long, Option[Any]]) : Option[Any] = { + val rewrittenExpr = expr transform { + case r @ AttributeReference(_, dataType, _, _) => + bindings(r.exprId.id) match { + case Some(v) => Literal.create(v, dataType) + case None => Literal.default(NullType) + } + } + Option(rewrittenExpr.eval()) + } + /** * Statically evaluate an expression containing one or more aggregates on an empty input. */ - private def evalOnZeroTups(expr : Expression) : Option[Any] = { + private def evalAggOnZeroTups(expr : Expression) : Option[Any] = { // AggregateExpressions are Unevaluable, so we need to replace all aggregates // in the expression with the value they would return for zero input tuples. val rewrittenExpr = expr transform { case a @ AggregateExpression(aggFunc, _, _, resultId) => - val resultLit = aggFunc.defaultResult match { - case Some(lit) => lit - case None => Literal.default(NullType) - } - Alias(resultLit, "aggVal") (exprId = resultId) + aggFunc.defaultResult.getOrElse(Literal.default(NullType)) } Option(rewrittenExpr.eval()) } + /** + * Statically evaluate a scalar subquery on an empty input. + * + * WARNING: This method only covers subqueries that pass the checks under + * [[org.apache.spark.sql.catalyst.analysis.CheckAnalysis]]. If the checks in + * CheckAnalysis become less restrictive, this method will need to change. + */ + private def evalSubqueryOnZeroTups(plan: LogicalPlan) : Option[Any] = { + // Inputs to this method will start with a chain of zero or more SubqueryAlias + // and Project operators, followed by an optional Filter, followed by an + // Aggregate. Traverse the operators recursively. + def evalPlan(lp : LogicalPlan) : Map[Long, Option[Any]] = { + lp match { + case SubqueryAlias(_, child) => evalPlan(child) + case Filter(condition, child) => + val bindings = evalPlan(child) + if (bindings.size == 0) bindings + else { + val exprResult = evalExpr(condition, bindings).getOrElse(false) + .asInstanceOf[Boolean] + if (exprResult) bindings else Map() + } + + case Project(projectList, child) => + val bindings = evalPlan(child) + if (bindings.size == 0) { + bindings + } else { + projectList.map(ne => (ne.exprId.id, evalExpr(ne, bindings))).toMap + } + + case Aggregate(_, aggExprs, _) => + // Some of the expressions under the Aggregate node are the join columns + // for joining with the outer query block. Fill those expressions in with + // nulls and statically evaluate the remainder. + aggExprs.map(ne => ne match { + case AttributeReference(_, _, _, _) => (ne.exprId.id, None) + case Alias(AttributeReference(_, _, _, _), _) => (ne.exprId.id, None) + case _ => (ne.exprId.id, evalAggOnZeroTups(ne)) + }).toMap + + case _ => sys.error(s"Unexpected operator in scalar subquery: $lp") + } + } + + val resultMap = evalPlan(plan) + + // By convention, the scalar subquery result is the leftmost field. + resultMap(plan.output.head.exprId.id) + } + + /** + * Split the plan for a scalar subquery into the parts above the Aggregate node + * (first part of returned value) and the parts below the Aggregate node, including + * the Aggregate (second part of returned value) + */ + private def splitSubquery(plan : LogicalPlan) : Tuple2[Seq[LogicalPlan], Aggregate] = { + var topPart = List[LogicalPlan]() + var bottomPart : LogicalPlan = plan + while (! bottomPart.isInstanceOf[Aggregate]) { + topPart = bottomPart :: topPart + bottomPart = bottomPart.children.head + } + (topPart, bottomPart.asInstanceOf[Aggregate]) + } + + /** + * Rewrite the nodes above the Aggregate in a subquery so that they generate an + * auxiliary column "isFiltered" + * @param subqueryPlan plan before rewrite + * @param filteredId expression ID for the "isFiltered" column + */ + private def addIsFiltered(subqueryPlan : LogicalPlan, filteredId : ExprId) : LogicalPlan = { + val isFilteredRef = AttributeReference("isFiltered", BooleanType)(exprId = filteredId) + val (topPart, aggNode) = splitSubquery(subqueryPlan) + var rewrittenQuery: LogicalPlan = null + if (topPart.size > 0 && topPart.head.isInstanceOf[Filter]) { + // Correlated subquery has a HAVING clause + // Rewrite the Filter into a Project that returns the value of the filtering predicate + val origFilter = topPart.head.asInstanceOf[Filter] + var topRemainder = topPart.tail + val newProjectList = + origFilter.output :+ Alias(origFilter.condition, "isFiltered")(exprId = filteredId) + val filterAsProject = Project(newProjectList, origFilter.child) + + rewrittenQuery = filterAsProject + while (topRemainder.size > 0) { + rewrittenQuery = topRemainder.head match { + case Project(origList, _) => Project(origList :+ isFilteredRef, rewrittenQuery) + case SubqueryAlias(alias, _) => SubqueryAlias(alias, rewrittenQuery) + } + topRemainder = topRemainder.tail + } + } else { + // Correlated subquery without HAVING clause + // Add an additional Project that adds a constant value for "isFiltered" + rewrittenQuery = Project(subqueryPlan.output :+ Alias(Literal(false), "isFiltered") + (exprId = filteredId), subqueryPlan) + } + return rewrittenQuery + } + /** * Construct a new child plan by left joining the given subqueries to a base plan. */ @@ -1672,32 +1785,39 @@ object RewriteCorrelatedScalarSubquery extends Rule[LogicalPlan] { subqueries: ArrayBuffer[ScalarSubquery]): LogicalPlan = { subqueries.foldLeft(child) { case (currentChild, ScalarSubquery(query, conditions, _)) => - val aggOutputExpr = query.asInstanceOf[Aggregate].aggregateExpressions.head val origOutput = query.output.head - // Ensure the rewritten subquery returns the same result when a tuple from the - // outer query block does not join with the subquery block. - // val (outputExpr, rewrittenQuery) = aggFunc.defaultResult match { - val (outputExpr, rewrittenQuery) = evalOnZeroTups(aggOutputExpr) match { - case Some(value) => - val origExprId = origOutput.exprId - val newExprId = NamedExpression.newExprId - - // Renumber the original output, because the outer query refers to its ID. - val newQuery = query transformExpressions { - case Alias(c, n) => Alias(c, n)(exprId = newExprId) - } - val coalesceExpr = Alias( - Coalesce(Seq(newQuery.output.head, Literal(value))), - origOutput.name) (exprId = origExprId) - (coalesceExpr, newQuery) - - case None => (origOutput, query) - } + val resultWithZeroTups = evalSubqueryOnZeroTups(query) + if (resultWithZeroTups.isEmpty) { + Project( + currentChild.output :+ origOutput, + Join(currentChild, query, LeftOuter, conditions.reduceOption(And))) + } else { + // Renumber the original output, because the outer query refers to its ID. + val newExprId = NamedExpression.newExprId + val renumberedQuery = query transformExpressions { + case a@Alias(c, n) if a.exprId == origOutput.exprId => Alias(c, n)(exprId = newExprId) + } - Project( - currentChild.output :+ outputExpr, - Join(currentChild, rewrittenQuery, LeftOuter, conditions.reduceOption(And))) + val filteredId = NamedExpression.newExprId + val isFilteredRef = AttributeReference("isFiltered", BooleanType)(exprId = filteredId) + val withIsFiltered = addIsFiltered(renumberedQuery, filteredId) + val aggValRef = renumberedQuery.output.head + + // CASE WHEN isFiltered IS NULL THEN COALESCE(aggVal, resultOnZeroTups) + // WHEN isFiltered THEN CAST(null AS ) + // ELSE aggVal END + val caseExpr = Alias(CaseWhen( + Seq((IsNull(isFilteredRef), Coalesce(Seq(aggValRef, + Literal(resultWithZeroTups.getOrElse(null))))), + (isFilteredRef, Literal(null, aggValRef.dataType))), + aggValRef), + origOutput.name)(exprId = origOutput.exprId) + + Project( + currentChild.output :+ caseExpr, + Join(currentChild, withIsFiltered, LeftOuter, conditions.reduceOption(And))) + } } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SubquerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SubquerySuite.scala index e47733aa20e0f..f28503cae4054 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SubquerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SubquerySuite.scala @@ -294,34 +294,30 @@ class SubquerySuite extends QueryTest with SharedSQLContext { Row(3) :: Nil) } - test("COUNT bug in WHERE clause (Filter)") { - + test("SPARK-15370: COUNT bug in WHERE clause (Filter)") { // Case 1: Canonical example of the COUNT bug checkAnswer( sql("select l.a from l where (select count(*) from r where l.a = r.c) < l.a"), Row(1) :: Row(1) :: Row(3) :: Row(6) :: Nil) - // Case 2: count(*) = 0; could be rewritten to NOT EXISTS but currently uses // a rewrite that is vulnerable to the COUNT bug checkAnswer( sql("select l.a from l where (select count(*) from r where l.a = r.c) = 0"), Row(1) :: Row(1) :: Row(null) :: Row(null) :: Nil) - // Case 3: COUNT bug without a COUNT aggregate checkAnswer( sql("select l.a from l where (select sum(r.d) is null from r where l.a = r.c)"), Row(1) :: Row(1) ::Row(null) :: Row(null) :: Row(6) :: Nil) - } - test("COUNT bug in SELECT clause (Project)") { + test("SPARK-15370: COUNT bug in SELECT clause (Project)") { checkAnswer( sql("select a, (select count(*) from r where l.a = r.c) as cnt from l"), Row(1, 0) :: Row(1, 0) :: Row(2, 2) :: Row(2, 2) :: Row(3, 1) :: Row(null, 0) :: Row(null, 0) :: Row(6, 1) :: Nil) } - test("COUNT bug in HAVING clause (Filter)") { + test("SPARK-15370: COUNT bug in HAVING clause (Filter)") { checkAnswer( sql("select l.a as grp_a from l group by l.a " + "having (select count(*) from r where grp_a = r.c) = 0 " + @@ -329,29 +325,46 @@ class SubquerySuite extends QueryTest with SharedSQLContext { Row(null) :: Row(1) :: Nil) } - test("COUNT bug in Aggregate") { + test("SPARK-15370: COUNT bug in Aggregate") { checkAnswer( sql("select l.a as aval, sum((select count(*) from r where l.a = r.c)) as cnt " + "from l group by l.a order by aval"), Row(null, 0) :: Row(1, 0) :: Row(2, 4) :: Row(3, 1) :: Row(6, 1) :: Nil) } - test("COUNT bug negative examples") { + test("SPARK-15370: COUNT bug negative examples") { // Case 1: Potential COUNT bug case that was working correctly prior to the fix checkAnswer( sql("select l.a from l where (select sum(r.d) from r where l.a = r.c) is null"), Row(1) :: Row(1) :: Row(null) :: Row(null) :: Row(6) :: Nil) - // Case 2: COUNT aggregate but no COUNT bug due to > 0 test. checkAnswer( sql("select l.a from l where (select count(*) from r where l.a = r.c) > 0"), Row(2) :: Row(2) :: Row(3) :: Row(6) :: Nil) - // Case 3: COUNT inside aggregate expression but no COUNT bug. checkAnswer( sql("select l.a from l where (select count(*) + sum(r.d) from r where l.a = r.c) = 0"), Nil) + } + test("SPARK-15370: COUNT bug in subquery in subquery in subquery") { + checkAnswer( + sql("""select l.a from l + |where ( + | select cntPlusOne + 1 as cntPlusTwo from ( + | select cnt + 1 as cntPlusOne from ( + | select sum(r.c) s, count(*) cnt from r where l.a = r.c having cnt = 0 + | ) + | ) + |) = 2""".stripMargin), + Row(1) :: Row(1) :: Row(null) :: Row(null) :: Nil) + } + test("SPARK-15370: COUNT bug with nasty predicate expr") { + checkAnswer( + sql("select l.a from l where " + + "(select case when count(*) = 1 then null else count(*) end as cnt " + + "from r where l.a = r.c) = 0"), + Row(1) :: Row(1) :: Row(null) :: Row(null) :: Nil) } } From 8cd2877179dded4557c8da92e5b16011637289b0 Mon Sep 17 00:00:00 2001 From: frreiss Date: Thu, 9 Jun 2016 22:02:47 -0700 Subject: [PATCH 3/5] Addressing additional corner cases and review comments. --- .../sql/catalyst/expressions/predicates.scala | 7 +- .../sql/catalyst/optimizer/Optimizer.scala | 186 ++++++++++-------- .../org/apache/spark/sql/SubquerySuite.scala | 7 + 3 files changed, 120 insertions(+), 80 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala index 8a6cf53782b91..a3b098afe5728 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala @@ -69,8 +69,11 @@ trait PredicateHelper { protected def replaceAlias( condition: Expression, aliases: AttributeMap[Expression]): Expression = { - condition.transform { - case a: Attribute => aliases.getOrElse(a, a) + // Use transformUp to prevent infinite recursion when the replacement expression + // redefines the same ExprId, + condition.transformUp { + case a: Attribute => + aliases.getOrElse(a, a) } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala index 500e08ef0fbe7..f583ea92d69df 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala @@ -527,7 +527,8 @@ object CollapseProject extends Rule[LogicalPlan] { // Substitute any attributes that are produced by the lower projection, so that we safely // eliminate it. // e.g., 'SELECT c + 1 FROM (SELECT a + b AS C ...' produces 'SELECT a + b + 1 ...' - val rewrittenUpper = upper.map(_.transform { + // Use transformUp to prevent infinite recursion. + val rewrittenUpper = upper.map(_.transformUp { case a: Attribute => aliases.getOrElse(a, a) }) // collapse upper and lower Projects may introduce unnecessary Aliases, trim them here. @@ -1698,10 +1699,10 @@ object RewriteCorrelatedScalarSubquery extends Rule[LogicalPlan] { * Statically evaluate an expression containing zero or more placeholders, given a set * of bindings for placeholder values. */ - private def evalExpr(expr : Expression, bindings : Map[Long, Option[Any]]) : Option[Any] = { + private def evalExpr(expr: Expression, bindings: Map[ExprId, Option[Any]]) : Option[Any] = { val rewrittenExpr = expr transform { case r @ AttributeReference(_, dataType, _, _) => - bindings(r.exprId.id) match { + bindings(r.exprId) match { case Some(v) => Literal.create(v, dataType) case None => Literal.default(NullType) } @@ -1712,12 +1713,15 @@ object RewriteCorrelatedScalarSubquery extends Rule[LogicalPlan] { /** * Statically evaluate an expression containing one or more aggregates on an empty input. */ - private def evalAggOnZeroTups(expr : Expression) : Option[Any] = { + private def evalAggOnZeroTups(expr: Expression) : Option[Any] = { // AggregateExpressions are Unevaluable, so we need to replace all aggregates // in the expression with the value they would return for zero input tuples. + // Also replace attribute refs (for example, for grouping columns) with NULL. val rewrittenExpr = expr transform { case a @ AggregateExpression(aggFunc, _, _, resultId) => aggFunc.defaultResult.getOrElse(Literal.default(NullType)) + + case AttributeReference(_, _, _, _) => Literal.default(NullType) } Option(rewrittenExpr.eval()) } @@ -1733,24 +1737,24 @@ object RewriteCorrelatedScalarSubquery extends Rule[LogicalPlan] { // Inputs to this method will start with a chain of zero or more SubqueryAlias // and Project operators, followed by an optional Filter, followed by an // Aggregate. Traverse the operators recursively. - def evalPlan(lp : LogicalPlan) : Map[Long, Option[Any]] = { + def evalPlan(lp : LogicalPlan) : Map[ExprId, Option[Any]] = { lp match { case SubqueryAlias(_, child) => evalPlan(child) case Filter(condition, child) => val bindings = evalPlan(child) - if (bindings.size == 0) bindings + if (bindings.isEmpty) bindings else { val exprResult = evalExpr(condition, bindings).getOrElse(false) .asInstanceOf[Boolean] - if (exprResult) bindings else Map() + if (exprResult) bindings else Map.empty } case Project(projectList, child) => val bindings = evalPlan(child) - if (bindings.size == 0) { + if (bindings.isEmpty) { bindings } else { - projectList.map(ne => (ne.exprId.id, evalExpr(ne, bindings))).toMap + projectList.map(ne => (ne.exprId, evalExpr(ne, bindings))).toMap } case Aggregate(_, aggExprs, _) => @@ -1758,9 +1762,9 @@ object RewriteCorrelatedScalarSubquery extends Rule[LogicalPlan] { // for joining with the outer query block. Fill those expressions in with // nulls and statically evaluate the remainder. aggExprs.map(ne => ne match { - case AttributeReference(_, _, _, _) => (ne.exprId.id, None) - case Alias(AttributeReference(_, _, _, _), _) => (ne.exprId.id, None) - case _ => (ne.exprId.id, evalAggOnZeroTups(ne)) + case AttributeReference(_, _, _, _) => (ne.exprId, None) + case Alias(AttributeReference(_, _, _, _), _) => (ne.exprId, None) + case _ => (ne.exprId, evalAggOnZeroTups(ne)) }).toMap case _ => sys.error(s"Unexpected operator in scalar subquery: $lp") @@ -1770,60 +1774,49 @@ object RewriteCorrelatedScalarSubquery extends Rule[LogicalPlan] { val resultMap = evalPlan(plan) // By convention, the scalar subquery result is the leftmost field. - resultMap(plan.output.head.exprId.id) + resultMap(plan.output.head.exprId) } /** - * Split the plan for a scalar subquery into the parts above the Aggregate node - * (first part of returned value) and the parts below the Aggregate node, including - * the Aggregate (second part of returned value) + * Split the plan for a scalar subquery into the parts above the innermost query block + * (first part of returned value), the HAVING clause of the innermost query block + * (optional second part) and the parts below the HAVING CLAUSE (third part). */ - private def splitSubquery(plan : LogicalPlan) : Tuple2[Seq[LogicalPlan], Aggregate] = { - var topPart = List[LogicalPlan]() + private def splitSubquery(plan: LogicalPlan) : (Seq[LogicalPlan], Option[Filter], Aggregate) = { + val topPart = ArrayBuffer.empty[LogicalPlan] var bottomPart : LogicalPlan = plan - while (! bottomPart.isInstanceOf[Aggregate]) { - topPart = bottomPart :: topPart - bottomPart = bottomPart.children.head - } - (topPart, bottomPart.asInstanceOf[Aggregate]) - } + while (true) { + bottomPart match { + case havingPart@Filter(_, aggPart@Aggregate(_, _, _)) => + return (topPart, Option(havingPart), aggPart.asInstanceOf[Aggregate]) - /** - * Rewrite the nodes above the Aggregate in a subquery so that they generate an - * auxiliary column "isFiltered" - * @param subqueryPlan plan before rewrite - * @param filteredId expression ID for the "isFiltered" column - */ - private def addIsFiltered(subqueryPlan : LogicalPlan, filteredId : ExprId) : LogicalPlan = { - val isFilteredRef = AttributeReference("isFiltered", BooleanType)(exprId = filteredId) - val (topPart, aggNode) = splitSubquery(subqueryPlan) - var rewrittenQuery: LogicalPlan = null - if (topPart.size > 0 && topPart.head.isInstanceOf[Filter]) { - // Correlated subquery has a HAVING clause - // Rewrite the Filter into a Project that returns the value of the filtering predicate - val origFilter = topPart.head.asInstanceOf[Filter] - var topRemainder = topPart.tail - val newProjectList = - origFilter.output :+ Alias(origFilter.condition, "isFiltered")(exprId = filteredId) - val filterAsProject = Project(newProjectList, origFilter.child) - - rewrittenQuery = filterAsProject - while (topRemainder.size > 0) { - rewrittenQuery = topRemainder.head match { - case Project(origList, _) => Project(origList :+ isFilteredRef, rewrittenQuery) - case SubqueryAlias(alias, _) => SubqueryAlias(alias, rewrittenQuery) - } - topRemainder = topRemainder.tail + case aggPart@Aggregate(_, _, _) => + // No HAVING clause + return (topPart, None, aggPart) + + case p@Project(_, child) => + topPart += p + bottomPart = child + + case s@SubqueryAlias(_, child) => + topPart += s + bottomPart = child + + case Filter(_, op@_) => + sys.error(s"Correlated subquery has unexpected operator $op below filter") + + case op@_ => sys.error(s"Unexpected operator $op in correlated subquery") } - } else { - // Correlated subquery without HAVING clause - // Add an additional Project that adds a constant value for "isFiltered" - rewrittenQuery = Project(subqueryPlan.output :+ Alias(Literal(false), "isFiltered") - (exprId = filteredId), subqueryPlan) } - return rewrittenQuery + + sys.error("This line should be unreachable") } + + + // Name of generated column used in rewrite below + val ALWAYS_TRUE_COLNAME = "alwaysTrue" + /** * Construct a new child plan by left joining the given subqueries to a base plan. */ @@ -1836,34 +1829,71 @@ object RewriteCorrelatedScalarSubquery extends Rule[LogicalPlan] { val resultWithZeroTups = evalSubqueryOnZeroTups(query) if (resultWithZeroTups.isEmpty) { + // CASE 1: Subquery guaranteed not to have the COUNT bug Project( currentChild.output :+ origOutput, Join(currentChild, query, LeftOuter, conditions.reduceOption(And))) } else { - // Renumber the original output, because the outer query refers to its ID. - val newExprId = NamedExpression.newExprId - val renumberedQuery = query transformExpressions { - case a@Alias(c, n) if a.exprId == origOutput.exprId => Alias(c, n)(exprId = newExprId) - } + // Subquery might have the COUNT bug. Add appropriate corrections. + val (topPart, havingNode, aggNode) = splitSubquery(query) + + // The next two cases add a leading column to the outer join input to make it + // possible to distinguish between the case when no tuples join and the case + // when the tuple that joins contains null values. + // The leading column always has the value TRUE. + val alwaysTrueExprId = NamedExpression.newExprId + val alwaysTrueExpr = Alias(Literal.TrueLiteral, + ALWAYS_TRUE_COLNAME)(exprId = alwaysTrueExprId) + val alwaysTrueRef = AttributeReference(ALWAYS_TRUE_COLNAME, + BooleanType)(exprId = alwaysTrueExprId) + + val aggValRef = query.output.head + + if (!havingNode.isDefined) { + // CASE 2: Subquery with no HAVING clause + Project( + currentChild.output :+ + Alias( + If(IsNull(alwaysTrueRef), + Literal(resultWithZeroTups.get, origOutput.dataType), + aggValRef), origOutput.name)(exprId = origOutput.exprId), + Join(currentChild, + Project(query.output :+ alwaysTrueExpr, query), + LeftOuter, conditions.reduceOption(And))) - val filteredId = NamedExpression.newExprId - val isFilteredRef = AttributeReference("isFiltered", BooleanType)(exprId = filteredId) - val withIsFiltered = addIsFiltered(renumberedQuery, filteredId) - val aggValRef = renumberedQuery.output.head - - // CASE WHEN isFiltered IS NULL THEN COALESCE(aggVal, resultOnZeroTups) - // WHEN isFiltered THEN CAST(null AS ) - // ELSE aggVal END - val caseExpr = Alias(CaseWhen( - Seq((IsNull(isFilteredRef), Coalesce(Seq(aggValRef, - Literal(resultWithZeroTups.getOrElse(null))))), - (isFilteredRef, Literal(null, aggValRef.dataType))), - aggValRef), - origOutput.name)(exprId = origOutput.exprId) + } else { + // CASE 3: Subquery with HAVING clause. Pull the HAVING clause above the join. + // Need to modify any operators below the join to pass through all columns + // referenced in the HAVING clause. + var subqueryRoot : UnaryNode = aggNode + val havingInputs : Seq[NamedExpression] = aggNode.output + + topPart.reverse.foreach( + _ match { + case Project(projList, _) => + subqueryRoot = Project(projList ++ havingInputs, subqueryRoot) + case s@SubqueryAlias(alias, _) => subqueryRoot = SubqueryAlias(alias, subqueryRoot) + case op@_ => sys.error(s"Unexpected operator $op in corelated subquery") + } + ) + + // CASE WHEN alwayTrue IS NULL THEN resultOnZeroTups + // WHEN NOT (original HAVING clause expr) THEN CAST(null AS ) + // ELSE (aggregate value) END AS (original column name) + val caseExpr = Alias(CaseWhen( + Seq[(Expression, Expression)] ( + (IsNull(alwaysTrueRef), Literal(resultWithZeroTups.get, origOutput.dataType)), + (Not(havingNode.get.condition), Literal(null, aggValRef.dataType)) + ), aggValRef + ), origOutput.name) (exprId = origOutput.exprId) + + Project( + currentChild.output :+ caseExpr, + Join(currentChild, + Project(subqueryRoot.output :+ alwaysTrueExpr, subqueryRoot), + LeftOuter, conditions.reduceOption(And))) - Project( - currentChild.output :+ caseExpr, - Join(currentChild, withIsFiltered, LeftOuter, conditions.reduceOption(And))) + } } } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SubquerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SubquerySuite.scala index 9122b1923a60b..9e9761547183f 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SubquerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SubquerySuite.scala @@ -367,4 +367,11 @@ class SubquerySuite extends QueryTest with SharedSQLContext { "from r where l.a = r.c) = 0"), Row(1) :: Row(1) :: Row(null) :: Row(null) :: Nil) } + + test("SPARK-15370: COUNT bug with attribute ref in subquery input and output ") { + checkAnswer( + sql("select l.b, (select (r.c + count(*)) is null from r where l.a = r.c) from l"), + Row(1.0, false) :: Row(1.0, false) :: Row(2.0, true) :: Row(2.0, true) :: + Row(3.0, false) :: Row(5.0, true) :: Row(null, false) :: Row(null, true) :: Nil) + } } From e5c592032b5604a8f8f10326ecd10ade22b5dc43 Mon Sep 17 00:00:00 2001 From: Herman van Hovell Date: Sun, 12 Jun 2016 16:43:30 -0700 Subject: [PATCH 4/5] Style fixes --- .../sql/catalyst/optimizer/Optimizer.scala | 118 +++++++++--------- 1 file changed, 56 insertions(+), 62 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala index f583ea92d69df..e797608669078 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala @@ -1701,9 +1701,9 @@ object RewriteCorrelatedScalarSubquery extends Rule[LogicalPlan] { */ private def evalExpr(expr: Expression, bindings: Map[ExprId, Option[Any]]) : Option[Any] = { val rewrittenExpr = expr transform { - case r @ AttributeReference(_, dataType, _, _) => + case r: AttributeReference => bindings(r.exprId) match { - case Some(v) => Literal.create(v, dataType) + case Some(v) => Literal.create(v, r.dataType) case None => Literal.default(NullType) } } @@ -1721,7 +1721,7 @@ object RewriteCorrelatedScalarSubquery extends Rule[LogicalPlan] { case a @ AggregateExpression(aggFunc, _, _, resultId) => aggFunc.defaultResult.getOrElse(Literal.default(NullType)) - case AttributeReference(_, _, _, _) => Literal.default(NullType) + case _: AttributeReference => Literal.default(NullType) } Option(rewrittenExpr.eval()) } @@ -1737,38 +1737,36 @@ object RewriteCorrelatedScalarSubquery extends Rule[LogicalPlan] { // Inputs to this method will start with a chain of zero or more SubqueryAlias // and Project operators, followed by an optional Filter, followed by an // Aggregate. Traverse the operators recursively. - def evalPlan(lp : LogicalPlan) : Map[ExprId, Option[Any]] = { - lp match { - case SubqueryAlias(_, child) => evalPlan(child) - case Filter(condition, child) => - val bindings = evalPlan(child) - if (bindings.isEmpty) bindings - else { - val exprResult = evalExpr(condition, bindings).getOrElse(false) - .asInstanceOf[Boolean] - if (exprResult) bindings else Map.empty - } + def evalPlan(lp : LogicalPlan) : Map[ExprId, Option[Any]] = lp match { + case SubqueryAlias(_, child) => evalPlan(child) + case Filter(condition, child) => + val bindings = evalPlan(child) + if (bindings.isEmpty) bindings + else { + val exprResult = evalExpr(condition, bindings).getOrElse(false) + .asInstanceOf[Boolean] + if (exprResult) bindings else Map.empty + } - case Project(projectList, child) => - val bindings = evalPlan(child) - if (bindings.isEmpty) { - bindings - } else { - projectList.map(ne => (ne.exprId, evalExpr(ne, bindings))).toMap - } + case Project(projectList, child) => + val bindings = evalPlan(child) + if (bindings.isEmpty) { + bindings + } else { + projectList.map(ne => (ne.exprId, evalExpr(ne, bindings))).toMap + } - case Aggregate(_, aggExprs, _) => - // Some of the expressions under the Aggregate node are the join columns - // for joining with the outer query block. Fill those expressions in with - // nulls and statically evaluate the remainder. - aggExprs.map(ne => ne match { - case AttributeReference(_, _, _, _) => (ne.exprId, None) - case Alias(AttributeReference(_, _, _, _), _) => (ne.exprId, None) - case _ => (ne.exprId, evalAggOnZeroTups(ne)) - }).toMap - - case _ => sys.error(s"Unexpected operator in scalar subquery: $lp") - } + case Aggregate(_, aggExprs, _) => + // Some of the expressions under the Aggregate node are the join columns + // for joining with the outer query block. Fill those expressions in with + // nulls and statically evaluate the remainder. + aggExprs.map { + case ref: AttributeReference => (ref.exprId, None) + case alias @ Alias(_: AttributeReference, _) => (alias.exprId, None) + case ne => (ne.exprId, evalAggOnZeroTups(ne)) + }.toMap + + case _ => sys.error(s"Unexpected operator in scalar subquery: $lp") } val resultMap = evalPlan(plan) @@ -1784,36 +1782,34 @@ object RewriteCorrelatedScalarSubquery extends Rule[LogicalPlan] { */ private def splitSubquery(plan: LogicalPlan) : (Seq[LogicalPlan], Option[Filter], Aggregate) = { val topPart = ArrayBuffer.empty[LogicalPlan] - var bottomPart : LogicalPlan = plan + var bottomPart: LogicalPlan = plan while (true) { bottomPart match { - case havingPart@Filter(_, aggPart@Aggregate(_, _, _)) => - return (topPart, Option(havingPart), aggPart.asInstanceOf[Aggregate]) + case havingPart @ Filter(_, aggPart: Aggregate) => + return (topPart, Option(havingPart), aggPart) - case aggPart@Aggregate(_, _, _) => + case aggPart: Aggregate => // No HAVING clause return (topPart, None, aggPart) - case p@Project(_, child) => + case p @ Project(_, child) => topPart += p bottomPart = child - case s@SubqueryAlias(_, child) => + case s @ SubqueryAlias(_, child) => topPart += s bottomPart = child - case Filter(_, op@_) => + case Filter(_, op) => sys.error(s"Correlated subquery has unexpected operator $op below filter") - case op@_ => sys.error(s"Unexpected operator $op in correlated subquery") + case op @ _ => sys.error(s"Unexpected operator $op in correlated subquery") } } sys.error("This line should be unreachable") } - - // Name of generated column used in rewrite below val ALWAYS_TRUE_COLNAME = "alwaysTrue" @@ -1849,13 +1845,13 @@ object RewriteCorrelatedScalarSubquery extends Rule[LogicalPlan] { val aggValRef = query.output.head - if (!havingNode.isDefined) { + if (havingNode.isEmpty) { // CASE 2: Subquery with no HAVING clause Project( currentChild.output :+ Alias( If(IsNull(alwaysTrueRef), - Literal(resultWithZeroTups.get, origOutput.dataType), + Literal.create(resultWithZeroTups.get, origOutput.dataType), aggValRef), origOutput.name)(exprId = origOutput.exprId), Join(currentChild, Project(query.output :+ alwaysTrueExpr, query), @@ -1865,27 +1861,25 @@ object RewriteCorrelatedScalarSubquery extends Rule[LogicalPlan] { // CASE 3: Subquery with HAVING clause. Pull the HAVING clause above the join. // Need to modify any operators below the join to pass through all columns // referenced in the HAVING clause. - var subqueryRoot : UnaryNode = aggNode - val havingInputs : Seq[NamedExpression] = aggNode.output - - topPart.reverse.foreach( - _ match { - case Project(projList, _) => - subqueryRoot = Project(projList ++ havingInputs, subqueryRoot) - case s@SubqueryAlias(alias, _) => subqueryRoot = SubqueryAlias(alias, subqueryRoot) - case op@_ => sys.error(s"Unexpected operator $op in corelated subquery") - } - ) + var subqueryRoot: UnaryNode = aggNode + val havingInputs: Seq[NamedExpression] = aggNode.output + + topPart.reverse.foreach { + case Project(projList, _) => + subqueryRoot = Project(projList ++ havingInputs, subqueryRoot) + case s @ SubqueryAlias(alias, _) => + subqueryRoot = SubqueryAlias(alias, subqueryRoot) + case op => sys.error(s"Unexpected operator $op in corelated subquery") + } // CASE WHEN alwayTrue IS NULL THEN resultOnZeroTups // WHEN NOT (original HAVING clause expr) THEN CAST(null AS ) // ELSE (aggregate value) END AS (original column name) - val caseExpr = Alias(CaseWhen( - Seq[(Expression, Expression)] ( - (IsNull(alwaysTrueRef), Literal(resultWithZeroTups.get, origOutput.dataType)), - (Not(havingNode.get.condition), Literal(null, aggValRef.dataType)) - ), aggValRef - ), origOutput.name) (exprId = origOutput.exprId) + val caseExpr = Alias(CaseWhen(Seq( + (IsNull(alwaysTrueRef), Literal.create(resultWithZeroTups.get, origOutput.dataType)), + (Not(havingNode.get.condition), Literal.create(null, aggValRef.dataType))), + aggValRef), + origOutput.name)(exprId = origOutput.exprId) Project( currentChild.output :+ caseExpr, From 30dd0bd7d560151085e53667fcc4f6a8895844ed Mon Sep 17 00:00:00 2001 From: Herman van Hovell Date: Sun, 12 Jun 2016 16:57:18 -0700 Subject: [PATCH 5/5] Some simplification --- .../sql/catalyst/optimizer/Optimizer.scala | 29 +++++++++---------- 1 file changed, 13 insertions(+), 16 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala index 91e65f665e8d5..7b9b21f416415 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala @@ -1716,10 +1716,10 @@ object RewritePredicateSubquery extends Rule[LogicalPlan] with PredicateHelper { // Filter the plan by applying left semi and left anti joins. withSubquery.foldLeft(newFilter) { case (p, PredicateSubquery(sub, conditions, _, _)) => - val (joinCond, outerPlan) = rewriteExistentialExpr(conditions.reduceOption(And), p) + val (joinCond, outerPlan) = rewriteExistentialExpr(conditions, p) Join(outerPlan, sub, LeftSemi, joinCond) case (p, Not(PredicateSubquery(sub, conditions, false, _))) => - val (joinCond, outerPlan) = rewriteExistentialExpr(conditions.reduceOption(And), p) + val (joinCond, outerPlan) = rewriteExistentialExpr(conditions, p) Join(outerPlan, sub, LeftAnti, joinCond) case (p, Not(PredicateSubquery(sub, conditions, true, _))) => // This is a NULL-aware (left) anti join (NAAJ) e.g. col NOT IN expr @@ -1728,11 +1728,11 @@ object RewritePredicateSubquery extends Rule[LogicalPlan] with PredicateHelper { // Note that will almost certainly be planned as a Broadcast Nested Loop join. // Use EXISTS if performance matters to you. - val (joinCond, outerPlan) = rewriteExistentialExpr(conditions.reduceLeftOption(And), p) + val (joinCond, outerPlan) = rewriteExistentialExpr(conditions, p) val anyNull = splitConjunctivePredicates(joinCond.get).map(IsNull).reduceLeft(Or) Join(outerPlan, sub, LeftAnti, Option(Or(anyNull, joinCond.get))) case (p, predicate) => - val (newCond, inputPlan) = rewriteExistentialExpr(Option(predicate), p) + val (newCond, inputPlan) = rewriteExistentialExpr(Seq(predicate), p) Project(p.output, Filter(newCond.get, inputPlan)) } } @@ -1745,22 +1745,19 @@ object RewritePredicateSubquery extends Rule[LogicalPlan] with PredicateHelper { * are blocked in the Analyzer. */ private def rewriteExistentialExpr( - expr: Option[Expression], + exprs: Seq[Expression], plan: LogicalPlan): (Option[Expression], LogicalPlan) = { var newPlan = plan - expr match { - case Some(e) => - val newExpr = e transformUp { - case PredicateSubquery(sub, conditions, nullAware, _) => - // TODO: support null-aware join - val exists = AttributeReference("exists", BooleanType, nullable = false)() - newPlan = Join(newPlan, sub, ExistenceJoin(exists), conditions.reduceLeftOption(And)) - exists + val newExprs = exprs.map { e => + e transformUp { + case PredicateSubquery(sub, conditions, nullAware, _) => + // TODO: support null-aware join + val exists = AttributeReference("exists", BooleanType, nullable = false)() + newPlan = Join(newPlan, sub, ExistenceJoin(exists), conditions.reduceLeftOption(And)) + exists } - (Option(newExpr), newPlan) - case None => - (expr, plan) } + (newExprs.reduceOption(And), newPlan) } }