From ffcf1405cc2cd79b45167977ae02f59684f6df50 Mon Sep 17 00:00:00 2001 From: Marek Novotny Date: Wed, 18 Apr 2018 10:00:27 +0200 Subject: [PATCH 1/8] [SPARK-24042][SQL] Collection function: zip_with_index --- python/pyspark/sql/functions.py | 21 ++- .../catalyst/analysis/FunctionRegistry.scala | 1 + .../expressions/collectionOperations.scala | 137 ++++++++++++++++++ .../CollectionExpressionsSuite.scala | 54 +++++++ .../expressions/ExpressionEvalHelper.scala | 3 + .../org/apache/spark/sql/functions.scala | 11 ++ .../spark/sql/DataFrameFunctionsSuite.scala | 101 +++++++++++++ 7 files changed, 317 insertions(+), 11 deletions(-) diff --git a/python/pyspark/sql/functions.py b/python/pyspark/sql/functions.py index da32ab25cad0c..080350a32f1b6 100644 --- a/python/pyspark/sql/functions.py +++ b/python/pyspark/sql/functions.py @@ -2172,23 +2172,22 @@ def sort_array(col, asc=True): return Column(sc._jvm.functions.sort_array(_to_java_column(col), asc)) -@since(1.5) -@ignore_unicode_prefix -def reverse(col): +@since(2.4) +def zip_with_index(col, indexFirst=False): """ - Collection function: returns a reversed string or an array with reverse order of elements. + Collection function: transforms the input array by encapsulating elements into pairs + with indexes indicating the order. :param col: name of column or expression - >>> df = spark.createDataFrame([('Spark SQL',)], ['data']) - >>> df.select(reverse(df.data).alias('s')).collect() - [Row(s=u'LQS krapS')] - >>> df = spark.createDataFrame([([2, 1, 3],) ,([1],) ,([],)], ['data']) - >>> df.select(reverse(df.data).alias('r')).collect() - [Row(r=[3, 1, 2]), Row(r=[1]), Row(r=[])] + >>> df = spark.createDataFrame([([2, 5, 3],), ([],)], ['data']) + >>> df.select(zip_with_index(df.data).alias('r')).collect() + [Row(r=[[value=2, index=0], [value=5, index=1], [value=3, index=2]]), Row(r=[])] + >>> df.select(zip_with_index(df.data, indexFirst=True).alias('r')).collect() + [Row(r=[[index=0, value=2], [index=1, value=5], [index=2, value=3]]), Row(r=[])] """ sc = SparkContext._active_spark_context - return Column(sc._jvm.functions.reverse(_to_java_column(col))) + return Column(sc._jvm.functions.zip_with_index(_to_java_column(col), indexFirst)) @since(2.3) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala index c41f16c61d7a2..bcd74624dfece 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala @@ -413,6 +413,7 @@ object FunctionRegistry { expression[ArrayMax]("array_max"), expression[Reverse]("reverse"), expression[Concat]("concat"), + expression[ZipWithIndex]("zip_with_index"), CreateStruct.registryEntry, // misc functions diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala index c16793bda028e..1f1c2b6e7a2fd 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala @@ -18,6 +18,7 @@ package org.apache.spark.sql.catalyst.expressions import java.util.Comparator +import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.analysis.TypeCheckResult import org.apache.spark.sql.catalyst.expressions.codegen._ @@ -883,3 +884,139 @@ case class Concat(children: Seq[Expression]) extends Expression { override def sql: String = s"concat(${children.map(_.sql).mkString(", ")})" } + +/** + * Returns the maximum value in the array. + */ +// scalastyle:off line.size.limit +@ExpressionDescription( + usage = "_FUNC_(array[, indexFirst]) - Transforms the input array by encapsulating elements into pairs with indexes indicating the order.", + examples = """ + Examples: + > SELECT _FUNC_(array("d", "a", null, "b")); + [("d",0),("a",1),(null,2),("b",3)] + > SELECT _FUNC_(array("d", "a", null, "b"), true); + [(0,"d"),(1,"a"),(2,null),(3,"b")] + """, + since = "2.4.0") +case class ZipWithIndex(child: Expression, indexFirst: Expression) + extends UnaryExpression with ExpectsInputTypes { + + def this(e: Expression) = this(e, Literal.FalseLiteral) + + val indexFirstValue: Boolean = indexFirst match { + case Literal(v: Boolean, BooleanType) => v + case _ => throw new AnalysisException("The second argument has to be a boolean constant.") + } + + private val MAX_ARRAY_LENGTH: Int = ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH + + override def inputTypes: Seq[AbstractDataType] = Seq(ArrayType) + + lazy val childArrayType: ArrayType = child.dataType.asInstanceOf[ArrayType] + + override def dataType: DataType = { + val elementField = StructField("value", childArrayType.elementType, childArrayType.containsNull) + val indexField = StructField("index", IntegerType, false) + + val fields = if (indexFirstValue) Seq(indexField, elementField) else Seq(elementField, indexField) + + ArrayType(StructType(fields), false) + } + + override protected def nullSafeEval(input: Any): Any = { + val array = input.asInstanceOf[ArrayData].toObjectArray(childArrayType.elementType) + + val makeStruct = (v: Any, i: Int) => if (indexFirstValue) InternalRow(i, v) else InternalRow(v, i) + val resultData = array.zipWithIndex.map{case (v, i) => makeStruct(v, i)} + + new GenericArrayData(resultData) + } + + override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { + nullSafeCodeGen(ctx, ev, c => { + if (CodeGenerator.isPrimitiveType(childArrayType.elementType)) { + genCodeForPrimitiveElements(ctx, c, ev.value) + } else { + genCodeForNonPrimitiveElements(ctx, c, ev.value) + } + }) + } + + private def genCodeForPrimitiveElements( + ctx: CodegenContext, + childVariableName: String, + arrayData: String): String = { + val numElements = ctx.freshName("numElements") + val byteArraySize = ctx.freshName("byteArraySize") + val data = ctx.freshName("byteArray") + val unsafeRow = ctx.freshName("unsafeRow") + val structSize = ctx.freshName("structSize") + val unsafeArrayData = ctx.freshName("unsafeArrayData") + val structsOffset = ctx.freshName("structsOffset") + val calculateArraySize = "UnsafeArrayData.calculateSizeOfUnderlyingByteArray" + val calculateHeader = "UnsafeArrayData.calculateHeaderPortionInBytes" + + val baseOffset = Platform.BYTE_ARRAY_OFFSET + val longSize = LongType.defaultSize + val primitiveValueTypeName = CodeGenerator.primitiveTypeName(childArrayType.elementType) + val valuePosition = if (indexFirstValue) "1" else "0" + val indexPosition = if (indexFirstValue) "0" else "1" + s""" + |final int $numElements = $childVariableName.numElements(); + |final int $structSize = ${UnsafeRow.calculateBitSetWidthInBytes(2) + longSize * 2}; + |final long $byteArraySize = $calculateArraySize($numElements, $longSize + $structSize); + |final int $structsOffset = $calculateHeader($numElements) + $numElements * $longSize; + |if ($byteArraySize > $MAX_ARRAY_LENGTH) { + | throw new RuntimeException("Unsuccessful try to zip array with index due to exceeding" + + | " the limit $MAX_ARRAY_LENGTH bytes for UnsafeArrayData. " + $byteArraySize + + | " bytes of data are required for performing the operation with the given array."); + |} + |final byte[] $data = new byte[(int)$byteArraySize]; + |UnsafeArrayData $unsafeArrayData = new UnsafeArrayData(); + |Platform.putLong($data, $baseOffset, $numElements); + |$unsafeArrayData.pointTo($data, $baseOffset, (int)$byteArraySize); + |UnsafeRow $unsafeRow = new UnsafeRow(2); + |for (int z = 0; z < $numElements; z++) { + | long offset = $structsOffset + z * $structSize; + | $unsafeArrayData.setLong(z, (offset << 32) + $structSize); + | $unsafeRow.pointTo($data, $baseOffset + offset, $structSize); + | if ($childVariableName.isNullAt(z)) { + | $unsafeRow.setNullAt($valuePosition); + | } else { + | $unsafeRow.set$primitiveValueTypeName( + | $valuePosition, + | ${CodeGenerator.getValue(childVariableName, childArrayType.elementType, "z")} + | ); + | } + | $unsafeRow.setInt($indexPosition, z); + |} + |$arrayData = $unsafeArrayData; + """.stripMargin + } + + private def genCodeForNonPrimitiveElements( + ctx: CodegenContext, + childVariableName: String, + arrayData: String): String = { + val genericArrayClass = classOf[GenericArrayData].getName + val rowClass = classOf[GenericInternalRow].getName + val numberOfElements = ctx.freshName("numElements") + val data = ctx.freshName("internalRowArray") + + val elementValue = CodeGenerator.getValue(childVariableName, childArrayType.elementType, "z") + val arguments = if (indexFirstValue) s"z, $elementValue" else s"$elementValue, z" + + s""" + |final int $numberOfElements = $childVariableName.numElements(); + |final Object[] $data = new Object[$numberOfElements]; + |for (int z = 0; z < $numberOfElements; z++) { + | $data[z] = new $rowClass(new Object[]{$arguments}); + |} + |$arrayData = new $genericArrayClass($data); + """.stripMargin + } + + override def prettyName: String = "zip_with_index" +} + diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala index 43c5dda2e4a48..78202cc857f8b 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala @@ -18,6 +18,7 @@ package org.apache.spark.sql.catalyst.expressions import org.apache.spark.SparkFunSuite +import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.types._ class CollectionExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { @@ -280,4 +281,57 @@ class CollectionExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper checkEvaluation(Concat(Seq(aa0, aa1)), Seq(Seq("a", "b"), Seq("c"), Seq("d"), Seq("e", "f"))) } + + test("Zip With Index") { + def r(values: Any*): InternalRow = create_row(values: _*) + val t = Literal.TrueLiteral + val f = Literal.FalseLiteral + + // Primitive-type elements + val ai0 = Literal.create(Seq(2, 8, 4, 7), ArrayType(IntegerType)) + val ai1 = Literal.create(Seq(null, 4, null, 2), ArrayType(IntegerType)) + val ai2 = Literal.create(Seq(null, null, null), ArrayType(IntegerType)) + val ai3 = Literal.create(Seq(1), ArrayType(IntegerType)) + val ai4 = Literal.create(Seq.empty, ArrayType(IntegerType)) + val ai5 = Literal.create(null, ArrayType(IntegerType)) + + checkEvaluation(ZipWithIndex(ai0, f), Seq(r(2, 0), r(8, 1), r(4, 2), r(7, 3))) + checkEvaluation(ZipWithIndex(ai1, f), Seq(r(null, 0), r(4, 1), r(null, 2), r(2, 3))) + checkEvaluation(ZipWithIndex(ai2, f), Seq(r(null, 0), r(null, 1), r(null, 2))) + checkEvaluation(ZipWithIndex(ai3, f), Seq(r(1, 0))) + checkEvaluation(ZipWithIndex(ai4, f), Seq.empty) + checkEvaluation(ZipWithIndex(ai5, f), null) + + checkEvaluation(ZipWithIndex(ai0, t), Seq(r(0, 2), r(1, 8), r(2, 4), r(3, 7))) + checkEvaluation(ZipWithIndex(ai1, t), Seq(r(0, null), r(1, 4), r(2, null), r(3, 2))) + checkEvaluation(ZipWithIndex(ai2, t), Seq(r(0, null), r(1, null), r(2, null))) + checkEvaluation(ZipWithIndex(ai3, t), Seq(r(0, 1))) + checkEvaluation(ZipWithIndex(ai4, t), Seq.empty) + checkEvaluation(ZipWithIndex(ai5, t), null) + + // Non-primitive-type elements + val as0 = Literal.create(Seq("b", "a", "y", "z"), ArrayType(StringType)) + val as1 = Literal.create(Seq(null, "x", null, "y"), ArrayType(StringType)) + val as2 = Literal.create(Seq(null, null, null), ArrayType(StringType)) + val as3 = Literal.create(Seq("a"), ArrayType(StringType)) + val as4 = Literal.create(Seq.empty, ArrayType(StringType)) + val as5 = Literal.create(null, ArrayType(StringType)) + val aas = Literal.create(Seq(Seq("e"), Seq("c", "d")), ArrayType(ArrayType(StringType))) + + checkEvaluation(ZipWithIndex(as0, f), Seq(r("b", 0), r("a", 1), r("y", 2), r("z", 3))) + checkEvaluation(ZipWithIndex(as1, f), Seq(r(null, 0), r("x", 1), r(null, 2), r("y", 3))) + checkEvaluation(ZipWithIndex(as2, f), Seq(r(null, 0), r(null, 1), r(null, 2))) + checkEvaluation(ZipWithIndex(as3, f), Seq(r("a", 0))) + checkEvaluation(ZipWithIndex(as4, f), Seq.empty) + checkEvaluation(ZipWithIndex(as5, f), null) + checkEvaluation(ZipWithIndex(aas, f), Seq(r(Seq("e"), 0), r(Seq("c", "d"), 1))) + + checkEvaluation(ZipWithIndex(as0, t), Seq(r(0, "b"), r(1, "a"), r(2, "y"), r(3, "z"))) + checkEvaluation(ZipWithIndex(as1, t), Seq(r(0, null), r(1, "x"), r(2, null), r(3, "y"))) + checkEvaluation(ZipWithIndex(as2, t), Seq(r(0, null), r(1, null), r(2, null))) + checkEvaluation(ZipWithIndex(as3, t), Seq(r(0, "a"))) + checkEvaluation(ZipWithIndex(as4, t), Seq.empty) + checkEvaluation(ZipWithIndex(as5, t), null) + checkEvaluation(ZipWithIndex(aas, t), Seq(r(0, Seq("e")), r(1, Seq("c", "d")))) + } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelper.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelper.scala index b4bf6d7107d7e..e1423b149ec8a 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelper.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelper.scala @@ -98,6 +98,9 @@ trait ExpressionEvalHelper extends GeneratorDrivenPropertyChecks { if (expected.isNaN) result.isNaN else expected == result case (result: Float, expected: Float) => if (expected.isNaN) result.isNaN else expected == result + case (result: InternalRow, expected: InternalRow) => + val structType = exprDataType.asInstanceOf[StructType] + result.toSeq(structType) == expected.toSeq(structType) case (result: Row, expected: InternalRow) => result.toSeq == expected.toSeq(result.schema) case _ => result == expected diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala index bea8c0e445002..93c9c7411abf6 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala @@ -3340,6 +3340,17 @@ object functions { */ def reverse(e: Column): Column = withExpr { Reverse(e.expr) } + /** + * Transforms the input array by encapsulating elements into pairs + * with indexes indicating the order. + * + * @group collection_funcs + * @since 2.4.0 + */ + def zip_with_index(e: Column, indexFirst: Boolean = false): Column = withExpr { + ZipWithIndex(e.expr, Literal(indexFirst)) + } + /** * Returns an unordered array containing the keys of the map. * @group collection_funcs diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala index 25e5cd60dd236..76f12b54c8e01 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala @@ -691,6 +691,107 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSQLContext { } } + test("zip_with_index function") { + val dummyFilter = (c: Column) => c.isNull || c.isNotNull // switch codegen on + val oneRowDF = Seq(("Spark", 3215, true)).toDF("s", "i", "b") + + // Test cases with primitive-type elements + val idf = Seq( + Seq(1, 9, 8, 7), + Seq.empty, + null + ).toDF("i") + + checkAnswer( + idf.select(zip_with_index('i)), + Seq(Row(Seq(Row(1, 0), Row(9, 1), Row(8, 2), Row(7, 3))), Row(Seq.empty), Row(null)) + ) + checkAnswer( + idf.filter(dummyFilter('i)).select(zip_with_index('i)), + Seq(Row(Seq(Row(1, 0), Row(9, 1), Row(8, 2), Row(7, 3))), Row(Seq.empty), Row(null)) + ) + checkAnswer( + idf.select(zip_with_index('i, true)), + Seq(Row(Seq(Row(0, 1), Row(1, 9), Row(2, 8), Row(3, 7))), Row(Seq.empty), Row(null)) + ) + checkAnswer( + idf.selectExpr("zip_with_index(i)"), + Seq(Row(Seq(Row(1, 0), Row(9, 1), Row(8, 2), Row(7, 3))), Row(Seq.empty), Row(null)) + ) + checkAnswer( + idf.selectExpr("zip_with_index(i, true)"), + Seq(Row(Seq(Row(0, 1), Row(1, 9), Row(2, 8), Row(3, 7))), Row(Seq.empty), Row(null)) + ) + checkAnswer( + oneRowDF.selectExpr("zip_with_index(array(null, 2, null), false)"), + Seq(Row(Seq(Row(null, 0), Row(2, 1), Row(null, 2)))) + ) + checkAnswer( + oneRowDF.selectExpr("zip_with_index(array(null, 2, null), true)"), + Seq(Row(Seq(Row(0, null), Row(1, 2), Row(2, null)))) + ) + + // Test cases with non-primitive-type elements + val sdf = Seq( + Seq("c", "a", "d", "b"), + Seq(null, "x", null), + Seq.empty, + null + ).toDF("s") + + checkAnswer( + sdf.select(zip_with_index('s)), + Seq( + Row(Seq(Row("c", 0), Row("a", 1), Row("d", 2), Row("b", 3))), + Row(Seq(Row(null, 0), Row("x", 1), Row(null, 2))), + Row(Seq.empty), + Row(null)) + ) + checkAnswer( + sdf.filter(dummyFilter('s)).select(zip_with_index('s)), + Seq( + Row(Seq(Row("c", 0), Row("a", 1), Row("d", 2), Row("b", 3))), + Row(Seq(Row(null, 0), Row("x", 1), Row(null, 2))), + Row(Seq.empty), + Row(null)) + ) + checkAnswer( + sdf.select(zip_with_index('s, true)), + Seq( + Row(Seq(Row(0, "c"), Row(1, "a"), Row(2, "d"), Row(3, "b"))), + Row(Seq(Row(0, null), Row(1, "x"), Row(2, null))), + Row(Seq.empty), + Row(null)) + ) + checkAnswer( + sdf.selectExpr("zip_with_index(s)"), + Seq( + Row(Seq(Row("c", 0), Row("a", 1), Row("d", 2), Row("b", 3))), + Row(Seq(Row(null, 0), Row("x", 1), Row(null, 2))), + Row(Seq.empty), + Row(null)) + ) + checkAnswer( + sdf.selectExpr("zip_with_index(s, true)"), + Seq( + Row(Seq(Row(0, "c"), Row(1, "a"), Row(2, "d"), Row(3, "b"))), + Row(Seq(Row(0, null), Row(1, "x"), Row(2, null))), + Row(Seq.empty), + Row(null)) + ) + + // Error test cases + intercept[AnalysisException] { + oneRowDF.select(zip_with_index('s)) + } + intercept[AnalysisException] { + oneRowDF.selectExpr("zip_with_index(array(1, 2, 3), b)") + } + intercept[AnalysisException] { + oneRowDF.selectExpr("zip_with_index(array(1, 2, 3), 1)") + } + } + private def assertValuesDoNotChangeAfterCoalesceOrUnion(v: Column): Unit = { import DataFrameFunctionsSuite.CodegenFallbackExpr for ((codegenFallback, wholeStage) <- Seq((true, false), (false, false), (false, true))) { From 4b4e02b5e62759d49cc2b30a15bffb1b63326c5d Mon Sep 17 00:00:00 2001 From: Marek Novotny Date: Sat, 21 Apr 2018 14:38:00 +0200 Subject: [PATCH 2/8] [SPARK-24042][SQL] Returning the python wrapper for reverse function back --- python/pyspark/sql/functions.py | 19 +++++++++++++++++++ 1 file changed, 19 insertions(+) diff --git a/python/pyspark/sql/functions.py b/python/pyspark/sql/functions.py index 080350a32f1b6..0231bba3b2c4d 100644 --- a/python/pyspark/sql/functions.py +++ b/python/pyspark/sql/functions.py @@ -2172,6 +2172,25 @@ def sort_array(col, asc=True): return Column(sc._jvm.functions.sort_array(_to_java_column(col), asc)) +@since(1.5) +@ignore_unicode_prefix +def reverse(col): + """ + Collection function: returns a reversed string or an array with reverse order of elements. + + :param col: name of column or expression + + >>> df = spark.createDataFrame([('Spark SQL',)], ['data']) + >>> df.select(reverse(df.data).alias('s')).collect() + [Row(s=u'LQS krapS')] + >>> df = spark.createDataFrame([([2, 1, 3],) ,([1],) ,([],)], ['data']) + >>> df.select(reverse(df.data).alias('r')).collect() + [Row(r=[3, 1, 2]), Row(r=[1]), Row(r=[])] + """ + sc = SparkContext._active_spark_context + return Column(sc._jvm.functions.reverse(_to_java_column(col))) + + @since(2.4) def zip_with_index(col, indexFirst=False): """ From aa97cadcf113c79646adc151b4aa40b4d924ba59 Mon Sep 17 00:00:00 2001 From: Marek Novotny Date: Sun, 22 Apr 2018 00:12:04 +0200 Subject: [PATCH 3/8] [SPARK-24042][SQL] Small refactoring after review + fixing failing test --- python/pyspark/sql/functions.py | 2 +- .../catalyst/expressions/collectionOperations.scala | 13 +++++++------ .../catalyst/expressions/ExpressionEvalHelper.scala | 2 +- .../main/scala/org/apache/spark/sql/functions.scala | 11 ++++++++++- 4 files changed, 19 insertions(+), 9 deletions(-) diff --git a/python/pyspark/sql/functions.py b/python/pyspark/sql/functions.py index 0231bba3b2c4d..69d093c896060 100644 --- a/python/pyspark/sql/functions.py +++ b/python/pyspark/sql/functions.py @@ -2204,7 +2204,7 @@ def zip_with_index(col, indexFirst=False): [Row(r=[[value=2, index=0], [value=5, index=1], [value=3, index=2]]), Row(r=[])] >>> df.select(zip_with_index(df.data, indexFirst=True).alias('r')).collect() [Row(r=[[index=0, value=2], [index=1, value=5], [index=2, value=3]]), Row(r=[])] - """ + """ sc = SparkContext._active_spark_context return Column(sc._jvm.functions.zip_with_index(_to_java_column(col), indexFirst)) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala index 1f1c2b6e7a2fd..a75a5ccfe7b24 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala @@ -899,12 +899,13 @@ case class Concat(children: Seq[Expression]) extends Expression { [(0,"d"),(1,"a"),(2,null),(3,"b")] """, since = "2.4.0") +// scalastyle:on line.size.limit case class ZipWithIndex(child: Expression, indexFirst: Expression) extends UnaryExpression with ExpectsInputTypes { def this(e: Expression) = this(e, Literal.FalseLiteral) - val indexFirstValue: Boolean = indexFirst match { + private val idxFirst: Boolean = indexFirst match { case Literal(v: Boolean, BooleanType) => v case _ => throw new AnalysisException("The second argument has to be a boolean constant.") } @@ -919,7 +920,7 @@ case class ZipWithIndex(child: Expression, indexFirst: Expression) val elementField = StructField("value", childArrayType.elementType, childArrayType.containsNull) val indexField = StructField("index", IntegerType, false) - val fields = if (indexFirstValue) Seq(indexField, elementField) else Seq(elementField, indexField) + val fields = if (idxFirst) Seq(indexField, elementField) else Seq(elementField, indexField) ArrayType(StructType(fields), false) } @@ -927,7 +928,7 @@ case class ZipWithIndex(child: Expression, indexFirst: Expression) override protected def nullSafeEval(input: Any): Any = { val array = input.asInstanceOf[ArrayData].toObjectArray(childArrayType.elementType) - val makeStruct = (v: Any, i: Int) => if (indexFirstValue) InternalRow(i, v) else InternalRow(v, i) + val makeStruct = (v: Any, i: Int) => if (idxFirst) InternalRow(i, v) else InternalRow(v, i) val resultData = array.zipWithIndex.map{case (v, i) => makeStruct(v, i)} new GenericArrayData(resultData) @@ -960,8 +961,8 @@ case class ZipWithIndex(child: Expression, indexFirst: Expression) val baseOffset = Platform.BYTE_ARRAY_OFFSET val longSize = LongType.defaultSize val primitiveValueTypeName = CodeGenerator.primitiveTypeName(childArrayType.elementType) - val valuePosition = if (indexFirstValue) "1" else "0" - val indexPosition = if (indexFirstValue) "0" else "1" + val (valuePosition, indexPosition) = if (idxFirst) ("1", "0") else ("0", "1") + s""" |final int $numElements = $childVariableName.numElements(); |final int $structSize = ${UnsafeRow.calculateBitSetWidthInBytes(2) + longSize * 2}; @@ -1005,7 +1006,7 @@ case class ZipWithIndex(child: Expression, indexFirst: Expression) val data = ctx.freshName("internalRowArray") val elementValue = CodeGenerator.getValue(childVariableName, childArrayType.elementType, "z") - val arguments = if (indexFirstValue) s"z, $elementValue" else s"$elementValue, z" + val arguments = if (idxFirst) s"z, $elementValue" else s"$elementValue, z" s""" |final int $numberOfElements = $childVariableName.numElements(); diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelper.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelper.scala index e1423b149ec8a..e739f1a6b4cfd 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelper.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelper.scala @@ -98,7 +98,7 @@ trait ExpressionEvalHelper extends GeneratorDrivenPropertyChecks { if (expected.isNaN) result.isNaN else expected == result case (result: Float, expected: Float) => if (expected.isNaN) result.isNaN else expected == result - case (result: InternalRow, expected: InternalRow) => + case (result: UnsafeRow, expected: GenericInternalRow) => val structType = exprDataType.asInstanceOf[StructType] result.toSeq(structType) == expected.toSeq(structType) case (result: Row, expected: InternalRow) => result.toSeq == expected.toSeq(result.schema) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala index 93c9c7411abf6..4473d8b0149f0 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala @@ -3347,7 +3347,16 @@ object functions { * @group collection_funcs * @since 2.4.0 */ - def zip_with_index(e: Column, indexFirst: Boolean = false): Column = withExpr { + def zip_with_index(e: Column): Column = withExpr { ZipWithIndex(e.expr, Literal.FalseLiteral) } + + /** + * Transforms the input array by encapsulating elements into pairs + * with indexes indicating the order. + * + * @group collection_funcs + * @since 2.4.0 + */ + def zip_with_index(e: Column, indexFirst: Boolean): Column = withExpr { ZipWithIndex(e.expr, Literal(indexFirst)) } From 1f11f73c0df64f4559e264cc0653f012b8ce6088 Mon Sep 17 00:00:00 2001 From: Marek Novotny Date: Sun, 22 Apr 2018 10:25:00 +0200 Subject: [PATCH 4/8] [SPARK-24042][SQL] Fixing scala style --- .../main/scala/org/apache/spark/sql/functions.scala | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala index 4473d8b0149f0..d9f9f1895a228 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala @@ -3350,12 +3350,12 @@ object functions { def zip_with_index(e: Column): Column = withExpr { ZipWithIndex(e.expr, Literal.FalseLiteral) } /** - * Transforms the input array by encapsulating elements into pairs - * with indexes indicating the order. - * - * @group collection_funcs - * @since 2.4.0 - */ + * Transforms the input array by encapsulating elements into pairs + * with indexes indicating the order. + * + * @group collection_funcs + * @since 2.4.0 + */ def zip_with_index(e: Column, indexFirst: Boolean): Column = withExpr { ZipWithIndex(e.expr, Literal(indexFirst)) } From 9dac7a40c284af87143ed9616662760801368436 Mon Sep 17 00:00:00 2001 From: Marek Novotny Date: Tue, 24 Apr 2018 23:02:32 +0200 Subject: [PATCH 5/8] [SPARK-24042][SQL] Fixing PySpark test + refactoring according to the feedback from code review. --- python/pyspark/sql/functions.py | 12 ++-- .../expressions/collectionOperations.scala | 70 ++++++++++++------- .../CollectionExpressionsSuite.scala | 56 +++++++-------- .../org/apache/spark/sql/functions.scala | 10 ++- .../spark/sql/DataFrameFunctionsSuite.scala | 60 +++++++++++----- 5 files changed, 127 insertions(+), 81 deletions(-) diff --git a/python/pyspark/sql/functions.py b/python/pyspark/sql/functions.py index 69d093c896060..262ada354043d 100644 --- a/python/pyspark/sql/functions.py +++ b/python/pyspark/sql/functions.py @@ -2192,7 +2192,7 @@ def reverse(col): @since(2.4) -def zip_with_index(col, indexFirst=False): +def zip_with_index(col, indexFirst=False, startFromZero=False): """ Collection function: transforms the input array by encapsulating elements into pairs with indexes indicating the order. @@ -2201,12 +2201,14 @@ def zip_with_index(col, indexFirst=False): >>> df = spark.createDataFrame([([2, 5, 3],), ([],)], ['data']) >>> df.select(zip_with_index(df.data).alias('r')).collect() - [Row(r=[[value=2, index=0], [value=5, index=1], [value=3, index=2]]), Row(r=[])] - >>> df.select(zip_with_index(df.data, indexFirst=True).alias('r')).collect() - [Row(r=[[index=0, value=2], [index=1, value=5], [index=2, value=3]]), Row(r=[])] + [Row(r=[Row(value=2, index=1), Row(value=5, index=2), Row(value=3, index=3)]), Row(r=[])] + >>> df.select(zip_with_index(df.data, indexFirst=True, startFromZero=False).alias('r')).collect() + [Row(r=[Row(index=1, value=2), Row(index=2, value=5), Row(index=3, value=3)]), Row(r=[])] + >>> df.select(zip_with_index(df.data, indexFirst=True, startFromZero=True).alias('r')).collect() + [Row(r=[Row(index=0, value=2), Row(index=1, value=5), Row(index=2, value=3)]), Row(r=[])] """ sc = SparkContext._active_spark_context - return Column(sc._jvm.functions.zip_with_index(_to_java_column(col), indexFirst)) + return Column(sc._jvm.functions.zip_with_index(_to_java_column(col), indexFirst, startFromZero)) @since(2.3) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala index a75a5ccfe7b24..f37bac205d3d4 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala @@ -886,28 +886,38 @@ case class Concat(children: Seq[Expression]) extends Expression { } /** - * Returns the maximum value in the array. + * Transforms an array by assigning an order number to each element. */ // scalastyle:off line.size.limit @ExpressionDescription( - usage = "_FUNC_(array[, indexFirst]) - Transforms the input array by encapsulating elements into pairs with indexes indicating the order.", + usage = "_FUNC_(array[, indexFirst, startFromZero]) - Transforms the input array by encapsulating elements into pairs with indexes indicating the order.", examples = """ Examples: > SELECT _FUNC_(array("d", "a", null, "b")); - [("d",0),("a",1),(null,2),("b",3)] - > SELECT _FUNC_(array("d", "a", null, "b"), true); + [("d",1),("a",2),(null,3),("b",4)] + > SELECT _FUNC_(array("d", "a", null, "b"), true, false); + [(1,"d"),(2,"a"),(3,null),(4,"b")] + > SELECT _FUNC_(array("d", "a", null, "b"), true, true); [(0,"d"),(1,"a"),(2,null),(3,"b")] """, since = "2.4.0") // scalastyle:on line.size.limit -case class ZipWithIndex(child: Expression, indexFirst: Expression) +case class ZipWithIndex(child: Expression, indexFirst: Expression, startFromZero: Expression) extends UnaryExpression with ExpectsInputTypes { - def this(e: Expression) = this(e, Literal.FalseLiteral) + def this(e: Expression) = this(e, Literal.FalseLiteral, Literal.FalseLiteral) - private val idxFirst: Boolean = indexFirst match { + def exprToFlag(e: Expression, order: String): Boolean = e match { case Literal(v: Boolean, BooleanType) => v - case _ => throw new AnalysisException("The second argument has to be a boolean constant.") + case _ => throw new AnalysisException(s"The $order argument has to be a boolean constant.") + } + + private val idxFirst: Boolean = exprToFlag(indexFirst, "second") + + private val (idxShift, idxGen): (Int, String) = if (exprToFlag(startFromZero, "third")) { + (0, "z") + } else { + (1, "z + 1") } private val MAX_ARRAY_LENGTH: Int = ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH @@ -929,26 +939,31 @@ case class ZipWithIndex(child: Expression, indexFirst: Expression) val array = input.asInstanceOf[ArrayData].toObjectArray(childArrayType.elementType) val makeStruct = (v: Any, i: Int) => if (idxFirst) InternalRow(i, v) else InternalRow(v, i) - val resultData = array.zipWithIndex.map{case (v, i) => makeStruct(v, i)} + val resultData = array.zipWithIndex.map{case (v, i) => makeStruct(v, i + idxShift)} new GenericArrayData(resultData) } override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { nullSafeCodeGen(ctx, ev, c => { - if (CodeGenerator.isPrimitiveType(childArrayType.elementType)) { - genCodeForPrimitiveElements(ctx, c, ev.value) + val numElements = ctx.freshName("numElements") + val code = if (CodeGenerator.isPrimitiveType(childArrayType.elementType)) { + genCodeForPrimitiveElements(ctx, c, ev.value, numElements) } else { - genCodeForNonPrimitiveElements(ctx, c, ev.value) + genCodeForAnyElements(ctx, c, ev.value, numElements) } + s""" + |final int $numElements = $c.numElements(); + |$code + """.stripMargin }) } private def genCodeForPrimitiveElements( ctx: CodegenContext, childVariableName: String, - arrayData: String): String = { - val numElements = ctx.freshName("numElements") + arrayData: String, + numElements: String): String = { val byteArraySize = ctx.freshName("byteArraySize") val data = ctx.freshName("byteArray") val unsafeRow = ctx.freshName("unsafeRow") @@ -964,14 +979,11 @@ case class ZipWithIndex(child: Expression, indexFirst: Expression) val (valuePosition, indexPosition) = if (idxFirst) ("1", "0") else ("0", "1") s""" - |final int $numElements = $childVariableName.numElements(); |final int $structSize = ${UnsafeRow.calculateBitSetWidthInBytes(2) + longSize * 2}; |final long $byteArraySize = $calculateArraySize($numElements, $longSize + $structSize); |final int $structsOffset = $calculateHeader($numElements) + $numElements * $longSize; |if ($byteArraySize > $MAX_ARRAY_LENGTH) { - | throw new RuntimeException("Unsuccessful try to zip array with index due to exceeding" + - | " the limit $MAX_ARRAY_LENGTH bytes for UnsafeArrayData. " + $byteArraySize + - | " bytes of data are required for performing the operation with the given array."); + | ${genCodeForAnyElements(ctx, childVariableName, arrayData, numElements)} |} |final byte[] $data = new byte[(int)$byteArraySize]; |UnsafeArrayData $unsafeArrayData = new UnsafeArrayData(); @@ -990,28 +1002,32 @@ case class ZipWithIndex(child: Expression, indexFirst: Expression) | ${CodeGenerator.getValue(childVariableName, childArrayType.elementType, "z")} | ); | } - | $unsafeRow.setInt($indexPosition, z); + | $unsafeRow.setInt($indexPosition, $idxGen); |} |$arrayData = $unsafeArrayData; """.stripMargin } - private def genCodeForNonPrimitiveElements( + private def genCodeForAnyElements( ctx: CodegenContext, childVariableName: String, - arrayData: String): String = { + arrayData: String, + numElements: String): String = { val genericArrayClass = classOf[GenericArrayData].getName val rowClass = classOf[GenericInternalRow].getName - val numberOfElements = ctx.freshName("numElements") val data = ctx.freshName("internalRowArray") - val elementValue = CodeGenerator.getValue(childVariableName, childArrayType.elementType, "z") - val arguments = if (idxFirst) s"z, $elementValue" else s"$elementValue, z" + val getElement = CodeGenerator.getValue(childVariableName, childArrayType.elementType, "z") + val elementValue = if (CodeGenerator.isPrimitiveType(childArrayType.elementType)) { + s"$childVariableName.isNullAt(z) ? null : (Object)$getElement" + } else { + getElement + } + val arguments = if (idxFirst) s"$idxGen, $elementValue" else s"$elementValue, $idxGen" s""" - |final int $numberOfElements = $childVariableName.numElements(); - |final Object[] $data = new Object[$numberOfElements]; - |for (int z = 0; z < $numberOfElements; z++) { + |final Object[] $data = new Object[$numElements]; + |for (int z = 0; z < $numElements; z++) { | $data[z] = new $rowClass(new Object[]{$arguments}); |} |$arrayData = new $genericArrayClass($data); diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala index 78202cc857f8b..3b53e51330ef4 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala @@ -291,23 +291,23 @@ class CollectionExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper val ai0 = Literal.create(Seq(2, 8, 4, 7), ArrayType(IntegerType)) val ai1 = Literal.create(Seq(null, 4, null, 2), ArrayType(IntegerType)) val ai2 = Literal.create(Seq(null, null, null), ArrayType(IntegerType)) - val ai3 = Literal.create(Seq(1), ArrayType(IntegerType)) + val ai3 = Literal.create(Seq(2), ArrayType(IntegerType)) val ai4 = Literal.create(Seq.empty, ArrayType(IntegerType)) val ai5 = Literal.create(null, ArrayType(IntegerType)) - checkEvaluation(ZipWithIndex(ai0, f), Seq(r(2, 0), r(8, 1), r(4, 2), r(7, 3))) - checkEvaluation(ZipWithIndex(ai1, f), Seq(r(null, 0), r(4, 1), r(null, 2), r(2, 3))) - checkEvaluation(ZipWithIndex(ai2, f), Seq(r(null, 0), r(null, 1), r(null, 2))) - checkEvaluation(ZipWithIndex(ai3, f), Seq(r(1, 0))) - checkEvaluation(ZipWithIndex(ai4, f), Seq.empty) - checkEvaluation(ZipWithIndex(ai5, f), null) + checkEvaluation(ZipWithIndex(ai0, f, f), Seq(r(2, 1), r(8, 2), r(4, 3), r(7, 4))) + checkEvaluation(ZipWithIndex(ai1, f, f), Seq(r(null, 1), r(4, 2), r(null, 3), r(2, 4))) + checkEvaluation(ZipWithIndex(ai2, f, f), Seq(r(null, 1), r(null, 2), r(null, 3))) + checkEvaluation(ZipWithIndex(ai3, f, f), Seq(r(2, 1))) + checkEvaluation(ZipWithIndex(ai4, f, f), Seq.empty) + checkEvaluation(ZipWithIndex(ai5, f, f), null) - checkEvaluation(ZipWithIndex(ai0, t), Seq(r(0, 2), r(1, 8), r(2, 4), r(3, 7))) - checkEvaluation(ZipWithIndex(ai1, t), Seq(r(0, null), r(1, 4), r(2, null), r(3, 2))) - checkEvaluation(ZipWithIndex(ai2, t), Seq(r(0, null), r(1, null), r(2, null))) - checkEvaluation(ZipWithIndex(ai3, t), Seq(r(0, 1))) - checkEvaluation(ZipWithIndex(ai4, t), Seq.empty) - checkEvaluation(ZipWithIndex(ai5, t), null) + checkEvaluation(ZipWithIndex(ai0, t, t), Seq(r(0, 2), r(1, 8), r(2, 4), r(3, 7))) + checkEvaluation(ZipWithIndex(ai1, t, t), Seq(r(0, null), r(1, 4), r(2, null), r(3, 2))) + checkEvaluation(ZipWithIndex(ai2, t, t), Seq(r(0, null), r(1, null), r(2, null))) + checkEvaluation(ZipWithIndex(ai3, t, t), Seq(r(0, 2))) + checkEvaluation(ZipWithIndex(ai4, t, t), Seq.empty) + checkEvaluation(ZipWithIndex(ai5, t, t), null) // Non-primitive-type elements val as0 = Literal.create(Seq("b", "a", "y", "z"), ArrayType(StringType)) @@ -318,20 +318,20 @@ class CollectionExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper val as5 = Literal.create(null, ArrayType(StringType)) val aas = Literal.create(Seq(Seq("e"), Seq("c", "d")), ArrayType(ArrayType(StringType))) - checkEvaluation(ZipWithIndex(as0, f), Seq(r("b", 0), r("a", 1), r("y", 2), r("z", 3))) - checkEvaluation(ZipWithIndex(as1, f), Seq(r(null, 0), r("x", 1), r(null, 2), r("y", 3))) - checkEvaluation(ZipWithIndex(as2, f), Seq(r(null, 0), r(null, 1), r(null, 2))) - checkEvaluation(ZipWithIndex(as3, f), Seq(r("a", 0))) - checkEvaluation(ZipWithIndex(as4, f), Seq.empty) - checkEvaluation(ZipWithIndex(as5, f), null) - checkEvaluation(ZipWithIndex(aas, f), Seq(r(Seq("e"), 0), r(Seq("c", "d"), 1))) - - checkEvaluation(ZipWithIndex(as0, t), Seq(r(0, "b"), r(1, "a"), r(2, "y"), r(3, "z"))) - checkEvaluation(ZipWithIndex(as1, t), Seq(r(0, null), r(1, "x"), r(2, null), r(3, "y"))) - checkEvaluation(ZipWithIndex(as2, t), Seq(r(0, null), r(1, null), r(2, null))) - checkEvaluation(ZipWithIndex(as3, t), Seq(r(0, "a"))) - checkEvaluation(ZipWithIndex(as4, t), Seq.empty) - checkEvaluation(ZipWithIndex(as5, t), null) - checkEvaluation(ZipWithIndex(aas, t), Seq(r(0, Seq("e")), r(1, Seq("c", "d")))) + checkEvaluation(ZipWithIndex(as0, f, f), Seq(r("b", 1), r("a", 2), r("y", 3), r("z", 4))) + checkEvaluation(ZipWithIndex(as1, f, f), Seq(r(null, 1), r("x", 2), r(null, 3), r("y", 4))) + checkEvaluation(ZipWithIndex(as2, f, f), Seq(r(null, 1), r(null, 2), r(null, 3))) + checkEvaluation(ZipWithIndex(as3, f, f), Seq(r("a", 1))) + checkEvaluation(ZipWithIndex(as4, f, f), Seq.empty) + checkEvaluation(ZipWithIndex(as5, f, f), null) + checkEvaluation(ZipWithIndex(aas, f, f), Seq(r(Seq("e"), 1), r(Seq("c", "d"), 2))) + + checkEvaluation(ZipWithIndex(as0, t, t), Seq(r(0, "b"), r(1, "a"), r(2, "y"), r(3, "z"))) + checkEvaluation(ZipWithIndex(as1, t, t), Seq(r(0, null), r(1, "x"), r(2, null), r(3, "y"))) + checkEvaluation(ZipWithIndex(as2, t, t), Seq(r(0, null), r(1, null), r(2, null))) + checkEvaluation(ZipWithIndex(as3, t, t), Seq(r(0, "a"))) + checkEvaluation(ZipWithIndex(as4, t, t), Seq.empty) + checkEvaluation(ZipWithIndex(as5, t, t), null) + checkEvaluation(ZipWithIndex(aas, t, t), Seq(r(0, Seq("e")), r(1, Seq("c", "d")))) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala index d9f9f1895a228..0399bdb4f5a20 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala @@ -3344,10 +3344,14 @@ object functions { * Transforms the input array by encapsulating elements into pairs * with indexes indicating the order. * + * Note: The array index is placed second and starts from one. + * * @group collection_funcs * @since 2.4.0 */ - def zip_with_index(e: Column): Column = withExpr { ZipWithIndex(e.expr, Literal.FalseLiteral) } + def zip_with_index(e: Column): Column = withExpr { + ZipWithIndex(e.expr, Literal.FalseLiteral, Literal.FalseLiteral) + } /** * Transforms the input array by encapsulating elements into pairs @@ -3356,8 +3360,8 @@ object functions { * @group collection_funcs * @since 2.4.0 */ - def zip_with_index(e: Column, indexFirst: Boolean): Column = withExpr { - ZipWithIndex(e.expr, Literal(indexFirst)) + def zip_with_index(e: Column, indexFirst: Boolean, startFromZero: Boolean): Column = withExpr { + ZipWithIndex(e.expr, Literal(indexFirst), Literal(startFromZero)) } /** diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala index 76f12b54c8e01..ddd09e798bcf2 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala @@ -704,30 +704,38 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSQLContext { checkAnswer( idf.select(zip_with_index('i)), - Seq(Row(Seq(Row(1, 0), Row(9, 1), Row(8, 2), Row(7, 3))), Row(Seq.empty), Row(null)) + Seq(Row(Seq(Row(1, 1), Row(9, 2), Row(8, 3), Row(7, 4))), Row(Seq.empty), Row(null)) ) checkAnswer( idf.filter(dummyFilter('i)).select(zip_with_index('i)), - Seq(Row(Seq(Row(1, 0), Row(9, 1), Row(8, 2), Row(7, 3))), Row(Seq.empty), Row(null)) + Seq(Row(Seq(Row(1, 1), Row(9, 2), Row(8, 3), Row(7, 4))), Row(Seq.empty), Row(null)) + ) + checkAnswer( + idf.select(zip_with_index('i, true, false)), + Seq(Row(Seq(Row(1, 1), Row(2, 9), Row(3, 8), Row(4, 7))), Row(Seq.empty), Row(null)) ) checkAnswer( - idf.select(zip_with_index('i, true)), + idf.select(zip_with_index('i, true, true)), Seq(Row(Seq(Row(0, 1), Row(1, 9), Row(2, 8), Row(3, 7))), Row(Seq.empty), Row(null)) ) checkAnswer( idf.selectExpr("zip_with_index(i)"), - Seq(Row(Seq(Row(1, 0), Row(9, 1), Row(8, 2), Row(7, 3))), Row(Seq.empty), Row(null)) + Seq(Row(Seq(Row(1, 1), Row(9, 2), Row(8, 3), Row(7, 4))), Row(Seq.empty), Row(null)) ) checkAnswer( - idf.selectExpr("zip_with_index(i, true)"), - Seq(Row(Seq(Row(0, 1), Row(1, 9), Row(2, 8), Row(3, 7))), Row(Seq.empty), Row(null)) + idf.selectExpr("zip_with_index(i, true, false)"), + Seq(Row(Seq(Row(1, 1), Row(2, 9), Row(3, 8), Row(4, 7))), Row(Seq.empty), Row(null)) ) checkAnswer( - oneRowDF.selectExpr("zip_with_index(array(null, 2, null), false)"), + idf.selectExpr("zip_with_index(i, false, true)"), + Seq(Row(Seq(Row(1, 0), Row(9, 1), Row(8, 2), Row(7, 3))), Row(Seq.empty), Row(null)) + ) + checkAnswer( + oneRowDF.selectExpr("zip_with_index(array(null, 2, null), false, true)"), Seq(Row(Seq(Row(null, 0), Row(2, 1), Row(null, 2)))) ) checkAnswer( - oneRowDF.selectExpr("zip_with_index(array(null, 2, null), true)"), + oneRowDF.selectExpr("zip_with_index(array(null, 2, null), true, true)"), Seq(Row(Seq(Row(0, null), Row(1, 2), Row(2, null)))) ) @@ -742,21 +750,29 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSQLContext { checkAnswer( sdf.select(zip_with_index('s)), Seq( - Row(Seq(Row("c", 0), Row("a", 1), Row("d", 2), Row("b", 3))), - Row(Seq(Row(null, 0), Row("x", 1), Row(null, 2))), + Row(Seq(Row("c", 1), Row("a", 2), Row("d", 3), Row("b", 4))), + Row(Seq(Row(null, 1), Row("x", 2), Row(null, 3))), Row(Seq.empty), Row(null)) ) checkAnswer( sdf.filter(dummyFilter('s)).select(zip_with_index('s)), Seq( - Row(Seq(Row("c", 0), Row("a", 1), Row("d", 2), Row("b", 3))), - Row(Seq(Row(null, 0), Row("x", 1), Row(null, 2))), + Row(Seq(Row("c", 1), Row("a", 2), Row("d", 3), Row("b", 4))), + Row(Seq(Row(null, 1), Row("x", 2), Row(null, 3))), Row(Seq.empty), Row(null)) ) checkAnswer( - sdf.select(zip_with_index('s, true)), + sdf.select(zip_with_index('s, true, false)), + Seq( + Row(Seq(Row(1, "c"), Row(2, "a"), Row(3, "d"), Row(4, "b"))), + Row(Seq(Row(1, null), Row(2, "x"), Row(3, null))), + Row(Seq.empty), + Row(null)) + ) + checkAnswer( + sdf.select(zip_with_index('s, true, true)), Seq( Row(Seq(Row(0, "c"), Row(1, "a"), Row(2, "d"), Row(3, "b"))), Row(Seq(Row(0, null), Row(1, "x"), Row(2, null))), @@ -765,6 +781,14 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSQLContext { ) checkAnswer( sdf.selectExpr("zip_with_index(s)"), + Seq( + Row(Seq(Row("c", 1), Row("a", 2), Row("d", 3), Row("b", 4))), + Row(Seq(Row(null, 1), Row("x", 2), Row(null, 3))), + Row(Seq.empty), + Row(null)) + ) + checkAnswer( + sdf.selectExpr("zip_with_index(s, false, true)"), Seq( Row(Seq(Row("c", 0), Row("a", 1), Row("d", 2), Row("b", 3))), Row(Seq(Row(null, 0), Row("x", 1), Row(null, 2))), @@ -772,10 +796,10 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSQLContext { Row(null)) ) checkAnswer( - sdf.selectExpr("zip_with_index(s, true)"), + sdf.selectExpr("zip_with_index(s, true, false)"), Seq( - Row(Seq(Row(0, "c"), Row(1, "a"), Row(2, "d"), Row(3, "b"))), - Row(Seq(Row(0, null), Row(1, "x"), Row(2, null))), + Row(Seq(Row(1, "c"), Row(2, "a"), Row(3, "d"), Row(4, "b"))), + Row(Seq(Row(1, null), Row(2, "x"), Row(3, null))), Row(Seq.empty), Row(null)) ) @@ -785,10 +809,10 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSQLContext { oneRowDF.select(zip_with_index('s)) } intercept[AnalysisException] { - oneRowDF.selectExpr("zip_with_index(array(1, 2, 3), b)") + oneRowDF.selectExpr("zip_with_index(array(1, 2, 3), b, false)") } intercept[AnalysisException] { - oneRowDF.selectExpr("zip_with_index(array(1, 2, 3), 1)") + oneRowDF.selectExpr("zip_with_index(array(1, 2, 3), true, 1)") } } From 8692c4d095eb15e4c5127241f8f94dcad6944091 Mon Sep 17 00:00:00 2001 From: Marek Novotny Date: Tue, 24 Apr 2018 23:11:51 +0200 Subject: [PATCH 6/8] [SPARK-24042][SQL] Fix of primitive-type codeGen. --- .../expressions/collectionOperations.scala | 39 ++++++++++--------- 1 file changed, 20 insertions(+), 19 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala index f37bac205d3d4..67ddaa524740f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala @@ -984,27 +984,28 @@ case class ZipWithIndex(child: Expression, indexFirst: Expression, startFromZero |final int $structsOffset = $calculateHeader($numElements) + $numElements * $longSize; |if ($byteArraySize > $MAX_ARRAY_LENGTH) { | ${genCodeForAnyElements(ctx, childVariableName, arrayData, numElements)} - |} - |final byte[] $data = new byte[(int)$byteArraySize]; - |UnsafeArrayData $unsafeArrayData = new UnsafeArrayData(); - |Platform.putLong($data, $baseOffset, $numElements); - |$unsafeArrayData.pointTo($data, $baseOffset, (int)$byteArraySize); - |UnsafeRow $unsafeRow = new UnsafeRow(2); - |for (int z = 0; z < $numElements; z++) { - | long offset = $structsOffset + z * $structSize; - | $unsafeArrayData.setLong(z, (offset << 32) + $structSize); - | $unsafeRow.pointTo($data, $baseOffset + offset, $structSize); - | if ($childVariableName.isNullAt(z)) { - | $unsafeRow.setNullAt($valuePosition); - | } else { - | $unsafeRow.set$primitiveValueTypeName( - | $valuePosition, - | ${CodeGenerator.getValue(childVariableName, childArrayType.elementType, "z")} - | ); + |} else { + | final byte[] $data = new byte[(int)$byteArraySize]; + | UnsafeArrayData $unsafeArrayData = new UnsafeArrayData(); + | Platform.putLong($data, $baseOffset, $numElements); + | $unsafeArrayData.pointTo($data, $baseOffset, (int)$byteArraySize); + | UnsafeRow $unsafeRow = new UnsafeRow(2); + | for (int z = 0; z < $numElements; z++) { + | long offset = $structsOffset + z * $structSize; + | $unsafeArrayData.setLong(z, (offset << 32) + $structSize); + | $unsafeRow.pointTo($data, $baseOffset + offset, $structSize); + | if ($childVariableName.isNullAt(z)) { + | $unsafeRow.setNullAt($valuePosition); + | } else { + | $unsafeRow.set$primitiveValueTypeName( + | $valuePosition, + | ${CodeGenerator.getValue(childVariableName, childArrayType.elementType, "z")} + | ); + | } + | $unsafeRow.setInt($indexPosition, $idxGen); | } - | $unsafeRow.setInt($indexPosition, $idxGen); + | $arrayData = $unsafeArrayData; |} - |$arrayData = $unsafeArrayData; """.stripMargin } From 17010b2a38bdcaf1df79e3f3f1a110d332ea3910 Mon Sep 17 00:00:00 2001 From: Marek Novotny Date: Tue, 24 Apr 2018 23:17:20 +0200 Subject: [PATCH 7/8] [SPARK-24042][SQL] Fixing python code style. --- python/pyspark/sql/functions.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/python/pyspark/sql/functions.py b/python/pyspark/sql/functions.py index 262ada354043d..f6a7afd3f6ee0 100644 --- a/python/pyspark/sql/functions.py +++ b/python/pyspark/sql/functions.py @@ -2199,12 +2199,12 @@ def zip_with_index(col, indexFirst=False, startFromZero=False): :param col: name of column or expression - >>> df = spark.createDataFrame([([2, 5, 3],), ([],)], ['data']) - >>> df.select(zip_with_index(df.data).alias('r')).collect() + >>> df = spark.createDataFrame([([2, 5, 3],), ([],)], ['d']) + >>> df.select(zip_with_index(df.d).alias('r')).collect() [Row(r=[Row(value=2, index=1), Row(value=5, index=2), Row(value=3, index=3)]), Row(r=[])] - >>> df.select(zip_with_index(df.data, indexFirst=True, startFromZero=False).alias('r')).collect() + >>> df.select(zip_with_index(df.d, indexFirst=True, startFromZero=False).alias('r')).collect() [Row(r=[Row(index=1, value=2), Row(index=2, value=5), Row(index=3, value=3)]), Row(r=[])] - >>> df.select(zip_with_index(df.data, indexFirst=True, startFromZero=True).alias('r')).collect() + >>> df.select(zip_with_index(df.d, indexFirst=True, startFromZero=True).alias('r')).collect() [Row(r=[Row(index=0, value=2), Row(index=1, value=5), Row(index=2, value=3)]), Row(r=[])] """ sc = SparkContext._active_spark_context From da270c71c517dae273c8d94c6eeba4b001e98c72 Mon Sep 17 00:00:00 2001 From: Marek Novotny Date: Wed, 25 Apr 2018 11:39:32 +0200 Subject: [PATCH 8/8] [SPARK-24042][SQL] Optimizing generated code for arrays that doesn't contain nulls. --- .../sql/catalyst/expressions/collectionOperations.scala | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala index 67ddaa524740f..5bbfd23c53345 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala @@ -994,7 +994,7 @@ case class ZipWithIndex(child: Expression, indexFirst: Expression, startFromZero | long offset = $structsOffset + z * $structSize; | $unsafeArrayData.setLong(z, (offset << 32) + $structSize); | $unsafeRow.pointTo($data, $baseOffset + offset, $structSize); - | if ($childVariableName.isNullAt(z)) { + | if (${childArrayType.containsNull} && $childVariableName.isNullAt(z)) { | $unsafeRow.setNullAt($valuePosition); | } else { | $unsafeRow.set$primitiveValueTypeName( @@ -1019,7 +1019,8 @@ case class ZipWithIndex(child: Expression, indexFirst: Expression, startFromZero val data = ctx.freshName("internalRowArray") val getElement = CodeGenerator.getValue(childVariableName, childArrayType.elementType, "z") - val elementValue = if (CodeGenerator.isPrimitiveType(childArrayType.elementType)) { + val isPrimitiveType = CodeGenerator.isPrimitiveType(childArrayType.elementType) + val elementValue = if (childArrayType.containsNull && isPrimitiveType) { s"$childVariableName.isNullAt(z) ? null : (Object)$getElement" } else { getElement