Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -163,6 +163,10 @@ object FunctionRegistry {
expression[IsNaN]("isnan"),
expression[IsNull]("isnull"),
expression[IsNotNull]("isnotnull"),
expression[NullIf]("nullif"),
expression[Nvl]("nvl"),
expression[Nvl]("ifnull"),
expression[Nvl2]("nvl2"),
expression[Least]("least"),
expression[CreateMap]("map"),
expression[CreateNamedStruct]("named_struct"),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -127,6 +127,131 @@ case class IsNaN(child: Expression) extends UnaryExpression
}
}

/**
* An Expression accepts two parameters and returns null if both parameters are equal.
* If they are not equal, the first parameter value is returned.
*/
@ExpressionDescription(
usage = "_FUNC_(a,b) - Returns null if a equals to b, or a otherwise.")
case class NullIf(left: Expression, right: Expression) extends BinaryExpression {
override def nullable: Boolean = true
override def dataType: DataType = left.dataType

override def eval(input: InternalRow): Any = {
val valueLeft = left.eval(input)
val valueRight = right.eval(input)
if (valueLeft != null && valueRight != null && valueLeft.equals(valueRight) ||
valueLeft == null && valueRight == null) {
null
} else {
valueLeft
}
}

override def genCode(ctx: CodegenContext, ev: ExprCode): String = {
val leftGen = left.gen(ctx)
val rightGen = right.gen(ctx)
s"""
${leftGen.code}
${rightGen.code}
boolean ${ev.isNull} = false;
${ctx.javaType(dataType)} ${ev.value} = ${ctx.defaultValue(dataType)};
if (!${leftGen.isNull} && !${rightGen.isNull} &&
(${ctx.genEqual(dataType, leftGen.value, rightGen.value)}) ||
${leftGen.isNull}) {
${ev.isNull} = true;
} else {
${ev.value} = ${leftGen.value};
}
"""
}
}

/**
* An Expression accepts two parameters and returns the second parameter if the value
* in the first parameter is null; if the first parameter is any value other than null,
* it is returned unchanged.
*
* Compare to Coalesce(), the difference is NVL will evaluate both parameters while Coalesce
* may not.
*/
@ExpressionDescription(
usage = "_FUNC_(a,b) - Returns b if a is null, or a otherwise.")
case class Nvl(left: Expression, right: Expression) extends BinaryExpression {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

isn't this just coalesce?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I will say, yes, kind of. Here is what I found: difference

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we should document the difference if there are any

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Did not notice this. I will do it shortly.

override def nullable: Boolean = true
override def dataType: DataType = left.dataType

override def eval(input: InternalRow): Any = {
val valueLeft = left.eval(input)
val valueRight = right.eval(input)
if (valueLeft == null) {
valueRight
} else {
valueLeft
}
}

override def genCode(ctx: CodegenContext, ev: ExprCode): String = {
val leftGen = left.gen(ctx)
val rightGen = right.gen(ctx)

s"""
${leftGen.code}
${rightGen.code}
boolean ${ev.isNull} = ${leftGen.isNull};
${ctx.javaType(dataType)} ${ev.value} = ${leftGen.value};
if (${ev.isNull} && !${rightGen.isNull}) {
${ev.isNull} = false;
${ev.value} = ${rightGen.value};
}
"""
}
}

/**
* An Expression accepts three parameters and returns the second parameter if the first parameter
* value is not null; if the first parameter is null, it returns the third parameter.
*/
@ExpressionDescription(
usage = "_FUNC_(a,b,c) - Returns b if a is not null, or c otherwise.")
case class Nvl2(first: Expression, second: Expression, third: Expression)
extends TernaryExpression {
override def nullable: Boolean = true
override def dataType: DataType = first.dataType
override def children: Seq[Expression] = first :: second :: third :: Nil

override def eval(input: InternalRow): Any = {
val valueFirst = first.eval(input)
val valueSecond = second.eval(input)
val valueThird = third.eval(input)
if (valueFirst == null) {
valueThird
} else {
valueSecond
}
}

override def genCode(ctx: CodegenContext, ev: ExprCode): String = {
val firstGen = first.gen(ctx)
val secondGen = second.gen(ctx)
val thirdGen = third.gen(ctx)
s"""
${firstGen.code}
${secondGen.code}
${thirdGen.code}
boolean ${ev.isNull} = true;
${ctx.javaType(dataType)} ${ev.value} = ${ctx.defaultValue(dataType)};
if (!${firstGen.isNull} && !${secondGen.isNull}) {
${ev.isNull} = false;
${ev.value} = ${secondGen.value};
} else if (${firstGen.isNull} && !${thirdGen.isNull}) {
${ev.isNull} = false;
${ev.value} = ${thirdGen.value};
}
"""
}
}

