diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala index 73df8e6df2bab..2d6637438015a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala @@ -387,19 +387,40 @@ object BooleanSimplification extends Rule[LogicalPlan] with PredicateHelper { * 2) Replace '=', '<=', and '>=' with 'true' literal if both operands are non-nullable. * 3) Replace '<' and '>' with 'false' literal if both operands are non-nullable. */ -object SimplifyBinaryComparison extends Rule[LogicalPlan] with PredicateHelper { +object SimplifyBinaryComparison + extends Rule[LogicalPlan] with PredicateHelper with ConstraintHelper { + + private def canSimplifyComparison( + plan: LogicalPlan, left: Expression, right: Expression): Boolean = { + if (left.semanticEquals(right)) { + if (!left.nullable && !right.nullable) { + true + } else { + // We do more checks for non-nullable cases + plan match { + case Filter(fc, _) => + splitConjunctivePredicates(fc).exists { condition => + condition.semanticEquals(IsNotNull(left)) + } + case _ => false + } + } + } else { + false + } + } + def apply(plan: LogicalPlan): LogicalPlan = plan transform { - case q: LogicalPlan => q transformExpressionsUp { + case l: LogicalPlan => l transformExpressionsUp { // True with equality case a EqualNullSafe b if a.semanticEquals(b) => TrueLiteral - case a EqualTo b if !a.nullable && !b.nullable && a.semanticEquals(b) => TrueLiteral - case a GreaterThanOrEqual b if !a.nullable && !b.nullable && a.semanticEquals(b) => - TrueLiteral - case a LessThanOrEqual b if !a.nullable && !b.nullable && a.semanticEquals(b) => TrueLiteral + case a EqualTo b if canSimplifyComparison(l, a, b) => TrueLiteral + case a GreaterThanOrEqual b if canSimplifyComparison(l, a, b) => TrueLiteral + case a LessThanOrEqual b if canSimplifyComparison(l, a, b) => TrueLiteral // False with inequality - case a GreaterThan b if !a.nullable && !b.nullable && a.semanticEquals(b) => FalseLiteral - case a LessThan b if !a.nullable && !b.nullable && a.semanticEquals(b) => FalseLiteral + case a GreaterThan b if canSimplifyComparison(l, a, b) => FalseLiteral + case a LessThan b if canSimplifyComparison(l, a, b) => FalseLiteral } } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/BinaryComparisonSimplificationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/BinaryComparisonSimplificationSuite.scala index 5794691a365a9..9c71cc8e0d291 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/BinaryComparisonSimplificationSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/BinaryComparisonSimplificationSuite.scala @@ -25,6 +25,7 @@ import org.apache.spark.sql.catalyst.expressions.Literal.{FalseLiteral, TrueLite import org.apache.spark.sql.catalyst.plans.PlanTest import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.rules._ +import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types.{IntegerType, StructField, StructType} class BinaryComparisonSimplificationSuite extends PlanTest with PredicateHelper { @@ -33,6 +34,8 @@ class BinaryComparisonSimplificationSuite extends PlanTest with PredicateHelper val batches = Batch("AnalysisNodes", Once, EliminateSubqueryAliases) :: + Batch("Infer Filters", Once, + InferFiltersFromConstraints) :: Batch("Constant Folding", FixedPoint(50), NullPropagation, ConstantFolding, @@ -44,12 +47,15 @@ class BinaryComparisonSimplificationSuite extends PlanTest with PredicateHelper val nullableRelation = LocalRelation('a.int.withNullability(true)) val nonNullableRelation = LocalRelation('a.int.withNullability(false)) - test("Preserve nullable exprs in general") { - for (e <- Seq('a === 'a, 'a <= 'a, 'a >= 'a, 'a < 'a, 'a > 'a)) { - val plan = nullableRelation.where(e).analyze - val actual = Optimize.execute(plan) - val correctAnswer = plan - comparePlans(actual, correctAnswer) + test("Preserve nullable exprs when constraintPropagation is false") { + withSQLConf(SQLConf.CONSTRAINT_PROPAGATION_ENABLED.key -> "false") { + val a = Symbol("a") + for (e <- Seq(a === a, a <= a, a >= a, a < a, a > a)) { + val plan = nullableRelation.where(e).analyze + val actual = Optimize.execute(plan) + val correctAnswer = plan + comparePlans(actual, correctAnswer) + } } } @@ -122,4 +128,51 @@ class BinaryComparisonSimplificationSuite extends PlanTest with PredicateHelper comparePlans(optimized, correctAnswer) } + + test("Simplify null and nonnull with filter constraints") { + val a = Symbol("a") + Seq(a === a, a <= a, a >= a, a < a, a > a).foreach { condition => + val plan = nonNullableRelation.where(condition).analyze + val actual = Optimize.execute(plan) + val correctAnswer = nonNullableRelation.analyze + comparePlans(actual, correctAnswer) + } + + // infer filter constraints will add IsNotNull + Seq(a === a, a <= a, a >= a).foreach { condition => + val plan = nullableRelation.where(condition).analyze + val actual = Optimize.execute(plan) + val correctAnswer = nullableRelation.where('a.isNotNull).analyze + comparePlans(actual, correctAnswer) + } + + Seq(a < a, a > a).foreach { condition => + val plan = nullableRelation.where(condition).analyze + val actual = Optimize.execute(plan) + val correctAnswer = nullableRelation.analyze + comparePlans(actual, correctAnswer) + } + } + + test("Simplify nullable without constraints propagation") { + withSQLConf(SQLConf.CONSTRAINT_PROPAGATION_ENABLED.key -> "false") { + val a = Symbol("a") + Seq(And(a === a, a.isNotNull), + And(a <= a, a.isNotNull), + And(a >= a, a.isNotNull)).foreach { condition => + val plan = nullableRelation.where(condition).analyze + val actual = Optimize.execute(plan) + val correctAnswer = nullableRelation.where('a.isNotNull).analyze + comparePlans(actual, correctAnswer) + } + + Seq(And(a < a, a.isNotNull), And(a > a, a.isNotNull)) + .foreach { condition => + val plan = nullableRelation.where(condition).analyze + val actual = Optimize.execute(plan) + val correctAnswer = nullableRelation.analyze + comparePlans(actual, correctAnswer) + } + } + } }