Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SPARK-30353][SQL] Add IsNotNull check in SimplifyBinaryComparison optimization #27008

Closed
wants to merge 9 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
Expand Up @@ -387,19 +387,40 @@ object BooleanSimplification extends Rule[LogicalPlan] with PredicateHelper {
* 2) Replace '=', '<=', and '>=' with 'true' literal if both operands are non-nullable.
* 3) Replace '<' and '>' with 'false' literal if both operands are non-nullable.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We need to update these statements above? IIUC this pr just checks more for non-nullable cases?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for remind, seems not need. Revert this.

*/
object SimplifyBinaryComparison extends Rule[LogicalPlan] with PredicateHelper {
object SimplifyBinaryComparison
extends Rule[LogicalPlan] with PredicateHelper with ConstraintHelper {

private def canSimplifyComparison(
plan: LogicalPlan, left: Expression, right: Expression): Boolean = {
if (left.semanticEquals(right)) {
if (!left.nullable && !right.nullable) {
true
} else {
// We do more checks for non-nullable cases
plan match {
case Filter(fc, _) =>
splitConjunctivePredicates(fc).exists { condition =>
condition.semanticEquals(IsNotNull(left))
}
case _ => false
}
}
} else {
false
}
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How about this?

  private def canSimplifyComparison(
      plan: LogicalPlan, left: Expression, right: Expression): Boolean = {
    if (!left.nullable && !right.nullable && left.semanticEquals(right)) {
      true
    } else {
      // We do more checks for non-nullable cases
      plan match {
        case Filter(fc, _) =>
          splitConjunctivePredicates(fc).exists { condition =>
            condition.semanticEquals(IsNotNull(left)) && condition.semanticEquals(IsNotNull(right))
          }
        case _ =>
          false
      }
    }
  }

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks fine.

}

def apply(plan: LogicalPlan): LogicalPlan = plan transform {
case q: LogicalPlan => q transformExpressionsUp {
case l: LogicalPlan => l transformExpressionsUp {
// True with equality
case a EqualNullSafe b if a.semanticEquals(b) => TrueLiteral
case a EqualTo b if !a.nullable && !b.nullable && a.semanticEquals(b) => TrueLiteral
case a GreaterThanOrEqual b if !a.nullable && !b.nullable && a.semanticEquals(b) =>
TrueLiteral
case a LessThanOrEqual b if !a.nullable && !b.nullable && a.semanticEquals(b) => TrueLiteral
case a EqualTo b if canSimplifyComparison(l, a, b) => TrueLiteral
case a GreaterThanOrEqual b if canSimplifyComparison(l, a, b) => TrueLiteral
case a LessThanOrEqual b if canSimplifyComparison(l, a, b) => TrueLiteral

// False with inequality
case a GreaterThan b if !a.nullable && !b.nullable && a.semanticEquals(b) => FalseLiteral
case a LessThan b if !a.nullable && !b.nullable && a.semanticEquals(b) => FalseLiteral
case a GreaterThan b if canSimplifyComparison(l, a, b) => FalseLiteral
case a LessThan b if canSimplifyComparison(l, a, b) => FalseLiteral
}
}
}
Expand Down
Expand Up @@ -25,6 +25,7 @@ import org.apache.spark.sql.catalyst.expressions.Literal.{FalseLiteral, TrueLite
import org.apache.spark.sql.catalyst.plans.PlanTest
import org.apache.spark.sql.catalyst.plans.logical._
import org.apache.spark.sql.catalyst.rules._
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.types.{IntegerType, StructField, StructType}

class BinaryComparisonSimplificationSuite extends PlanTest with PredicateHelper {
Expand All @@ -33,6 +34,8 @@ class BinaryComparisonSimplificationSuite extends PlanTest with PredicateHelper
val batches =
Batch("AnalysisNodes", Once,
EliminateSubqueryAliases) ::
Batch("Infer Filters", Once,
InferFiltersFromConstraints) ::
Batch("Constant Folding", FixedPoint(50),
NullPropagation,
ConstantFolding,
Expand All @@ -44,12 +47,15 @@ class BinaryComparisonSimplificationSuite extends PlanTest with PredicateHelper
val nullableRelation = LocalRelation('a.int.withNullability(true))
val nonNullableRelation = LocalRelation('a.int.withNullability(false))

test("Preserve nullable exprs in general") {
for (e <- Seq('a === 'a, 'a <= 'a, 'a >= 'a, 'a < 'a, 'a > 'a)) {
val plan = nullableRelation.where(e).analyze
val actual = Optimize.execute(plan)
val correctAnswer = plan
comparePlans(actual, correctAnswer)
test("Preserve nullable exprs when constraintPropagation is false") {
withSQLConf(SQLConf.CONSTRAINT_PROPAGATION_ENABLED.key -> "false") {
val a = Symbol("a")
for (e <- Seq(a === a, a <= a, a >= a, a < a, a > a)) {
val plan = nullableRelation.where(e).analyze
val actual = Optimize.execute(plan)
val correctAnswer = plan
comparePlans(actual, correctAnswer)
}
}
}

Expand Down Expand Up @@ -122,4 +128,51 @@ class BinaryComparisonSimplificationSuite extends PlanTest with PredicateHelper

comparePlans(optimized, correctAnswer)
}

test("Simplify null and nonnull with filter constraints") {
val a = Symbol("a")
Seq(a === a, a <= a, a >= a, a < a, a > a).foreach { condition =>
val plan = nonNullableRelation.where(condition).analyze
val actual = Optimize.execute(plan)
val correctAnswer = nonNullableRelation.analyze
comparePlans(actual, correctAnswer)
}

// infer filter constraints will add IsNotNull
Seq(a === a, a <= a, a >= a).foreach { condition =>
val plan = nullableRelation.where(condition).analyze
val actual = Optimize.execute(plan)
val correctAnswer = nullableRelation.where('a.isNotNull).analyze
comparePlans(actual, correctAnswer)
}

Seq(a < a, a > a).foreach { condition =>
val plan = nullableRelation.where(condition).analyze
val actual = Optimize.execute(plan)
val correctAnswer = nullableRelation.analyze
comparePlans(actual, correctAnswer)
}
}

test("Simplify nullable without constraints propagation") {
withSQLConf(SQLConf.CONSTRAINT_PROPAGATION_ENABLED.key -> "false") {
val a = Symbol("a")
Seq(And(a === a, a.isNotNull),
And(a <= a, a.isNotNull),
And(a >= a, a.isNotNull)).foreach { condition =>
val plan = nullableRelation.where(condition).analyze
val actual = Optimize.execute(plan)
val correctAnswer = nullableRelation.where('a.isNotNull).analyze
comparePlans(actual, correctAnswer)
}

Seq(And(a < a, a.isNotNull), And(a > a, a.isNotNull))
.foreach { condition =>
val plan = nullableRelation.where(condition).analyze
val actual = Optimize.execute(plan)
val correctAnswer = nullableRelation.analyze
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

shouldn't this be nullableRelation.where(false)?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

False has been remove by PruneFilters, so result is just empty LocalRelation.

comparePlans(actual, correctAnswer)
}
}
}
}