Skip to content

Commit

Permalink
[SPARK-23924][SQL] Add element_at function
Browse files Browse the repository at this point in the history
## What changes were proposed in this pull request?

The PR adds the SQL function `element_at`. The behavior of the function is based on Presto's one.

This function returns element of array at given index in value if column is array, or returns value for the given key in value if column is map.

## How was this patch tested?

Added UTs

Author: Kazuaki Ishizaki <ishizaki@jp.ibm.com>

Closes #21053 from kiszk/SPARK-23924.
  • Loading branch information
kiszk authored and ueshin committed Apr 19, 2018
1 parent d5bec48 commit 46bb2b5
Show file tree
Hide file tree
Showing 7 changed files with 276 additions and 24 deletions.
24 changes: 24 additions & 0 deletions python/pyspark/sql/functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -1862,6 +1862,30 @@ def array_position(col, value):
return Column(sc._jvm.functions.array_position(_to_java_column(col), value))


@ignore_unicode_prefix
@since(2.4)
def element_at(col, extraction):
"""
Collection function: Returns element of array at given index in extraction if col is array.
Returns value for the given key in extraction if col is map.
:param col: name of column containing array or map
:param extraction: index to check for in array or key to check for in map
.. note:: The position is not zero based, but 1 based index.
>>> df = spark.createDataFrame([(["a", "b", "c"],), ([],)], ['data'])
>>> df.select(element_at(df.data, 1)).collect()
[Row(element_at(data, 1)=u'a'), Row(element_at(data, 1)=None)]
>>> df = spark.createDataFrame([({"a": 1.0, "b": 2.0},), ({},)], ['data'])
>>> df.select(element_at(df.data, "a")).collect()
[Row(element_at(data, a)=1.0), Row(element_at(data, a)=None)]
"""
sc = SparkContext._active_spark_context
return Column(sc._jvm.functions.element_at(_to_java_column(col), extraction))


