From fd3967df272fb84a7df3bf155feeaff189eb141a Mon Sep 17 00:00:00 2001 From: Allen Xu Date: Wed, 21 Apr 2021 22:13:06 +0800 Subject: [PATCH 01/21] Initial support for elementAt --- .../nvidia/spark/rapids/GpuOverrides.scala | 9 ++++++ .../sql/rapids/complexTypeExtractors.scala | 29 +++++++++++++++++-- 2 files changed, 35 insertions(+), 3 deletions(-) diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuOverrides.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuOverrides.scala index ed2d85e26b0..bff15928a1f 100644 --- a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuOverrides.scala +++ b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuOverrides.scala @@ -2272,6 +2272,15 @@ object GpuOverrides { ("map", TypeSig.MAP.nested(TypeSig.STRING), TypeSig.MAP.nested(TypeSig.all)), ("key", TypeSig.lit(TypeEnum.STRING), TypeSig.all)), (in, conf, p, r) => new GpuGetMapValueMeta(in, conf, p, r)), + expr[ElementAt]( + "Returns element of array at given(1-based) index in value if column is array. " + + "Returns value for the given key in value if column is map.", + ExprChecks.binaryProjectNotLambda( + TypeSig.commonCudfTypes, TypeSig.all, + ("left", TypeSig.ARRAY.nested(TypeSig.commonCudfTypes) + + TypeSig.MAP.nested(TypeSig.STRING), TypeSig.all), + ("right", TypeSig.lit(TypeEnum.INT) + TypeSig.lit(TypeEnum.STRING), TypeSig.all)), + (in, conf, p, r) => new GpuElementAtMeta(in, conf, p, r)), expr[CreateNamedStruct]( "Creates a struct with the given field names and values", CreateNamedStructCheck, diff --git a/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/complexTypeExtractors.scala b/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/complexTypeExtractors.scala index 42b648db7df..d78e9f04ce3 100644 --- a/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/complexTypeExtractors.scala +++ b/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/complexTypeExtractors.scala @@ -19,11 +19,10 @@ package org.apache.spark.sql.rapids import ai.rapids.cudf.{ColumnVector, Scalar} import com.nvidia.spark.rapids.{BinaryExprMeta, DataFromReplacementRule, GpuBinaryExpression, GpuColumnVector, GpuExpression, GpuOverrides, GpuScalar, RapidsConf, RapidsMeta} import com.nvidia.spark.rapids.RapidsPluginImplicits._ - import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.analysis.TypeCheckResult -import org.apache.spark.sql.catalyst.expressions.{ExpectsInputTypes, Expression, ExtractValue, GetArrayItem, GetMapValue, ImplicitCastInputTypes, NullIntolerant, UnaryExpression} -import org.apache.spark.sql.catalyst.util.{quoteIdentifier, TypeUtils} +import org.apache.spark.sql.catalyst.expressions.{ElementAt, ExpectsInputTypes, Expression, ExtractValue, GetArrayItem, GetMapValue, ImplicitCastInputTypes, NullIntolerant, UnaryExpression} +import org.apache.spark.sql.catalyst.util.{TypeUtils, quoteIdentifier} import org.apache.spark.sql.types.{AbstractDataType, AnyDataType, ArrayType, BooleanType, DataType, IntegralType, MapType, StructType} import org.apache.spark.sql.vectorized.ColumnarBatch @@ -186,6 +185,30 @@ case class GpuGetMapValue(child: Expression, key: Expression) override def right: Expression = key } +class GpuElementAtMeta( + expr: ElementAt, + conf: RapidsConf, + parent: Option[RapidsMeta[_, _, _]], + rule: DataFromReplacementRule) + extends BinaryExprMeta[ElementAt](expr, conf, parent, rule) { + override def convertToGpu(lhs: Expression, rhs: Expression): GpuExpression = { + + } +} + +case class GpuElementAt(left: Expression, right: Expression) extends GpuBinaryExpression { + + override def doColumnar(lhs: GpuColumnVector, rhs: GpuColumnVector): ColumnVector = ??? + + override def doColumnar(lhs: Scalar, rhs: GpuColumnVector): ColumnVector = ??? + + override def doColumnar(lhs: GpuColumnVector, rhs: Scalar): ColumnVector = ??? + + override def doColumnar(numRows: Int, lhs: Scalar, rhs: Scalar): ColumnVector = ??? + + override def dataType: DataType = ??? +} + /** Checks if the array (left) has the element (right) */ case class GpuArrayContains(left: Expression, right: Expression) From 4e42d39480dd50dc426132af7143395e2f6c0759 Mon Sep 17 00:00:00 2001 From: Allen Xu Date: Fri, 23 Apr 2021 15:17:35 +0800 Subject: [PATCH 02/21] temp work saving --- .../sql/rapids/complexTypeExtractors.scala | 67 ++++++++++++++++--- 1 file changed, 58 insertions(+), 9 deletions(-) diff --git a/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/complexTypeExtractors.scala b/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/complexTypeExtractors.scala index d78e9f04ce3..96eaa34b0c5 100644 --- a/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/complexTypeExtractors.scala +++ b/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/complexTypeExtractors.scala @@ -20,10 +20,10 @@ import ai.rapids.cudf.{ColumnVector, Scalar} import com.nvidia.spark.rapids.{BinaryExprMeta, DataFromReplacementRule, GpuBinaryExpression, GpuColumnVector, GpuExpression, GpuOverrides, GpuScalar, RapidsConf, RapidsMeta} import com.nvidia.spark.rapids.RapidsPluginImplicits._ import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.analysis.TypeCheckResult +import org.apache.spark.sql.catalyst.analysis.{TypeCheckResult, TypeCoercion} import org.apache.spark.sql.catalyst.expressions.{ElementAt, ExpectsInputTypes, Expression, ExtractValue, GetArrayItem, GetMapValue, ImplicitCastInputTypes, NullIntolerant, UnaryExpression} import org.apache.spark.sql.catalyst.util.{TypeUtils, quoteIdentifier} -import org.apache.spark.sql.types.{AbstractDataType, AnyDataType, ArrayType, BooleanType, DataType, IntegralType, MapType, StructType} +import org.apache.spark.sql.types.{AbstractDataType, AnyDataType, ArrayType, BooleanType, DataType, IntegerType, IntegralType, LongType, MapType, StructType} import org.apache.spark.sql.vectorized.ColumnarBatch case class GpuGetStructField(child: Expression, ordinal: Int, name: Option[String] = None) @@ -192,21 +192,70 @@ class GpuElementAtMeta( rule: DataFromReplacementRule) extends BinaryExprMeta[ElementAt](expr, conf, parent, rule) { override def convertToGpu(lhs: Expression, rhs: Expression): GpuExpression = { - + GpuElementAt(lhs, rhs) } } -case class GpuElementAt(left: Expression, right: Expression) extends GpuBinaryExpression { +case class GpuElementAt(left: Expression, right: Expression) + extends GpuBinaryExpression with ExpectsInputTypes { + + // ?? need ? + private lazy val mapKeyType = left.dataType.asInstanceOf[MapType].keyType + private lazy val arrayContainsNull = left.dataType.asInstanceOf[ArrayType].containsNull + + override lazy val dataType: DataType = left.dataType match { + case ArrayType(elementType, _) => elementType + case MapType(_, valueType, _) => valueType + } + + override def inputTypes: Seq[AbstractDataType] = { + (left.dataType, right.dataType) match { + case (arr: ArrayType, e2: IntegralType) if (e2 != LongType) => + Seq(arr, IntegerType) + case (MapType(keyType, valueType, hasNull), e2) => + TypeCoercion.findTightestCommonType(keyType, e2) match { + case Some(dt) => Seq(MapType(dt, valueType, hasNull), dt) + case _ => Seq.empty + } + case (l, r) => Seq.empty + } + } + + override def checkInputDataTypes(): TypeCheckResult = { + (left.dataType, right.dataType) match { + case (_: ArrayType, e2) if e2 != IntegerType => + TypeCheckResult.TypeCheckFailure(s"Input to function $prettyName should have " + + s"been ${ArrayType.simpleString} followed by a ${IntegerType.simpleString}, but it's " + + s"[${left.dataType.catalogString}, ${right.dataType.catalogString}].") + case (MapType(e1, _, _), e2) if (!e2.sameType(e1)) => + TypeCheckResult.TypeCheckFailure(s"Input to function $prettyName should have " + + s"been ${MapType.simpleString} followed by a value of same key type, but it's " + + s"[${left.dataType.catalogString}, ${right.dataType.catalogString}].") + case (e1, _) if (!e1.isInstanceOf[MapType] && !e1.isInstanceOf[ArrayType]) => + TypeCheckResult.TypeCheckFailure(s"The first argument to function $prettyName should " + + s"have been ${ArrayType.simpleString} or ${MapType.simpleString} type, but its " + + s"${left.dataType.catalogString} type.") + case _ => TypeCheckResult.TypeCheckSuccess + } + } + + // Eventually we need something more full featured like + // GetArrayItemUtil.computeNullabilityFromArray + override def nullable: Boolean = true - override def doColumnar(lhs: GpuColumnVector, rhs: GpuColumnVector): ColumnVector = ??? + override def doColumnar(lhs: GpuColumnVector, rhs: GpuColumnVector): ColumnVector = + throw new IllegalStateException("This is not supported yet") - override def doColumnar(lhs: Scalar, rhs: GpuColumnVector): ColumnVector = ??? + override def doColumnar(lhs: Scalar, rhs: GpuColumnVector): ColumnVector = + throw new IllegalStateException("This is not supported yet") - override def doColumnar(lhs: GpuColumnVector, rhs: Scalar): ColumnVector = ??? + override def doColumnar(lhs: GpuColumnVector, rhs: Scalar): ColumnVector = + throw new IllegalStateException("This is not supported yet") - override def doColumnar(numRows: Int, lhs: Scalar, rhs: Scalar): ColumnVector = ??? + override def doColumnar(numRows: Int, lhs: Scalar, rhs: Scalar): ColumnVector = + throw new IllegalStateException("This is not supported yet") - override def dataType: DataType = ??? + override def prettyName: String = "element_at" } /** Checks if the array (left) has the element (right) From 6b83aadc647ff50b889c2033663e5712f334f184 Mon Sep 17 00:00:00 2001 From: Allen Xu Date: Sun, 25 Apr 2021 16:52:13 +0800 Subject: [PATCH 03/21] draft element_at --- .../sql/rapids/complexTypeExtractors.scala | 34 +++++++++++++++---- 1 file changed, 27 insertions(+), 7 deletions(-) diff --git a/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/complexTypeExtractors.scala b/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/complexTypeExtractors.scala index 96eaa34b0c5..24da82f9057 100644 --- a/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/complexTypeExtractors.scala +++ b/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/complexTypeExtractors.scala @@ -199,10 +199,6 @@ class GpuElementAtMeta( case class GpuElementAt(left: Expression, right: Expression) extends GpuBinaryExpression with ExpectsInputTypes { - // ?? need ? - private lazy val mapKeyType = left.dataType.asInstanceOf[MapType].keyType - private lazy val arrayContainsNull = left.dataType.asInstanceOf[ArrayType].containsNull - override lazy val dataType: DataType = left.dataType match { case ArrayType(elementType, _) => elementType case MapType(_, valueType, _) => valueType @@ -249,11 +245,35 @@ case class GpuElementAt(left: Expression, right: Expression) override def doColumnar(lhs: Scalar, rhs: GpuColumnVector): ColumnVector = throw new IllegalStateException("This is not supported yet") - override def doColumnar(lhs: GpuColumnVector, rhs: Scalar): ColumnVector = - throw new IllegalStateException("This is not supported yet") + override def doColumnar(lhs: GpuColumnVector, rhs: Scalar): ColumnVector = { + lhs.dataType match { + case _: ArrayType => { + if (rhs.isValid) { + if (rhs.getInt > 0) { + // SQL 1-based index + lhs.getBase.extractListElement(rhs.getInt - 1) + } else if (rhs.getInt == 0) { + throw new ArrayIndexOutOfBoundsException("SQL array indices start at 1") + } else { + lhs.getBase.extractListElement(rhs.getInt) + } + } else { + withResource(Scalar.fromNull( + GpuColumnVector.getNonNestedRapidsType(dataType))) { nullScalar => + ColumnVector.fromScalar(nullScalar, lhs.getRowCount.toInt) + } + } + } + case _: MapType => { + lhs.getBase.getMapValue(rhs) + } + } + } override def doColumnar(numRows: Int, lhs: Scalar, rhs: Scalar): ColumnVector = - throw new IllegalStateException("This is not supported yet") + withResource(GpuColumnVector.from(lhs, numRows, left.dataType)) { expandedLhs => + doColumnar(expandedLhs, rhs) + } override def prettyName: String = "element_at" } From 3b6425fdb0ce305c734771acd0b98baca3f6fce4 Mon Sep 17 00:00:00 2001 From: Allen Xu Date: Sun, 25 Apr 2021 20:53:46 +0800 Subject: [PATCH 04/21] Support element_at Signed-off-by: Allen Xu --- docs/configs.md | 1 + docs/supported_ops.md | 132 ++++++++++++++++++ .../src/main/python/array_test.py | 11 +- integration_tests/src/main/python/map_test.py | 11 ++ .../sql/rapids/complexTypeExtractors.scala | 3 +- 5 files changed, 156 insertions(+), 2 deletions(-) diff --git a/docs/configs.md b/docs/configs.md index 45b70539676..622c0009ed2 100644 --- a/docs/configs.md +++ b/docs/configs.md @@ -166,6 +166,7 @@ Name | SQL Function(s) | Description | Default Value | Notes spark.rapids.sql.expression.DayOfWeek|`dayofweek`|Returns the day of the week (1 = Sunday...7=Saturday)|true|None| spark.rapids.sql.expression.DayOfYear|`dayofyear`|Returns the day of the year from a date or timestamp|true|None| spark.rapids.sql.expression.Divide|`/`|Division|true|None| +spark.rapids.sql.expression.ElementAt|`element_at`|Returns element of array at given(1-based) index in value if column is array. Returns value for the given key in value if column is map.|true|None| spark.rapids.sql.expression.EndsWith| |Ends with|true|None| spark.rapids.sql.expression.EqualNullSafe|`<=>`|Check if the values are equal including nulls <=>|true|None| spark.rapids.sql.expression.EqualTo|`=`, `==`|Check if the values are equal|true|None| diff --git a/docs/supported_ops.md b/docs/supported_ops.md index 409e619281f..bf835f99b6b 100644 --- a/docs/supported_ops.md +++ b/docs/supported_ops.md @@ -5301,6 +5301,138 @@ Accelerator support is described below. +ElementAt +`element_at` +Returns element of array at given(1-based) index in value if column is array. Returns value for the given key in value if column is map. +None +project +left +NS +NS +NS +NS +NS +NS +NS +NS +NS +NS +NS +NS +NS +NS +PS* (missing nested DECIMAL, NULL, BINARY, CALENDAR, ARRAY, MAP, STRUCT, UDT) +PS* (missing nested DECIMAL, NULL, BINARY, CALENDAR, ARRAY, MAP, STRUCT, UDT) +NS +NS + + +right +NS +NS +NS +PS (Literal value only) +NS +NS +NS +NS +NS +PS (Literal value only) +NS +NS +NS +NS +NS +NS +NS +NS + + +result +S +S +S +S +S +S +S +S +S* +S +NS +NS +NS +NS +NS +NS +NS +NS + + +lambda +left +NS +NS +NS +NS +NS +NS +NS +NS +NS +NS +NS +NS +NS +NS +NS +NS +NS +NS + + +right +NS +NS +NS +NS +NS +NS +NS +NS +NS +NS +NS +NS +NS +NS +NS +NS +NS +NS + + +result +NS +NS +NS +NS +NS +NS +NS +NS +NS +NS +NS +NS +NS +NS +NS +NS +NS +NS + + EndsWith Ends with diff --git a/integration_tests/src/main/python/array_test.py b/integration_tests/src/main/python/array_test.py index df4ddcb420c..3307c15927f 100644 --- a/integration_tests/src/main/python/array_test.py +++ b/integration_tests/src/main/python/array_test.py @@ -18,7 +18,7 @@ from conftest import is_dataproc_runtime from data_gen import * from pyspark.sql.types import * -from pyspark.sql.functions import array_contains, col, first, isnan, lit +from pyspark.sql.functions import array_contains, col, first, isnan, lit, element_at # Once we support arrays as literals then we can support a[null] and # negative indexes for all array gens. When that happens @@ -109,3 +109,12 @@ def main_df(spark): chk_val = df.select(col('a')[0].alias('t')).filter(~isnan(col('t'))).collect()[0][0] return df.select(array_contains(col('a'), chk_val)) assert_gpu_and_cpu_are_equal_collect(main_df) + +@pytest.mark.parametrize('data_gen', [byte_gen, short_gen, int_gen, long_gen, + FloatGen(no_nans=True), DoubleGen(no_nans=True), + string_gen, boolean_gen, date_gen, timestamp_gen], ids=idfn) +def test_array_element_at(data_gen): + arr_gen = ArrayGen(data_gen) + assert_gpu_and_cpu_are_equal_collect(lambda spark: unary_op_df( + spark, arr_gen).select(element_at(col('a'), 1), + element_at(col('a'), -1)), no_nans_conf) \ No newline at end of file diff --git a/integration_tests/src/main/python/map_test.py b/integration_tests/src/main/python/map_test.py index bacb86d1326..f67c2cb44a2 100644 --- a/integration_tests/src/main/python/map_test.py +++ b/integration_tests/src/main/python/map_test.py @@ -30,3 +30,14 @@ def test_simple_get_map_value(data_gen): 'a["key_9"]', 'a["NOT_FOUND"]', 'a["key_5"]')) + +@pytest.mark.parametrize('data_gen', [simple_string_to_string_map_gen], ids=idfn) +def test_simple_map_element_at(data_gen): + assert_gpu_and_cpu_are_equal_collect( + lambda spark : unary_op_df(spark, data_gen).selectExpr( + 'element_at(a, "key_0")', + 'element_at(a, "key_1")', + 'element_at(a, "null")', + 'element_at(a, "key_9")', + 'element_at(a, "NOT_FOUND")', + 'element_at(a, "key_5")')) diff --git a/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/complexTypeExtractors.scala b/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/complexTypeExtractors.scala index 24da82f9057..600a9bc1de3 100644 --- a/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/complexTypeExtractors.scala +++ b/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/complexTypeExtractors.scala @@ -19,10 +19,11 @@ package org.apache.spark.sql.rapids import ai.rapids.cudf.{ColumnVector, Scalar} import com.nvidia.spark.rapids.{BinaryExprMeta, DataFromReplacementRule, GpuBinaryExpression, GpuColumnVector, GpuExpression, GpuOverrides, GpuScalar, RapidsConf, RapidsMeta} import com.nvidia.spark.rapids.RapidsPluginImplicits._ + import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.analysis.{TypeCheckResult, TypeCoercion} import org.apache.spark.sql.catalyst.expressions.{ElementAt, ExpectsInputTypes, Expression, ExtractValue, GetArrayItem, GetMapValue, ImplicitCastInputTypes, NullIntolerant, UnaryExpression} -import org.apache.spark.sql.catalyst.util.{TypeUtils, quoteIdentifier} +import org.apache.spark.sql.catalyst.util.{quoteIdentifier, TypeUtils} import org.apache.spark.sql.types.{AbstractDataType, AnyDataType, ArrayType, BooleanType, DataType, IntegerType, IntegralType, LongType, MapType, StructType} import org.apache.spark.sql.vectorized.ColumnarBatch From fc5f6fc2349760ed6b6a3c558aa6d7d5d9388334 Mon Sep 17 00:00:00 2001 From: Allen Xu Date: Wed, 28 Apr 2021 20:13:10 +0800 Subject: [PATCH 05/21] debug for array of array --- .../src/main/python/array_test.py | 29 +++++++++++++---- integration_tests/src/main/python/map_test.py | 3 +- .../nvidia/spark/rapids/GpuOverrides.scala | 31 ++++++++++++++++--- .../sql/rapids/complexTypeExtractors.scala | 10 ------ 4 files changed, 51 insertions(+), 22 deletions(-) diff --git a/integration_tests/src/main/python/array_test.py b/integration_tests/src/main/python/array_test.py index 3307c15927f..c943cbf02b8 100644 --- a/integration_tests/src/main/python/array_test.py +++ b/integration_tests/src/main/python/array_test.py @@ -110,11 +110,28 @@ def main_df(spark): return df.select(array_contains(col('a'), chk_val)) assert_gpu_and_cpu_are_equal_collect(main_df) -@pytest.mark.parametrize('data_gen', [byte_gen, short_gen, int_gen, long_gen, - FloatGen(no_nans=True), DoubleGen(no_nans=True), - string_gen, boolean_gen, date_gen, timestamp_gen], ids=idfn) +@pytest.mark.parametrize('data_gen', array_gens_sample, ids=idfn) def test_array_element_at(data_gen): - arr_gen = ArrayGen(data_gen) assert_gpu_and_cpu_are_equal_collect(lambda spark: unary_op_df( - spark, arr_gen).select(element_at(col('a'), 1), - element_at(col('a'), -1)), no_nans_conf) \ No newline at end of file + spark, data_gen).select(element_at(col('a'), 1), + element_at(col('a'), -1)), + conf={'spark.sql.ansi.enabled':False}) + + +@pytest.mark.parametrize('data_gen', array_gens_sample, ids=idfn) +def test_array_element_at_null(data_gen): + array_gen = ArrayGen(data_gen) + assert_gpu_and_cpu_are_equal_collect(lambda spark: unary_op_df( + spark, data_gen).select(element_at(col('a'), 1), + element_at(col('a'), -1)), + conf={'spark.sql.ansi.enabled':False, + 'spark.sql.legacy.allowNegativeScaleOfDecimal': True}) + +@pytest.mark.parametrize('data_gen', [ArrayGen(ArrayGen(short_gen, max_length=10), max_length=10)], ids=idfn) +def test_array_element_at_test(data_gen): + array_gen = ArrayGen(data_gen) + assert_gpu_and_cpu_are_equal_collect(lambda spark: unary_op_df( + spark, data_gen).select(element_at(col('a'), 1), + element_at(col('a'), -1)), + conf={'spark.sql.ansi.enabled':False, + 'spark.sql.legacy.allowNegativeScaleOfDecimal': True}) \ No newline at end of file diff --git a/integration_tests/src/main/python/map_test.py b/integration_tests/src/main/python/map_test.py index f67c2cb44a2..372a99fce9c 100644 --- a/integration_tests/src/main/python/map_test.py +++ b/integration_tests/src/main/python/map_test.py @@ -40,4 +40,5 @@ def test_simple_map_element_at(data_gen): 'element_at(a, "null")', 'element_at(a, "key_9")', 'element_at(a, "NOT_FOUND")', - 'element_at(a, "key_5")')) + 'element_at(a, "key_5")'), + conf={'spark.sql.ansi.enabled':False}) diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuOverrides.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuOverrides.scala index bff15928a1f..5502882871f 100644 --- a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuOverrides.scala +++ b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuOverrides.scala @@ -2276,11 +2276,32 @@ object GpuOverrides { "Returns element of array at given(1-based) index in value if column is array. " + "Returns value for the given key in value if column is map.", ExprChecks.binaryProjectNotLambda( - TypeSig.commonCudfTypes, TypeSig.all, - ("left", TypeSig.ARRAY.nested(TypeSig.commonCudfTypes) + - TypeSig.MAP.nested(TypeSig.STRING), TypeSig.all), - ("right", TypeSig.lit(TypeEnum.INT) + TypeSig.lit(TypeEnum.STRING), TypeSig.all)), - (in, conf, p, r) => new GpuElementAtMeta(in, conf, p, r)), + TypeSig.commonCudfTypes + TypeSig.ARRAY + TypeSig.STRUCT + TypeSig.NULL + + TypeSig.DECIMAL + TypeSig.MAP, TypeSig.all, + ("array/map", TypeSig.ARRAY.nested(TypeSig.commonCudfTypes + TypeSig.ARRAY + + TypeSig.STRUCT + TypeSig.NULL + TypeSig.DECIMAL + TypeSig.MAP) + + TypeSig.MAP.nested(TypeSig.STRING) + .withPsNote(TypeEnum.MAP ,"If it's map, only string is supported. " + + "Extra check is inside the expression metadata"), TypeSig.all), + ("index/key", TypeSig.lit(TypeEnum.INT) + TypeSig.lit(TypeEnum.STRING), TypeSig.all)), + (in, conf, p, r) => new BinaryExprMeta[ElementAt](in, conf, p, r) { + override def tagExprForGpu(): Unit = { + // To distinguish the supported nested type between Array and Map + in.left.dataType match { + case MapType(_,valueType,_) => { + valueType match { + case StringType => // minimum support + case _ => willNotWorkOnGpu(s"${valueType.simpleString} is not supported for" + + s" Map value") + } + } + case ArrayType(_,_) => // Array supports more + } + } + override def convertToGpu(lhs: Expression, rhs: Expression): GpuExpression = { + GpuElementAt(lhs, rhs) + } + }), expr[CreateNamedStruct]( "Creates a struct with the given field names and values", CreateNamedStructCheck, diff --git a/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/complexTypeExtractors.scala b/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/complexTypeExtractors.scala index 600a9bc1de3..90270ce5c7b 100644 --- a/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/complexTypeExtractors.scala +++ b/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/complexTypeExtractors.scala @@ -186,16 +186,6 @@ case class GpuGetMapValue(child: Expression, key: Expression) override def right: Expression = key } -class GpuElementAtMeta( - expr: ElementAt, - conf: RapidsConf, - parent: Option[RapidsMeta[_, _, _]], - rule: DataFromReplacementRule) - extends BinaryExprMeta[ElementAt](expr, conf, parent, rule) { - override def convertToGpu(lhs: Expression, rhs: Expression): GpuExpression = { - GpuElementAt(lhs, rhs) - } -} case class GpuElementAt(left: Expression, right: Expression) extends GpuBinaryExpression with ExpectsInputTypes { From 0239e18b1fc3e2a8aace02b08c3384367f6a862f Mon Sep 17 00:00:00 2001 From: Allen Xu Date: Wed, 28 Apr 2021 21:06:55 +0800 Subject: [PATCH 06/21] Fix array of array bug --- .../src/main/scala/com/nvidia/spark/rapids/GpuOverrides.scala | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuOverrides.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuOverrides.scala index 5502882871f..ec1eeb93847 100644 --- a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuOverrides.scala +++ b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuOverrides.scala @@ -2276,8 +2276,8 @@ object GpuOverrides { "Returns element of array at given(1-based) index in value if column is array. " + "Returns value for the given key in value if column is map.", ExprChecks.binaryProjectNotLambda( - TypeSig.commonCudfTypes + TypeSig.ARRAY + TypeSig.STRUCT + TypeSig.NULL + - TypeSig.DECIMAL + TypeSig.MAP, TypeSig.all, + (TypeSig.commonCudfTypes + TypeSig.ARRAY + TypeSig.STRUCT + TypeSig.NULL + + TypeSig.DECIMAL + TypeSig.MAP).nested(), TypeSig.all, ("array/map", TypeSig.ARRAY.nested(TypeSig.commonCudfTypes + TypeSig.ARRAY + TypeSig.STRUCT + TypeSig.NULL + TypeSig.DECIMAL + TypeSig.MAP) + TypeSig.MAP.nested(TypeSig.STRING) From 7a33a901a193b59e9fcd4bf2fc41fd106b3d275f Mon Sep 17 00:00:00 2001 From: Allen Xu Date: Wed, 28 Apr 2021 21:11:14 +0800 Subject: [PATCH 07/21] code clean --- integration_tests/src/main/python/array_test.py | 11 +---------- integration_tests/src/main/python/map_test.py | 2 +- 2 files changed, 2 insertions(+), 11 deletions(-) diff --git a/integration_tests/src/main/python/array_test.py b/integration_tests/src/main/python/array_test.py index c943cbf02b8..53872a21f32 100644 --- a/integration_tests/src/main/python/array_test.py +++ b/integration_tests/src/main/python/array_test.py @@ -119,19 +119,10 @@ def test_array_element_at(data_gen): @pytest.mark.parametrize('data_gen', array_gens_sample, ids=idfn) -def test_array_element_at_null(data_gen): +def test_array_element_at_array(data_gen): array_gen = ArrayGen(data_gen) assert_gpu_and_cpu_are_equal_collect(lambda spark: unary_op_df( spark, data_gen).select(element_at(col('a'), 1), element_at(col('a'), -1)), conf={'spark.sql.ansi.enabled':False, 'spark.sql.legacy.allowNegativeScaleOfDecimal': True}) - -@pytest.mark.parametrize('data_gen', [ArrayGen(ArrayGen(short_gen, max_length=10), max_length=10)], ids=idfn) -def test_array_element_at_test(data_gen): - array_gen = ArrayGen(data_gen) - assert_gpu_and_cpu_are_equal_collect(lambda spark: unary_op_df( - spark, data_gen).select(element_at(col('a'), 1), - element_at(col('a'), -1)), - conf={'spark.sql.ansi.enabled':False, - 'spark.sql.legacy.allowNegativeScaleOfDecimal': True}) \ No newline at end of file diff --git a/integration_tests/src/main/python/map_test.py b/integration_tests/src/main/python/map_test.py index 372a99fce9c..f3724f01ac1 100644 --- a/integration_tests/src/main/python/map_test.py +++ b/integration_tests/src/main/python/map_test.py @@ -32,7 +32,7 @@ def test_simple_get_map_value(data_gen): 'a["key_5"]')) @pytest.mark.parametrize('data_gen', [simple_string_to_string_map_gen], ids=idfn) -def test_simple_map_element_at(data_gen): +def test_simple_element_at_map(data_gen): assert_gpu_and_cpu_are_equal_collect( lambda spark : unary_op_df(spark, data_gen).selectExpr( 'element_at(a, "key_0")', From 2c47adfb11b0dcab3b579509b23d6fb0c3f62ba0 Mon Sep 17 00:00:00 2001 From: Allen Xu Date: Wed, 28 Apr 2021 21:39:04 +0800 Subject: [PATCH 08/21] Update doc Signed-off-by: Allen Xu --- docs/supported_ops.md | 22 +++++++++++----------- 1 file changed, 11 insertions(+), 11 deletions(-) diff --git a/docs/supported_ops.md b/docs/supported_ops.md index bf835f99b6b..1027d0f1a31 100644 --- a/docs/supported_ops.md +++ b/docs/supported_ops.md @@ -5306,7 +5306,7 @@ Accelerator support is described below. Returns element of array at given(1-based) index in value if column is array. Returns value for the given key in value if column is map. None project -left +array/map NS NS NS @@ -5321,13 +5321,13 @@ Accelerator support is described below. NS NS NS -PS* (missing nested DECIMAL, NULL, BINARY, CALENDAR, ARRAY, MAP, STRUCT, UDT) -PS* (missing nested DECIMAL, NULL, BINARY, CALENDAR, ARRAY, MAP, STRUCT, UDT) +PS* (missing nested BINARY, CALENDAR, UDT) +PS* (If it's map, only string is supported. Extra check is inside the expression metadata; missing nested BINARY, CALENDAR, UDT) NS NS -right +index/key NS NS NS @@ -5359,18 +5359,18 @@ Accelerator support is described below. S S* S +S* +S NS NS -NS -NS -NS -NS -NS +PS* (missing nested BINARY, CALENDAR, UDT) +PS* (missing nested BINARY, CALENDAR, UDT) +PS* (missing nested BINARY, CALENDAR, UDT) NS lambda -left +array/map NS NS NS @@ -5391,7 +5391,7 @@ Accelerator support is described below. NS -right +index/key NS NS NS From 7bf27170db133bd5311e33df8e1d9bad1a700158 Mon Sep 17 00:00:00 2001 From: Allen Xu Date: Thu, 29 Apr 2021 11:13:27 +0800 Subject: [PATCH 09/21] remove test code --- integration_tests/src/main/python/array_test.py | 9 --------- 1 file changed, 9 deletions(-) diff --git a/integration_tests/src/main/python/array_test.py b/integration_tests/src/main/python/array_test.py index 53872a21f32..7005aee766a 100644 --- a/integration_tests/src/main/python/array_test.py +++ b/integration_tests/src/main/python/array_test.py @@ -112,15 +112,6 @@ def main_df(spark): @pytest.mark.parametrize('data_gen', array_gens_sample, ids=idfn) def test_array_element_at(data_gen): - assert_gpu_and_cpu_are_equal_collect(lambda spark: unary_op_df( - spark, data_gen).select(element_at(col('a'), 1), - element_at(col('a'), -1)), - conf={'spark.sql.ansi.enabled':False}) - - -@pytest.mark.parametrize('data_gen', array_gens_sample, ids=idfn) -def test_array_element_at_array(data_gen): - array_gen = ArrayGen(data_gen) assert_gpu_and_cpu_are_equal_collect(lambda spark: unary_op_df( spark, data_gen).select(element_at(col('a'), 1), element_at(col('a'), -1)), From f39ee188e839d41006084dcbe3e07f306c8509a8 Mon Sep 17 00:00:00 2001 From: Allen Xu Date: Fri, 30 Apr 2021 10:51:18 +0800 Subject: [PATCH 10/21] make Spark input more strict - doc refine Signed-off-by: Allen Xu --- docs/supported_ops.md | 64 +++++++++---------- .../nvidia/spark/rapids/GpuOverrides.scala | 3 +- 2 files changed, 34 insertions(+), 33 deletions(-) diff --git a/docs/supported_ops.md b/docs/supported_ops.md index 1027d0f1a31..f03988fb3a6 100644 --- a/docs/supported_ops.md +++ b/docs/supported_ops.md @@ -5307,24 +5307,24 @@ Accelerator support is described below. None project array/map -NS -NS -NS -NS -NS -NS -NS -NS -NS -NS -NS -NS -NS -NS + + + + + + + + + + + + + + PS* (missing nested BINARY, CALENDAR, UDT) PS* (If it's map, only string is supported. Extra check is inside the expression metadata; missing nested BINARY, CALENDAR, UDT) -NS -NS + + index/key @@ -5371,24 +5371,24 @@ Accelerator support is described below. lambda array/map + + + + + + + + + + + + + + NS NS -NS -NS -NS -NS -NS -NS -NS -NS -NS -NS -NS -NS -NS -NS -NS -NS + + index/key diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuOverrides.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuOverrides.scala index ec1eeb93847..cb1dd80fccf 100644 --- a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuOverrides.scala +++ b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuOverrides.scala @@ -2282,7 +2282,8 @@ object GpuOverrides { TypeSig.STRUCT + TypeSig.NULL + TypeSig.DECIMAL + TypeSig.MAP) + TypeSig.MAP.nested(TypeSig.STRING) .withPsNote(TypeEnum.MAP ,"If it's map, only string is supported. " + - "Extra check is inside the expression metadata"), TypeSig.all), + "Extra check is inside the expression metadata"), + TypeSig.ARRAY.nested(TypeSig.all) + TypeSig.MAP.nested(TypeSig.all)), ("index/key", TypeSig.lit(TypeEnum.INT) + TypeSig.lit(TypeEnum.STRING), TypeSig.all)), (in, conf, p, r) => new BinaryExprMeta[ElementAt](in, conf, p, r) { override def tagExprForGpu(): Unit = { From e00d8565dbae9f1e2231e7567e165f1bf8bc00c5 Mon Sep 17 00:00:00 2001 From: Allen Xu Date: Fri, 30 Apr 2021 14:17:15 +0800 Subject: [PATCH 11/21] resolve comments Signed-off-by: Allen Xu --- .../main/scala/com/nvidia/spark/rapids/GpuOverrides.scala | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuOverrides.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuOverrides.scala index cb1dd80fccf..4745a34a47e 100644 --- a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuOverrides.scala +++ b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuOverrides.scala @@ -2284,7 +2284,10 @@ object GpuOverrides { .withPsNote(TypeEnum.MAP ,"If it's map, only string is supported. " + "Extra check is inside the expression metadata"), TypeSig.ARRAY.nested(TypeSig.all) + TypeSig.MAP.nested(TypeSig.all)), - ("index/key", TypeSig.lit(TypeEnum.INT) + TypeSig.lit(TypeEnum.STRING), TypeSig.all)), + ("index/key", (TypeSig.lit(TypeEnum.INT) + TypeSig.lit(TypeEnum.STRING)) + .withPsNote(TypeEnum.INT, "If it's the index for array, only INT is supported." + + "If it's the key for map, only STRING is supported"), + TypeSig.all)), (in, conf, p, r) => new BinaryExprMeta[ElementAt](in, conf, p, r) { override def tagExprForGpu(): Unit = { // To distinguish the supported nested type between Array and Map @@ -2296,7 +2299,7 @@ object GpuOverrides { s" Map value") } } - case ArrayType(_,_) => // Array supports more + case _ => // Array supports more } } override def convertToGpu(lhs: Expression, rhs: Expression): GpuExpression = { From 4e0791fba2e13458fbf3afd87acbb186eaaca61b Mon Sep 17 00:00:00 2001 From: Allen Xu Date: Thu, 6 May 2021 16:16:28 +0800 Subject: [PATCH 12/21] Resolve comments Signed-off-by: Allen Xu --- docs/supported_ops.md | 6 +- .../nvidia/spark/rapids/GpuOverrides.scala | 36 +++-- .../spark/rapids/collectionOperations.scala | 49 ------- .../sql/rapids/collectionOperations.scala | 133 ++++++++++++++++++ .../sql/rapids/complexTypeExtractors.scala | 89 +----------- 5 files changed, 162 insertions(+), 151 deletions(-) delete mode 100644 sql-plugin/src/main/scala/com/nvidia/spark/rapids/collectionOperations.scala create mode 100644 sql-plugin/src/main/scala/org/apache/spark/sql/rapids/collectionOperations.scala diff --git a/docs/supported_ops.md b/docs/supported_ops.md index f03988fb3a6..58fa36c8127 100644 --- a/docs/supported_ops.md +++ b/docs/supported_ops.md @@ -5322,7 +5322,7 @@ Accelerator support is described below. PS* (missing nested BINARY, CALENDAR, UDT) -PS* (If it's map, only string is supported. Extra check is inside the expression metadata; missing nested BINARY, CALENDAR, UDT) +PS* (If it's map, only string is supported.; missing nested BINARY, CALENDAR, UDT) @@ -5331,13 +5331,13 @@ Accelerator support is described below. NS NS NS -PS (Literal value only) +PS (ints are only supported as array indexes, not as maps keys; Literal value only) NS NS NS NS NS -PS (Literal value only) +PS (strings are only supported as map keys, not array indexes; Literal value only) NS NS NS diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuOverrides.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuOverrides.scala index 4745a34a47e..bd9de3a1017 100644 --- a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuOverrides.scala +++ b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuOverrides.scala @@ -2281,26 +2281,36 @@ object GpuOverrides { ("array/map", TypeSig.ARRAY.nested(TypeSig.commonCudfTypes + TypeSig.ARRAY + TypeSig.STRUCT + TypeSig.NULL + TypeSig.DECIMAL + TypeSig.MAP) + TypeSig.MAP.nested(TypeSig.STRING) - .withPsNote(TypeEnum.MAP ,"If it's map, only string is supported. " + - "Extra check is inside the expression metadata"), + .withPsNote(TypeEnum.MAP ,"If it's map, only string is supported."), TypeSig.ARRAY.nested(TypeSig.all) + TypeSig.MAP.nested(TypeSig.all)), ("index/key", (TypeSig.lit(TypeEnum.INT) + TypeSig.lit(TypeEnum.STRING)) - .withPsNote(TypeEnum.INT, "If it's the index for array, only INT is supported." + - "If it's the key for map, only STRING is supported"), + .withPsNote(TypeEnum.INT, "ints are only supported as array indexes, " + + "not as maps keys") + .withPsNote(TypeEnum.STRING, "strings are only supported as map keys, " + + "not array indexes"), TypeSig.all)), (in, conf, p, r) => new BinaryExprMeta[ElementAt](in, conf, p, r) { override def tagExprForGpu(): Unit = { // To distinguish the supported nested type between Array and Map - in.left.dataType match { - case MapType(_,valueType,_) => { - valueType match { - case StringType => // minimum support - case _ => willNotWorkOnGpu(s"${valueType.simpleString} is not supported for" + - s" Map value") - } - } - case _ => // Array supports more + val checks = in.left.dataType match { + case _: MapType => + // This should match exactly with the checks for GetMapValue + ExprChecks.binaryProjectNotLambda(TypeSig.STRING, TypeSig.all, + ("map", TypeSig.MAP.nested(TypeSig.STRING), TypeSig.MAP.nested(TypeSig.all)), + ("key", TypeSig.lit(TypeEnum.STRING), TypeSig.all)) + case _: ArrayType => + // This should match exactly with the checks for GetArrayItem + ExprChecks.binaryProjectNotLambda( + (TypeSig.commonCudfTypes + TypeSig.ARRAY + TypeSig.STRUCT + TypeSig.NULL + + TypeSig.DECIMAL + TypeSig.MAP).nested(), + TypeSig.all, + ("array", TypeSig.ARRAY.nested(TypeSig.commonCudfTypes + TypeSig.ARRAY + + TypeSig.STRUCT + TypeSig.NULL + TypeSig.DECIMAL + TypeSig.MAP), + TypeSig.ARRAY.nested(TypeSig.all)), + ("ordinal", TypeSig.lit(TypeEnum.INT), TypeSig.INT)) + case _ => throw new IllegalStateException("Only Array or Map is supported as input.") } + checks.tag(this) } override def convertToGpu(lhs: Expression, rhs: Expression): GpuExpression = { GpuElementAt(lhs, rhs) diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/collectionOperations.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/collectionOperations.scala deleted file mode 100644 index 19c641d1cd9..00000000000 --- a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/collectionOperations.scala +++ /dev/null @@ -1,49 +0,0 @@ -/* - * Copyright (c) 2021, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package com.nvidia.spark.rapids - -import ai.rapids.cudf.ColumnVector - -import org.apache.spark.sql.catalyst.expressions.Expression -import org.apache.spark.sql.types._ - -case class GpuSize(child: Expression, legacySizeOfNull: Boolean) - extends GpuUnaryExpression { - - require(child.dataType.isInstanceOf[ArrayType] || child.dataType.isInstanceOf[MapType], - s"The size function doesn't support the operand type ${child.dataType}") - - override def dataType: DataType = IntegerType - override def nullable: Boolean = if (legacySizeOfNull) false else super.nullable - - override protected def doColumnar(input: GpuColumnVector): ColumnVector = { - - // Compute sizes of cuDF.ListType to get sizes of each ArrayData or MapData, considering - // MapData is represented as List of Struct in terms of cuDF. - withResource(input.getBase.countElements()) { collectionSize => - if (legacySizeOfNull) { - withResource(GpuScalar.from(-1)) { nullScalar => - withResource(input.getBase.isNull) { inputIsNull => - inputIsNull.ifElse(nullScalar, collectionSize) - } - } - } else { - collectionSize.incRefCount() - } - } - } -} diff --git a/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/collectionOperations.scala b/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/collectionOperations.scala new file mode 100644 index 00000000000..0f4792c17d4 --- /dev/null +++ b/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/collectionOperations.scala @@ -0,0 +1,133 @@ +/* + * Copyright (c) 2021, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.rapids + +import ai.rapids.cudf.{ColumnVector, Scalar} +import com.nvidia.spark.rapids.{GpuBinaryExpression, GpuColumnVector, GpuScalar, GpuUnaryExpression} + +import org.apache.spark.sql.catalyst.analysis.{TypeCheckResult, TypeCoercion} +import org.apache.spark.sql.catalyst.expressions.{ExpectsInputTypes, Expression} +import org.apache.spark.sql.types._ + +case class GpuSize(child: Expression, legacySizeOfNull: Boolean) + extends GpuUnaryExpression { + + require(child.dataType.isInstanceOf[ArrayType] || child.dataType.isInstanceOf[MapType], + s"The size function doesn't support the operand type ${child.dataType}") + + override def dataType: DataType = IntegerType + override def nullable: Boolean = if (legacySizeOfNull) false else super.nullable + + override protected def doColumnar(input: GpuColumnVector): ColumnVector = { + + // Compute sizes of cuDF.ListType to get sizes of each ArrayData or MapData, considering + // MapData is represented as List of Struct in terms of cuDF. + withResource(input.getBase.countElements()) { collectionSize => + if (legacySizeOfNull) { + withResource(GpuScalar.from(-1)) { nullScalar => + withResource(input.getBase.isNull) { inputIsNull => + inputIsNull.ifElse(nullScalar, collectionSize) + } + } + } else { + collectionSize.incRefCount() + } + } + } +} + +case class GpuElementAt(left: Expression, right: Expression) + extends GpuBinaryExpression with ExpectsInputTypes { + + override lazy val dataType: DataType = left.dataType match { + case ArrayType(elementType, _) => elementType + case MapType(_, valueType, _) => valueType + } + + override def inputTypes: Seq[AbstractDataType] = { + (left.dataType, right.dataType) match { + case (arr: ArrayType, e2: IntegralType) if (e2 != LongType) => + Seq(arr, IntegerType) + case (MapType(keyType, valueType, hasNull), e2) => + TypeCoercion.findTightestCommonType(keyType, e2) match { + case Some(dt) => Seq(MapType(dt, valueType, hasNull), dt) + case _ => Seq.empty + } + case (l, r) => Seq.empty + } + } + + override def checkInputDataTypes(): TypeCheckResult = { + (left.dataType, right.dataType) match { + case (_: ArrayType, e2) if e2 != IntegerType => + TypeCheckResult.TypeCheckFailure(s"Input to function $prettyName should have " + + s"been ${ArrayType.simpleString} followed by a ${IntegerType.simpleString}, but it's " + + s"[${left.dataType.catalogString}, ${right.dataType.catalogString}].") + case (MapType(e1, _, _), e2) if (!e2.sameType(e1)) => + TypeCheckResult.TypeCheckFailure(s"Input to function $prettyName should have " + + s"been ${MapType.simpleString} followed by a value of same key type, but it's " + + s"[${left.dataType.catalogString}, ${right.dataType.catalogString}].") + case (e1, _) if (!e1.isInstanceOf[MapType] && !e1.isInstanceOf[ArrayType]) => + TypeCheckResult.TypeCheckFailure(s"The first argument to function $prettyName should " + + s"have been ${ArrayType.simpleString} or ${MapType.simpleString} type, but its " + + s"${left.dataType.catalogString} type.") + case _ => TypeCheckResult.TypeCheckSuccess + } + } + + // Eventually we need something more full featured like + // GetArrayItemUtil.computeNullabilityFromArray + override def nullable: Boolean = true + + override def doColumnar(lhs: GpuColumnVector, rhs: GpuColumnVector): ColumnVector = + throw new IllegalStateException("This is not supported yet") + + override def doColumnar(lhs: Scalar, rhs: GpuColumnVector): ColumnVector = + throw new IllegalStateException("This is not supported yet") + + override def doColumnar(lhs: GpuColumnVector, rhs: Scalar): ColumnVector = { + lhs.dataType match { + case _: ArrayType => { + if (rhs.isValid) { + if (rhs.getInt > 0) { + // SQL 1-based index + lhs.getBase.extractListElement(rhs.getInt - 1) + } else if (rhs.getInt == 0) { + throw new ArrayIndexOutOfBoundsException("SQL array indices start at 1") + } else { + lhs.getBase.extractListElement(rhs.getInt) + } + } else { + withResource(Scalar.fromNull( + GpuColumnVector.getNonNestedRapidsType(dataType))) { nullScalar => + ColumnVector.fromScalar(nullScalar, lhs.getRowCount.toInt) + } + } + } + case _: MapType => { + lhs.getBase.getMapValue(rhs) + } + } + } + + override def doColumnar(numRows: Int, lhs: Scalar, rhs: Scalar): ColumnVector = + withResource(GpuColumnVector.from(lhs, numRows, left.dataType)) { expandedLhs => + doColumnar(expandedLhs, rhs) + } + + override def prettyName: String = "element_at" +} diff --git a/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/complexTypeExtractors.scala b/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/complexTypeExtractors.scala index 90270ce5c7b..42b648db7df 100644 --- a/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/complexTypeExtractors.scala +++ b/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/complexTypeExtractors.scala @@ -21,10 +21,10 @@ import com.nvidia.spark.rapids.{BinaryExprMeta, DataFromReplacementRule, GpuBina import com.nvidia.spark.rapids.RapidsPluginImplicits._ import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.analysis.{TypeCheckResult, TypeCoercion} -import org.apache.spark.sql.catalyst.expressions.{ElementAt, ExpectsInputTypes, Expression, ExtractValue, GetArrayItem, GetMapValue, ImplicitCastInputTypes, NullIntolerant, UnaryExpression} +import org.apache.spark.sql.catalyst.analysis.TypeCheckResult +import org.apache.spark.sql.catalyst.expressions.{ExpectsInputTypes, Expression, ExtractValue, GetArrayItem, GetMapValue, ImplicitCastInputTypes, NullIntolerant, UnaryExpression} import org.apache.spark.sql.catalyst.util.{quoteIdentifier, TypeUtils} -import org.apache.spark.sql.types.{AbstractDataType, AnyDataType, ArrayType, BooleanType, DataType, IntegerType, IntegralType, LongType, MapType, StructType} +import org.apache.spark.sql.types.{AbstractDataType, AnyDataType, ArrayType, BooleanType, DataType, IntegralType, MapType, StructType} import org.apache.spark.sql.vectorized.ColumnarBatch case class GpuGetStructField(child: Expression, ordinal: Int, name: Option[String] = None) @@ -186,89 +186,6 @@ case class GpuGetMapValue(child: Expression, key: Expression) override def right: Expression = key } - -case class GpuElementAt(left: Expression, right: Expression) - extends GpuBinaryExpression with ExpectsInputTypes { - - override lazy val dataType: DataType = left.dataType match { - case ArrayType(elementType, _) => elementType - case MapType(_, valueType, _) => valueType - } - - override def inputTypes: Seq[AbstractDataType] = { - (left.dataType, right.dataType) match { - case (arr: ArrayType, e2: IntegralType) if (e2 != LongType) => - Seq(arr, IntegerType) - case (MapType(keyType, valueType, hasNull), e2) => - TypeCoercion.findTightestCommonType(keyType, e2) match { - case Some(dt) => Seq(MapType(dt, valueType, hasNull), dt) - case _ => Seq.empty - } - case (l, r) => Seq.empty - } - } - - override def checkInputDataTypes(): TypeCheckResult = { - (left.dataType, right.dataType) match { - case (_: ArrayType, e2) if e2 != IntegerType => - TypeCheckResult.TypeCheckFailure(s"Input to function $prettyName should have " + - s"been ${ArrayType.simpleString} followed by a ${IntegerType.simpleString}, but it's " + - s"[${left.dataType.catalogString}, ${right.dataType.catalogString}].") - case (MapType(e1, _, _), e2) if (!e2.sameType(e1)) => - TypeCheckResult.TypeCheckFailure(s"Input to function $prettyName should have " + - s"been ${MapType.simpleString} followed by a value of same key type, but it's " + - s"[${left.dataType.catalogString}, ${right.dataType.catalogString}].") - case (e1, _) if (!e1.isInstanceOf[MapType] && !e1.isInstanceOf[ArrayType]) => - TypeCheckResult.TypeCheckFailure(s"The first argument to function $prettyName should " + - s"have been ${ArrayType.simpleString} or ${MapType.simpleString} type, but its " + - s"${left.dataType.catalogString} type.") - case _ => TypeCheckResult.TypeCheckSuccess - } - } - - // Eventually we need something more full featured like - // GetArrayItemUtil.computeNullabilityFromArray - override def nullable: Boolean = true - - override def doColumnar(lhs: GpuColumnVector, rhs: GpuColumnVector): ColumnVector = - throw new IllegalStateException("This is not supported yet") - - override def doColumnar(lhs: Scalar, rhs: GpuColumnVector): ColumnVector = - throw new IllegalStateException("This is not supported yet") - - override def doColumnar(lhs: GpuColumnVector, rhs: Scalar): ColumnVector = { - lhs.dataType match { - case _: ArrayType => { - if (rhs.isValid) { - if (rhs.getInt > 0) { - // SQL 1-based index - lhs.getBase.extractListElement(rhs.getInt - 1) - } else if (rhs.getInt == 0) { - throw new ArrayIndexOutOfBoundsException("SQL array indices start at 1") - } else { - lhs.getBase.extractListElement(rhs.getInt) - } - } else { - withResource(Scalar.fromNull( - GpuColumnVector.getNonNestedRapidsType(dataType))) { nullScalar => - ColumnVector.fromScalar(nullScalar, lhs.getRowCount.toInt) - } - } - } - case _: MapType => { - lhs.getBase.getMapValue(rhs) - } - } - } - - override def doColumnar(numRows: Int, lhs: Scalar, rhs: Scalar): ColumnVector = - withResource(GpuColumnVector.from(lhs, numRows, left.dataType)) { expandedLhs => - doColumnar(expandedLhs, rhs) - } - - override def prettyName: String = "element_at" -} - /** Checks if the array (left) has the element (right) */ case class GpuArrayContains(left: Expression, right: Expression) From 272b358a24837b84977cc1949c8ed41d137414d5 Mon Sep 17 00:00:00 2001 From: Allen Xu Date: Thu, 6 May 2021 17:05:44 +0800 Subject: [PATCH 13/21] resolve comments, rebase to latest Signed-off-by: Allen Xu --- .../src/main/scala/com/nvidia/spark/rapids/GpuOverrides.scala | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuOverrides.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuOverrides.scala index bd9de3a1017..99bfcecf2b7 100644 --- a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuOverrides.scala +++ b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuOverrides.scala @@ -2294,12 +2294,12 @@ object GpuOverrides { // To distinguish the supported nested type between Array and Map val checks = in.left.dataType match { case _: MapType => - // This should match exactly with the checks for GetMapValue + // Match exactly with the checks for GetMapValue ExprChecks.binaryProjectNotLambda(TypeSig.STRING, TypeSig.all, ("map", TypeSig.MAP.nested(TypeSig.STRING), TypeSig.MAP.nested(TypeSig.all)), ("key", TypeSig.lit(TypeEnum.STRING), TypeSig.all)) case _: ArrayType => - // This should match exactly with the checks for GetArrayItem + // Match exactly with the checks for GetArrayItem ExprChecks.binaryProjectNotLambda( (TypeSig.commonCudfTypes + TypeSig.ARRAY + TypeSig.STRUCT + TypeSig.NULL + TypeSig.DECIMAL + TypeSig.MAP).nested(), From a19ddf6b7f39db889801f2ba2b488f452617da67 Mon Sep 17 00:00:00 2001 From: Allen Xu Date: Thu, 6 May 2021 21:33:02 +0800 Subject: [PATCH 14/21] Update support_op_docs Signed-off-by: Allen Xu --- docs/supported_ops.md | 104 +++++++++++++++++++++--------------------- 1 file changed, 52 insertions(+), 52 deletions(-) diff --git a/docs/supported_ops.md b/docs/supported_ops.md index 58fa36c8127..6703dddf57e 100644 --- a/docs/supported_ops.md +++ b/docs/supported_ops.md @@ -5433,6 +5433,32 @@ Accelerator support is described below. NS +Expression +SQL Functions(s) +Description +Notes +Context +Param/Output +BOOLEAN +BYTE +SHORT +INT +LONG +FLOAT +DOUBLE +DATE +TIMESTAMP +STRING +DECIMAL +NULL +BINARY +CALENDAR +ARRAY +MAP +STRUCT +UDT + + EndsWith Ends with @@ -5565,32 +5591,6 @@ Accelerator support is described below. -Expression -SQL Functions(s) -Description -Notes -Context -Param/Output -BOOLEAN -BYTE -SHORT -INT -LONG -FLOAT -DOUBLE -DATE -TIMESTAMP -STRING -DECIMAL -NULL -BINARY -CALENDAR -ARRAY -MAP -STRUCT -UDT - - EqualNullSafe `<=>` Check if the values are equal including nulls <=> @@ -5855,6 +5855,32 @@ Accelerator support is described below. NS +Expression +SQL Functions(s) +Description +Notes +Context +Param/Output +BOOLEAN +BYTE +SHORT +INT +LONG +FLOAT +DOUBLE +DATE +TIMESTAMP +STRING +DECIMAL +NULL +BINARY +CALENDAR +ARRAY +MAP +STRUCT +UDT + + Exp `exp` Euler's number e raised to a power @@ -5945,32 +5971,6 @@ Accelerator support is described below. -Expression -SQL Functions(s) -Description -Notes -Context -Param/Output -BOOLEAN -BYTE -SHORT -INT -LONG -FLOAT -DOUBLE -DATE -TIMESTAMP -STRING -DECIMAL -NULL -BINARY -CALENDAR -ARRAY -MAP -STRUCT -UDT - - Explode `explode`, `explode_outer` Given an input array produces a sequence of rows for each value in the array. From be05e2fb9a918388af2f561a7859497604fc0604 Mon Sep 17 00:00:00 2001 From: Allen Xu Date: Sat, 8 May 2021 11:36:03 +0800 Subject: [PATCH 15/21] Refactor GetMapValue and GetArrayItem core methods --- .../sql/rapids/collectionOperations.scala | 18 +----- .../sql/rapids/complexTypeExtractors.scala | 56 +++++++++++++++---- 2 files changed, 47 insertions(+), 27 deletions(-) diff --git a/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/collectionOperations.scala b/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/collectionOperations.scala index 0f4792c17d4..2c2c83e02db 100644 --- a/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/collectionOperations.scala +++ b/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/collectionOperations.scala @@ -102,24 +102,10 @@ case class GpuElementAt(left: Expression, right: Expression) override def doColumnar(lhs: GpuColumnVector, rhs: Scalar): ColumnVector = { lhs.dataType match { case _: ArrayType => { - if (rhs.isValid) { - if (rhs.getInt > 0) { - // SQL 1-based index - lhs.getBase.extractListElement(rhs.getInt - 1) - } else if (rhs.getInt == 0) { - throw new ArrayIndexOutOfBoundsException("SQL array indices start at 1") - } else { - lhs.getBase.extractListElement(rhs.getInt) - } - } else { - withResource(Scalar.fromNull( - GpuColumnVector.getNonNestedRapidsType(dataType))) { nullScalar => - ColumnVector.fromScalar(nullScalar, lhs.getRowCount.toInt) - } - } + GetArrayItemUtil.evalColumnar(lhs, rhs, dataType, zeroIndexed = false) } case _: MapType => { - lhs.getBase.getMapValue(rhs) + GetMapValueUtil.evalColumnar(lhs, rhs) } } } diff --git a/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/complexTypeExtractors.scala b/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/complexTypeExtractors.scala index 42b648db7df..92c87956f41 100644 --- a/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/complexTypeExtractors.scala +++ b/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/complexTypeExtractors.scala @@ -17,6 +17,7 @@ package org.apache.spark.sql.rapids import ai.rapids.cudf.{ColumnVector, Scalar} +import com.nvidia.spark.RebaseHelper.withResource import com.nvidia.spark.rapids.{BinaryExprMeta, DataFromReplacementRule, GpuBinaryExpression, GpuColumnVector, GpuExpression, GpuOverrides, GpuScalar, RapidsConf, RapidsMeta} import com.nvidia.spark.rapids.RapidsPluginImplicits._ @@ -118,15 +119,7 @@ case class GpuGetArrayItem(child: Expression, ordinal: Expression) throw new IllegalStateException("This is not supported yet") override def doColumnar(lhs: GpuColumnVector, ordinal: Scalar): ColumnVector = { - // Need to handle negative indexes... - if (ordinal.isValid && ordinal.getInt >= 0) { - lhs.getBase.extractListElement(ordinal.getInt) - } else { - withResource(Scalar.fromNull( - GpuColumnVector.getNonNestedRapidsType(dataType))) { nullScalar => - ColumnVector.fromScalar(nullScalar, lhs.getRowCount.toInt) - } - } + GetArrayItemUtil.evalColumnar(lhs, ordinal, dataType, zeroIndexed = true) } override def doColumnar(numRows: Int, lhs: Scalar, rhs: Scalar): ColumnVector = { @@ -167,7 +160,7 @@ case class GpuGetMapValue(child: Expression, key: Expression) override def prettyName: String = "getMapValue" override def doColumnar(lhs: GpuColumnVector, rhs: Scalar): ColumnVector = - lhs.getBase.getMapValue(rhs) + GetMapValueUtil.evalColumnar(lhs, rhs) override def doColumnar(numRows: Int, lhs: Scalar, rhs: Scalar): ColumnVector = { withResource(GpuColumnVector.from(lhs, numRows, left.dataType)) { expandedLhs => @@ -210,4 +203,45 @@ case class GpuArrayContains(left: Expression, right: Expression) lhs.getBase.listContainsColumn(rhs.getBase) override def prettyName: String = "array_contains" -} \ No newline at end of file +} + +object GetArrayItemUtil { + def evalColumnar(array: GpuColumnVector, ordinal: Scalar, dataType: DataType, + zeroIndexed: Boolean): ColumnVector = { + // for array index use case, index starts at 0 + if (zeroIndexed) { + // Need to handle negative indexes... + if (ordinal.isValid && ordinal.getInt >= 0) { + array.getBase.extractListElement(ordinal.getInt) + } else { + withResource(Scalar.fromNull( + GpuColumnVector.getNonNestedRapidsType(dataType))) { nullScalar => + ColumnVector.fromScalar(nullScalar, array.getRowCount.toInt) + } + } + } else { + // for element_at use case, index starts at 1 + if (ordinal.isValid) { + if (ordinal.getInt > 0) { + // SQL 1-based index + array.getBase.extractListElement(ordinal.getInt - 1) + } else if (ordinal.getInt == 0) { + throw new ArrayIndexOutOfBoundsException("SQL array indices start at 1") + } else { + array.getBase.extractListElement(ordinal.getInt) + } + } else { + withResource(Scalar.fromNull( + GpuColumnVector.getNonNestedRapidsType(dataType))) { nullScalar => + ColumnVector.fromScalar(nullScalar, array.getRowCount.toInt) + } + } + } + } +} + +object GetMapValueUtil { + def evalColumnar(map: GpuColumnVector, key: Scalar): ColumnVector = { + map.getBase.getMapValue(key) + } +} From d4b761f32241e894ee14eeebc87cf8224f1fa94e Mon Sep 17 00:00:00 2001 From: Allen Xu Date: Fri, 14 May 2021 10:15:08 +0800 Subject: [PATCH 16/21] resolve comments --- .../apache/spark/sql/rapids/complexTypeExtractors.scala | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/complexTypeExtractors.scala b/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/complexTypeExtractors.scala index 92c87956f41..f65846b3d2e 100644 --- a/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/complexTypeExtractors.scala +++ b/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/complexTypeExtractors.scala @@ -17,8 +17,7 @@ package org.apache.spark.sql.rapids import ai.rapids.cudf.{ColumnVector, Scalar} -import com.nvidia.spark.RebaseHelper.withResource -import com.nvidia.spark.rapids.{BinaryExprMeta, DataFromReplacementRule, GpuBinaryExpression, GpuColumnVector, GpuExpression, GpuOverrides, GpuScalar, RapidsConf, RapidsMeta} +import com.nvidia.spark.rapids.{Arm, BinaryExprMeta, DataFromReplacementRule, GpuBinaryExpression, GpuColumnVector, GpuExpression, GpuOverrides, GpuScalar, RapidsConf, RapidsMeta} import com.nvidia.spark.rapids.RapidsPluginImplicits._ import org.apache.spark.sql.catalyst.InternalRow @@ -205,7 +204,9 @@ case class GpuArrayContains(left: Expression, right: Expression) override def prettyName: String = "array_contains" } -object GetArrayItemUtil { +/** Core static methods for GetArrayItem and ElementAt + */ +object GetArrayItemUtil extends Arm{ def evalColumnar(array: GpuColumnVector, ordinal: Scalar, dataType: DataType, zeroIndexed: Boolean): ColumnVector = { // for array index use case, index starts at 0 @@ -240,6 +241,8 @@ object GetArrayItemUtil { } } +/** Core static methods for GetMapValue and ElementAt + */ object GetMapValueUtil { def evalColumnar(map: GpuColumnVector, key: Scalar): ColumnVector = { map.getBase.getMapValue(key) From ae6f8d08a964fe989650ab824802454226b31a57 Mon Sep 17 00:00:00 2001 From: Allen Xu Date: Fri, 14 May 2021 21:43:56 +0800 Subject: [PATCH 17/21] revert refactor Signed-off-by: Allen Xu --- .../sql/rapids/collectionOperations.scala | 18 +++++- .../sql/rapids/complexTypeExtractors.scala | 58 ++++--------------- 2 files changed, 27 insertions(+), 49 deletions(-) diff --git a/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/collectionOperations.scala b/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/collectionOperations.scala index 2c2c83e02db..0f4792c17d4 100644 --- a/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/collectionOperations.scala +++ b/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/collectionOperations.scala @@ -102,10 +102,24 @@ case class GpuElementAt(left: Expression, right: Expression) override def doColumnar(lhs: GpuColumnVector, rhs: Scalar): ColumnVector = { lhs.dataType match { case _: ArrayType => { - GetArrayItemUtil.evalColumnar(lhs, rhs, dataType, zeroIndexed = false) + if (rhs.isValid) { + if (rhs.getInt > 0) { + // SQL 1-based index + lhs.getBase.extractListElement(rhs.getInt - 1) + } else if (rhs.getInt == 0) { + throw new ArrayIndexOutOfBoundsException("SQL array indices start at 1") + } else { + lhs.getBase.extractListElement(rhs.getInt) + } + } else { + withResource(Scalar.fromNull( + GpuColumnVector.getNonNestedRapidsType(dataType))) { nullScalar => + ColumnVector.fromScalar(nullScalar, lhs.getRowCount.toInt) + } + } } case _: MapType => { - GetMapValueUtil.evalColumnar(lhs, rhs) + lhs.getBase.getMapValue(rhs) } } } diff --git a/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/complexTypeExtractors.scala b/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/complexTypeExtractors.scala index f65846b3d2e..2e94fef249d 100644 --- a/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/complexTypeExtractors.scala +++ b/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/complexTypeExtractors.scala @@ -17,7 +17,7 @@ package org.apache.spark.sql.rapids import ai.rapids.cudf.{ColumnVector, Scalar} -import com.nvidia.spark.rapids.{Arm, BinaryExprMeta, DataFromReplacementRule, GpuBinaryExpression, GpuColumnVector, GpuExpression, GpuOverrides, GpuScalar, RapidsConf, RapidsMeta} +import com.nvidia.spark.rapids.{BinaryExprMeta, DataFromReplacementRule, GpuBinaryExpression, GpuColumnVector, GpuExpression, GpuOverrides, GpuScalar, RapidsConf, RapidsMeta} import com.nvidia.spark.rapids.RapidsPluginImplicits._ import org.apache.spark.sql.catalyst.InternalRow @@ -118,7 +118,15 @@ case class GpuGetArrayItem(child: Expression, ordinal: Expression) throw new IllegalStateException("This is not supported yet") override def doColumnar(lhs: GpuColumnVector, ordinal: Scalar): ColumnVector = { - GetArrayItemUtil.evalColumnar(lhs, ordinal, dataType, zeroIndexed = true) + // Need to handle negative indexes... + if (ordinal.isValid && ordinal.getInt >= 0) { + lhs.getBase.extractListElement(ordinal.getInt) + } else { + withResource(Scalar.fromNull( + GpuColumnVector.getNonNestedRapidsType(dataType))) { nullScalar => + ColumnVector.fromScalar(nullScalar, lhs.getRowCount.toInt) + } + } } override def doColumnar(numRows: Int, lhs: Scalar, rhs: Scalar): ColumnVector = { @@ -159,7 +167,7 @@ case class GpuGetMapValue(child: Expression, key: Expression) override def prettyName: String = "getMapValue" override def doColumnar(lhs: GpuColumnVector, rhs: Scalar): ColumnVector = - GetMapValueUtil.evalColumnar(lhs, rhs) + lhs.getBase.getMapValue(rhs) override def doColumnar(numRows: Int, lhs: Scalar, rhs: Scalar): ColumnVector = { withResource(GpuColumnVector.from(lhs, numRows, left.dataType)) { expandedLhs => @@ -204,47 +212,3 @@ case class GpuArrayContains(left: Expression, right: Expression) override def prettyName: String = "array_contains" } -/** Core static methods for GetArrayItem and ElementAt - */ -object GetArrayItemUtil extends Arm{ - def evalColumnar(array: GpuColumnVector, ordinal: Scalar, dataType: DataType, - zeroIndexed: Boolean): ColumnVector = { - // for array index use case, index starts at 0 - if (zeroIndexed) { - // Need to handle negative indexes... - if (ordinal.isValid && ordinal.getInt >= 0) { - array.getBase.extractListElement(ordinal.getInt) - } else { - withResource(Scalar.fromNull( - GpuColumnVector.getNonNestedRapidsType(dataType))) { nullScalar => - ColumnVector.fromScalar(nullScalar, array.getRowCount.toInt) - } - } - } else { - // for element_at use case, index starts at 1 - if (ordinal.isValid) { - if (ordinal.getInt > 0) { - // SQL 1-based index - array.getBase.extractListElement(ordinal.getInt - 1) - } else if (ordinal.getInt == 0) { - throw new ArrayIndexOutOfBoundsException("SQL array indices start at 1") - } else { - array.getBase.extractListElement(ordinal.getInt) - } - } else { - withResource(Scalar.fromNull( - GpuColumnVector.getNonNestedRapidsType(dataType))) { nullScalar => - ColumnVector.fromScalar(nullScalar, array.getRowCount.toInt) - } - } - } - } -} - -/** Core static methods for GetMapValue and ElementAt - */ -object GetMapValueUtil { - def evalColumnar(map: GpuColumnVector, key: Scalar): ColumnVector = { - map.getBase.getMapValue(key) - } -} From 03110abd810c094b41b21a0ffbc7b0c0b7233676 Mon Sep 17 00:00:00 2001 From: Allen Xu Date: Fri, 14 May 2021 21:55:25 +0800 Subject: [PATCH 18/21] final fix Signed-off-by: Allen Xu --- .../org/apache/spark/sql/rapids/collectionOperations.scala | 3 +-- .../org/apache/spark/sql/rapids/complexTypeExtractors.scala | 3 +-- 2 files changed, 2 insertions(+), 4 deletions(-) diff --git a/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/collectionOperations.scala b/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/collectionOperations.scala index 0f4792c17d4..685fac68c58 100644 --- a/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/collectionOperations.scala +++ b/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/collectionOperations.scala @@ -112,8 +112,7 @@ case class GpuElementAt(left: Expression, right: Expression) lhs.getBase.extractListElement(rhs.getInt) } } else { - withResource(Scalar.fromNull( - GpuColumnVector.getNonNestedRapidsType(dataType))) { nullScalar => + withResource(GpuScalar.from(null, dataType)) { nullScalar => ColumnVector.fromScalar(nullScalar, lhs.getRowCount.toInt) } } diff --git a/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/complexTypeExtractors.scala b/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/complexTypeExtractors.scala index 2e94fef249d..0c6cad19f22 100644 --- a/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/complexTypeExtractors.scala +++ b/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/complexTypeExtractors.scala @@ -122,8 +122,7 @@ case class GpuGetArrayItem(child: Expression, ordinal: Expression) if (ordinal.isValid && ordinal.getInt >= 0) { lhs.getBase.extractListElement(ordinal.getInt) } else { - withResource(Scalar.fromNull( - GpuColumnVector.getNonNestedRapidsType(dataType))) { nullScalar => + withResource(GpuScalar.from(null, dataType)) { nullScalar => ColumnVector.fromScalar(nullScalar, lhs.getRowCount.toInt) } } From 56982cd14d3d49c72d4a018532ddf5595a96b4ca Mon Sep 17 00:00:00 2001 From: Allen Xu Date: Sat, 15 May 2021 13:52:02 +0800 Subject: [PATCH 19/21] use columnVectorFromNull --- .../org/apache/spark/sql/rapids/collectionOperations.scala | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/collectionOperations.scala b/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/collectionOperations.scala index 5389c01c50f..5c3396ac4d9 100644 --- a/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/collectionOperations.scala +++ b/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/collectionOperations.scala @@ -112,11 +112,9 @@ case class GpuElementAt(left: Expression, right: Expression) lhs.getBase.extractListElement(rhs.getBase.getInt) } } else { - withResource(GpuScalar.from(null, dataType)) { nullScalar => - ColumnVector.fromScalar(nullScalar, lhs.getRowCount.toInt) + GpuColumnVector.columnVectorFromNull(lhs.getRowCount.toInt, dataType) } } - } case _: MapType => { lhs.getBase.getMapValue(rhs.getBase) } From 0ce0d353563c7d1df7431e779746a7d0b9e6b3f3 Mon Sep 17 00:00:00 2001 From: Allen Xu Date: Sat, 15 May 2021 14:42:24 +0800 Subject: [PATCH 20/21] refine --- .../apache/spark/sql/rapids/collectionOperations.scala | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/collectionOperations.scala b/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/collectionOperations.scala index 5c3396ac4d9..b76e9c7dc81 100644 --- a/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/collectionOperations.scala +++ b/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/collectionOperations.scala @@ -103,13 +103,14 @@ case class GpuElementAt(left: Expression, right: Expression) lhs.dataType match { case _: ArrayType => { if (rhs.isValid) { - if (rhs.getBase.getInt > 0) { + val ordinalValue = rhs.getValue.asInstanceOf[Int] + if (ordinalValue > 0) { // SQL 1-based index - lhs.getBase.extractListElement(rhs.getBase.getInt - 1) - } else if (rhs.getValue == 0) { + lhs.getBase.extractListElement(ordinalValue - 1) + } else if (ordinalValue == 0) { throw new ArrayIndexOutOfBoundsException("SQL array indices start at 1") } else { - lhs.getBase.extractListElement(rhs.getBase.getInt) + lhs.getBase.extractListElement(ordinalValue) } } else { GpuColumnVector.columnVectorFromNull(lhs.getRowCount.toInt, dataType) From 0f55a04910f86842b37712bc79835b4d79409cc1 Mon Sep 17 00:00:00 2001 From: Allen Xu Date: Sat, 15 May 2021 17:34:01 +0800 Subject: [PATCH 21/21] indentation Signed-off-by: Allen Xu --- .../org/apache/spark/sql/rapids/collectionOperations.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/collectionOperations.scala b/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/collectionOperations.scala index b76e9c7dc81..19b198e3d35 100644 --- a/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/collectionOperations.scala +++ b/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/collectionOperations.scala @@ -114,8 +114,8 @@ case class GpuElementAt(left: Expression, right: Expression) } } else { GpuColumnVector.columnVectorFromNull(lhs.getRowCount.toInt, dataType) - } } + } case _: MapType => { lhs.getBase.getMapValue(rhs.getBase) }