Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support ElementAt #2260

Merged
merged 24 commits into from
May 15, 2021
Merged
Show file tree
Hide file tree
Changes from 23 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions docs/configs.md
Original file line number Diff line number Diff line change
Expand Up @@ -172,6 +172,7 @@ Name | SQL Function(s) | Description | Default Value | Notes
<a name="sql.expression.DayOfWeek"></a>spark.rapids.sql.expression.DayOfWeek|`dayofweek`|Returns the day of the week (1 = Sunday...7=Saturday)|true|None|
<a name="sql.expression.DayOfYear"></a>spark.rapids.sql.expression.DayOfYear|`dayofyear`|Returns the day of the year from a date or timestamp|true|None|
<a name="sql.expression.Divide"></a>spark.rapids.sql.expression.Divide|`/`|Division|true|None|
<a name="sql.expression.ElementAt"></a>spark.rapids.sql.expression.ElementAt|`element_at`|Returns element of array at given(1-based) index in value if column is array. Returns value for the given key in value if column is map.|true|None|
<a name="sql.expression.EndsWith"></a>spark.rapids.sql.expression.EndsWith| |Ends with|true|None|
<a name="sql.expression.EqualNullSafe"></a>spark.rapids.sql.expression.EqualNullSafe|`<=>`|Check if the values are equal including nulls <=>|true|None|
<a name="sql.expression.EqualTo"></a>spark.rapids.sql.expression.EqualTo|`=`, `==`|Check if the values are equal|true|None|
Expand Down
236 changes: 184 additions & 52 deletions docs/supported_ops.md
Original file line number Diff line number Diff line change
Expand Up @@ -5367,6 +5367,164 @@ Accelerator support is described below.
<td> </td>
</tr>
<tr>
<td rowSpan="6">ElementAt</td>
firestarman marked this conversation as resolved.
Show resolved Hide resolved
<td rowSpan="6">`element_at`</td>
<td rowSpan="6">Returns element of array at given(1-based) index in value if column is array. Returns value for the given key in value if column is map.</td>
<td rowSpan="6">None</td>
<td rowSpan="3">project</td>
<td>array/map</td>
<td> </td>
<td> </td>
<td> </td>
<td> </td>
<td> </td>
<td> </td>
<td> </td>
<td> </td>
<td> </td>
<td> </td>
<td> </td>
<td> </td>
<td> </td>
<td> </td>
<td><em>PS* (missing nested BINARY, CALENDAR, UDT)</em></td>
<td><em>PS* (If it's map, only string is supported.; missing nested BINARY, CALENDAR, UDT)</em></td>
<td> </td>
<td> </td>
</tr>
<tr>
<td>index/key</td>
<td><b>NS</b></td>
<td><b>NS</b></td>
<td><b>NS</b></td>
<td><em>PS (ints are only supported as array indexes, not as maps keys; Literal value only)</em></td>
<td><b>NS</b></td>
<td><b>NS</b></td>
<td><b>NS</b></td>
<td><b>NS</b></td>
<td><b>NS</b></td>
<td><em>PS (strings are only supported as map keys, not array indexes; Literal value only)</em></td>
<td><b>NS</b></td>
<td><b>NS</b></td>
<td><b>NS</b></td>
<td><b>NS</b></td>
<td><b>NS</b></td>
<td><b>NS</b></td>
<td><b>NS</b></td>
<td><b>NS</b></td>
</tr>
<tr>
<td>result</td>
<td>S</td>
<td>S</td>
<td>S</td>
<td>S</td>
<td>S</td>
<td>S</td>
<td>S</td>
<td>S</td>
<td>S*</td>
<td>S</td>
<td>S*</td>
<td>S</td>
<td><b>NS</b></td>
<td><b>NS</b></td>
<td><em>PS* (missing nested BINARY, CALENDAR, UDT)</em></td>
<td><em>PS* (missing nested BINARY, CALENDAR, UDT)</em></td>
<td><em>PS* (missing nested BINARY, CALENDAR, UDT)</em></td>
<td><b>NS</b></td>
</tr>
<tr>
<td rowSpan="3">lambda</td>
<td>array/map</td>
<td> </td>
<td> </td>
<td> </td>
<td> </td>
<td> </td>
<td> </td>
<td> </td>
<td> </td>
<td> </td>
<td> </td>
<td> </td>
<td> </td>
<td> </td>
<td> </td>
<td><b>NS</b></td>
<td><b>NS</b></td>
<td> </td>
<td> </td>
</tr>
<tr>
<td>index/key</td>
<td><b>NS</b></td>
<td><b>NS</b></td>
<td><b>NS</b></td>
<td><b>NS</b></td>
<td><b>NS</b></td>
<td><b>NS</b></td>
<td><b>NS</b></td>
<td><b>NS</b></td>
<td><b>NS</b></td>
<td><b>NS</b></td>
<td><b>NS</b></td>
<td><b>NS</b></td>
<td><b>NS</b></td>
<td><b>NS</b></td>
<td><b>NS</b></td>
<td><b>NS</b></td>
<td><b>NS</b></td>
<td><b>NS</b></td>
</tr>
<tr>
<td>result</td>
<td><b>NS</b></td>
<td><b>NS</b></td>
<td><b>NS</b></td>
<td><b>NS</b></td>
<td><b>NS</b></td>
<td><b>NS</b></td>
<td><b>NS</b></td>
<td><b>NS</b></td>
<td><b>NS</b></td>
<td><b>NS</b></td>
<td><b>NS</b></td>
<td><b>NS</b></td>
<td><b>NS</b></td>
<td><b>NS</b></td>
<td><b>NS</b></td>
<td><b>NS</b></td>
<td><b>NS</b></td>
<td><b>NS</b></td>
</tr>
<tr>
<th>Expression</th>
<th>SQL Functions(s)</th>
<th>Description</th>
<th>Notes</th>
<th>Context</th>
<th>Param/Output</th>
<th>BOOLEAN</th>
<th>BYTE</th>
<th>SHORT</th>
<th>INT</th>
<th>LONG</th>
<th>FLOAT</th>
<th>DOUBLE</th>
<th>DATE</th>
<th>TIMESTAMP</th>
<th>STRING</th>
<th>DECIMAL</th>
<th>NULL</th>
<th>BINARY</th>
<th>CALENDAR</th>
<th>ARRAY</th>
<th>MAP</th>
<th>STRUCT</th>
<th>UDT</th>
</tr>
<tr>
<td rowSpan="6">EndsWith</td>
<td rowSpan="6"> </td>
<td rowSpan="6">Ends with</td>
Expand Down Expand Up @@ -5499,32 +5657,6 @@ Accelerator support is described below.
<td> </td>
</tr>
<tr>
<th>Expression</th>
<th>SQL Functions(s)</th>
<th>Description</th>
<th>Notes</th>
<th>Context</th>
<th>Param/Output</th>
<th>BOOLEAN</th>
<th>BYTE</th>
<th>SHORT</th>
<th>INT</th>
<th>LONG</th>
<th>FLOAT</th>
<th>DOUBLE</th>
<th>DATE</th>
<th>TIMESTAMP</th>
<th>STRING</th>
<th>DECIMAL</th>
<th>NULL</th>
<th>BINARY</th>
<th>CALENDAR</th>
<th>ARRAY</th>
<th>MAP</th>
<th>STRUCT</th>
<th>UDT</th>
</tr>
<tr>
<td rowSpan="6">EqualNullSafe</td>
<td rowSpan="6">`<=>`</td>
<td rowSpan="6">Check if the values are equal including nulls <=></td>
Expand Down Expand Up @@ -5789,6 +5921,32 @@ Accelerator support is described below.
<td><b>NS</b></td>
</tr>
<tr>
<th>Expression</th>
<th>SQL Functions(s)</th>
<th>Description</th>
<th>Notes</th>
<th>Context</th>
<th>Param/Output</th>
<th>BOOLEAN</th>
<th>BYTE</th>
<th>SHORT</th>
<th>INT</th>
<th>LONG</th>
<th>FLOAT</th>
<th>DOUBLE</th>
<th>DATE</th>
<th>TIMESTAMP</th>
<th>STRING</th>
<th>DECIMAL</th>
<th>NULL</th>
<th>BINARY</th>
<th>CALENDAR</th>
<th>ARRAY</th>
<th>MAP</th>
<th>STRUCT</th>
<th>UDT</th>
</tr>
<tr>
<td rowSpan="4">Exp</td>
<td rowSpan="4">`exp`</td>
<td rowSpan="4">Euler's number e raised to a power</td>
Expand Down Expand Up @@ -5879,32 +6037,6 @@ Accelerator support is described below.
<td> </td>
</tr>
<tr>
<th>Expression</th>
<th>SQL Functions(s)</th>
<th>Description</th>
<th>Notes</th>
<th>Context</th>
<th>Param/Output</th>
<th>BOOLEAN</th>
<th>BYTE</th>
<th>SHORT</th>
<th>INT</th>
<th>LONG</th>
<th>FLOAT</th>
<th>DOUBLE</th>
<th>DATE</th>
<th>TIMESTAMP</th>
<th>STRING</th>
<th>DECIMAL</th>
<th>NULL</th>
<th>BINARY</th>
<th>CALENDAR</th>
<th>ARRAY</th>
<th>MAP</th>
<th>STRUCT</th>
<th>UDT</th>
</tr>
<tr>
<td rowSpan="2">Explode</td>
<td rowSpan="2">`explode`, `explode_outer`</td>
<td rowSpan="2">Given an input array produces a sequence of rows for each value in the array.</td>
Expand Down
10 changes: 9 additions & 1 deletion integration_tests/src/main/python/array_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
from conftest import is_dataproc_runtime
from data_gen import *
from pyspark.sql.types import *
from pyspark.sql.functions import array_contains, col, first, isnan, lit
from pyspark.sql.functions import array_contains, col, first, isnan, lit, element_at

# Once we support arrays as literals then we can support a[null] and
# negative indexes for all array gens. When that happens
Expand Down Expand Up @@ -109,3 +109,11 @@ def main_df(spark):
chk_val = df.select(col('a')[0].alias('t')).filter(~isnan(col('t'))).collect()[0][0]
return df.select(array_contains(col('a'), chk_val))
assert_gpu_and_cpu_are_equal_collect(main_df)

@pytest.mark.parametrize('data_gen', array_gens_sample, ids=idfn)
def test_array_element_at(data_gen):
assert_gpu_and_cpu_are_equal_collect(lambda spark: unary_op_df(
spark, data_gen).select(element_at(col('a'), 1),
element_at(col('a'), -1)),
conf={'spark.sql.ansi.enabled':False,
'spark.sql.legacy.allowNegativeScaleOfDecimal': True})
12 changes: 12 additions & 0 deletions integration_tests/src/main/python/map_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,3 +30,15 @@ def test_simple_get_map_value(data_gen):
'a["key_9"]',
'a["NOT_FOUND"]',
'a["key_5"]'))

