Skip to content

Commit

Permalink
make sure input.primitive is always variable name not code
Browse files Browse the repository at this point in the history
  • Loading branch information
cloud-fan committed Sep 5, 2015
1 parent bca8c07 commit 96ed788
Show file tree
Hide file tree
Showing 5 changed files with 75 additions and 67 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -84,11 +84,11 @@ class CodeGenContext {
/**
* Holding all the functions those will be added into generated class.
*/
val addedFuntions: mutable.Map[String, String] =
val addedFunctions: mutable.Map[String, String] =
mutable.Map.empty[String, String]

def addNewFunction(funcName: String, funcCode: String): Unit = {
addedFuntions += ((funcName, funcCode))
addedFunctions += ((funcName, funcCode))
}

final val JAVA_BOOLEAN = "boolean"
Expand Down Expand Up @@ -298,8 +298,8 @@ class CodeGenContext {
| $body
|}
""".stripMargin
addNewFunction(name, code)
name
addNewFunction(name, code)
name
}

functions.map(name => s"$name($row);").mkString("\n")
Expand Down Expand Up @@ -337,7 +337,7 @@ abstract class CodeGenerator[InType <: AnyRef, OutType <: AnyRef] extends Loggin
}

protected def declareAddedFunctions(ctx: CodeGenContext): String = {
ctx.addedFuntions.map { case (funcName, funcCode) => funcCode }.mkString("\n")
ctx.addedFunctions.map { case (funcName, funcCode) => funcCode }.mkString("\n")
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,6 @@

package org.apache.spark.sql.catalyst.expressions.codegen

import scala.collection.mutable.ArrayBuffer

import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.expressions.aggregate.NoOp
import org.apache.spark.sql.types.DecimalType
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ object GenerateProjection extends CodeGenerator[Seq[Expression], Projection] {
val columns = expressions.zipWithIndex.map {
case (e, i) =>
s"private ${ctx.javaType(e.dataType)} c$i = ${ctx.defaultValue(e.dataType)};\n"
}.mkString("\n ")
}.mkString("\n")

val initColumns = expressions.zipWithIndex.map {
case (e, i) =>
Expand All @@ -67,18 +67,18 @@ object GenerateProjection extends CodeGenerator[Seq[Expression], Projection] {

val getCases = (0 until expressions.size).map { i =>
s"case $i: return c$i;"
}.mkString("\n ")
}.mkString("\n")

val updateCases = expressions.zipWithIndex.map { case (e, i) =>
s"case $i: { c$i = (${ctx.boxedType(e.dataType)})value; return;}"
}.mkString("\n ")
}.mkString("\n")

val specificAccessorFunctions = ctx.primitiveTypes.map { jt =>
val cases = expressions.zipWithIndex.flatMap {
case (e, i) if ctx.javaType(e.dataType) == jt =>
Some(s"case $i: return c$i;")
case _ => None
}.mkString("\n ")
}.mkString("\n")
if (cases.length > 0) {
val getter = "get" + ctx.primitiveTypeName(jt)
s"""
Expand All @@ -103,7 +103,7 @@ object GenerateProjection extends CodeGenerator[Seq[Expression], Projection] {
case (e, i) if ctx.javaType(e.dataType) == jt =>
Some(s"case $i: { c$i = value; return; }")
case _ => None
}.mkString("\n ")
}.mkString("\n")
if (cases.length > 0) {
val setter = "set" + ctx.primitiveTypeName(jt)
s"""
Expand Down Expand Up @@ -152,7 +152,7 @@ object GenerateProjection extends CodeGenerator[Seq[Expression], Projection] {

val copyColumns = expressions.zipWithIndex.map { case (e, i) =>
s"""if (!nullBits[$i]) arr[$i] = c$i;"""
}.mkString("\n ")
}.mkString("\n")

val code = s"""
public SpecificProjection generate($exprType[] expr) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -134,7 +134,7 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro
ctx.addMutableState("byte[]", buffer, s"this.$buffer = new byte[$fixedSize];")
val cursor = ctx.freshName("cursor")
ctx.addMutableState("int", cursor, s"this.$cursor = 0;")
val tmp = ctx.freshName("tmpBuffer")
val tmpBuffer = ctx.freshName("tmpBuffer")

val convertedFields = inputTypes.zip(inputs).zipWithIndex.map { case ((dt, input), i) =>
val ev = createConvertCode(ctx, input, dt)
Expand All @@ -144,10 +144,10 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro
int $numBytes = $cursor + (${genAdditionalSize(dt, ev)});
if ($buffer.length < $numBytes) {
// This will not happen frequently, because the buffer is re-used.
byte[] $tmp = new byte[$numBytes * 2];
byte[] $tmpBuffer = new byte[$numBytes * 2];
Platform.copyMemory($buffer, Platform.BYTE_ARRAY_OFFSET,
$tmp, Platform.BYTE_ARRAY_OFFSET, $buffer.length);
$buffer = $tmp;
$tmpBuffer, Platform.BYTE_ARRAY_OFFSET, $buffer.length);
$buffer = $tmpBuffer;
}
$output.pointTo($buffer, Platform.BYTE_ARRAY_OFFSET, ${inputTypes.length}, $numBytes);
"""
Expand Down Expand Up @@ -207,20 +207,22 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro
val buffer = ctx.freshName("buffer")
ctx.addMutableState("byte[]", buffer, s"$buffer = new byte[64];")
val outputIsNull = ctx.freshName("isNull")
val tmp = ctx.freshName("tmp")
val numElements = ctx.freshName("numElements")
val fixedSize = ctx.freshName("fixedSize")
val numBytes = ctx.freshName("numBytes")
val elements = ctx.freshName("elements")
val cursor = ctx.freshName("cursor")
val index = ctx.freshName("index")
val elementName = ctx.freshName("elementName")

val element = GeneratedExpressionCode(
code = "",
isNull = s"$tmp.isNullAt($index)",
primitive = s"${ctx.getValue(tmp, elementType, index)}"
)
val convertedElement: GeneratedExpressionCode = createConvertCode(ctx, element, elementType)
val element = {
val code = s"${ctx.javaType(elementType)} $elementName = " +
s"${ctx.getValue(input.primitive, elementType, index)};"
val isNull = s"${input.primitive}.isNullAt($index)"
GeneratedExpressionCode(code, isNull, elementName)
}

val convertedElement = createConvertCode(ctx, element, elementType)

// go through the input array to calculate how many bytes we need.
val calculateNumBytes = elementType match {
Expand Down Expand Up @@ -272,6 +274,7 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro
// Should we do word align?
val elementSize = elementType.defaultSize
s"""
${convertedElement.code}
Platform.put${ctx.primitiveTypeName(elementType)}(
$buffer,
Platform.BYTE_ARRAY_OFFSET + $cursor,
Expand All @@ -280,6 +283,7 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro
"""
case t: DecimalType if t.precision <= Decimal.MAX_LONG_DIGITS =>
s"""
${convertedElement.code}
Platform.putLong(
$buffer,
Platform.BYTE_ARRAY_OFFSET + $cursor,
Expand Down Expand Up @@ -307,11 +311,10 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro
${input.code}
final boolean $outputIsNull = ${input.isNull};
if (!$outputIsNull) {
final ArrayData $tmp = ${input.primitive};
if ($tmp instanceof UnsafeArrayData) {
$output = (UnsafeArrayData) $tmp;
if (${input.primitive} instanceof UnsafeArrayData) {
$output = (UnsafeArrayData) ${input.primitive};
} else {
final int $numElements = $tmp.numElements();
final int $numElements = ${input.primitive}.numElements();
final int $fixedSize = 4 * $numElements;
int $numBytes = $fixedSize;

Expand Down Expand Up @@ -350,29 +353,31 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro
valueType: DataType): GeneratedExpressionCode = {
val output = ctx.freshName("convertedMap")
val outputIsNull = ctx.freshName("isNull")
val tmp = ctx.freshName("tmp")

val keyArray = GeneratedExpressionCode(
code = "",
isNull = "false",
primitive = s"$tmp.keyArray()"
)
val valueArray = GeneratedExpressionCode(
code = "",
isNull = "false",
primitive = s"$tmp.valueArray()"
)
val convertedKeys: GeneratedExpressionCode = createCodeForArray(ctx, keyArray, keyType)
val convertedValues: GeneratedExpressionCode = createCodeForArray(ctx, valueArray, valueType)
val keyArrayName = ctx.freshName("keyArrayName")
val valueArrayName = ctx.freshName("valueArrayName")

val keyArray = {
val code = s"ArrayData $keyArrayName = ${input.primitive}.keyArray();"
val isNull = "false"
GeneratedExpressionCode(code, isNull, keyArrayName)
}

val valueArray = {
val code = s"ArrayData $valueArrayName = ${input.primitive}.valueArray();"
val isNull = "false"
GeneratedExpressionCode(code, isNull, valueArrayName)
}

val convertedKeys = createCodeForArray(ctx, keyArray, keyType)
val convertedValues = createCodeForArray(ctx, valueArray, valueType)

val code = s"""
${input.code}
final boolean $outputIsNull = ${input.isNull};
UnsafeMapData $output = null;
if (!$outputIsNull) {
final MapData $tmp = ${input.primitive};
if ($tmp instanceof UnsafeMapData) {
$output = (UnsafeMapData) $tmp;
if (${input.primitive} instanceof UnsafeMapData) {
$output = (UnsafeMapData) ${input.primitive};
} else {
${convertedKeys.code}
${convertedValues.code}
Expand All @@ -393,22 +398,22 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro
case t: StructType =>
val output = ctx.freshName("convertedStruct")
val outputIsNull = ctx.freshName("isNull")
val tmp = ctx.freshName("tmp")
val fieldTypes = t.fields.map(_.dataType)
val fieldEvals = fieldTypes.zipWithIndex.map { case (dt, i) =>
val getFieldCode = ctx.getValue(tmp, dt, i.toString)
val fieldIsNull = s"$tmp.isNullAt($i)"
GeneratedExpressionCode("", fieldIsNull, getFieldCode)
val fieldName = ctx.freshName("fieldName")
val code = s"${ctx.javaType(dt)} $fieldName = " +
s"${ctx.getValue(input.primitive, dt, i.toString)};"
val isNull = s"${input.primitive}.isNullAt($i)"
GeneratedExpressionCode(code, isNull, fieldName)
}
val converter = createCodeForStruct(ctx, tmp, fieldEvals, fieldTypes)
val converter = createCodeForStruct(ctx, input.primitive, fieldEvals, fieldTypes)
val code = s"""
${input.code}
UnsafeRow $output = null;
final boolean $outputIsNull = ${input.isNull};
if (!$outputIsNull) {
final InternalRow $tmp = ${input.primitive};
if ($tmp instanceof UnsafeRow) {
$output = (UnsafeRow) $tmp;
if (${input.primitive} instanceof UnsafeRow) {
$output = (UnsafeRow) ${input.primitive};
} else {
${converter.code}
$output = ${converter.primitive};
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -48,21 +48,22 @@ case class CreateArray(children: Seq[Expression]) extends Expression {

override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = {
val arrayClass = classOf[GenericArrayData].getName
val values = ctx.freshName("values")
s"""
final boolean ${ev.isNull} = false;
final Object[] values = new Object[${children.size}];
final Object[] $values = new Object[${children.size}];
""" +
children.zipWithIndex.map { case (e, i) =>
val eval = e.gen(ctx)
eval.code + s"""
if (${eval.isNull}) {
values[$i] = null;
$values[$i] = null;
} else {
values[$i] = ${eval.primitive};
$values[$i] = ${eval.primitive};
}
"""
}.mkString("\n") +
s"final ${ctx.javaType(dataType)} ${ev.primitive} = new $arrayClass(values);"
s"final ArrayData ${ev.primitive} = new $arrayClass($values);"
}

override def prettyName: String = "array"
Expand Down Expand Up @@ -94,21 +95,23 @@ case class CreateStruct(children: Seq[Expression]) extends Expression {
}

override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = {
val rowClass = classOf[GenericMutableRow].getName
val rowClass = classOf[GenericInternalRow].getName
val values = ctx.freshName("values")
s"""
boolean ${ev.isNull} = false;
final $rowClass ${ev.primitive} = new $rowClass(${children.size});
final Object[] $values = new Object[${children.size}];
""" +
children.zipWithIndex.map { case (e, i) =>
val eval = e.gen(ctx)
eval.code + s"""
if (${eval.isNull}) {
${ev.primitive}.update($i, null);
$values[$i] = null;
} else {
${ev.primitive}.update($i, ${eval.primitive});
$values[$i] = ${eval.primitive};
}
"""
}.mkString("\n")
}.mkString("\n") +
s"final InternalRow ${ev.primitive} = new $rowClass($values);"
}

override def prettyName: String = "struct"
Expand Down Expand Up @@ -161,21 +164,23 @@ case class CreateNamedStruct(children: Seq[Expression]) extends Expression {
}

override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = {
val rowClass = classOf[GenericMutableRow].getName
val rowClass = classOf[GenericInternalRow].getName
val values = ctx.freshName("values")
s"""
boolean ${ev.isNull} = false;
final $rowClass ${ev.primitive} = new $rowClass(${valExprs.size});
final Object[] $values = new Object[${valExprs.size}];
""" +
valExprs.zipWithIndex.map { case (e, i) =>
val eval = e.gen(ctx)
eval.code + s"""
if (${eval.isNull}) {
${ev.primitive}.update($i, null);
$values[$i] = null;
} else {
${ev.primitive}.update($i, ${eval.primitive});
$values[$i] = ${eval.primitive};
}
"""
}.mkString("\n")
}.mkString("\n") +
s"final InternalRow ${ev.primitive} = new $rowClass($values);"
}

override def prettyName: String = "named_struct"
Expand Down

0 comments on commit 96ed788

Please sign in to comment.