Skip to content

Commit

Permalink
debug for array of array
Browse files Browse the repository at this point in the history
  • Loading branch information
wjxiz1992 committed Apr 28, 2021
1 parent 07f5713 commit 1b8cea3
Show file tree
Hide file tree
Showing 4 changed files with 51 additions and 22 deletions.
29 changes: 23 additions & 6 deletions integration_tests/src/main/python/array_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,11 +110,28 @@ def main_df(spark):
return df.select(array_contains(col('a'), chk_val))
assert_gpu_and_cpu_are_equal_collect(main_df)

@pytest.mark.parametrize('data_gen', [byte_gen, short_gen, int_gen, long_gen,
FloatGen(no_nans=True), DoubleGen(no_nans=True),
string_gen, boolean_gen, date_gen, timestamp_gen], ids=idfn)
@pytest.mark.parametrize('data_gen', array_gens_sample, ids=idfn)
def test_array_element_at(data_gen):
arr_gen = ArrayGen(data_gen)
assert_gpu_and_cpu_are_equal_collect(lambda spark: unary_op_df(
spark, arr_gen).select(element_at(col('a'), 1),
element_at(col('a'), -1)), no_nans_conf)
spark, data_gen).select(element_at(col('a'), 1),
element_at(col('a'), -1)),
conf={'spark.sql.ansi.enabled':False})


@pytest.mark.parametrize('data_gen', array_gens_sample, ids=idfn)
def test_array_element_at_null(data_gen):
array_gen = ArrayGen(data_gen)
assert_gpu_and_cpu_are_equal_collect(lambda spark: unary_op_df(
spark, data_gen).select(element_at(col('a'), 1),
element_at(col('a'), -1)),
conf={'spark.sql.ansi.enabled':False,
'spark.sql.legacy.allowNegativeScaleOfDecimal': True})

@pytest.mark.parametrize('data_gen', [ArrayGen(ArrayGen(short_gen, max_length=10), max_length=10)], ids=idfn)
def test_array_element_at_test(data_gen):
array_gen = ArrayGen(data_gen)
assert_gpu_and_cpu_are_equal_collect(lambda spark: unary_op_df(
spark, data_gen).select(element_at(col('a'), 1),
element_at(col('a'), -1)),
conf={'spark.sql.ansi.enabled':False,
'spark.sql.legacy.allowNegativeScaleOfDecimal': True})
3 changes: 2 additions & 1 deletion integration_tests/src/main/python/map_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,4 +40,5 @@ def test_simple_map_element_at(data_gen):
'element_at(a, "null")',
'element_at(a, "key_9")',
'element_at(a, "NOT_FOUND")',
'element_at(a, "key_5")'))
'element_at(a, "key_5")'),
conf={'spark.sql.ansi.enabled':False})
Original file line number Diff line number Diff line change
Expand Up @@ -2248,11 +2248,32 @@ object GpuOverrides {
"Returns element of array at given(1-based) index in value if column is array. " +
"Returns value for the given key in value if column is map.",
ExprChecks.binaryProjectNotLambda(
TypeSig.commonCudfTypes, TypeSig.all,
("left", TypeSig.ARRAY.nested(TypeSig.commonCudfTypes) +
TypeSig.MAP.nested(TypeSig.STRING), TypeSig.all),
("right", TypeSig.lit(TypeEnum.INT) + TypeSig.lit(TypeEnum.STRING), TypeSig.all)),
(in, conf, p, r) => new GpuElementAtMeta(in, conf, p, r)),
TypeSig.commonCudfTypes + TypeSig.ARRAY + TypeSig.STRUCT + TypeSig.NULL +
TypeSig.DECIMAL + TypeSig.MAP, TypeSig.all,
("array/map", TypeSig.ARRAY.nested(TypeSig.commonCudfTypes + TypeSig.ARRAY +
TypeSig.STRUCT + TypeSig.NULL + TypeSig.DECIMAL + TypeSig.MAP) +
TypeSig.MAP.nested(TypeSig.STRING)
.withPsNote(TypeEnum.MAP ,"If it's map, only string is supported. " +
"Extra check is inside the expression metadata"), TypeSig.all),
("index/key", TypeSig.lit(TypeEnum.INT) + TypeSig.lit(TypeEnum.STRING), TypeSig.all)),
(in, conf, p, r) => new BinaryExprMeta[ElementAt](in, conf, p, r) {
override def tagExprForGpu(): Unit = {
// To distinguish the supported nested type between Array and Map
in.left.dataType match {
case MapType(_,valueType,_) => {
valueType match {
case StringType => // minimum support
case _ => willNotWorkOnGpu(s"${valueType.simpleString} is not supported for" +
s" Map value")
}
}
case ArrayType(_,_) => // Array supports more
}
}
override def convertToGpu(lhs: Expression, rhs: Expression): GpuExpression = {
GpuElementAt(lhs, rhs)
}
}),
expr[CreateNamedStruct](
"Creates a struct with the given field names and values",
CreateNamedStructCheck,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -186,16 +186,6 @@ case class GpuGetMapValue(child: Expression, key: Expression)
override def right: Expression = key
}

class GpuElementAtMeta(
expr: ElementAt,
conf: RapidsConf,
parent: Option[RapidsMeta[_, _, _]],
rule: DataFromReplacementRule)
extends BinaryExprMeta[ElementAt](expr, conf, parent, rule) {
override def convertToGpu(lhs: Expression, rhs: Expression): GpuExpression = {
GpuElementAt(lhs, rhs)
}
}

case class GpuElementAt(left: Expression, right: Expression)
extends GpuBinaryExpression with ExpectsInputTypes {
Expand Down

0 comments on commit 1b8cea3

Please sign in to comment.