From 191f7859437689e01520c3ad4c500f53dfdc14a5 Mon Sep 17 00:00:00 2001 From: Chao Sun Date: Sat, 16 May 2026 17:50:56 -0700 Subject: [PATCH 1/2] [SPARK-56840][SQL][3.5] Avoid unresolved NullIf type lookup ### Why are the changes needed? `NULLIF` builds its replacement expression before analysis has resolved all child expressions. For nested field references, the existing implementation can read the left operand data type too early while constructing the null branch, which can fail analysis even though the SQL shape is valid. SPARK-56840 tracks this analyzer failure. ### What changes were proposed in this PR? - Build the `NULLIF` null branch with a lazy typed-null placeholder so construction does not eagerly read the unresolved left operand type, while `NullIf.replacement.dataType` remains valid once the operand type is available. - Make that placeholder `RuntimeReplaceable`, so `ReplaceExpressions` restores an ordinary typed `Literal(null, ...)` before later optimizer rules run and existing null-literal simplifications continue to apply. - Backport the focused regressions, including the real nested-field `nullif(c.provider, lower(...))` repro test, into the same branch-3.5 commit. ### Does this PR introduce _any_ user-facing change? Yes. Valid `NULLIF` expressions over unresolved nested field references that could fail during analysis now resolve and execute successfully. ### How was this patch tested? - Not run yet. ### Was this patch authored or co-authored using generative AI tooling? Generated-by: Codex (GPT-5.5) --- .../expressions/nullExpressions.scala | 26 ++++++++++- .../expressions/NullExpressionsSuite.scala | 25 ++++++++++- .../catalyst/optimizer/OptimizerSuite.scala | 22 ++++++++- .../spark/sql/DataFrameFunctionsSuite.scala | 45 +++++++++++++++++++ .../org/apache/spark/sql/ExplainSuite.scala | 1 + 5 files changed, 116 insertions(+), 3 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/nullExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/nullExpressions.scala index 948cb6fbedd32..4e57c7ecea1fe 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/nullExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/nullExpressions.scala @@ -25,6 +25,7 @@ import org.apache.spark.sql.catalyst.expressions.codegen.Block._ import org.apache.spark.sql.catalyst.trees.TreePattern.{COALESCE, NULL_CHECK, TreePattern} import org.apache.spark.sql.catalyst.util.TypeUtils import org.apache.spark.sql.errors.QueryCompilationErrors +import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types._ /** @@ -140,6 +141,21 @@ case class Coalesce(children: Seq[Expression]) copy(children = newChildren) } +private case class TypedNullLiteral(child: Expression) + extends UnaryExpression with RuntimeReplaceable { + override def nullable: Boolean = true + + override def dataType: DataType = child.dataType + + override def toString: String = "null" + + override def sql: String = "NULL" + + override lazy val replacement: Expression = Literal.create(null, child.dataType) + + override protected def withNewChildInternal(newChild: Expression): TypedNullLiteral = + copy(child = newChild) +} @ExpressionDescription( usage = "_FUNC_(expr1, expr2) - Returns null if `expr1` equals to `expr2`, or `expr1` otherwise.", @@ -154,7 +170,15 @@ case class NullIf(left: Expression, right: Expression, replacement: Expression) extends RuntimeReplaceable with InheritAnalysisRules { def this(left: Expression, right: Expression) = { - this(left, right, If(EqualTo(left, right), Literal.create(null, left.dataType), left)) + this(left, right, + if (!SQLConf.get.getConf(SQLConf.ALWAYS_INLINE_COMMON_EXPR)) { + With(left) { case Seq(ref) => + If(EqualTo(ref, right), TypedNullLiteral(ref), ref) + } + } else { + If(EqualTo(left, right), TypedNullLiteral(left), left) + } + ) } override def parameters: Seq[Expression] = Seq(left, right) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/NullExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/NullExpressionsSuite.scala index da8e11c0433eb..6732d8e74c0ce 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/NullExpressionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/NullExpressionsSuite.scala @@ -20,7 +20,8 @@ package org.apache.spark.sql.catalyst.expressions import java.sql.Timestamp import org.apache.spark.SparkFunSuite -import org.apache.spark.sql.catalyst.analysis.SimpleAnalyzer +import org.apache.spark.sql.catalyst.FunctionIdentifier +import org.apache.spark.sql.catalyst.analysis.{FunctionRegistry, SimpleAnalyzer, UnresolvedAttribute} import org.apache.spark.sql.catalyst.expressions.codegen.CodegenContext import org.apache.spark.sql.catalyst.expressions.objects.AssertNotNull import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, Project} @@ -140,6 +141,28 @@ class NullExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { assert(analyze(new Nvl(floatLit, doubleLit)).dataType == DoubleType) } + test("NullIf replacement preserves its data type before type coercion") { + Seq(true, false).foreach { alwaysInlineCommonExpr => + withSQLConf(SQLConf.ALWAYS_INLINE_COMMON_EXPR.key -> alwaysInlineCommonExpr.toString) { + val nullIf = new NullIf(Literal(1), Literal(1)) + assert(nullIf.dataType == IntegerType) + assert(nullIf.replacement.dataType == IntegerType) + } + } + } + + test("NullIf accepts unresolved nested fields during inlined function construction") { + withSQLConf(SQLConf.ALWAYS_INLINE_COMMON_EXPR.key -> "true") { + val nullIf = FunctionRegistry.builtin.lookupFunction( + FunctionIdentifier("nullif"), + Seq( + UnresolvedAttribute(Seq("c", "provider")), + Lower(Literal("ERROR_MULTIPLE_PROVIDERS")))) + + assert(nullIf.isInstanceOf[NullIf]) + } + } + test("AtLeastNNonNulls") { val mix = Seq(Literal("x"), Literal.create(null, StringType), diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/OptimizerSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/OptimizerSuite.scala index 6b63f860b7da9..42a1edb34ee89 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/OptimizerSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/OptimizerSuite.scala @@ -19,11 +19,12 @@ package org.apache.spark.sql.catalyst.optimizer import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.catalyst.dsl.plans._ -import org.apache.spark.sql.catalyst.expressions.{Alias, IntegerLiteral, Literal} +import org.apache.spark.sql.catalyst.expressions.{Alias, IntegerLiteral, Literal, NullIf, RuntimeReplaceable} import org.apache.spark.sql.catalyst.plans.PlanTest import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, OneRowRelation, Project} import org.apache.spark.sql.catalyst.rules.Rule import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.types.BooleanType /** * A dummy optimizer rule for testing that decrements integer literals until 0. @@ -71,4 +72,23 @@ class OptimizerSuite extends PlanTest { s"test, please set '${SQLConf.OPTIMIZER_MAX_ITERATIONS.key}' to a larger value.")) } } + test("NullIf typed null branch is replaced with a null literal") { + val optimizer = new SimpleTestOptimizer() { + override def defaultBatches: Seq[Batch] = + Batch("test", fixedPoint, + ReplaceExpressions) :: Nil + } + + withSQLConf(SQLConf.ALWAYS_INLINE_COMMON_EXPR.key -> "true") { + val nullIf = new NullIf(Literal(true), Literal(true)) + val plan = Project(Alias(nullIf, "out")() :: Nil, OneRowRelation()).analyze + val optimized = optimizer.execute(plan) + + assert(optimized.expressions.exists(_.exists { + case Literal(null, BooleanType) => true + case _ => false + })) + assert(optimized.expressions.forall(!_.exists(_.isInstanceOf[RuntimeReplaceable]))) + } + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala index 71ad4a25578eb..57595e4c4ae2e 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala @@ -309,6 +309,51 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSparkSession { checkAnswer(df.select(isnotnull(col("a"))), Seq(Row(false))) } + test("nullif function") { + Seq(true, false).foreach { alwaysInlineCommonExpr => + withSQLConf(SQLConf.ALWAYS_INLINE_COMMON_EXPR.key -> alwaysInlineCommonExpr.toString) { + Seq( + "SELECT NULLIF(1, 1)" -> Seq(Row(null)), + "SELECT NULLIF(1, 2)" -> Seq(Row(1)), + "SELECT NULLIF(NULL, 1)" -> Seq(Row(null)), + "SELECT NULLIF(1, NULL)" -> Seq(Row(1)), + "SELECT NULLIF(NULL, NULL)" -> Seq(Row(null)), + "SELECT NULLIF('abc', 'abc')" -> Seq(Row(null)), + "SELECT NULLIF('abc', 'xyz')" -> Seq(Row("abc")), + "SELECT NULLIF(id, 1) " + + "FROM range(10) " + + "GROUP BY NULLIF(id, 1)" -> Seq(Row(null), Row(2), Row(3), Row(4), Row(5), Row(6), + Row(7), Row(8), Row(9), Row(0)), + "SELECT NULLIF(id, 1), COUNT(*)" + + "FROM range(10) " + + "GROUP BY NULLIF(id, 1) " + + "HAVING COUNT(*) > 1" -> Seq.empty[Row] + ).foreach { + case (sqlText, expected) => checkAnswer(sql(sqlText), expected) + } + + checkError( + exception = intercept[AnalysisException] { + sql("SELECT NULLIF(id, 1), COUNT(*) " + + "FROM range(10) " + + "GROUP BY NULLIF(id, 2)") + }, + condition = "MISSING_AGGREGATION", + parameters = Map( + "expression" -> "\"id\"", + "expressionAnyValue" -> "\"any_value(id)\"") + ) + + val nestedDf = Seq("error_multiple_providers", "openai") + .toDF("provider") + .select(struct(col("provider")).as("c")) + checkAnswer( + nestedDf.select(nullif(col("c.provider"), lower(lit("ERROR_MULTIPLE_PROVIDERS")))), + Seq(Row(null), Row("openai"))) + } + } + } + test("equal_null function") { val df = Seq[(Integer, Integer)]((null, 8)).toDF("a", "b") checkAnswer(df.selectExpr("equal_null(a, b)"), Seq(Row(false))) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/ExplainSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/ExplainSuite.scala index a206e97c35362..739557bef3016 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/ExplainSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/ExplainSuite.scala @@ -248,6 +248,7 @@ class ExplainSuite extends ExplainSuiteHelper with DisableAdaptiveExecutionSuite checkKeywordsExistsInExplain(df, "Project [id#xL AS ifnull(id, 1)#xL, if ((id#xL = 1)) null " + "else id#xL AS nullif(id, 1)#xL, id#xL AS nvl(id, 1)#xL, 1 AS nvl2(id, 1, 2)#x]") + checkKeywordsNotExistsInExplain(df, ExtendedMode, "typednullliteral") } test("SPARK-26659: explain of DataWritingCommandExec should not contain duplicate cmd.nodeName") { From a2b780095d42b1be05fcd502d18ab33d419ee029 Mon Sep 17 00:00:00 2001 From: Chao Sun Date: Sun, 17 May 2026 09:21:23 -0700 Subject: [PATCH 2/2] [SPARK-56840][SQL][3.5] Fix branch-3.5 NullIf backport --- .../expressions/nullExpressions.scala | 11 +--- .../expressions/NullExpressionsSuite.scala | 26 ++++------ .../catalyst/optimizer/OptimizerSuite.scala | 18 +++---- .../spark/sql/DataFrameFunctionsSuite.scala | 52 +++---------------- 4 files changed, 26 insertions(+), 81 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/nullExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/nullExpressions.scala index 4e57c7ecea1fe..edf8ee00e708d 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/nullExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/nullExpressions.scala @@ -25,7 +25,6 @@ import org.apache.spark.sql.catalyst.expressions.codegen.Block._ import org.apache.spark.sql.catalyst.trees.TreePattern.{COALESCE, NULL_CHECK, TreePattern} import org.apache.spark.sql.catalyst.util.TypeUtils import org.apache.spark.sql.errors.QueryCompilationErrors -import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types._ /** @@ -170,15 +169,7 @@ case class NullIf(left: Expression, right: Expression, replacement: Expression) extends RuntimeReplaceable with InheritAnalysisRules { def this(left: Expression, right: Expression) = { - this(left, right, - if (!SQLConf.get.getConf(SQLConf.ALWAYS_INLINE_COMMON_EXPR)) { - With(left) { case Seq(ref) => - If(EqualTo(ref, right), TypedNullLiteral(ref), ref) - } - } else { - If(EqualTo(left, right), TypedNullLiteral(left), left) - } - ) + this(left, right, If(EqualTo(left, right), TypedNullLiteral(left), left)) } override def parameters: Seq[Expression] = Seq(left, right) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/NullExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/NullExpressionsSuite.scala index 6732d8e74c0ce..f97af0835f3fc 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/NullExpressionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/NullExpressionsSuite.scala @@ -142,25 +142,19 @@ class NullExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { } test("NullIf replacement preserves its data type before type coercion") { - Seq(true, false).foreach { alwaysInlineCommonExpr => - withSQLConf(SQLConf.ALWAYS_INLINE_COMMON_EXPR.key -> alwaysInlineCommonExpr.toString) { - val nullIf = new NullIf(Literal(1), Literal(1)) - assert(nullIf.dataType == IntegerType) - assert(nullIf.replacement.dataType == IntegerType) - } - } + val nullIf = new NullIf(Literal(1), Literal(1)) + assert(nullIf.dataType == IntegerType) + assert(nullIf.replacement.dataType == IntegerType) } - test("NullIf accepts unresolved nested fields during inlined function construction") { - withSQLConf(SQLConf.ALWAYS_INLINE_COMMON_EXPR.key -> "true") { - val nullIf = FunctionRegistry.builtin.lookupFunction( - FunctionIdentifier("nullif"), - Seq( - UnresolvedAttribute(Seq("c", "provider")), - Lower(Literal("ERROR_MULTIPLE_PROVIDERS")))) + test("NullIf accepts unresolved nested fields during function construction") { + val nullIf = FunctionRegistry.builtin.lookupFunction( + FunctionIdentifier("nullif"), + Seq( + UnresolvedAttribute(Seq("c", "provider")), + Lower(Literal("ERROR_MULTIPLE_PROVIDERS")))) - assert(nullIf.isInstanceOf[NullIf]) - } + assert(nullIf.isInstanceOf[NullIf]) } test("AtLeastNNonNulls") { diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/OptimizerSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/OptimizerSuite.scala index 42a1edb34ee89..fb9a0f6f6e6c9 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/OptimizerSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/OptimizerSuite.scala @@ -79,16 +79,14 @@ class OptimizerSuite extends PlanTest { ReplaceExpressions) :: Nil } - withSQLConf(SQLConf.ALWAYS_INLINE_COMMON_EXPR.key -> "true") { - val nullIf = new NullIf(Literal(true), Literal(true)) - val plan = Project(Alias(nullIf, "out")() :: Nil, OneRowRelation()).analyze - val optimized = optimizer.execute(plan) + val nullIf = new NullIf(Literal(true), Literal(true)) + val plan = Project(Alias(nullIf, "out")() :: Nil, OneRowRelation()).analyze + val optimized = optimizer.execute(plan) - assert(optimized.expressions.exists(_.exists { - case Literal(null, BooleanType) => true - case _ => false - })) - assert(optimized.expressions.forall(!_.exists(_.isInstanceOf[RuntimeReplaceable]))) - } + assert(optimized.expressions.exists(_.exists { + case Literal(null, BooleanType) => true + case _ => false + })) + assert(optimized.expressions.forall(!_.exists(_.isInstanceOf[RuntimeReplaceable]))) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala index 57595e4c4ae2e..251b5429102cb 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala @@ -309,51 +309,6 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSparkSession { checkAnswer(df.select(isnotnull(col("a"))), Seq(Row(false))) } - test("nullif function") { - Seq(true, false).foreach { alwaysInlineCommonExpr => - withSQLConf(SQLConf.ALWAYS_INLINE_COMMON_EXPR.key -> alwaysInlineCommonExpr.toString) { - Seq( - "SELECT NULLIF(1, 1)" -> Seq(Row(null)), - "SELECT NULLIF(1, 2)" -> Seq(Row(1)), - "SELECT NULLIF(NULL, 1)" -> Seq(Row(null)), - "SELECT NULLIF(1, NULL)" -> Seq(Row(1)), - "SELECT NULLIF(NULL, NULL)" -> Seq(Row(null)), - "SELECT NULLIF('abc', 'abc')" -> Seq(Row(null)), - "SELECT NULLIF('abc', 'xyz')" -> Seq(Row("abc")), - "SELECT NULLIF(id, 1) " + - "FROM range(10) " + - "GROUP BY NULLIF(id, 1)" -> Seq(Row(null), Row(2), Row(3), Row(4), Row(5), Row(6), - Row(7), Row(8), Row(9), Row(0)), - "SELECT NULLIF(id, 1), COUNT(*)" + - "FROM range(10) " + - "GROUP BY NULLIF(id, 1) " + - "HAVING COUNT(*) > 1" -> Seq.empty[Row] - ).foreach { - case (sqlText, expected) => checkAnswer(sql(sqlText), expected) - } - - checkError( - exception = intercept[AnalysisException] { - sql("SELECT NULLIF(id, 1), COUNT(*) " + - "FROM range(10) " + - "GROUP BY NULLIF(id, 2)") - }, - condition = "MISSING_AGGREGATION", - parameters = Map( - "expression" -> "\"id\"", - "expressionAnyValue" -> "\"any_value(id)\"") - ) - - val nestedDf = Seq("error_multiple_providers", "openai") - .toDF("provider") - .select(struct(col("provider")).as("c")) - checkAnswer( - nestedDf.select(nullif(col("c.provider"), lower(lit("ERROR_MULTIPLE_PROVIDERS")))), - Seq(Row(null), Row("openai"))) - } - } - } - test("equal_null function") { val df = Seq[(Integer, Integer)]((null, 8)).toDF("a", "b") checkAnswer(df.selectExpr("equal_null(a, b)"), Seq(Row(false))) @@ -370,6 +325,13 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSparkSession { checkAnswer(df.selectExpr("nullif(a, a)"), Seq(Row(null))) checkAnswer(df.select(nullif(lit(5), lit(5))), Seq(Row(null))) + + val nestedDf = Seq("error_multiple_providers", "openai") + .toDF("provider") + .select(struct(col("provider")).as("c")) + checkAnswer( + nestedDf.select(nullif(col("c.provider"), lower(lit("ERROR_MULTIPLE_PROVIDERS")))), + Seq(Row(null), Row("openai"))) } test("nvl") {