From fb045394b4c83f5f17376a20594c6de9866dcc42 Mon Sep 17 00:00:00 2001 From: Kazuaki Ishizaki Date: Wed, 11 Apr 2018 07:21:07 +0100 Subject: [PATCH 1/9] initial commit --- python/pyspark/sql/functions.py | 17 ++++++ .../catalyst/analysis/FunctionRegistry.scala | 1 + .../expressions/stringExpressions.scala | 53 +++++++++++++++++++ .../CollectionExpressionsSuite.scala | 24 +++++++++ .../org/apache/spark/sql/functions.scala | 14 +++++ .../spark/sql/DataFrameFunctionsSuite.scala | 18 +++++++ 6 files changed, 127 insertions(+) diff --git a/python/pyspark/sql/functions.py b/python/pyspark/sql/functions.py index d3bb0a5d6b36a..44a9933dc88bb 100644 --- a/python/pyspark/sql/functions.py +++ b/python/pyspark/sql/functions.py @@ -1845,6 +1845,23 @@ def array_contains(col, value): return Column(sc._jvm.functions.array_contains(_to_java_column(col), value)) +@since(2.4) +def array_position(str, substr): + """ + Collection function: Locates the position of the first occurrence of substr column + in the given string as Decimal. Returns null if either of the arguments are null. + + .. note:: The position is not zero based, but 1 based index. Returns 0 if substr + could not be found in str. + + >>> df = spark.createDataFrame([('abcd',)], ['s',]) + >>> df.select(array_position(df.s, 'b').alias('s')).collect() + [Row(s=Decimal(2))] + """ + sc = SparkContext._active_spark_context + return Column(sc._jvm.functions.array_position(_to_java_column(str), substr)) + + @since(1.4) def explode(col): """Returns a new row for each element in the given array or map. diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala index 38c874ad948e1..74095fe697b6a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala @@ -402,6 +402,7 @@ object FunctionRegistry { // collection functions expression[CreateArray]("array"), expression[ArrayContains]("array_contains"), + expression[ArrayPosition]("array_position"), expression[CreateMap]("map"), expression[CreateNamedStruct]("named_struct"), expression[MapKeys]("map_keys"), diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala index 5a02ca0d6862c..de9c2eb68375f 100755 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala @@ -1002,6 +1002,59 @@ case class StringInstr(str: Expression, substr: Expression) } } +/** + * A function that returns the position of the first occurrence of substr in the given string + * as BigInt. Returns 0 if substr could not be found in str. + * Returns null if either of the arguments are null and + * + * NOTE: that this is not zero based, but 1-based index. The first character in str has index 1. + */ +// scalastyle:off line.size.limit +@ExpressionDescription( + usage = """ + _FUNC_(str, substr) - Returns the (1-based) index of the first occurrence of `substr` in `str`. + """, + examples = """ + Examples: + > SELECT _FUNC_('SparkSQL', 'SQL'); + 6 + """) +// scalastyle:on line.size.limit +case class ArrayPosition(str: Expression, substr: Expression) + extends BinaryExpression with ImplicitCastInputTypes { + + override def left: Expression = str + override def right: Expression = substr + override def dataType: DataType = DecimalType.BigIntDecimal + override def inputTypes: Seq[DataType] = Seq(StringType, StringType) + + private val stringInstr = StringInstr(str, substr) + + override def nullSafeEval(string: Any, sub: Any): Any = { + val r = stringInstr.nullSafeEval(string, sub) + if (r == null) null else { + new Decimal().setOrNull(r.asInstanceOf[Int].toLong, DecimalType.MAX_PRECISION, 0) + } + } + + override def prettyName: String = "array_position" + + override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { + val value = ctx.freshName("arrayPositionValue") + val evPosition = stringInstr.doGenCode( + ctx, ExprCode("", ev.isNull, VariableValue(value, CodeGenerator.JAVA_INT))) + ev.copy( + code = evPosition.code + + s""" + |Decimal ${ev.value} = ${CodeGenerator.defaultValue(dataType)}; + |if (!${evPosition.isNull}) { + | ${ev.value} = Decimal.apply((long)$value); + |} + """.stripMargin, + isNull = evPosition.isNull) + } +} + /** * Returns the substring from string str before count occurrences of the delimiter delim. * If count is positive, everything the left of the final delimiter (counting from left) is diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala index 517639dbc7232..34beb5772064e 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala @@ -18,6 +18,7 @@ package org.apache.spark.sql.catalyst.expressions import org.apache.spark.SparkFunSuite +import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.types._ class CollectionExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { @@ -43,6 +44,29 @@ class CollectionExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper checkEvaluation(Size(Literal.create(null, ArrayType(StringType))), -1) } + test("Array Position") { + val s1 = 'a.string.at(0) + val s2 = 'b.string.at(1) + val s3 = 'c.string.at(2) + val row1 = create_row("aaads", "aa", "zz") + val nullString = Literal.create(null, StringType) + + checkEvaluation(ArrayPosition(Literal("aaads"), Literal("aa")), Decimal(BigInt(1)), row1) + checkEvaluation(ArrayPosition(Literal("aaads"), Literal("de")), Decimal(BigInt(0)), row1) + checkEvaluation(ArrayPosition(nullString, Literal("de")), null, row1) + checkEvaluation(ArrayPosition(Literal("aaads"), nullString), null, row1) + + checkEvaluation(ArrayPosition(s1, s2), Decimal(BigInt(1)), row1) + checkEvaluation(ArrayPosition(s1, s3), Decimal(BigInt(0)), row1) + + // scalastyle:off + // non ascii characters are not allowed in the source code, so we disable the scalastyle. + checkEvaluation(ArrayPosition(s1, s2), Decimal(BigInt(3)), create_row("花花世界", "世界")) + checkEvaluation(ArrayPosition(s1, s2), Decimal(BigInt(1)), create_row("花花世界", "花")) + checkEvaluation(ArrayPosition(s1, s2), Decimal(BigInt(0)), create_row("花花世界", "小")) + // scalastyle:on + } + test("MapKeys/MapValues") { val m0 = Literal.create(Map("a" -> "1", "b" -> "2"), MapType(StringType, StringType)) val m1 = Literal.create(Map[String, String](), MapType(StringType, StringType)) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala index a55a800f48245..3466ce150aa82 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala @@ -3038,6 +3038,20 @@ object functions { ArrayContains(column.expr, Literal(value)) } + /** + * Locate the position of the first occurrence of substr column in the given string. + * Returns null if either of the arguments are null. + * + * @note The position is not zero based, but 1 based index. Returns 0 if substr + * could not be found in str. + * + * @group string_funcs + * @since 2.4.0 + */ + def array_position(str: Column, substring: String): Column = withExpr { + ArrayPosition(str.expr, lit(substring).expr) + } + /** * Creates a new row for each element in the given array or map column. * diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala index 74c42f2599dca..6759ce3a02240 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala @@ -326,6 +326,24 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSQLContext { }.getMessage().contains("only supports array input")) } + test("array position function") { + val df = Seq(("aaads", "aa", "zz", null)).toDF("a", "b", "c", "nul") + + checkAnswer( + df.select( + array_position($"a", "aa"), + array_position($"a", "gg"), + array_position($"nul", "gg")), + Row(BigInt(1), BigInt(0), null)) + checkAnswer( + df.selectExpr( + "array_position(a, b)", + "array_position(a, c)", + "array_position(a, nul)", + "array_position(nul, c)"), + Row(BigInt(1), BigInt(0), null, null)) + } + test("array size function") { val df = Seq( (Seq[Int](1, 2), "x"), From 3d8069d45fef7d88f8603550f5d8a4c09a74e130 Mon Sep 17 00:00:00 2001 From: Kazuaki Ishizaki Date: Wed, 11 Apr 2018 07:26:10 +0100 Subject: [PATCH 2/9] remove unnecessary scalastyle pragma --- .../spark/sql/catalyst/expressions/stringExpressions.scala | 2 -- 1 file changed, 2 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala index de9c2eb68375f..fb433e406ff7e 100755 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala @@ -1009,7 +1009,6 @@ case class StringInstr(str: Expression, substr: Expression) * * NOTE: that this is not zero based, but 1-based index. The first character in str has index 1. */ -// scalastyle:off line.size.limit @ExpressionDescription( usage = """ _FUNC_(str, substr) - Returns the (1-based) index of the first occurrence of `substr` in `str`. @@ -1019,7 +1018,6 @@ case class StringInstr(str: Expression, substr: Expression) > SELECT _FUNC_('SparkSQL', 'SQL'); 6 """) -// scalastyle:on line.size.limit case class ArrayPosition(str: Expression, substr: Expression) extends BinaryExpression with ImplicitCastInputTypes { From 5339a65ad4cf7ffbb98552873276e16dac90d900 Mon Sep 17 00:00:00 2001 From: Kazuaki Ishizaki Date: Wed, 11 Apr 2018 15:43:40 +0100 Subject: [PATCH 3/9] address review comment --- .../spark/sql/catalyst/expressions/stringExpressions.scala | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala index fb433e406ff7e..289375175ad00 100755 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala @@ -1017,7 +1017,8 @@ case class StringInstr(str: Expression, substr: Expression) Examples: > SELECT _FUNC_('SparkSQL', 'SQL'); 6 - """) + """, + since = "2.4.0") case class ArrayPosition(str: Expression, substr: Expression) extends BinaryExpression with ImplicitCastInputTypes { From f1238b6b4f6e47ea2349fc3e905bdac79d6d4557 Mon Sep 17 00:00:00 2001 From: Kazuaki Ishizaki Date: Wed, 11 Apr 2018 21:03:09 +0100 Subject: [PATCH 4/9] rebase with master --- .../expressions/collectionOperations.scala | 53 +++++++++++++++++++ .../expressions/stringExpressions.scala | 52 ------------------ 2 files changed, 53 insertions(+), 52 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala index 76b71f5b86074..7aca0c0dc2dbf 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala @@ -505,3 +505,56 @@ case class ArrayMax(child: Expression) extends UnaryExpression with ImplicitCast override def prettyName: String = "array_max" } + + +/** + * A function that returns the position of the first occurrence of substr in the given string + * as BigInt. Returns 0 if substr could not be found in str. + * Returns null if either of the arguments are null and + * + * NOTE: that this is not zero based, but 1-based index. The first character in str has index 1. + */ +@ExpressionDescription( + usage = """ + _FUNC_(str, substr) - Returns the (1-based) index of the first occurrence of `substr` in `str`. + """, + examples = """ + Examples: + > SELECT _FUNC_('SparkSQL', 'SQL'); + 6 + """, + since = "2.4.0") +case class ArrayPosition(str: Expression, substr: Expression) + extends BinaryExpression with ImplicitCastInputTypes { + + override def left: Expression = str + override def right: Expression = substr + override def dataType: DataType = DecimalType.BigIntDecimal + override def inputTypes: Seq[DataType] = Seq(StringType, StringType) + + private val stringInstr = StringInstr(str, substr) + + override def nullSafeEval(string: Any, sub: Any): Any = { + val r = stringInstr.nullSafeEval(string, sub) + if (r == null) null else { + new Decimal().setOrNull(r.asInstanceOf[Int].toLong, DecimalType.MAX_PRECISION, 0) + } + } + + override def prettyName: String = "array_position" + + override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { + val value = ctx.freshName("arrayPositionValue") + val evPosition = stringInstr.doGenCode( + ctx, ExprCode(ev.isNull, JavaCode.variable(value, IntegerType))) + ev.copy( + code = evPosition.code + + s""" + |Decimal ${ev.value} = ${CodeGenerator.defaultValue(dataType)}; + |if (!${evPosition.isNull}) { + | ${ev.value} = Decimal.apply((long)$value); + |} + """.stripMargin, + isNull = evPosition.isNull) + } +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala index 289375175ad00..5a02ca0d6862c 100755 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala @@ -1002,58 +1002,6 @@ case class StringInstr(str: Expression, substr: Expression) } } -/** - * A function that returns the position of the first occurrence of substr in the given string - * as BigInt. Returns 0 if substr could not be found in str. - * Returns null if either of the arguments are null and - * - * NOTE: that this is not zero based, but 1-based index. The first character in str has index 1. - */ -@ExpressionDescription( - usage = """ - _FUNC_(str, substr) - Returns the (1-based) index of the first occurrence of `substr` in `str`. - """, - examples = """ - Examples: - > SELECT _FUNC_('SparkSQL', 'SQL'); - 6 - """, - since = "2.4.0") -case class ArrayPosition(str: Expression, substr: Expression) - extends BinaryExpression with ImplicitCastInputTypes { - - override def left: Expression = str - override def right: Expression = substr - override def dataType: DataType = DecimalType.BigIntDecimal - override def inputTypes: Seq[DataType] = Seq(StringType, StringType) - - private val stringInstr = StringInstr(str, substr) - - override def nullSafeEval(string: Any, sub: Any): Any = { - val r = stringInstr.nullSafeEval(string, sub) - if (r == null) null else { - new Decimal().setOrNull(r.asInstanceOf[Int].toLong, DecimalType.MAX_PRECISION, 0) - } - } - - override def prettyName: String = "array_position" - - override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { - val value = ctx.freshName("arrayPositionValue") - val evPosition = stringInstr.doGenCode( - ctx, ExprCode("", ev.isNull, VariableValue(value, CodeGenerator.JAVA_INT))) - ev.copy( - code = evPosition.code + - s""" - |Decimal ${ev.value} = ${CodeGenerator.defaultValue(dataType)}; - |if (!${evPosition.isNull}) { - | ${ev.value} = Decimal.apply((long)$value); - |} - """.stripMargin, - isNull = evPosition.isNull) - } -} - /** * Returns the substring from string str before count occurrences of the delimiter delim. * If count is positive, everything the left of the final delimiter (counting from left) is From 07305db03773d12d36e1314466f3433556f7abcf Mon Sep 17 00:00:00 2001 From: Kazuaki Ishizaki Date: Thu, 12 Apr 2018 07:27:51 +0100 Subject: [PATCH 5/9] reimplement this for desired behavior --- python/pyspark/sql/functions.py | 10 ++-- .../expressions/collectionOperations.scala | 56 ++++++++++--------- .../CollectionExpressionsSuite.scala | 44 +++++++-------- .../org/apache/spark/sql/functions.scala | 12 ++-- .../spark/sql/DataFrameFunctionsSuite.scala | 52 +++++++++++------ 5 files changed, 97 insertions(+), 77 deletions(-) diff --git a/python/pyspark/sql/functions.py b/python/pyspark/sql/functions.py index 44a9933dc88bb..ca0377dc0f534 100644 --- a/python/pyspark/sql/functions.py +++ b/python/pyspark/sql/functions.py @@ -1846,7 +1846,7 @@ def array_contains(col, value): @since(2.4) -def array_position(str, substr): +def array_position(col, value): """ Collection function: Locates the position of the first occurrence of substr column in the given string as Decimal. Returns null if either of the arguments are null. @@ -1854,12 +1854,12 @@ def array_position(str, substr): .. note:: The position is not zero based, but 1 based index. Returns 0 if substr could not be found in str. - >>> df = spark.createDataFrame([('abcd',)], ['s',]) - >>> df.select(array_position(df.s, 'b').alias('s')).collect() - [Row(s=Decimal(2))] + >>> df = spark.createDataFrame([(["c", "b", "a"],), ([],)], ['data']) + >>> df.select(array_position(df.s, "a")).collect() + [Row(array_position(data, a)=Decimal(3)), Row(array_position(data, a)=Decimal(0))] """ sc = SparkContext._active_spark_context - return Column(sc._jvm.functions.array_position(_to_java_column(str), substr)) + return Column(sc._jvm.functions.array_position(_to_java_column(col), value)) @since(1.4) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala index 7aca0c0dc2dbf..e3f023f146ae7 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala @@ -516,45 +516,51 @@ case class ArrayMax(child: Expression) extends UnaryExpression with ImplicitCast */ @ExpressionDescription( usage = """ - _FUNC_(str, substr) - Returns the (1-based) index of the first occurrence of `substr` in `str`. + _FUNC_(array, element) - Returns the (1-based) index of the first element of the array. """, examples = """ Examples: - > SELECT _FUNC_('SparkSQL', 'SQL'); - 6 + > SELECT _FUNC_(array(3, 2, 1), 1); + 3 """, since = "2.4.0") -case class ArrayPosition(str: Expression, substr: Expression) +case class ArrayPosition(left: Expression, right: Expression) extends BinaryExpression with ImplicitCastInputTypes { - override def left: Expression = str - override def right: Expression = substr override def dataType: DataType = DecimalType.BigIntDecimal - override def inputTypes: Seq[DataType] = Seq(StringType, StringType) + override def inputTypes: Seq[AbstractDataType] = + Seq(ArrayType, left.dataType.asInstanceOf[ArrayType].elementType) - private val stringInstr = StringInstr(str, substr) + override def nullable: Boolean = { + left.nullable || right.nullable || left.dataType.asInstanceOf[ArrayType].containsNull + } - override def nullSafeEval(string: Any, sub: Any): Any = { - val r = stringInstr.nullSafeEval(string, sub) - if (r == null) null else { - new Decimal().setOrNull(r.asInstanceOf[Int].toLong, DecimalType.MAX_PRECISION, 0) - } + override def nullSafeEval(arr: Any, value: Any): Any = { + arr.asInstanceOf[ArrayData].foreach(right.dataType, (i, v) => + if (v == value) { + return new Decimal().setOrNull((i + 1).toLong, DecimalType.MAX_PRECISION, 0) + } + ) + new Decimal().setOrNull(0.toLong, DecimalType.MAX_PRECISION, 0) } override def prettyName: String = "array_position" override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { - val value = ctx.freshName("arrayPositionValue") - val evPosition = stringInstr.doGenCode( - ctx, ExprCode(ev.isNull, JavaCode.variable(value, IntegerType))) - ev.copy( - code = evPosition.code + - s""" - |Decimal ${ev.value} = ${CodeGenerator.defaultValue(dataType)}; - |if (!${evPosition.isNull}) { - | ${ev.value} = Decimal.apply((long)$value); - |} - """.stripMargin, - isNull = evPosition.isNull) + nullSafeCodeGen(ctx, ev, (arr, value) => { + val pos = ctx.freshName("arrayPosition") + val i = ctx.freshName("i") + val getValue = CodeGenerator.getValue(arr, right.dataType, i) + s""" + |int ${pos} = 0; + |for (int $i = 0; $i < $arr.numElements(); $i ++) { + | if (${ctx.genEqual(right.dataType, value, getValue)}) { + | ${pos} = $i + 1; + | break; + | } + |} + |${ev.value} = Decimal.apply((long)$pos); + """ + }) } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala index 34beb5772064e..49b50add0e67f 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala @@ -44,29 +44,6 @@ class CollectionExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper checkEvaluation(Size(Literal.create(null, ArrayType(StringType))), -1) } - test("Array Position") { - val s1 = 'a.string.at(0) - val s2 = 'b.string.at(1) - val s3 = 'c.string.at(2) - val row1 = create_row("aaads", "aa", "zz") - val nullString = Literal.create(null, StringType) - - checkEvaluation(ArrayPosition(Literal("aaads"), Literal("aa")), Decimal(BigInt(1)), row1) - checkEvaluation(ArrayPosition(Literal("aaads"), Literal("de")), Decimal(BigInt(0)), row1) - checkEvaluation(ArrayPosition(nullString, Literal("de")), null, row1) - checkEvaluation(ArrayPosition(Literal("aaads"), nullString), null, row1) - - checkEvaluation(ArrayPosition(s1, s2), Decimal(BigInt(1)), row1) - checkEvaluation(ArrayPosition(s1, s3), Decimal(BigInt(0)), row1) - - // scalastyle:off - // non ascii characters are not allowed in the source code, so we disable the scalastyle. - checkEvaluation(ArrayPosition(s1, s2), Decimal(BigInt(3)), create_row("花花世界", "世界")) - checkEvaluation(ArrayPosition(s1, s2), Decimal(BigInt(1)), create_row("花花世界", "花")) - checkEvaluation(ArrayPosition(s1, s2), Decimal(BigInt(0)), create_row("花花世界", "小")) - // scalastyle:on - } - test("MapKeys/MapValues") { val m0 = Literal.create(Map("a" -> "1", "b" -> "2"), MapType(StringType, StringType)) val m1 = Literal.create(Map[String, String](), MapType(StringType, StringType)) @@ -193,4 +170,25 @@ class CollectionExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper checkEvaluation(Reverse(as7), null) checkEvaluation(Reverse(aa), Seq(Seq("e"), Seq("c", "d"), Seq("a", "b"))) } + + test("Array Position") { + val a0 = Literal.create(Seq(1, 2, 3), ArrayType(IntegerType)) + val a1 = Literal.create(Seq[String](null, ""), ArrayType(StringType)) + val a2 = Literal.create(Seq(null), ArrayType(LongType)) + val a3 = Literal.create(null, ArrayType(StringType)) + + checkEvaluation(ArrayPosition(a0, Literal(1)), Decimal(BigInt(1))) + checkEvaluation(ArrayPosition(a0, Literal(0)), Decimal(BigInt(0))) + checkEvaluation(ArrayPosition(a0, Literal.create(null, IntegerType)), null) + + checkEvaluation(ArrayPosition(a1, Literal("")), Decimal(BigInt(2))) + checkEvaluation(ArrayPosition(a1, Literal("a")), Decimal(BigInt(0))) + checkEvaluation(ArrayPosition(a1, Literal.create(null, StringType)), null) + + checkEvaluation(ArrayPosition(a2, Literal(1L)), Decimal(BigInt(0))) + checkEvaluation(ArrayPosition(a2, Literal.create(null, LongType)), null) + + checkEvaluation(ArrayPosition(a3, Literal("")), null) + checkEvaluation(ArrayPosition(a3, Literal.create(null, StringType)), null) + } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala index 3466ce150aa82..60258407e3d07 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala @@ -3039,17 +3039,17 @@ object functions { } /** - * Locate the position of the first occurrence of substr column in the given string. + * Locate the position of the first occurrence of the value in the given array. * Returns null if either of the arguments are null. * - * @note The position is not zero based, but 1 based index. Returns 0 if substr - * could not be found in str. + * @note The position is not zero based, but 1 based index. Returns 0 if value + * could not be found in array. * - * @group string_funcs + * @group collection_funcs * @since 2.4.0 */ - def array_position(str: Column, substring: String): Column = withExpr { - ArrayPosition(str.expr, lit(substring).expr) + def array_position(column: Column, value: Any): Column = withExpr { + ArrayPosition(column.expr, Literal(value)) } /** diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala index 6759ce3a02240..5e7112eea1c5c 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala @@ -326,24 +326,6 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSQLContext { }.getMessage().contains("only supports array input")) } - test("array position function") { - val df = Seq(("aaads", "aa", "zz", null)).toDF("a", "b", "c", "nul") - - checkAnswer( - df.select( - array_position($"a", "aa"), - array_position($"a", "gg"), - array_position($"nul", "gg")), - Row(BigInt(1), BigInt(0), null)) - checkAnswer( - df.selectExpr( - "array_position(a, b)", - "array_position(a, c)", - "array_position(a, nul)", - "array_position(nul, c)"), - Row(BigInt(1), BigInt(0), null, null)) - } - test("array size function") { val df = Seq( (Seq[Int](1, 2), "x"), @@ -553,6 +535,40 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSQLContext { } } + test("array position function") { + val df = Seq( + (Seq[Int](1, 2), "x"), + (Seq[Int](), "x") + ).toDF("a", "b") + + checkAnswer( + df.select(array_position(df("a"), 1)), + Seq(Row(BigInt(1)), Row(BigInt(0))) + ) + checkAnswer( + df.selectExpr("array_position(a, 1)"), + Seq(Row(BigInt(1)), Row(BigInt(0))) + ) + + checkAnswer( + df.select(array_position(df("a"), null)), + Seq(Row(null), Row(null)) + ) + checkAnswer( + df.selectExpr("array_position(a, null)"), + Seq(Row(null), Row(null)) + ) + + checkAnswer( + df.selectExpr("array_position(array(array(1), null)[0], 1)"), + Seq(Row(BigInt(1)), Row(BigInt(1))) + ) + checkAnswer( + df.selectExpr("array_position(array(1, null), array(1, null)[0])"), + Seq(Row(BigInt(1)), Row(BigInt(1))) + ) + } + private def assertValuesDoNotChangeAfterCoalesceOrUnion(v: Column): Unit = { import DataFrameFunctionsSuite.CodegenFallbackExpr for ((codegenFallback, wholeStage) <- Seq((true, false), (false, false), (false, true))) { From 3a16231ce94d63be2d21c22fba51b12efb9a3ce6 Mon Sep 17 00:00:00 2001 From: Kazuaki Ishizaki Date: Thu, 12 Apr 2018 11:35:28 +0100 Subject: [PATCH 6/9] use long for result instead of BigInt --- python/pyspark/sql/functions.py | 8 ++++---- .../expressions/collectionOperations.scala | 14 +++++++------- .../expressions/CollectionExpressionsSuite.scala | 10 +++++----- .../scala/org/apache/spark/sql/functions.scala | 2 +- .../apache/spark/sql/DataFrameFunctionsSuite.scala | 8 ++++---- 5 files changed, 21 insertions(+), 21 deletions(-) diff --git a/python/pyspark/sql/functions.py b/python/pyspark/sql/functions.py index ca0377dc0f534..19d3e347940ca 100644 --- a/python/pyspark/sql/functions.py +++ b/python/pyspark/sql/functions.py @@ -1848,15 +1848,15 @@ def array_contains(col, value): @since(2.4) def array_position(col, value): """ - Collection function: Locates the position of the first occurrence of substr column - in the given string as Decimal. Returns null if either of the arguments are null. + Collection function: Locates the position of the first occurrence of the given value + in the given array. Returns null if either of the arguments are null. .. note:: The position is not zero based, but 1 based index. Returns 0 if substr could not be found in str. >>> df = spark.createDataFrame([(["c", "b", "a"],), ([],)], ['data']) - >>> df.select(array_position(df.s, "a")).collect() - [Row(array_position(data, a)=Decimal(3)), Row(array_position(data, a)=Decimal(0))] + >>> df.select(array_position(df.data, "a")).collect() + [Row(array_position(data, a)=3), Row(array_position(data, a)=0)] """ sc = SparkContext._active_spark_context return Column(sc._jvm.functions.array_position(_to_java_column(col), value)) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala index e3f023f146ae7..9dbf7e55207fc 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala @@ -508,15 +508,15 @@ case class ArrayMax(child: Expression) extends UnaryExpression with ImplicitCast /** - * A function that returns the position of the first occurrence of substr in the given string - * as BigInt. Returns 0 if substr could not be found in str. + * A function that returns the position of the first occurrence of element in the given array + * as long. Returns 0 if substr could not be found in str. * Returns null if either of the arguments are null and * * NOTE: that this is not zero based, but 1-based index. The first character in str has index 1. */ @ExpressionDescription( usage = """ - _FUNC_(array, element) - Returns the (1-based) index of the first element of the array. + _FUNC_(array, element) - Returns the (1-based) index of the first element of the array as long. """, examples = """ Examples: @@ -527,7 +527,7 @@ case class ArrayMax(child: Expression) extends UnaryExpression with ImplicitCast case class ArrayPosition(left: Expression, right: Expression) extends BinaryExpression with ImplicitCastInputTypes { - override def dataType: DataType = DecimalType.BigIntDecimal + override def dataType: DataType = LongType override def inputTypes: Seq[AbstractDataType] = Seq(ArrayType, left.dataType.asInstanceOf[ArrayType].elementType) @@ -538,10 +538,10 @@ case class ArrayPosition(left: Expression, right: Expression) override def nullSafeEval(arr: Any, value: Any): Any = { arr.asInstanceOf[ArrayData].foreach(right.dataType, (i, v) => if (v == value) { - return new Decimal().setOrNull((i + 1).toLong, DecimalType.MAX_PRECISION, 0) + return (i + 1).toLong } ) - new Decimal().setOrNull(0.toLong, DecimalType.MAX_PRECISION, 0) + 0L } override def prettyName: String = "array_position" @@ -559,7 +559,7 @@ case class ArrayPosition(left: Expression, right: Expression) | break; | } |} - |${ev.value} = Decimal.apply((long)$pos); + |${ev.value} = (long) $pos; """ }) } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala index 49b50add0e67f..42984b2adbd1e 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala @@ -177,15 +177,15 @@ class CollectionExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper val a2 = Literal.create(Seq(null), ArrayType(LongType)) val a3 = Literal.create(null, ArrayType(StringType)) - checkEvaluation(ArrayPosition(a0, Literal(1)), Decimal(BigInt(1))) - checkEvaluation(ArrayPosition(a0, Literal(0)), Decimal(BigInt(0))) + checkEvaluation(ArrayPosition(a0, Literal(1)), 1L) + checkEvaluation(ArrayPosition(a0, Literal(0)), 0L) checkEvaluation(ArrayPosition(a0, Literal.create(null, IntegerType)), null) - checkEvaluation(ArrayPosition(a1, Literal("")), Decimal(BigInt(2))) - checkEvaluation(ArrayPosition(a1, Literal("a")), Decimal(BigInt(0))) + checkEvaluation(ArrayPosition(a1, Literal("")), 2L) + checkEvaluation(ArrayPosition(a1, Literal("a")), 0L) checkEvaluation(ArrayPosition(a1, Literal.create(null, StringType)), null) - checkEvaluation(ArrayPosition(a2, Literal(1L)), Decimal(BigInt(0))) + checkEvaluation(ArrayPosition(a2, Literal(1L)), 0L) checkEvaluation(ArrayPosition(a2, Literal.create(null, LongType)), null) checkEvaluation(ArrayPosition(a3, Literal("")), null) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala index 60258407e3d07..3a09ec4f1982e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala @@ -3039,7 +3039,7 @@ object functions { } /** - * Locate the position of the first occurrence of the value in the given array. + * Locates the position of the first occurrence of the value in the given array as long. * Returns null if either of the arguments are null. * * @note The position is not zero based, but 1 based index. Returns 0 if value diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala index 5e7112eea1c5c..13161e7e24cfe 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala @@ -543,11 +543,11 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSQLContext { checkAnswer( df.select(array_position(df("a"), 1)), - Seq(Row(BigInt(1)), Row(BigInt(0))) + Seq(Row(1L), Row(0L)) ) checkAnswer( df.selectExpr("array_position(a, 1)"), - Seq(Row(BigInt(1)), Row(BigInt(0))) + Seq(Row(1L), Row(0L)) ) checkAnswer( @@ -561,11 +561,11 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSQLContext { checkAnswer( df.selectExpr("array_position(array(array(1), null)[0], 1)"), - Seq(Row(BigInt(1)), Row(BigInt(1))) + Seq(Row(1L), Row(1L)) ) checkAnswer( df.selectExpr("array_position(array(1, null), array(1, null)[0])"), - Seq(Row(BigInt(1)), Row(BigInt(1))) + Seq(Row(1L), Row(1L)) ) } From d4cebedf6830dcecd33a6b94e146e9f3a26c7918 Mon Sep 17 00:00:00 2001 From: Kazuaki Ishizaki Date: Fri, 13 Apr 2018 05:31:26 +0100 Subject: [PATCH 7/9] address review comment --- .../sql/catalyst/expressions/collectionOperations.scala | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala index 9dbf7e55207fc..09b2c398d8718 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala @@ -508,9 +508,8 @@ case class ArrayMax(child: Expression) extends UnaryExpression with ImplicitCast /** - * A function that returns the position of the first occurrence of element in the given array - * as long. Returns 0 if substr could not be found in str. - * Returns null if either of the arguments are null and + * Returns the position of the first occurrence of element in the given array as long. + * Returns 0 if substr could not be found in str. Returns null if either of the arguments are null * * NOTE: that this is not zero based, but 1-based index. The first character in str has index 1. */ From 9a0321d67cd6e7c65cbf54e22099cc0cfae03463 Mon Sep 17 00:00:00 2001 From: Kazuaki Ishizaki Date: Mon, 16 Apr 2018 19:00:45 +0100 Subject: [PATCH 8/9] address review comments --- .../catalyst/expressions/collectionOperations.scala | 10 +++++----- .../expressions/CollectionExpressionsSuite.scala | 3 ++- 2 files changed, 7 insertions(+), 6 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala index 09b2c398d8718..fc50f47635478 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala @@ -531,7 +531,7 @@ case class ArrayPosition(left: Expression, right: Expression) Seq(ArrayType, left.dataType.asInstanceOf[ArrayType].elementType) override def nullable: Boolean = { - left.nullable || right.nullable || left.dataType.asInstanceOf[ArrayType].containsNull + left.nullable || right.nullable } override def nullSafeEval(arr: Any, value: Any): Any = { @@ -551,15 +551,15 @@ case class ArrayPosition(left: Expression, right: Expression) val i = ctx.freshName("i") val getValue = CodeGenerator.getValue(arr, right.dataType, i) s""" - |int ${pos} = 0; + |int $pos = 0; |for (int $i = 0; $i < $arr.numElements(); $i ++) { - | if (${ctx.genEqual(right.dataType, value, getValue)}) { - | ${pos} = $i + 1; + | if (!$arr.isNullAt($i) && ${ctx.genEqual(right.dataType, value, getValue)}) { + | $pos = $i + 1; | break; | } |} |${ev.value} = (long) $pos; - """ + """.stripMargin }) } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala index 42984b2adbd1e..5a951b0858e28 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala @@ -172,11 +172,12 @@ class CollectionExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper } test("Array Position") { - val a0 = Literal.create(Seq(1, 2, 3), ArrayType(IntegerType)) + val a0 = Literal.create(Seq(1, null, 2, 3), ArrayType(IntegerType)) val a1 = Literal.create(Seq[String](null, ""), ArrayType(StringType)) val a2 = Literal.create(Seq(null), ArrayType(LongType)) val a3 = Literal.create(null, ArrayType(StringType)) + checkEvaluation(ArrayPosition(a0, Literal(3)), 4L) checkEvaluation(ArrayPosition(a0, Literal(1)), 1L) checkEvaluation(ArrayPosition(a0, Literal(0)), 0L) checkEvaluation(ArrayPosition(a0, Literal.create(null, IntegerType)), null) From 7362b1c699ba0126b92944b6894aa111e31ea9da Mon Sep 17 00:00:00 2001 From: Kazuaki Ishizaki Date: Wed, 18 Apr 2018 14:20:55 +0100 Subject: [PATCH 9/9] address review comments --- python/pyspark/sql/functions.py | 4 ++-- .../catalyst/expressions/collectionOperations.scala | 10 ++++------ .../expressions/CollectionExpressionsSuite.scala | 1 - 3 files changed, 6 insertions(+), 9 deletions(-) diff --git a/python/pyspark/sql/functions.py b/python/pyspark/sql/functions.py index 19d3e347940ca..36dcabc6766d8 100644 --- a/python/pyspark/sql/functions.py +++ b/python/pyspark/sql/functions.py @@ -1851,8 +1851,8 @@ def array_position(col, value): Collection function: Locates the position of the first occurrence of the given value in the given array. Returns null if either of the arguments are null. - .. note:: The position is not zero based, but 1 based index. Returns 0 if substr - could not be found in str. + .. note:: The position is not zero based, but 1 based index. Returns 0 if the given + value could not be found in the array. >>> df = spark.createDataFrame([(["c", "b", "a"],), ([],)], ['data']) >>> df.select(array_position(df.data, "a")).collect() diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala index fc50f47635478..e6a05f535cb1c 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala @@ -509,9 +509,11 @@ case class ArrayMax(child: Expression) extends UnaryExpression with ImplicitCast /** * Returns the position of the first occurrence of element in the given array as long. - * Returns 0 if substr could not be found in str. Returns null if either of the arguments are null + * Returns 0 if the given value could not be found in the array. Returns null if either of + * the arguments are null * - * NOTE: that this is not zero based, but 1-based index. The first character in str has index 1. + * NOTE: that this is not zero based, but 1-based index. The first element in the array has + * index 1. */ @ExpressionDescription( usage = """ @@ -530,10 +532,6 @@ case class ArrayPosition(left: Expression, right: Expression) override def inputTypes: Seq[AbstractDataType] = Seq(ArrayType, left.dataType.asInstanceOf[ArrayType].elementType) - override def nullable: Boolean = { - left.nullable || right.nullable - } - override def nullSafeEval(arr: Any, value: Any): Any = { arr.asInstanceOf[ArrayData].foreach(right.dataType, (i, v) => if (v == value) { diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala index 5a951b0858e28..916cd3bb4cca5 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala @@ -18,7 +18,6 @@ package org.apache.spark.sql.catalyst.expressions import org.apache.spark.SparkFunSuite -import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.types._ class CollectionExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper {