Skip to content

Commit

Permalink
[SPARK-25068][SQL] Add exists function.
Browse files Browse the repository at this point in the history
## What changes were proposed in this pull request?

This pr adds `exists` function which tests whether a predicate holds for one or more elements in the array.

```sql
> SELECT exists(array(1, 2, 3), x -> x % 2 == 0);
 true
```

## How was this patch tested?

Added tests.

Closes #22052 from ueshin/issues/SPARK-25068/exists.

Authored-by: Takuya UESHIN <ueshin@databricks.com>
Signed-off-by: Xiao Li <gatorsmile@gmail.com>
  • Loading branch information
ueshin authored and gatorsmile committed Aug 9, 2018
1 parent fec67ed commit 9b8521e
Show file tree
Hide file tree
Showing 6 changed files with 205 additions and 0 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -444,6 +444,7 @@ object FunctionRegistry {
expression[ArrayTransform]("transform"),
expression[MapFilter]("map_filter"),
expression[ArrayFilter]("filter"),
expression[ArrayExists]("exists"),
expression[ArrayAggregate]("aggregate"),
CreateStruct.registryEntry,

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -356,6 +356,53 @@ case class ArrayFilter(
override def prettyName: String = "filter"
}

/**
* Tests whether a predicate holds for one or more elements in the array.
*/
@ExpressionDescription(usage =
"_FUNC_(expr, pred) - Tests whether a predicate holds for one or more elements in the array.",
examples = """
Examples:
> SELECT _FUNC_(array(1, 2, 3), x -> x % 2 == 0);
true
""",
since = "2.4.0")
case class ArrayExists(
input: Expression,
function: Expression)
extends ArrayBasedSimpleHigherOrderFunction with CodegenFallback {

override def nullable: Boolean = input.nullable

override def dataType: DataType = BooleanType

override def expectingFunctionType: AbstractDataType = BooleanType

override def bind(f: (Expression, Seq[(DataType, Boolean)]) => LambdaFunction): ArrayExists = {
val elem = HigherOrderFunction.arrayArgumentType(input.dataType)
copy(function = f(function, elem :: Nil))
}

@transient lazy val LambdaFunction(_, Seq(elementVar: NamedLambdaVariable), _) = function

override def nullSafeEval(inputRow: InternalRow, value: Any): Any = {
val arr = value.asInstanceOf[ArrayData]
val f = functionForEval
var exists = false
var i = 0
while (i < arr.numElements && !exists) {
elementVar.value.set(arr.get(i, elementVar.dataType))
if (f.eval(inputRow).asInstanceOf[Boolean]) {
exists = true
}
i += 1
}
exists
}

override def prettyName: String = "exists"
}

/**
* Applies a binary operator to a start value and all elements in the array.
*/
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -202,6 +202,43 @@ class HigherOrderFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper
Seq(Seq(1, 3), null, Seq(5)))
}

test("ArrayExists") {
def exists(expr: Expression, f: Expression => Expression): Expression = {
val at = expr.dataType.asInstanceOf[ArrayType]
ArrayExists(expr, createLambda(at.elementType, at.containsNull, f))
}

val ai0 = Literal.create(Seq(1, 2, 3), ArrayType(IntegerType, containsNull = false))
val ai1 = Literal.create(Seq[Integer](1, null, 3), ArrayType(IntegerType, containsNull = true))
val ain = Literal.create(null, ArrayType(IntegerType, containsNull = false))

val isEven: Expression => Expression = x => x % 2 === 0
val isNullOrOdd: Expression => Expression = x => x.isNull || x % 2 === 1

checkEvaluation(exists(ai0, isEven), true)
checkEvaluation(exists(ai0, isNullOrOdd), true)
checkEvaluation(exists(ai1, isEven), false)
checkEvaluation(exists(ai1, isNullOrOdd), true)
checkEvaluation(exists(ain, isEven), null)
checkEvaluation(exists(ain, isNullOrOdd), null)

val as0 =
Literal.create(Seq("a0", "b1", "a2", "c3"), ArrayType(StringType, containsNull = false))
val as1 = Literal.create(Seq(null, "b", "c"), ArrayType(StringType, containsNull = true))
val asn = Literal.create(null, ArrayType(StringType, containsNull = false))

val startsWithA: Expression => Expression = x => x.startsWith("a")

checkEvaluation(exists(as0, startsWithA), true)
checkEvaluation(exists(as1, startsWithA), false)
checkEvaluation(exists(asn, startsWithA), null)

val aai = Literal.create(Seq(Seq(1, 2, 3), null, Seq(4, 5)),
ArrayType(ArrayType(IntegerType, containsNull = false), containsNull = true))
checkEvaluation(transform(aai, ix => exists(ix, isNullOrOdd)),
Seq(true, null, true))
}

