Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SPARK-23924][SQL] Add element_at function #21053

Closed
wants to merge 12 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 24 additions & 0 deletions python/pyspark/sql/functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -1862,6 +1862,30 @@ def array_position(col, value):
return Column(sc._jvm.functions.array_position(_to_java_column(col), value))


@ignore_unicode_prefix
@since(2.4)
def element_at(col, extraction):
"""
Collection function: Returns element of array at given index in extraction if col is array.
Returns value for the given key in extraction if col is map.

:param col: name of column containing array or map
:param extraction: index to check for in array or key to check for in map

.. note:: The position is not zero based, but 1 based index.

>>> df = spark.createDataFrame([(["a", "b", "c"],), ([],)], ['data'])
>>> df.select(element_at(df.data, 1)).collect()
[Row(element_at(data, 1)=u'a'), Row(element_at(data, 1)=None)]

>>> df = spark.createDataFrame([({"a": 1.0, "b": 2.0},), ({},)], ['data'])
>>> df.select(element_at(df.data, "a")).collect()
[Row(element_at(data, a)=1.0), Row(element_at(data, a)=None)]
"""
sc = SparkContext._active_spark_context
return Column(sc._jvm.functions.element_at(_to_java_column(col), extraction))


@since(1.4)
def explode(col):
"""Returns a new row for each element in the given array or map.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -405,6 +405,7 @@ object FunctionRegistry {
expression[ArrayPosition]("array_position"),
expression[CreateMap]("map"),
expression[CreateNamedStruct]("named_struct"),
expression[ElementAt]("element_at"),
expression[MapKeys]("map_keys"),
expression[MapValues]("map_values"),
expression[Size]("size"),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -561,3 +561,107 @@ case class ArrayPosition(left: Expression, right: Expression)
})
}
}

/**
* Returns the value of index `right` in Array `left` or the value for key `right` in Map `left`.
*/
@ExpressionDescription(
usage = """
_FUNC_(array, index) - Returns element of array at given (1-based) index. If index < 0,
accesses elements from the last to the first. Returns NULL if the index exceeds the length
of the array.

