From af6cb56d59ca9499ab374be63c95730bf0758d6a Mon Sep 17 00:00:00 2001 From: Davies Liu Date: Wed, 20 Apr 2016 13:49:31 -0700 Subject: [PATCH 1/5] fix subquery resolution --- .../sql/catalyst/analysis/Analyzer.scala | 36 +++++++++++++------ .../sql/catalyst/analysis/CheckAnalysis.scala | 18 ++++++---- .../sql/catalyst/expressions/subquery.scala | 19 +++++++--- .../sql/catalyst/optimizer/Optimizer.scala | 6 ++-- .../org/apache/spark/sql/SubquerySuite.scala | 27 ++++++++------ 5 files changed, 72 insertions(+), 34 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index 236476900a519..cd4f356173b21 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -865,15 +865,31 @@ class Analyzer( * Resolve the correlated predicates in the [[Filter]] clauses (e.g. WHERE or HAVING) of a * sub-query by using the plan the predicates should be correlated to. */ - private def resolveCorrelatedPredicates(q: LogicalPlan, p: LogicalPlan): LogicalPlan = { - q transformUp { - case f @ Filter(cond, child) if child.resolved && !f.resolved => - val newCond = resolveExpression(cond, p, throws = false) - if (!cond.fastEquals(newCond)) { - Filter(newCond, child) - } else { - f - } + private def resolveCorrelatedSubquery( + subquery: LogicalPlan, + outers: Seq[LogicalPlan]): LogicalPlan = { + val analyzed = execute(subquery) + if (analyzed.resolved) { + analyzed + } else { + // Only resolve the lowest plan that is not resolved by outer plan, otherwise it could be + // resolved by itself + val resolvedByOuter = analyzed transformDown { + case q: LogicalPlan if q.childrenResolved && !q.resolved => + q transformExpressions { + case expr => + outers.foldLeft(expr) { case (e, outer) => + // TODO: create alias for outer attributes, they may conflict with the attributes + // from children of q. + resolveExpression(e, outer, throws = false) + } + } + } + if (resolvedByOuter fastEquals analyzed) { + analyzed + } else { + resolveCorrelatedSubquery(resolvedByOuter, outers) + } } } @@ -883,7 +899,7 @@ class Analyzer( case e: SubqueryExpression if !e.query.resolved => // First resolve as much of the sub-query as possible. After that we use the children of // this plan to resolve the remaining correlated predicates. - e.withNewPlan(q.children.foldLeft(execute(e.query))(resolveCorrelatedPredicates)) + e.withNewPlan(resolveCorrelatedSubquery(e.query, q.children)) } } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala index 45e4d535c18cc..817e40f564507 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala @@ -113,16 +113,22 @@ trait CheckAnalysis extends PredicateHelper { case f @ Filter(condition, child) => // Make sure that no correlated reference is below Aggregates, Outer Joins and on the // right hand side of Unions. - lazy val attributes = child.outputSet + lazy val outerAttributes = child.outputSet def failOnCorrelatedReference( - p: LogicalPlan, - message: String): Unit = p.transformAllExpressions { - case e: NamedExpression if attributes.contains(e) => - failAnalysis(s"Accessing outer query column is not allowed in $message: $e") + plan: LogicalPlan, + message: String): Unit = plan foreach { + case p => + lazy val inputs = p.inputSet + p.transformExpressions { + case e: AttributeReference + if !inputs.contains(e) && outerAttributes.contains(e) => + println(s"inputs: $inputs $outerAttributes") + failAnalysis(s"Accessing outer query column is not allowed in $message: $e") + } } def checkForCorrelatedReferences(p: PredicateSubquery): Unit = p.query.foreach { case a @ Aggregate(_, _, source) => - failOnCorrelatedReference(source, "an AGGREATE") + failOnCorrelatedReference(source, "an AGGREGATE") case j @ Join(left, _, RightOuter, _) => failOnCorrelatedReference(left, "a RIGHT OUTER JOIN") case j @ Join(_, right, jt, _) if jt != Inner => diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/subquery.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/subquery.scala index cbee0e61f7a7a..89c8201f20542 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/subquery.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/subquery.scala @@ -87,7 +87,6 @@ case class ScalarSubquery( */ abstract class PredicateSubquery extends SubqueryExpression with Unevaluable with Predicate { override def nullable: Boolean = false - override def plan: LogicalPlan = SubqueryAlias(prettyName, query) } object PredicateSubquery { @@ -105,10 +104,14 @@ object PredicateSubquery { * FROM b) * }}} */ -case class InSubQuery(value: Expression, query: LogicalPlan) extends PredicateSubquery { +case class InSubQuery( + value: Expression, + query: LogicalPlan, + exprId: ExprId = NamedExpression.newExprId) extends PredicateSubquery { override def children: Seq[Expression] = value :: Nil override lazy val resolved: Boolean = value.resolved && query.resolved - override def withNewPlan(plan: LogicalPlan): InSubQuery = InSubQuery(value, plan) + override def withNewPlan(plan: LogicalPlan): InSubQuery = InSubQuery(value, plan, exprId) + override def plan: LogicalPlan = SubqueryAlias(s"subquery#${exprId.id}", query) /** * The unwrapped value side expressions. @@ -140,6 +143,8 @@ case class InSubQuery(value: Expression, query: LogicalPlan) extends PredicateSu TypeCheckResult.TypeCheckSuccess } + + override def toString: String = s"$value IN subquery#${exprId.id}" } /** @@ -153,7 +158,11 @@ case class InSubQuery(value: Expression, query: LogicalPlan) extends PredicateSu * WHERE b.id = a.id) * }}} */ -case class Exists(query: LogicalPlan) extends PredicateSubquery { +case class Exists( + query: LogicalPlan, + exprId: ExprId = NamedExpression.newExprId) extends PredicateSubquery { override def children: Seq[Expression] = Nil - override def withNewPlan(plan: LogicalPlan): Exists = Exists(plan) + override def withNewPlan(plan: LogicalPlan): Exists = Exists(plan, exprId) + override def plan: LogicalPlan = SubqueryAlias(toString, query) + override def toString: String = s"exists#${exprId.id}" } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala index ecc2d773e7753..994e31d8304d8 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala @@ -1476,7 +1476,7 @@ object RewritePredicateSubquery extends Rule[LogicalPlan] with PredicateHelper { case f @ Filter(cond, child) => // Find all correlated predicates. val (correlated, local) = splitConjunctivePredicates(cond).partition { e => - e.references.intersect(references).nonEmpty + (e.references -- child.outputSet).intersect(references).nonEmpty } // Rewrite the filter without the correlated predicates if any. correlated match { @@ -1536,10 +1536,10 @@ object RewritePredicateSubquery extends Rule[LogicalPlan] with PredicateHelper { // Filter the plan by applying left semi and left anti joins. withSubquery.foldLeft(newFilter) { - case (p, Exists(sub)) => + case (p, Exists(sub, _)) => val (resolved, conditions) = pullOutCorrelatedPredicates(sub, p) Join(p, resolved, LeftSemi, conditions.reduceOption(And)) - case (p, Not(Exists(sub))) => + case (p, Not(Exists(sub, _))) => val (resolved, conditions) = pullOutCorrelatedPredicates(sub, p) Join(p, resolved, LeftAnti, conditions.reduceOption(And)) case (p, in: InSubQuery) => diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SubquerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SubquerySuite.scala index 5742983fb9d07..2ea104d8c8135 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SubquerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SubquerySuite.scala @@ -115,21 +115,21 @@ class SubquerySuite extends QueryTest with SharedSQLContext { test("EXISTS predicate subquery") { checkAnswer( - sql("select * from l where exists(select * from r where l.a = r.c)"), + sql("select * from l where exists (select * from r where l.a = r.c)"), Row(2, 1.0) :: Row(2, 1.0) :: Row(3, 3.0) :: Row(6, null) :: Nil) checkAnswer( - sql("select * from l where exists(select * from r where l.a = r.c) and l.a <= 2"), + sql("select * from l where exists (select * from r where l.a = r.c) and l.a <= 2"), Row(2, 1.0) :: Row(2, 1.0) :: Nil) } test("NOT EXISTS predicate subquery") { checkAnswer( - sql("select * from l where not exists(select * from r where l.a = r.c)"), + sql("select * from l where not exists (select * from r where l.a = r.c)"), Row(1, 2.0) :: Row(1, 2.0) :: Row(null, null) :: Row(null, 5.0) :: Nil) checkAnswer( - sql("select * from l where not exists(select * from r where l.a = r.c and l.b < r.d)"), + sql("select * from l where not exists (select * from r where l.a = r.c and l.b < r.d)"), Row(1, 2.0) :: Row(1, 2.0) :: Row(3, 3.0) :: Row(null, null) :: Row(null, 5.0) :: Row(6, null) :: Nil) } @@ -150,20 +150,20 @@ class SubquerySuite extends QueryTest with SharedSQLContext { test("NOT IN predicate subquery") { checkAnswer( - sql("select * from l where a not in(select c from r)"), + sql("select * from l where a not in (select c from r)"), Nil) checkAnswer( - sql("select * from l where a not in(select c from r where c is not null)"), + sql("select * from l where a not in (select c from r where c is not null)"), Row(1, 2.0) :: Row(1, 2.0) :: Nil) checkAnswer( - sql("select * from l where a not in(select c from t where b < d)"), + sql("select * from l where a not in (select c from t where b < d)"), Row(1, 2.0) :: Row(1, 2.0) :: Row(3, 3.0) :: Nil) // Empty sub-query checkAnswer( - sql("select * from l where a not in(select c from r where c > 10 and b < d)"), + sql("select * from l where a not in (select c from r where c > 10 and b < d)"), Row(1, 2.0) :: Row(1, 2.0) :: Row(2, 1.0) :: Row(2, 1.0) :: Row(3, 3.0) :: Row(null, null) :: Row(null, 5.0) :: Row(6, null) :: Nil) @@ -171,11 +171,18 @@ class SubquerySuite extends QueryTest with SharedSQLContext { test("complex IN predicate subquery") { checkAnswer( - sql("select * from l where (a, b) not in(select c, d from r)"), + sql("select * from l where (a, b) not in (select c, d from r)"), Nil) checkAnswer( - sql("select * from l where (a, b) not in(select c, d from t) and (a + b) is not null"), + sql("select * from l where (a, b) not in (select c, d from t) and (a + b) is not null"), Row(1, 2.0) :: Row(1, 2.0) :: Row(2, 1.0) :: Row(2, 1.0) :: Row(3, 3.0) :: Nil) } + + test("same column in subquery and outer table") { + checkAnswer( + sql("select a from l l1 where a in (select a as b from l where a < 3 group by a)"), + Row(1) :: Row(1) :: Row(2) :: Row(2) :: Nil + ) + } } From b5b4c74446d6c6aadcd5c95a7ed588edc72d0957 Mon Sep 17 00:00:00 2001 From: Davies Liu Date: Wed, 20 Apr 2016 15:00:50 -0700 Subject: [PATCH 2/5] fix column with same exprId --- .../sql/catalyst/analysis/Analyzer.scala | 50 +++++++++++++------ .../sql/catalyst/optimizer/Optimizer.scala | 45 ++++++++++++++--- .../org/apache/spark/sql/SubquerySuite.scala | 2 +- 3 files changed, 74 insertions(+), 23 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index cd4f356173b21..28a1890e426b8 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -862,13 +862,14 @@ class Analyzer( object ResolveSubquery extends Rule[LogicalPlan] with PredicateHelper { /** - * Resolve the correlated predicates in the [[Filter]] clauses (e.g. WHERE or HAVING) of a + * Resolve the correlated predicates in the clauses (e.g. WHERE or HAVING) of a * sub-query by using the plan the predicates should be correlated to. */ private def resolveCorrelatedSubquery( - subquery: LogicalPlan, - outers: Seq[LogicalPlan]): LogicalPlan = { - val analyzed = execute(subquery) + sub: LogicalPlan, outer: LogicalPlan, + aliases: scala.collection.mutable.Map[Attribute, Alias]): LogicalPlan = { + // First resolve as much of the sub-query as possible + val analyzed = execute(sub) if (analyzed.resolved) { analyzed } else { @@ -877,29 +878,48 @@ class Analyzer( val resolvedByOuter = analyzed transformDown { case q: LogicalPlan if q.childrenResolved && !q.resolved => q transformExpressions { - case expr => - outers.foldLeft(expr) { case (e, outer) => - // TODO: create alias for outer attributes, they may conflict with the attributes - // from children of q. - resolveExpression(e, outer, throws = false) + case u @ UnresolvedAttribute(nameParts) => + withPosition(u) { + try { + val outerAttrOpt = outer.resolve(nameParts, resolver) + if (outerAttrOpt.isDefined) { + // Create an alias for the attribute come from outer table, or it may conflict + // with others from subquery + val alias = Alias(outerAttrOpt.get, "outer")() + val attr = alias.toAttribute + aliases += attr -> alias + attr + } else { + u + } + } catch { + case a: AnalysisException => u + } } } } if (resolvedByOuter fastEquals analyzed) { analyzed } else { - resolveCorrelatedSubquery(resolvedByOuter, outers) + resolveCorrelatedSubquery(resolvedByOuter, outer, aliases) } } } def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators { - case q: LogicalPlan if q.childrenResolved => - q transformExpressions { + // Only a few unary node (Project/Filter/Aggregate/Having) could have subquery + case q: UnaryNode if q.childrenResolved => + val aliases = scala.collection.mutable.Map[Attribute, Alias]() + val newPlan = q transformExpressions { case e: SubqueryExpression if !e.query.resolved => - // First resolve as much of the sub-query as possible. After that we use the children of - // this plan to resolve the remaining correlated predicates. - e.withNewPlan(resolveCorrelatedSubquery(e.query, q.children)) + e.withNewPlan(resolveCorrelatedSubquery(e.query, q.child, aliases)) + } + if (aliases.nonEmpty) { + val projs = q.child.output ++ aliases.values + Project(q.child.output, + newPlan.withNewChildren(Seq(Project(projs, q.child)))) + } else { + newPlan } } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala index 994e31d8304d8..e97fabe106b8a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala @@ -1517,10 +1517,31 @@ object RewritePredicateSubquery extends Rule[LogicalPlan] with PredicateHelper { */ private def pullOutCorrelatedPredicates( in: InSubQuery, - query: LogicalPlan): (LogicalPlan, Seq[Expression]) = { + query: LogicalPlan): (LogicalPlan, LogicalPlan, Seq[Expression]) = { val (resolved, joinCondition) = pullOutCorrelatedPredicates(in.query, query) - val conditions = joinCondition ++ in.expressions.zip(resolved.output).map(EqualTo.tupled) - (resolved, conditions) + // in.expressions may have the same + val outerAttributes = AttributeSet(in.expressions.flatMap(_.references)) + if (outerAttributes.intersect(resolved.outputSet).nonEmpty) { + val aliases = mutable.Map[Attribute, Alias]() + val exprs = in.expressions.map { expr => + expr transformUp { + case a: AttributeReference if resolved.outputSet.contains(a) => + val alias = Alias(a, a.toString)() + val attr = alias.toAttribute + aliases += attr -> alias + attr + } + } + val newP = Project(query.output ++ aliases.values, query) + val newResolved = Project(resolved.output.map(a => Alias(a, a.toString)()), + resolved) + val conditions = joinCondition ++ exprs.zip(newResolved.output).map(EqualTo.tupled) + (newP, newResolved, conditions) + } else { + val conditions = + joinCondition ++ in.expressions.zip(resolved.output).map(EqualTo.tupled) + (query, resolved, conditions) + } } def apply(plan: LogicalPlan): LogicalPlan = plan transform { @@ -1543,10 +1564,15 @@ object RewritePredicateSubquery extends Rule[LogicalPlan] with PredicateHelper { val (resolved, conditions) = pullOutCorrelatedPredicates(sub, p) Join(p, resolved, LeftAnti, conditions.reduceOption(And)) case (p, in: InSubQuery) => - val (resolved, conditions) = pullOutCorrelatedPredicates(in, p) - Join(p, resolved, LeftSemi, conditions.reduceOption(And)) + val (newP, resolved, conditions) = pullOutCorrelatedPredicates(in, p) + if (newP fastEquals p) { + Join(p, resolved, LeftSemi, conditions.reduceOption(And)) + } else { + Project(p.output, + Join(newP, resolved, LeftSemi, conditions.reduceOption(And))) + } case (p, Not(in: InSubQuery)) => - val (resolved, conditions) = pullOutCorrelatedPredicates(in, p) + val (newP, resolved, conditions) = pullOutCorrelatedPredicates(in, p) // This is a NULL-aware (left) anti join (NAAJ). // Construct the condition. A NULL in one of the conditions is regarded as a positive // result; such a row will be filtered out by the Anti-Join operator. @@ -1555,7 +1581,12 @@ object RewritePredicateSubquery extends Rule[LogicalPlan] with PredicateHelper { // Note that will almost certainly be planned as a Broadcast Nested Loop join. Use EXISTS // if performance matters to you. - Join(p, resolved, LeftAnti, Option(Or(anyNull, condition))) + if (newP fastEquals p) { + Join(p, resolved, LeftAnti, Option(Or(anyNull, condition))) + } else { + Project(p.output, + Join(newP, resolved, LeftAnti, Option(Or(anyNull, condition)))) + } } } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SubquerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SubquerySuite.scala index 2ea104d8c8135..b4a08f87877b3 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SubquerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SubquerySuite.scala @@ -181,7 +181,7 @@ class SubquerySuite extends QueryTest with SharedSQLContext { test("same column in subquery and outer table") { checkAnswer( - sql("select a from l l1 where a in (select a as b from l where a < 3 group by a)"), + sql("select a from l l1 where a in (select a from l where a < 3 group by a)"), Row(1) :: Row(1) :: Row(2) :: Row(2) :: Nil ) } From e04d1193f9a15785224722c40b3fdfae6eb6cfb4 Mon Sep 17 00:00:00 2001 From: Davies Liu Date: Wed, 20 Apr 2016 15:06:31 -0700 Subject: [PATCH 3/5] only create an alias when has a conflict --- .../spark/sql/catalyst/analysis/Analyzer.scala | 16 ++++++++++------ 1 file changed, 10 insertions(+), 6 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index 28a1890e426b8..24136ef7e54a4 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -883,12 +883,16 @@ class Analyzer( try { val outerAttrOpt = outer.resolve(nameParts, resolver) if (outerAttrOpt.isDefined) { - // Create an alias for the attribute come from outer table, or it may conflict - // with others from subquery - val alias = Alias(outerAttrOpt.get, "outer")() - val attr = alias.toAttribute - aliases += attr -> alias - attr + val outerAttr = outerAttrOpt.get + if (q.inputSet.contains(outerAttr)) { + // Got a conflict, create an alias for the attribute come from outer table + val alias = Alias(outerAttr, outerAttr.toString)() + val attr = alias.toAttribute + aliases += attr -> alias + attr + } else { + outerAttr + } } else { u } From c3327a94a5fdaa303b979c96b0bbf593fdf3e6ff Mon Sep 17 00:00:00 2001 From: Davies Liu Date: Wed, 20 Apr 2016 16:22:54 -0700 Subject: [PATCH 4/5] fix type coercion for InSubquery --- .../sql/catalyst/analysis/CheckAnalysis.scala | 1 - .../catalyst/analysis/HiveTypeCoercion.scala | 25 +++++++++++++++++++ .../sql/catalyst/expressions/subquery.scala | 6 ++--- .../sql/catalyst/optimizer/Optimizer.scala | 13 ++++++---- 4 files changed, 36 insertions(+), 9 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala index 817e40f564507..a50b9a1e1a9d8 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala @@ -122,7 +122,6 @@ trait CheckAnalysis extends PredicateHelper { p.transformExpressions { case e: AttributeReference if !inputs.contains(e) && outerAttributes.contains(e) => - println(s"inputs: $inputs $outerAttributes") failAnalysis(s"Accessing outer query column is not allowed in $message: $e") } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala index 5323b79c57c4b..0306afb0d8bbb 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala @@ -376,6 +376,31 @@ object HiveTypeCoercion { case Some(finalDataType) => i.withNewChildren(i.children.map(Cast(_, finalDataType))) case None => i } + + case InSubQuery(struct: CreateStruct, subquery, exprId) + if struct.children.zip(subquery.output).exists(x => x._1.dataType != x._2.dataType) => + val widerTypes: Seq[Option[DataType]] = struct.children.zip(subquery.output).map { + case (l, r) => findWiderTypeForTwo(l.dataType, r.dataType) + } + val newStruct = struct.withNewChildren(struct.children.zip(widerTypes).map { + case (e, Some(t)) => Cast(e, t) + case (e, _) => e + }) + val newSubquery = Project(subquery.output.zip(widerTypes).map { + case (a, Some(t)) => Alias(Cast(a, t), a.toString)() + case (a, _) => a + }, subquery) + InSubQuery(newStruct, newSubquery, exprId) + + case sub @ InSubQuery(expr, subquery, exprId) + if expr.dataType != subquery.output.head.dataType => + findWiderTypeForTwo(expr.dataType, subquery.output.head.dataType) match { + case Some(t) => + val attr = subquery.output.head + val proj = Seq(Alias(Cast(attr, t), attr.toString)()) + InSubQuery(Cast(expr, t), Project(proj, subquery), exprId) + case _ => sub + } } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/subquery.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/subquery.scala index 89c8201f20542..1993bd2587d1f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/subquery.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/subquery.scala @@ -127,7 +127,7 @@ case class InSubQuery( override def checkInputDataTypes(): TypeCheckResult = { // Check the number of arguments. if (expressions.length != query.output.length) { - TypeCheckResult.TypeCheckFailure( + return TypeCheckResult.TypeCheckFailure( s"The number of fields in the value (${expressions.length}) does not match with " + s"the number of columns in the subquery (${query.output.length})") } @@ -135,8 +135,8 @@ case class InSubQuery( // Check the argument types. expressions.zip(query.output).zipWithIndex.foreach { case ((e, a), i) if e.dataType != a.dataType => - TypeCheckResult.TypeCheckFailure( - s"The data type of value[$i](${e.dataType}) does not match " + + return TypeCheckResult.TypeCheckFailure( + s"The data type of value[$i] (${e.dataType}) does not match " + s"subquery column '${a.name}' (${a.dataType}).") case _ => } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala index e97fabe106b8a..9a7bfaee00bc5 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala @@ -1519,7 +1519,7 @@ object RewritePredicateSubquery extends Rule[LogicalPlan] with PredicateHelper { in: InSubQuery, query: LogicalPlan): (LogicalPlan, LogicalPlan, Seq[Expression]) = { val (resolved, joinCondition) = pullOutCorrelatedPredicates(in.query, query) - // in.expressions may have the same + // Check whether there is some attributes have same exprId but come from different side val outerAttributes = AttributeSet(in.expressions.flatMap(_.references)) if (outerAttributes.intersect(resolved.outputSet).nonEmpty) { val aliases = mutable.Map[Attribute, Alias]() @@ -1533,10 +1533,13 @@ object RewritePredicateSubquery extends Rule[LogicalPlan] with PredicateHelper { } } val newP = Project(query.output ++ aliases.values, query) - val newResolved = Project(resolved.output.map(a => Alias(a, a.toString)()), - resolved) - val conditions = joinCondition ++ exprs.zip(newResolved.output).map(EqualTo.tupled) - (newP, newResolved, conditions) + val projection = resolved.output.map { + case a if outerAttributes.contains(a) => Alias(a, a.toString)() + case a => a + } + val subquery = Project(projection, resolved) + val conditions = joinCondition ++ exprs.zip(subquery.output).map(EqualTo.tupled) + (newP, subquery, conditions) } else { val conditions = joinCondition ++ in.expressions.zip(resolved.output).map(EqualTo.tupled) From fd8c75cfbc847c29ba313fce5f43903396714356 Mon Sep 17 00:00:00 2001 From: Davies Liu Date: Wed, 20 Apr 2016 20:08:39 -0700 Subject: [PATCH 5/5] fix tests --- .../scala/org/apache/spark/sql/catalyst/plans/PlanTest.scala | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/PlanTest.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/PlanTest.scala index 71919366999ab..f5439d70addc4 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/PlanTest.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/PlanTest.scala @@ -35,6 +35,10 @@ abstract class PlanTest extends SparkFunSuite with PredicateHelper { plan transformAllExpressions { case s: ScalarSubquery => ScalarSubquery(s.query, ExprId(0)) + case s: InSubQuery => + InSubQuery(s.value, s.query, ExprId(0)) + case e: Exists => + Exists(e.query, ExprId(0)) case a: AttributeReference => AttributeReference(a.name, a.dataType, a.nullable)(exprId = ExprId(0)) case a: Alias =>