Skip to content

Commit

Permalink
Make ArrayExists follow the three-valued boolean logic.
Browse files Browse the repository at this point in the history
  • Loading branch information
ueshin committed Jun 14, 2019
1 parent abe370f commit 0d30c66
Show file tree
Hide file tree
Showing 4 changed files with 32 additions and 10 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -395,6 +395,8 @@ case class ArrayExists(
function: Expression)
extends ArrayBasedSimpleHigherOrderFunction with CodegenFallback {

override def nullable: Boolean = super.nullable || function.nullable

override def dataType: DataType = BooleanType

override def functionType: AbstractDataType = BooleanType
Expand All @@ -409,16 +411,23 @@ case class ArrayExists(
override def nullSafeEval(inputRow: InternalRow, argumentValue: Any): Any = {
val arr = argumentValue.asInstanceOf[ArrayData]
val f = functionForEval
var exists = false
var foundNull = false
var i = 0
while (i < arr.numElements && !exists) {
while (i < arr.numElements) {
elementVar.value.set(arr.get(i, elementVar.dataType))
if (f.eval(inputRow).asInstanceOf[Boolean]) {
exists = true
val ret = f.eval(inputRow)
if (ret == null) {
foundNull = true
} else if (ret.asInstanceOf[Boolean]) {
return true
}
i += 1
}
exists
if (foundNull) {
null
} else {
false
}
}

override def prettyName: String = "exists"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -63,9 +63,6 @@ object ReplaceNullWithFalseInPredicate extends Rule[LogicalPlan] {
case af @ ArrayFilter(_, lf @ LambdaFunction(func, _, _)) =>
val newLambda = lf.copy(function = replaceNullWithFalse(func))
af.copy(function = newLambda)
case ae @ ArrayExists(_, lf @ LambdaFunction(func, _, _)) =>
val newLambda = lf.copy(function = replaceNullWithFalse(func))
ae.copy(function = newLambda)
case mf @ MapFilter(_, lf @ LambdaFunction(func, _, _)) =>
val newLambda = lf.copy(function = replaceNullWithFalse(func))
mf.copy(function = newLambda)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -255,13 +255,21 @@ class HigherOrderFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper

val isEven: Expression => Expression = x => x % 2 === 0
val isNullOrOdd: Expression => Expression = x => x.isNull || x % 2 === 1
val alwaysFalse: Expression => Expression = _ => Literal.FalseLiteral
val alwaysNull: Expression => Expression = _ => Literal(null, BooleanType)

checkEvaluation(exists(ai0, isEven), true)
checkEvaluation(exists(ai0, isNullOrOdd), true)
checkEvaluation(exists(ai1, isEven), false)
checkEvaluation(exists(ai0, alwaysFalse), false)
checkEvaluation(exists(ai0, alwaysNull), null)
checkEvaluation(exists(ai1, isEven), null)
checkEvaluation(exists(ai1, isNullOrOdd), true)
checkEvaluation(exists(ai1, alwaysFalse), false)
checkEvaluation(exists(ai1, alwaysNull), null)
checkEvaluation(exists(ain, isEven), null)
checkEvaluation(exists(ain, isNullOrOdd), null)
checkEvaluation(exists(ain, alwaysFalse), null)
checkEvaluation(exists(ain, alwaysNull), null)

val as0 =
Literal.create(Seq("a0", "b1", "a2", "c3"), ArrayType(StringType, containsNull = false))
Expand All @@ -271,8 +279,14 @@ class HigherOrderFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper
val startsWithA: Expression => Expression = x => x.startsWith("a")

checkEvaluation(exists(as0, startsWithA), true)
checkEvaluation(exists(as1, startsWithA), false)
checkEvaluation(exists(as0, alwaysFalse), false)
checkEvaluation(exists(as0, alwaysNull), null)
checkEvaluation(exists(as1, startsWithA), null)
checkEvaluation(exists(as1, alwaysFalse), false)
checkEvaluation(exists(as1, alwaysNull), null)
checkEvaluation(exists(asn, startsWithA), null)
checkEvaluation(exists(asn, alwaysFalse), null)
checkEvaluation(exists(asn, alwaysNull), null)

val aai = Literal.create(Seq(Seq(1, 2, 3), null, Seq(4, 5)),
ArrayType(ArrayType(IntegerType, containsNull = false), containsNull = true))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2246,6 +2246,7 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSQLContext {
test("exists function - array for primitive type containing null") {
val df = Seq[Seq[Integer]](
Seq(1, 9, 8, null, 7),
Seq(1, 3, 5),
Seq(5, null, null, 9, 7, null),
Seq.empty,
null
Expand All @@ -2256,6 +2257,7 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSQLContext {
Seq(
Row(true),
Row(false),
Row(null),
Row(false),
Row(null)))
}
Expand Down

0 comments on commit 0d30c66

Please sign in to comment.