From fb6c8cd182dae1aded07baf59cb185f1afde84e7 Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Tue, 15 Mar 2016 09:01:04 +0000 Subject: [PATCH 01/17] Modify output nullable with constraint for Join. --- .../spark/sql/catalyst/plans/QueryPlan.scala | 5 ++- .../plans/logical/basicOperators.scala | 33 +++++++++++++++++++ .../optimizer/NullFilteringSuite.scala | 25 ++++++++++++++ 3 files changed, 62 insertions(+), 1 deletion(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala index 920e989d058dc..d335d405ec68e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala @@ -26,16 +26,19 @@ abstract class QueryPlan[PlanType <: QueryPlan[PlanType]] extends TreeNode[PlanT def output: Seq[Attribute] + def outputForConstraint: Seq[Attribute] = output + /** * Extracts the relevant constraints from a given set of constraints based on the attributes that * appear in the [[outputSet]]. */ protected def getRelevantConstraints(constraints: Set[Expression]): Set[Expression] = { + val relatedOutputSet = AttributeSet(outputForConstraint) constraints .union(inferAdditionalConstraints(constraints)) .union(constructIsNotNullConstraints(constraints)) .filter(constraint => - constraint.references.nonEmpty && constraint.references.subsetOf(outputSet)) + constraint.references.nonEmpty && constraint.references.subsetOf(relatedOutputSet)) } /** diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala index 09ea3fea6a694..0cc9234c2c33d 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala @@ -250,7 +250,40 @@ case class Join( condition: Option[Expression]) extends BinaryNode with PredicateHelper { + private def leftNotNulls = constraints + .filter(_.isInstanceOf[IsNotNull]) + .filter(_.references.subsetOf(left.outputSet)) + .flatMap(_.references.map(_.exprId)) + + private def notNullLeftOutput = left.output.map { o => + if (leftNotNulls.contains(o.exprId)) o.withNullability(false) else o + } + + private def rightNotNulls = constraints + .filter(_.isInstanceOf[IsNotNull]) + .filter(_.references.subsetOf(right.outputSet)) + .flatMap(_.references.map(_.exprId)) + + private def notNullRightOutput = right.output.map { o => + if (rightNotNulls.contains(o.exprId)) o.withNullability(false) else o + } + override def output: Seq[Attribute] = { + joinType match { + case LeftSemi => + notNullLeftOutput + case LeftOuter => + notNullLeftOutput ++ right.output.map(_.withNullability(true)) + case RightOuter => + left.output.map(_.withNullability(true)) ++ notNullRightOutput + case FullOuter => + left.output.map(_.withNullability(true)) ++ right.output.map(_.withNullability(true)) + case _ => + notNullLeftOutput ++ notNullRightOutput + } + } + + override def outputForConstraint: Seq[Attribute] = { joinType match { case LeftSemi => left.output diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/NullFilteringSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/NullFilteringSuite.scala index 142e4ae6e4399..6853bc5d27d4a 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/NullFilteringSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/NullFilteringSuite.scala @@ -31,6 +31,21 @@ class NullFilteringSuite extends PlanTest { Batch("CombineFilters", FixedPoint(5), CombineFilters) :: Nil } + def compareNullability(query: LogicalPlan): Unit = { + val constraints = query.constraints + val output = query.output + + val notNullOutput = query.constraints + .filter(_.isInstanceOf[IsNotNull]) + .flatMap(_.references) + + notNullOutput.foreach { o => + if (query.outputSet.contains(o)) { + assert(query.output.exists(q => o.exprId == q.exprId && !q.nullable)) + } + } + } + val testRelation = LocalRelation('a.int, 'b.int, 'c.int) test("filter: filter out nulls in condition") { @@ -46,6 +61,7 @@ class NullFilteringSuite extends PlanTest { val originalQuery = x.join(y, condition = Some(("x.a".attr === "y.a".attr) && ("x.b".attr === 1) && ("y.c".attr > 5))) .analyze + compareNullability(originalQuery) val left = x.where(IsNotNull('a) && IsNotNull('b)) val right = y.where(IsNotNull('a) && IsNotNull('c)) val correctAnswer = left.join(right, @@ -53,6 +69,7 @@ class NullFilteringSuite extends PlanTest { .analyze val optimized = Optimize.execute(originalQuery) comparePlans(optimized, correctAnswer) + compareNullability(optimized) } test("single inner join: filter out nulls on either side on non equal keys") { @@ -61,6 +78,7 @@ class NullFilteringSuite extends PlanTest { val originalQuery = x.join(y, condition = Some(("x.a".attr =!= "y.a".attr) && ("x.b".attr === 1) && ("y.c".attr > 5))) .analyze + compareNullability(originalQuery) val left = x.where(IsNotNull('a) && IsNotNull('b)) val right = y.where(IsNotNull('a) && IsNotNull('c)) val correctAnswer = left.join(right, @@ -68,6 +86,7 @@ class NullFilteringSuite extends PlanTest { .analyze val optimized = Optimize.execute(originalQuery) comparePlans(optimized, correctAnswer) + compareNullability(optimized) } test("single inner join with pre-existing filters: filter out nulls on either side") { @@ -75,12 +94,14 @@ class NullFilteringSuite extends PlanTest { val y = testRelation.subquery('y) val originalQuery = x.where('b > 5).join(y.where('c === 10), condition = Some("x.a".attr === "y.a".attr)).analyze + compareNullability(originalQuery) val left = x.where(IsNotNull('a) && IsNotNull('b) && 'b > 5) val right = y.where(IsNotNull('a) && IsNotNull('c) && 'c === 10) val correctAnswer = left.join(right, condition = Some("x.a".attr === "y.a".attr)).analyze val optimized = Optimize.execute(originalQuery) comparePlans(optimized, correctAnswer) + compareNullability(optimized) } test("single outer join: no null filters are generated") { @@ -88,8 +109,10 @@ class NullFilteringSuite extends PlanTest { val y = testRelation.subquery('y) val originalQuery = x.join(y, FullOuter, condition = Some("x.a".attr === "y.a".attr)).analyze + compareNullability(originalQuery) val optimized = Optimize.execute(originalQuery) comparePlans(optimized, originalQuery) + compareNullability(optimized) } test("multiple inner joins: filter out nulls on all sides on equi-join keys") { @@ -102,11 +125,13 @@ class NullFilteringSuite extends PlanTest { .join(t2, condition = Some("t1.b".attr === "t2.b".attr)) .join(t3, condition = Some("t2.b".attr === "t3.b".attr)) .join(t4, condition = Some("t3.b".attr === "t4.b".attr)).analyze + compareNullability(originalQuery) val correctAnswer = t1.where(IsNotNull('b)) .join(t2.where(IsNotNull('b)), condition = Some("t1.b".attr === "t2.b".attr)) .join(t3.where(IsNotNull('b)), condition = Some("t2.b".attr === "t3.b".attr)) .join(t4.where(IsNotNull('b)), condition = Some("t3.b".attr === "t4.b".attr)).analyze val optimized = Optimize.execute(originalQuery) comparePlans(optimized, correctAnswer) + compareNullability(optimized) } } From 2e4eca4c8336da790321569709475eaff8f193b5 Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Tue, 15 Mar 2016 09:51:11 +0000 Subject: [PATCH 02/17] Replace attributes in condition with correct ones. --- .../org/apache/spark/sql/catalyst/planning/patterns.scala | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala index 7302b63646d66..c98f99f83f6f7 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala @@ -101,7 +101,12 @@ object ExtractEquiJoinKeys extends Logging with PredicateHelper { (JoinType, Seq[Expression], Seq[Expression], Option[Expression], LogicalPlan, LogicalPlan) def unapply(plan: LogicalPlan): Option[ReturnType] = plan match { - case join @ Join(left, right, joinType, condition) => + case join @ Join(left, right, joinType, cond) => + val attributeRewrites = join.output.map(o => o.exprId -> o).toMap + val condition = cond.map(_.transform { + case a: AttributeReference => attributeRewrites(a.exprId) + }) + logDebug(s"Considering join on: $condition") // Find equi-join predicates that can be evaluated before the join, and thus can be used // as join keys. From aef73d5e71b9c997ac98a7172036c7c79f9e9b1c Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Wed, 16 Mar 2016 04:32:41 +0000 Subject: [PATCH 03/17] Refactor. --- .../spark/sql/catalyst/plans/QueryPlan.scala | 5 +-- .../plans/logical/basicOperators.scala | 33 +++++++++---------- .../optimizer/NullFilteringSuite.scala | 23 ++++++------- 3 files changed, 29 insertions(+), 32 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala index d335d405ec68e..920e989d058dc 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala @@ -26,19 +26,16 @@ abstract class QueryPlan[PlanType <: QueryPlan[PlanType]] extends TreeNode[PlanT def output: Seq[Attribute] - def outputForConstraint: Seq[Attribute] = output - /** * Extracts the relevant constraints from a given set of constraints based on the attributes that * appear in the [[outputSet]]. */ protected def getRelevantConstraints(constraints: Set[Expression]): Set[Expression] = { - val relatedOutputSet = AttributeSet(outputForConstraint) constraints .union(inferAdditionalConstraints(constraints)) .union(constructIsNotNullConstraints(constraints)) .filter(constraint => - constraint.references.nonEmpty && constraint.references.subsetOf(relatedOutputSet)) + constraint.references.nonEmpty && constraint.references.subsetOf(outputSet)) } /** diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala index 0cc9234c2c33d..3c585cb640eff 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala @@ -250,19 +250,33 @@ case class Join( condition: Option[Expression]) extends BinaryNode with PredicateHelper { - private def leftNotNulls = constraints + private def notNullsFromCondition: Set[Attribute] = condition.toSet.flatMap { (e: Expression) => + e.collect { + case e @ EqualTo(_, _) => e.references + case g @ GreaterThan(_, _) => g.references + case g @ GreaterThanOrEqual(_, _) => g.references + case l @ LessThan(_, _) => l.references + case l @ LessThanOrEqual(_, _) => l.references + case n @ Not(EqualTo(_, _)) => n.references + case e @ IsNotNull(a: Attribute) => Set(a) + } + }.foldLeft(Set.empty[Attribute])(_ union _.toSet) + + private def leftNotNulls = left.constraints .filter(_.isInstanceOf[IsNotNull]) .filter(_.references.subsetOf(left.outputSet)) .flatMap(_.references.map(_.exprId)) + .union(notNullsFromCondition.map(_.exprId)) private def notNullLeftOutput = left.output.map { o => if (leftNotNulls.contains(o.exprId)) o.withNullability(false) else o } - private def rightNotNulls = constraints + private def rightNotNulls = right.constraints .filter(_.isInstanceOf[IsNotNull]) .filter(_.references.subsetOf(right.outputSet)) .flatMap(_.references.map(_.exprId)) + .union(notNullsFromCondition.map(_.exprId)) private def notNullRightOutput = right.output.map { o => if (rightNotNulls.contains(o.exprId)) o.withNullability(false) else o @@ -283,21 +297,6 @@ case class Join( } } - override def outputForConstraint: Seq[Attribute] = { - joinType match { - case LeftSemi => - left.output - case LeftOuter => - left.output ++ right.output.map(_.withNullability(true)) - case RightOuter => - left.output.map(_.withNullability(true)) ++ right.output - case FullOuter => - left.output.map(_.withNullability(true)) ++ right.output.map(_.withNullability(true)) - case _ => - left.output ++ right.output - } - } - override protected def validConstraints: Set[Expression] = { joinType match { case Inner if condition.isDefined => diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/NullFilteringSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/NullFilteringSuite.scala index 6853bc5d27d4a..d42c6cd369e72 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/NullFilteringSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/NullFilteringSuite.scala @@ -31,7 +31,7 @@ class NullFilteringSuite extends PlanTest { Batch("CombineFilters", FixedPoint(5), CombineFilters) :: Nil } - def compareNullability(query: LogicalPlan): Unit = { + def checkNullability(query: LogicalPlan): Unit = { val constraints = query.constraints val output = query.output @@ -61,7 +61,7 @@ class NullFilteringSuite extends PlanTest { val originalQuery = x.join(y, condition = Some(("x.a".attr === "y.a".attr) && ("x.b".attr === 1) && ("y.c".attr > 5))) .analyze - compareNullability(originalQuery) + checkNullability(originalQuery) val left = x.where(IsNotNull('a) && IsNotNull('b)) val right = y.where(IsNotNull('a) && IsNotNull('c)) val correctAnswer = left.join(right, @@ -69,7 +69,7 @@ class NullFilteringSuite extends PlanTest { .analyze val optimized = Optimize.execute(originalQuery) comparePlans(optimized, correctAnswer) - compareNullability(optimized) + checkNullability(optimized) } test("single inner join: filter out nulls on either side on non equal keys") { @@ -78,7 +78,7 @@ class NullFilteringSuite extends PlanTest { val originalQuery = x.join(y, condition = Some(("x.a".attr =!= "y.a".attr) && ("x.b".attr === 1) && ("y.c".attr > 5))) .analyze - compareNullability(originalQuery) + checkNullability(originalQuery) val left = x.where(IsNotNull('a) && IsNotNull('b)) val right = y.where(IsNotNull('a) && IsNotNull('c)) val correctAnswer = left.join(right, @@ -86,7 +86,7 @@ class NullFilteringSuite extends PlanTest { .analyze val optimized = Optimize.execute(originalQuery) comparePlans(optimized, correctAnswer) - compareNullability(optimized) + checkNullability(optimized) } test("single inner join with pre-existing filters: filter out nulls on either side") { @@ -94,14 +94,14 @@ class NullFilteringSuite extends PlanTest { val y = testRelation.subquery('y) val originalQuery = x.where('b > 5).join(y.where('c === 10), condition = Some("x.a".attr === "y.a".attr)).analyze - compareNullability(originalQuery) + checkNullability(originalQuery) val left = x.where(IsNotNull('a) && IsNotNull('b) && 'b > 5) val right = y.where(IsNotNull('a) && IsNotNull('c) && 'c === 10) val correctAnswer = left.join(right, condition = Some("x.a".attr === "y.a".attr)).analyze val optimized = Optimize.execute(originalQuery) comparePlans(optimized, correctAnswer) - compareNullability(optimized) + checkNullability(optimized) } test("single outer join: no null filters are generated") { @@ -109,10 +109,10 @@ class NullFilteringSuite extends PlanTest { val y = testRelation.subquery('y) val originalQuery = x.join(y, FullOuter, condition = Some("x.a".attr === "y.a".attr)).analyze - compareNullability(originalQuery) + checkNullability(originalQuery) val optimized = Optimize.execute(originalQuery) comparePlans(optimized, originalQuery) - compareNullability(optimized) + checkNullability(optimized) } test("multiple inner joins: filter out nulls on all sides on equi-join keys") { @@ -125,13 +125,14 @@ class NullFilteringSuite extends PlanTest { .join(t2, condition = Some("t1.b".attr === "t2.b".attr)) .join(t3, condition = Some("t2.b".attr === "t3.b".attr)) .join(t4, condition = Some("t3.b".attr === "t4.b".attr)).analyze - compareNullability(originalQuery) + checkNullability(originalQuery) val correctAnswer = t1.where(IsNotNull('b)) .join(t2.where(IsNotNull('b)), condition = Some("t1.b".attr === "t2.b".attr)) .join(t3.where(IsNotNull('b)), condition = Some("t2.b".attr === "t3.b".attr)) .join(t4.where(IsNotNull('b)), condition = Some("t3.b".attr === "t4.b".attr)).analyze val optimized = Optimize.execute(originalQuery) comparePlans(optimized, correctAnswer) - compareNullability(optimized) + checkNullability(correctAnswer) + checkNullability(optimized) } } From 5bf4b4b544ef2aa25d93c974e94f8314a6626ef7 Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Wed, 16 Mar 2016 04:47:27 +0000 Subject: [PATCH 04/17] Refactor. --- .../spark/sql/catalyst/plans/QueryPlan.scala | 2 +- .../plans/logical/basicOperators.scala | 20 +++++-------------- 2 files changed, 6 insertions(+), 16 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala index 920e989d058dc..d966562348457 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala @@ -43,7 +43,7 @@ abstract class QueryPlan[PlanType <: QueryPlan[PlanType]] extends TreeNode[PlanT * For e.g., if an expression is of the form (`a > 5`), this returns a constraint of the form * `isNotNull(a)` */ - private def constructIsNotNullConstraints(constraints: Set[Expression]): Set[Expression] = { + protected def constructIsNotNullConstraints(constraints: Set[Expression]): Set[Expression] = { // Currently we only propagate constraints if the condition consists of equality // and ranges. For all other cases, we return an empty set of constraints constraints.map { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala index 3c585cb640eff..c2bdf05133d5b 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala @@ -250,33 +250,23 @@ case class Join( condition: Option[Expression]) extends BinaryNode with PredicateHelper { - private def notNullsFromCondition: Set[Attribute] = condition.toSet.flatMap { (e: Expression) => - e.collect { - case e @ EqualTo(_, _) => e.references - case g @ GreaterThan(_, _) => g.references - case g @ GreaterThanOrEqual(_, _) => g.references - case l @ LessThan(_, _) => l.references - case l @ LessThanOrEqual(_, _) => l.references - case n @ Not(EqualTo(_, _)) => n.references - case e @ IsNotNull(a: Attribute) => Set(a) - } - }.foldLeft(Set.empty[Attribute])(_ union _.toSet) + private def notNullsFromCondition: Set[Expression] = + constructIsNotNullConstraints(condition.toSet.flatMap(splitConjunctivePredicates)) + .filter(_.references.nonEmpty) - private def leftNotNulls = left.constraints + private def leftNotNulls = left.constraints.union(notNullsFromCondition) .filter(_.isInstanceOf[IsNotNull]) .filter(_.references.subsetOf(left.outputSet)) .flatMap(_.references.map(_.exprId)) - .union(notNullsFromCondition.map(_.exprId)) private def notNullLeftOutput = left.output.map { o => if (leftNotNulls.contains(o.exprId)) o.withNullability(false) else o } - private def rightNotNulls = right.constraints + private def rightNotNulls = right.constraints.union(notNullsFromCondition) .filter(_.isInstanceOf[IsNotNull]) .filter(_.references.subsetOf(right.outputSet)) .flatMap(_.references.map(_.exprId)) - .union(notNullsFromCondition.map(_.exprId)) private def notNullRightOutput = right.output.map { o => if (rightNotNulls.contains(o.exprId)) o.withNullability(false) else o From 93a73b79fbeee9c1bd722f5aa66257409d3c2512 Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Wed, 16 Mar 2016 04:59:28 +0000 Subject: [PATCH 05/17] Modify output nullability with constraints for Filter operator. --- .../sql/catalyst/plans/logical/basicOperators.scala | 11 ++++++++++- .../sql/catalyst/optimizer/NullFilteringSuite.scala | 3 +++ 2 files changed, 13 insertions(+), 1 deletion(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala index c2bdf05133d5b..243d04bb4fcd2 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala @@ -103,7 +103,16 @@ case class Generate( case class Filter(condition: Expression, child: LogicalPlan) extends UnaryNode with PredicateHelper { - override def output: Seq[Attribute] = child.output + + private def notNulls = + constructIsNotNullConstraints(splitConjunctivePredicates(condition).toSet) + .filter(x => x.isInstanceOf[IsNotNull] && x.references.nonEmpty) + .filter(_.references.subsetOf(child.outputSet)) + .flatMap(_.references.map(_.exprId)) + + override def output: Seq[Attribute] = child.output.map { o => + if (notNulls.contains(o.exprId)) o.withNullability(false) else o + } override def maxRows: Option[Long] = child.maxRows diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/NullFilteringSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/NullFilteringSuite.scala index d42c6cd369e72..5448846d7e589 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/NullFilteringSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/NullFilteringSuite.scala @@ -53,6 +53,9 @@ class NullFilteringSuite extends PlanTest { val correctAnswer = testRelation.where(IsNotNull('a) && 'a === 1).analyze val optimized = Optimize.execute(originalQuery) comparePlans(optimized, correctAnswer) + checkNullability(originalQuery) + checkNullability(correctAnswer) + checkNullability(optimized) } test("single inner join: filter out nulls on either side on equi-join keys") { From c7d54a0fb78c826903c0db8f1b1ac7b0d54bb303 Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Wed, 16 Mar 2016 07:30:08 +0000 Subject: [PATCH 06/17] Fix a bug. --- .../org/apache/spark/sql/catalyst/planning/patterns.scala | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala index c98f99f83f6f7..3e9e761eb7d8d 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala @@ -104,7 +104,8 @@ object ExtractEquiJoinKeys extends Logging with PredicateHelper { case join @ Join(left, right, joinType, cond) => val attributeRewrites = join.output.map(o => o.exprId -> o).toMap val condition = cond.map(_.transform { - case a: AttributeReference => attributeRewrites(a.exprId) + case a: AttributeReference if attributeRewrites.contains(a.exprId) => + attributeRewrites(a.exprId) }) logDebug(s"Considering join on: $condition") From 76a8566e9506e6a41107c8cb244a76d4525b7a44 Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Wed, 16 Mar 2016 10:04:32 +0000 Subject: [PATCH 07/17] Fix test. --- .../sql/catalyst/plans/logical/basicOperators.scala | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala index 243d04bb4fcd2..c3ac41fea376a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala @@ -259,25 +259,25 @@ case class Join( condition: Option[Expression]) extends BinaryNode with PredicateHelper { - private def notNullsFromCondition: Set[Expression] = + private lazy val notNullsFromCondition: Set[Expression] = constructIsNotNullConstraints(condition.toSet.flatMap(splitConjunctivePredicates)) .filter(_.references.nonEmpty) - private def leftNotNulls = left.constraints.union(notNullsFromCondition) + private lazy val leftNotNulls = left.constraints.union(notNullsFromCondition) .filter(_.isInstanceOf[IsNotNull]) .filter(_.references.subsetOf(left.outputSet)) .flatMap(_.references.map(_.exprId)) - private def notNullLeftOutput = left.output.map { o => + private lazy val notNullLeftOutput = left.output.map { o => if (leftNotNulls.contains(o.exprId)) o.withNullability(false) else o } - private def rightNotNulls = right.constraints.union(notNullsFromCondition) + private lazy val rightNotNulls = right.constraints.union(notNullsFromCondition) .filter(_.isInstanceOf[IsNotNull]) .filter(_.references.subsetOf(right.outputSet)) .flatMap(_.references.map(_.exprId)) - private def notNullRightOutput = right.output.map { o => + private lazy val notNullRightOutput = right.output.map { o => if (rightNotNulls.contains(o.exprId)) o.withNullability(false) else o } From 6b4a98cdd44a447e6ab8c3aee908647243e62449 Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Thu, 17 Mar 2016 07:54:05 +0000 Subject: [PATCH 08/17] fix bug. --- .../sql/catalyst/optimizer/Optimizer.scala | 58 +++++++++++++------ .../sql/catalyst/planning/patterns.scala | 3 +- .../InferFiltersFromConstraintsSuite.scala | 6 +- 3 files changed, 46 insertions(+), 21 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala index 3f57b0758eaff..c64e9f5947263 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala @@ -827,7 +827,7 @@ object CombineFilters extends Rule[LogicalPlan] with PredicateHelper { (ExpressionSet(splitConjunctivePredicates(fc)) -- ExpressionSet(splitConjunctivePredicates(nc))).reduceOption(And) match { case Some(ac) => - Filter(And(ac, nc), grandChild) + Filter(RewriteAttributes.execute(grandChild.outputSet, And(ac, nc)), grandChild) case None => nf } @@ -1113,29 +1113,36 @@ object PushPredicateThroughJoin extends Rule[LogicalPlan] with PredicateHelper { joinType match { case Inner => // push down the single side `where` condition into respective sides - val newLeft = leftFilterConditions. - reduceLeftOption(And).map(Filter(_, left)).getOrElse(left) - val newRight = rightFilterConditions. - reduceLeftOption(And).map(Filter(_, right)).getOrElse(right) + val newLeft = leftFilterConditions.reduceLeftOption(And) + .map(x => Filter(RewriteAttributes.execute(left.outputSet, x), left)) + .getOrElse(left) + val newRight = rightFilterConditions.reduceLeftOption(And) + .map(x => Filter(RewriteAttributes.execute(right.outputSet, x), right)) + .getOrElse(right) val newJoinCond = (commonFilterCondition ++ joinCondition).reduceLeftOption(And) + .map(RewriteAttributes.execute(newLeft.outputSet ++ newRight.outputSet, _)) Join(newLeft, newRight, Inner, newJoinCond) case RightOuter => // push down the right side only `where` condition val newLeft = left - val newRight = rightFilterConditions. - reduceLeftOption(And).map(Filter(_, right)).getOrElse(right) + val newRight = rightFilterConditions.reduceLeftOption(And) + .map(x => Filter(RewriteAttributes.execute(right.outputSet, x), right)) + .getOrElse(right) val newJoinCond = joinCondition + .map(RewriteAttributes.execute(newLeft.outputSet ++ newRight.outputSet, _)) val newJoin = Join(newLeft, newRight, RightOuter, newJoinCond) (leftFilterConditions ++ commonFilterCondition). reduceLeftOption(And).map(Filter(_, newJoin)).getOrElse(newJoin) case _ @ (LeftOuter | LeftSemi) => // push down the left side only `where` condition - val newLeft = leftFilterConditions. - reduceLeftOption(And).map(Filter(_, left)).getOrElse(left) + val newLeft = leftFilterConditions.reduceLeftOption(And) + .map(x => Filter(RewriteAttributes.execute(left.outputSet, x), left)) + .getOrElse(left) val newRight = right val newJoinCond = joinCondition + .map(RewriteAttributes.execute(newLeft.outputSet ++ newRight.outputSet, _)) val newJoin = Join(newLeft, newRight, joinType, newJoinCond) (rightFilterConditions ++ commonFilterCondition). @@ -1152,27 +1159,34 @@ object PushPredicateThroughJoin extends Rule[LogicalPlan] with PredicateHelper { joinType match { case _ @ (Inner | LeftSemi) => // push down the single side only join filter for both sides sub queries - val newLeft = leftJoinConditions. - reduceLeftOption(And).map(Filter(_, left)).getOrElse(left) - val newRight = rightJoinConditions. - reduceLeftOption(And).map(Filter(_, right)).getOrElse(right) + val newLeft = leftJoinConditions.reduceLeftOption(And) + .map(x => Filter(RewriteAttributes.execute(left.outputSet, x), left)) + .getOrElse(left) + val newRight = rightJoinConditions.reduceLeftOption(And) + .map(x => Filter(RewriteAttributes.execute(right.outputSet, x), right)) + .getOrElse(right) val newJoinCond = commonJoinCondition.reduceLeftOption(And) + .map(RewriteAttributes.execute(newLeft.outputSet ++ newRight.outputSet, _)) Join(newLeft, newRight, joinType, newJoinCond) case RightOuter => // push down the left side only join filter for left side sub query - val newLeft = leftJoinConditions. - reduceLeftOption(And).map(Filter(_, left)).getOrElse(left) + val newLeft = leftJoinConditions.reduceLeftOption(And) + .map(x => Filter(RewriteAttributes.execute(left.outputSet, x), left)) + .getOrElse(left) val newRight = right val newJoinCond = (rightJoinConditions ++ commonJoinCondition).reduceLeftOption(And) + .map(RewriteAttributes.execute(newLeft.outputSet ++ newRight.outputSet, _)) Join(newLeft, newRight, RightOuter, newJoinCond) case LeftOuter => // push down the right side only join filter for right sub query val newLeft = left - val newRight = rightJoinConditions. - reduceLeftOption(And).map(Filter(_, right)).getOrElse(right) + val newRight = rightJoinConditions.reduceLeftOption(And) + .map(x => Filter(RewriteAttributes.execute(right.outputSet, x), right)) + .getOrElse(right) val newJoinCond = (leftJoinConditions ++ commonJoinCondition).reduceLeftOption(And) + .map(RewriteAttributes.execute(newLeft.outputSet ++ newRight.outputSet, _)) Join(newLeft, newRight, LeftOuter, newJoinCond) case FullOuter => f @@ -1331,3 +1345,13 @@ object ComputeCurrentTime extends Rule[LogicalPlan] { } } } + +object RewriteAttributes { + def execute(outputs: AttributeSet, expr: Expression) = { + val attributeRewrites = outputs.map(o => o.exprId -> o).toMap + expr.transform { + case a: AttributeReference if attributeRewrites.contains(a.exprId) => + attributeRewrites(a.exprId) + } + } +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala index 3e9e761eb7d8d..d5d08ebc7e8d9 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala @@ -102,7 +102,8 @@ object ExtractEquiJoinKeys extends Logging with PredicateHelper { def unapply(plan: LogicalPlan): Option[ReturnType] = plan match { case join @ Join(left, right, joinType, cond) => - val attributeRewrites = join.output.map(o => o.exprId -> o).toMap + val attributeRewrites = + (join.left.outputSet ++ join.right.outputSet).map(o => o.exprId -> o).toMap val condition = cond.map(_.transform { case a: AttributeReference if attributeRewrites.contains(a.exprId) => attributeRewrites(a.exprId) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/InferFiltersFromConstraintsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/InferFiltersFromConstraintsSuite.scala index 9c21582865b14..bafb291c8b7ed 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/InferFiltersFromConstraintsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/InferFiltersFromConstraintsSuite.scala @@ -71,10 +71,11 @@ class InferFiltersFromConstraintsSuite extends PlanTest { val right = y.where(IsNotNull('a) && IsNotNull('c) && "y.c".attr > 5 && "y.a".attr === 1) val correctAnswer = left.join(right, condition = Some("x.a".attr === "y.a".attr)).analyze val optimized = Optimize.execute(originalQuery) - //comparePlans(optimized, correctAnswer) + comparePlans(optimized, correctAnswer) checkNullability(optimized) + checkNullability(correctAnswer) } - /* + test("single inner join: filter out nulls on either side on non equal keys") { val x = testRelation.subquery('x) val y = testRelation.subquery('y) @@ -152,5 +153,4 @@ class InferFiltersFromConstraintsSuite extends PlanTest { checkNullability(correctAnswer) checkNullability(optimized) } - */ } From 665bf50880b3f87353c5037946db28f8dda7d1c2 Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Thu, 17 Mar 2016 11:41:35 +0000 Subject: [PATCH 09/17] Fix scala style. --- .../org/apache/spark/sql/catalyst/optimizer/Optimizer.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala index c64e9f5947263..65179283b2d6d 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala @@ -1347,7 +1347,7 @@ object ComputeCurrentTime extends Rule[LogicalPlan] { } object RewriteAttributes { - def execute(outputs: AttributeSet, expr: Expression) = { + def execute(outputs: AttributeSet, expr: Expression): Expression = { val attributeRewrites = outputs.map(o => o.exprId -> o).toMap expr.transform { case a: AttributeReference if attributeRewrites.contains(a.exprId) => From ac795610e2091a1534b80e7eea01630c4cb5deb4 Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Fri, 18 Mar 2016 07:45:12 +0000 Subject: [PATCH 10/17] Fix it. --- .../plans/logical/basicOperators.scala | 26 +++++++++++++------ 1 file changed, 18 insertions(+), 8 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala index c3ac41fea376a..8c7a5e51b5279 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala @@ -263,19 +263,29 @@ case class Join( constructIsNotNullConstraints(condition.toSet.flatMap(splitConjunctivePredicates)) .filter(_.references.nonEmpty) - private lazy val leftNotNulls = left.constraints.union(notNullsFromCondition) - .filter(_.isInstanceOf[IsNotNull]) - .filter(_.references.subsetOf(left.outputSet)) - .flatMap(_.references.map(_.exprId)) + private lazy val leftNotNulls = { + val constraints = joinType match { + case Inner | LeftSemi => left.constraints.union(notNullsFromCondition) + case _ => left.constraints + } + constraints.filter(_.isInstanceOf[IsNotNull]) + .filter(_.references.subsetOf(left.outputSet)) + .flatMap(_.references.map(_.exprId)) + } private lazy val notNullLeftOutput = left.output.map { o => if (leftNotNulls.contains(o.exprId)) o.withNullability(false) else o } - private lazy val rightNotNulls = right.constraints.union(notNullsFromCondition) - .filter(_.isInstanceOf[IsNotNull]) - .filter(_.references.subsetOf(right.outputSet)) - .flatMap(_.references.map(_.exprId)) + private lazy val rightNotNulls = { + val constraints = joinType match { + case Inner => right.constraints.union(notNullsFromCondition) + case _ => right.constraints + } + constraints.filter(_.isInstanceOf[IsNotNull]) + .filter(_.references.subsetOf(right.outputSet)) + .flatMap(_.references.map(_.exprId)) + } private lazy val notNullRightOutput = right.output.map { o => if (rightNotNulls.contains(o.exprId)) o.withNullability(false) else o From 7f68967eeb3f303c552dadc760788d3fe9d090f5 Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Sat, 19 Mar 2016 04:33:06 +0000 Subject: [PATCH 11/17] Modify attribute nullability for filter pushdown. --- .../sql/catalyst/optimizer/Optimizer.scala | 38 ++++++++++++++++--- .../sql/catalyst/analysis/AnalysisSuite.scala | 21 +++++----- 2 files changed, 44 insertions(+), 15 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala index 65179283b2d6d..870e8929f9750 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala @@ -898,7 +898,10 @@ object PushPredicateThroughProject extends Rule[LogicalPlan] with PredicateHelpe // If there is no nondeterministic conditions, push down the whole condition. if (nondeterministic.isEmpty) { - project.copy(child = Filter(replaceAlias(condition, aliasMap), grandChild)) + val newFilter = Filter(replaceAlias(condition, aliasMap), grandChild) + val newFields = fields.map(RewriteAttributes.execute(newFilter.outputSet, _)) + .asInstanceOf[Seq[NamedExpression]] + project.copy(projectList = newFields, child = newFilter) } else { // If they are all nondeterministic conditions, leave it un-changed. if (deterministic.isEmpty) { @@ -907,8 +910,11 @@ object PushPredicateThroughProject extends Rule[LogicalPlan] with PredicateHelpe // Push down the small conditions without nondeterministic expressions. val pushedCondition = deterministic.map(replaceAlias(_, aliasMap)).reduce(And) + val newFilter = Filter(pushedCondition, grandChild) + val newFields = fields.map(RewriteAttributes.execute(newFilter.outputSet, _)) + .asInstanceOf[Seq[NamedExpression]] Filter(nondeterministic.reduce(And), - project.copy(child = Filter(pushedCondition, grandChild))) + project.copy(projectList = newFields, child = newFilter)) } } } @@ -930,9 +936,17 @@ object PushPredicateThroughGenerate extends Rule[LogicalPlan] with PredicateHelp } if (pushDown.nonEmpty) { val pushDownPredicate = pushDown.reduce(And) + val newFilter = Filter(pushDownPredicate, g.child) + val newGeneratorOutput = + g.generatorOutput.map(RewriteAttributes.execute(newFilter.outputSet, _)) val newGenerate = Generate(g.generator, join = g.join, outer = g.outer, - g.qualifier, g.generatorOutput, Filter(pushDownPredicate, g.child)) - if (stayUp.isEmpty) newGenerate else Filter(stayUp.reduce(And), newGenerate) + g.qualifier, g.generatorOutput, newFilter) + if (stayUp.isEmpty) { + newGenerate + } else { + val newCondition = RewriteAttributes.execute(newGenerate.outputSet, stayUp.reduce(And)) + Filter(newCondition, newGenerate) + } } else { filter } @@ -964,10 +978,22 @@ object PushPredicateThroughAggregate extends Rule[LogicalPlan] with PredicateHel if (pushDown.nonEmpty) { val pushDownPredicate = pushDown.reduce(And) val replaced = replaceAlias(pushDownPredicate, aliasMap) - val newAggregate = aggregate.copy(child = Filter(replaced, aggregate.child)) + val newFilter = Filter(replaced, aggregate.child) + val newGroupingExpressions = + aggregate.groupingExpressions.map(RewriteAttributes.execute(newFilter.outputSet, _)) + val newAggregateExpressions = + aggregate.aggregateExpressions.map(RewriteAttributes.execute(newFilter.outputSet, _)) + .asInstanceOf[Seq[NamedExpression]] + val newAggregate = aggregate.copy(groupingExpressions = newGroupingExpressions, + aggregateExpressions = newAggregateExpressions, child = newFilter) // If there is no more filter to stay up, just eliminate the filter. // Otherwise, create "Filter(stayUp) <- Aggregate <- Filter(pushDownPredicate)". - if (stayUp.isEmpty) newAggregate else Filter(stayUp.reduce(And), newAggregate) + if (stayUp.isEmpty) { + newAggregate + } else { + val newCondition = RewriteAttributes.execute(newAggregate.outputSet, stayUp.reduce(And)) + Filter(newCondition, newAggregate) + } } else { filter } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala index 8b568b6dd6acd..efb9507c3e3bd 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala @@ -80,6 +80,8 @@ class AnalysisSuite extends AnalysisTest { test("resolve sort references - filter/limit") { val a = testRelation2.output(0) val b = testRelation2.output(1) + val aNotNullable = testRelation2.output(0).withNullability(false) + val bNotNullable = testRelation2.output(1).withNullability(false) val c = testRelation2.output(2) // Case 1: one missing attribute is in the leaf node and another is in the unary node @@ -88,10 +90,10 @@ class AnalysisSuite extends AnalysisTest { .where('b > "str").select('a) .sortBy('b.asc, 'c.desc) val expected1 = testRelation2 - .where(a > "str").select(a, b, c) - .where(b > "str").select(a, b, c) - .sortBy(b.asc, c.desc) - .select(a) + .where(a > "str").select(aNotNullable, b, c) + .where(b > "str").select(aNotNullable, bNotNullable, c) + .sortBy(bNotNullable.asc, c.desc) + .select(aNotNullable) checkAnalysis(plan1, expected1) // Case 2: all the missing attributes are in the leaf node @@ -100,15 +102,16 @@ class AnalysisSuite extends AnalysisTest { .where('a > "str").select('a) .sortBy('b.asc, 'c.desc) val expected2 = testRelation2 - .where(a > "str").select(a, b, c) - .where(a > "str").select(a, b, c) + .where(a > "str").select(aNotNullable, b, c) + .where(aNotNullable > "str").select(aNotNullable, b, c) .sortBy(b.asc, c.desc) - .select(a) + .select(aNotNullable) checkAnalysis(plan2, expected2) } test("resolve sort references - join") { val a = testRelation2.output(0) + val aNotNullable = testRelation2.output(0).withNullability(false) val b = testRelation2.output(1) val c = testRelation2.output(2) val h = testRelation3.output(3) @@ -118,9 +121,9 @@ class AnalysisSuite extends AnalysisTest { .where('a > "str").select('a, 'b) .sortBy('c.desc, 'h.asc) val expected = testRelation2.join(testRelation3) - .where(a > "str").select(a, b, c, h) + .where(a > "str").select(aNotNullable, b, c, h) .sortBy(c.desc, h.asc) - .select(a, b) + .select(aNotNullable, b) checkAnalysis(plan, expected) } From a7b8daef9e82e184226f101a5fd81fcc070dc25c Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Mon, 21 Mar 2016 04:32:25 +0000 Subject: [PATCH 12/17] Reset nullabilty for project and filter list in preparing scaning in memory relation. --- .../scala/org/apache/spark/sql/execution/SparkPlanner.scala | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlanner.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlanner.scala index 9da2c74c62fc6..5296175960e3e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlanner.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlanner.scala @@ -65,8 +65,9 @@ class SparkPlanner( prunePushedDownFilters: Seq[Expression] => Seq[Expression], scanBuilder: Seq[Attribute] => SparkPlan): SparkPlan = { - val projectSet = AttributeSet(projectList.flatMap(_.references)) - val filterSet = AttributeSet(filterPredicates.flatMap(_.references)) + val projectSet = AttributeSet(projectList.flatMap(_.references).map(_.withNullability(true))) + val filterSet = + AttributeSet(filterPredicates.flatMap(_.references).map(_.withNullability(true))) val filterCondition: Option[Expression] = prunePushedDownFilters(filterPredicates).reduceLeftOption(catalyst.expressions.And) From 23b328d1c01806841943ad8dd0ab3eed8963d7e2 Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Mon, 21 Mar 2016 05:02:21 +0000 Subject: [PATCH 13/17] Unnecessary change removed. --- .../apache/spark/sql/catalyst/planning/patterns.scala | 9 +-------- 1 file changed, 1 insertion(+), 8 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala index d5d08ebc7e8d9..7302b63646d66 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala @@ -101,14 +101,7 @@ object ExtractEquiJoinKeys extends Logging with PredicateHelper { (JoinType, Seq[Expression], Seq[Expression], Option[Expression], LogicalPlan, LogicalPlan) def unapply(plan: LogicalPlan): Option[ReturnType] = plan match { - case join @ Join(left, right, joinType, cond) => - val attributeRewrites = - (join.left.outputSet ++ join.right.outputSet).map(o => o.exprId -> o).toMap - val condition = cond.map(_.transform { - case a: AttributeReference if attributeRewrites.contains(a.exprId) => - attributeRewrites(a.exprId) - }) - + case join @ Join(left, right, joinType, condition) => logDebug(s"Considering join on: $condition") // Find equi-join predicates that can be evaluated before the join, and thus can be used // as join keys. From da3f35b4d315cc3c2576ac781cd3e8beef5eb774 Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Mon, 21 Mar 2016 08:16:21 +0000 Subject: [PATCH 14/17] Fix python test. --- sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala index ac2ca3c5a35d7..2c41548ec77cd 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala @@ -1517,7 +1517,10 @@ class Dataset[T] private[sql]( } val attrs = this.logicalPlan.output val colsAfterDrop = attrs.filter { attr => - attr != expression + expression match { + case ar: AttributeReference => attr.exprId != ar.exprId + case _ => true + } }.map(attr => Column(attr)) select(colsAfterDrop : _*) } From cdc5878e4ccb98f87b4496b85a2ef95e1722a1f6 Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Thu, 24 Mar 2016 03:09:45 +0000 Subject: [PATCH 15/17] Correct output nullability of logical plans. --- .../sql/catalyst/analysis/unresolved.scala | 2 +- .../sql/catalyst/catalog/interface.scala | 2 +- .../spark/sql/catalyst/plans/QueryPlan.scala | 27 ++++- .../plans/logical/LocalRelation.scala | 2 +- .../plans/logical/ScriptTransformation.scala | 4 +- .../plans/logical/basicOperators.scala | 114 ++++++------------ .../sql/catalyst/plans/logical/commands.scala | 4 +- .../sql/catalyst/plans/logical/object.scala | 4 +- .../catalyst/plans/logical/partitioning.scala | 2 +- .../analysis/AnalysisErrorSuite.scala | 2 +- .../sql/catalyst/trees/TreeNodeSuite.scala | 2 +- .../spark/sql/execution/ExistingRDD.scala | 6 +- .../apache/spark/sql/execution/Expand.scala | 4 +- .../apache/spark/sql/execution/Generate.scala | 7 +- .../spark/sql/execution/LocalTableScan.scala | 2 +- .../spark/sql/execution/SparkPlan.scala | 3 + .../spark/sql/execution/basicOperators.scala | 4 +- .../columnar/InMemoryColumnarTableScan.scala | 2 +- .../sql/execution/command/commands.scala | 22 ++-- .../datasources/LogicalRelation.scala | 4 +- .../spark/sql/execution/datasources/ddl.scala | 6 +- .../spark/sql/execution/debug/package.scala | 2 +- .../python/BatchPythonEvaluation.scala | 5 +- .../sql/execution/python/EvaluatePython.scala | 2 +- .../streaming/StreamingRelation.scala | 3 +- .../spark/sql/ExtraStrategiesSuite.scala | 2 +- .../spark/sql/hive/HiveMetastoreCatalog.scala | 4 +- .../org/apache/spark/sql/hive/HiveQl.scala | 6 +- .../apache/spark/sql/hive/SQLBuilder.scala | 2 +- .../hive/execution/InsertIntoHiveTable.scala | 2 +- .../hive/execution/ScriptTransformation.scala | 2 +- 31 files changed, 123 insertions(+), 132 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala index 01afa01ae95c5..0a15f467636ac 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala @@ -43,7 +43,7 @@ case class UnresolvedRelation( /** Returns a `.` separated name for this relation. */ def tableName: String = tableIdentifier.unquotedString - override def output: Seq[Attribute] = Nil + override def outputBeforeConstraints: Seq[Attribute] = Nil override lazy val resolved = false } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/interface.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/interface.scala index c4e49614c5c35..774c9aaf8ea52 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/interface.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/interface.scala @@ -286,7 +286,7 @@ case class CatalogRelation( extends LeafNode { // TODO: implement this - override def output: Seq[Attribute] = Seq.empty + override def outputBeforeConstraints: Seq[Attribute] = Seq.empty require(metadata.name.database == Some(db), "provided database does not much the one specified in the table definition") diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala index d966562348457..463d7436fc279 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala @@ -24,18 +24,32 @@ import org.apache.spark.sql.types.{DataType, StructType} abstract class QueryPlan[PlanType <: QueryPlan[PlanType]] extends TreeNode[PlanType] { self: PlanType => - def output: Seq[Attribute] + def output: Seq[Attribute] = { + val isNotNulls = constraints.collect { + case IsNotNull(e) if e.isInstanceOf[Attribute] => e.asInstanceOf[Attribute].exprId + } + outputBeforeConstraints.map { o => + if (isNotNulls.contains(o.exprId)) { + o.withNullability(false) + } else { + o + } + } + } + + protected def outputBeforeConstraints: Seq[Attribute] /** * Extracts the relevant constraints from a given set of constraints based on the attributes that - * appear in the [[outputSet]]. + * appear in the [[outputSetBeforeConstraints]]. */ protected def getRelevantConstraints(constraints: Set[Expression]): Set[Expression] = { constraints .union(inferAdditionalConstraints(constraints)) .union(constructIsNotNullConstraints(constraints)) .filter(constraint => - constraint.references.nonEmpty && constraint.references.subsetOf(outputSet)) + constraint.references.nonEmpty && + constraint.references.subsetOf(outputSetBeforeConstraints)) } /** @@ -43,7 +57,7 @@ abstract class QueryPlan[PlanType <: QueryPlan[PlanType]] extends TreeNode[PlanT * For e.g., if an expression is of the form (`a > 5`), this returns a constraint of the form * `isNotNull(a)` */ - protected def constructIsNotNullConstraints(constraints: Set[Expression]): Set[Expression] = { + private def constructIsNotNullConstraints(constraints: Set[Expression]): Set[Expression] = { // Currently we only propagate constraints if the condition consists of equality // and ranges. For all other cases, we return an empty set of constraints constraints.map { @@ -106,6 +120,11 @@ abstract class QueryPlan[PlanType <: QueryPlan[PlanType]] extends TreeNode[PlanT */ def outputSet: AttributeSet = AttributeSet(output) + /** + * Returns the set of attributes before affected by constraints that are output by this node. + */ + def outputSetBeforeConstraints: AttributeSet = AttributeSet(outputBeforeConstraints) + /** * All Attributes that appear in expressions from this operator. Note that this set does not * include attributes that are implicitly referenced by being passed through to the output tuple. diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LocalRelation.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LocalRelation.scala index 5813b74c770d8..7109d378e07a0 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LocalRelation.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LocalRelation.scala @@ -42,7 +42,7 @@ object LocalRelation { } } -case class LocalRelation(output: Seq[Attribute], data: Seq[InternalRow] = Nil) +case class LocalRelation(outputBeforeConstraints: Seq[Attribute], data: Seq[InternalRow] = Nil) extends LeafNode with analysis.MultiInstanceRelation { // A local relation must have resolved output. diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/ScriptTransformation.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/ScriptTransformation.scala index 578027da776e5..c2f02dce3dd07 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/ScriptTransformation.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/ScriptTransformation.scala @@ -24,13 +24,13 @@ import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeSet, Expre * * @param input the set of expression that should be passed to the script. * @param script the command that should be executed. - * @param output the attributes that are produced by the script. + * @param outputBeforeConstraints the attributes that are produced by the script. * @param ioschema the input and output schema applied in the execution of the script. */ case class ScriptTransformation( input: Seq[Expression], script: String, - output: Seq[Attribute], + outputBeforeConstraints: Seq[Attribute], child: LogicalPlan, ioschema: ScriptInputOutputSchema) extends UnaryNode { override def references: AttributeSet = AttributeSet(input.flatMap(_.references)) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala index 8c7a5e51b5279..71405833ee42f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala @@ -33,11 +33,11 @@ import org.apache.spark.sql.types._ * at the top of the logical query plan. */ case class ReturnAnswer(child: LogicalPlan) extends UnaryNode { - override def output: Seq[Attribute] = child.output + override def outputBeforeConstraints: Seq[Attribute] = child.output } case class Project(projectList: Seq[NamedExpression], child: LogicalPlan) extends UnaryNode { - override def output: Seq[Attribute] = projectList.map(_.toAttribute) + override def outputBeforeConstraints: Seq[Attribute] = projectList.map(_.toAttribute) override def maxRows: Option[Long] = child.maxRows override lazy val resolved: Boolean = { @@ -91,7 +91,7 @@ case class Generate( override def producedAttributes: AttributeSet = AttributeSet(generatorOutput) - def output: Seq[Attribute] = { + def outputBeforeConstraints: Seq[Attribute] = { val qualified = qualifier.map(q => // prepend the new qualifier to the existed one generatorOutput.map(a => a.withQualifiers(q +: a.qualifiers)) @@ -103,16 +103,7 @@ case class Generate( case class Filter(condition: Expression, child: LogicalPlan) extends UnaryNode with PredicateHelper { - - private def notNulls = - constructIsNotNullConstraints(splitConjunctivePredicates(condition).toSet) - .filter(x => x.isInstanceOf[IsNotNull] && x.references.nonEmpty) - .filter(_.references.subsetOf(child.outputSet)) - .flatMap(_.references.map(_.exprId)) - - override def output: Seq[Attribute] = child.output.map { o => - if (notNulls.contains(o.exprId)) o.withNullability(false) else o - } + override def outputBeforeConstraints: Seq[Attribute] = child.output override def maxRows: Option[Long] = child.maxRows @@ -141,7 +132,7 @@ case class Intersect(left: LogicalPlan, right: LogicalPlan) extends SetOperation def duplicateResolved: Boolean = left.outputSet.intersect(right.outputSet).isEmpty - override def output: Seq[Attribute] = + override def outputBeforeConstraints: Seq[Attribute] = left.output.zip(right.output).map { case (leftAttr, rightAttr) => leftAttr.withNullability(leftAttr.nullable && rightAttr.nullable) } @@ -175,7 +166,7 @@ case class Intersect(left: LogicalPlan, right: LogicalPlan) extends SetOperation case class Except(left: LogicalPlan, right: LogicalPlan) extends SetOperation(left, right) { /** We don't use right.output because those rows get excluded from the set. */ - override def output: Seq[Attribute] = left.output + override def outputBeforeConstraints: Seq[Attribute] = left.output override protected def validConstraints: Set[Expression] = leftConstraints @@ -206,7 +197,7 @@ case class Union(children: Seq[LogicalPlan]) extends LogicalPlan { } // updating nullability to make all the children consistent - override def output: Seq[Attribute] = + override def outputBeforeConstraints: Seq[Attribute] = children.map(_.output).transpose.map(attrs => attrs.head.withNullability(attrs.exists(_.nullable))) @@ -259,50 +250,18 @@ case class Join( condition: Option[Expression]) extends BinaryNode with PredicateHelper { - private lazy val notNullsFromCondition: Set[Expression] = - constructIsNotNullConstraints(condition.toSet.flatMap(splitConjunctivePredicates)) - .filter(_.references.nonEmpty) - - private lazy val leftNotNulls = { - val constraints = joinType match { - case Inner | LeftSemi => left.constraints.union(notNullsFromCondition) - case _ => left.constraints - } - constraints.filter(_.isInstanceOf[IsNotNull]) - .filter(_.references.subsetOf(left.outputSet)) - .flatMap(_.references.map(_.exprId)) - } - - private lazy val notNullLeftOutput = left.output.map { o => - if (leftNotNulls.contains(o.exprId)) o.withNullability(false) else o - } - - private lazy val rightNotNulls = { - val constraints = joinType match { - case Inner => right.constraints.union(notNullsFromCondition) - case _ => right.constraints - } - constraints.filter(_.isInstanceOf[IsNotNull]) - .filter(_.references.subsetOf(right.outputSet)) - .flatMap(_.references.map(_.exprId)) - } - - private lazy val notNullRightOutput = right.output.map { o => - if (rightNotNulls.contains(o.exprId)) o.withNullability(false) else o - } - - override def output: Seq[Attribute] = { + override def outputBeforeConstraints: Seq[Attribute] = { joinType match { case LeftSemi => - notNullLeftOutput + left.output case LeftOuter => - notNullLeftOutput ++ right.output.map(_.withNullability(true)) + left.output ++ right.output.map(_.withNullability(true)) case RightOuter => - left.output.map(_.withNullability(true)) ++ notNullRightOutput + left.output.map(_.withNullability(true)) ++ right.output case FullOuter => left.output.map(_.withNullability(true)) ++ right.output.map(_.withNullability(true)) case _ => - notNullLeftOutput ++ notNullRightOutput + left.output ++ right.output } } @@ -351,7 +310,7 @@ case class Join( * A hint for the optimizer that we should broadcast the `child` if used in a join operator. */ case class BroadcastHint(child: LogicalPlan) extends UnaryNode { - override def output: Seq[Attribute] = child.output + override def outputBeforeConstraints: Seq[Attribute] = child.output // We manually set statistics of BroadcastHint to smallest value to make sure // the plan wrapped by BroadcastHint will be considered to broadcast later. @@ -367,7 +326,7 @@ case class InsertIntoTable( extends LogicalPlan { override def children: Seq[LogicalPlan] = child :: Nil - override def output: Seq[Attribute] = Seq.empty + override def outputBeforeConstraints: Seq[Attribute] = Seq.empty assert(overwrite || !ifNotExists) override lazy val resolved: Boolean = childrenResolved && child.output.zip(table.output).forall { @@ -386,13 +345,13 @@ case class InsertIntoTable( * value is the CTE definition. */ case class With(child: LogicalPlan, cteRelations: Map[String, SubqueryAlias]) extends UnaryNode { - override def output: Seq[Attribute] = child.output + override def outputBeforeConstraints: Seq[Attribute] = child.output } case class WithWindowDefinition( windowDefinitions: Map[String, WindowSpecDefinition], child: LogicalPlan) extends UnaryNode { - override def output: Seq[Attribute] = child.output + override def outputBeforeConstraints: Seq[Attribute] = child.output } /** @@ -405,7 +364,7 @@ case class Sort( order: Seq[SortOrder], global: Boolean, child: LogicalPlan) extends UnaryNode { - override def output: Seq[Attribute] = child.output + override def outputBeforeConstraints: Seq[Attribute] = child.output override def maxRows: Option[Long] = child.maxRows } @@ -422,7 +381,7 @@ case class Range( end: Long, step: Long, numSlices: Int, - output: Seq[Attribute]) extends LeafNode with MultiInstanceRelation { + outputBeforeConstraints: Seq[Attribute]) extends LeafNode with MultiInstanceRelation { require(step != 0, "step cannot be 0") val numElements: BigInt = { val safeStart = BigInt(start) @@ -459,7 +418,7 @@ case class Aggregate( !expressions.exists(!_.resolved) && childrenResolved && !hasWindowExpressions } - override def output: Seq[Attribute] = aggregateExpressions.map(_.toAttribute) + override def outputBeforeConstraints: Seq[Attribute] = aggregateExpressions.map(_.toAttribute) override def maxRows: Option[Long] = child.maxRows override def validConstraints: Set[Expression] = @@ -480,7 +439,7 @@ case class Window( orderSpec: Seq[SortOrder], child: LogicalPlan) extends UnaryNode { - override def output: Seq[Attribute] = + override def outputBeforeConstraints: Seq[Attribute] = child.output ++ windowExpressions.map(_.toAttribute) def windowOutputSet: AttributeSet = AttributeSet(windowExpressions.map(_.toAttribute)) @@ -552,12 +511,12 @@ private[sql] object Expand { * a input row. * * @param projections to apply - * @param output of all projections. + * @param outputBeforeConstraints of all projections. * @param child operator. */ case class Expand( projections: Seq[Seq[Expression]], - output: Seq[Attribute], + outputBeforeConstraints: Seq[Attribute], child: LogicalPlan) extends UnaryNode { override def references: AttributeSet = @@ -589,7 +548,7 @@ case class GroupingSets( child: LogicalPlan, aggregations: Seq[NamedExpression]) extends UnaryNode { - override def output: Seq[Attribute] = aggregations.map(_.toAttribute) + override def outputBeforeConstraints: Seq[Attribute] = aggregations.map(_.toAttribute) // Needs to be unresolved before its translated to Aggregate + Expand because output attributes // will change in analysis. @@ -602,12 +561,14 @@ case class Pivot( pivotValues: Seq[Literal], aggregates: Seq[Expression], child: LogicalPlan) extends UnaryNode { - override def output: Seq[Attribute] = groupByExprs.map(_.toAttribute) ++ aggregates match { - case agg :: Nil => pivotValues.map(value => AttributeReference(value.toString, agg.dataType)()) - case _ => pivotValues.flatMap{ value => - aggregates.map(agg => AttributeReference(value + "_" + agg.sql, agg.dataType)()) + override def outputBeforeConstraints: Seq[Attribute] = + groupByExprs.map(_.toAttribute) ++ aggregates match { + case agg :: Nil => + pivotValues.map(value => AttributeReference(value.toString, agg.dataType)()) + case _ => pivotValues.flatMap{ value => + aggregates.map(agg => AttributeReference(value + "_" + agg.sql, agg.dataType)()) + } } - } } object Limit { @@ -624,7 +585,7 @@ object Limit { } case class GlobalLimit(limitExpr: Expression, child: LogicalPlan) extends UnaryNode { - override def output: Seq[Attribute] = child.output + override def outputBeforeConstraints: Seq[Attribute] = child.output override def maxRows: Option[Long] = { limitExpr match { case IntegerLiteral(limit) => Some(limit) @@ -639,7 +600,7 @@ case class GlobalLimit(limitExpr: Expression, child: LogicalPlan) extends UnaryN } case class LocalLimit(limitExpr: Expression, child: LogicalPlan) extends UnaryNode { - override def output: Seq[Attribute] = child.output + override def outputBeforeConstraints: Seq[Attribute] = child.output override def maxRows: Option[Long] = { limitExpr match { case IntegerLiteral(limit) => Some(limit) @@ -655,7 +616,8 @@ case class LocalLimit(limitExpr: Expression, child: LogicalPlan) extends UnaryNo case class SubqueryAlias(alias: String, child: LogicalPlan) extends UnaryNode { - override def output: Seq[Attribute] = child.output.map(_.withQualifiers(alias :: Nil)) + override def outputBeforeConstraints: Seq[Attribute] = + child.output.map(_.withQualifiers(alias :: Nil)) } /** @@ -677,7 +639,7 @@ case class Sample( child: LogicalPlan)( val isTableSample: java.lang.Boolean = false) extends UnaryNode { - override def output: Seq[Attribute] = child.output + override def outputBeforeConstraints: Seq[Attribute] = child.output override def statistics: Statistics = { val ratio = upperBound - lowerBound @@ -697,7 +659,7 @@ case class Sample( */ case class Distinct(child: LogicalPlan) extends UnaryNode { override def maxRows: Option[Long] = child.maxRows - override def output: Seq[Attribute] = child.output + override def outputBeforeConstraints: Seq[Attribute] = child.output } /** @@ -708,7 +670,7 @@ case class Distinct(child: LogicalPlan) extends UnaryNode { */ case class Repartition(numPartitions: Int, shuffle: Boolean, child: LogicalPlan) extends UnaryNode { - override def output: Seq[Attribute] = child.output + override def outputBeforeConstraints: Seq[Attribute] = child.output } /** @@ -716,7 +678,7 @@ case class Repartition(numPartitions: Int, shuffle: Boolean, child: LogicalPlan) */ case object OneRowRelation extends LeafNode { override def maxRows: Option[Long] = Some(1) - override def output: Seq[Attribute] = Nil + override def outputBeforeConstraints: Seq[Attribute] = Nil /** * Computes [[Statistics]] for this plan. The default implementation assumes the output diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/commands.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/commands.scala index 47b34d1fa2e49..b0e45d7d09ab0 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/commands.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/commands.scala @@ -37,7 +37,7 @@ private[sql] case class DescribeFunction( isExtended: Boolean) extends LogicalPlan with Command { override def children: Seq[LogicalPlan] = Seq.empty - override val output: Seq[Attribute] = Seq( + override val outputBeforeConstraints: Seq[Attribute] = Seq( AttributeReference("function_desc", StringType, nullable = false)()) } @@ -48,6 +48,6 @@ private[sql] case class DescribeFunction( private[sql] case class ShowFunctions( db: Option[String], pattern: Option[String]) extends LogicalPlan with Command { override def children: Seq[LogicalPlan] = Seq.empty - override val output: Seq[Attribute] = Seq( + override val outputBeforeConstraints: Seq[Attribute] = Seq( AttributeReference("function", StringType, nullable = false)()) } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/object.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/object.scala index da7f81c785461..2e622b3884191 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/object.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/object.scala @@ -30,7 +30,7 @@ trait ObjectOperator extends LogicalPlan { /** The serializer that is used to produce the output of this operator. */ def serializer: Seq[NamedExpression] - override def output: Seq[Attribute] = serializer.map(_.toAttribute) + override def outputBeforeConstraints: Seq[Attribute] = serializer.map(_.toAttribute) /** * An [[ObjectOperator]] may have one or more deserializers to convert internal rows to objects. @@ -117,7 +117,7 @@ case class AppendColumns( serializer: Seq[NamedExpression], child: LogicalPlan) extends UnaryNode with ObjectOperator { - override def output: Seq[Attribute] = child.output ++ newColumns + override def outputBeforeConstraints: Seq[Attribute] = child.output ++ newColumns def newColumns: Seq[Attribute] = serializer.map(_.toAttribute) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/partitioning.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/partitioning.scala index a5bdee1b854ce..7878a2f95852c 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/partitioning.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/partitioning.scala @@ -24,7 +24,7 @@ import org.apache.spark.sql.catalyst.expressions.{Attribute, Expression, SortOrd * result have expectations about the distribution and ordering of partitioned input data. */ abstract class RedistributeData extends UnaryNode { - override def output: Seq[Attribute] = child.output + override def outputBeforeConstraints: Seq[Attribute] = child.output } case class SortPartitions(sortExpressions: Seq[SortOrder], child: LogicalPlan) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisErrorSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisErrorSuite.scala index 1b297525bdafb..8cde005717dcd 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisErrorSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisErrorSuite.scala @@ -88,7 +88,7 @@ case class TestFunction( case class UnresolvedTestPlan() extends LeafNode { override lazy val resolved = false - override def output: Seq[Attribute] = Nil + override def outputBeforeConstraints: Seq[Attribute] = Nil } class AnalysisErrorSuite extends AnalysisTest { diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/trees/TreeNodeSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/trees/TreeNodeSuite.scala index 6a188e7e55126..24f7fc73d4de6 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/trees/TreeNodeSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/trees/TreeNodeSuite.scala @@ -35,7 +35,7 @@ case class Dummy(optKey: Option[Expression]) extends Expression with CodegenFall case class ComplexPlan(exprs: Seq[Seq[Expression]]) extends org.apache.spark.sql.catalyst.plans.logical.LeafNode { - override def output: Seq[Attribute] = Nil + override def outputBeforeConstraints: Seq[Attribute] = Nil } case class ExpressionInMap(map: Map[String, Expression]) extends Expression with Unevaluable { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/ExistingRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/ExistingRDD.scala index e97c6be7f177a..d34966b0eee67 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/ExistingRDD.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/ExistingRDD.scala @@ -72,7 +72,7 @@ object RDDConversions { /** Logical plan node for scanning data from an RDD. */ private[sql] case class LogicalRDD( - output: Seq[Attribute], + outputBeforeConstraints: Seq[Attribute], rdd: RDD[InternalRow])(sqlContext: SQLContext) extends LogicalPlan with MultiInstanceRelation { @@ -99,7 +99,7 @@ private[sql] case class LogicalRDD( /** Physical plan node for scanning data from an RDD. */ private[sql] case class PhysicalRDD( - output: Seq[Attribute], + override val output: Seq[Attribute], rdd: RDD[InternalRow], override val nodeName: String) extends LeafNode { @@ -124,7 +124,7 @@ private[sql] case class PhysicalRDD( /** Physical plan node for scanning data from a relation. */ private[sql] case class DataSourceScan( - output: Seq[Attribute], + override val output: Seq[Attribute], rdd: RDD[InternalRow], @transient relation: BaseRelation, override val metadata: Map[String, String] = Map.empty) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/Expand.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/Expand.scala index a84e180ad1dd8..396516687ea09 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/Expand.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/Expand.scala @@ -30,12 +30,12 @@ import org.apache.spark.sql.execution.metric.SQLMetrics * multiple output rows for a input row. * @param projections The group of expressions, all of the group expressions should * output the same schema specified bye the parameter `output` - * @param output The output Schema + * @param outputBeforeConstraints The output Schema * @param child Child operator */ case class Expand( projections: Seq[Seq[Expression]], - output: Seq[Attribute], + override val output: Seq[Attribute], child: SparkPlan) extends UnaryNode with CodegenSupport { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/Generate.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/Generate.scala index 9938d2169f1c3..5419f702bafbb 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/Generate.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/Generate.scala @@ -44,14 +44,15 @@ private[execution] sealed case class LazyIterator(func: () => TraversableOnce[In * it. * @param outer when true, each input row will be output at least once, even if the output of the * given `generator` is empty. `outer` has no effect when `join` is false. - * @param output the output attributes of this node, which constructed in analysis phase, - * and we can not change it, as the parent node bound with it already. + * @param outputBeforeConstraints the output attributes of this node, which constructed in analysis + * phase, and we can not change it, as the parent node bound with it + * already. */ case class Generate( generator: Generator, join: Boolean, outer: Boolean, - output: Seq[Attribute], + override val output: Seq[Attribute], child: SparkPlan) extends UnaryNode { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/LocalTableScan.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/LocalTableScan.scala index f8aec9e7a1d1b..803bfe7bf60e3 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/LocalTableScan.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/LocalTableScan.scala @@ -27,7 +27,7 @@ import org.apache.spark.sql.execution.metric.SQLMetrics * Physical plan node for scanning data from a local collection. */ private[sql] case class LocalTableScan( - output: Seq[Attribute], + override val output: Seq[Attribute], rows: Seq[InternalRow]) extends LeafNode { private[sql] override lazy val metrics = Map( diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala index e04683c499a32..8327dac988df3 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala @@ -42,6 +42,9 @@ import org.apache.spark.util.ThreadUtils */ abstract class SparkPlan extends QueryPlan[SparkPlan] with Logging with Serializable { + // Constraints are not working on SparkPlan. So we make this as empty. + override protected def outputBeforeConstraints: Seq[Attribute] = Seq.empty[Attribute] + /** * A handle to the SQL Context that was used to create this plan. Since many operators need * access to the sqlContext for RDD operations or configuration this field is automatically diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala index 6e2a5aa4f97c7..10ba5e9eb7c3d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala @@ -201,7 +201,7 @@ case class Range( step: Long, numSlices: Int, numElements: BigInt, - output: Seq[Attribute]) + override val output: Seq[Attribute]) extends LeafNode with CodegenSupport { private[sql] override lazy val metrics = Map( @@ -394,7 +394,7 @@ case class Except(left: SparkPlan, right: SparkPlan) extends BinaryNode { * (hopefully structurally equivalent) tree from a different optimization sequence into an already * resolved tree. */ -case class OutputFaker(output: Seq[Attribute], child: SparkPlan) extends SparkPlan { +case class OutputFaker(override val output: Seq[Attribute], child: SparkPlan) extends SparkPlan { def children: Seq[SparkPlan] = child :: Nil protected override def doExecute(): RDD[InternalRow] = child.execute() diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryColumnarTableScan.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryColumnarTableScan.scala index 1f964b1fc1dce..8b9805ed6f9a5 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryColumnarTableScan.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryColumnarTableScan.scala @@ -55,7 +55,7 @@ private[columnar] case class CachedBatch(numRows: Int, buffers: Array[Array[Byte]], stats: InternalRow) private[sql] case class InMemoryRelation( - output: Seq[Attribute], + override val outputBeforeConstraints: Seq[Attribute], useCompression: Boolean, batchSize: Int, storageLevel: StorageLevel, diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/commands.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/commands.scala index 44b07e4613263..cc1f3a83824fa 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/commands.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/commands.scala @@ -37,7 +37,7 @@ import org.apache.spark.sql.types._ * wrapped in `ExecutedCommand` during execution. */ private[sql] trait RunnableCommand extends LogicalPlan with logical.Command { - override def output: Seq[Attribute] = Seq.empty + override def outputBeforeConstraints: Seq[Attribute] = Seq.empty override def children: Seq[LogicalPlan] = Seq.empty def run(sqlContext: SQLContext): Seq[Row] } @@ -212,7 +212,7 @@ case class SetCommand(kv: Option[(String, Option[String])]) extends RunnableComm (keyValueOutput, runFunc) } - override val output: Seq[Attribute] = _output + override val outputBeforeConstraints: Seq[Attribute] = _output override def run(sqlContext: SQLContext): Seq[Row] = runFunc(sqlContext) @@ -226,7 +226,7 @@ case class SetCommand(kv: Option[(String, Option[String])]) extends RunnableComm */ case class ExplainCommand( logicalPlan: LogicalPlan, - override val output: Seq[Attribute] = + override val outputBeforeConstraints: Seq[Attribute] = Seq(AttributeReference("plan", StringType, nullable = true)()), extended: Boolean = false) extends RunnableCommand { @@ -264,7 +264,7 @@ case class CacheTableCommand( Seq.empty[Row] } - override def output: Seq[Attribute] = Seq.empty + override def outputBeforeConstraints: Seq[Attribute] = Seq.empty } @@ -275,7 +275,7 @@ case class UncacheTableCommand(tableName: String) extends RunnableCommand { Seq.empty[Row] } - override def output: Seq[Attribute] = Seq.empty + override def outputBeforeConstraints: Seq[Attribute] = Seq.empty } /** @@ -288,13 +288,13 @@ case object ClearCacheCommand extends RunnableCommand { Seq.empty[Row] } - override def output: Seq[Attribute] = Seq.empty + override def outputBeforeConstraints: Seq[Attribute] = Seq.empty } case class DescribeCommand( table: TableIdentifier, - override val output: Seq[Attribute], + override val outputBeforeConstraints: Seq[Attribute], isExtended: Boolean) extends RunnableCommand { @@ -319,7 +319,7 @@ case class DescribeCommand( case class ShowTablesCommand(databaseName: Option[String]) extends RunnableCommand { // The result of SHOW TABLES has two columns, tableName and isTemporary. - override val output: Seq[Attribute] = { + override val outputBeforeConstraints: Seq[Attribute] = { val schema = StructType( StructField("tableName", StringType, false) :: StructField("isTemporary", BooleanType, false) :: Nil) @@ -347,7 +347,7 @@ case class ShowTablesCommand(databaseName: Option[String]) extends RunnableComma * TODO currently we are simply ignore the db */ case class ShowFunctions(db: Option[String], pattern: Option[String]) extends RunnableCommand { - override val output: Seq[Attribute] = { + override val outputBeforeConstraints: Seq[Attribute] = { val schema = StructType( StructField("function", StringType, nullable = false) :: Nil) @@ -380,7 +380,7 @@ case class DescribeFunction( functionName: String, isExtended: Boolean) extends RunnableCommand { - override val output: Seq[Attribute] = { + override val outputBeforeConstraints: Seq[Attribute] = { val schema = StructType( StructField("function_desc", StringType, nullable = false) :: Nil) @@ -421,5 +421,5 @@ case class SetDatabaseCommand(databaseName: String) extends RunnableCommand { Seq.empty[Row] } - override val output: Seq[Attribute] = Seq.empty + override val outputBeforeConstraints: Seq[Attribute] = Seq.empty } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/LogicalRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/LogicalRelation.scala index 0e0748ff32df3..d151bb3234e0d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/LogicalRelation.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/LogicalRelation.scala @@ -35,7 +35,7 @@ case class LogicalRelation( metastoreTableIdentifier: Option[TableIdentifier] = None) extends LeafNode with MultiInstanceRelation { - override val output: Seq[AttributeReference] = { + override val outputBeforeConstraints: Seq[AttributeReference] = { val attrs = relation.schema.toAttributes expectedOutputAttributes.map { expectedAttrs => assert(expectedAttrs.length == attrs.length) @@ -49,6 +49,8 @@ case class LogicalRelation( }.getOrElse(attrs) } + override def output: Seq[AttributeReference] = outputBeforeConstraints + // Logical Relations are distinct if they have different output for the sake of transformations. override def equals(other: Any): Boolean = other match { case l @ LogicalRelation(otherRelation, _, _) => relation == otherRelation && output == l.output diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/ddl.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/ddl.scala index 7ca0e8859a03e..c03efa5dfa150 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/ddl.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/ddl.scala @@ -39,7 +39,7 @@ case class DescribeCommand( override def children: Seq[LogicalPlan] = Seq.empty - override val output: Seq[Attribute] = Seq( + override val outputBeforeConstraints: Seq[Attribute] = Seq( // Column names are based on Hive. AttributeReference("col_name", StringType, nullable = false, new MetadataBuilder().putString("comment", "name of the column").build())(), @@ -65,7 +65,7 @@ case class CreateTableUsing( allowExisting: Boolean, managedIfNoPath: Boolean) extends LogicalPlan with logical.Command { - override def output: Seq[Attribute] = Seq.empty + override def outputBeforeConstraints: Seq[Attribute] = Seq.empty override def children: Seq[LogicalPlan] = Seq.empty } @@ -84,7 +84,7 @@ case class CreateTableUsingAsSelect( mode: SaveMode, options: Map[String, String], child: LogicalPlan) extends logical.UnaryNode { - override def output: Seq[Attribute] = Seq.empty[Attribute] + override def outputBeforeConstraints: Seq[Attribute] = Seq.empty[Attribute] } case class CreateTempTableUsing( diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/debug/package.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/debug/package.scala index 034bf152620de..30afedc1ec463 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/debug/package.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/debug/package.scala @@ -70,7 +70,7 @@ package object debug { } private[sql] case class DebugNode(child: SparkPlan) extends UnaryNode with CodegenSupport { - def output: Seq[Attribute] = child.output + override def outputBeforeConstraints: Seq[Attribute] = child.output implicit object SetAccumulatorParam extends AccumulatorParam[HashSet[String]] { def zero(initialValue: HashSet[String]): HashSet[String] = { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/BatchPythonEvaluation.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/BatchPythonEvaluation.scala index 79e4491026b65..dd3eeb90f427b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/BatchPythonEvaluation.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/BatchPythonEvaluation.scala @@ -40,7 +40,10 @@ import org.apache.spark.sql.types.{StructField, StructType} * we drain the queue to find the original input row. Note that if the Python process is way too * slow, this could lead to the queue growing unbounded and eventually run out of memory. */ -case class BatchPythonEvaluation(udf: PythonUDF, output: Seq[Attribute], child: SparkPlan) +case class BatchPythonEvaluation( + udf: PythonUDF, + override val outputBeforeConstraints: Seq[Attribute], + child: SparkPlan) extends SparkPlan { def children: Seq[SparkPlan] = child :: Nil diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/EvaluatePython.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/EvaluatePython.scala index da28ec4f53412..583d2d2421046 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/EvaluatePython.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/EvaluatePython.scala @@ -44,7 +44,7 @@ case class EvaluatePython( resultAttribute: AttributeReference) extends logical.UnaryNode { - def output: Seq[Attribute] = child.output :+ resultAttribute + def outputBeforeConstraints: Seq[Attribute] = child.output :+ resultAttribute // References should not include the produced attribute. override def references: AttributeSet = udf.references diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamingRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamingRelation.scala index e35c444348f48..50e186c8499c8 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamingRelation.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamingRelation.scala @@ -29,6 +29,7 @@ object StreamingRelation { * Used to link a streaming [[Source]] of data into a * [[org.apache.spark.sql.catalyst.plans.logical.LogicalPlan]]. */ -case class StreamingRelation(source: Source, output: Seq[Attribute]) extends LeafNode { +case class StreamingRelation(source: Source, outputBeforeConstraints: Seq[Attribute]) + extends LeafNode { override def toString: String = source.toString } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/ExtraStrategiesSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/ExtraStrategiesSuite.scala index b1987c690811d..c4e8232c80b15 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/ExtraStrategiesSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/ExtraStrategiesSuite.scala @@ -24,7 +24,7 @@ import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, Project} import org.apache.spark.sql.execution.SparkPlan import org.apache.spark.sql.test.SharedSQLContext -case class FastOperator(output: Seq[Attribute]) extends SparkPlan { +case class FastOperator(override val output: Seq[Attribute]) extends SparkPlan { override protected def doExecute(): RDD[InternalRow] = { val str = Literal("so fast").value diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala index b6c78691e4827..6eb04d07e1239 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala @@ -772,7 +772,7 @@ private[hive] case class InsertIntoHiveTable( extends LogicalPlan { override def children: Seq[LogicalPlan] = child :: Nil - override def output: Seq[Attribute] = Seq.empty + override def outputBeforeConstraints: Seq[Attribute] = Seq.empty val numDynamicPartitions = partition.values.count(_.isEmpty) @@ -944,7 +944,7 @@ private[hive] case class MetastoreRelation( /** Non-partitionKey attributes */ val attributes = table.schema.map(_.toAttribute) - val output = attributes ++ partitionKeys + override val outputBeforeConstraints = attributes ++ partitionKeys /** An attribute map that can be used to lookup original attributes based on expression id. */ val attributeMap = AttributeMap(output.map(o => (o, o))) diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala index 00fc8af5781ad..07111a9d281aa 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala @@ -50,7 +50,7 @@ import org.apache.spark.sql.AnalysisException */ private[hive] case object NativePlaceholder extends LogicalPlan { override def children: Seq[LogicalPlan] = Seq.empty - override def output: Seq[Attribute] = Seq.empty + override def outputBeforeConstraints: Seq[Attribute] = Seq.empty } private[hive] case class CreateTableAsSelect( @@ -58,7 +58,7 @@ private[hive] case class CreateTableAsSelect( child: LogicalPlan, allowExisting: Boolean) extends UnaryNode with Command { - override def output: Seq[Attribute] = Seq.empty[Attribute] + override def outputBeforeConstraints: Seq[Attribute] = Seq.empty[Attribute] override lazy val resolved: Boolean = tableDesc.name.database.isDefined && tableDesc.schema.nonEmpty && @@ -74,7 +74,7 @@ private[hive] case class CreateViewAsSelect( allowExisting: Boolean, replace: Boolean, sql: String) extends UnaryNode with Command { - override def output: Seq[Attribute] = Seq.empty[Attribute] + override def outputBeforeConstraints: Seq[Attribute] = Seq.empty[Attribute] override lazy val resolved: Boolean = false } diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/SQLBuilder.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/SQLBuilder.scala index f3446a364b9fa..071a7e4b2dde9 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/SQLBuilder.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/SQLBuilder.scala @@ -445,7 +445,7 @@ class SQLBuilder(logicalPlan: LogicalPlan, sqlContext: SQLContext) extends Loggi case class SQLTable( database: String, table: String, - output: Seq[Attribute], + override val outputBeforeConstraints: Seq[Attribute], sample: Option[(Double, Double)] = None) extends LeafNode { def withSample(lowerBound: Double, upperBound: Double): SQLTable = this.copy(sample = Some(lowerBound -> upperBound)) diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/InsertIntoHiveTable.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/InsertIntoHiveTable.scala index 4ffd868242b86..83cec80ae4893 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/InsertIntoHiveTable.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/InsertIntoHiveTable.scala @@ -47,7 +47,7 @@ case class InsertIntoHiveTable( @transient private lazy val hiveContext = new Context(sc.hiveconf) @transient private lazy val catalog = sc.sessionState.catalog - def output: Seq[Attribute] = Seq.empty + override def output: Seq[Attribute] = Seq.empty private def saveAsHiveFile( rdd: RDD[InternalRow], diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/ScriptTransformation.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/ScriptTransformation.scala index 62e7c1223cd96..6483ed829494e 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/ScriptTransformation.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/ScriptTransformation.scala @@ -54,7 +54,7 @@ private[hive] case class ScriptTransformation( input: Seq[Expression], script: String, - output: Seq[Attribute], + override val output: Seq[Attribute], child: SparkPlan, ioschema: HiveScriptIOSchema)(@transient private val sc: HiveContext) extends UnaryNode { From 13ad1721fa2f8af2e8701e743633bfacb47484f4 Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Fri, 25 Mar 2016 08:54:57 +0000 Subject: [PATCH 16/17] Reset nullability for InMemoryScans. --- .../spark/sql/execution/SparkPlanner.scala | 17 ++++++++++++----- .../sql/hive/execution/HiveComparisonTest.scala | 2 +- 2 files changed, 13 insertions(+), 6 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlanner.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlanner.scala index 5296175960e3e..c3e8d9781dd10 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlanner.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlanner.scala @@ -60,14 +60,21 @@ class SparkPlanner( * provided `scanBuilder` function so that it can avoid unnecessary column materialization. */ def pruneFilterProject( - projectList: Seq[NamedExpression], - filterPredicates: Seq[Expression], + inputProjectList: Seq[NamedExpression], + inputFilterPredicates: Seq[Expression], prunePushedDownFilters: Seq[Expression] => Seq[Expression], scanBuilder: Seq[Attribute] => SparkPlan): SparkPlan = { - val projectSet = AttributeSet(projectList.flatMap(_.references).map(_.withNullability(true))) - val filterSet = - AttributeSet(filterPredicates.flatMap(_.references).map(_.withNullability(true))) + val projectList = inputProjectList.map { _.transform { + case a: Attribute => a.withNullability(true) + }}.asInstanceOf[Seq[NamedExpression]] + + val filterPredicates = inputFilterPredicates.map { _.transform { + case a: Attribute => a.withNullability(true) + }} + + val projectSet = AttributeSet(projectList.flatMap(_.references)) + val filterSet = AttributeSet(filterPredicates.flatMap(_.references)) val filterCondition: Option[Expression] = prunePushedDownFilters(filterPredicates).reduceLeftOption(catalyst.expressions.And) diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveComparisonTest.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveComparisonTest.scala index cfca93bbf0659..73adfbfa1d966 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveComparisonTest.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveComparisonTest.scala @@ -480,7 +480,7 @@ abstract class HiveComparisonTest val executions = queryList.map(new TestHive.QueryExecution(_)) executions.foreach(_.toRdd) val tablesGenerated = queryList.zip(executions).flatMap { - case (q, e) => e.sparkPlan.collect { + case (q, e) => e.executedPlan.collect { case i: InsertIntoHiveTable if tablesRead contains i.table.tableName => (q, e, i) } From c6dabcd38f4822e19f989b2196d52ea9e806af3b Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Sun, 27 Mar 2016 07:10:13 +0000 Subject: [PATCH 17/17] Check attribute resolved status. --- .../org/apache/spark/sql/catalyst/plans/QueryPlan.scala | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala index 477b27225be48..846007f1e2633 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala @@ -26,10 +26,10 @@ abstract class QueryPlan[PlanType <: QueryPlan[PlanType]] extends TreeNode[PlanT def output: Seq[Attribute] = { val isNotNulls = constraints.collect { - case IsNotNull(e) if e.isInstanceOf[Attribute] => e.asInstanceOf[Attribute].exprId + case IsNotNull(a: Attribute) if a.resolved => a.exprId } outputBeforeConstraints.map { o => - if (isNotNulls.contains(o.exprId)) { + if (o.resolved && isNotNulls.contains(o.exprId)) { o.withNullability(false) } else { o @@ -80,7 +80,7 @@ abstract class QueryPlan[PlanType <: QueryPlan[PlanType]] extends TreeNode[PlanT // Second, we infer additional constraints from non-nullable attributes that are part of the // operator's output - val nonNullableAttributes = output.filterNot(_.nullable) + val nonNullableAttributes = outputBeforeConstraints.filter(_.resolved).filterNot(_.nullable) isNotNullConstraints ++= nonNullableAttributes.map(IsNotNull).toSet isNotNullConstraints -- constraints