Skip to content

Commit

Permalink
[SPARK-28962][SQL] Provide index argument to filter lambda functions
Browse files Browse the repository at this point in the history
### What changes were proposed in this pull request?

Lambda functions to array `filter` can now take as input the index as well as the element. This behavior matches array `transform`.

### Why are the changes needed?
See JIRA. It's generally useful, and particularly so if you're working with fixed length arrays.

### Does this PR introduce any user-facing change?
Previously filter lambdas had to look like
`filter(arr, el -> whatever)`

Now, lambdas can take an index argument as well
`filter(array, (el, idx) -> whatever)`

### How was this patch tested?
I added unit tests to `HigherOrderFunctionsSuite`.

Closes #25666 from henrydavidge/filter-idx.

Authored-by: Henry D <henrydavidge@gmail.com>
Signed-off-by: Takuya UESHIN <ueshin@databricks.com>
  • Loading branch information
henrydavidge authored and ueshin committed Oct 2, 2019
1 parent 730a178 commit 51d6ba7
Show file tree
Hide file tree
Showing 3 changed files with 57 additions and 5 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -344,8 +344,13 @@ case class MapFilter(
Examples:
> SELECT _FUNC_(array(1, 2, 3), x -> x % 2 == 1);
[1,3]
> SELECT _FUNC_(array(0, 2, 3), (x, i) -> x > i);
[2,3]
""",
since = "2.4.0")
since = "2.4.0",
note = """
The inner function may use the index argument since 3.0.0.
""")
case class ArrayFilter(
argument: Expression,
function: Expression)
Expand All @@ -357,10 +362,19 @@ case class ArrayFilter(

override def bind(f: (Expression, Seq[(DataType, Boolean)]) => LambdaFunction): ArrayFilter = {
val ArrayType(elementType, containsNull) = argument.dataType
copy(function = f(function, (elementType, containsNull) :: Nil))
function match {
case LambdaFunction(_, arguments, _) if arguments.size == 2 =>
copy(function = f(function, (elementType, containsNull) :: (IntegerType, false) :: Nil))
case _ =>
copy(function = f(function, (elementType, containsNull) :: Nil))
}
}

@transient lazy val LambdaFunction(_, Seq(elementVar: NamedLambdaVariable), _) = function
@transient lazy val (elementVar, indexVar) = {
val LambdaFunction(_, (elementVar: NamedLambdaVariable) +: tail, _) = function
val indexVar = tail.headOption.map(_.asInstanceOf[NamedLambdaVariable])
(elementVar, indexVar)
}

override def nullSafeEval(inputRow: InternalRow, argumentValue: Any): Any = {
val arr = argumentValue.asInstanceOf[ArrayData]
Expand All @@ -369,6 +383,9 @@ case class ArrayFilter(
var i = 0
while (i < arr.numElements) {
elementVar.value.set(arr.get(i, elementVar.dataType))
if (indexVar.isDefined) {
indexVar.get.value.set(i)
}
if (f.eval(inputRow).asInstanceOf[Boolean]) {
buffer += elementVar.value.get
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,11 @@ class HigherOrderFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper
ArrayFilter(expr, createLambda(et, cn, f)).bind(validateBinding)
}

def filter(expr: Expression, f: (Expression, Expression) => Expression): Expression = {
val ArrayType(et, cn) = expr.dataType
ArrayFilter(expr, createLambda(et, cn, IntegerType, false, f)).bind(validateBinding)
}

def transformKeys(expr: Expression, f: (Expression, Expression) => Expression): Expression = {
val MapType(kt, vt, vcn) = expr.dataType
TransformKeys(expr, createLambda(kt, false, vt, vcn, f)).bind(validateBinding)
Expand Down Expand Up @@ -218,9 +223,11 @@ class HigherOrderFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper

val isEven: Expression => Expression = x => x % 2 === 0
val isNullOrOdd: Expression => Expression = x => x.isNull || x % 2 === 1
val indexIsEven: (Expression, Expression) => Expression = { case (_, idx) => idx % 2 === 0 }

checkEvaluation(filter(ai0, isEven), Seq(2))
checkEvaluation(filter(ai0, isNullOrOdd), Seq(1, 3))
checkEvaluation(filter(ai0, indexIsEven), Seq(1, 3))
checkEvaluation(filter(ai1, isEven), Seq.empty)
checkEvaluation(filter(ai1, isNullOrOdd), Seq(1, null, 3))
checkEvaluation(filter(ain, isEven), null)
Expand All @@ -234,13 +241,17 @@ class HigherOrderFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper
val startsWithA: Expression => Expression = x => x.startsWith("a")

checkEvaluation(filter(as0, startsWithA), Seq("a0", "a2"))
checkEvaluation(filter(as0, indexIsEven), Seq("a0", "a2"))
checkEvaluation(filter(as1, startsWithA), Seq("a"))
checkEvaluation(filter(as1, indexIsEven), Seq("a", "c"))
checkEvaluation(filter(asn, startsWithA), null)

val aai = Literal.create(Seq(Seq(1, 2, 3), null, Seq(4, 5)),
ArrayType(ArrayType(IntegerType, containsNull = false), containsNull = true))
checkEvaluation(transform(aai, ix => filter(ix, isNullOrOdd)),
Seq(Seq(1, 3), null, Seq(5)))
checkEvaluation(transform(aai, ix => filter(ix, indexIsEven)),
Seq(Seq(1, 3), null, Seq(4)))
}

test("ArrayExists") {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2290,6 +2290,30 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSparkSession {
testNonPrimitiveType()
}

test("filter function - index argument") {
val df = Seq(
Seq("c", "a", "b"),
Seq("b", null, "c", null),
Seq.empty,
null
).toDF("s")

def testIndexArgument(): Unit = {
checkAnswer(df.selectExpr("filter(s, (x, i) -> i % 2 == 0)"),
Seq(
Row(Seq("c", "b")),
Row(Seq("b", "c")),
Row(Seq.empty),
Row(null)))
}

// Test with local relation, the Project will be evaluated without codegen
testIndexArgument()
// Test with cached relation, the Project will be evaluated with codegen
df.cache()
testIndexArgument()
}

test("filter function - invalid") {
val df = Seq(
(Seq("c", "a", "b"), 1),
Expand All @@ -2299,9 +2323,9 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSparkSession {
).toDF("s", "i")

val ex1 = intercept[AnalysisException] {
df.selectExpr("filter(s, (x, y) -> x + y)")
df.selectExpr("filter(s, (x, y, z) -> x + y)")
}
assert(ex1.getMessage.contains("The number of lambda function arguments '2' does not match"))
assert(ex1.getMessage.contains("The number of lambda function arguments '3' does not match"))

val ex2 = intercept[AnalysisException] {
df.selectExpr("filter(i, x -> x)")
Expand Down

0 comments on commit 51d6ba7

Please sign in to comment.