diff --git a/docs/configs.md b/docs/configs.md
index 8614134631d..805ca0cedd2 100644
--- a/docs/configs.md
+++ b/docs/configs.md
@@ -172,6 +172,7 @@ Name | SQL Function(s) | Description | Default Value | Notes
spark.rapids.sql.expression.DayOfWeek|`dayofweek`|Returns the day of the week (1 = Sunday...7=Saturday)|true|None|
spark.rapids.sql.expression.DayOfYear|`dayofyear`|Returns the day of the year from a date or timestamp|true|None|
spark.rapids.sql.expression.Divide|`/`|Division|true|None|
+spark.rapids.sql.expression.ElementAt|`element_at`|Returns element of array at given(1-based) index in value if column is array. Returns value for the given key in value if column is map.|true|None|
spark.rapids.sql.expression.EndsWith| |Ends with|true|None|
spark.rapids.sql.expression.EqualNullSafe|`<=>`|Check if the values are equal including nulls <=>|true|None|
spark.rapids.sql.expression.EqualTo|`=`, `==`|Check if the values are equal|true|None|
diff --git a/docs/supported_ops.md b/docs/supported_ops.md
index e1f93ea88fb..10ae14bce27 100644
--- a/docs/supported_ops.md
+++ b/docs/supported_ops.md
@@ -5367,6 +5367,164 @@ Accelerator support is described below.
|
+ElementAt |
+`element_at` |
+Returns element of array at given(1-based) index in value if column is array. Returns value for the given key in value if column is map. |
+None |
+project |
+array/map |
+ |
+ |
+ |
+ |
+ |
+ |
+ |
+ |
+ |
+ |
+ |
+ |
+ |
+ |
+PS* (missing nested BINARY, CALENDAR, UDT) |
+PS* (If it's map, only string is supported.; missing nested BINARY, CALENDAR, UDT) |
+ |
+ |
+
+
+index/key |
+NS |
+NS |
+NS |
+PS (ints are only supported as array indexes, not as maps keys; Literal value only) |
+NS |
+NS |
+NS |
+NS |
+NS |
+PS (strings are only supported as map keys, not array indexes; Literal value only) |
+NS |
+NS |
+NS |
+NS |
+NS |
+NS |
+NS |
+NS |
+
+
+result |
+S |
+S |
+S |
+S |
+S |
+S |
+S |
+S |
+S* |
+S |
+S* |
+S |
+NS |
+NS |
+PS* (missing nested BINARY, CALENDAR, UDT) |
+PS* (missing nested BINARY, CALENDAR, UDT) |
+PS* (missing nested BINARY, CALENDAR, UDT) |
+NS |
+
+
+lambda |
+array/map |
+ |
+ |
+ |
+ |
+ |
+ |
+ |
+ |
+ |
+ |
+ |
+ |
+ |
+ |
+NS |
+NS |
+ |
+ |
+
+
+index/key |
+NS |
+NS |
+NS |
+NS |
+NS |
+NS |
+NS |
+NS |
+NS |
+NS |
+NS |
+NS |
+NS |
+NS |
+NS |
+NS |
+NS |
+NS |
+
+
+result |
+NS |
+NS |
+NS |
+NS |
+NS |
+NS |
+NS |
+NS |
+NS |
+NS |
+NS |
+NS |
+NS |
+NS |
+NS |
+NS |
+NS |
+NS |
+
+
+Expression |
+SQL Functions(s) |
+Description |
+Notes |
+Context |
+Param/Output |
+BOOLEAN |
+BYTE |
+SHORT |
+INT |
+LONG |
+FLOAT |
+DOUBLE |
+DATE |
+TIMESTAMP |
+STRING |
+DECIMAL |
+NULL |
+BINARY |
+CALENDAR |
+ARRAY |
+MAP |
+STRUCT |
+UDT |
+
+
EndsWith |
|
Ends with |
@@ -5499,32 +5657,6 @@ Accelerator support is described below.
|
-Expression |
-SQL Functions(s) |
-Description |
-Notes |
-Context |
-Param/Output |
-BOOLEAN |
-BYTE |
-SHORT |
-INT |
-LONG |
-FLOAT |
-DOUBLE |
-DATE |
-TIMESTAMP |
-STRING |
-DECIMAL |
-NULL |
-BINARY |
-CALENDAR |
-ARRAY |
-MAP |
-STRUCT |
-UDT |
-
-
EqualNullSafe |
`<=>` |
Check if the values are equal including nulls <=> |
@@ -5789,6 +5921,32 @@ Accelerator support is described below.
NS |
+Expression |
+SQL Functions(s) |
+Description |
+Notes |
+Context |
+Param/Output |
+BOOLEAN |
+BYTE |
+SHORT |
+INT |
+LONG |
+FLOAT |
+DOUBLE |
+DATE |
+TIMESTAMP |
+STRING |
+DECIMAL |
+NULL |
+BINARY |
+CALENDAR |
+ARRAY |
+MAP |
+STRUCT |
+UDT |
+
+
Exp |
`exp` |
Euler's number e raised to a power |
@@ -5879,32 +6037,6 @@ Accelerator support is described below.
|
-Expression |
-SQL Functions(s) |
-Description |
-Notes |
-Context |
-Param/Output |
-BOOLEAN |
-BYTE |
-SHORT |
-INT |
-LONG |
-FLOAT |
-DOUBLE |
-DATE |
-TIMESTAMP |
-STRING |
-DECIMAL |
-NULL |
-BINARY |
-CALENDAR |
-ARRAY |
-MAP |
-STRUCT |
-UDT |
-
-
Explode |
`explode`, `explode_outer` |
Given an input array produces a sequence of rows for each value in the array. |
diff --git a/integration_tests/src/main/python/array_test.py b/integration_tests/src/main/python/array_test.py
index df4ddcb420c..7005aee766a 100644
--- a/integration_tests/src/main/python/array_test.py
+++ b/integration_tests/src/main/python/array_test.py
@@ -18,7 +18,7 @@
from conftest import is_dataproc_runtime
from data_gen import *
from pyspark.sql.types import *
-from pyspark.sql.functions import array_contains, col, first, isnan, lit
+from pyspark.sql.functions import array_contains, col, first, isnan, lit, element_at
# Once we support arrays as literals then we can support a[null] and
# negative indexes for all array gens. When that happens
@@ -109,3 +109,11 @@ def main_df(spark):
chk_val = df.select(col('a')[0].alias('t')).filter(~isnan(col('t'))).collect()[0][0]
return df.select(array_contains(col('a'), chk_val))
assert_gpu_and_cpu_are_equal_collect(main_df)
+
+@pytest.mark.parametrize('data_gen', array_gens_sample, ids=idfn)
+def test_array_element_at(data_gen):
+ assert_gpu_and_cpu_are_equal_collect(lambda spark: unary_op_df(
+ spark, data_gen).select(element_at(col('a'), 1),
+ element_at(col('a'), -1)),
+ conf={'spark.sql.ansi.enabled':False,
+ 'spark.sql.legacy.allowNegativeScaleOfDecimal': True})
diff --git a/integration_tests/src/main/python/map_test.py b/integration_tests/src/main/python/map_test.py
index bacb86d1326..f3724f01ac1 100644
--- a/integration_tests/src/main/python/map_test.py
+++ b/integration_tests/src/main/python/map_test.py
@@ -30,3 +30,15 @@ def test_simple_get_map_value(data_gen):
'a["key_9"]',
'a["NOT_FOUND"]',
'a["key_5"]'))
+
+@pytest.mark.parametrize('data_gen', [simple_string_to_string_map_gen], ids=idfn)
+def test_simple_element_at_map(data_gen):
+ assert_gpu_and_cpu_are_equal_collect(
+ lambda spark : unary_op_df(spark, data_gen).selectExpr(
+ 'element_at(a, "key_0")',
+ 'element_at(a, "key_1")',
+ 'element_at(a, "null")',
+ 'element_at(a, "key_9")',
+ 'element_at(a, "NOT_FOUND")',
+ 'element_at(a, "key_5")'),
+ conf={'spark.sql.ansi.enabled':False})
diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuOverrides.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuOverrides.scala
index c4a2a1406d8..fbdfc17a044 100644
--- a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuOverrides.scala
+++ b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuOverrides.scala
@@ -2271,6 +2271,50 @@ object GpuOverrides {
("map", TypeSig.MAP.nested(TypeSig.STRING), TypeSig.MAP.nested(TypeSig.all)),
("key", TypeSig.lit(TypeEnum.STRING), TypeSig.all)),
(in, conf, p, r) => new GpuGetMapValueMeta(in, conf, p, r)),
+ expr[ElementAt](
+ "Returns element of array at given(1-based) index in value if column is array. " +
+ "Returns value for the given key in value if column is map.",
+ ExprChecks.binaryProjectNotLambda(
+ (TypeSig.commonCudfTypes + TypeSig.ARRAY + TypeSig.STRUCT + TypeSig.NULL +
+ TypeSig.DECIMAL + TypeSig.MAP).nested(), TypeSig.all,
+ ("array/map", TypeSig.ARRAY.nested(TypeSig.commonCudfTypes + TypeSig.ARRAY +
+ TypeSig.STRUCT + TypeSig.NULL + TypeSig.DECIMAL + TypeSig.MAP) +
+ TypeSig.MAP.nested(TypeSig.STRING)
+ .withPsNote(TypeEnum.MAP ,"If it's map, only string is supported."),
+ TypeSig.ARRAY.nested(TypeSig.all) + TypeSig.MAP.nested(TypeSig.all)),
+ ("index/key", (TypeSig.lit(TypeEnum.INT) + TypeSig.lit(TypeEnum.STRING))
+ .withPsNote(TypeEnum.INT, "ints are only supported as array indexes, " +
+ "not as maps keys")
+ .withPsNote(TypeEnum.STRING, "strings are only supported as map keys, " +
+ "not array indexes"),
+ TypeSig.all)),
+ (in, conf, p, r) => new BinaryExprMeta[ElementAt](in, conf, p, r) {
+ override def tagExprForGpu(): Unit = {
+ // To distinguish the supported nested type between Array and Map
+ val checks = in.left.dataType match {
+ case _: MapType =>
+ // Match exactly with the checks for GetMapValue
+ ExprChecks.binaryProjectNotLambda(TypeSig.STRING, TypeSig.all,
+ ("map", TypeSig.MAP.nested(TypeSig.STRING), TypeSig.MAP.nested(TypeSig.all)),
+ ("key", TypeSig.lit(TypeEnum.STRING), TypeSig.all))
+ case _: ArrayType =>
+ // Match exactly with the checks for GetArrayItem
+ ExprChecks.binaryProjectNotLambda(
+ (TypeSig.commonCudfTypes + TypeSig.ARRAY + TypeSig.STRUCT + TypeSig.NULL +
+ TypeSig.DECIMAL + TypeSig.MAP).nested(),
+ TypeSig.all,
+ ("array", TypeSig.ARRAY.nested(TypeSig.commonCudfTypes + TypeSig.ARRAY +
+ TypeSig.STRUCT + TypeSig.NULL + TypeSig.DECIMAL + TypeSig.MAP),
+ TypeSig.ARRAY.nested(TypeSig.all)),
+ ("ordinal", TypeSig.lit(TypeEnum.INT), TypeSig.INT))
+ case _ => throw new IllegalStateException("Only Array or Map is supported as input.")
+ }
+ checks.tag(this)
+ }
+ override def convertToGpu(lhs: Expression, rhs: Expression): GpuExpression = {
+ GpuElementAt(lhs, rhs)
+ }
+ }),
expr[CreateNamedStruct](
"Creates a struct with the given field names and values",
CreateNamedStructCheck,
diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/collectionOperations.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/collectionOperations.scala
deleted file mode 100644
index 028941ba40a..00000000000
--- a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/collectionOperations.scala
+++ /dev/null
@@ -1,49 +0,0 @@
-/*
- * Copyright (c) 2021, NVIDIA CORPORATION.
- *
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package com.nvidia.spark.rapids
-
-import ai.rapids.cudf.{ColumnVector, Scalar}
-
-import org.apache.spark.sql.catalyst.expressions.Expression
-import org.apache.spark.sql.types._
-
-case class GpuSize(child: Expression, legacySizeOfNull: Boolean)
- extends GpuUnaryExpression {
-
- require(child.dataType.isInstanceOf[ArrayType] || child.dataType.isInstanceOf[MapType],
- s"The size function doesn't support the operand type ${child.dataType}")
-
- override def dataType: DataType = IntegerType
- override def nullable: Boolean = if (legacySizeOfNull) false else super.nullable
-
- override protected def doColumnar(input: GpuColumnVector): ColumnVector = {
-
- // Compute sizes of cuDF.ListType to get sizes of each ArrayData or MapData, considering
- // MapData is represented as List of Struct in terms of cuDF.
- withResource(input.getBase.countElements()) { collectionSize =>
- if (legacySizeOfNull) {
- withResource(Scalar.fromInt(-1)) { nullScalar =>
- withResource(input.getBase.isNull) { inputIsNull =>
- inputIsNull.ifElse(nullScalar, collectionSize)
- }
- }
- } else {
- collectionSize.incRefCount()
- }
- }
- }
-}
diff --git a/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/collectionOperations.scala b/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/collectionOperations.scala
new file mode 100644
index 00000000000..19b198e3d35
--- /dev/null
+++ b/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/collectionOperations.scala
@@ -0,0 +1,131 @@
+/*
+ * Copyright (c) 2021, NVIDIA CORPORATION.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.rapids
+
+import ai.rapids.cudf.{ColumnVector, Scalar}
+import com.nvidia.spark.rapids.{GpuBinaryExpression, GpuColumnVector, GpuScalar, GpuUnaryExpression}
+
+import org.apache.spark.sql.catalyst.analysis.{TypeCheckResult, TypeCoercion}
+import org.apache.spark.sql.catalyst.expressions.{ExpectsInputTypes, Expression}
+import org.apache.spark.sql.types._
+
+case class GpuSize(child: Expression, legacySizeOfNull: Boolean)
+ extends GpuUnaryExpression {
+
+ require(child.dataType.isInstanceOf[ArrayType] || child.dataType.isInstanceOf[MapType],
+ s"The size function doesn't support the operand type ${child.dataType}")
+
+ override def dataType: DataType = IntegerType
+ override def nullable: Boolean = if (legacySizeOfNull) false else super.nullable
+
+ override protected def doColumnar(input: GpuColumnVector): ColumnVector = {
+
+ // Compute sizes of cuDF.ListType to get sizes of each ArrayData or MapData, considering
+ // MapData is represented as List of Struct in terms of cuDF.
+ withResource(input.getBase.countElements()) { collectionSize =>
+ if (legacySizeOfNull) {
+ withResource(Scalar.fromInt(-1)) { nullScalar =>
+ withResource(input.getBase.isNull) { inputIsNull =>
+ inputIsNull.ifElse(nullScalar, collectionSize)
+ }
+ }
+ } else {
+ collectionSize.incRefCount()
+ }
+ }
+ }
+}
+
+case class GpuElementAt(left: Expression, right: Expression)
+ extends GpuBinaryExpression with ExpectsInputTypes {
+
+ override lazy val dataType: DataType = left.dataType match {
+ case ArrayType(elementType, _) => elementType
+ case MapType(_, valueType, _) => valueType
+ }
+
+ override def inputTypes: Seq[AbstractDataType] = {
+ (left.dataType, right.dataType) match {
+ case (arr: ArrayType, e2: IntegralType) if (e2 != LongType) =>
+ Seq(arr, IntegerType)
+ case (MapType(keyType, valueType, hasNull), e2) =>
+ TypeCoercion.findTightestCommonType(keyType, e2) match {
+ case Some(dt) => Seq(MapType(dt, valueType, hasNull), dt)
+ case _ => Seq.empty
+ }
+ case (l, r) => Seq.empty
+ }
+ }
+
+ override def checkInputDataTypes(): TypeCheckResult = {
+ (left.dataType, right.dataType) match {
+ case (_: ArrayType, e2) if e2 != IntegerType =>
+ TypeCheckResult.TypeCheckFailure(s"Input to function $prettyName should have " +
+ s"been ${ArrayType.simpleString} followed by a ${IntegerType.simpleString}, but it's " +
+ s"[${left.dataType.catalogString}, ${right.dataType.catalogString}].")
+ case (MapType(e1, _, _), e2) if (!e2.sameType(e1)) =>
+ TypeCheckResult.TypeCheckFailure(s"Input to function $prettyName should have " +
+ s"been ${MapType.simpleString} followed by a value of same key type, but it's " +
+ s"[${left.dataType.catalogString}, ${right.dataType.catalogString}].")
+ case (e1, _) if (!e1.isInstanceOf[MapType] && !e1.isInstanceOf[ArrayType]) =>
+ TypeCheckResult.TypeCheckFailure(s"The first argument to function $prettyName should " +
+ s"have been ${ArrayType.simpleString} or ${MapType.simpleString} type, but its " +
+ s"${left.dataType.catalogString} type.")
+ case _ => TypeCheckResult.TypeCheckSuccess
+ }
+ }
+
+ // Eventually we need something more full featured like
+ // GetArrayItemUtil.computeNullabilityFromArray
+ override def nullable: Boolean = true
+
+ override def doColumnar(lhs: GpuColumnVector, rhs: GpuColumnVector): ColumnVector =
+ throw new IllegalStateException("This is not supported yet")
+
+ override def doColumnar(lhs: GpuScalar, rhs: GpuColumnVector): ColumnVector =
+ throw new IllegalStateException("This is not supported yet")
+
+ override def doColumnar(lhs: GpuColumnVector, rhs: GpuScalar): ColumnVector = {
+ lhs.dataType match {
+ case _: ArrayType => {
+ if (rhs.isValid) {
+ val ordinalValue = rhs.getValue.asInstanceOf[Int]
+ if (ordinalValue > 0) {
+ // SQL 1-based index
+ lhs.getBase.extractListElement(ordinalValue - 1)
+ } else if (ordinalValue == 0) {
+ throw new ArrayIndexOutOfBoundsException("SQL array indices start at 1")
+ } else {
+ lhs.getBase.extractListElement(ordinalValue)
+ }
+ } else {
+ GpuColumnVector.columnVectorFromNull(lhs.getRowCount.toInt, dataType)
+ }
+ }
+ case _: MapType => {
+ lhs.getBase.getMapValue(rhs.getBase)
+ }
+ }
+ }
+
+ override def doColumnar(numRows: Int, lhs: GpuScalar, rhs: GpuScalar): ColumnVector =
+ withResource(GpuColumnVector.from(lhs, numRows, left.dataType)) { expandedLhs =>
+ doColumnar(expandedLhs, rhs)
+ }
+
+ override def prettyName: String = "element_at"
+}