test("ArrayAggregate") {
val ai0 = Literal.create(Seq(1, 2, 3), ArrayType(IntegerType, containsNull = false))
val ai1 = Literal.create(Seq[Integer](1, null, 3), ArrayType(IntegerType, containsNull = true))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -45,3 +45,9 @@ select transform(zs, z -> aggregate(z, 1, (acc, val) -> acc * val * size(z))) as

-- Aggregate a null array
select aggregate(cast(null as array<int>), 0, (a, y) -> a + y + 1, a -> a + 2) as v;

-- Check for element existence
select exists(ys, y -> y > 30) as v from nested;

-- Check for element existence in a null array
select exists(cast(null as array<int>), y -> y > 30) as v;
Original file line number Diff line number Diff line change
Expand Up @@ -145,3 +145,21 @@ select aggregate(cast(null as array<int>), 0, (a, y) -> a + y + 1, a -> a + 2) a
struct<v:int>
-- !query 14 output
NULL


-- !query 15
select exists(ys, y -> y > 30) as v from nested
-- !query 15 schema
struct<v:boolean>
-- !query 15 output
false
true
true


-- !query 16
select exists(cast(null as array<int>), y -> y > 30) as v
-- !query 16 schema
struct<v:boolean>
-- !query 16 output
NULL
Original file line number Diff line number Diff line change
Expand Up @@ -1996,6 +1996,102 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSQLContext {
assert(ex3.getMessage.contains("data type mismatch: argument 2 requires boolean type"))
}

test("exists function - array for primitive type not containing null") {
val df = Seq(
Seq(1, 9, 8, 7),
Seq(5, 9, 7),
Seq.empty,
null
).toDF("i")

def testArrayOfPrimitiveTypeNotContainsNull(): Unit = {
checkAnswer(df.selectExpr("exists(i, x -> x % 2 == 0)"),
Seq(
Row(true),
Row(false),
Row(false),
Row(null)))
}

// Test with local relation, the Project will be evaluated without codegen
testArrayOfPrimitiveTypeNotContainsNull()
// Test with cached relation, the Project will be evaluated with codegen
df.cache()
testArrayOfPrimitiveTypeNotContainsNull()
}

test("exists function - array for primitive type containing null") {
val df = Seq[Seq[Integer]](
Seq(1, 9, 8, null, 7),
Seq(5, null, null, 9, 7, null),
Seq.empty,
null
).toDF("i")

def testArrayOfPrimitiveTypeContainsNull(): Unit = {
checkAnswer(df.selectExpr("exists(i, x -> x % 2 == 0)"),
Seq(
Row(true),
Row(false),
Row(false),
Row(null)))
}

// Test with local relation, the Project will be evaluated without codegen
testArrayOfPrimitiveTypeContainsNull()
// Test with cached relation, the Project will be evaluated with codegen
df.cache()
testArrayOfPrimitiveTypeContainsNull()
}

test("exists function - array for non-primitive type") {
val df = Seq(
Seq("c", "a", "b"),
Seq("b", null, "c", null),
Seq.empty,
null
).toDF("s")

def testNonPrimitiveType(): Unit = {
checkAnswer(df.selectExpr("exists(s, x -> x is null)"),
Seq(
Row(false),
Row(true),
Row(false),
Row(null)))
}

// Test with local relation, the Project will be evaluated without codegen
testNonPrimitiveType()
// Test with cached relation, the Project will be evaluated with codegen
df.cache()
testNonPrimitiveType()
}

test("exists function - invalid") {
val df = Seq(
(Seq("c", "a", "b"), 1),
(Seq("b", null, "c", null), 2),
(Seq.empty, 3),
(null, 4)
).toDF("s", "i")

val ex1 = intercept[AnalysisException] {
df.selectExpr("exists(s, (x, y) -> x + y)")
}
assert(ex1.getMessage.contains("The number of lambda function arguments '2' does not match"))

val ex2 = intercept[AnalysisException] {
df.selectExpr("exists(i, x -> x)")
}
assert(ex2.getMessage.contains("data type mismatch: argument 1 requires array type"))

val ex3 = intercept[AnalysisException] {
df.selectExpr("exists(s, x -> x)")
}
assert(ex3.getMessage.contains("data type mismatch: argument 2 requires boolean type"))
}

test("aggregate function - array for primitive type not containing null") {
val df = Seq(
Seq(1, 9, 8, 7),
Expand Down

0 comments on commit 9b8521e

Please sign in to comment.