@since(1.4)
def explode(col):
"""Returns a new row for each element in the given array or map.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -405,6 +405,7 @@ object FunctionRegistry {
expression[ArrayPosition]("array_position"),
expression[CreateMap]("map"),
expression[CreateNamedStruct]("named_struct"),
expression[ElementAt]("element_at"),
expression[MapKeys]("map_keys"),
expression[MapValues]("map_values"),
expression[Size]("size"),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -561,3 +561,107 @@ case class ArrayPosition(left: Expression, right: Expression)
})
}
}

/**
* Returns the value of index `right` in Array `left` or the value for key `right` in Map `left`.
*/
@ExpressionDescription(
usage = """
_FUNC_(array, index) - Returns element of array at given (1-based) index. If index < 0,
accesses elements from the last to the first. Returns NULL if the index exceeds the length
of the array.
_FUNC_(map, key) - Returns value for given key, or NULL if the key is not contained in the map
""",
examples = """
Examples:
> SELECT _FUNC_(array(1, 2, 3), 2);
2
> SELECT _FUNC_(map(1, 'a', 2, 'b'), 2);
"b"
""",
since = "2.4.0")
case class ElementAt(left: Expression, right: Expression) extends GetMapValueUtil {

override def dataType: DataType = left.dataType match {
case ArrayType(elementType, _) => elementType
case MapType(_, valueType, _) => valueType
}

override def inputTypes: Seq[AbstractDataType] = {
Seq(TypeCollection(ArrayType, MapType),
left.dataType match {
case _: ArrayType => IntegerType
case _: MapType => left.dataType.asInstanceOf[MapType].keyType
}
)
}

override def nullable: Boolean = true

override def nullSafeEval(value: Any, ordinal: Any): Any = {
left.dataType match {
case _: ArrayType =>
val array = value.asInstanceOf[ArrayData]
val index = ordinal.asInstanceOf[Int]
if (array.numElements() < math.abs(index)) {
null
} else {
val idx = if (index == 0) {
throw new ArrayIndexOutOfBoundsException("SQL array indices start at 1")
} else if (index > 0) {
index - 1
} else {
array.numElements() + index
}
if (left.dataType.asInstanceOf[ArrayType].containsNull && array.isNullAt(idx)) {
null
} else {
array.get(idx, dataType)
}
}
case _: MapType =>
getValueEval(value, ordinal, left.dataType.asInstanceOf[MapType].keyType)
}
}

override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
left.dataType match {
case _: ArrayType =>
nullSafeCodeGen(ctx, ev, (eval1, eval2) => {
val index = ctx.freshName("elementAtIndex")
val nullCheck = if (left.dataType.asInstanceOf[ArrayType].containsNull) {
s"""
|if ($eval1.isNullAt($index)) {
| ${ev.isNull} = true;
|} else
""".stripMargin
} else {
""
}
s"""
|int $index = (int) $eval2;
|if ($eval1.numElements() < Math.abs($index)) {
| ${ev.isNull} = true;
|} else {
| if ($index == 0) {
| throw new ArrayIndexOutOfBoundsException("SQL array indices start at 1");
| } else if ($index > 0) {
| $index--;
| } else {
| $index += $eval1.numElements();
| }
| $nullCheck
| {
| ${ev.value} = ${CodeGenerator.getValue(eval1, dataType, index)};
| }
|}
""".stripMargin
})
case _: MapType =>
doGetValueGenCode(ctx, ev, left.dataType.asInstanceOf[MapType])
}
}

override def prettyName: String = "element_at"
}
Original file line number Diff line number Diff line change
Expand Up @@ -268,31 +268,12 @@ case class GetArrayItem(child: Expression, ordinal: Expression)
}

/**
* Returns the value of key `key` in Map `child`.
*
* We need to do type checking here as `key` expression maybe unresolved.
* Common base class for [[GetMapValue]] and [[ElementAt]].
*/
case class GetMapValue(child: Expression, key: Expression)
extends BinaryExpression with ImplicitCastInputTypes with ExtractValue with NullIntolerant {

private def keyType = child.dataType.asInstanceOf[MapType].keyType

// We have done type checking for child in `ExtractValue`, so only need to check the `key`.
override def inputTypes: Seq[AbstractDataType] = Seq(AnyDataType, keyType)

override def toString: String = s"$child[$key]"
override def sql: String = s"${child.sql}[${key.sql}]"

override def left: Expression = child
override def right: Expression = key

/** `Null` is returned for invalid ordinals. */
override def nullable: Boolean = true

override def dataType: DataType = child.dataType.asInstanceOf[MapType].valueType

abstract class GetMapValueUtil extends BinaryExpression with ImplicitCastInputTypes {
// todo: current search is O(n), improve it.
protected override def nullSafeEval(value: Any, ordinal: Any): Any = {
def getValueEval(value: Any, ordinal: Any, keyType: DataType): Any = {
val map = value.asInstanceOf[MapData]
val length = map.numElements()
val keys = map.keyArray()
Expand All @@ -315,14 +296,15 @@ case class GetMapValue(child: Expression, key: Expression)
}
}

override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
def doGetValueGenCode(ctx: CodegenContext, ev: ExprCode, mapType: MapType): ExprCode = {
val index = ctx.freshName("index")
val length = ctx.freshName("length")
val keys = ctx.freshName("keys")
val found = ctx.freshName("found")
val key = ctx.freshName("key")
val values = ctx.freshName("values")
val nullCheck = if (child.dataType.asInstanceOf[MapType].valueContainsNull) {
val keyType = mapType.keyType
val nullCheck = if (mapType.valueContainsNull) {
s" || $values.isNullAt($index)"
} else {
""
Expand Down Expand Up @@ -354,3 +336,37 @@ case class GetMapValue(child: Expression, key: Expression)
})
}
}

/**
* Returns the value of key `key` in Map `child`.
*
* We need to do type checking here as `key` expression maybe unresolved.
*/
case class GetMapValue(child: Expression, key: Expression)
extends GetMapValueUtil with ExtractValue with NullIntolerant {

private def keyType = child.dataType.asInstanceOf[MapType].keyType

// We have done type checking for child in `ExtractValue`, so only need to check the `key`.
override def inputTypes: Seq[AbstractDataType] = Seq(AnyDataType, keyType)

override def toString: String = s"$child[$key]"
override def sql: String = s"${child.sql}[${key.sql}]"

override def left: Expression = child
override def right: Expression = key

/** `Null` is returned for invalid ordinals. */
override def nullable: Boolean = true

override def dataType: DataType = child.dataType.asInstanceOf[MapType].valueType

// todo: current search is O(n), improve it.
override def nullSafeEval(value: Any, ordinal: Any): Any = {
getValueEval(value, ordinal, keyType)
}

override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
doGetValueGenCode(ctx, ev, child.dataType.asInstanceOf[MapType])
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -191,4 +191,52 @@ class CollectionExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper
checkEvaluation(ArrayPosition(a3, Literal("")), null)
checkEvaluation(ArrayPosition(a3, Literal.create(null, StringType)), null)
}

test("elementAt") {
val a0 = Literal.create(Seq(1, 2, 3), ArrayType(IntegerType))
val a1 = Literal.create(Seq[String](null, ""), ArrayType(StringType))
val a2 = Literal.create(Seq(null), ArrayType(LongType))
val a3 = Literal.create(null, ArrayType(StringType))

intercept[Exception] {
checkEvaluation(ElementAt(a0, Literal(0)), null)
}.getMessage.contains("SQL array indices start at 1")
intercept[Exception] { checkEvaluation(ElementAt(a0, Literal(1.1)), null) }
checkEvaluation(ElementAt(a0, Literal(4)), null)
checkEvaluation(ElementAt(a0, Literal(-4)), null)

checkEvaluation(ElementAt(a0, Literal(1)), 1)
checkEvaluation(ElementAt(a0, Literal(2)), 2)
checkEvaluation(ElementAt(a0, Literal(3)), 3)
checkEvaluation(ElementAt(a0, Literal(-3)), 1)
checkEvaluation(ElementAt(a0, Literal(-2)), 2)
checkEvaluation(ElementAt(a0, Literal(-1)), 3)

checkEvaluation(ElementAt(a1, Literal(1)), null)
checkEvaluation(ElementAt(a1, Literal(2)), "")
checkEvaluation(ElementAt(a1, Literal(-2)), null)
checkEvaluation(ElementAt(a1, Literal(-1)), "")

checkEvaluation(ElementAt(a2, Literal(1)), null)

checkEvaluation(ElementAt(a3, Literal(1)), null)


val m0 =
Literal.create(Map("a" -> "1", "b" -> "2", "c" -> null), MapType(StringType, StringType))
val m1 = Literal.create(Map[String, String](), MapType(StringType, StringType))
val m2 = Literal.create(null, MapType(StringType, StringType))

checkEvaluation(ElementAt(m0, Literal(1.0)), null)

checkEvaluation(ElementAt(m0, Literal("d")), null)

checkEvaluation(ElementAt(m1, Literal("a")), null)

checkEvaluation(ElementAt(m0, Literal("a")), "1")
checkEvaluation(ElementAt(m0, Literal("b")), "2")
checkEvaluation(ElementAt(m0, Literal("c")), null)

checkEvaluation(ElementAt(m2, Literal("a")), null)
}
}
11 changes: 11 additions & 0 deletions sql/core/src/main/scala/org/apache/spark/sql/functions.scala
Original file line number Diff line number Diff line change
Expand Up @@ -3052,6 +3052,17 @@ object functions {
ArrayPosition(column.expr, Literal(value))
}

/**
* Returns element of array at given index in value if column is array. Returns value for
* the given key in value if column is map.
*
* @group collection_funcs
* @since 2.4.0
*/
def element_at(column: Column, value: Any): Column = withExpr {
ElementAt(column.expr, Literal(value))
}

/**
* Creates a new row for each element in the given array or map column.
*
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -569,6 +569,54 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSQLContext {
)
}

test("element_at function") {
val df = Seq(
(Seq[String]("1", "2", "3")),
(Seq[String](null, "")),
(Seq[String]())
).toDF("a")

intercept[Exception] {
checkAnswer(
df.select(element_at(df("a"), 0)),
Seq(Row(null), Row(null), Row(null))
)
}.getMessage.contains("SQL array indices start at 1")
intercept[Exception] {
checkAnswer(
df.select(element_at(df("a"), 1.1)),
Seq(Row(null), Row(null), Row(null))
)
}
checkAnswer(
df.select(element_at(df("a"), 4)),
Seq(Row(null), Row(null), Row(null))
)

checkAnswer(
df.select(element_at(df("a"), 1)),
Seq(Row("1"), Row(null), Row(null))
)
checkAnswer(
df.select(element_at(df("a"), -1)),
Seq(Row("3"), Row(""), Row(null))
)

checkAnswer(
df.selectExpr("element_at(a, 4)"),
Seq(Row(null), Row(null), Row(null))
)

checkAnswer(
df.selectExpr("element_at(a, 1)"),
Seq(Row("1"), Row(null), Row(null))
)
checkAnswer(
df.selectExpr("element_at(a, -1)"),
Seq(Row("3"), Row(""), Row(null))
)
}

private def assertValuesDoNotChangeAfterCoalesceOrUnion(v: Column): Unit = {
import DataFrameFunctionsSuite.CodegenFallbackExpr
for ((codegenFallback, wholeStage) <- Seq((true, false), (false, false), (false, true))) {
Expand Down

0 comments on commit 46bb2b5

Please sign in to comment.