diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/ReplaceNullWithFalseInPredicate.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/ReplaceNullWithFalseInPredicate.scala index ef3de4738c75c..698ece4f9e69f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/ReplaceNullWithFalseInPredicate.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/ReplaceNullWithFalseInPredicate.scala @@ -20,7 +20,7 @@ package org.apache.spark.sql.catalyst.optimizer import org.apache.spark.sql.catalyst.expressions.{And, ArrayExists, ArrayFilter, CaseWhen, Expression, If} import org.apache.spark.sql.catalyst.expressions.{LambdaFunction, Literal, MapFilter, Or} import org.apache.spark.sql.catalyst.expressions.Literal.FalseLiteral -import org.apache.spark.sql.catalyst.plans.logical.{Filter, Join, LogicalPlan} +import org.apache.spark.sql.catalyst.plans.logical.{DeleteFromTable, Filter, Join, LogicalPlan} import org.apache.spark.sql.catalyst.rules.Rule import org.apache.spark.sql.types.BooleanType import org.apache.spark.util.Utils @@ -53,6 +53,7 @@ object ReplaceNullWithFalseInPredicate extends Rule[LogicalPlan] { def apply(plan: LogicalPlan): LogicalPlan = plan transform { case f @ Filter(cond, _) => f.copy(condition = replaceNullWithFalse(cond)) case j @ Join(_, _, _, Some(cond), _) => j.copy(condition = Some(replaceNullWithFalse(cond))) + case d @ DeleteFromTable(_, Some(cond)) => d.copy(condition = Some(replaceNullWithFalse(cond))) case p: LogicalPlan => p transformExpressions { case i @ If(pred, _, _) => i.copy(predicate = replaceNullWithFalse(pred)) case cw @ CaseWhen(branches, _) => diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ReplaceNullWithFalseInPredicateSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ReplaceNullWithFalseInPredicateSuite.scala index eb52c5b74772c..6fc31c94e47eb 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ReplaceNullWithFalseInPredicateSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ReplaceNullWithFalseInPredicateSuite.scala @@ -24,7 +24,7 @@ import org.apache.spark.sql.catalyst.dsl.plans._ import org.apache.spark.sql.catalyst.expressions.{And, ArrayExists, ArrayFilter, ArrayTransform, CaseWhen, Expression, GreaterThan, If, LambdaFunction, Literal, MapFilter, NamedExpression, Or, UnresolvedNamedLambdaVariable} import org.apache.spark.sql.catalyst.expressions.Literal.{FalseLiteral, TrueLiteral} import org.apache.spark.sql.catalyst.plans.{Inner, PlanTest} -import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, LogicalPlan} +import org.apache.spark.sql.catalyst.plans.logical.{DeleteFromTable, LocalRelation, LogicalPlan} import org.apache.spark.sql.catalyst.rules.RuleExecutor import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types.{BooleanType, IntegerType} @@ -48,6 +48,7 @@ class ReplaceNullWithFalseInPredicateSuite extends PlanTest { test("replace null inside filter and join conditions") { testFilter(originalCond = Literal(null, BooleanType), expectedCond = FalseLiteral) testJoin(originalCond = Literal(null, BooleanType), expectedCond = FalseLiteral) + testDelete(originalCond = Literal(null, BooleanType), expectedCond = FalseLiteral) } test("Not expected type - replaceNullWithFalse") { @@ -64,6 +65,7 @@ class ReplaceNullWithFalseInPredicateSuite extends PlanTest { Literal(null, BooleanType)) testFilter(originalCond, expectedCond = FalseLiteral) testJoin(originalCond, expectedCond = FalseLiteral) + testDelete(originalCond, expectedCond = FalseLiteral) } test("replace nulls in nested expressions in branches of If") { @@ -73,6 +75,7 @@ class ReplaceNullWithFalseInPredicateSuite extends PlanTest { UnresolvedAttribute("b") && Literal(null, BooleanType)) testFilter(originalCond, expectedCond = FalseLiteral) testJoin(originalCond, expectedCond = FalseLiteral) + testDelete(originalCond, expectedCond = FalseLiteral) } test("replace null in elseValue of CaseWhen") { @@ -83,6 +86,7 @@ class ReplaceNullWithFalseInPredicateSuite extends PlanTest { val expectedCond = CaseWhen(branches, FalseLiteral) testFilter(originalCond, expectedCond) testJoin(originalCond, expectedCond) + testDelete(originalCond, expectedCond) } test("replace null in branch values of CaseWhen") { @@ -92,6 +96,7 @@ class ReplaceNullWithFalseInPredicateSuite extends PlanTest { val originalCond = CaseWhen(branches, Literal(null)) testFilter(originalCond, expectedCond = FalseLiteral) testJoin(originalCond, expectedCond = FalseLiteral) + testDelete(originalCond, expectedCond = FalseLiteral) } test("replace null in branches of If inside CaseWhen") { @@ -108,6 +113,7 @@ class ReplaceNullWithFalseInPredicateSuite extends PlanTest { testFilter(originalCond, expectedCond) testJoin(originalCond, expectedCond) + testDelete(originalCond, expectedCond) } test("replace null in complex CaseWhen expressions") { @@ -127,6 +133,7 @@ class ReplaceNullWithFalseInPredicateSuite extends PlanTest { testFilter(originalCond, expectedCond) testJoin(originalCond, expectedCond) + testDelete(originalCond, expectedCond) } test("replace null in Or") { @@ -134,12 +141,14 @@ class ReplaceNullWithFalseInPredicateSuite extends PlanTest { val expectedCond = UnresolvedAttribute("b") testFilter(originalCond, expectedCond) testJoin(originalCond, expectedCond) + testDelete(originalCond, expectedCond) } test("replace null in And") { val originalCond = And(UnresolvedAttribute("b"), Literal(null)) testFilter(originalCond, expectedCond = FalseLiteral) testJoin(originalCond, expectedCond = FalseLiteral) + testDelete(originalCond, expectedCond = FalseLiteral) } test("replace nulls in nested And/Or expressions") { @@ -148,6 +157,7 @@ class ReplaceNullWithFalseInPredicateSuite extends PlanTest { Or(Literal(null), And(Literal(null), And(UnresolvedAttribute("b"), Literal(null))))) testFilter(originalCond, expectedCond = FalseLiteral) testJoin(originalCond, expectedCond = FalseLiteral) + testDelete(originalCond, expectedCond = FalseLiteral) } test("replace null in And inside branches of If") { @@ -157,6 +167,7 @@ class ReplaceNullWithFalseInPredicateSuite extends PlanTest { And(UnresolvedAttribute("b"), Literal(null, BooleanType))) testFilter(originalCond, expectedCond = FalseLiteral) testJoin(originalCond, expectedCond = FalseLiteral) + testDelete(originalCond, expectedCond = FalseLiteral) } test("replace null in branches of If inside And") { @@ -168,6 +179,7 @@ class ReplaceNullWithFalseInPredicateSuite extends PlanTest { And(FalseLiteral, UnresolvedAttribute("b")))) testFilter(originalCond, expectedCond = FalseLiteral) testJoin(originalCond, expectedCond = FalseLiteral) + testDelete(originalCond, expectedCond = FalseLiteral) } test("replace null in branches of If inside another If") { @@ -177,6 +189,7 @@ class ReplaceNullWithFalseInPredicateSuite extends PlanTest { Literal(null)) testFilter(originalCond, expectedCond = FalseLiteral) testJoin(originalCond, expectedCond = FalseLiteral) + testDelete(originalCond, expectedCond = FalseLiteral) } test("replace null in CaseWhen inside another CaseWhen") { @@ -184,6 +197,7 @@ class ReplaceNullWithFalseInPredicateSuite extends PlanTest { val originalCond = CaseWhen(Seq(nestedCaseWhen -> TrueLiteral), Literal(null)) testFilter(originalCond, expectedCond = FalseLiteral) testJoin(originalCond, expectedCond = FalseLiteral) + testDelete(originalCond, expectedCond = FalseLiteral) } test("inability to replace null in non-boolean branches of If") { @@ -196,6 +210,7 @@ class ReplaceNullWithFalseInPredicateSuite extends PlanTest { FalseLiteral) testFilter(originalCond = condition, expectedCond = condition) testJoin(originalCond = condition, expectedCond = condition) + testDelete(originalCond = condition, expectedCond = condition) } test("inability to replace null in non-boolean values of CaseWhen") { @@ -210,6 +225,7 @@ class ReplaceNullWithFalseInPredicateSuite extends PlanTest { val condition = CaseWhen(branches) testFilter(originalCond = condition, expectedCond = condition) testJoin(originalCond = condition, expectedCond = condition) + testDelete(originalCond = condition, expectedCond = condition) } test("inability to replace null in non-boolean branches of If inside another If") { @@ -222,6 +238,7 @@ class ReplaceNullWithFalseInPredicateSuite extends PlanTest { FalseLiteral) testFilter(originalCond = condition, expectedCond = condition) testJoin(originalCond = condition, expectedCond = condition) + testDelete(originalCond = condition, expectedCond = condition) } test("replace null in If used as a join condition") { @@ -353,6 +370,10 @@ class ReplaceNullWithFalseInPredicateSuite extends PlanTest { test((rel, exp) => rel.select(exp), originalExpr, expectedExpr) } + private def testDelete(originalCond: Expression, expectedCond: Expression): Unit = { + test((rel, expr) => DeleteFromTable(rel, Some(expr)), originalCond, expectedCond) + } + private def testHigherOrderFunc( argument: Expression, createExpr: (Expression, Expression) => Expression,