Skip to content

Commit

Permalink
[SPARK-25714][BACKPORT-2.3] Fix Null Handling in the Optimizer rule B…
Browse files Browse the repository at this point in the history
…ooleanSimplification

This PR is to backport #22702 to branch 2.3.

---

## What changes were proposed in this pull request?
```Scala
    val df1 = Seq(("abc", 1), (null, 3)).toDF("col1", "col2")
    df1.write.mode(SaveMode.Overwrite).parquet("/tmp/test1")
    val df2 = spark.read.parquet("/tmp/test1")
    df2.filter("col1 = 'abc' OR (col1 != 'abc' AND col2 == 3)").show()
```

Before the PR, it returns both rows. After the fix, it returns `Row ("abc", 1))`. This is to fix the bug in NULL handling in BooleanSimplification. This is a bug introduced in Spark 1.6 release.

## How was this patch tested?
Added test cases

Closes #22718 from gatorsmile/cherrypickSPARK-25714.

Authored-by: gatorsmile <gatorsmile@gmail.com>
Signed-off-by: Wenchen Fan <wenchen@databricks.com>
  • Loading branch information
gatorsmile authored and cloud-fan committed Oct 16, 2018
1 parent 1e15998 commit d87896b
Show file tree
Hide file tree
Showing 3 changed files with 153 additions and 33 deletions.
Expand Up @@ -120,6 +120,13 @@ case class Not(child: Expression)

override def inputTypes: Seq[DataType] = Seq(BooleanType)

// +---------+-----------+
// | CHILD | NOT CHILD |
// +---------+-----------+
// | TRUE | FALSE |
// | FALSE | TRUE |
// | UNKNOWN | UNKNOWN |
// +---------+-----------+
protected override def nullSafeEval(input: Any): Any = !input.asInstanceOf[Boolean]

override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
Expand Down Expand Up @@ -374,6 +381,13 @@ case class And(left: Expression, right: Expression) extends BinaryOperator with

override def sqlOperator: String = "AND"

