From 9e7b91db49a2949bb17d85beada70ee1f2117a14 Mon Sep 17 00:00:00 2001 From: Kazuaki Ishizaki Date: Wed, 15 Jun 2016 01:38:02 +0900 Subject: [PATCH 1/5] eliminate unreachable code blocks --- .../expressions/codegen/CodeGenerator.scala | 7 ++ .../codegen/GenerateUnsafeProjection.scala | 95 +++++++++++-------- .../expressions/complexTypeCreator.scala | 85 +++++++++++------ .../spark/sql/DataFrameComplexTypeSuite.scala | 32 +++++++ 4 files changed, 153 insertions(+), 66 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala index ff97cd321199a..79a5568511486 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala @@ -108,6 +108,13 @@ class CodegenContext { */ var copyResult: Boolean = false + /** + * Whether should we write complex type data only to generic structure + * + * If we write complex type data only to a generic structure at projection, set true to this var + */ + var genericWriteBuffer: Boolean = false + /** * Holding expressions' mutable states like `MonotonicallyIncreasingID.count` as a * 3-tuple: java type, variable name, code to init it. diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala index 5efba4b3a6087..61a840af35a46 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala @@ -56,13 +56,17 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro ExprCode(code, isNull, fieldName) } - s""" - if ($input instanceof UnsafeRow) { - ${writeUnsafeData(ctx, s"((UnsafeRow) $input)", bufferHolder)} - } else { - ${writeExpressionsToBuffer(ctx, input, fieldEvals, fieldTypes, bufferHolder)} - } - """ + if (ctx.genericWriteBuffer) { + s"${writeExpressionsToBuffer(ctx, input, fieldEvals, fieldTypes, bufferHolder)}" + } else { + s""" + if ($input instanceof UnsafeRow) { + ${writeUnsafeData(ctx, s"((UnsafeRow) $input)", bufferHolder)} + } else { + ${writeExpressionsToBuffer(ctx, input, fieldEvals, fieldTypes, bufferHolder)} + } + """ + } } private def writeExpressionsToBuffer( @@ -222,23 +226,31 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro case _ => s"$arrayWriter.write($index, $element);" } - s""" - if ($input instanceof UnsafeArrayData) { - ${writeUnsafeData(ctx, s"((UnsafeArrayData) $input)", bufferHolder)} - } else { - final int $numElements = $input.numElements(); - $arrayWriter.initialize($bufferHolder, $numElements, $fixedElementSize); - - for (int $index = 0; $index < $numElements; $index++) { - if ($input.isNullAt($index)) { - $arrayWriter.setNullAt($index); - } else { - final $jt $element = ${ctx.getValue(input, et, index)}; - $writeElement - } + + val writeSafeArrayToBuffer = s""" + final int $numElements = $input.numElements(); + $arrayWriter.initialize($bufferHolder, $numElements, $fixedElementSize); + + for (int $index = 0; $index < $numElements; $index++) { + if ($input.isNullAt($index)) { + $arrayWriter.setNullAt($index); + } else { + final $jt $element = ${ctx.getValue(input, et, index)}; + $writeElement } } - """ + """ + if (ctx.genericWriteBuffer) { + writeSafeArrayToBuffer + } else { + s""" + if ($input instanceof UnsafeArrayData) { + ${writeUnsafeData(ctx, s"((UnsafeArrayData) $input)", bufferHolder)} + } else { + ${writeSafeArrayToBuffer} + } + """ + } } // TODO: if the nullability of value element is correct, we can use it to save null check. @@ -254,27 +266,34 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro // Writes out unsafe map according to the format described in `UnsafeMapData`. - s""" - if ($input instanceof UnsafeMapData) { - ${writeUnsafeData(ctx, s"((UnsafeMapData) $input)", bufferHolder)} - } else { - final ArrayData $keys = $input.keyArray(); - final ArrayData $values = $input.valueArray(); + val writeSafeMapToBuffer = s""" + final ArrayData $keys = $input.keyArray(); + final ArrayData $values = $input.valueArray(); - // preserve 4 bytes to write the key array numBytes later. - $bufferHolder.grow(4); - $bufferHolder.cursor += 4; + // preserve 4 bytes to write the key array numBytes later. + $bufferHolder.grow(4); + $bufferHolder.cursor += 4; - // Remember the current cursor so that we can write numBytes of key array later. - final int $tmpCursor = $bufferHolder.cursor; + // Remember the current cursor so that we can write numBytes of key array later. + final int $tmpCursor = $bufferHolder.cursor; - ${writeArrayToBuffer(ctx, keys, keyType, bufferHolder)} - // Write the numBytes of key array into the first 4 bytes. - Platform.putInt($bufferHolder.buffer, $tmpCursor - 4, $bufferHolder.cursor - $tmpCursor); + ${writeArrayToBuffer(ctx, keys, keyType, bufferHolder)} + // Write the numBytes of key array into the first 4 bytes. + Platform.putInt($bufferHolder.buffer, $tmpCursor - 4, $bufferHolder.cursor - $tmpCursor); - ${writeArrayToBuffer(ctx, values, valueType, bufferHolder)} - } + ${writeArrayToBuffer(ctx, values, valueType, bufferHolder)} """ + if (ctx.genericWriteBuffer) { + writeSafeMapToBuffer + } else { + s""" + if ($input instanceof UnsafeMapData) { + ${writeUnsafeData(ctx, s"((UnsafeMapData) $input)", bufferHolder)} + } else { + ${writeSafeMapToBuffer} + } + """ + } } /** diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala index d603d3c73ecbc..bc04e1568fd45 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala @@ -50,6 +50,7 @@ case class CreateArray(children: Seq[Expression]) extends Expression { override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { val arrayClass = classOf[GenericArrayData].getName + ctx.genericWriteBuffer = true val values = ctx.freshName("values") ctx.addMutableState("Object[]", values, s"this.$values = null;") @@ -60,18 +61,24 @@ case class CreateArray(children: Seq[Expression]) extends Expression { ctx.INPUT_ROW, children.zipWithIndex.map { case (e, i) => val eval = e.genCode(ctx) - eval.code + s""" - if (${eval.isNull}) { - $values[$i] = null; + eval.code + ( + if (eval.isNull == "false") { + s"\n$values[$i] = ${eval.value};" } else { - $values[$i] = ${eval.value}; - } - """ + s""" + if (${eval.isNull}) { + $values[$i] = null; + } else { + $values[$i] = ${eval.value}; + } + """ + }) }) + s""" final ArrayData ${ev.value} = new $arrayClass($values); this.$values = null; - """) + """, + isNull = "false") } override def prettyName: String = "array" @@ -124,6 +131,7 @@ case class CreateMap(children: Seq[Expression]) extends Expression { override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { val arrayClass = classOf[GenericArrayData].getName val mapClass = classOf[ArrayBasedMapData].getName + ctx.genericWriteBuffer = true val keyArray = ctx.freshName("keyArray") val valueArray = ctx.freshName("valueArray") ctx.addMutableState("Object[]", keyArray, s"this.$keyArray = null;") @@ -152,20 +160,25 @@ case class CreateMap(children: Seq[Expression]) extends Expression { ctx.INPUT_ROW, values.zipWithIndex.map { case (value, i) => val eval = value.genCode(ctx) - s""" - ${eval.code} - if (${eval.isNull}) { - $valueArray[$i] = null; + eval.code + ( + if (eval.isNull == "false") { + s"\n$valueArray[$i] = ${eval.value};" } else { - $valueArray[$i] = ${eval.value}; - } - """ + s""" + if (${eval.isNull}) { + $valueArray[$i] = null; + } else { + $valueArray[$i] = ${eval.value}; + } + """ + }) }) + s""" final MapData ${ev.value} = new $mapClass($keyData, $valueData); this.$keyArray = null; this.$valueArray = null; - """) + """, + isNull = "false") } override def prettyName: String = "map" @@ -200,6 +213,7 @@ case class CreateStruct(children: Seq[Expression]) extends Expression { override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { val rowClass = classOf[GenericInternalRow].getName + ctx.genericWriteBuffer = true val values = ctx.freshName("values") ctx.addMutableState("Object[]", values, s"this.$values = null;") @@ -210,17 +224,24 @@ case class CreateStruct(children: Seq[Expression]) extends Expression { ctx.INPUT_ROW, children.zipWithIndex.map { case (e, i) => val eval = e.genCode(ctx) - eval.code + s""" - if (${eval.isNull}) { - $values[$i] = null; + eval.code + ( + if (eval.isNull == "false") { + s"\n$values[$i] = ${eval.value};" } else { - $values[$i] = ${eval.value}; - }""" + s""" + if (${eval.isNull}) { + $values[$i] = null; + } else { + $values[$i] = ${eval.value}; + } + """ + }) }) + s""" final InternalRow ${ev.value} = new $rowClass($values); this.$values = null; - """) + """, + isNull = "false") } override def prettyName: String = "struct" @@ -290,6 +311,7 @@ case class CreateNamedStruct(children: Seq[Expression]) extends Expression { override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { val rowClass = classOf[GenericInternalRow].getName + ctx.genericWriteBuffer = true val values = ctx.freshName("values") ctx.addMutableState("Object[]", values, s"this.$values = null;") @@ -300,17 +322,24 @@ case class CreateNamedStruct(children: Seq[Expression]) extends Expression { ctx.INPUT_ROW, valExprs.zipWithIndex.map { case (e, i) => val eval = e.genCode(ctx) - eval.code + s""" - if (${eval.isNull}) { - $values[$i] = null; - } else { - $values[$i] = ${eval.value}; - }""" + eval.code + ( + if (eval.isNull == "false") { + s"\n$values[$i] = ${eval.value};" + } else { + s""" + if (${eval.isNull}) { + $values[$i] = null; + } else { + $values[$i] = ${eval.value}; + } + """ + }) }) + s""" final InternalRow ${ev.value} = new $rowClass($values); this.$values = null; - """) + """, + isNull = "false") } override def prettyName: String = "named_struct" diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameComplexTypeSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameComplexTypeSuite.scala index 72f676e6225ee..2bc1be3a43c71 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameComplexTypeSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameComplexTypeSuite.scala @@ -26,6 +26,38 @@ import org.apache.spark.sql.test.SharedSQLContext class DataFrameComplexTypeSuite extends QueryTest with SharedSQLContext { import testImplicits._ + test("primitive type on array") { + val df = sparkContext.parallelize(Seq(1, 2), 1).toDF("v") + df.selectExpr("Array(v + 2, v + 3)").collect + } + + test("array on array") { + val df = sparkContext.parallelize(Seq(1, 2), 1).toDF("v") + df.selectExpr("Array(Array(v, v + 1, v + 2)," + + "null," + + "Array(v, v - 1, v - 2))").collect + } + + test("primitive type on map") { + val df = sparkContext.parallelize(Seq(1, 2), 1).toDF("v") + df.selectExpr("map(v + 3, v + 4)").collect + } + + test("map on map") { + val df = sparkContext.parallelize(Seq(1, 2), 1).toDF("v") + df.selectExpr("map(map(v, v + 3), map(v, v + 4))").collect + } + + test("primitive type on struct") { + val df = sparkContext.parallelize(Seq(1, 2), 1).toDF("v") + df.selectExpr("struct(v + 3, v + 4)").collect + } + + test("struct on struct") { + val df = sparkContext.parallelize(Seq(1, 2), 1).toDF("v") + df.selectExpr("struct(struct(v + 3), null, struct(v + 4))").collect + } + test("UDF on struct") { val f = udf((a: String) => a) val df = sparkContext.parallelize(Seq((1, 1))).toDF("a", "b") From 4458f4f9c0b9d08937b1e2f4fe028cdd2969fc9a Mon Sep 17 00:00:00 2001 From: Kazuaki Ishizaki Date: Wed, 15 Jun 2016 21:39:50 +0900 Subject: [PATCH 2/5] remove a global variable from CodegenContext --- .../expressions/codegen/CodeGenerator.scala | 7 -- .../codegen/GenerateUnsafeProjection.scala | 95 ++++++++----------- .../expressions/complexTypeCreator.scala | 4 - 3 files changed, 38 insertions(+), 68 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala index 79a5568511486..ff97cd321199a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala @@ -108,13 +108,6 @@ class CodegenContext { */ var copyResult: Boolean = false - /** - * Whether should we write complex type data only to generic structure - * - * If we write complex type data only to a generic structure at projection, set true to this var - */ - var genericWriteBuffer: Boolean = false - /** * Holding expressions' mutable states like `MonotonicallyIncreasingID.count` as a * 3-tuple: java type, variable name, code to init it. diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala index 61a840af35a46..5efba4b3a6087 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala @@ -56,17 +56,13 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro ExprCode(code, isNull, fieldName) } - if (ctx.genericWriteBuffer) { - s"${writeExpressionsToBuffer(ctx, input, fieldEvals, fieldTypes, bufferHolder)}" - } else { - s""" - if ($input instanceof UnsafeRow) { - ${writeUnsafeData(ctx, s"((UnsafeRow) $input)", bufferHolder)} - } else { - ${writeExpressionsToBuffer(ctx, input, fieldEvals, fieldTypes, bufferHolder)} - } - """ - } + s""" + if ($input instanceof UnsafeRow) { + ${writeUnsafeData(ctx, s"((UnsafeRow) $input)", bufferHolder)} + } else { + ${writeExpressionsToBuffer(ctx, input, fieldEvals, fieldTypes, bufferHolder)} + } + """ } private def writeExpressionsToBuffer( @@ -226,31 +222,23 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro case _ => s"$arrayWriter.write($index, $element);" } - - val writeSafeArrayToBuffer = s""" - final int $numElements = $input.numElements(); - $arrayWriter.initialize($bufferHolder, $numElements, $fixedElementSize); - - for (int $index = 0; $index < $numElements; $index++) { - if ($input.isNullAt($index)) { - $arrayWriter.setNullAt($index); - } else { - final $jt $element = ${ctx.getValue(input, et, index)}; - $writeElement + s""" + if ($input instanceof UnsafeArrayData) { + ${writeUnsafeData(ctx, s"((UnsafeArrayData) $input)", bufferHolder)} + } else { + final int $numElements = $input.numElements(); + $arrayWriter.initialize($bufferHolder, $numElements, $fixedElementSize); + + for (int $index = 0; $index < $numElements; $index++) { + if ($input.isNullAt($index)) { + $arrayWriter.setNullAt($index); + } else { + final $jt $element = ${ctx.getValue(input, et, index)}; + $writeElement + } } } - """ - if (ctx.genericWriteBuffer) { - writeSafeArrayToBuffer - } else { - s""" - if ($input instanceof UnsafeArrayData) { - ${writeUnsafeData(ctx, s"((UnsafeArrayData) $input)", bufferHolder)} - } else { - ${writeSafeArrayToBuffer} - } - """ - } + """ } // TODO: if the nullability of value element is correct, we can use it to save null check. @@ -266,34 +254,27 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro // Writes out unsafe map according to the format described in `UnsafeMapData`. - val writeSafeMapToBuffer = s""" - final ArrayData $keys = $input.keyArray(); - final ArrayData $values = $input.valueArray(); + s""" + if ($input instanceof UnsafeMapData) { + ${writeUnsafeData(ctx, s"((UnsafeMapData) $input)", bufferHolder)} + } else { + final ArrayData $keys = $input.keyArray(); + final ArrayData $values = $input.valueArray(); - // preserve 4 bytes to write the key array numBytes later. - $bufferHolder.grow(4); - $bufferHolder.cursor += 4; + // preserve 4 bytes to write the key array numBytes later. + $bufferHolder.grow(4); + $bufferHolder.cursor += 4; - // Remember the current cursor so that we can write numBytes of key array later. - final int $tmpCursor = $bufferHolder.cursor; + // Remember the current cursor so that we can write numBytes of key array later. + final int $tmpCursor = $bufferHolder.cursor; - ${writeArrayToBuffer(ctx, keys, keyType, bufferHolder)} - // Write the numBytes of key array into the first 4 bytes. - Platform.putInt($bufferHolder.buffer, $tmpCursor - 4, $bufferHolder.cursor - $tmpCursor); + ${writeArrayToBuffer(ctx, keys, keyType, bufferHolder)} + // Write the numBytes of key array into the first 4 bytes. + Platform.putInt($bufferHolder.buffer, $tmpCursor - 4, $bufferHolder.cursor - $tmpCursor); - ${writeArrayToBuffer(ctx, values, valueType, bufferHolder)} + ${writeArrayToBuffer(ctx, values, valueType, bufferHolder)} + } """ - if (ctx.genericWriteBuffer) { - writeSafeMapToBuffer - } else { - s""" - if ($input instanceof UnsafeMapData) { - ${writeUnsafeData(ctx, s"((UnsafeMapData) $input)", bufferHolder)} - } else { - ${writeSafeMapToBuffer} - } - """ - } } /** diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala index bc04e1568fd45..35342e9c840f3 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala @@ -50,7 +50,6 @@ case class CreateArray(children: Seq[Expression]) extends Expression { override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { val arrayClass = classOf[GenericArrayData].getName - ctx.genericWriteBuffer = true val values = ctx.freshName("values") ctx.addMutableState("Object[]", values, s"this.$values = null;") @@ -131,7 +130,6 @@ case class CreateMap(children: Seq[Expression]) extends Expression { override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { val arrayClass = classOf[GenericArrayData].getName val mapClass = classOf[ArrayBasedMapData].getName - ctx.genericWriteBuffer = true val keyArray = ctx.freshName("keyArray") val valueArray = ctx.freshName("valueArray") ctx.addMutableState("Object[]", keyArray, s"this.$keyArray = null;") @@ -213,7 +211,6 @@ case class CreateStruct(children: Seq[Expression]) extends Expression { override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { val rowClass = classOf[GenericInternalRow].getName - ctx.genericWriteBuffer = true val values = ctx.freshName("values") ctx.addMutableState("Object[]", values, s"this.$values = null;") @@ -311,7 +308,6 @@ case class CreateNamedStruct(children: Seq[Expression]) extends Expression { override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { val rowClass = classOf[GenericInternalRow].getName - ctx.genericWriteBuffer = true val values = ctx.freshName("values") ctx.addMutableState("Object[]", values, s"this.$values = null;") From a19300736210ed30663ec00c531b6c40d4127a92 Mon Sep 17 00:00:00 2001 From: Kazuaki Ishizaki Date: Thu, 16 Jun 2016 03:17:53 +0900 Subject: [PATCH 3/5] remove unreachable code elimination --- .../expressions/complexTypeCreator.scala | 69 +++++++------------ 1 file changed, 24 insertions(+), 45 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala index 35342e9c840f3..b2a12b6e2bcd4 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala @@ -60,18 +60,13 @@ case class CreateArray(children: Seq[Expression]) extends Expression { ctx.INPUT_ROW, children.zipWithIndex.map { case (e, i) => val eval = e.genCode(ctx) - eval.code + ( - if (eval.isNull == "false") { - s"\n$values[$i] = ${eval.value};" + eval.code + s""" + if (${eval.isNull}) { + $values[$i] = null; } else { - s""" - if (${eval.isNull}) { - $values[$i] = null; - } else { - $values[$i] = ${eval.value}; - } - """ - }) + $values[$i] = ${eval.value}; + } + """ }) + s""" final ArrayData ${ev.value} = new $arrayClass($values); @@ -158,18 +153,14 @@ case class CreateMap(children: Seq[Expression]) extends Expression { ctx.INPUT_ROW, values.zipWithIndex.map { case (value, i) => val eval = value.genCode(ctx) - eval.code + ( - if (eval.isNull == "false") { - s"\n$valueArray[$i] = ${eval.value};" + s""" + ${eval.code} + if (${eval.isNull}) { + $valueArray[$i] = null; } else { - s""" - if (${eval.isNull}) { - $valueArray[$i] = null; - } else { - $valueArray[$i] = ${eval.value}; - } - """ - }) + $valueArray[$i] = ${eval.value}; + } + """ }) + s""" final MapData ${ev.value} = new $mapClass($keyData, $valueData); @@ -221,18 +212,12 @@ case class CreateStruct(children: Seq[Expression]) extends Expression { ctx.INPUT_ROW, children.zipWithIndex.map { case (e, i) => val eval = e.genCode(ctx) - eval.code + ( - if (eval.isNull == "false") { - s"\n$values[$i] = ${eval.value};" + eval.code + s""" + if (${eval.isNull}) { + $values[$i] = null; } else { - s""" - if (${eval.isNull}) { - $values[$i] = null; - } else { - $values[$i] = ${eval.value}; - } - """ - }) + $values[$i] = ${eval.value}; + }""" }) + s""" final InternalRow ${ev.value} = new $rowClass($values); @@ -318,18 +303,12 @@ case class CreateNamedStruct(children: Seq[Expression]) extends Expression { ctx.INPUT_ROW, valExprs.zipWithIndex.map { case (e, i) => val eval = e.genCode(ctx) - eval.code + ( - if (eval.isNull == "false") { - s"\n$values[$i] = ${eval.value};" - } else { - s""" - if (${eval.isNull}) { - $values[$i] = null; - } else { - $values[$i] = ${eval.value}; - } - """ - }) + eval.code + s""" + if (${eval.isNull}) { + $values[$i] = null; + } else { + $values[$i] = ${eval.value}; + }""" }) + s""" final InternalRow ${ev.value} = new $rowClass($values); From 8d7d311c5808f983e5de19f50c5034e89e6b7629 Mon Sep 17 00:00:00 2001 From: Kazuaki Ishizaki Date: Thu, 16 Jun 2016 18:34:31 +0900 Subject: [PATCH 4/5] remove an unnecessary assignment add unit tests to check elimination of zeroOutNullBytes --- .../expressions/complexTypeCreator.scala | 4 -- .../spark/sql/DataFrameComplexTypeSuite.scala | 46 +++++++++++++------ 2 files changed, 32 insertions(+), 18 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala index b2a12b6e2bcd4..9e4a8a5a6ab54 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala @@ -54,7 +54,6 @@ case class CreateArray(children: Seq[Expression]) extends Expression { ctx.addMutableState("Object[]", values, s"this.$values = null;") ev.copy(code = s""" - final boolean ${ev.isNull} = false; this.$values = new Object[${children.size}];""" + ctx.splitExpressions( ctx.INPUT_ROW, @@ -133,7 +132,6 @@ case class CreateMap(children: Seq[Expression]) extends Expression { val keyData = s"new $arrayClass($keyArray)" val valueData = s"new $arrayClass($valueArray)" ev.copy(code = s""" - final boolean ${ev.isNull} = false; $keyArray = new Object[${keys.size}]; $valueArray = new Object[${values.size}];""" + ctx.splitExpressions( @@ -206,7 +204,6 @@ case class CreateStruct(children: Seq[Expression]) extends Expression { ctx.addMutableState("Object[]", values, s"this.$values = null;") ev.copy(code = s""" - boolean ${ev.isNull} = false; this.$values = new Object[${children.size}];""" + ctx.splitExpressions( ctx.INPUT_ROW, @@ -297,7 +294,6 @@ case class CreateNamedStruct(children: Seq[Expression]) extends Expression { ctx.addMutableState("Object[]", values, s"this.$values = null;") ev.copy(code = s""" - boolean ${ev.isNull} = false; $values = new Object[${valExprs.size}];""" + ctx.splitExpressions( ctx.INPUT_ROW, diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameComplexTypeSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameComplexTypeSuite.scala index 2bc1be3a43c71..cb1447eb7888f 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameComplexTypeSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameComplexTypeSuite.scala @@ -17,6 +17,7 @@ package org.apache.spark.sql +import org.apache.spark.sql.execution.debug.codegenString import org.apache.spark.sql.functions._ import org.apache.spark.sql.test.SharedSQLContext @@ -26,36 +27,53 @@ import org.apache.spark.sql.test.SharedSQLContext class DataFrameComplexTypeSuite extends QueryTest with SharedSQLContext { import testImplicits._ - test("primitive type on array") { + test("create an array") { val df = sparkContext.parallelize(Seq(1, 2), 1).toDF("v") - df.selectExpr("Array(v + 2, v + 3)").collect + df.selectExpr("Array(v + 3, v + 4)").collect } - test("array on array") { + test("create an map") { val df = sparkContext.parallelize(Seq(1, 2), 1).toDF("v") - df.selectExpr("Array(Array(v, v + 1, v + 2)," + - "null," + - "Array(v, v - 1, v - 2))").collect + df.selectExpr("map(v + 3, v + 4)").collect } - test("primitive type on map") { + test("create an struct") { val df = sparkContext.parallelize(Seq(1, 2), 1).toDF("v") - df.selectExpr("map(v + 3, v + 4)").collect + df.selectExpr("struct(v + 3, v + 4)").collect + } + + def validate(df: DataFrame): Unit = { + val logicalPlan = df.logicalPlan + val queryExecution = sqlContext.sessionState.executePlan(logicalPlan) + val cg = codegenString(queryExecution.executedPlan) + + if (cg.contains("Found 0 WholeStageCodegen subtrees")) { + return + } + + if ("zeroOutNullBytes".r.findFirstIn(cg).isDefined) { + fail( + s""" + |=== FAIL: generated code must not include: zeroOutNullBytes === + |$cg + """.stripMargin + ) + } } - test("map on map") { + test ("check elimination of zeroOutNullBytes on array") { val df = sparkContext.parallelize(Seq(1, 2), 1).toDF("v") - df.selectExpr("map(map(v, v + 3), map(v, v + 4))").collect + validate(df.selectExpr("Array(v + 3, v + 4)")) } - test("primitive type on struct") { + test ("check elimination of zeroOutNullBytes on map") { val df = sparkContext.parallelize(Seq(1, 2), 1).toDF("v") - df.selectExpr("struct(v + 3, v + 4)").collect + validate(df.selectExpr("struct(v + 3, v + 4)")) } - test("struct on struct") { + test ("check elimination of zeroOutNullBytes on struct") { val df = sparkContext.parallelize(Seq(1, 2), 1).toDF("v") - df.selectExpr("struct(struct(v + 3), null, struct(v + 4))").collect + validate(df.selectExpr("Array(v + 3, v + 4)")) } test("UDF on struct") { From 7ad542d81005bf074f6df8404b3cb228d277621c Mon Sep 17 00:00:00 2001 From: Kazuaki Ishizaki Date: Wed, 22 Jun 2016 01:37:52 +0900 Subject: [PATCH 5/5] drop three unit tests --- .../spark/sql/DataFrameComplexTypeSuite.scala | 15 --------------- 1 file changed, 15 deletions(-) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameComplexTypeSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameComplexTypeSuite.scala index cb1447eb7888f..0f4f9440aa3a8 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameComplexTypeSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameComplexTypeSuite.scala @@ -27,21 +27,6 @@ import org.apache.spark.sql.test.SharedSQLContext class DataFrameComplexTypeSuite extends QueryTest with SharedSQLContext { import testImplicits._ - test("create an array") { - val df = sparkContext.parallelize(Seq(1, 2), 1).toDF("v") - df.selectExpr("Array(v + 3, v + 4)").collect - } - - test("create an map") { - val df = sparkContext.parallelize(Seq(1, 2), 1).toDF("v") - df.selectExpr("map(v + 3, v + 4)").collect - } - - test("create an struct") { - val df = sparkContext.parallelize(Seq(1, 2), 1).toDF("v") - df.selectExpr("struct(v + 3, v + 4)").collect - } - def validate(df: DataFrame): Unit = { val logicalPlan = df.logicalPlan val queryExecution = sqlContext.sessionState.executePlan(logicalPlan)