From 1b8cea3988001c06a01775c8cf5da38147c51121 Mon Sep 17 00:00:00 2001 From: Allen Xu Date: Wed, 28 Apr 2021 20:13:10 +0800 Subject: [PATCH] debug for array of array --- .../src/main/python/array_test.py | 29 +++++++++++++---- integration_tests/src/main/python/map_test.py | 3 +- .../nvidia/spark/rapids/GpuOverrides.scala | 31 ++++++++++++++++--- .../sql/rapids/complexTypeExtractors.scala | 10 ------ 4 files changed, 51 insertions(+), 22 deletions(-) diff --git a/integration_tests/src/main/python/array_test.py b/integration_tests/src/main/python/array_test.py index 3307c15927f..c943cbf02b8 100644 --- a/integration_tests/src/main/python/array_test.py +++ b/integration_tests/src/main/python/array_test.py @@ -110,11 +110,28 @@ def main_df(spark): return df.select(array_contains(col('a'), chk_val)) assert_gpu_and_cpu_are_equal_collect(main_df) -@pytest.mark.parametrize('data_gen', [byte_gen, short_gen, int_gen, long_gen, - FloatGen(no_nans=True), DoubleGen(no_nans=True), - string_gen, boolean_gen, date_gen, timestamp_gen], ids=idfn) +@pytest.mark.parametrize('data_gen', array_gens_sample, ids=idfn) def test_array_element_at(data_gen): - arr_gen = ArrayGen(data_gen) assert_gpu_and_cpu_are_equal_collect(lambda spark: unary_op_df( - spark, arr_gen).select(element_at(col('a'), 1), - element_at(col('a'), -1)), no_nans_conf) \ No newline at end of file + spark, data_gen).select(element_at(col('a'), 1), + element_at(col('a'), -1)), + conf={'spark.sql.ansi.enabled':False}) + + +@pytest.mark.parametrize('data_gen', array_gens_sample, ids=idfn) +def test_array_element_at_null(data_gen): + array_gen = ArrayGen(data_gen) + assert_gpu_and_cpu_are_equal_collect(lambda spark: unary_op_df( + spark, data_gen).select(element_at(col('a'), 1), + element_at(col('a'), -1)), + conf={'spark.sql.ansi.enabled':False, + 'spark.sql.legacy.allowNegativeScaleOfDecimal': True}) + +@pytest.mark.parametrize('data_gen', [ArrayGen(ArrayGen(short_gen, max_length=10), max_length=10)], ids=idfn) +def test_array_element_at_test(data_gen): + array_gen = ArrayGen(data_gen) + assert_gpu_and_cpu_are_equal_collect(lambda spark: unary_op_df( + spark, data_gen).select(element_at(col('a'), 1), + element_at(col('a'), -1)), + conf={'spark.sql.ansi.enabled':False, + 'spark.sql.legacy.allowNegativeScaleOfDecimal': True}) \ No newline at end of file diff --git a/integration_tests/src/main/python/map_test.py b/integration_tests/src/main/python/map_test.py index f67c2cb44a2..372a99fce9c 100644 --- a/integration_tests/src/main/python/map_test.py +++ b/integration_tests/src/main/python/map_test.py @@ -40,4 +40,5 @@ def test_simple_map_element_at(data_gen): 'element_at(a, "null")', 'element_at(a, "key_9")', 'element_at(a, "NOT_FOUND")', - 'element_at(a, "key_5")')) + 'element_at(a, "key_5")'), + conf={'spark.sql.ansi.enabled':False}) diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuOverrides.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuOverrides.scala index 53defcef336..606ea4ad443 100644 --- a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuOverrides.scala +++ b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuOverrides.scala @@ -2248,11 +2248,32 @@ object GpuOverrides { "Returns element of array at given(1-based) index in value if column is array. " + "Returns value for the given key in value if column is map.", ExprChecks.binaryProjectNotLambda( - TypeSig.commonCudfTypes, TypeSig.all, - ("left", TypeSig.ARRAY.nested(TypeSig.commonCudfTypes) + - TypeSig.MAP.nested(TypeSig.STRING), TypeSig.all), - ("right", TypeSig.lit(TypeEnum.INT) + TypeSig.lit(TypeEnum.STRING), TypeSig.all)), - (in, conf, p, r) => new GpuElementAtMeta(in, conf, p, r)), + TypeSig.commonCudfTypes + TypeSig.ARRAY + TypeSig.STRUCT + TypeSig.NULL + + TypeSig.DECIMAL + TypeSig.MAP, TypeSig.all, + ("array/map", TypeSig.ARRAY.nested(TypeSig.commonCudfTypes + TypeSig.ARRAY + + TypeSig.STRUCT + TypeSig.NULL + TypeSig.DECIMAL + TypeSig.MAP) + + TypeSig.MAP.nested(TypeSig.STRING) + .withPsNote(TypeEnum.MAP ,"If it's map, only string is supported. " + + "Extra check is inside the expression metadata"), TypeSig.all), + ("index/key", TypeSig.lit(TypeEnum.INT) + TypeSig.lit(TypeEnum.STRING), TypeSig.all)), + (in, conf, p, r) => new BinaryExprMeta[ElementAt](in, conf, p, r) { + override def tagExprForGpu(): Unit = { + // To distinguish the supported nested type between Array and Map + in.left.dataType match { + case MapType(_,valueType,_) => { + valueType match { + case StringType => // minimum support + case _ => willNotWorkOnGpu(s"${valueType.simpleString} is not supported for" + + s" Map value") + } + } + case ArrayType(_,_) => // Array supports more + } + } + override def convertToGpu(lhs: Expression, rhs: Expression): GpuExpression = { + GpuElementAt(lhs, rhs) + } + }), expr[CreateNamedStruct]( "Creates a struct with the given field names and values", CreateNamedStructCheck, diff --git a/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/complexTypeExtractors.scala b/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/complexTypeExtractors.scala index 600a9bc1de3..90270ce5c7b 100644 --- a/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/complexTypeExtractors.scala +++ b/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/complexTypeExtractors.scala @@ -186,16 +186,6 @@ case class GpuGetMapValue(child: Expression, key: Expression) override def right: Expression = key } -class GpuElementAtMeta( - expr: ElementAt, - conf: RapidsConf, - parent: Option[RapidsMeta[_, _, _]], - rule: DataFromReplacementRule) - extends BinaryExprMeta[ElementAt](expr, conf, parent, rule) { - override def convertToGpu(lhs: Expression, rhs: Expression): GpuExpression = { - GpuElementAt(lhs, rhs) - } -} case class GpuElementAt(left: Expression, right: Expression) extends GpuBinaryExpression with ExpectsInputTypes {