// +---------+---------+---------+---------+
// | AND | TRUE | FALSE | UNKNOWN |
// +---------+---------+---------+---------+
// | TRUE | TRUE | FALSE | UNKNOWN |
// | FALSE | FALSE | FALSE | FALSE |
// | UNKNOWN | UNKNOWN | FALSE | UNKNOWN |
// +---------+---------+---------+---------+
override def eval(input: InternalRow): Any = {
val input1 = left.eval(input)
if (input1 == false) {
Expand Down Expand Up @@ -437,6 +451,13 @@ case class Or(left: Expression, right: Expression) extends BinaryOperator with P

override def sqlOperator: String = "OR"

// +---------+---------+---------+---------+
// | OR | TRUE | FALSE | UNKNOWN |
// +---------+---------+---------+---------+
// | TRUE | TRUE | TRUE | TRUE |
// | FALSE | TRUE | FALSE | UNKNOWN |
// | UNKNOWN | TRUE | UNKNOWN | UNKNOWN |
// +---------+---------+---------+---------+
override def eval(input: InternalRow): Any = {
val input1 = left.eval(input)
if (input1 == true) {
Expand Down Expand Up @@ -560,6 +581,13 @@ case class EqualTo(left: Expression, right: Expression)

override def symbol: String = "="

// +---------+---------+---------+---------+
// | = | TRUE | FALSE | UNKNOWN |
// +---------+---------+---------+---------+
// | TRUE | TRUE | FALSE | UNKNOWN |
// | FALSE | FALSE | TRUE | UNKNOWN |
// | UNKNOWN | UNKNOWN | UNKNOWN | UNKNOWN |
// +---------+---------+---------+---------+
protected override def nullSafeEval(left: Any, right: Any): Any = ordering.equiv(left, right)

override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
Expand Down Expand Up @@ -597,6 +625,13 @@ case class EqualNullSafe(left: Expression, right: Expression) extends BinaryComp

override def nullable: Boolean = false

// +---------+---------+---------+---------+
// | <=> | TRUE | FALSE | UNKNOWN |
// +---------+---------+---------+---------+
// | TRUE | TRUE | FALSE | UNKNOWN |
// | FALSE | FALSE | TRUE | UNKNOWN |
// | UNKNOWN | UNKNOWN | UNKNOWN | TRUE |
// +---------+---------+---------+---------+
override def eval(input: InternalRow): Any = {
val input1 = left.eval(input)
val input2 = right.eval(input)
Expand Down
Expand Up @@ -268,15 +268,37 @@ object BooleanSimplification extends Rule[LogicalPlan] with PredicateHelper {
case a And b if a.semanticEquals(b) => a
case a Or b if a.semanticEquals(b) => a

case a And (b Or c) if Not(a).semanticEquals(b) => And(a, c)
case a And (b Or c) if Not(a).semanticEquals(c) => And(a, b)
case (a Or b) And c if a.semanticEquals(Not(c)) => And(b, c)
case (a Or b) And c if b.semanticEquals(Not(c)) => And(a, c)

case a Or (b And c) if Not(a).semanticEquals(b) => Or(a, c)
case a Or (b And c) if Not(a).semanticEquals(c) => Or(a, b)
case (a And b) Or c if a.semanticEquals(Not(c)) => Or(b, c)
case (a And b) Or c if b.semanticEquals(Not(c)) => Or(a, c)
// The following optimizations are applicable only when the operands are not nullable,
// since the three-value logic of AND and OR are different in NULL handling.
// See the chart:
// +---------+---------+---------+---------+
// | operand | operand | OR | AND |
// +---------+---------+---------+---------+
// | TRUE | TRUE | TRUE | TRUE |
// | TRUE | FALSE | TRUE | FALSE |
// | FALSE | FALSE | FALSE | FALSE |
// | UNKNOWN | TRUE | TRUE | UNKNOWN |
// | UNKNOWN | FALSE | UNKNOWN | FALSE |
// | UNKNOWN | UNKNOWN | UNKNOWN | UNKNOWN |
// +---------+---------+---------+---------+

// (NULL And (NULL Or FALSE)) = NULL, but (NULL And FALSE) = FALSE. Thus, a can't be nullable.
case a And (b Or c) if !a.nullable && Not(a).semanticEquals(b) => And(a, c)
// (NULL And (FALSE Or NULL)) = NULL, but (NULL And FALSE) = FALSE. Thus, a can't be nullable.
case a And (b Or c) if !a.nullable && Not(a).semanticEquals(c) => And(a, b)
// ((NULL Or FALSE) And NULL) = NULL, but (FALSE And NULL) = FALSE. Thus, c can't be nullable.
case (a Or b) And c if !c.nullable && a.semanticEquals(Not(c)) => And(b, c)
// ((FALSE Or NULL) And NULL) = NULL, but (FALSE And NULL) = FALSE. Thus, c can't be nullable.
case (a Or b) And c if !c.nullable && b.semanticEquals(Not(c)) => And(a, c)

// (NULL Or (NULL And TRUE)) = NULL, but (NULL Or TRUE) = TRUE. Thus, a can't be nullable.
case a Or (b And c) if !a.nullable && Not(a).semanticEquals(b) => Or(a, c)
// (NULL Or (TRUE And NULL)) = NULL, but (NULL Or TRUE) = TRUE. Thus, a can't be nullable.
case a Or (b And c) if !a.nullable && Not(a).semanticEquals(c) => Or(a, b)
// ((NULL And TRUE) Or NULL) = NULL, but (TRUE Or NULL) = TRUE. Thus, c can't be nullable.
case (a And b) Or c if !c.nullable && a.semanticEquals(Not(c)) => Or(b, c)
// ((TRUE And NULL) Or NULL) = NULL, but (TRUE Or NULL) = TRUE. Thus, c can't be nullable.
case (a And b) Or c if !c.nullable && b.semanticEquals(Not(c)) => Or(a, c)

// Common factor elimination for conjunction
case and @ (left And right) =>
Expand Down
Expand Up @@ -29,7 +29,7 @@ import org.apache.spark.sql.catalyst.rules._
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.types.BooleanType

class BooleanSimplificationSuite extends PlanTest with PredicateHelper {
class BooleanSimplificationSuite extends PlanTest with ExpressionEvalHelper with PredicateHelper {

object Optimize extends RuleExecutor[LogicalPlan] {
val batches =
Expand Down Expand Up @@ -71,6 +71,14 @@ class BooleanSimplificationSuite extends PlanTest with PredicateHelper {
comparePlans(actual, correctAnswer)
}

private def checkConditionInNotNullableRelation(
input: Expression, expected: Expression): Unit = {
val plan = testNotNullableRelationWithData.where(input).analyze
val actual = Optimize.execute(plan)
val correctAnswer = testNotNullableRelationWithData.where(expected).analyze
comparePlans(actual, correctAnswer)
}

private def checkConditionInNotNullableRelation(
input: Expression, expected: LogicalPlan): Unit = {
val plan = testNotNullableRelationWithData.where(input).analyze
Expand Down Expand Up @@ -119,42 +127,55 @@ class BooleanSimplificationSuite extends PlanTest with PredicateHelper {
'a === 'b || 'b > 3 && 'a > 3 && 'a < 5)
}

test("e && (!e || f)") {
checkCondition('e && (!'e || 'f ), 'e && 'f)
test("e && (!e || f) - not nullable") {
checkConditionInNotNullableRelation('e && (!'e || 'f ), 'e && 'f)

checkCondition('e && ('f || !'e ), 'e && 'f)
checkConditionInNotNullableRelation('e && ('f || !'e ), 'e && 'f)

checkCondition((!'e || 'f ) && 'e, 'f && 'e)
checkConditionInNotNullableRelation((!'e || 'f ) && 'e, 'f && 'e)

checkCondition(('f || !'e ) && 'e, 'f && 'e)
checkConditionInNotNullableRelation(('f || !'e ) && 'e, 'f && 'e)
}

test("a < 1 && (!(a < 1) || f)") {
checkCondition('a < 1 && (!('a < 1) || 'f), ('a < 1) && 'f)
checkCondition('a < 1 && ('f || !('a < 1)), ('a < 1) && 'f)
test("e && (!e || f) - nullable") {
Seq ('e && (!'e || 'f ),
'e && ('f || !'e ),
(!'e || 'f ) && 'e,
('f || !'e ) && 'e,
'e || (!'e && 'f),
'e || ('f && !'e),
('e && 'f) || !'e,
('f && 'e) || !'e).foreach { expr =>
checkCondition(expr, expr)
}
}

checkCondition('a <= 1 && (!('a <= 1) || 'f), ('a <= 1) && 'f)
checkCondition('a <= 1 && ('f || !('a <= 1)), ('a <= 1) && 'f)
test("a < 1 && (!(a < 1) || f) - not nullable") {
checkConditionInNotNullableRelation('a < 1 && (!('a < 1) || 'f), ('a < 1) && 'f)
checkConditionInNotNullableRelation('a < 1 && ('f || !('a < 1)), ('a < 1) && 'f)

checkCondition('a > 1 && (!('a > 1) || 'f), ('a > 1) && 'f)
checkCondition('a > 1 && ('f || !('a > 1)), ('a > 1) && 'f)
checkConditionInNotNullableRelation('a <= 1 && (!('a <= 1) || 'f), ('a <= 1) && 'f)
checkConditionInNotNullableRelation('a <= 1 && ('f || !('a <= 1)), ('a <= 1) && 'f)

checkCondition('a >= 1 && (!('a >= 1) || 'f), ('a >= 1) && 'f)
checkCondition('a >= 1 && ('f || !('a >= 1)), ('a >= 1) && 'f)
checkConditionInNotNullableRelation('a > 1 && (!('a > 1) || 'f), ('a > 1) && 'f)
checkConditionInNotNullableRelation('a > 1 && ('f || !('a > 1)), ('a > 1) && 'f)

checkConditionInNotNullableRelation('a >= 1 && (!('a >= 1) || 'f), ('a >= 1) && 'f)
checkConditionInNotNullableRelation('a >= 1 && ('f || !('a >= 1)), ('a >= 1) && 'f)
}

test("a < 1 && ((a >= 1) || f)") {
checkCondition('a < 1 && ('a >= 1 || 'f ), ('a < 1) && 'f)
checkCondition('a < 1 && ('f || 'a >= 1), ('a < 1) && 'f)
test("a < 1 && ((a >= 1) || f) - not nullable") {
checkConditionInNotNullableRelation('a < 1 && ('a >= 1 || 'f ), ('a < 1) && 'f)
checkConditionInNotNullableRelation('a < 1 && ('f || 'a >= 1), ('a < 1) && 'f)

checkCondition('a <= 1 && ('a > 1 || 'f ), ('a <= 1) && 'f)
checkCondition('a <= 1 && ('f || 'a > 1), ('a <= 1) && 'f)
checkConditionInNotNullableRelation('a <= 1 && ('a > 1 || 'f ), ('a <= 1) && 'f)
checkConditionInNotNullableRelation('a <= 1 && ('f || 'a > 1), ('a <= 1) && 'f)

checkCondition('a > 1 && (('a <= 1) || 'f), ('a > 1) && 'f)
checkCondition('a > 1 && ('f || ('a <= 1)), ('a > 1) && 'f)
checkConditionInNotNullableRelation('a > 1 && (('a <= 1) || 'f), ('a > 1) && 'f)
checkConditionInNotNullableRelation('a > 1 && ('f || ('a <= 1)), ('a > 1) && 'f)

checkCondition('a >= 1 && (('a < 1) || 'f), ('a >= 1) && 'f)
checkCondition('a >= 1 && ('f || ('a < 1)), ('a >= 1) && 'f)
checkConditionInNotNullableRelation('a >= 1 && (('a < 1) || 'f), ('a >= 1) && 'f)
checkConditionInNotNullableRelation('a >= 1 && ('f || ('a < 1)), ('a >= 1) && 'f)
}

test("DeMorgan's law") {
Expand Down Expand Up @@ -217,4 +238,46 @@ class BooleanSimplificationSuite extends PlanTest with PredicateHelper {
checkCondition('e || !'f, testRelationWithData.where('e || !'f).analyze)
checkCondition(!'f || 'e, testRelationWithData.where(!'f || 'e).analyze)
}

protected def assertEquivalent(e1: Expression, e2: Expression): Unit = {
val correctAnswer = Project(Alias(e2, "out")() :: Nil, OneRowRelation()).analyze
val actual = Optimize.execute(Project(Alias(e1, "out")() :: Nil, OneRowRelation()).analyze)
comparePlans(actual, correctAnswer)
}

test("filter reduction - positive cases") {
val fields = Seq(
'col1NotNULL.boolean.notNull,
'col2NotNULL.boolean.notNull
)
val Seq(col1NotNULL, col2NotNULL) = fields.zipWithIndex.map { case (f, i) => f.at(i) }

val exprs = Seq(
// actual expressions of the transformations: original -> transformed
(col1NotNULL && (!col1NotNULL || col2NotNULL)) -> (col1NotNULL && col2NotNULL),
(col1NotNULL && (col2NotNULL || !col1NotNULL)) -> (col1NotNULL && col2NotNULL),
((!col1NotNULL || col2NotNULL) && col1NotNULL) -> (col2NotNULL && col1NotNULL),
((col2NotNULL || !col1NotNULL) && col1NotNULL) -> (col2NotNULL && col1NotNULL),

(col1NotNULL || (!col1NotNULL && col2NotNULL)) -> (col1NotNULL || col2NotNULL),
(col1NotNULL || (col2NotNULL && !col1NotNULL)) -> (col1NotNULL || col2NotNULL),
((!col1NotNULL && col2NotNULL) || col1NotNULL) -> (col2NotNULL || col1NotNULL),
((col2NotNULL && !col1NotNULL) || col1NotNULL) -> (col2NotNULL || col1NotNULL)
)

// check plans
for ((originalExpr, expectedExpr) <- exprs) {
assertEquivalent(originalExpr, expectedExpr)
}

// check evaluation
val binaryBooleanValues = Seq(true, false)
for (col1NotNULLVal <- binaryBooleanValues;
col2NotNULLVal <- binaryBooleanValues;
(originalExpr, expectedExpr) <- exprs) {
val inputRow = create_row(col1NotNULLVal, col2NotNULLVal)
val optimizedVal = evaluate(expectedExpr, inputRow)
checkEvaluation(originalExpr, optimizedVal, inputRow)
}
}
}

0 comments on commit d87896b

Please sign in to comment.