From b8c781d5072764c973a43addf8b16f84b66ef07a Mon Sep 17 00:00:00 2001 From: Kazuaki Ishizaki Date: Thu, 12 Apr 2018 14:21:22 +0100 Subject: [PATCH 01/12] initial commit --- python/pyspark/sql/functions.py | 21 ++++ .../catalyst/analysis/FunctionRegistry.scala | 1 + .../expressions/collectionOperations.scala | 105 +++++++++++++++++- .../expressions/complexTypeExtractors.scala | 64 +++++++---- .../CollectionExpressionsSuite.scala | 48 ++++++++ .../org/apache/spark/sql/functions.scala | 11 ++ .../spark/sql/DataFrameFunctionsSuite.scala | 71 ++++++++++++ 7 files changed, 296 insertions(+), 25 deletions(-) diff --git a/python/pyspark/sql/functions.py b/python/pyspark/sql/functions.py index 36dcabc6766d8..cd70bc57beec3 100644 --- a/python/pyspark/sql/functions.py +++ b/python/pyspark/sql/functions.py @@ -1862,6 +1862,27 @@ def array_position(col, value): return Column(sc._jvm.functions.array_position(_to_java_column(col), value)) +@since(2.4) +def element_at(col, value): + """ + Collection function: returns element of array at given index in value if col is array. + returns value for the given key in value if col is map. + + :param col: name of column containing array or map + :param value: value to check for in array or key to check for in map + + >>> df = spark.createDataFrame([(["a", "b", "c"],), ([],)], ['data']) + >>> df.select(element_at(df.data, 1)).collect() + [Row(element_at(data, 1)=u'a'), Row(element_at(data, 1)=None)] + + >>> df = spark.createDataFrame([({"a": 1.0, "b": 2.0},), ({},)], ['data']) + >>> df.select(element_at(df.data, "a")).collect() + [Row(element_at(data, a)=1.0), Row(element_at(data, a)=None)] + """ + sc = SparkContext._active_spark_context + return Column(sc._jvm.functions.element_at(_to_java_column(col), value)) + + @since(1.4) def explode(col): """Returns a new row for each element in the given array or map. diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala index 74095fe697b6a..a44f2d5272b8e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala @@ -405,6 +405,7 @@ object FunctionRegistry { expression[ArrayPosition]("array_position"), expression[CreateMap]("map"), expression[CreateNamedStruct]("named_struct"), + expression[ElementAt]("element_at"), expression[MapKeys]("map_keys"), expression[MapValues]("map_values"), expression[Size]("size"), diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala index e6a05f535cb1c..9257def1aff20 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala @@ -506,7 +506,6 @@ case class ArrayMax(child: Expression) extends UnaryExpression with ImplicitCast override def prettyName: String = "array_max" } - /** * Returns the position of the first occurrence of element in the given array as long. * Returns 0 if the given value could not be found in the array. Returns null if either of @@ -529,6 +528,7 @@ case class ArrayPosition(left: Expression, right: Expression) extends BinaryExpression with ImplicitCastInputTypes { override def dataType: DataType = LongType + override def inputTypes: Seq[AbstractDataType] = Seq(ArrayType, left.dataType.asInstanceOf[ArrayType].elementType) @@ -561,3 +561,106 @@ case class ArrayPosition(left: Expression, right: Expression) }) } } + +/** + * Returns the value of index `right` in Array `left` or key `right` in Map `left`. + */ +@ExpressionDescription( + usage = """ + _FUNC_(array, index) - Returns element of array at given index. If index < 0, accesses elements + from the last to the first. + + _FUNC_(map, key) - Returns value for given key, or NULL if the key is not contained in the map + """, + examples = """ + Examples: + > SELECT _FUNC_(array(1, 2, 3), 2); + 2 + > SELECT _FUNC_(map(1, 'a', 2, 'b'), 2); + "b" + """, + since = "2.4.0") +case class ElementAt(left: Expression, right: Expression) extends GetMapValueUtil { + + override def dataType: DataType = left.dataType match { + case _: ArrayType => left.dataType.asInstanceOf[ArrayType].elementType + case _: MapType => left.dataType.asInstanceOf[MapType].valueType + } + + override def inputTypes: Seq[AbstractDataType] = { + Seq(TypeCollection(ArrayType, MapType), + left.dataType match { + case _: ArrayType => IntegerType + case _: MapType => left.dataType.asInstanceOf[MapType].keyType + } + ) + } + + override def nullable: Boolean = true + + override def nullSafeEval(value: Any, ordinal: Any): Any = { + left.dataType match { + case _: ArrayType => + val array = value.asInstanceOf[ArrayData] + val index = ordinal.asInstanceOf[Int] + if (array.numElements() < math.abs(index)) { + null + } else { + val idx = if (index == 0) { + throw new ArrayIndexOutOfBoundsException("SQL array indices start at 1") + } else if (index > 0) { + index - 1 + } else { + array.numElements() + index + } + if (array.isNullAt(idx)) { + null + } else { + array.get(idx, dataType) + } + } + case _: MapType => + getValueEval(value, ordinal, left.dataType.asInstanceOf[MapType].keyType) + } + } + + override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { + left.dataType match { + case _: ArrayType => + nullSafeCodeGen(ctx, ev, (eval1, eval2) => { + val index = ctx.freshName("elementAtIndex") + val nullCheck = if (left.dataType.asInstanceOf[ArrayType].containsNull) { + s""" + |if ($eval1.isNullAt($index)) { + | ${ev.isNull} = true; + |} else + """ + } else { + "" + } + s""" + |int $index = (int) $eval2; + |if ($eval1.numElements() < Math.abs($index)) { + | ${ev.isNull} = true; + |} else { + | if ($index == 0) { + | throw new ArrayIndexOutOfBoundsException("SQL array indices start at 1"); + | } else if ($index > 0) { + | $index--; + | } else { + | $index += $eval1.numElements(); + | } + | $nullCheck + | { + | ${ev.value} = ${CodeGenerator.getValue(eval1, dataType, index)}; + | } + |} + """ + }) + case _: MapType => + doGetValueGenCode(ctx, ev, left.dataType.asInstanceOf[MapType]) + } + } + + override def prettyName: String = "element_at" +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeExtractors.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeExtractors.scala index 6cdad19168dce..ce38c3335af5d 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeExtractors.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeExtractors.scala @@ -268,31 +268,12 @@ case class GetArrayItem(child: Expression, ordinal: Expression) } /** - * Returns the value of key `key` in Map `child`. - * - * We need to do type checking here as `key` expression maybe unresolved. + * Common base class for [[GetMapValue]] and [[ElementAt]]. */ -case class GetMapValue(child: Expression, key: Expression) - extends BinaryExpression with ImplicitCastInputTypes with ExtractValue with NullIntolerant { - - private def keyType = child.dataType.asInstanceOf[MapType].keyType - - // We have done type checking for child in `ExtractValue`, so only need to check the `key`. - override def inputTypes: Seq[AbstractDataType] = Seq(AnyDataType, keyType) - - override def toString: String = s"$child[$key]" - override def sql: String = s"${child.sql}[${key.sql}]" - - override def left: Expression = child - override def right: Expression = key - - /** `Null` is returned for invalid ordinals. */ - override def nullable: Boolean = true - - override def dataType: DataType = child.dataType.asInstanceOf[MapType].valueType +abstract class GetMapValueUtil extends BinaryExpression with ImplicitCastInputTypes { // todo: current search is O(n), improve it. - protected override def nullSafeEval(value: Any, ordinal: Any): Any = { + def getValueEval(value: Any, ordinal: Any, keyType: DataType): Any = { val map = value.asInstanceOf[MapData] val length = map.numElements() val keys = map.keyArray() @@ -315,14 +296,15 @@ case class GetMapValue(child: Expression, key: Expression) } } - override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { + def doGetValueGenCode(ctx: CodegenContext, ev: ExprCode, mapType: MapType): ExprCode = { val index = ctx.freshName("index") val length = ctx.freshName("length") val keys = ctx.freshName("keys") val found = ctx.freshName("found") val key = ctx.freshName("key") val values = ctx.freshName("values") - val nullCheck = if (child.dataType.asInstanceOf[MapType].valueContainsNull) { + val keyType = mapType.keyType + val nullCheck = if (mapType.valueContainsNull) { s" || $values.isNullAt($index)" } else { "" @@ -354,3 +336,37 @@ case class GetMapValue(child: Expression, key: Expression) }) } } + +/** + * Returns the value of key `key` in Map `child`. + * + * We need to do type checking here as `key` expression maybe unresolved. + */ +case class GetMapValue(child: Expression, key: Expression) extends GetMapValueUtil + with ExtractValue with NullIntolerant { + + private def keyType = child.dataType.asInstanceOf[MapType].keyType + + // We have done type checking for child in `ExtractValue`, so only need to check the `key`. + override def inputTypes: Seq[AbstractDataType] = Seq(AnyDataType, keyType) + + override def toString: String = s"$child[$key]" + override def sql: String = s"${child.sql}[${key.sql}]" + + override def left: Expression = child + override def right: Expression = key + + /** `Null` is returned for invalid ordinals. */ + override def nullable: Boolean = true + + override def dataType: DataType = child.dataType.asInstanceOf[MapType].valueType + + // todo: current search is O(n), improve it. + override def nullSafeEval(value: Any, ordinal: Any): Any = { + getValueEval(value, ordinal, keyType) + } + + override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { + doGetValueGenCode(ctx, ev, child.dataType.asInstanceOf[MapType]) + } +} diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala index 916cd3bb4cca5..7d8fe211858b2 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala @@ -191,4 +191,52 @@ class CollectionExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper checkEvaluation(ArrayPosition(a3, Literal("")), null) checkEvaluation(ArrayPosition(a3, Literal.create(null, StringType)), null) } + + test("elementAt") { + val a0 = Literal.create(Seq(1, 2, 3), ArrayType(IntegerType)) + val a1 = Literal.create(Seq[String](null, ""), ArrayType(StringType)) + val a2 = Literal.create(Seq(null), ArrayType(LongType)) + val a3 = Literal.create(null, ArrayType(StringType)) + + intercept[Exception] { + checkEvaluation(ElementAt(a0, Literal(0)), null) + }.getMessage.contains("SQL array indices start at 1") + intercept[Exception] { checkEvaluation(ElementAt(a0, Literal(1.1)), null) } + checkEvaluation(ElementAt(a0, Literal(4)), null) + checkEvaluation(ElementAt(a0, Literal(-4)), null) + + checkEvaluation(ElementAt(a0, Literal(1)), 1) + checkEvaluation(ElementAt(a0, Literal(2)), 2) + checkEvaluation(ElementAt(a0, Literal(3)), 3) + checkEvaluation(ElementAt(a0, Literal(-3)), 1) + checkEvaluation(ElementAt(a0, Literal(-2)), 2) + checkEvaluation(ElementAt(a0, Literal(-1)), 3) + + checkEvaluation(ElementAt(a1, Literal(1)), null) + checkEvaluation(ElementAt(a1, Literal(2)), "") + checkEvaluation(ElementAt(a1, Literal(-2)), null) + checkEvaluation(ElementAt(a1, Literal(-1)), "") + + checkEvaluation(ElementAt(a2, Literal(1)), null) + + checkEvaluation(ElementAt(a3, Literal(1)), null) + + + val m0 = + Literal.create(Map("a" -> "1", "b" -> "2", "c" -> null), MapType(StringType, StringType)) + val m1 = Literal.create(Map[String, String](), MapType(StringType, StringType)) + val m2 = Literal.create(null, MapType(StringType, StringType)) + + checkEvaluation(ElementAt(m0, Literal(1.0)), null) + + checkEvaluation(ElementAt(m0, Literal("d")), null) + + checkEvaluation(ElementAt(m1, Literal("a")), null) + + checkEvaluation(ElementAt(m0, Literal("a")), "1") + checkEvaluation(ElementAt(m0, Literal("b")), "2") + checkEvaluation(ElementAt(m0, Literal("c")), null) + + checkEvaluation(ElementAt(m2, Literal("a")), null) + } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala index 3a09ec4f1982e..9c8580378303e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala @@ -3052,6 +3052,17 @@ object functions { ArrayPosition(column.expr, Literal(value)) } + /** + * Returns element of array at given index in value if column is array. Returns value for + * the given key in value if column is map. + * + * @group collection_funcs + * @since 2.4.0 + */ + def element_at(column: Column, value: Any): Column = withExpr { + ElementAt(column.expr, Literal(value)) + } + /** * Creates a new row for each element in the given array or map column. * diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala index 13161e7e24cfe..403497af9d7c4 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala @@ -566,6 +566,77 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSQLContext { checkAnswer( df.selectExpr("array_position(array(1, null), array(1, null)[0])"), Seq(Row(1L), Row(1L)) + } + + test("element at function") { + val df = Seq( + (Seq[String]("1", "2", "3")), + (Seq[String](null, "")), + (Seq[String]()) + ).toDF("a") + + intercept[Exception] { + checkAnswer( + df.select(element_at(df("a"), 0)), + Seq(Row(null), Row(null), Row(null)) + ) + }.getMessage.contains("SQL array indices start at 1") + intercept[Exception] { + checkAnswer( + df.select(element_at(df("a"), 1.1)), + Seq(Row(null), Row(null), Row(null)) + ) + } + checkAnswer( + df.select(element_at(df("a"), 4)), + Seq(Row(null), Row(null), Row(null)) + ) + checkAnswer( + df.select(element_at(df("a"), -4)), + Seq(Row(null), Row(null), Row(null)) + ) + + checkAnswer( + df.select(element_at(df("a"), 1)), + Seq(Row("1"), Row(null), Row(null)) + ) + checkAnswer( + df.select(element_at(df("a"), 2)), + Seq(Row("2"), Row(""), Row(null)) + ) + checkAnswer( + df.select(element_at(df("a"), -1)), + Seq(Row("3"), Row(""), Row(null)) + ) + checkAnswer( + df.select(element_at(df("a"), -2)), + Seq(Row("2"), Row(null), Row(null)) + ) + + checkAnswer( + df.selectExpr("element_at(a, 4)"), + Seq(Row(null), Row(null), Row(null)) + ) + checkAnswer( + df.selectExpr("element_at(a, -4)"), + Seq(Row(null), Row(null), Row(null)) + ) + + checkAnswer( + df.selectExpr("element_at(a, 1)"), + Seq(Row("1"), Row(null), Row(null)) + ) + checkAnswer( + df.selectExpr("element_at(a, 2)"), + Seq(Row("2"), Row(""), Row(null)) + ) + checkAnswer( + df.selectExpr("element_at(a, -1)"), + Seq(Row("3"), Row(""), Row(null)) + ) + checkAnswer( + df.selectExpr("element_at(a, -2)"), + Seq(Row("2"), Row(null), Row(null)) ) } From 77b572a6fd7f074fdf455b1d063771570345b563 Mon Sep 17 00:00:00 2001 From: Kazuaki Ishizaki Date: Fri, 13 Apr 2018 05:26:09 +0100 Subject: [PATCH 02/12] fix test failure --- python/pyspark/sql/functions.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/pyspark/sql/functions.py b/python/pyspark/sql/functions.py index cd70bc57beec3..d4cd16d66e470 100644 --- a/python/pyspark/sql/functions.py +++ b/python/pyspark/sql/functions.py @@ -1873,7 +1873,7 @@ def element_at(col, value): >>> df = spark.createDataFrame([(["a", "b", "c"],), ([],)], ['data']) >>> df.select(element_at(df.data, 1)).collect() - [Row(element_at(data, 1)=u'a'), Row(element_at(data, 1)=None)] + [Row(element_at(data, 1)='a'), Row(element_at(data, 1)=None)] >>> df = spark.createDataFrame([({"a": 1.0, "b": 2.0},), ({},)], ['data']) >>> df.select(element_at(df.data, "a")).collect() From 9d38e4bc29e2dda9d1af5ce8adfcb4ad024e29ad Mon Sep 17 00:00:00 2001 From: Kazuaki Ishizaki Date: Fri, 13 Apr 2018 11:49:15 +0100 Subject: [PATCH 03/12] fix test failure --- python/pyspark/sql/functions.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/pyspark/sql/functions.py b/python/pyspark/sql/functions.py index d4cd16d66e470..cd70bc57beec3 100644 --- a/python/pyspark/sql/functions.py +++ b/python/pyspark/sql/functions.py @@ -1873,7 +1873,7 @@ def element_at(col, value): >>> df = spark.createDataFrame([(["a", "b", "c"],), ([],)], ['data']) >>> df.select(element_at(df.data, 1)).collect() - [Row(element_at(data, 1)='a'), Row(element_at(data, 1)=None)] + [Row(element_at(data, 1)=u'a'), Row(element_at(data, 1)=None)] >>> df = spark.createDataFrame([({"a": 1.0, "b": 2.0},), ({},)], ['data']) >>> df.select(element_at(df.data, "a")).collect() From 61c904ba69e4bbcdb927c393292010ad8f930f34 Mon Sep 17 00:00:00 2001 From: Kazuaki Ishizaki Date: Fri, 13 Apr 2018 19:39:16 +0100 Subject: [PATCH 04/12] address review comment --- .../spark/sql/catalyst/expressions/collectionOperations.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala index 9257def1aff20..c26d5a1e7a2d7 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala @@ -563,7 +563,7 @@ case class ArrayPosition(left: Expression, right: Expression) } /** - * Returns the value of index `right` in Array `left` or key `right` in Map `left`. + * Returns the value of index `right` in Array `left` or the value for key `right` in Map `left`. */ @ExpressionDescription( usage = """ From 872a5006e16935528677be78d64e75bff55e96fe Mon Sep 17 00:00:00 2001 From: Kazuaki Ishizaki Date: Mon, 16 Apr 2018 03:32:34 +0100 Subject: [PATCH 05/12] address review comment --- .../spark/sql/DataFrameFunctionsSuite.scala | 27 ++----------------- 1 file changed, 2 insertions(+), 25 deletions(-) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala index 403497af9d7c4..fa5173fd64f24 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala @@ -535,6 +535,7 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSQLContext { } } +<<<<<<< HEAD test("array position function") { val df = Seq( (Seq[Int](1, 2), "x"), @@ -568,7 +569,7 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSQLContext { Seq(Row(1L), Row(1L)) } - test("element at function") { + test("element_at function") { val df = Seq( (Seq[String]("1", "2", "3")), (Seq[String](null, "")), @@ -591,53 +592,29 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSQLContext { df.select(element_at(df("a"), 4)), Seq(Row(null), Row(null), Row(null)) ) - checkAnswer( - df.select(element_at(df("a"), -4)), - Seq(Row(null), Row(null), Row(null)) - ) checkAnswer( df.select(element_at(df("a"), 1)), Seq(Row("1"), Row(null), Row(null)) ) - checkAnswer( - df.select(element_at(df("a"), 2)), - Seq(Row("2"), Row(""), Row(null)) - ) checkAnswer( df.select(element_at(df("a"), -1)), Seq(Row("3"), Row(""), Row(null)) ) - checkAnswer( - df.select(element_at(df("a"), -2)), - Seq(Row("2"), Row(null), Row(null)) - ) checkAnswer( df.selectExpr("element_at(a, 4)"), Seq(Row(null), Row(null), Row(null)) ) - checkAnswer( - df.selectExpr("element_at(a, -4)"), - Seq(Row(null), Row(null), Row(null)) - ) checkAnswer( df.selectExpr("element_at(a, 1)"), Seq(Row("1"), Row(null), Row(null)) ) - checkAnswer( - df.selectExpr("element_at(a, 2)"), - Seq(Row("2"), Row(""), Row(null)) - ) checkAnswer( df.selectExpr("element_at(a, -1)"), Seq(Row("3"), Row(""), Row(null)) ) - checkAnswer( - df.selectExpr("element_at(a, -2)"), - Seq(Row("2"), Row(null), Row(null)) - ) } private def assertValuesDoNotChangeAfterCoalesceOrUnion(v: Column): Unit = { From d167b4b8699b4dfaab5249b1c85a0506e78623b8 Mon Sep 17 00:00:00 2001 From: Kazuaki Ishizaki Date: Mon, 16 Apr 2018 17:22:13 +0100 Subject: [PATCH 06/12] address review comments --- python/pyspark/sql/functions.py | 1 + .../sql/catalyst/expressions/collectionOperations.scala | 8 ++++---- .../sql/catalyst/expressions/complexTypeExtractors.scala | 4 ++-- 3 files changed, 7 insertions(+), 6 deletions(-) diff --git a/python/pyspark/sql/functions.py b/python/pyspark/sql/functions.py index cd70bc57beec3..df0ab73c6740c 100644 --- a/python/pyspark/sql/functions.py +++ b/python/pyspark/sql/functions.py @@ -1845,6 +1845,7 @@ def array_contains(col, value): return Column(sc._jvm.functions.array_contains(_to_java_column(col), value)) +@ignore_unicode_prefix @since(2.4) def array_position(col, value): """ diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala index c26d5a1e7a2d7..f8ab31656b27b 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala @@ -568,7 +568,7 @@ case class ArrayPosition(left: Expression, right: Expression) @ExpressionDescription( usage = """ _FUNC_(array, index) - Returns element of array at given index. If index < 0, accesses elements - from the last to the first. + from the last to the first. Returns NULL if the index exceeds the length of the array. _FUNC_(map, key) - Returns value for given key, or NULL if the key is not contained in the map """, @@ -613,7 +613,7 @@ case class ElementAt(left: Expression, right: Expression) extends GetMapValueUti } else { array.numElements() + index } - if (array.isNullAt(idx)) { + if (left.dataType.asInstanceOf[ArrayType].containsNull && array.isNullAt(idx)) { null } else { array.get(idx, dataType) @@ -634,7 +634,7 @@ case class ElementAt(left: Expression, right: Expression) extends GetMapValueUti |if ($eval1.isNullAt($index)) { | ${ev.isNull} = true; |} else - """ + """.stripMargin } else { "" } @@ -655,7 +655,7 @@ case class ElementAt(left: Expression, right: Expression) extends GetMapValueUti | ${ev.value} = ${CodeGenerator.getValue(eval1, dataType, index)}; | } |} - """ + """.stripMargin }) case _: MapType => doGetValueGenCode(ctx, ev, left.dataType.asInstanceOf[MapType]) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeExtractors.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeExtractors.scala index ce38c3335af5d..3fba52d745453 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeExtractors.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeExtractors.scala @@ -342,8 +342,8 @@ abstract class GetMapValueUtil extends BinaryExpression with ImplicitCastInputTy * * We need to do type checking here as `key` expression maybe unresolved. */ -case class GetMapValue(child: Expression, key: Expression) extends GetMapValueUtil - with ExtractValue with NullIntolerant { +case class GetMapValue(child: Expression, key: Expression) + extends GetMapValueUtil with ExtractValue with NullIntolerant { private def keyType = child.dataType.asInstanceOf[MapType].keyType From 5163679c8aa4bc58cedcd96c13ef503b2b1f0598 Mon Sep 17 00:00:00 2001 From: Kazuaki Ishizaki Date: Wed, 18 Apr 2018 13:25:32 +0100 Subject: [PATCH 07/12] address review comments --- python/pyspark/sql/functions.py | 6 ++++-- .../sql/catalyst/expressions/collectionOperations.scala | 9 +++++---- 2 files changed, 9 insertions(+), 6 deletions(-) diff --git a/python/pyspark/sql/functions.py b/python/pyspark/sql/functions.py index df0ab73c6740c..1caa1762c21af 100644 --- a/python/pyspark/sql/functions.py +++ b/python/pyspark/sql/functions.py @@ -1864,13 +1864,15 @@ def array_position(col, value): @since(2.4) -def element_at(col, value): +def element_at(col, index): """ Collection function: returns element of array at given index in value if col is array. returns value for the given key in value if col is map. :param col: name of column containing array or map - :param value: value to check for in array or key to check for in map + :param index: index to check for in array or key to check for in map + + .. note:: The position is not zero based, but 1 based index. >>> df = spark.createDataFrame([(["a", "b", "c"],), ([],)], ['data']) >>> df.select(element_at(df.data, 1)).collect() diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala index f8ab31656b27b..3cac0f3cfe25b 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala @@ -567,8 +567,9 @@ case class ArrayPosition(left: Expression, right: Expression) */ @ExpressionDescription( usage = """ - _FUNC_(array, index) - Returns element of array at given index. If index < 0, accesses elements - from the last to the first. Returns NULL if the index exceeds the length of the array. + _FUNC_(array, index) - Returns element of array at given (1-based) index. If index < 0, + accesses elements from the last to the first. Returns NULL if the index exceeds the length + of the array. _FUNC_(map, key) - Returns value for given key, or NULL if the key is not contained in the map """, @@ -583,8 +584,8 @@ case class ArrayPosition(left: Expression, right: Expression) case class ElementAt(left: Expression, right: Expression) extends GetMapValueUtil { override def dataType: DataType = left.dataType match { - case _: ArrayType => left.dataType.asInstanceOf[ArrayType].elementType - case _: MapType => left.dataType.asInstanceOf[MapType].valueType + case ArrayType(elementType, _) => elementType + case MapType(_, valueType, _) => valueType } override def inputTypes: Seq[AbstractDataType] = { From a70539727261a111fc5d886f2509229f913dd00f Mon Sep 17 00:00:00 2001 From: Kazuaki Ishizaki Date: Wed, 18 Apr 2018 17:14:34 +0100 Subject: [PATCH 08/12] fix python test failure --- python/pyspark/sql/functions.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/python/pyspark/sql/functions.py b/python/pyspark/sql/functions.py index 1caa1762c21af..e0c103c34b9e4 100644 --- a/python/pyspark/sql/functions.py +++ b/python/pyspark/sql/functions.py @@ -1864,13 +1864,13 @@ def array_position(col, value): @since(2.4) -def element_at(col, index): +def element_at(col, extraction): """ - Collection function: returns element of array at given index in value if col is array. - returns value for the given key in value if col is map. + Collection function: returns element of array at given index in extraction if col is array. + returns value for the given key in extraction if col is map. :param col: name of column containing array or map - :param index: index to check for in array or key to check for in map + :param extraction: index to check for in array or key to check for in map .. note:: The position is not zero based, but 1 based index. @@ -1883,7 +1883,7 @@ def element_at(col, index): [Row(element_at(data, a)=1.0), Row(element_at(data, a)=None)] """ sc = SparkContext._active_spark_context - return Column(sc._jvm.functions.element_at(_to_java_column(col), value)) + return Column(sc._jvm.functions.element_at(_to_java_column(col), extraction)) @since(1.4) From 2fbb1e8fbc17e95560a18170fbd3ad0e1f4d78ff Mon Sep 17 00:00:00 2001 From: Kazuaki Ishizaki Date: Thu, 19 Apr 2018 04:48:18 +0100 Subject: [PATCH 09/12] rebase with master --- .../spark/sql/catalyst/expressions/collectionOperations.scala | 2 +- .../scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala | 1 - 2 files changed, 1 insertion(+), 2 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala index 3cac0f3cfe25b..dba426e999dda 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala @@ -506,6 +506,7 @@ case class ArrayMax(child: Expression) extends UnaryExpression with ImplicitCast override def prettyName: String = "array_max" } + /** * Returns the position of the first occurrence of element in the given array as long. * Returns 0 if the given value could not be found in the array. Returns null if either of @@ -528,7 +529,6 @@ case class ArrayPosition(left: Expression, right: Expression) extends BinaryExpression with ImplicitCastInputTypes { override def dataType: DataType = LongType - override def inputTypes: Seq[AbstractDataType] = Seq(ArrayType, left.dataType.asInstanceOf[ArrayType].elementType) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala index fa5173fd64f24..e0203407c5beb 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala @@ -535,7 +535,6 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSQLContext { } } -<<<<<<< HEAD test("array position function") { val df = Seq( (Seq[Int](1, 2), "x"), From 06fb27e68ddda7075c0edef14df50cf56a370b4e Mon Sep 17 00:00:00 2001 From: Kazuaki Ishizaki Date: Thu, 19 Apr 2018 06:16:18 +0100 Subject: [PATCH 10/12] fix mistakes during rebase --- python/pyspark/sql/functions.py | 3 +-- .../scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala | 1 + 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/python/pyspark/sql/functions.py b/python/pyspark/sql/functions.py index e0c103c34b9e4..cf7e4511b01ef 100644 --- a/python/pyspark/sql/functions.py +++ b/python/pyspark/sql/functions.py @@ -1845,7 +1845,6 @@ def array_contains(col, value): return Column(sc._jvm.functions.array_contains(_to_java_column(col), value)) -@ignore_unicode_prefix @since(2.4) def array_position(col, value): """ @@ -1866,7 +1865,7 @@ def array_position(col, value): @since(2.4) def element_at(col, extraction): """ - Collection function: returns element of array at given index in extraction if col is array. + Collection function: Returns element of array at given index in extraction if col is array. returns value for the given key in extraction if col is map. :param col: name of column containing array or map diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala index e0203407c5beb..7c976c1b7f915 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala @@ -566,6 +566,7 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSQLContext { checkAnswer( df.selectExpr("array_position(array(1, null), array(1, null)[0])"), Seq(Row(1L), Row(1L)) + ) } test("element_at function") { From 96dd82b6385e5114496d0541d744c9fce48f8db0 Mon Sep 17 00:00:00 2001 From: Kazuaki Ishizaki Date: Thu, 19 Apr 2018 07:30:52 +0100 Subject: [PATCH 11/12] add @ignore_unicode_prefix --- python/pyspark/sql/functions.py | 1 + 1 file changed, 1 insertion(+) diff --git a/python/pyspark/sql/functions.py b/python/pyspark/sql/functions.py index cf7e4511b01ef..2753f12269023 100644 --- a/python/pyspark/sql/functions.py +++ b/python/pyspark/sql/functions.py @@ -1862,6 +1862,7 @@ def array_position(col, value): return Column(sc._jvm.functions.array_position(_to_java_column(col), value)) +@ignore_unicode_prefix @since(2.4) def element_at(col, extraction): """ From 90e026e9b2d58e17995d21e44c5a68fa7f0f7d52 Mon Sep 17 00:00:00 2001 From: Kazuaki Ishizaki Date: Thu, 19 Apr 2018 07:51:01 +0100 Subject: [PATCH 12/12] Improve comment --- python/pyspark/sql/functions.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/pyspark/sql/functions.py b/python/pyspark/sql/functions.py index 2753f12269023..1be68f2a4a448 100644 --- a/python/pyspark/sql/functions.py +++ b/python/pyspark/sql/functions.py @@ -1867,7 +1867,7 @@ def array_position(col, value): def element_at(col, extraction): """ Collection function: Returns element of array at given index in extraction if col is array. - returns value for the given key in extraction if col is map. + Returns value for the given key in extraction if col is map. :param col: name of column containing array or map :param extraction: index to check for in array or key to check for in map