diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala index a816922f49aee..51d2a73ea97b7 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala @@ -118,19 +118,23 @@ object ExtractEquiJoinKeys extends Logging with PredicateHelper { // Replace null with default value for joining key, then those rows with null in it could // be joined together case EqualNullSafe(l, r) if canEvaluate(l, left) && canEvaluate(r, right) => - Some((Coalesce(Seq(l, Literal.default(l.dataType))), - Coalesce(Seq(r, Literal.default(r.dataType))))) + Seq((Coalesce(Seq(l, Literal.default(l.dataType))), + Coalesce(Seq(r, Literal.default(r.dataType)))), + (IsNull(l), IsNull(r)) + ) case EqualNullSafe(l, r) if canEvaluate(l, right) && canEvaluate(r, left) => - Some((Coalesce(Seq(r, Literal.default(r.dataType))), - Coalesce(Seq(l, Literal.default(l.dataType))))) + Seq((Coalesce(Seq(r, Literal.default(r.dataType))), + Coalesce(Seq(l, Literal.default(l.dataType)))), + (IsNull(r), IsNull(l)) + ) case other => None } val otherPredicates = predicates.filterNot { case EqualTo(l, r) if l.references.isEmpty || r.references.isEmpty => false - case EqualTo(l, r) => + case Equality(l, r) => canEvaluate(l, left) && canEvaluate(r, right) || canEvaluate(l, right) && canEvaluate(r, left) - case other => false + case _ => false } if (joinKeys.nonEmpty) { diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/NormalizeFloatingPointNumbersSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/NormalizeFloatingPointNumbersSuite.scala index 5f616da2978bb..f5af416602c9d 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/NormalizeFloatingPointNumbersSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/NormalizeFloatingPointNumbersSuite.scala @@ -19,7 +19,7 @@ package org.apache.spark.sql.catalyst.optimizer import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.catalyst.dsl.plans._ -import org.apache.spark.sql.catalyst.expressions.KnownFloatingPointNormalized +import org.apache.spark.sql.catalyst.expressions.{And, IsNull, KnownFloatingPointNormalized} import org.apache.spark.sql.catalyst.plans.PlanTest import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.rules.RuleExecutor @@ -78,5 +78,18 @@ class NormalizeFloatingPointNumbersSuite extends PlanTest { comparePlans(doubleOptimized, correctAnswer) } + + test("normalize floating points in join keys (equal null safe) - idempotence") { + val query = testRelation1.join(testRelation2, condition = Some(a <=> b)) + + val optimized = Optimize.execute(query) + val doubleOptimized = Optimize.execute(optimized) + val joinCond = IsNull(a) === IsNull(b) && + KnownFloatingPointNormalized(NormalizeNaNAndZero(coalesce(a, 0.0))) === + KnownFloatingPointNormalized(NormalizeNaNAndZero(coalesce(b, 0.0))) + val correctAnswer = testRelation1.join(testRelation2, condition = Some(joinCond)) + + comparePlans(doubleOptimized, correctAnswer) + } }