_FUNC_(map, key) - Returns value for given key, or NULL if the key is not contained in the map
""",
examples = """
Examples:
> SELECT _FUNC_(array(1, 2, 3), 2);
2
> SELECT _FUNC_(map(1, 'a', 2, 'b'), 2);
"b"
""",
since = "2.4.0")
case class ElementAt(left: Expression, right: Expression) extends GetMapValueUtil {

override def dataType: DataType = left.dataType match {
case ArrayType(elementType, _) => elementType
case MapType(_, valueType, _) => valueType
}

override def inputTypes: Seq[AbstractDataType] = {
Seq(TypeCollection(ArrayType, MapType),
left.dataType match {
case _: ArrayType => IntegerType
case _: MapType => left.dataType.asInstanceOf[MapType].keyType
}
)
}

override def nullable: Boolean = true
Copy link
Member

@viirya viirya Apr 16, 2018

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe? ```scala override def nullable: Boolean = left.dataType match { case a: ArrayType => a.containsNull case m: MapType => m.valueContainsNull } || left.nullable || right.nullable ```

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm afraid it's wrong because this returns null when the given index is "out of bounds" (array.numElements() < math.abs(index)) for array type or the given key doesn't exist for map type.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ah, I see. Invalid right can cause null result too.

Copy link
Member Author

@kiszk kiszk Apr 16, 2018

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yeah, may depend on right value, too.


override def nullSafeEval(value: Any, ordinal: Any): Any = {
left.dataType match {
case _: ArrayType =>
val array = value.asInstanceOf[ArrayData]
val index = ordinal.asInstanceOf[Int]
if (array.numElements() < math.abs(index)) {
null
} else {
val idx = if (index == 0) {
throw new ArrayIndexOutOfBoundsException("SQL array indices start at 1")
} else if (index > 0) {
index - 1
} else {
array.numElements() + index
}
if (left.dataType.asInstanceOf[ArrayType].containsNull && array.isNullAt(idx)) {
null
} else {
array.get(idx, dataType)
}
}
case _: MapType =>
getValueEval(value, ordinal, left.dataType.asInstanceOf[MapType].keyType)
}
}

override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
left.dataType match {
case _: ArrayType =>
nullSafeCodeGen(ctx, ev, (eval1, eval2) => {
val index = ctx.freshName("elementAtIndex")
val nullCheck = if (left.dataType.asInstanceOf[ArrayType].containsNull) {
s"""
|if ($eval1.isNullAt($index)) {
| ${ev.isNull} = true;
|} else
""".stripMargin
} else {
""
}
s"""
|int $index = (int) $eval2;
|if ($eval1.numElements() < Math.abs($index)) {
| ${ev.isNull} = true;
|} else {
| if ($index == 0) {
| throw new ArrayIndexOutOfBoundsException("SQL array indices start at 1");
| } else if ($index > 0) {
| $index--;
| } else {
| $index += $eval1.numElements();
| }
| $nullCheck
| {
| ${ev.value} = ${CodeGenerator.getValue(eval1, dataType, index)};
| }
|}
""".stripMargin
})
case _: MapType =>
doGetValueGenCode(ctx, ev, left.dataType.asInstanceOf[MapType])
}
}

override def prettyName: String = "element_at"
}
Original file line number Diff line number Diff line change
Expand Up @@ -268,31 +268,12 @@ case class GetArrayItem(child: Expression, ordinal: Expression)
}

/**
* Returns the value of key `key` in Map `child`.
*
* We need to do type checking here as `key` expression maybe unresolved.
* Common base class for [[GetMapValue]] and [[ElementAt]].
*/
case class GetMapValue(child: Expression, key: Expression)
extends BinaryExpression with ImplicitCastInputTypes with ExtractValue with NullIntolerant {

private def keyType = child.dataType.asInstanceOf[MapType].keyType

// We have done type checking for child in `ExtractValue`, so only need to check the `key`.
override def inputTypes: Seq[AbstractDataType] = Seq(AnyDataType, keyType)

override def toString: String = s"$child[$key]"
override def sql: String = s"${child.sql}[${key.sql}]"

override def left: Expression = child
override def right: Expression = key

/** `Null` is returned for invalid ordinals. */
override def nullable: Boolean = true

override def dataType: DataType = child.dataType.asInstanceOf[MapType].valueType

abstract class GetMapValueUtil extends BinaryExpression with ImplicitCastInputTypes {
// todo: current search is O(n), improve it.
protected override def nullSafeEval(value: Any, ordinal: Any): Any = {
def getValueEval(value: Any, ordinal: Any, keyType: DataType): Any = {
val map = value.asInstanceOf[MapData]
val length = map.numElements()
val keys = map.keyArray()
Expand All @@ -315,14 +296,15 @@ case class GetMapValue(child: Expression, key: Expression)
}
}

override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
def doGetValueGenCode(ctx: CodegenContext, ev: ExprCode, mapType: MapType): ExprCode = {
val index = ctx.freshName("index")
val length = ctx.freshName("length")
val keys = ctx.freshName("keys")
val found = ctx.freshName("found")
val key = ctx.freshName("key")
val values = ctx.freshName("values")
val nullCheck = if (child.dataType.asInstanceOf[MapType].valueContainsNull) {
val keyType = mapType.keyType
val nullCheck = if (mapType.valueContainsNull) {
s" || $values.isNullAt($index)"
} else {
""
Expand Down Expand Up @@ -354,3 +336,37 @@ case class GetMapValue(child: Expression, key: Expression)
})
}
}

/**
* Returns the value of key `key` in Map `child`.
*
* We need to do type checking here as `key` expression maybe unresolved.
*/
case class GetMapValue(child: Expression, key: Expression)
extends GetMapValueUtil with ExtractValue with NullIntolerant {

private def keyType = child.dataType.asInstanceOf[MapType].keyType

// We have done type checking for child in `ExtractValue`, so only need to check the `key`.
override def inputTypes: Seq[AbstractDataType] = Seq(AnyDataType, keyType)

override def toString: String = s"$child[$key]"
override def sql: String = s"${child.sql}[${key.sql}]"

override def left: Expression = child
override def right: Expression = key

/** `Null` is returned for invalid ordinals. */
override def nullable: Boolean = true

override def dataType: DataType = child.dataType.asInstanceOf[MapType].valueType

// todo: current search is O(n), improve it.
override def nullSafeEval(value: Any, ordinal: Any): Any = {
getValueEval(value, ordinal, keyType)
}

override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
doGetValueGenCode(ctx, ev, child.dataType.asInstanceOf[MapType])
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -191,4 +191,52 @@ class CollectionExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper
checkEvaluation(ArrayPosition(a3, Literal("")), null)
checkEvaluation(ArrayPosition(a3, Literal.create(null, StringType)), null)
}

test("elementAt") {
val a0 = Literal.create(Seq(1, 2, 3), ArrayType(IntegerType))
val a1 = Literal.create(Seq[String](null, ""), ArrayType(StringType))
val a2 = Literal.create(Seq(null), ArrayType(LongType))
val a3 = Literal.create(null, ArrayType(StringType))

intercept[Exception] {
checkEvaluation(ElementAt(a0, Literal(0)), null)
}.getMessage.contains("SQL array indices start at 1")
intercept[Exception] { checkEvaluation(ElementAt(a0, Literal(1.1)), null) }
checkEvaluation(ElementAt(a0, Literal(4)), null)
checkEvaluation(ElementAt(a0, Literal(-4)), null)

checkEvaluation(ElementAt(a0, Literal(1)), 1)
checkEvaluation(ElementAt(a0, Literal(2)), 2)
checkEvaluation(ElementAt(a0, Literal(3)), 3)
checkEvaluation(ElementAt(a0, Literal(-3)), 1)
checkEvaluation(ElementAt(a0, Literal(-2)), 2)
checkEvaluation(ElementAt(a0, Literal(-1)), 3)

checkEvaluation(ElementAt(a1, Literal(1)), null)
checkEvaluation(ElementAt(a1, Literal(2)), "")
checkEvaluation(ElementAt(a1, Literal(-2)), null)
checkEvaluation(ElementAt(a1, Literal(-1)), "")

checkEvaluation(ElementAt(a2, Literal(1)), null)

checkEvaluation(ElementAt(a3, Literal(1)), null)


val m0 =
Literal.create(Map("a" -> "1", "b" -> "2", "c" -> null), MapType(StringType, StringType))
val m1 = Literal.create(Map[String, String](), MapType(StringType, StringType))
val m2 = Literal.create(null, MapType(StringType, StringType))

checkEvaluation(ElementAt(m0, Literal(1.0)), null)

checkEvaluation(ElementAt(m0, Literal("d")), null)

checkEvaluation(ElementAt(m1, Literal("a")), null)

checkEvaluation(ElementAt(m0, Literal("a")), "1")
checkEvaluation(ElementAt(m0, Literal("b")), "2")
checkEvaluation(ElementAt(m0, Literal("c")), null)

checkEvaluation(ElementAt(m2, Literal("a")), null)
}
}
11 changes: 11 additions & 0 deletions sql/core/src/main/scala/org/apache/spark/sql/functions.scala
Original file line number Diff line number Diff line change
Expand Up @@ -3052,6 +3052,17 @@ object functions {
ArrayPosition(column.expr, Literal(value))
}

/**
* Returns element of array at given index in value if column is array. Returns value for
* the given key in value if column is map.
*
* @group collection_funcs
* @since 2.4.0
*/
def element_at(column: Column, value: Any): Column = withExpr {
ElementAt(column.expr, Literal(value))
}

/**
* Creates a new row for each element in the given array or map column.
*
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -569,6 +569,54 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSQLContext {
)
}

test("element_at function") {
val df = Seq(
(Seq[String]("1", "2", "3")),
(Seq[String](null, "")),
(Seq[String]())
).toDF("a")

intercept[Exception] {
checkAnswer(
df.select(element_at(df("a"), 0)),
Seq(Row(null), Row(null), Row(null))
)
}.getMessage.contains("SQL array indices start at 1")
intercept[Exception] {
checkAnswer(
df.select(element_at(df("a"), 1.1)),
Seq(Row(null), Row(null), Row(null))
)
}
checkAnswer(
df.select(element_at(df("a"), 4)),
Seq(Row(null), Row(null), Row(null))
)

checkAnswer(
df.select(element_at(df("a"), 1)),
Seq(Row("1"), Row(null), Row(null))
)
checkAnswer(
df.select(element_at(df("a"), -1)),
Seq(Row("3"), Row(""), Row(null))
)

checkAnswer(
df.selectExpr("element_at(a, 4)"),
Seq(Row(null), Row(null), Row(null))
)

checkAnswer(
df.selectExpr("element_at(a, 1)"),
Seq(Row("1"), Row(null), Row(null))
)
checkAnswer(
df.selectExpr("element_at(a, -1)"),
Seq(Row("3"), Row(""), Row(null))
)
}

private def assertValuesDoNotChangeAfterCoalesceOrUnion(v: Column): Unit = {
import DataFrameFunctionsSuite.CodegenFallbackExpr
for ((codegenFallback, wholeStage) <- Seq((true, false), (false, false), (false, true))) {
Expand Down