Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -140,6 +140,21 @@ case class Coalesce(children: Seq[Expression])
copy(children = newChildren)
}

private case class TypedNullLiteral(child: Expression)
extends UnaryExpression with RuntimeReplaceable {
override def nullable: Boolean = true

override def dataType: DataType = child.dataType

override def toString: String = "null"

override def sql: String = "NULL"

override lazy val replacement: Expression = Literal.create(null, child.dataType)

override protected def withNewChildInternal(newChild: Expression): TypedNullLiteral =
copy(child = newChild)
}

@ExpressionDescription(
usage = "_FUNC_(expr1, expr2) - Returns null if `expr1` equals to `expr2`, or `expr1` otherwise.",
Expand All @@ -154,7 +169,7 @@ case class NullIf(left: Expression, right: Expression, replacement: Expression)
extends RuntimeReplaceable with InheritAnalysisRules {

def this(left: Expression, right: Expression) = {
this(left, right, If(EqualTo(left, right), Literal.create(null, left.dataType), left))
this(left, right, If(EqualTo(left, right), TypedNullLiteral(left), left))
}

override def parameters: Seq[Expression] = Seq(left, right)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,8 @@ package org.apache.spark.sql.catalyst.expressions
import java.sql.Timestamp

import org.apache.spark.SparkFunSuite
import org.apache.spark.sql.catalyst.analysis.SimpleAnalyzer
import org.apache.spark.sql.catalyst.FunctionIdentifier
import org.apache.spark.sql.catalyst.analysis.{FunctionRegistry, SimpleAnalyzer, UnresolvedAttribute}
import org.apache.spark.sql.catalyst.expressions.codegen.CodegenContext
import org.apache.spark.sql.catalyst.expressions.objects.AssertNotNull
import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, Project}
Expand Down Expand Up @@ -140,6 +141,22 @@ class NullExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper {
assert(analyze(new Nvl(floatLit, doubleLit)).dataType == DoubleType)
}

test("NullIf replacement preserves its data type before type coercion") {
val nullIf = new NullIf(Literal(1), Literal(1))
assert(nullIf.dataType == IntegerType)
assert(nullIf.replacement.dataType == IntegerType)
}

test("NullIf accepts unresolved nested fields during function construction") {
val nullIf = FunctionRegistry.builtin.lookupFunction(
FunctionIdentifier("nullif"),
Seq(
UnresolvedAttribute(Seq("c", "provider")),
Lower(Literal("ERROR_MULTIPLE_PROVIDERS"))))

assert(nullIf.isInstanceOf[NullIf])
}

test("AtLeastNNonNulls") {
val mix = Seq(Literal("x"),
Literal.create(null, StringType),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,11 +19,12 @@ package org.apache.spark.sql.catalyst.optimizer

import org.apache.spark.sql.AnalysisException
import org.apache.spark.sql.catalyst.dsl.plans._
import org.apache.spark.sql.catalyst.expressions.{Alias, IntegerLiteral, Literal}
import org.apache.spark.sql.catalyst.expressions.{Alias, IntegerLiteral, Literal, NullIf, RuntimeReplaceable}
import org.apache.spark.sql.catalyst.plans.PlanTest
import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, OneRowRelation, Project}
import org.apache.spark.sql.catalyst.rules.Rule
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.types.BooleanType

/**
* A dummy optimizer rule for testing that decrements integer literals until 0.
Expand Down Expand Up @@ -71,4 +72,21 @@ class OptimizerSuite extends PlanTest {
s"test, please set '${SQLConf.OPTIMIZER_MAX_ITERATIONS.key}' to a larger value."))
}
}
test("NullIf typed null branch is replaced with a null literal") {
val optimizer = new SimpleTestOptimizer() {
override def defaultBatches: Seq[Batch] =
Batch("test", fixedPoint,
ReplaceExpressions) :: Nil
}

val nullIf = new NullIf(Literal(true), Literal(true))
val plan = Project(Alias(nullIf, "out")() :: Nil, OneRowRelation()).analyze
val optimized = optimizer.execute(plan)

assert(optimized.expressions.exists(_.exists {
case Literal(null, BooleanType) => true
case _ => false
}))
assert(optimized.expressions.forall(!_.exists(_.isInstanceOf[RuntimeReplaceable])))
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -325,6 +325,13 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSparkSession {

checkAnswer(df.selectExpr("nullif(a, a)"), Seq(Row(null)))
checkAnswer(df.select(nullif(lit(5), lit(5))), Seq(Row(null)))

val nestedDf = Seq("error_multiple_providers", "openai")
.toDF("provider")
.select(struct(col("provider")).as("c"))
checkAnswer(
nestedDf.select(nullif(col("c.provider"), lower(lit("ERROR_MULTIPLE_PROVIDERS")))),
Seq(Row(null), Row("openai")))
}

test("nvl") {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -248,6 +248,7 @@ class ExplainSuite extends ExplainSuiteHelper with DisableAdaptiveExecutionSuite
checkKeywordsExistsInExplain(df,
"Project [id#xL AS ifnull(id, 1)#xL, if ((id#xL = 1)) null " +
"else id#xL AS nullif(id, 1)#xL, id#xL AS nvl(id, 1)#xL, 1 AS nvl2(id, 1, 2)#x]")
checkKeywordsNotExistsInExplain(df, ExtendedMode, "typednullliteral")
}

test("SPARK-26659: explain of DataWritingCommandExec should not contain duplicate cmd.nodeName") {
Expand Down