Skip to content

Commit

Permalink
[SPARK-28306][SQL][FOLLOWUP] Fix NormalizeFloatingNumbers rule idempo…
Browse files Browse the repository at this point in the history
…tence for equi-join with `<=>` predicates

## What changes were proposed in this pull request?
Idempotence of the `NormalizeFloatingNumbers` rule was broken due to the implementation of `ExtractEquiJoinKeys`. There is no reason that we don't remove `EqualNullSafe` join keys from an equi-join's `otherPredicates`.

## How was this patch tested?
A new UT.

Closes #25126 from yeshengm/spark-28306.

Authored-by: Yesheng Ma <kimi.ysma@gmail.com>
Signed-off-by: gatorsmile <gatorsmile@gmail.com>
  • Loading branch information
yeshengm authored and gatorsmile committed Jul 15, 2019
1 parent 8d1e87a commit 2f3997f
Show file tree
Hide file tree
Showing 2 changed files with 24 additions and 7 deletions.
Expand Up @@ -118,19 +118,23 @@ object ExtractEquiJoinKeys extends Logging with PredicateHelper {
// Replace null with default value for joining key, then those rows with null in it could
// be joined together
case EqualNullSafe(l, r) if canEvaluate(l, left) && canEvaluate(r, right) =>
Some((Coalesce(Seq(l, Literal.default(l.dataType))),
Coalesce(Seq(r, Literal.default(r.dataType)))))
Seq((Coalesce(Seq(l, Literal.default(l.dataType))),
Coalesce(Seq(r, Literal.default(r.dataType)))),
(IsNull(l), IsNull(r))
)
case EqualNullSafe(l, r) if canEvaluate(l, right) && canEvaluate(r, left) =>
Some((Coalesce(Seq(r, Literal.default(r.dataType))),
Coalesce(Seq(l, Literal.default(l.dataType)))))
Seq((Coalesce(Seq(r, Literal.default(r.dataType))),
Coalesce(Seq(l, Literal.default(l.dataType)))),
(IsNull(r), IsNull(l))
)
case other => None
}
val otherPredicates = predicates.filterNot {
case EqualTo(l, r) if l.references.isEmpty || r.references.isEmpty => false
case EqualTo(l, r) =>
case Equality(l, r) =>
canEvaluate(l, left) && canEvaluate(r, right) ||
canEvaluate(l, right) && canEvaluate(r, left)
case other => false
case _ => false
}

if (joinKeys.nonEmpty) {
Expand Down
Expand Up @@ -19,7 +19,7 @@ package org.apache.spark.sql.catalyst.optimizer

import org.apache.spark.sql.catalyst.dsl.expressions._
import org.apache.spark.sql.catalyst.dsl.plans._
import org.apache.spark.sql.catalyst.expressions.KnownFloatingPointNormalized
import org.apache.spark.sql.catalyst.expressions.{And, IsNull, KnownFloatingPointNormalized}
import org.apache.spark.sql.catalyst.plans.PlanTest
import org.apache.spark.sql.catalyst.plans.logical._
import org.apache.spark.sql.catalyst.rules.RuleExecutor
Expand Down Expand Up @@ -78,5 +78,18 @@ class NormalizeFloatingPointNumbersSuite extends PlanTest {

comparePlans(doubleOptimized, correctAnswer)
}

test("normalize floating points in join keys (equal null safe) - idempotence") {
val query = testRelation1.join(testRelation2, condition = Some(a <=> b))

val optimized = Optimize.execute(query)
val doubleOptimized = Optimize.execute(optimized)
val joinCond = IsNull(a) === IsNull(b) &&
KnownFloatingPointNormalized(NormalizeNaNAndZero(coalesce(a, 0.0))) ===
KnownFloatingPointNormalized(NormalizeNaNAndZero(coalesce(b, 0.0)))
val correctAnswer = testRelation1.join(testRelation2, condition = Some(joinCond))

comparePlans(doubleOptimized, correctAnswer)
}
}

0 comments on commit 2f3997f

Please sign in to comment.