@pytest.mark.parametrize('data_gen', [simple_string_to_string_map_gen], ids=idfn)
def test_simple_element_at_map(data_gen):
assert_gpu_and_cpu_are_equal_collect(
lambda spark : unary_op_df(spark, data_gen).selectExpr(
'element_at(a, "key_0")',
'element_at(a, "key_1")',
'element_at(a, "null")',
'element_at(a, "key_9")',
'element_at(a, "NOT_FOUND")',
'element_at(a, "key_5")'),
conf={'spark.sql.ansi.enabled':False})
Original file line number Diff line number Diff line change
Expand Up @@ -2271,6 +2271,50 @@ object GpuOverrides {
("map", TypeSig.MAP.nested(TypeSig.STRING), TypeSig.MAP.nested(TypeSig.all)),
("key", TypeSig.lit(TypeEnum.STRING), TypeSig.all)),
(in, conf, p, r) => new GpuGetMapValueMeta(in, conf, p, r)),
expr[ElementAt](
"Returns element of array at given(1-based) index in value if column is array. " +
"Returns value for the given key in value if column is map.",
ExprChecks.binaryProjectNotLambda(
(TypeSig.commonCudfTypes + TypeSig.ARRAY + TypeSig.STRUCT + TypeSig.NULL +
TypeSig.DECIMAL + TypeSig.MAP).nested(), TypeSig.all,
("array/map", TypeSig.ARRAY.nested(TypeSig.commonCudfTypes + TypeSig.ARRAY +
TypeSig.STRUCT + TypeSig.NULL + TypeSig.DECIMAL + TypeSig.MAP) +
TypeSig.MAP.nested(TypeSig.STRING)
.withPsNote(TypeEnum.MAP ,"If it's map, only string is supported."),
TypeSig.ARRAY.nested(TypeSig.all) + TypeSig.MAP.nested(TypeSig.all)),
("index/key", (TypeSig.lit(TypeEnum.INT) + TypeSig.lit(TypeEnum.STRING))
.withPsNote(TypeEnum.INT, "ints are only supported as array indexes, " +
"not as maps keys")
.withPsNote(TypeEnum.STRING, "strings are only supported as map keys, " +
"not array indexes"),
TypeSig.all)),
(in, conf, p, r) => new BinaryExprMeta[ElementAt](in, conf, p, r) {
override def tagExprForGpu(): Unit = {
// To distinguish the supported nested type between Array and Map
val checks = in.left.dataType match {
case _: MapType =>
// Match exactly with the checks for GetMapValue
ExprChecks.binaryProjectNotLambda(TypeSig.STRING, TypeSig.all,
("map", TypeSig.MAP.nested(TypeSig.STRING), TypeSig.MAP.nested(TypeSig.all)),
("key", TypeSig.lit(TypeEnum.STRING), TypeSig.all))
case _: ArrayType =>
// Match exactly with the checks for GetArrayItem
ExprChecks.binaryProjectNotLambda(
(TypeSig.commonCudfTypes + TypeSig.ARRAY + TypeSig.STRUCT + TypeSig.NULL +
TypeSig.DECIMAL + TypeSig.MAP).nested(),
TypeSig.all,
("array", TypeSig.ARRAY.nested(TypeSig.commonCudfTypes + TypeSig.ARRAY +
TypeSig.STRUCT + TypeSig.NULL + TypeSig.DECIMAL + TypeSig.MAP),
TypeSig.ARRAY.nested(TypeSig.all)),
("ordinal", TypeSig.lit(TypeEnum.INT), TypeSig.INT))
case _ => throw new IllegalStateException("Only Array or Map is supported as input.")
}
checks.tag(this)
}
override def convertToGpu(lhs: Expression, rhs: Expression): GpuExpression = {
GpuElementAt(lhs, rhs)
}
}),
expr[CreateNamedStruct](
"Creates a struct with the given field names and values",
CreateNamedStructCheck,
Expand Down
Loading