Skip to content

Commit

Permalink
[SPARK-23925][SQL] Styling
Browse files Browse the repository at this point in the history
  • Loading branch information
pepinoflo committed May 15, 2018
1 parent 7601ea0 commit 3bd11e2
Show file tree
Hide file tree
Showing 2 changed files with 67 additions and 64 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -1478,7 +1478,8 @@ case class Flatten(child: Expression) extends UnaryExpression {
Examples:
> SELECT _FUNC_('123', 2);
['123', '123']
""")
""",
since = "2.4.0")
case class ArrayRepeat(left: Expression, right: Expression)
extends BinaryExpression with ExpectsInputTypes {

Expand All @@ -1496,7 +1497,7 @@ case class ArrayRepeat(left: Expression, right: Expression)
null
} else {
if (count.asInstanceOf[Int] > MAX_ARRAY_LENGTH) {
throw new RuntimeException(s"Unsuccessful try to create array with $count elements" +
throw new RuntimeException(s"Unsuccessful try to create array with $count elements " +
s"due to exceeding the array size limit $MAX_ARRAY_LENGTH.");
}
val element = left.eval(input)
Expand All @@ -1507,7 +1508,6 @@ case class ArrayRepeat(left: Expression, right: Expression)
override def prettyName: String = "array_repeat"

override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {

val leftGen = left.genCode(ctx)
val rightGen = right.genCode(ctx)
val element = leftGen.value
Expand All @@ -1523,25 +1523,27 @@ case class ArrayRepeat(left: Expression, right: Expression)

ev.copy(code =
s"""
|boolean ${ev.isNull} = false;
|${leftGen.code}
|${rightGen.code}
|${CodeGenerator.javaType(dataType)} ${ev.value} = ${CodeGenerator.defaultValue(dataType)};
|$resultCode
""".stripMargin)
|boolean ${ev.isNull} = false;
|${leftGen.code}
|${rightGen.code}
|${CodeGenerator.javaType(dataType)} ${ev.value} =
| ${CodeGenerator.defaultValue(dataType)};
|$resultCode
""".stripMargin)
}

private def nullElementsProtection(ev: ExprCode,
rightIsNull: String,
coreLogic: String): String = {
private def nullElementsProtection(
ev: ExprCode,
rightIsNull: String,
coreLogic: String): String = {
if (nullable) {
s"""
|if ($rightIsNull) {
| ${ev.isNull} = true;
|} else {
| ${coreLogic}
|}
""".stripMargin
|if ($rightIsNull) {
| ${ev.isNull} = true;
|} else {
| ${coreLogic}
|}
""".stripMargin
} else {
coreLogic
}
Expand All @@ -1551,66 +1553,67 @@ case class ArrayRepeat(left: Expression, right: Expression)
val numElements = ctx.freshName("numElements")
val numElementsCode =
s"""
|int $numElements = 0;
|if ($count > 0) {
| $numElements = $count;
|}
|if ($numElements > $MAX_ARRAY_LENGTH) {
| throw new RuntimeException("Unsuccessful try to create array with " + $numElements +
| " elements due to exceeding the array size limit $MAX_ARRAY_LENGTH.");
|}
""".stripMargin
|int $numElements = 0;
|if ($count > 0) {
| $numElements = $count;
|}
|if ($numElements > $MAX_ARRAY_LENGTH) {
| throw new RuntimeException("Unsuccessful try to create array with " + $numElements +
| " elements due to exceeding the array size limit $MAX_ARRAY_LENGTH.");
|}
""".stripMargin

(numElements, numElementsCode)
}

private def genCodeForPrimitiveElement(ctx: CodegenContext,
elementType: DataType,
element: String,
count: String,
leftIsNull: String,
arrayDataName: String): String = {

private def genCodeForPrimitiveElement(
ctx: CodegenContext,
elementType: DataType,
element: String,
count: String,
leftIsNull: String,
arrayDataName: String): String = {
val tempArrayDataName = ctx.freshName("tempArrayData")
val primitiveValueTypeName = CodeGenerator.primitiveTypeName(elementType)
val errorMessage = s" $prettyName failed."
val (numElemName, numElemCode) = genCodeForNumberOfElements(ctx, count)

s"""
|$numElemCode
|${ctx.createUnsafeArray(tempArrayDataName, numElemName, elementType, s" $prettyName failed.")}
|if (!$leftIsNull) {
| for (int k = 0; k < $tempArrayDataName.numElements(); k++) {
| $tempArrayDataName.set$primitiveValueTypeName(k, $element);
| }
|} else {
| for (int k = 0; k < $tempArrayDataName.numElements(); k++) {
| $tempArrayDataName.setNullAt(k);
| }
|}
|$arrayDataName = $tempArrayDataName;
""".stripMargin
|$numElemCode
|${ctx.createUnsafeArray(tempArrayDataName, numElemName, elementType, errorMessage)}
|if (!$leftIsNull) {
| for (int k = 0; k < $tempArrayDataName.numElements(); k++) {
| $tempArrayDataName.set$primitiveValueTypeName(k, $element);
| }
|} else {
| for (int k = 0; k < $tempArrayDataName.numElements(); k++) {
| $tempArrayDataName.setNullAt(k);
| }
|}
|$arrayDataName = $tempArrayDataName;
""".stripMargin
}

private def genCodeForNonPrimitiveElement(ctx: CodegenContext,
element: String,
count: String,
leftIsNull: String,
arrayDataName: String): String = {

private def genCodeForNonPrimitiveElement(
ctx: CodegenContext,
element: String,
count: String,
leftIsNull: String,
arrayDataName: String): String = {
val genericArrayClass = classOf[GenericArrayData].getName
val arrayName = ctx.freshName("arrayObject")
val (numElemName, numElemCode) = genCodeForNumberOfElements(ctx, count)

s"""
|$numElemCode
|Object[] $arrayName = new Object[(int)$numElemName];
|if (!$leftIsNull) {
| for (int k = 0; k < $numElemName; k++) {
| $arrayName[k] = $element;
| }
|}
|$arrayDataName = new $genericArrayClass($arrayName);
""".stripMargin
|$numElemCode
|Object[] $arrayName = new Object[(int)$numElemName];
|if (!$leftIsNull) {
| for (int k = 0; k < $numElemName; k++) {
| $arrayName[k] = $element;
| }
|}
|$arrayDataName = new $genericArrayClass($arrayName);
""".stripMargin
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -846,8 +846,8 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSQLContext {
test("array_repeat function") {
val dummyFilter = (c: Column) => c.isNull || c.isNotNull // to switch codeGen on
val strDF = Seq(
("hi", 2),
(null, 2)
("hi", 2),
(null, 2)
).toDF("a", "b")

val strDFTwiceResult = Seq(
Expand Down

0 comments on commit 3bd11e2

Please sign in to comment.