Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -145,6 +145,21 @@ case class Coalesce(children: Seq[Expression])
copy(children = newChildren)
}

private case class TypedNullLiteral(child: Expression)
extends UnaryExpression with RuntimeReplaceable {
override def nullable: Boolean = true

override def dataType: DataType = child.dataType

override def toString: String = "null"

override def sql: String = "NULL"

override lazy val replacement: Expression = Literal.create(null, child.dataType)

override protected def withNewChildInternal(newChild: Expression): TypedNullLiteral =
copy(child = newChild)
}

@ExpressionDescription(
usage = "_FUNC_(expr1, expr2) - Returns null if `expr1` equals to `expr2`, or `expr1` otherwise.",
Expand All @@ -162,10 +177,10 @@ case class NullIf(left: Expression, right: Expression, replacement: Expression)
this(left, right,
if (!SQLConf.get.getConf(SQLConf.ALWAYS_INLINE_COMMON_EXPR)) {
With(left) { case Seq(ref) =>
If(EqualTo(ref, right), Literal.create(null, left.dataType), ref)
If(EqualTo(ref, right), TypedNullLiteral(ref), ref)
}
} else {
If(EqualTo(left, right), Literal.create(null, left.dataType), left)
If(EqualTo(left, right), TypedNullLiteral(left), left)
}
)
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -143,6 +143,16 @@ class NullExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper {
assert(analyze(new Nvl(floatLit, doubleLit)).dataType == DoubleType)
}

test("NullIf replacement preserves its data type before type coercion") {
Seq(true, false).foreach { alwaysInlineCommonExpr =>
withSQLConf(SQLConf.ALWAYS_INLINE_COMMON_EXPR.key -> alwaysInlineCommonExpr.toString) {
val nullIf = new NullIf(Literal(1), Literal(1))
assert(nullIf.dataType == IntegerType)
assert(nullIf.replacement.dataType == IntegerType)
}
}
}

test("AtLeastNNonNulls") {
val mix = Seq(Literal("x"),
Literal.create(null, StringType),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,13 +21,13 @@ import org.apache.spark.SparkException
import org.apache.spark.sql.AnalysisException
import org.apache.spark.sql.catalyst.dsl.expressions._
import org.apache.spark.sql.catalyst.dsl.plans._
import org.apache.spark.sql.catalyst.expressions.{Add, Alias, ArrayCompact, AttributeReference, CreateArray, CreateStruct, IntegerLiteral, Literal, MapFromEntries, Multiply, NamedExpression, Remainder}
import org.apache.spark.sql.catalyst.expressions.{Add, Alias, ArrayCompact, AttributeReference, CreateArray, CreateStruct, IntegerLiteral, Literal, MapFromEntries, Multiply, NamedExpression, NullIf, Remainder, RuntimeReplaceable}
import org.apache.spark.sql.catalyst.expressions.aggregate.Sum
import org.apache.spark.sql.catalyst.plans.PlanTest
import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, LocalRelation, LogicalPlan, OneRowRelation, Project}
import org.apache.spark.sql.catalyst.rules.Rule
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.types.{ArrayType, IntegerType, MapType, StructField, StructType}
import org.apache.spark.sql.types.{ArrayType, BooleanType, IntegerType, MapType, StructField, StructType}

/**
* A dummy optimizer rule for testing that decrements integer literals until 0.
Expand Down Expand Up @@ -334,4 +334,24 @@ class OptimizerSuite extends PlanTest {
assert(optimized2.schema ===
StructType(StructField("map", MapType(IntegerType, IntegerType, false), false) :: Nil))
}

test("NullIf typed null branch is replaced with a null literal") {
val optimizer = new SimpleTestOptimizer() {
override def defaultBatches: Seq[Batch] =
Batch("test", fixedPoint,
ReplaceExpressions) :: Nil
}

withSQLConf(SQLConf.ALWAYS_INLINE_COMMON_EXPR.key -> "true") {
val nullIf = new NullIf(Literal(true), Literal(true))
val plan = Project(Alias(nullIf, "out")() :: Nil, OneRowRelation()).analyze
val optimized = optimizer.execute(plan)

assert(optimized.expressions.exists(_.exists {
case Literal(null, BooleanType) => true
case _ => false
}))
assert(optimized.expressions.forall(!_.exists(_.isInstanceOf[RuntimeReplaceable])))
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -350,6 +350,13 @@ class DataFrameFunctionsSuite extends SharedSparkSession {
"expression" -> "\"id\"",
"expressionAnyValue" -> "\"any_value(id)\"")
)

val nestedDf = Seq("error_multiple_providers", "openai")
.toDF("provider")
.select(struct(col("provider")).as("c"))
checkAnswer(
nestedDf.select(nullif(col("c.provider"), lower(lit("ERROR_MULTIPLE_PROVIDERS")))),
Seq(Row(null), Row("openai")))
}
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -251,6 +251,7 @@ class ExplainSuite extends ExplainSuiteHelper with DisableAdaptiveExecutionSuite
checkKeywordsExistsInExplain(df,
"Project [id#xL AS ifnull(id, 1)#xL, if ((id#xL = 1)) null " +
"else id#xL AS nullif(id, 1)#xL, id#xL AS nvl(id, 1)#xL, 1 AS nvl2(id, 1, 2)#x]")
checkKeywordsNotExistsInExplain(df, ExtendedMode, "typednullliteral")
}

test("SPARK-26659: explain of DataWritingCommandExec should not contain duplicate cmd.nodeName") {
Expand Down