diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala index bf96248feaef7..da3103b4ebb6b 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala @@ -84,11 +84,11 @@ class CodeGenContext { /** * Holding all the functions those will be added into generated class. */ - val addedFuntions: mutable.Map[String, String] = + val addedFunctions: mutable.Map[String, String] = mutable.Map.empty[String, String] def addNewFunction(funcName: String, funcCode: String): Unit = { - addedFuntions += ((funcName, funcCode)) + addedFunctions += ((funcName, funcCode)) } final val JAVA_BOOLEAN = "boolean" @@ -298,8 +298,8 @@ class CodeGenContext { | $body |} """.stripMargin - addNewFunction(name, code) - name + addNewFunction(name, code) + name } functions.map(name => s"$name($row);").mkString("\n") @@ -337,7 +337,7 @@ abstract class CodeGenerator[InType <: AnyRef, OutType <: AnyRef] extends Loggin } protected def declareAddedFunctions(ctx: CodeGenContext): String = { - ctx.addedFuntions.map { case (funcName, funcCode) => funcCode }.mkString("\n") + ctx.addedFunctions.map { case (funcName, funcCode) => funcCode }.mkString("\n") } /** diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateMutableProjection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateMutableProjection.scala index b4d4df8934bd4..793023b9fbed3 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateMutableProjection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateMutableProjection.scala @@ -17,8 +17,6 @@ package org.apache.spark.sql.catalyst.expressions.codegen -import scala.collection.mutable.ArrayBuffer - import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.aggregate.NoOp import org.apache.spark.sql.types.DecimalType diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateProjection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateProjection.scala index c744e84d822e8..2164ddf03d1b2 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateProjection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateProjection.scala @@ -48,7 +48,7 @@ object GenerateProjection extends CodeGenerator[Seq[Expression], Projection] { val columns = expressions.zipWithIndex.map { case (e, i) => s"private ${ctx.javaType(e.dataType)} c$i = ${ctx.defaultValue(e.dataType)};\n" - }.mkString("\n ") + }.mkString("\n") val initColumns = expressions.zipWithIndex.map { case (e, i) => @@ -67,18 +67,18 @@ object GenerateProjection extends CodeGenerator[Seq[Expression], Projection] { val getCases = (0 until expressions.size).map { i => s"case $i: return c$i;" - }.mkString("\n ") + }.mkString("\n") val updateCases = expressions.zipWithIndex.map { case (e, i) => s"case $i: { c$i = (${ctx.boxedType(e.dataType)})value; return;}" - }.mkString("\n ") + }.mkString("\n") val specificAccessorFunctions = ctx.primitiveTypes.map { jt => val cases = expressions.zipWithIndex.flatMap { case (e, i) if ctx.javaType(e.dataType) == jt => Some(s"case $i: return c$i;") case _ => None - }.mkString("\n ") + }.mkString("\n") if (cases.length > 0) { val getter = "get" + ctx.primitiveTypeName(jt) s""" @@ -103,7 +103,7 @@ object GenerateProjection extends CodeGenerator[Seq[Expression], Projection] { case (e, i) if ctx.javaType(e.dataType) == jt => Some(s"case $i: { c$i = value; return; }") case _ => None - }.mkString("\n ") + }.mkString("\n") if (cases.length > 0) { val setter = "set" + ctx.primitiveTypeName(jt) s""" @@ -152,7 +152,7 @@ object GenerateProjection extends CodeGenerator[Seq[Expression], Projection] { val copyColumns = expressions.zipWithIndex.map { case (e, i) => s"""if (!nullBits[$i]) arr[$i] = c$i;""" - }.mkString("\n ") + }.mkString("\n") val code = s""" public SpecificProjection generate($exprType[] expr) { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala index b570fe86db1aa..03c5f449bf9ac 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala @@ -134,7 +134,7 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro ctx.addMutableState("byte[]", buffer, s"this.$buffer = new byte[$fixedSize];") val cursor = ctx.freshName("cursor") ctx.addMutableState("int", cursor, s"this.$cursor = 0;") - val tmp = ctx.freshName("tmpBuffer") + val tmpBuffer = ctx.freshName("tmpBuffer") val convertedFields = inputTypes.zip(inputs).zipWithIndex.map { case ((dt, input), i) => val ev = createConvertCode(ctx, input, dt) @@ -144,10 +144,10 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro int $numBytes = $cursor + (${genAdditionalSize(dt, ev)}); if ($buffer.length < $numBytes) { // This will not happen frequently, because the buffer is re-used. - byte[] $tmp = new byte[$numBytes * 2]; + byte[] $tmpBuffer = new byte[$numBytes * 2]; Platform.copyMemory($buffer, Platform.BYTE_ARRAY_OFFSET, - $tmp, Platform.BYTE_ARRAY_OFFSET, $buffer.length); - $buffer = $tmp; + $tmpBuffer, Platform.BYTE_ARRAY_OFFSET, $buffer.length); + $buffer = $tmpBuffer; } $output.pointTo($buffer, Platform.BYTE_ARRAY_OFFSET, ${inputTypes.length}, $numBytes); """ @@ -207,20 +207,22 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro val buffer = ctx.freshName("buffer") ctx.addMutableState("byte[]", buffer, s"$buffer = new byte[64];") val outputIsNull = ctx.freshName("isNull") - val tmp = ctx.freshName("tmp") val numElements = ctx.freshName("numElements") val fixedSize = ctx.freshName("fixedSize") val numBytes = ctx.freshName("numBytes") val elements = ctx.freshName("elements") val cursor = ctx.freshName("cursor") val index = ctx.freshName("index") + val elementName = ctx.freshName("elementName") - val element = GeneratedExpressionCode( - code = "", - isNull = s"$tmp.isNullAt($index)", - primitive = s"${ctx.getValue(tmp, elementType, index)}" - ) - val convertedElement: GeneratedExpressionCode = createConvertCode(ctx, element, elementType) + val element = { + val code = s"${ctx.javaType(elementType)} $elementName = " + + s"${ctx.getValue(input.primitive, elementType, index)};" + val isNull = s"${input.primitive}.isNullAt($index)" + GeneratedExpressionCode(code, isNull, elementName) + } + + val convertedElement = createConvertCode(ctx, element, elementType) // go through the input array to calculate how many bytes we need. val calculateNumBytes = elementType match { @@ -272,6 +274,7 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro // Should we do word align? val elementSize = elementType.defaultSize s""" + ${convertedElement.code} Platform.put${ctx.primitiveTypeName(elementType)}( $buffer, Platform.BYTE_ARRAY_OFFSET + $cursor, @@ -280,6 +283,7 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro """ case t: DecimalType if t.precision <= Decimal.MAX_LONG_DIGITS => s""" + ${convertedElement.code} Platform.putLong( $buffer, Platform.BYTE_ARRAY_OFFSET + $cursor, @@ -307,11 +311,10 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro ${input.code} final boolean $outputIsNull = ${input.isNull}; if (!$outputIsNull) { - final ArrayData $tmp = ${input.primitive}; - if ($tmp instanceof UnsafeArrayData) { - $output = (UnsafeArrayData) $tmp; + if (${input.primitive} instanceof UnsafeArrayData) { + $output = (UnsafeArrayData) ${input.primitive}; } else { - final int $numElements = $tmp.numElements(); + final int $numElements = ${input.primitive}.numElements(); final int $fixedSize = 4 * $numElements; int $numBytes = $fixedSize; @@ -350,29 +353,31 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro valueType: DataType): GeneratedExpressionCode = { val output = ctx.freshName("convertedMap") val outputIsNull = ctx.freshName("isNull") - val tmp = ctx.freshName("tmp") - - val keyArray = GeneratedExpressionCode( - code = "", - isNull = "false", - primitive = s"$tmp.keyArray()" - ) - val valueArray = GeneratedExpressionCode( - code = "", - isNull = "false", - primitive = s"$tmp.valueArray()" - ) - val convertedKeys: GeneratedExpressionCode = createCodeForArray(ctx, keyArray, keyType) - val convertedValues: GeneratedExpressionCode = createCodeForArray(ctx, valueArray, valueType) + val keyArrayName = ctx.freshName("keyArrayName") + val valueArrayName = ctx.freshName("valueArrayName") + + val keyArray = { + val code = s"ArrayData $keyArrayName = ${input.primitive}.keyArray();" + val isNull = "false" + GeneratedExpressionCode(code, isNull, keyArrayName) + } + + val valueArray = { + val code = s"ArrayData $valueArrayName = ${input.primitive}.valueArray();" + val isNull = "false" + GeneratedExpressionCode(code, isNull, valueArrayName) + } + + val convertedKeys = createCodeForArray(ctx, keyArray, keyType) + val convertedValues = createCodeForArray(ctx, valueArray, valueType) val code = s""" ${input.code} final boolean $outputIsNull = ${input.isNull}; UnsafeMapData $output = null; if (!$outputIsNull) { - final MapData $tmp = ${input.primitive}; - if ($tmp instanceof UnsafeMapData) { - $output = (UnsafeMapData) $tmp; + if (${input.primitive} instanceof UnsafeMapData) { + $output = (UnsafeMapData) ${input.primitive}; } else { ${convertedKeys.code} ${convertedValues.code} @@ -393,22 +398,22 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro case t: StructType => val output = ctx.freshName("convertedStruct") val outputIsNull = ctx.freshName("isNull") - val tmp = ctx.freshName("tmp") val fieldTypes = t.fields.map(_.dataType) val fieldEvals = fieldTypes.zipWithIndex.map { case (dt, i) => - val getFieldCode = ctx.getValue(tmp, dt, i.toString) - val fieldIsNull = s"$tmp.isNullAt($i)" - GeneratedExpressionCode("", fieldIsNull, getFieldCode) + val fieldName = ctx.freshName("fieldName") + val code = s"${ctx.javaType(dt)} $fieldName = " + + s"${ctx.getValue(input.primitive, dt, i.toString)};" + val isNull = s"${input.primitive}.isNullAt($i)" + GeneratedExpressionCode(code, isNull, fieldName) } - val converter = createCodeForStruct(ctx, tmp, fieldEvals, fieldTypes) + val converter = createCodeForStruct(ctx, input.primitive, fieldEvals, fieldTypes) val code = s""" ${input.code} UnsafeRow $output = null; final boolean $outputIsNull = ${input.isNull}; if (!$outputIsNull) { - final InternalRow $tmp = ${input.primitive}; - if ($tmp instanceof UnsafeRow) { - $output = (UnsafeRow) $tmp; + if (${input.primitive} instanceof UnsafeRow) { + $output = (UnsafeRow) ${input.primitive}; } else { ${converter.code} $output = ${converter.primitive}; diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala index 1c546719730b7..82eab5fb3d03a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala @@ -48,21 +48,22 @@ case class CreateArray(children: Seq[Expression]) extends Expression { override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { val arrayClass = classOf[GenericArrayData].getName + val values = ctx.freshName("values") s""" final boolean ${ev.isNull} = false; - final Object[] values = new Object[${children.size}]; + final Object[] $values = new Object[${children.size}]; """ + children.zipWithIndex.map { case (e, i) => val eval = e.gen(ctx) eval.code + s""" if (${eval.isNull}) { - values[$i] = null; + $values[$i] = null; } else { - values[$i] = ${eval.primitive}; + $values[$i] = ${eval.primitive}; } """ }.mkString("\n") + - s"final ${ctx.javaType(dataType)} ${ev.primitive} = new $arrayClass(values);" + s"final ArrayData ${ev.primitive} = new $arrayClass($values);" } override def prettyName: String = "array" @@ -94,21 +95,23 @@ case class CreateStruct(children: Seq[Expression]) extends Expression { } override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { - val rowClass = classOf[GenericMutableRow].getName + val rowClass = classOf[GenericInternalRow].getName + val values = ctx.freshName("values") s""" boolean ${ev.isNull} = false; - final $rowClass ${ev.primitive} = new $rowClass(${children.size}); + final Object[] $values = new Object[${children.size}]; """ + children.zipWithIndex.map { case (e, i) => val eval = e.gen(ctx) eval.code + s""" if (${eval.isNull}) { - ${ev.primitive}.update($i, null); + $values[$i] = null; } else { - ${ev.primitive}.update($i, ${eval.primitive}); + $values[$i] = ${eval.primitive}; } """ - }.mkString("\n") + }.mkString("\n") + + s"final InternalRow ${ev.primitive} = new $rowClass($values);" } override def prettyName: String = "struct" @@ -161,21 +164,23 @@ case class CreateNamedStruct(children: Seq[Expression]) extends Expression { } override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { - val rowClass = classOf[GenericMutableRow].getName + val rowClass = classOf[GenericInternalRow].getName + val values = ctx.freshName("values") s""" boolean ${ev.isNull} = false; - final $rowClass ${ev.primitive} = new $rowClass(${valExprs.size}); + final Object[] $values = new Object[${valExprs.size}]; """ + valExprs.zipWithIndex.map { case (e, i) => val eval = e.gen(ctx) eval.code + s""" if (${eval.isNull}) { - ${ev.primitive}.update($i, null); + $values[$i] = null; } else { - ${ev.primitive}.update($i, ${eval.primitive}); + $values[$i] = ${eval.primitive}; } """ - }.mkString("\n") + }.mkString("\n") + + s"final InternalRow ${ev.primitive} = new $rowClass($values);" } override def prettyName: String = "named_struct"