From 30e91826d5c1c268689f8e683727a7566c9c17c5 Mon Sep 17 00:00:00 2001 From: Davies Liu Date: Mon, 17 Aug 2015 13:40:50 -0700 Subject: [PATCH 1/2] fix bug in generated unsafe projection when there is binary in ArrayData --- .../codegen/GenerateUnsafeProjection.scala | 12 +++++++--- .../codegen/GeneratedProjectionSuite.scala | 24 ++++++++++++++++++- 2 files changed, 32 insertions(+), 4 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala index b2fb913850794..b570fe86db1aa 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala @@ -224,7 +224,7 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro // go through the input array to calculate how many bytes we need. val calculateNumBytes = elementType match { - case _ if (ctx.isPrimitiveType(elementType)) => + case _ if ctx.isPrimitiveType(elementType) => // Should we do word align? val elementSize = elementType.defaultSize s""" @@ -237,6 +237,7 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro case _ => val writer = getWriter(elementType) val elementSize = s"$writer.getSize($elements[$index])" + // TODO(davies): avoid the copy val unsafeType = elementType match { case _: StructType => "UnsafeRow" case _: ArrayType => "UnsafeArrayData" @@ -249,8 +250,13 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro case _ => "" } + val newElements = if (elementType == BinaryType) { + s"new byte[$numElements][]" + } else { + s"new $unsafeType[$numElements]" + } s""" - final $unsafeType[] $elements = new $unsafeType[$numElements]; + final $unsafeType[] $elements = $newElements; for (int $index = 0; $index < $numElements; $index++) { ${convertedElement.code} if (!${convertedElement.isNull}) { @@ -262,7 +268,7 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro } val writeElement = elementType match { - case _ if (ctx.isPrimitiveType(elementType)) => + case _ if ctx.isPrimitiveType(elementType) => // Should we do word align? val elementSize = elementType.defaultSize s""" diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/GeneratedProjectionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/GeneratedProjectionSuite.scala index 8c7ee8720f7bb..b719cea2460ad 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/GeneratedProjectionSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/GeneratedProjectionSuite.scala @@ -20,7 +20,7 @@ package org.apache.spark.sql.catalyst.expressions.codegen import org.apache.spark.SparkFunSuite import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.types.{StringType, IntegerType, StructField, StructType} +import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.UTF8String /** @@ -79,4 +79,26 @@ class GeneratedProjectionSuite extends SparkFunSuite { val row2 = mutableProj(result) assert(result === row2) } + + test("generated unsafe projection with array of binary") { + val row = InternalRow( + Array[Byte](1, 2), + new GenericArrayData(Array(Array[Byte](1, 2), null, Array[Byte](3, 4)))) + val fields = (BinaryType :: ArrayType(BinaryType) :: Nil).toArray[DataType] + + val unsafeProj = UnsafeProjection.create(fields) + val unsafeRow: UnsafeRow = unsafeProj(row) + println(s"unsafe row is $unsafeRow") + assert(java.util.Arrays.equals(unsafeRow.getBinary(0), Array[Byte](1, 2))) + println(s"array is ${unsafeRow.getArray(1).getBinary(0)}") + unsafeRow.getArray(1).getBinary(0).foreach(println) + assert(java.util.Arrays.equals(unsafeRow.getArray(1).getBinary(0), Array[Byte](1, 2))) + assert(unsafeRow.getArray(1).isNullAt(1)) + assert(unsafeRow.getArray(1).getBinary(1) === null) + assert(java.util.Arrays.equals(unsafeRow.getArray(1).getBinary(2), Array[Byte](3, 4))) + + val safeProj = FromUnsafeProjection(fields) + val row2 = safeProj(unsafeRow) + assert(row2 === row) + } } From 31141d1fd2e310a4466cb0b12e5597f916bc6d00 Mon Sep 17 00:00:00 2001 From: Davies Liu Date: Mon, 17 Aug 2015 14:43:21 -0700 Subject: [PATCH 2/2] Update GeneratedProjectionSuite.scala --- .../expressions/codegen/GeneratedProjectionSuite.scala | 3 --- 1 file changed, 3 deletions(-) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/GeneratedProjectionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/GeneratedProjectionSuite.scala index b719cea2460ad..098944a9f4fc5 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/GeneratedProjectionSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/GeneratedProjectionSuite.scala @@ -88,10 +88,7 @@ class GeneratedProjectionSuite extends SparkFunSuite { val unsafeProj = UnsafeProjection.create(fields) val unsafeRow: UnsafeRow = unsafeProj(row) - println(s"unsafe row is $unsafeRow") assert(java.util.Arrays.equals(unsafeRow.getBinary(0), Array[Byte](1, 2))) - println(s"array is ${unsafeRow.getArray(1).getBinary(0)}") - unsafeRow.getArray(1).getBinary(0).foreach(println) assert(java.util.Arrays.equals(unsafeRow.getArray(1).getBinary(0), Array[Byte](1, 2))) assert(unsafeRow.getArray(1).isNullAt(1)) assert(unsafeRow.getArray(1).getBinary(1) === null)