From 8d050df82225b36d06f5b29fff9df2a49b39f551 Mon Sep 17 00:00:00 2001 From: Sameer Agarwal Date: Tue, 19 Jan 2016 00:48:17 -0800 Subject: [PATCH 01/13] initial framework --- .../spark/sql/catalyst/plans/QueryPlan.scala | 30 +++++++++++++++-- .../catalyst/plans/logical/LogicalPlan.scala | 6 +++- .../plans/logical/basicOperators.scala | 32 +++++++++++++++++++ .../sql/catalyst/plans/LogicalPlanSuite.scala | 13 ++++++++ 4 files changed, 78 insertions(+), 3 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala index b43b7ee71e7aa..a2c424dccb38f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala @@ -17,15 +17,41 @@ package org.apache.spark.sql.catalyst.plans -import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeSet, Expression, VirtualColumn} +import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.trees.TreeNode import org.apache.spark.sql.types.{DataType, StructType} -abstract class QueryPlan[PlanType <: TreeNode[PlanType]] extends TreeNode[PlanType] { +abstract class QueryPlan[PlanType <: TreeNode[PlanType]] + extends TreeNode[PlanType] with PredicateHelper { self: PlanType => def output: Seq[Attribute] + /** + * Extracts the output property from a given child. + */ + def extractConstraintFromChild(child: QueryPlan[PlanType]): Option[Expression] = { + child.constraint.flatMap { predicate => + val conjunctivePredicates = splitConjunctivePredicates(predicate) + conjunctivePredicates.flatMap { p => + if (p.references.subsetOf(outputSet)) { + // We only keep predicates that are composed by attributes in the outputSet. + Some(p) + } else { + None + } + }.reduceOption(And) + } + } + + /** + * An expression that describes the data property of the output rows of this operator. + * For example, if the output of this operator is column `a`, an example `constraint` + * expression can be `a > 10`. + */ + def constraint: Option[Expression] = None + + /** * Returns the set of attributes that are output by this node. */ diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala index 6d859551f8c52..36b18c32c05dc 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala @@ -301,10 +301,14 @@ abstract class LeafNode extends LogicalPlan { /** * A logical plan node with single child. */ -abstract class UnaryNode extends LogicalPlan { +abstract class UnaryNode extends LogicalPlan with PredicateHelper { def child: LogicalPlan override def children: Seq[LogicalPlan] = child :: Nil + + override def constraint: Option[Expression] = { + extractConstraintFromChild(child) + } } /** diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala index f4a3d85d2a8a4..bbe783b53a2d2 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala @@ -37,6 +37,21 @@ case class Project(projectList: Seq[NamedExpression], child: LogicalPlan) extend !expressions.exists(!_.resolved) && childrenResolved && !hasSpecialExpressions } + + override def constraint: Option[Expression] = { + extractConstraintFromChild(child) match { + case Some(constraint) => + splitConjunctivePredicates(constraint).flatMap { p => + if (p.references.subsetOf(outputSet)) { + Some(p) + } else { + None + } + }.reduceOption(And) + case None => + None + } + } } /** @@ -88,6 +103,23 @@ case class Generate( case class Filter(condition: Expression, child: LogicalPlan) extends UnaryNode { override def output: Seq[Attribute] = child.output + + override def constraint: Option[Expression] = { + val conjunctivePredicates = splitConjunctivePredicates(condition) + val newProperty = conjunctivePredicates.flatMap { p => + if (p.references.subsetOf(outputSet)) { + Some(p) + } else { + None + } + }.reduceOption(And) + (newProperty, extractConstraintFromChild(child)) match { + case (Some(p1), Some(p2)) => Some(And(p1, p2)) + case (None, Some(p2)) => Some(p2) + case (Some(p1), None) => Some(p1) + case (None, None) => None + } + } } abstract class SetOperation(left: LogicalPlan, right: LogicalPlan) extends BinaryNode { diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/LogicalPlanSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/LogicalPlanSuite.scala index 455a3810c719e..5abfb9b1178ce 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/LogicalPlanSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/LogicalPlanSuite.scala @@ -21,6 +21,8 @@ import org.apache.spark.SparkFunSuite import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.util._ +import org.apache.spark.sql.catalyst.dsl.expressions._ +import org.apache.spark.sql.catalyst.dsl.plans._ /** * This suite is used to test [[LogicalPlan]]'s `resolveOperators` and make sure it can correctly @@ -70,4 +72,15 @@ class LogicalPlanSuite extends SparkFunSuite { assert(invocationCount === 1) } + + test("propagating constraint in filter") { + val tr = LocalRelation('a.int, 'b.string, 'c.int) + assert(tr.analyze.constraint.isEmpty) + assert(tr.select('a.attr).analyze.constraint.isEmpty) + assert(tr.where('a.attr > 10).analyze.constraint.get == ('a > 10)) + assert(tr.where('a.attr > 10).select('c.attr).analyze.constraint.get == ('a > 10)) + assert(tr.where('a.attr > 10).select('c.attr, 'a.attr).where('c.attr < 100) + .analyze.constraint.get == And('a > 10, 'c < 100)) + assert(tr.where('a.attr > 10).select('c.attr, 'b.attr).analyze.constraint.isEmpty) + } } From fed48b84e4e40a58bddb1ee742bda738a0b63ab1 Mon Sep 17 00:00:00 2001 From: Sameer Agarwal Date: Tue, 19 Jan 2016 11:28:21 -0800 Subject: [PATCH 02/13] Initial set of constraints --- .../spark/sql/catalyst/plans/QueryPlan.scala | 10 +++++----- .../catalyst/plans/logical/LogicalPlan.scala | 2 +- .../plans/logical/basicOperators.scala | 20 +++++++++++++++++++ 3 files changed, 26 insertions(+), 6 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala index a2c424dccb38f..f54c35e7467c9 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala @@ -30,7 +30,7 @@ abstract class QueryPlan[PlanType <: TreeNode[PlanType]] /** * Extracts the output property from a given child. */ - def extractConstraintFromChild(child: QueryPlan[PlanType]): Option[Expression] = { + def extractConstraintFromChild(child: QueryPlan[PlanType]): Seq[Expression] = { child.constraint.flatMap { predicate => val conjunctivePredicates = splitConjunctivePredicates(predicate) conjunctivePredicates.flatMap { p => @@ -45,11 +45,11 @@ abstract class QueryPlan[PlanType <: TreeNode[PlanType]] } /** - * An expression that describes the data property of the output rows of this operator. - * For example, if the output of this operator is column `a`, an example `constraint` - * expression can be `a > 10`. + * An sequence of expressions that describes the data property of the output rows of this + * operator. For example, if the output of this operator is column `a`, an example `constraint` + * can be `Seq(a > 10, a < 20)`. */ - def constraint: Option[Expression] = None + def constraint: Seq[Expression] = Nil /** diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala index 36b18c32c05dc..0285558551306 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala @@ -306,7 +306,7 @@ abstract class UnaryNode extends LogicalPlan with PredicateHelper { override def children: Seq[LogicalPlan] = child :: Nil - override def constraint: Option[Expression] = { + override def constraint: Seq[Expression] = { extractConstraintFromChild(child) } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala index bbe783b53a2d2..2887314f597f3 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala @@ -144,6 +144,15 @@ case class Union(left: LogicalPlan, right: LogicalPlan) extends SetOperation(lef val sizeInBytes = left.statistics.sizeInBytes + right.statistics.sizeInBytes Statistics(sizeInBytes = sizeInBytes) } + + override def constraint: Option[Expression] = { + (extractConstraintFromChild(left), extractConstraintFromChild(right)) match { + case (Some(p1), Some(p2)) => Some(Or(p1, p2)) + case (None, Some(p2)) => None + case (Some(p1), None) => None + case (None, None) => None + } + } } case class Intersect(left: LogicalPlan, right: LogicalPlan) extends SetOperation(left, right) { @@ -152,11 +161,22 @@ case class Intersect(left: LogicalPlan, right: LogicalPlan) extends SetOperation left.output.zip(right.output).map { case (leftAttr, rightAttr) => leftAttr.withNullability(leftAttr.nullable && rightAttr.nullable) } + + override def constraint: Option[Expression] = { + (extractConstraintFromChild(left), extractConstraintFromChild(right)) match { + case (Some(p1), Some(p2)) => Some(And(p1, p2)) + case (None, Some(p2)) => Some(p2) + case (Some(p1), None) => Some(p1) + case (None, None) => None + } + } } case class Except(left: LogicalPlan, right: LogicalPlan) extends SetOperation(left, right) { /** We don't use right.output because those rows get excluded from the set. */ override def output: Seq[Attribute] = left.output + + override def constraint: Option[Expression] = extractConstraintFromChild(left) } case class Join( From 76d27275d9cc0452da2d2fd0c765b8fc5d586474 Mon Sep 17 00:00:00 2001 From: Sameer Agarwal Date: Tue, 19 Jan 2016 18:03:18 -0800 Subject: [PATCH 03/13] Constraint propagation in Set and Binary operators --- .../spark/sql/catalyst/plans/QueryPlan.scala | 19 +---- .../catalyst/plans/logical/LogicalPlan.scala | 4 +- .../plans/logical/basicOperators.scala | 85 +++++++++---------- .../sql/catalyst/plans/LogicalPlanSuite.scala | 15 ++-- 4 files changed, 55 insertions(+), 68 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala index f54c35e7467c9..b8765a853e1c6 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala @@ -30,27 +30,16 @@ abstract class QueryPlan[PlanType <: TreeNode[PlanType]] /** * Extracts the output property from a given child. */ - def extractConstraintFromChild(child: QueryPlan[PlanType]): Seq[Expression] = { - child.constraint.flatMap { predicate => - val conjunctivePredicates = splitConjunctivePredicates(predicate) - conjunctivePredicates.flatMap { p => - if (p.references.subsetOf(outputSet)) { - // We only keep predicates that are composed by attributes in the outputSet. - Some(p) - } else { - None - } - }.reduceOption(And) - } + def extractConstraintsFromChild(child: QueryPlan[PlanType]): Seq[Expression] = { + child.constraints.filter(_.references.subsetOf(outputSet)) } /** * An sequence of expressions that describes the data property of the output rows of this - * operator. For example, if the output of this operator is column `a`, an example `constraint` + * operator. For example, if the output of this operator is column `a`, an example `constraints` * can be `Seq(a > 10, a < 20)`. */ - def constraint: Seq[Expression] = Nil - + def constraints: Seq[Expression] = Nil /** * Returns the set of attributes that are output by this node. diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala index 0285558551306..6193a0f4a2cf9 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala @@ -306,8 +306,8 @@ abstract class UnaryNode extends LogicalPlan with PredicateHelper { override def children: Seq[LogicalPlan] = child :: Nil - override def constraint: Seq[Expression] = { - extractConstraintFromChild(child) + override def constraints: Seq[Expression] = { + extractConstraintsFromChild(child) } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala index 2887314f597f3..b8cce00891b6e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala @@ -37,21 +37,6 @@ case class Project(projectList: Seq[NamedExpression], child: LogicalPlan) extend !expressions.exists(!_.resolved) && childrenResolved && !hasSpecialExpressions } - - override def constraint: Option[Expression] = { - extractConstraintFromChild(child) match { - case Some(constraint) => - splitConjunctivePredicates(constraint).flatMap { p => - if (p.references.subsetOf(outputSet)) { - Some(p) - } else { - None - } - }.reduceOption(And) - case None => - None - } - } } /** @@ -104,25 +89,30 @@ case class Generate( case class Filter(condition: Expression, child: LogicalPlan) extends UnaryNode { override def output: Seq[Attribute] = child.output - override def constraint: Option[Expression] = { - val conjunctivePredicates = splitConjunctivePredicates(condition) - val newProperty = conjunctivePredicates.flatMap { p => - if (p.references.subsetOf(outputSet)) { - Some(p) - } else { - None - } - }.reduceOption(And) - (newProperty, extractConstraintFromChild(child)) match { - case (Some(p1), Some(p2)) => Some(And(p1, p2)) - case (None, Some(p2)) => Some(p2) - case (Some(p1), None) => Some(p1) - case (None, None) => None - } + override def constraints: Seq[Expression] = { + val newConstraint = splitConjunctivePredicates(condition).filter( + _.references.subsetOf(outputSet)) + newConstraint.union(extractConstraintsFromChild(child)) } } abstract class SetOperation(left: LogicalPlan, right: LogicalPlan) extends BinaryNode { + + override def output: Seq[Attribute] = + left.output.zip(right.output).map { case (leftAttr, rightAttr) => + leftAttr.withNullability(leftAttr.nullable || rightAttr.nullable) + } + + protected def leftConstraints: Seq[Expression] = extractConstraintsFromChild(left) + + protected def rightConstraints: Seq[Expression] = { + require(left.output.size == right.output.size) + val attributeRewrites = AttributeMap(left.output.zip(right.output)) + extractConstraintsFromChild(right).map(_ transform { + case a: Attribute => attributeRewrites(a) + }) + } + final override lazy val resolved: Boolean = childrenResolved && left.output.length == right.output.length && @@ -145,13 +135,8 @@ case class Union(left: LogicalPlan, right: LogicalPlan) extends SetOperation(lef Statistics(sizeInBytes = sizeInBytes) } - override def constraint: Option[Expression] = { - (extractConstraintFromChild(left), extractConstraintFromChild(right)) match { - case (Some(p1), Some(p2)) => Some(Or(p1, p2)) - case (None, Some(p2)) => None - case (Some(p1), None) => None - case (None, None) => None - } + override def constraints: Seq[Expression] = { + leftConstraints.intersect(rightConstraints) } } @@ -162,13 +147,8 @@ case class Intersect(left: LogicalPlan, right: LogicalPlan) extends SetOperation leftAttr.withNullability(leftAttr.nullable && rightAttr.nullable) } - override def constraint: Option[Expression] = { - (extractConstraintFromChild(left), extractConstraintFromChild(right)) match { - case (Some(p1), Some(p2)) => Some(And(p1, p2)) - case (None, Some(p2)) => Some(p2) - case (Some(p1), None) => Some(p1) - case (None, None) => None - } + override def constraints: Seq[Expression] = { + leftConstraints.union(rightConstraints) } } @@ -176,7 +156,7 @@ case class Except(left: LogicalPlan, right: LogicalPlan) extends SetOperation(le /** We don't use right.output because those rows get excluded from the set. */ override def output: Seq[Attribute] = left.output - override def constraint: Option[Expression] = extractConstraintFromChild(left) + override def constraints: Seq[Expression] = leftConstraints } case class Join( @@ -200,6 +180,21 @@ case class Join( } } + override def constraints: Seq[Expression] = { + joinType match { + case LeftSemi => + extractConstraintsFromChild(left) + case LeftOuter => + extractConstraintsFromChild(left).union(extractConstraintsFromChild(right)) + case RightOuter => + extractConstraintsFromChild(left).union(extractConstraintsFromChild(right)) + case FullOuter => + extractConstraintsFromChild(left).union(extractConstraintsFromChild(right)) + case _ => + extractConstraintsFromChild(left).union(extractConstraintsFromChild(right)) + } + } + def selfJoinResolved: Boolean = left.outputSet.intersect(right.outputSet).isEmpty // Joins are only resolved if they don't introduce ambiguous expression ids. diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/LogicalPlanSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/LogicalPlanSuite.scala index 5abfb9b1178ce..9d41b72db7e02 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/LogicalPlanSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/LogicalPlanSuite.scala @@ -18,6 +18,7 @@ package org.apache.spark.sql.catalyst.plans import org.apache.spark.SparkFunSuite +import org.apache.spark.sql.catalyst.analysis._ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.util._ @@ -75,12 +76,14 @@ class LogicalPlanSuite extends SparkFunSuite { test("propagating constraint in filter") { val tr = LocalRelation('a.int, 'b.string, 'c.int) - assert(tr.analyze.constraint.isEmpty) - assert(tr.select('a.attr).analyze.constraint.isEmpty) - assert(tr.where('a.attr > 10).analyze.constraint.get == ('a > 10)) - assert(tr.where('a.attr > 10).select('c.attr).analyze.constraint.get == ('a > 10)) + assert(tr.analyze.constraints.isEmpty) + assert(tr.select('a.attr).analyze.constraints.isEmpty) + val logicalPlan = tr.where('a.attr > 10).analyze + assert(logicalPlan.constraints == + Seq(logicalPlan.resolve(Seq('a > 10), caseInsensitiveResolution)) + assert(tr.where('a.attr > 10).select('c.attr).analyze.constraints.get == ('a > 10)) assert(tr.where('a.attr > 10).select('c.attr, 'a.attr).where('c.attr < 100) - .analyze.constraint.get == And('a > 10, 'c < 100)) - assert(tr.where('a.attr > 10).select('c.attr, 'b.attr).analyze.constraint.isEmpty) + .analyze.constraints.get == And('a > 10, 'c < 100)) + assert(tr.where('a.attr > 10).select('c.attr, 'b.attr).analyze.constraints.isEmpty) } } From 67e138d480b97f78e00229b2e407eececd56a599 Mon Sep 17 00:00:00 2001 From: Sameer Agarwal Date: Tue, 19 Jan 2016 18:53:57 -0800 Subject: [PATCH 04/13] modify test --- .../sql/catalyst/plans/LogicalPlanSuite.scala | 14 +++++++++++++- 1 file changed, 13 insertions(+), 1 deletion(-) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/LogicalPlanSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/LogicalPlanSuite.scala index 9d41b72db7e02..2eae20809bc8b 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/LogicalPlanSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/LogicalPlanSuite.scala @@ -75,15 +75,27 @@ class LogicalPlanSuite extends SparkFunSuite { } test("propagating constraint in filter") { + + def resolve(plan: LogicalPlan, constraints: Seq[String]): Seq[Expression] = { + Seq(plan.resolve(constraints.map(_.toString), caseInsensitiveResolution).get) + } + val tr = LocalRelation('a.int, 'b.string, 'c.int) assert(tr.analyze.constraints.isEmpty) assert(tr.select('a.attr).analyze.constraints.isEmpty) - val logicalPlan = tr.where('a.attr > 10).analyze + assert(tr.where('a.attr > 10).analyze.constraints.zip(Seq('a.attr > 10)) + .forall(e => e._1.semanticEquals(e._2))) + /* + assert(tr.where('a.attr > 10).analyze.constraints == resolve(tr.where('a.attr > 10).analyze, + Seq("a > 10"))) + */ +/* assert(logicalPlan.constraints == Seq(logicalPlan.resolve(Seq('a > 10), caseInsensitiveResolution)) assert(tr.where('a.attr > 10).select('c.attr).analyze.constraints.get == ('a > 10)) assert(tr.where('a.attr > 10).select('c.attr, 'a.attr).where('c.attr < 100) .analyze.constraints.get == And('a > 10, 'c < 100)) assert(tr.where('a.attr > 10).select('c.attr, 'b.attr).analyze.constraints.isEmpty) +*/ } } From 04ff99ab96957f57c74ae24835b4bfcdd27e06b8 Mon Sep 17 00:00:00 2001 From: Sameer Agarwal Date: Wed, 20 Jan 2016 00:04:54 -0800 Subject: [PATCH 05/13] fix tests --- .../spark/sql/catalyst/plans/QueryPlan.scala | 6 ++-- .../catalyst/plans/logical/LogicalPlan.scala | 2 +- .../plans/logical/basicOperators.scala | 16 +++++------ .../sql/catalyst/plans/LogicalPlanSuite.scala | 28 +++++-------------- 4 files changed, 19 insertions(+), 33 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala index b8765a853e1c6..288fdf5926b1d 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala @@ -30,16 +30,16 @@ abstract class QueryPlan[PlanType <: TreeNode[PlanType]] /** * Extracts the output property from a given child. */ - def extractConstraintsFromChild(child: QueryPlan[PlanType]): Seq[Expression] = { + def extractConstraintsFromChild(child: QueryPlan[PlanType]): Set[Expression] = { child.constraints.filter(_.references.subsetOf(outputSet)) } /** * An sequence of expressions that describes the data property of the output rows of this * operator. For example, if the output of this operator is column `a`, an example `constraints` - * can be `Seq(a > 10, a < 20)`. + * can be `Set(a > 10, a < 20)`. */ - def constraints: Seq[Expression] = Nil + def constraints: Set[Expression] = Set.empty /** * Returns the set of attributes that are output by this node. diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala index 6193a0f4a2cf9..268eb2897a240 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala @@ -306,7 +306,7 @@ abstract class UnaryNode extends LogicalPlan with PredicateHelper { override def children: Seq[LogicalPlan] = child :: Nil - override def constraints: Seq[Expression] = { + override def constraints: Set[Expression] = { extractConstraintsFromChild(child) } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala index b8cce00891b6e..2159f149bfcd2 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala @@ -89,9 +89,9 @@ case class Generate( case class Filter(condition: Expression, child: LogicalPlan) extends UnaryNode { override def output: Seq[Attribute] = child.output - override def constraints: Seq[Expression] = { + override def constraints: Set[Expression] = { val newConstraint = splitConjunctivePredicates(condition).filter( - _.references.subsetOf(outputSet)) + _.references.subsetOf(outputSet)).toSet newConstraint.union(extractConstraintsFromChild(child)) } } @@ -103,9 +103,9 @@ abstract class SetOperation(left: LogicalPlan, right: LogicalPlan) extends Binar leftAttr.withNullability(leftAttr.nullable || rightAttr.nullable) } - protected def leftConstraints: Seq[Expression] = extractConstraintsFromChild(left) + protected def leftConstraints: Set[Expression] = extractConstraintsFromChild(left) - protected def rightConstraints: Seq[Expression] = { + protected def rightConstraints: Set[Expression] = { require(left.output.size == right.output.size) val attributeRewrites = AttributeMap(left.output.zip(right.output)) extractConstraintsFromChild(right).map(_ transform { @@ -135,7 +135,7 @@ case class Union(left: LogicalPlan, right: LogicalPlan) extends SetOperation(lef Statistics(sizeInBytes = sizeInBytes) } - override def constraints: Seq[Expression] = { + override def constraints: Set[Expression] = { leftConstraints.intersect(rightConstraints) } } @@ -147,7 +147,7 @@ case class Intersect(left: LogicalPlan, right: LogicalPlan) extends SetOperation leftAttr.withNullability(leftAttr.nullable && rightAttr.nullable) } - override def constraints: Seq[Expression] = { + override def constraints: Set[Expression] = { leftConstraints.union(rightConstraints) } } @@ -156,7 +156,7 @@ case class Except(left: LogicalPlan, right: LogicalPlan) extends SetOperation(le /** We don't use right.output because those rows get excluded from the set. */ override def output: Seq[Attribute] = left.output - override def constraints: Seq[Expression] = leftConstraints + override def constraints: Set[Expression] = leftConstraints } case class Join( @@ -180,7 +180,7 @@ case class Join( } } - override def constraints: Seq[Expression] = { + override def constraints: Set[Expression] = { joinType match { case LeftSemi => extractConstraintsFromChild(left) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/LogicalPlanSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/LogicalPlanSuite.scala index 2eae20809bc8b..3bde69050cd0f 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/LogicalPlanSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/LogicalPlanSuite.scala @@ -19,11 +19,10 @@ package org.apache.spark.sql.catalyst.plans import org.apache.spark.SparkFunSuite import org.apache.spark.sql.catalyst.analysis._ -import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.catalyst.plans.logical._ -import org.apache.spark.sql.catalyst.util._ import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.catalyst.dsl.plans._ +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.plans.logical._ /** * This suite is used to test [[LogicalPlan]]'s `resolveOperators` and make sure it can correctly @@ -75,27 +74,14 @@ class LogicalPlanSuite extends SparkFunSuite { } test("propagating constraint in filter") { - - def resolve(plan: LogicalPlan, constraints: Seq[String]): Seq[Expression] = { - Seq(plan.resolve(constraints.map(_.toString), caseInsensitiveResolution).get) - } - val tr = LocalRelation('a.int, 'b.string, 'c.int) + def resolveColumn(columnName: String): Expression = + tr.analyze.resolveQuoted(columnName, caseInsensitiveResolution).get assert(tr.analyze.constraints.isEmpty) assert(tr.select('a.attr).analyze.constraints.isEmpty) - assert(tr.where('a.attr > 10).analyze.constraints.zip(Seq('a.attr > 10)) - .forall(e => e._1.semanticEquals(e._2))) - /* - assert(tr.where('a.attr > 10).analyze.constraints == resolve(tr.where('a.attr > 10).analyze, - Seq("a > 10"))) - */ -/* - assert(logicalPlan.constraints == - Seq(logicalPlan.resolve(Seq('a > 10), caseInsensitiveResolution)) - assert(tr.where('a.attr > 10).select('c.attr).analyze.constraints.get == ('a > 10)) - assert(tr.where('a.attr > 10).select('c.attr, 'a.attr).where('c.attr < 100) - .analyze.constraints.get == And('a > 10, 'c < 100)) + assert(tr.where('a.attr > 10).analyze.constraints == Set(resolveColumn("a") > 10)) assert(tr.where('a.attr > 10).select('c.attr, 'b.attr).analyze.constraints.isEmpty) -*/ + assert(tr.where('a.attr > 10).select('c.attr, 'a.attr).where('c.attr < 100) + .analyze.constraints == Set(resolveColumn("a") > 10, resolveColumn("c") < 100)) } } From 7bde51d811fef7b422e31864a0613413ce523f28 Mon Sep 17 00:00:00 2001 From: Sameer Agarwal Date: Mon, 25 Jan 2016 11:35:12 -0800 Subject: [PATCH 06/13] outstanding changes --- .../plans/logical/basicOperators.scala | 6 ++- .../plans/ConstraintPropagationSuite.scala | 49 +++++++++++++++++++ .../sql/catalyst/plans/LogicalPlanSuite.scala | 16 +----- 3 files changed, 55 insertions(+), 16 deletions(-) create mode 100644 sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/ConstraintPropagationSuite.scala diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala index 2159f149bfcd2..ed1df98f2e8b0 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala @@ -107,7 +107,8 @@ abstract class SetOperation(left: LogicalPlan, right: LogicalPlan) extends Binar protected def rightConstraints: Set[Expression] = { require(left.output.size == right.output.size) - val attributeRewrites = AttributeMap(left.output.zip(right.output)) + val attributeRewrites = AttributeMap(right.output.zip(left.output)) + println(extractConstraintsFromChild(right), attributeRewrites) extractConstraintsFromChild(right).map(_ transform { case a: Attribute => attributeRewrites(a) }) @@ -136,6 +137,8 @@ case class Union(left: LogicalPlan, right: LogicalPlan) extends SetOperation(lef } override def constraints: Set[Expression] = { + println("left", leftConstraints) + println("right", rightConstraints) leftConstraints.intersect(rightConstraints) } } @@ -182,6 +185,7 @@ case class Join( override def constraints: Set[Expression] = { joinType match { + case Inner => case LeftSemi => extractConstraintsFromChild(left) case LeftOuter => diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/ConstraintPropagationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/ConstraintPropagationSuite.scala new file mode 100644 index 0000000000000..f656db217612a --- /dev/null +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/ConstraintPropagationSuite.scala @@ -0,0 +1,49 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.plans + +import org.apache.spark.SparkFunSuite +import org.apache.spark.sql.catalyst.analysis._ +import org.apache.spark.sql.catalyst.dsl.expressions._ +import org.apache.spark.sql.catalyst.dsl.plans._ +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.plans.logical._ + +class ConstraintPropagationSuite extends SparkFunSuite { + + private def resolveColumn(tr: LocalRelation, columnName: String): Expression = + tr.analyze.resolveQuoted(columnName, caseInsensitiveResolution).get + + test("propagating constraints in filter") { + val tr = LocalRelation('a.int, 'b.string, 'c.int) + assert(tr.analyze.constraints.isEmpty) + assert(tr.select('a.attr).analyze.constraints.isEmpty) + assert(tr.where('a.attr > 10).analyze.constraints == Set(resolveColumn(tr, "a") > 10)) + assert(tr.where('a.attr > 10).select('c.attr, 'b.attr).analyze.constraints.isEmpty) + assert(tr.where('a.attr > 10).select('c.attr, 'a.attr).where('c.attr < 100) + .analyze.constraints == Set(resolveColumn(tr, "a") > 10, resolveColumn(tr, "c") < 100)) + } + + test("propagating constraints in union") { + val tr1 = LocalRelation('a.int, 'b.string, 'c.int) + val tr2 = LocalRelation('a.int, 'b.string, 'c.int) + assert(tr1.analyze.constraints.isEmpty && tr2.analyze.constraints.isEmpty) + assert(tr1.where('a.attr > 10).unionAll(tr2.where('a.attr > 10)) + .analyze.constraints == Set(resolveColumn(tr1, "a") > 10)) + } +} diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/LogicalPlanSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/LogicalPlanSuite.scala index 3bde69050cd0f..455a3810c719e 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/LogicalPlanSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/LogicalPlanSuite.scala @@ -18,11 +18,9 @@ package org.apache.spark.sql.catalyst.plans import org.apache.spark.SparkFunSuite -import org.apache.spark.sql.catalyst.analysis._ -import org.apache.spark.sql.catalyst.dsl.expressions._ -import org.apache.spark.sql.catalyst.dsl.plans._ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.logical._ +import org.apache.spark.sql.catalyst.util._ /** * This suite is used to test [[LogicalPlan]]'s `resolveOperators` and make sure it can correctly @@ -72,16 +70,4 @@ class LogicalPlanSuite extends SparkFunSuite { assert(invocationCount === 1) } - - test("propagating constraint in filter") { - val tr = LocalRelation('a.int, 'b.string, 'c.int) - def resolveColumn(columnName: String): Expression = - tr.analyze.resolveQuoted(columnName, caseInsensitiveResolution).get - assert(tr.analyze.constraints.isEmpty) - assert(tr.select('a.attr).analyze.constraints.isEmpty) - assert(tr.where('a.attr > 10).analyze.constraints == Set(resolveColumn("a") > 10)) - assert(tr.where('a.attr > 10).select('c.attr, 'b.attr).analyze.constraints.isEmpty) - assert(tr.where('a.attr > 10).select('c.attr, 'a.attr).where('c.attr < 100) - .analyze.constraints == Set(resolveColumn("a") > 10, resolveColumn("c") < 100)) - } } From 0c4c78bfc3fb71479dfce282177e184f95354a2c Mon Sep 17 00:00:00 2001 From: Sameer Agarwal Date: Mon, 25 Jan 2016 13:17:42 -0800 Subject: [PATCH 07/13] support union with multiple children --- .../catalyst/plans/logical/basicOperators.scala | 16 +++++++++++++--- 1 file changed, 13 insertions(+), 3 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala index d71e3ac7d4085..9d82b8fc6e318 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala @@ -176,10 +176,20 @@ case class Union(children: Seq[LogicalPlan]) extends LogicalPlan { Statistics(sizeInBytes = sizeInBytes) } + def rewriteConstraints( + planA: LogicalPlan, + planB: LogicalPlan, + constraints: Set[Expression]): Set[Expression] = { + require(planA.output.size == planB.output.size) + val attributeRewrites = AttributeMap(planB.output.zip(planA.output)) + constraints.map(_ transform { + case a: Attribute => attributeRewrites(a) + }) + } + override def constraints: Set[Expression] = { - println("left", leftConstraints) - println("right", rightConstraints) - leftConstraints.intersect(rightConstraints) + children.map(child => rewriteConstraints(children.head, child, + extractConstraintsFromChild(child))).reduce(_ intersect _) } } From 7fb2f9ce01aafe45d720212cb13120eb693a6c06 Mon Sep 17 00:00:00 2001 From: Sameer Agarwal Date: Mon, 25 Jan 2016 23:53:56 -0800 Subject: [PATCH 08/13] join propagation --- .../plans/logical/basicOperators.scala | 47 ++++++++++++++----- .../plans/ConstraintPropagationSuite.scala | 28 ++++++++--- 2 files changed, 58 insertions(+), 17 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala index 9d82b8fc6e318..720afc28ea03e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala @@ -97,27 +97,24 @@ case class Filter(condition: Expression, child: LogicalPlan) extends UnaryNode { } abstract class SetOperation(left: LogicalPlan, right: LogicalPlan) extends BinaryNode { + final override lazy val resolved: Boolean = + childrenResolved && + left.output.length == right.output.length && + left.output.zip(right.output).forall { case (l, r) => l.dataType == r.dataType } - override def output: Seq[Attribute] = - left.output.zip(right.output).map { case (leftAttr, rightAttr) => - leftAttr.withNullability(leftAttr.nullable || rightAttr.nullable) - } + override def extractConstraintsFromChild(child: QueryPlan[LogicalPlan]): Set[Expression] = { + child.constraints.filter(_.references.subsetOf(child.outputSet)) + } protected def leftConstraints: Set[Expression] = extractConstraintsFromChild(left) protected def rightConstraints: Set[Expression] = { require(left.output.size == right.output.size) val attributeRewrites = AttributeMap(right.output.zip(left.output)) - println(extractConstraintsFromChild(right), attributeRewrites) extractConstraintsFromChild(right).map(_ transform { case a: Attribute => attributeRewrites(a) }) } - - final override lazy val resolved: Boolean = - childrenResolved && - left.output.length == right.output.length && - left.output.zip(right.output).forall { case (l, r) => l.dataType == r.dataType } } private[sql] object SetOperation { @@ -176,6 +173,10 @@ case class Union(children: Seq[LogicalPlan]) extends LogicalPlan { Statistics(sizeInBytes = sizeInBytes) } + override def extractConstraintsFromChild(child: QueryPlan[LogicalPlan]): Set[Expression] = { + child.constraints.filter(_.references.subsetOf(child.outputSet)) + } + def rewriteConstraints( planA: LogicalPlan, planB: LogicalPlan, @@ -214,11 +215,35 @@ case class Join( } } + def extractNullabilityConstraintsFromJoinCondition(): Set[Expression] = { + var constraints = Set.empty[Expression] + if (condition.isDefined) { + splitConjunctivePredicates(condition.get).foreach { + case EqualTo(l, r) => + constraints = constraints.union(Set(IsNotNull(l), IsNotNull(r))) + case GreaterThan(l, r) => + constraints = constraints.union(Set(IsNotNull(l), IsNotNull(r))) + case GreaterThanOrEqual(l, r) => + constraints = constraints.union(Set(IsNotNull(l), IsNotNull(r))) + case LessThan(l, r) => + constraints = constraints.union(Set(IsNotNull(l), IsNotNull(r))) + case LessThanOrEqual(l, r) => + constraints = constraints.union(Set(IsNotNull(l), IsNotNull(r))) + } + } + // Currently we only propagate constraints if the condition consists of equality + // and ranges. For all other cases, we return an empty set of constraints + constraints + } + override def constraints: Set[Expression] = { joinType match { case Inner => + extractConstraintsFromChild(left).union(extractConstraintsFromChild(right)) + .union(extractNullabilityConstraintsFromJoinCondition()) case LeftSemi => extractConstraintsFromChild(left) + .union(extractNullabilityConstraintsFromJoinCondition()) case LeftOuter => extractConstraintsFromChild(left).union(extractConstraintsFromChild(right)) case RightOuter => @@ -226,7 +251,7 @@ case class Join( case FullOuter => extractConstraintsFromChild(left).union(extractConstraintsFromChild(right)) case _ => - extractConstraintsFromChild(left).union(extractConstraintsFromChild(right)) + Set.empty } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/ConstraintPropagationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/ConstraintPropagationSuite.scala index f656db217612a..7dde726f8f4a2 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/ConstraintPropagationSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/ConstraintPropagationSuite.scala @@ -29,7 +29,7 @@ class ConstraintPropagationSuite extends SparkFunSuite { private def resolveColumn(tr: LocalRelation, columnName: String): Expression = tr.analyze.resolveQuoted(columnName, caseInsensitiveResolution).get - test("propagating constraints in filter") { + test("propagating constraints in filter/project") { val tr = LocalRelation('a.int, 'b.string, 'c.int) assert(tr.analyze.constraints.isEmpty) assert(tr.select('a.attr).analyze.constraints.isEmpty) @@ -40,10 +40,26 @@ class ConstraintPropagationSuite extends SparkFunSuite { } test("propagating constraints in union") { - val tr1 = LocalRelation('a.int, 'b.string, 'c.int) - val tr2 = LocalRelation('a.int, 'b.string, 'c.int) - assert(tr1.analyze.constraints.isEmpty && tr2.analyze.constraints.isEmpty) - assert(tr1.where('a.attr > 10).unionAll(tr2.where('a.attr > 10)) - .analyze.constraints == Set(resolveColumn(tr1, "a") > 10)) + val tr1 = LocalRelation('a.int, 'b.int, 'c.int) + val tr2 = LocalRelation('d.int, 'e.int, 'f.int) + val tr3 = LocalRelation('g.int, 'h.int, 'i.int) + assert(tr1.where('a.attr > 10).unionAll(tr2.where('e.attr > 10) + .unionAll(tr3.where('i.attr > 10))).analyze.constraints.isEmpty) + assert(tr1.where('a.attr > 10).unionAll(tr2.where('d.attr > 10) + .unionAll(tr3.where('g.attr > 10))).analyze.constraints == Set(resolveColumn(tr1, "a") > 10)) + } + + test("propagating constraints in intersect") { + val tr1 = LocalRelation('a.int, 'b.int, 'c.int) + val tr2 = LocalRelation('a.int, 'b.int, 'c.int) + assert(tr1.where('a.attr > 10).intersect(tr2.where('b.attr < 100)).analyze.constraints == + Set(resolveColumn(tr1, "a") > 10, resolveColumn(tr1, "b") < 100)) + } + + test("propagating constraints in except") { + val tr1 = LocalRelation('a.int, 'b.int, 'c.int) + val tr2 = LocalRelation('a.int, 'b.int, 'c.int) + assert(tr1.where('a.attr > 10).except(tr2.where('b.attr < 100)).analyze.constraints == + Set(resolveColumn(tr1, "a") > 10)) } } From f15ef96657603b79a815853fba991835fe3ca50f Mon Sep 17 00:00:00 2001 From: Sameer Agarwal Date: Tue, 26 Jan 2016 18:17:32 -0800 Subject: [PATCH 09/13] support all joins --- .../plans/logical/basicOperators.scala | 40 +++++----- .../plans/ConstraintPropagationSuite.scala | 76 +++++++++++++++++-- 2 files changed, 91 insertions(+), 25 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala index 720afc28ea03e..ede56b37e20df 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala @@ -215,10 +215,13 @@ case class Join( } } - def extractNullabilityConstraintsFromJoinCondition(): Set[Expression] = { + override def constraints: Set[Expression] = { var constraints = Set.empty[Expression] - if (condition.isDefined) { - splitConjunctivePredicates(condition.get).foreach { + + // Currently we only propagate constraints if the condition consists of equality + // and ranges. For all other cases, we return an empty set of constraints + def extractIsNotNullConstraints(condition: Expression): Set[Expression] = { + splitConjunctivePredicates(condition).foreach { case EqualTo(l, r) => constraints = constraints.union(Set(IsNotNull(l), IsNotNull(r))) case GreaterThan(l, r) => @@ -230,29 +233,32 @@ case class Join( case LessThanOrEqual(l, r) => constraints = constraints.union(Set(IsNotNull(l), IsNotNull(r))) } + constraints } - // Currently we only propagate constraints if the condition consists of equality - // and ranges. For all other cases, we return an empty set of constraints - constraints - } - override def constraints: Set[Expression] = { - joinType match { - case Inner => + def extractIsNullConstraints(plan: LogicalPlan): Set[Expression] = { + constraints = constraints.union(plan.output.map(IsNull).toSet) + constraints + } + + constraints = joinType match { + case Inner if condition.isDefined => extractConstraintsFromChild(left).union(extractConstraintsFromChild(right)) - .union(extractNullabilityConstraintsFromJoinCondition()) - case LeftSemi => - extractConstraintsFromChild(left) - .union(extractNullabilityConstraintsFromJoinCondition()) - case LeftOuter => + .union(extractIsNotNullConstraints(condition.get)) + case LeftSemi if condition.isDefined => extractConstraintsFromChild(left).union(extractConstraintsFromChild(right)) + .union(extractIsNotNullConstraints(condition.get)) + case LeftOuter => + extractConstraintsFromChild(left).union(extractIsNullConstraints(right)) case RightOuter => - extractConstraintsFromChild(left).union(extractConstraintsFromChild(right)) + extractConstraintsFromChild(right).union(extractIsNullConstraints(left)) case FullOuter => - extractConstraintsFromChild(left).union(extractConstraintsFromChild(right)) + extractIsNullConstraints(left).union(extractIsNullConstraints(right)) case _ => Set.empty } + + constraints.filter(_.references.subsetOf(outputSet)) } def selfJoinResolved: Boolean = left.outputSet.intersect(right.outputSet).isEmpty diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/ConstraintPropagationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/ConstraintPropagationSuite.scala index 7dde726f8f4a2..36c63a484ae2f 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/ConstraintPropagationSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/ConstraintPropagationSuite.scala @@ -29,14 +29,19 @@ class ConstraintPropagationSuite extends SparkFunSuite { private def resolveColumn(tr: LocalRelation, columnName: String): Expression = tr.analyze.resolveQuoted(columnName, caseInsensitiveResolution).get + private def verifyConstraints(a: Set[Expression], b: Set[Expression]): Unit = { + assert(a.forall(i => b.map(_.semanticEquals(i)).reduce(_ || _))) + assert(b.forall(i => a.map(_.semanticEquals(i)).reduce(_ || _))) + } + test("propagating constraints in filter/project") { val tr = LocalRelation('a.int, 'b.string, 'c.int) assert(tr.analyze.constraints.isEmpty) assert(tr.select('a.attr).analyze.constraints.isEmpty) - assert(tr.where('a.attr > 10).analyze.constraints == Set(resolveColumn(tr, "a") > 10)) assert(tr.where('a.attr > 10).select('c.attr, 'b.attr).analyze.constraints.isEmpty) - assert(tr.where('a.attr > 10).select('c.attr, 'a.attr).where('c.attr < 100) - .analyze.constraints == Set(resolveColumn(tr, "a") > 10, resolveColumn(tr, "c") < 100)) + verifyConstraints(tr.where('a.attr > 10).analyze.constraints, Set(resolveColumn(tr, "a") > 10)) + verifyConstraints(tr.where('a.attr > 10).select('c.attr, 'a.attr).where('c.attr < 100) + .analyze.constraints, Set(resolveColumn(tr, "a") > 10, resolveColumn(tr, "c") < 100)) } test("propagating constraints in union") { @@ -45,21 +50,76 @@ class ConstraintPropagationSuite extends SparkFunSuite { val tr3 = LocalRelation('g.int, 'h.int, 'i.int) assert(tr1.where('a.attr > 10).unionAll(tr2.where('e.attr > 10) .unionAll(tr3.where('i.attr > 10))).analyze.constraints.isEmpty) - assert(tr1.where('a.attr > 10).unionAll(tr2.where('d.attr > 10) - .unionAll(tr3.where('g.attr > 10))).analyze.constraints == Set(resolveColumn(tr1, "a") > 10)) + verifyConstraints(tr1.where('a.attr > 10).unionAll(tr2.where('d.attr > 10) + .unionAll(tr3.where('g.attr > 10))).analyze.constraints, Set(resolveColumn(tr1, "a") > 10)) } test("propagating constraints in intersect") { val tr1 = LocalRelation('a.int, 'b.int, 'c.int) val tr2 = LocalRelation('a.int, 'b.int, 'c.int) - assert(tr1.where('a.attr > 10).intersect(tr2.where('b.attr < 100)).analyze.constraints == - Set(resolveColumn(tr1, "a") > 10, resolveColumn(tr1, "b") < 100)) + verifyConstraints(tr1.where('a.attr > 10).intersect(tr2.where('b.attr < 100)) + .analyze.constraints, Set(resolveColumn(tr1, "a") > 10, resolveColumn(tr1, "b") < 100)) } test("propagating constraints in except") { val tr1 = LocalRelation('a.int, 'b.int, 'c.int) val tr2 = LocalRelation('a.int, 'b.int, 'c.int) - assert(tr1.where('a.attr > 10).except(tr2.where('b.attr < 100)).analyze.constraints == + verifyConstraints(tr1.where('a.attr > 10).except(tr2.where('b.attr < 100)).analyze.constraints, Set(resolveColumn(tr1, "a") > 10)) } + + test("propagating constraints in inner join") { + val tr1 = LocalRelation('a.int, 'b.int, 'c.int).subquery('tr1) + val tr2 = LocalRelation('a.int, 'd.int, 'e.int).subquery('tr2) + verifyConstraints(tr1.where('a.attr > 10).join(tr2.where('d.attr < 100), Inner, + Some("tr1.a".attr === "tr2.a".attr)).analyze.constraints, + Set(tr1.resolveQuoted("a", caseInsensitiveResolution).get > 10, + tr2.resolveQuoted("d", caseInsensitiveResolution).get < 100, + IsNotNull(tr2.resolveQuoted("a", caseInsensitiveResolution).get), + IsNotNull(tr1.resolveQuoted("a", caseInsensitiveResolution).get))) + } + + test("propagating constraints in left-semi join") { + val tr1 = LocalRelation('a.int, 'b.int, 'c.int).subquery('tr1) + val tr2 = LocalRelation('a.int, 'd.int, 'e.int).subquery('tr2) + verifyConstraints(tr1.where('a.attr > 10).join(tr2.where('d.attr < 100), LeftSemi, + Some("tr1.a".attr === "tr2.a".attr)).analyze.constraints, + Set(tr1.resolveQuoted("a", caseInsensitiveResolution).get > 10, + IsNotNull(tr1.resolveQuoted("a", caseInsensitiveResolution).get))) + } + + test("propagating constraints in left-outer join") { + val tr1 = LocalRelation('a.int, 'b.int, 'c.int).subquery('tr1) + val tr2 = LocalRelation('a.int, 'd.int, 'e.int).subquery('tr2) + verifyConstraints(tr1.where('a.attr > 10).join(tr2.where('d.attr < 100), LeftOuter, + Some("tr1.a".attr === "tr2.a".attr)).analyze.constraints, + Set(tr1.resolveQuoted("a", caseInsensitiveResolution).get > 10, + IsNull(tr2.resolveQuoted("a", caseInsensitiveResolution).get), + IsNull(tr2.resolveQuoted("d", caseInsensitiveResolution).get), + IsNull(tr2.resolveQuoted("e", caseInsensitiveResolution).get))) + } + + test("propagating constraints in right-outer join") { + val tr1 = LocalRelation('a.int, 'b.int, 'c.int).subquery('tr1) + val tr2 = LocalRelation('a.int, 'd.int, 'e.int).subquery('tr2) + verifyConstraints(tr1.where('a.attr > 10).join(tr2.where('d.attr < 100), RightOuter, + Some("tr1.a".attr === "tr2.a".attr)).analyze.constraints, + Set(tr2.resolveQuoted("d", caseInsensitiveResolution).get < 100, + IsNull(tr1.resolveQuoted("a", caseInsensitiveResolution).get), + IsNull(tr1.resolveQuoted("b", caseInsensitiveResolution).get), + IsNull(tr1.resolveQuoted("c", caseInsensitiveResolution).get))) + } + + test("propagating constraints in full-outer join") { + val tr1 = LocalRelation('a.int, 'b.int, 'c.int).subquery('tr1) + val tr2 = LocalRelation('a.int, 'd.int, 'e.int).subquery('tr2) + verifyConstraints(tr1.where('a.attr > 10).join(tr2.where('d.attr < 100), FullOuter, + Some("tr1.a".attr === "tr2.a".attr)).analyze.constraints, + Set(IsNull(tr1.resolveQuoted("a", caseInsensitiveResolution).get), + IsNull(tr1.resolveQuoted("b", caseInsensitiveResolution).get), + IsNull(tr1.resolveQuoted("c", caseInsensitiveResolution).get), + IsNull(tr2.resolveQuoted("a", caseInsensitiveResolution).get), + IsNull(tr2.resolveQuoted("d", caseInsensitiveResolution).get), + IsNull(tr2.resolveQuoted("e", caseInsensitiveResolution).get))) + } } From 53be8379ef7f2b86bcd4398361b6748abe2800d8 Mon Sep 17 00:00:00 2001 From: Sameer Agarwal Date: Fri, 29 Jan 2016 01:36:45 -0800 Subject: [PATCH 10/13] Michael's comments --- .../spark/sql/catalyst/plans/QueryPlan.scala | 11 ++- .../catalyst/plans/logical/LogicalPlan.scala | 4 +- .../plans/logical/basicOperators.scala | 83 ++++++++++--------- .../plans/ConstraintPropagationSuite.scala | 68 ++++++++++----- 4 files changed, 99 insertions(+), 67 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala index 288fdf5926b1d..ed32efe5d7bc2 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala @@ -28,18 +28,21 @@ abstract class QueryPlan[PlanType <: TreeNode[PlanType]] def output: Seq[Attribute] /** - * Extracts the output property from a given child. + * Extracts the relevant per-row constraints from a given child while removing those that + * don't apply anymore. */ - def extractConstraintsFromChild(child: QueryPlan[PlanType]): Set[Expression] = { + protected def getRelevantConstraints(child: QueryPlan[PlanType]): Set[Expression] = { child.constraints.filter(_.references.subsetOf(outputSet)) } /** - * An sequence of expressions that describes the data property of the output rows of this + * A sequence of expressions that describes the data property of the output rows of this * operator. For example, if the output of this operator is column `a`, an example `constraints` * can be `Set(a > 10, a < 20)`. */ - def constraints: Set[Expression] = Set.empty + lazy val constraints: Set[Expression] = validConstraints + + protected def validConstraints: Set[Expression] = Set.empty /** * Returns the set of attributes that are output by this node. diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala index 268eb2897a240..791489d193dc0 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala @@ -306,8 +306,8 @@ abstract class UnaryNode extends LogicalPlan with PredicateHelper { override def children: Seq[LogicalPlan] = child :: Nil - override def constraints: Set[Expression] = { - extractConstraintsFromChild(child) + override protected def validConstraints: Set[Expression] = { + getRelevantConstraints(child) } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala index ede56b37e20df..cd99f6960d2c6 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala @@ -89,10 +89,11 @@ case class Generate( case class Filter(condition: Expression, child: LogicalPlan) extends UnaryNode { override def output: Seq[Attribute] = child.output - override def constraints: Set[Expression] = { - val newConstraint = splitConjunctivePredicates(condition).filter( - _.references.subsetOf(outputSet)).toSet - newConstraint.union(extractConstraintsFromChild(child)) + override protected def validConstraints: Set[Expression] = { + val newConstraint = splitConjunctivePredicates(condition) + .filter(_.references.subsetOf(outputSet)) + .toSet + newConstraint.union(getRelevantConstraints(child)) } } @@ -102,16 +103,16 @@ abstract class SetOperation(left: LogicalPlan, right: LogicalPlan) extends Binar left.output.length == right.output.length && left.output.zip(right.output).forall { case (l, r) => l.dataType == r.dataType } - override def extractConstraintsFromChild(child: QueryPlan[LogicalPlan]): Set[Expression] = { + override def getRelevantConstraints(child: QueryPlan[LogicalPlan]): Set[Expression] = { child.constraints.filter(_.references.subsetOf(child.outputSet)) } - protected def leftConstraints: Set[Expression] = extractConstraintsFromChild(left) + protected def leftConstraints: Set[Expression] = getRelevantConstraints(left) protected def rightConstraints: Set[Expression] = { require(left.output.size == right.output.size) val attributeRewrites = AttributeMap(right.output.zip(left.output)) - extractConstraintsFromChild(right).map(_ transform { + getRelevantConstraints(right).map(_ transform { case a: Attribute => attributeRewrites(a) }) } @@ -128,7 +129,7 @@ case class Intersect(left: LogicalPlan, right: LogicalPlan) extends SetOperation leftAttr.withNullability(leftAttr.nullable && rightAttr.nullable) } - override def constraints: Set[Expression] = { + override protected def validConstraints: Set[Expression] = { leftConstraints.union(rightConstraints) } } @@ -137,7 +138,7 @@ case class Except(left: LogicalPlan, right: LogicalPlan) extends SetOperation(le /** We don't use right.output because those rows get excluded from the set. */ override def output: Seq[Attribute] = left.output - override def constraints: Set[Expression] = leftConstraints + override protected def validConstraints: Set[Expression] = leftConstraints } /** Factory for constructing new `Union` nodes. */ @@ -173,7 +174,7 @@ case class Union(children: Seq[LogicalPlan]) extends LogicalPlan { Statistics(sizeInBytes = sizeInBytes) } - override def extractConstraintsFromChild(child: QueryPlan[LogicalPlan]): Set[Expression] = { + override def getRelevantConstraints(child: QueryPlan[LogicalPlan]): Set[Expression] = { child.constraints.filter(_.references.subsetOf(child.outputSet)) } @@ -188,9 +189,10 @@ case class Union(children: Seq[LogicalPlan]) extends LogicalPlan { }) } - override def constraints: Set[Expression] = { - children.map(child => rewriteConstraints(children.head, child, - extractConstraintsFromChild(child))).reduce(_ intersect _) + override protected def validConstraints: Set[Expression] = { + children + .map(child => rewriteConstraints(children.head, child, getRelevantConstraints(child))) + .reduce(_ intersect _) } } @@ -215,50 +217,51 @@ case class Join( } } - override def constraints: Set[Expression] = { - var constraints = Set.empty[Expression] - + override protected def validConstraints: Set[Expression] = { // Currently we only propagate constraints if the condition consists of equality // and ranges. For all other cases, we return an empty set of constraints - def extractIsNotNullConstraints(condition: Expression): Set[Expression] = { - splitConjunctivePredicates(condition).foreach { + def constructIsNotNullConstraints(condition: Expression): Set[Expression] = { + splitConjunctivePredicates(condition).map { case EqualTo(l, r) => - constraints = constraints.union(Set(IsNotNull(l), IsNotNull(r))) + Set(IsNotNull(l), IsNotNull(r)) case GreaterThan(l, r) => - constraints = constraints.union(Set(IsNotNull(l), IsNotNull(r))) + Set(IsNotNull(l), IsNotNull(r)) case GreaterThanOrEqual(l, r) => - constraints = constraints.union(Set(IsNotNull(l), IsNotNull(r))) + Set(IsNotNull(l), IsNotNull(r)) case LessThan(l, r) => - constraints = constraints.union(Set(IsNotNull(l), IsNotNull(r))) + Set(IsNotNull(l), IsNotNull(r)) case LessThanOrEqual(l, r) => - constraints = constraints.union(Set(IsNotNull(l), IsNotNull(r))) - } - constraints + Set(IsNotNull(l), IsNotNull(r)) + case _ => + Set.empty[Expression] + }.foldLeft(Set.empty[Expression])(_ union _.toSet) } - def extractIsNullConstraints(plan: LogicalPlan): Set[Expression] = { - constraints = constraints.union(plan.output.map(IsNull).toSet) - constraints + def constructIsNullConstraints(plan: LogicalPlan): Set[Expression] = { + plan.output.map(IsNull).toSet } - constraints = joinType match { + (joinType match { case Inner if condition.isDefined => - extractConstraintsFromChild(left).union(extractConstraintsFromChild(right)) - .union(extractIsNotNullConstraints(condition.get)) + getRelevantConstraints(left) + .union(getRelevantConstraints(right)) + .union(constructIsNotNullConstraints(condition.get)) case LeftSemi if condition.isDefined => - extractConstraintsFromChild(left).union(extractConstraintsFromChild(right)) - .union(extractIsNotNullConstraints(condition.get)) + getRelevantConstraints(left) + .union(getRelevantConstraints(right)) + .union(constructIsNotNullConstraints(condition.get)) case LeftOuter => - extractConstraintsFromChild(left).union(extractIsNullConstraints(right)) + getRelevantConstraints(left) + .union(constructIsNullConstraints(right)) case RightOuter => - extractConstraintsFromChild(right).union(extractIsNullConstraints(left)) + getRelevantConstraints(right) + .union(constructIsNullConstraints(left)) case FullOuter => - extractIsNullConstraints(left).union(extractIsNullConstraints(right)) + constructIsNullConstraints(left) + .union(constructIsNullConstraints(right)) case _ => - Set.empty - } - - constraints.filter(_.references.subsetOf(outputSet)) + Set.empty[Expression] + }).filter(_.references.subsetOf(outputSet)) } def selfJoinResolved: Boolean = left.outputSet.intersect(right.outputSet).isEmpty diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/ConstraintPropagationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/ConstraintPropagationSuite.scala index 36c63a484ae2f..ed03c0364f694 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/ConstraintPropagationSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/ConstraintPropagationSuite.scala @@ -34,45 +34,63 @@ class ConstraintPropagationSuite extends SparkFunSuite { assert(b.forall(i => a.map(_.semanticEquals(i)).reduce(_ || _))) } - test("propagating constraints in filter/project") { + test("propagating constraints in filters") { val tr = LocalRelation('a.int, 'b.string, 'c.int) assert(tr.analyze.constraints.isEmpty) - assert(tr.select('a.attr).analyze.constraints.isEmpty) assert(tr.where('a.attr > 10).select('c.attr, 'b.attr).analyze.constraints.isEmpty) verifyConstraints(tr.where('a.attr > 10).analyze.constraints, Set(resolveColumn(tr, "a") > 10)) - verifyConstraints(tr.where('a.attr > 10).select('c.attr, 'a.attr).where('c.attr < 100) - .analyze.constraints, Set(resolveColumn(tr, "a") > 10, resolveColumn(tr, "c") < 100)) + verifyConstraints(tr + .where('a.attr > 10) + .select('c.attr, 'a.attr) + .where('c.attr < 100) + .analyze.constraints, + Set(resolveColumn(tr, "a") > 10, resolveColumn(tr, "c") < 100)) } test("propagating constraints in union") { val tr1 = LocalRelation('a.int, 'b.int, 'c.int) val tr2 = LocalRelation('d.int, 'e.int, 'f.int) val tr3 = LocalRelation('g.int, 'h.int, 'i.int) - assert(tr1.where('a.attr > 10).unionAll(tr2.where('e.attr > 10) - .unionAll(tr3.where('i.attr > 10))).analyze.constraints.isEmpty) - verifyConstraints(tr1.where('a.attr > 10).unionAll(tr2.where('d.attr > 10) - .unionAll(tr3.where('g.attr > 10))).analyze.constraints, Set(resolveColumn(tr1, "a") > 10)) + assert(tr1 + .where('a.attr > 10) + .unionAll(tr2.where('e.attr > 10) + .unionAll(tr3.where('i.attr > 10))) + .analyze.constraints.isEmpty) + verifyConstraints(tr1 + .where('a.attr > 10) + .unionAll(tr2.where('d.attr > 10) + .unionAll(tr3.where('g.attr > 10))) + .analyze.constraints, + Set(resolveColumn(tr1, "a") > 10)) } test("propagating constraints in intersect") { val tr1 = LocalRelation('a.int, 'b.int, 'c.int) val tr2 = LocalRelation('a.int, 'b.int, 'c.int) - verifyConstraints(tr1.where('a.attr > 10).intersect(tr2.where('b.attr < 100)) - .analyze.constraints, Set(resolveColumn(tr1, "a") > 10, resolveColumn(tr1, "b") < 100)) + verifyConstraints(tr1 + .where('a.attr > 10) + .intersect(tr2.where('b.attr < 100)) + .analyze.constraints, + Set(resolveColumn(tr1, "a") > 10, resolveColumn(tr1, "b") < 100)) } test("propagating constraints in except") { val tr1 = LocalRelation('a.int, 'b.int, 'c.int) val tr2 = LocalRelation('a.int, 'b.int, 'c.int) - verifyConstraints(tr1.where('a.attr > 10).except(tr2.where('b.attr < 100)).analyze.constraints, + verifyConstraints(tr1 + .where('a.attr > 10) + .except(tr2.where('b.attr < 100)) + .analyze.constraints, Set(resolveColumn(tr1, "a") > 10)) } test("propagating constraints in inner join") { val tr1 = LocalRelation('a.int, 'b.int, 'c.int).subquery('tr1) val tr2 = LocalRelation('a.int, 'd.int, 'e.int).subquery('tr2) - verifyConstraints(tr1.where('a.attr > 10).join(tr2.where('d.attr < 100), Inner, - Some("tr1.a".attr === "tr2.a".attr)).analyze.constraints, + verifyConstraints(tr1 + .where('a.attr > 10) + .join(tr2.where('d.attr < 100), Inner, Some("tr1.a".attr === "tr2.a".attr)) + .analyze.constraints, Set(tr1.resolveQuoted("a", caseInsensitiveResolution).get > 10, tr2.resolveQuoted("d", caseInsensitiveResolution).get < 100, IsNotNull(tr2.resolveQuoted("a", caseInsensitiveResolution).get), @@ -82,8 +100,10 @@ class ConstraintPropagationSuite extends SparkFunSuite { test("propagating constraints in left-semi join") { val tr1 = LocalRelation('a.int, 'b.int, 'c.int).subquery('tr1) val tr2 = LocalRelation('a.int, 'd.int, 'e.int).subquery('tr2) - verifyConstraints(tr1.where('a.attr > 10).join(tr2.where('d.attr < 100), LeftSemi, - Some("tr1.a".attr === "tr2.a".attr)).analyze.constraints, + verifyConstraints(tr1 + .where('a.attr > 10) + .join(tr2.where('d.attr < 100), LeftSemi, Some("tr1.a".attr === "tr2.a".attr)) + .analyze.constraints, Set(tr1.resolveQuoted("a", caseInsensitiveResolution).get > 10, IsNotNull(tr1.resolveQuoted("a", caseInsensitiveResolution).get))) } @@ -91,8 +111,10 @@ class ConstraintPropagationSuite extends SparkFunSuite { test("propagating constraints in left-outer join") { val tr1 = LocalRelation('a.int, 'b.int, 'c.int).subquery('tr1) val tr2 = LocalRelation('a.int, 'd.int, 'e.int).subquery('tr2) - verifyConstraints(tr1.where('a.attr > 10).join(tr2.where('d.attr < 100), LeftOuter, - Some("tr1.a".attr === "tr2.a".attr)).analyze.constraints, + verifyConstraints(tr1 + .where('a.attr > 10) + .join(tr2.where('d.attr < 100), LeftOuter, Some("tr1.a".attr === "tr2.a".attr)) + .analyze.constraints, Set(tr1.resolveQuoted("a", caseInsensitiveResolution).get > 10, IsNull(tr2.resolveQuoted("a", caseInsensitiveResolution).get), IsNull(tr2.resolveQuoted("d", caseInsensitiveResolution).get), @@ -102,8 +124,10 @@ class ConstraintPropagationSuite extends SparkFunSuite { test("propagating constraints in right-outer join") { val tr1 = LocalRelation('a.int, 'b.int, 'c.int).subquery('tr1) val tr2 = LocalRelation('a.int, 'd.int, 'e.int).subquery('tr2) - verifyConstraints(tr1.where('a.attr > 10).join(tr2.where('d.attr < 100), RightOuter, - Some("tr1.a".attr === "tr2.a".attr)).analyze.constraints, + verifyConstraints(tr1 + .where('a.attr > 10) + .join(tr2.where('d.attr < 100), RightOuter, Some("tr1.a".attr === "tr2.a".attr)) + .analyze.constraints, Set(tr2.resolveQuoted("d", caseInsensitiveResolution).get < 100, IsNull(tr1.resolveQuoted("a", caseInsensitiveResolution).get), IsNull(tr1.resolveQuoted("b", caseInsensitiveResolution).get), @@ -113,8 +137,10 @@ class ConstraintPropagationSuite extends SparkFunSuite { test("propagating constraints in full-outer join") { val tr1 = LocalRelation('a.int, 'b.int, 'c.int).subquery('tr1) val tr2 = LocalRelation('a.int, 'd.int, 'e.int).subquery('tr2) - verifyConstraints(tr1.where('a.attr > 10).join(tr2.where('d.attr < 100), FullOuter, - Some("tr1.a".attr === "tr2.a".attr)).analyze.constraints, + verifyConstraints(tr1 + .where('a.attr > 10) + .join(tr2.where('d.attr < 100), FullOuter, Some("tr1.a".attr === "tr2.a".attr)) + .analyze.constraints, Set(IsNull(tr1.resolveQuoted("a", caseInsensitiveResolution).get), IsNull(tr1.resolveQuoted("b", caseInsensitiveResolution).get), IsNull(tr1.resolveQuoted("c", caseInsensitiveResolution).get), From 302444f5d7940c1e0327fe4826453c1b628a97b9 Mon Sep 17 00:00:00 2001 From: Sameer Agarwal Date: Fri, 29 Jan 2016 18:34:04 -0800 Subject: [PATCH 11/13] Michael's comments --- .../spark/sql/catalyst/plans/QueryPlan.scala | 19 +++-- .../catalyst/plans/logical/LogicalPlan.scala | 4 +- .../plans/logical/basicOperators.scala | 75 +++++++------------ .../plans/ConstraintPropagationSuite.scala | 21 +----- 4 files changed, 48 insertions(+), 71 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala index ed32efe5d7bc2..e50a9ae043a40 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala @@ -28,11 +28,11 @@ abstract class QueryPlan[PlanType <: TreeNode[PlanType]] def output: Seq[Attribute] /** - * Extracts the relevant per-row constraints from a given child while removing those that - * don't apply anymore. + * Extracts the relevant constraints from a given set of constraints based on the attributes that + * appear in the [[outputSet]]. */ - protected def getRelevantConstraints(child: QueryPlan[PlanType]): Set[Expression] = { - child.constraints.filter(_.references.subsetOf(outputSet)) + private def getRelevantConstraints(constraints: Set[Expression]): Set[Expression] = { + constraints.filter(_.references.subsetOf(outputSet)) } /** @@ -40,8 +40,14 @@ abstract class QueryPlan[PlanType <: TreeNode[PlanType]] * operator. For example, if the output of this operator is column `a`, an example `constraints` * can be `Set(a > 10, a < 20)`. */ - lazy val constraints: Set[Expression] = validConstraints + lazy val constraints: Set[Expression] = getRelevantConstraints(validConstraints) + /** + * This method can be overridden by any child class of QueryPlan to specify a set of constraints + * based on the given operator's constraint propagation logic. These constraints are then + * canonicalized and filtered automatically to contain only those attributes that appear in the + * [[outputSet]] + */ protected def validConstraints: Set[Expression] = Set.empty /** @@ -77,6 +83,7 @@ abstract class QueryPlan[PlanType <: TreeNode[PlanType]] * Runs [[transform]] with `rule` on all expressions present in this query operator. * Users should not expect a specific directionality. If a specific directionality is needed, * transformExpressionsDown or transformExpressionsUp should be used. + * * @param rule the rule to be applied to every expression in this operator. */ def transformExpressions(rule: PartialFunction[Expression, Expression]): this.type = { @@ -85,6 +92,7 @@ abstract class QueryPlan[PlanType <: TreeNode[PlanType]] /** * Runs [[transformDown]] with `rule` on all expressions present in this query operator. + * * @param rule the rule to be applied to every expression in this operator. */ def transformExpressionsDown(rule: PartialFunction[Expression, Expression]): this.type = { @@ -117,6 +125,7 @@ abstract class QueryPlan[PlanType <: TreeNode[PlanType]] /** * Runs [[transformUp]] with `rule` on all expressions present in this query operator. + * * @param rule the rule to be applied to every expression in this operator. * @return */ diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala index 791489d193dc0..80dc8a43ef843 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala @@ -306,9 +306,7 @@ abstract class UnaryNode extends LogicalPlan with PredicateHelper { override def children: Seq[LogicalPlan] = child :: Nil - override protected def validConstraints: Set[Expression] = { - getRelevantConstraints(child) - } + override protected def validConstraints: Set[Expression] = child.constraints } /** diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala index 67df495382b40..65c135322c04f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala @@ -94,22 +94,18 @@ case class Filter(condition: Expression, child: LogicalPlan) extends UnaryNode { val newConstraint = splitConjunctivePredicates(condition) .filter(_.references.subsetOf(outputSet)) .toSet - newConstraint.union(getRelevantConstraints(child)) + newConstraint.union(child.constraints) } } abstract class SetOperation(left: LogicalPlan, right: LogicalPlan) extends BinaryNode { - override def getRelevantConstraints(child: QueryPlan[LogicalPlan]): Set[Expression] = { - child.constraints.filter(_.references.subsetOf(child.outputSet)) - } - - protected def leftConstraints: Set[Expression] = getRelevantConstraints(left) + protected def leftConstraints: Set[Expression] = left.constraints protected def rightConstraints: Set[Expression] = { require(left.output.size == right.output.size) val attributeRewrites = AttributeMap(right.output.zip(left.output)) - getRelevantConstraints(right).map(_ transform { + right.constraints.map(_ transform { case a: Attribute => attributeRewrites(a) }) } @@ -186,10 +182,6 @@ case class Union(children: Seq[LogicalPlan]) extends LogicalPlan { Statistics(sizeInBytes = sizeInBytes) } - override def getRelevantConstraints(child: QueryPlan[LogicalPlan]): Set[Expression] = { - child.constraints.filter(_.references.subsetOf(child.outputSet)) - } - def rewriteConstraints( planA: LogicalPlan, planB: LogicalPlan, @@ -203,7 +195,7 @@ case class Union(children: Seq[LogicalPlan]) extends LogicalPlan { override protected def validConstraints: Set[Expression] = { children - .map(child => rewriteConstraints(children.head, child, getRelevantConstraints(child))) + .map(child => rewriteConstraints(children.head, child, child.constraints)) .reduce(_ intersect _) } } @@ -229,51 +221,42 @@ case class Join( } } - override protected def validConstraints: Set[Expression] = { + private def constructIsNotNullConstraints(condition: Expression): Set[Expression] = { // Currently we only propagate constraints if the condition consists of equality // and ranges. For all other cases, we return an empty set of constraints - def constructIsNotNullConstraints(condition: Expression): Set[Expression] = { - splitConjunctivePredicates(condition).map { - case EqualTo(l, r) => - Set(IsNotNull(l), IsNotNull(r)) - case GreaterThan(l, r) => - Set(IsNotNull(l), IsNotNull(r)) - case GreaterThanOrEqual(l, r) => - Set(IsNotNull(l), IsNotNull(r)) - case LessThan(l, r) => - Set(IsNotNull(l), IsNotNull(r)) - case LessThanOrEqual(l, r) => - Set(IsNotNull(l), IsNotNull(r)) - case _ => - Set.empty[Expression] - }.foldLeft(Set.empty[Expression])(_ union _.toSet) - } - - def constructIsNullConstraints(plan: LogicalPlan): Set[Expression] = { - plan.output.map(IsNull).toSet - } + splitConjunctivePredicates(condition).map { + case EqualTo(l, r) => + Set(IsNotNull(l), IsNotNull(r)) + case GreaterThan(l, r) => + Set(IsNotNull(l), IsNotNull(r)) + case GreaterThanOrEqual(l, r) => + Set(IsNotNull(l), IsNotNull(r)) + case LessThan(l, r) => + Set(IsNotNull(l), IsNotNull(r)) + case LessThanOrEqual(l, r) => + Set(IsNotNull(l), IsNotNull(r)) + case _ => + Set.empty[Expression] + }.foldLeft(Set.empty[Expression])(_ union _.toSet) + } - (joinType match { + override protected def validConstraints: Set[Expression] = { + joinType match { case Inner if condition.isDefined => - getRelevantConstraints(left) - .union(getRelevantConstraints(right)) + left.constraints + .union(right.constraints) .union(constructIsNotNullConstraints(condition.get)) case LeftSemi if condition.isDefined => - getRelevantConstraints(left) - .union(getRelevantConstraints(right)) + left.constraints + .union(right.constraints) .union(constructIsNotNullConstraints(condition.get)) case LeftOuter => - getRelevantConstraints(left) - .union(constructIsNullConstraints(right)) + left.constraints case RightOuter => - getRelevantConstraints(right) - .union(constructIsNullConstraints(left)) + right.constraints case FullOuter => - constructIsNullConstraints(left) - .union(constructIsNullConstraints(right)) - case _ => Set.empty[Expression] - }).filter(_.references.subsetOf(outputSet)) + } } def selfJoinResolved: Boolean = left.outputSet.intersect(right.outputSet).isEmpty diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/ConstraintPropagationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/ConstraintPropagationSuite.scala index ed03c0364f694..31995c3c8ad08 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/ConstraintPropagationSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/ConstraintPropagationSuite.scala @@ -115,10 +115,7 @@ class ConstraintPropagationSuite extends SparkFunSuite { .where('a.attr > 10) .join(tr2.where('d.attr < 100), LeftOuter, Some("tr1.a".attr === "tr2.a".attr)) .analyze.constraints, - Set(tr1.resolveQuoted("a", caseInsensitiveResolution).get > 10, - IsNull(tr2.resolveQuoted("a", caseInsensitiveResolution).get), - IsNull(tr2.resolveQuoted("d", caseInsensitiveResolution).get), - IsNull(tr2.resolveQuoted("e", caseInsensitiveResolution).get))) + Set(tr1.resolveQuoted("a", caseInsensitiveResolution).get > 10)) } test("propagating constraints in right-outer join") { @@ -128,24 +125,14 @@ class ConstraintPropagationSuite extends SparkFunSuite { .where('a.attr > 10) .join(tr2.where('d.attr < 100), RightOuter, Some("tr1.a".attr === "tr2.a".attr)) .analyze.constraints, - Set(tr2.resolveQuoted("d", caseInsensitiveResolution).get < 100, - IsNull(tr1.resolveQuoted("a", caseInsensitiveResolution).get), - IsNull(tr1.resolveQuoted("b", caseInsensitiveResolution).get), - IsNull(tr1.resolveQuoted("c", caseInsensitiveResolution).get))) + Set(tr2.resolveQuoted("d", caseInsensitiveResolution).get < 100)) } test("propagating constraints in full-outer join") { val tr1 = LocalRelation('a.int, 'b.int, 'c.int).subquery('tr1) val tr2 = LocalRelation('a.int, 'd.int, 'e.int).subquery('tr2) - verifyConstraints(tr1 - .where('a.attr > 10) + assert(tr1.where('a.attr > 10) .join(tr2.where('d.attr < 100), FullOuter, Some("tr1.a".attr === "tr2.a".attr)) - .analyze.constraints, - Set(IsNull(tr1.resolveQuoted("a", caseInsensitiveResolution).get), - IsNull(tr1.resolveQuoted("b", caseInsensitiveResolution).get), - IsNull(tr1.resolveQuoted("c", caseInsensitiveResolution).get), - IsNull(tr2.resolveQuoted("a", caseInsensitiveResolution).get), - IsNull(tr2.resolveQuoted("d", caseInsensitiveResolution).get), - IsNull(tr2.resolveQuoted("e", caseInsensitiveResolution).get))) + .analyze.constraints.isEmpty) } } From b52742a93581b6aa899b652cf1a346960a2888b7 Mon Sep 17 00:00:00 2001 From: Sameer Agarwal Date: Mon, 1 Feb 2016 16:48:51 -0800 Subject: [PATCH 12/13] move constructIsNotNullConstraints in QueryPlan --- .../spark/sql/catalyst/plans/QueryPlan.scala | 26 ++++++++- .../plans/logical/basicOperators.scala | 34 +++-------- .../plans/ConstraintPropagationSuite.scala | 57 +++++++++++++++---- 3 files changed, 77 insertions(+), 40 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala index e50a9ae043a40..546e635700739 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala @@ -31,8 +31,30 @@ abstract class QueryPlan[PlanType <: TreeNode[PlanType]] * Extracts the relevant constraints from a given set of constraints based on the attributes that * appear in the [[outputSet]]. */ - private def getRelevantConstraints(constraints: Set[Expression]): Set[Expression] = { - constraints.filter(_.references.subsetOf(outputSet)) + protected def getRelevantConstraints(constraints: Set[Expression]): Set[Expression] = { + constraints + .union(constructIsNotNullConstraints(constraints)) + .filter(constraint => + constraint.references.nonEmpty && constraint.references.subsetOf(outputSet)) + } + + private def constructIsNotNullConstraints(constraints: Set[Expression]): Set[Expression] = { + // Currently we only propagate constraints if the condition consists of equality + // and ranges. For all other cases, we return an empty set of constraints + constraints.map { + case EqualTo(l, r) => + Set(IsNotNull(l), IsNotNull(r)) + case GreaterThan(l, r) => + Set(IsNotNull(l), IsNotNull(r)) + case GreaterThanOrEqual(l, r) => + Set(IsNotNull(l), IsNotNull(r)) + case LessThan(l, r) => + Set(IsNotNull(l), IsNotNull(r)) + case LessThanOrEqual(l, r) => + Set(IsNotNull(l), IsNotNull(r)) + case _ => + Set.empty[Expression] + }.foldLeft(Set.empty[Expression])(_ union _.toSet) } /** diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala index 65c135322c04f..83551325fd5ba 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala @@ -91,10 +91,7 @@ case class Filter(condition: Expression, child: LogicalPlan) extends UnaryNode { override def output: Seq[Attribute] = child.output override protected def validConstraints: Set[Expression] = { - val newConstraint = splitConjunctivePredicates(condition) - .filter(_.references.subsetOf(outputSet)) - .toSet - newConstraint.union(child.constraints) + child.constraints.union(splitConjunctivePredicates(condition).toSet) } } @@ -221,35 +218,20 @@ case class Join( } } - private def constructIsNotNullConstraints(condition: Expression): Set[Expression] = { - // Currently we only propagate constraints if the condition consists of equality - // and ranges. For all other cases, we return an empty set of constraints - splitConjunctivePredicates(condition).map { - case EqualTo(l, r) => - Set(IsNotNull(l), IsNotNull(r)) - case GreaterThan(l, r) => - Set(IsNotNull(l), IsNotNull(r)) - case GreaterThanOrEqual(l, r) => - Set(IsNotNull(l), IsNotNull(r)) - case LessThan(l, r) => - Set(IsNotNull(l), IsNotNull(r)) - case LessThanOrEqual(l, r) => - Set(IsNotNull(l), IsNotNull(r)) - case _ => - Set.empty[Expression] - }.foldLeft(Set.empty[Expression])(_ union _.toSet) - } - override protected def validConstraints: Set[Expression] = { joinType match { case Inner if condition.isDefined => left.constraints .union(right.constraints) - .union(constructIsNotNullConstraints(condition.get)) + .union(splitConjunctivePredicates(condition.get).toSet) case LeftSemi if condition.isDefined => left.constraints .union(right.constraints) - .union(constructIsNotNullConstraints(condition.get)) + .union(splitConjunctivePredicates(condition.get).toSet) + case Inner => + left.constraints.union(right.constraints) + case LeftSemi => + left.constraints.union(right.constraints) case LeftOuter => left.constraints case RightOuter => @@ -259,8 +241,6 @@ case class Join( } } - def selfJoinResolved: Boolean = left.outputSet.intersect(right.outputSet).isEmpty - def duplicateResolved: Boolean = left.outputSet.intersect(right.outputSet).isEmpty // Joins are only resolved if they don't introduce ambiguous expression ids. diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/ConstraintPropagationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/ConstraintPropagationSuite.scala index 31995c3c8ad08..b5cf91394d910 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/ConstraintPropagationSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/ConstraintPropagationSuite.scala @@ -29,49 +29,78 @@ class ConstraintPropagationSuite extends SparkFunSuite { private def resolveColumn(tr: LocalRelation, columnName: String): Expression = tr.analyze.resolveQuoted(columnName, caseInsensitiveResolution).get - private def verifyConstraints(a: Set[Expression], b: Set[Expression]): Unit = { - assert(a.forall(i => b.map(_.semanticEquals(i)).reduce(_ || _))) - assert(b.forall(i => a.map(_.semanticEquals(i)).reduce(_ || _))) + private def verifyConstraints(found: Set[Expression], expected: Set[Expression]): Unit = { + val missing = expected.filterNot(i => found.map(_.semanticEquals(i)).reduce(_ || _)) + val extra = found.filterNot(i => expected.map(_.semanticEquals(i)).reduce(_ || _)) + if (missing.nonEmpty || extra.nonEmpty) { + fail( + s""" + |== FAIL: Constraints do not match === + |Found: ${found.mkString(",")} + |Expected: ${expected.mkString(",")} + |== Result == + |Missing: ${if (missing.isEmpty) "N/A" else missing.mkString(",")} + |Found but not expected: ${if (extra.isEmpty) "N/A" else extra.mkString(",")} + """.stripMargin) + } } test("propagating constraints in filters") { val tr = LocalRelation('a.int, 'b.string, 'c.int) + assert(tr.analyze.constraints.isEmpty) + assert(tr.where('a.attr > 10).select('c.attr, 'b.attr).analyze.constraints.isEmpty) - verifyConstraints(tr.where('a.attr > 10).analyze.constraints, Set(resolveColumn(tr, "a") > 10)) + + verifyConstraints(tr + .where('a.attr > 10) + .analyze.constraints, + Set(resolveColumn(tr, "a") > 10, + IsNotNull(resolveColumn(tr, "a")))) + verifyConstraints(tr .where('a.attr > 10) .select('c.attr, 'a.attr) .where('c.attr < 100) .analyze.constraints, - Set(resolveColumn(tr, "a") > 10, resolveColumn(tr, "c") < 100)) + Set(resolveColumn(tr, "a") > 10, + resolveColumn(tr, "c") < 100, + IsNotNull(resolveColumn(tr, "a")), + IsNotNull(resolveColumn(tr, "c")))) } test("propagating constraints in union") { val tr1 = LocalRelation('a.int, 'b.int, 'c.int) val tr2 = LocalRelation('d.int, 'e.int, 'f.int) val tr3 = LocalRelation('g.int, 'h.int, 'i.int) + assert(tr1 .where('a.attr > 10) .unionAll(tr2.where('e.attr > 10) .unionAll(tr3.where('i.attr > 10))) .analyze.constraints.isEmpty) + verifyConstraints(tr1 .where('a.attr > 10) .unionAll(tr2.where('d.attr > 10) .unionAll(tr3.where('g.attr > 10))) .analyze.constraints, - Set(resolveColumn(tr1, "a") > 10)) + Set(resolveColumn(tr1, "a") > 10, + IsNotNull(resolveColumn(tr1, "a")))) } test("propagating constraints in intersect") { val tr1 = LocalRelation('a.int, 'b.int, 'c.int) val tr2 = LocalRelation('a.int, 'b.int, 'c.int) + verifyConstraints(tr1 .where('a.attr > 10) .intersect(tr2.where('b.attr < 100)) .analyze.constraints, - Set(resolveColumn(tr1, "a") > 10, resolveColumn(tr1, "b") < 100)) + Set(resolveColumn(tr1, "a") > 10, + resolveColumn(tr1, "b") < 100, + IsNotNull(resolveColumn(tr1, "a")), + IsNotNull(resolveColumn(tr1, "b")))) } test("propagating constraints in except") { @@ -81,7 +110,8 @@ class ConstraintPropagationSuite extends SparkFunSuite { .where('a.attr > 10) .except(tr2.where('b.attr < 100)) .analyze.constraints, - Set(resolveColumn(tr1, "a") > 10)) + Set(resolveColumn(tr1, "a") > 10, + IsNotNull(resolveColumn(tr1, "a")))) } test("propagating constraints in inner join") { @@ -93,8 +123,11 @@ class ConstraintPropagationSuite extends SparkFunSuite { .analyze.constraints, Set(tr1.resolveQuoted("a", caseInsensitiveResolution).get > 10, tr2.resolveQuoted("d", caseInsensitiveResolution).get < 100, + tr1.resolveQuoted("a", caseInsensitiveResolution).get === + tr2.resolveQuoted("a", caseInsensitiveResolution).get, IsNotNull(tr2.resolveQuoted("a", caseInsensitiveResolution).get), - IsNotNull(tr1.resolveQuoted("a", caseInsensitiveResolution).get))) + IsNotNull(tr1.resolveQuoted("a", caseInsensitiveResolution).get), + IsNotNull(tr2.resolveQuoted("d", caseInsensitiveResolution).get))) } test("propagating constraints in left-semi join") { @@ -115,7 +148,8 @@ class ConstraintPropagationSuite extends SparkFunSuite { .where('a.attr > 10) .join(tr2.where('d.attr < 100), LeftOuter, Some("tr1.a".attr === "tr2.a".attr)) .analyze.constraints, - Set(tr1.resolveQuoted("a", caseInsensitiveResolution).get > 10)) + Set(tr1.resolveQuoted("a", caseInsensitiveResolution).get > 10, + IsNotNull(tr1.resolveQuoted("a", caseInsensitiveResolution).get))) } test("propagating constraints in right-outer join") { @@ -125,7 +159,8 @@ class ConstraintPropagationSuite extends SparkFunSuite { .where('a.attr > 10) .join(tr2.where('d.attr < 100), RightOuter, Some("tr1.a".attr === "tr2.a".attr)) .analyze.constraints, - Set(tr2.resolveQuoted("d", caseInsensitiveResolution).get < 100)) + Set(tr2.resolveQuoted("d", caseInsensitiveResolution).get < 100, + IsNotNull(tr2.resolveQuoted("d", caseInsensitiveResolution).get))) } test("propagating constraints in full-outer join") { From 2bd27356a58eefc52b1f9529e7c708e7cbc4beee Mon Sep 17 00:00:00 2001 From: Sameer Agarwal Date: Tue, 2 Feb 2016 14:35:56 -0800 Subject: [PATCH 13/13] Yin's comments --- .../spark/sql/catalyst/plans/QueryPlan.scala | 8 +++-- .../catalyst/plans/logical/LogicalPlan.scala | 2 +- .../plans/logical/basicOperators.scala | 32 +++++++++++-------- 3 files changed, 26 insertions(+), 16 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala index 546e635700739..05f5bdbfc0769 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala @@ -21,8 +21,7 @@ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.trees.TreeNode import org.apache.spark.sql.types.{DataType, StructType} -abstract class QueryPlan[PlanType <: TreeNode[PlanType]] - extends TreeNode[PlanType] with PredicateHelper { +abstract class QueryPlan[PlanType <: TreeNode[PlanType]] extends TreeNode[PlanType] { self: PlanType => def output: Seq[Attribute] @@ -38,6 +37,11 @@ abstract class QueryPlan[PlanType <: TreeNode[PlanType]] constraint.references.nonEmpty && constraint.references.subsetOf(outputSet)) } + /** + * Infers a set of `isNotNull` constraints from a given set of equality/comparison expressions. + * For e.g., if an expression is of the form (`a > 5`), this returns a constraint of the form + * `isNotNull(a)` + */ private def constructIsNotNullConstraints(constraints: Set[Expression]): Set[Expression] = { // Currently we only propagate constraints if the condition consists of equality // and ranges. For all other cases, we return an empty set of constraints diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala index 80dc8a43ef843..d8944a424156e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala @@ -301,7 +301,7 @@ abstract class LeafNode extends LogicalPlan { /** * A logical plan node with single child. */ -abstract class UnaryNode extends LogicalPlan with PredicateHelper { +abstract class UnaryNode extends LogicalPlan { def child: LogicalPlan override def children: Seq[LogicalPlan] = child :: Nil diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala index 83551325fd5ba..8150ff8434762 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala @@ -87,7 +87,8 @@ case class Generate( } } -case class Filter(condition: Expression, child: LogicalPlan) extends UnaryNode { +case class Filter(condition: Expression, child: LogicalPlan) + extends UnaryNode with PredicateHelper { override def output: Seq[Attribute] = child.output override protected def validConstraints: Set[Expression] = { @@ -179,12 +180,17 @@ case class Union(children: Seq[LogicalPlan]) extends LogicalPlan { Statistics(sizeInBytes = sizeInBytes) } - def rewriteConstraints( - planA: LogicalPlan, - planB: LogicalPlan, + /** + * Maps the constraints containing a given (original) sequence of attributes to those with a + * given (reference) sequence of attributes. Given the nature of union, we expect that the + * mapping between the original and reference sequences are symmetric. + */ + private def rewriteConstraints( + reference: Seq[Attribute], + original: Seq[Attribute], constraints: Set[Expression]): Set[Expression] = { - require(planA.output.size == planB.output.size) - val attributeRewrites = AttributeMap(planB.output.zip(planA.output)) + require(reference.size == original.size) + val attributeRewrites = AttributeMap(original.zip(reference)) constraints.map(_ transform { case a: Attribute => attributeRewrites(a) }) @@ -192,16 +198,17 @@ case class Union(children: Seq[LogicalPlan]) extends LogicalPlan { override protected def validConstraints: Set[Expression] = { children - .map(child => rewriteConstraints(children.head, child, child.constraints)) + .map(child => rewriteConstraints(children.head.output, child.output, child.constraints)) .reduce(_ intersect _) } } case class Join( - left: LogicalPlan, - right: LogicalPlan, - joinType: JoinType, - condition: Option[Expression]) extends BinaryNode { + left: LogicalPlan, + right: LogicalPlan, + joinType: JoinType, + condition: Option[Expression]) + extends BinaryNode with PredicateHelper { override def output: Seq[Attribute] = { joinType match { @@ -226,12 +233,11 @@ case class Join( .union(splitConjunctivePredicates(condition.get).toSet) case LeftSemi if condition.isDefined => left.constraints - .union(right.constraints) .union(splitConjunctivePredicates(condition.get).toSet) case Inner => left.constraints.union(right.constraints) case LeftSemi => - left.constraints.union(right.constraints) + left.constraints case LeftOuter => left.constraints case RightOuter =>