/**
* An Expression evaluates to `left` iff it's not NaN, or evaluates to `right` otherwise.
* This Expression is useful for mapping NaN values to null.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@

package org.apache.spark.sql.catalyst.expressions

import java.sql.{Date, Timestamp}

import org.apache.spark.SparkFunSuite
import org.apache.spark.sql.types._

Expand All @@ -31,11 +33,25 @@ class NullFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper {
testFunc(1.0F, FloatType)
testFunc(1.0, DoubleType)
testFunc(Decimal(1.5), DecimalType(2, 1))
testFunc(new java.sql.Date(10), DateType)
testFunc(new java.sql.Timestamp(10), TimestampType)
testFunc(new Date(System.currentTimeMillis()), DateType)
testFunc(new Timestamp(System.currentTimeMillis()), TimestampType)
testFunc("abcd", StringType)
}

def testAllTypes2Values(testFunc: (Any, Any, DataType) => Unit): Unit = {
testFunc(false, true, BooleanType)
testFunc(1.toByte, 2.toByte, ByteType)
testFunc(1.toShort, 2.toShort, ShortType)
testFunc(1, 2, IntegerType)
testFunc(1L, 2L, LongType)
testFunc(1.0F, 2.0F, FloatType)
testFunc(1.0, 2.0, DoubleType)
testFunc(Decimal(1.5), Decimal(2.5), DecimalType(2, 1))
testFunc(new Date(1460745262177L), new Date(1260745262177L), DateType)
testFunc(new Timestamp(10), new Timestamp(20), TimestampType)
testFunc("abcd", "xyz", StringType)
}

test("isnull and isnotnull") {
testAllTypes { (value: Any, tpe: DataType) =>
checkEvaluation(IsNull(Literal.create(value, tpe)), false)
Expand All @@ -55,6 +71,41 @@ class NullFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper {
checkEvaluation(IsNaN(Literal(5.5f)), false)
}

test("NullIf") {
testAllTypes2Values { (value1: Any, value2: Any, tpe: DataType) =>
val lit1 = Literal.create(value1, tpe)
val lit2 = Literal.create(value2, tpe)
val nullLit = Literal.create(null, tpe)
checkEvaluation(NullIf(lit1, lit2), value1)
checkEvaluation(NullIf(lit1, lit1), null)
checkEvaluation(NullIf(nullLit, lit2), null)
checkEvaluation(NullIf(lit1, nullLit), value1)
checkEvaluation(NullIf(nullLit, nullLit), null)
}
}

test("Nvl / IfNull") {
testAllTypes2Values { (value1: Any, value2: Any, tpe: DataType) =>
val lit1 = Literal.create(value1, tpe)
val lit2 = Literal.create(value2, tpe)
val nullLit = Literal.create(null, tpe)
checkEvaluation(Nvl(lit1, lit2), value1)
checkEvaluation(Nvl(nullLit, lit2), value2)
}
}

test("Nvl2") {
testAllTypes2Values { (value1: Any, value2: Any, tpe: DataType) =>
val lit1 = Literal.create(value1, tpe)
val lit2 = Literal.create(value2, tpe)
val nullLit = Literal.create(null, tpe)
checkEvaluation(Nvl2(lit1, lit1, lit2), value1)
checkEvaluation(Nvl2(lit1, nullLit, lit2), null)
checkEvaluation(Nvl2(nullLit, lit1, lit2), value2)
checkEvaluation(Nvl2(nullLit, lit1, nullLit), null)
}
}

test("nanvl") {
checkEvaluation(NaNvl(Literal(5.0), Literal.create(null, DoubleType)), 5.0)
checkEvaluation(NaNvl(Literal.create(null, DoubleType), Literal(5.0)), null)
Expand Down
34 changes: 34 additions & 0 deletions sql/core/src/main/scala/org/apache/spark/sql/functions.scala
Original file line number Diff line number Diff line change
Expand Up @@ -1014,6 +1014,40 @@ object functions {
*/
def nanvl(col1: Column, col2: Column): Column = withExpr { NaNvl(col1.expr, col2.expr) }

/**
* Returns null if col1 equals to col2, otherwise returns col1.
*
* @group normal_funcs
* @since 2.0.0
*/
def nullif(col1: Column, col2: Column): Column = withExpr { NullIf(col1.expr, col2.expr)}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

maybe we don't need to add these to functions.scala. they are too sql specific


/**
* Returns col2 if col1 equals is null, otherwise returns col1.
*
* @group normal_funcs
* @since 2.0.0
*/
def nvl(col1: Column, col2: Column): Column = withExpr { Nvl(col1.expr, col2.expr)}

/**
* Returns col2 if col1 equals is null, otherwise returns col1. Same as NVL().
*
* @group normal_funcs
* @since 2.0.0
*/
def ifnull(col1: Column, col2: Column): Column = withExpr { Nvl(col1.expr, col2.expr)}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

docs


/**
* Returns col3 if col1 equals is null, otherwise returns col2.
*
* @group normal_funcs
* @since 2.0.0
*/
def nvl2(col1: Column, col2: Column, col3: Column): Column = withExpr {
Nvl2(col1.expr, col2.expr, col3.expr)
}

/**
* Unary minus, i.e. negate the expression.
* {{{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -330,6 +330,32 @@ class ColumnExpressionSuite extends QueryTest with SharedSQLContext {
)
}

test("nullif") {
val testData = sqlContext.createDataFrame(sparkContext.parallelize(
Row("a", "a") :: Row("a", "b") :: Nil),
StructType(Seq(StructField("c1", StringType), StructField("c2", StringType))))
testData.registerTempTable("t")
checkAnswer(sql("select nullif(c1, c2) from t"), Row(null) :: Row("a") :: Nil)
}

test("nvl") {
val testData = sqlContext.createDataFrame(sparkContext.parallelize(
Row(null, "b") :: Row("a", "b") :: Nil),
StructType(Seq(StructField("c1", StringType), StructField("c2", StringType))))
testData.registerTempTable("t")
checkAnswer(sql("select nvl(c1, c2) from t"), Row("b") :: Row("a") :: Nil)
checkAnswer(sql("select ifnull(c1, c2) from t"), Row("b") :: Row("a") :: Nil)
}

test("nvl2") {
val testData = sqlContext.createDataFrame(sparkContext.parallelize(
Row(null, "b", "c") :: Row("a", "b", "c") :: Nil),
StructType(Seq(StructField("c1", StringType), StructField("c2", StringType),
StructField("c3", StringType))))
testData.registerTempTable("t")
checkAnswer(sql("select nvl2(c1, c2, c3) from t"), Row("b") :: Row("c") :: Nil)
}

test("===") {
checkAnswer(
testData2.filter($"a" === 1),
Expand Down