Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SPARK-23919][SQL] Add array_position function #21037

Closed
wants to merge 9 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
17 changes: 17 additions & 0 deletions python/pyspark/sql/functions.py
Expand Up @@ -1845,6 +1845,23 @@ def array_contains(col, value):
return Column(sc._jvm.functions.array_contains(_to_java_column(col), value))


@since(2.4)
def array_position(col, value):
"""
Collection function: Locates the position of the first occurrence of the given value
in the given array. Returns null if either of the arguments are null.

.. note:: The position is not zero based, but 1 based index. Returns 0 if the given
value could not be found in the array.

>>> df = spark.createDataFrame([(["c", "b", "a"],), ([],)], ['data'])
>>> df.select(array_position(df.data, "a")).collect()
[Row(array_position(data, a)=3), Row(array_position(data, a)=0)]
"""
sc = SparkContext._active_spark_context
return Column(sc._jvm.functions.array_position(_to_java_column(col), value))


@since(1.4)
def explode(col):
"""Returns a new row for each element in the given array or map.
Expand Down
Expand Up @@ -402,6 +402,7 @@ object FunctionRegistry {
// collection functions
expression[CreateArray]("array"),
expression[ArrayContains]("array_contains"),
expression[ArrayPosition]("array_position"),
expression[CreateMap]("map"),
expression[CreateNamedStruct]("named_struct"),
expression[MapKeys]("map_keys"),
Expand Down
Expand Up @@ -505,3 +505,59 @@ case class ArrayMax(child: Expression) extends UnaryExpression with ImplicitCast

override def prettyName: String = "array_max"
}


/**
* Returns the position of the first occurrence of element in the given array as long.
* Returns 0 if the given value could not be found in the array. Returns null if either of
* the arguments are null
*
* NOTE: that this is not zero based, but 1-based index. The first element in the array has
* index 1.
*/
@ExpressionDescription(
usage = """
_FUNC_(array, element) - Returns the (1-based) index of the first element of the array as long.
""",
examples = """
Examples:
> SELECT _FUNC_(array(3, 2, 1), 1);
3
""",
since = "2.4.0")
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just wanted to note that we can use note here too:

note = """
Use RLIKE to match with standard regular expressions.
""")

I am mentioning this because we are adding many functions now :-).

case class ArrayPosition(left: Expression, right: Expression)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What is the behavior when left contains null element?

Copy link
Member Author

@kiszk kiszk Apr 12, 2018

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah, if an array in left contains a null element, as you can see a UT, it works as an usual element.

val left = Literal.create(Seq[String](null, ""), ArrayType(StringType)) // contains null
checkEvaluation(ArrayPosition(left, Literal(""), 2L)
checkEvaluation(ArrayPosition(left, Literal.create(null, StringType)), null)

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So we can't know the position of null in the array even if the array contains null?

Copy link
Member Author

@kiszk kiszk Apr 16, 2018

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

According to these UTs in Presto, you are right.

extends BinaryExpression with ImplicitCastInputTypes {

override def dataType: DataType = LongType
override def inputTypes: Seq[AbstractDataType] =
Seq(ArrayType, left.dataType.asInstanceOf[ArrayType].elementType)

override def nullSafeEval(arr: Any, value: Any): Any = {
arr.asInstanceOf[ArrayData].foreach(right.dataType, (i, v) =>
if (v == value) {
return (i + 1).toLong
}
)
0L
}

override def prettyName: String = "array_position"

override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
nullSafeCodeGen(ctx, ev, (arr, value) => {
val pos = ctx.freshName("arrayPosition")
val i = ctx.freshName("i")
val getValue = CodeGenerator.getValue(arr, right.dataType, i)
s"""
|int $pos = 0;
|for (int $i = 0; $i < $arr.numElements(); $i ++) {
| if (!$arr.isNullAt($i) && ${ctx.genEqual(right.dataType, value, getValue)}) {
| $pos = $i + 1;
| break;
| }
|}
|${ev.value} = (long) $pos;
""".stripMargin
})
}
}
Expand Up @@ -169,4 +169,26 @@ class CollectionExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper
checkEvaluation(Reverse(as7), null)
checkEvaluation(Reverse(aa), Seq(Seq("e"), Seq("c", "d"), Seq("a", "b")))
}

test("Array Position") {
val a0 = Literal.create(Seq(1, null, 2, 3), ArrayType(IntegerType))
val a1 = Literal.create(Seq[String](null, ""), ArrayType(StringType))
val a2 = Literal.create(Seq(null), ArrayType(LongType))
val a3 = Literal.create(null, ArrayType(StringType))

checkEvaluation(ArrayPosition(a0, Literal(3)), 4L)
checkEvaluation(ArrayPosition(a0, Literal(1)), 1L)
checkEvaluation(ArrayPosition(a0, Literal(0)), 0L)
checkEvaluation(ArrayPosition(a0, Literal.create(null, IntegerType)), null)

checkEvaluation(ArrayPosition(a1, Literal("")), 2L)
checkEvaluation(ArrayPosition(a1, Literal("a")), 0L)
checkEvaluation(ArrayPosition(a1, Literal.create(null, StringType)), null)

checkEvaluation(ArrayPosition(a2, Literal(1L)), 0L)
checkEvaluation(ArrayPosition(a2, Literal.create(null, LongType)), null)

checkEvaluation(ArrayPosition(a3, Literal("")), null)
checkEvaluation(ArrayPosition(a3, Literal.create(null, StringType)), null)
}
}
14 changes: 14 additions & 0 deletions sql/core/src/main/scala/org/apache/spark/sql/functions.scala
Expand Up @@ -3038,6 +3038,20 @@ object functions {
ArrayContains(column.expr, Literal(value))
}

/**
* Locates the position of the first occurrence of the value in the given array as long.
* Returns null if either of the arguments are null.
*
* @note The position is not zero based, but 1 based index. Returns 0 if value
* could not be found in array.
*
* @group collection_funcs
* @since 2.4.0
*/
def array_position(column: Column, value: Any): Column = withExpr {
ArrayPosition(column.expr, Literal(value))
}

/**
* Creates a new row for each element in the given array or map column.
*
Expand Down
Expand Up @@ -535,6 +535,40 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSQLContext {
}
}

test("array position function") {
val df = Seq(
(Seq[Int](1, 2), "x"),
(Seq[Int](), "x")
).toDF("a", "b")

checkAnswer(
df.select(array_position(df("a"), 1)),
Seq(Row(1L), Row(0L))
)
checkAnswer(
df.selectExpr("array_position(a, 1)"),
Seq(Row(1L), Row(0L))
)

checkAnswer(
df.select(array_position(df("a"), null)),
Seq(Row(null), Row(null))
)
checkAnswer(
df.selectExpr("array_position(a, null)"),
Seq(Row(null), Row(null))
)

checkAnswer(
df.selectExpr("array_position(array(array(1), null)[0], 1)"),
Seq(Row(1L), Row(1L))
)
checkAnswer(
df.selectExpr("array_position(array(1, null), array(1, null)[0])"),
Seq(Row(1L), Row(1L))
)
}

private def assertValuesDoNotChangeAfterCoalesceOrUnion(v: Column): Unit = {
import DataFrameFunctionsSuite.CodegenFallbackExpr
for ((codegenFallback, wholeStage) <- Seq((true, false), (false, false), (false, true))) {
Expand Down