Skip to content

Commit

Permalink
[SPARK-22549][SQL] Fix 64KB JVM bytecode limit problem with concat_ws
Browse files Browse the repository at this point in the history
## What changes were proposed in this pull request?

This PR changes `concat_ws` code generation to place generated code for expression for arguments into separated methods if these size could be large.
This PR resolved the case of `concat_ws` with a lot of argument

## How was this patch tested?

Added new test cases into `StringExpressionsSuite`

Author: Kazuaki Ishizaki <ishizaki@jp.ibm.com>

Closes #19777 from kiszk/SPARK-22549.
  • Loading branch information
kiszk authored and cloud-fan committed Nov 21, 2017
1 parent c13b60e commit 41c6f36
Show file tree
Hide file tree
Showing 2 changed files with 89 additions and 24 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -137,13 +137,34 @@ case class ConcatWs(children: Seq[Expression])
if (children.forall(_.dataType == StringType)) {
// All children are strings. In that case we can construct a fixed size array.
val evals = children.map(_.genCode(ctx))

val inputs = evals.map { eval =>
s"${eval.isNull} ? (UTF8String) null : ${eval.value}"
}.mkString(", ")

ev.copy(evals.map(_.code).mkString("\n") + s"""
UTF8String ${ev.value} = UTF8String.concatWs($inputs);
val separator = evals.head
val strings = evals.tail
val numArgs = strings.length
val args = ctx.freshName("args")

val inputs = strings.zipWithIndex.map { case (eval, index) =>
if (eval.isNull != "true") {
s"""
${eval.code}
if (!${eval.isNull}) {
$args[$index] = ${eval.value};
}
"""
} else {
""
}
}
val codes = if (ctx.INPUT_ROW != null && ctx.currentVars == null) {
ctx.splitExpressions(inputs, "valueConcatWs",
("InternalRow", ctx.INPUT_ROW) :: ("UTF8String[]", args) :: Nil)
} else {
inputs.mkString("\n")
}
ev.copy(s"""
UTF8String[] $args = new UTF8String[$numArgs];
${separator.code}
$codes
UTF8String ${ev.value} = UTF8String.concatWs(${separator.value}, $args);
boolean ${ev.isNull} = ${ev.value} == null;
""")
} else {
Expand All @@ -156,32 +177,63 @@ case class ConcatWs(children: Seq[Expression])
child.dataType match {
case StringType =>
("", // we count all the StringType arguments num at once below.
s"$array[$idxInVararg ++] = ${eval.isNull} ? (UTF8String) null : ${eval.value};")
if (eval.isNull == "true") {
""
} else {
s"$array[$idxInVararg ++] = ${eval.isNull} ? (UTF8String) null : ${eval.value};"
})
case _: ArrayType =>
val size = ctx.freshName("n")
(s"""
if (!${eval.isNull}) {
$varargNum += ${eval.value}.numElements();
}
""",
s"""
if (!${eval.isNull}) {
final int $size = ${eval.value}.numElements();
for (int j = 0; j < $size; j ++) {
$array[$idxInVararg ++] = ${ctx.getValue(eval.value, StringType, "j")};
}
if (eval.isNull == "true") {
("", "")
} else {
(s"""
if (!${eval.isNull}) {
$varargNum += ${eval.value}.numElements();
}
""",
s"""
if (!${eval.isNull}) {
final int $size = ${eval.value}.numElements();
for (int j = 0; j < $size; j ++) {
$array[$idxInVararg ++] = ${ctx.getValue(eval.value, StringType, "j")};
}
}
""")
}
""")
}
}.unzip

ev.copy(evals.map(_.code).mkString("\n") +
s"""
val codes = ctx.splitExpressions(ctx.INPUT_ROW, evals.map(_.code))
val varargCounts = ctx.splitExpressions(varargCount, "varargCountsConcatWs",
("InternalRow", ctx.INPUT_ROW) :: Nil,
"int",
{ body =>
s"""
int $varargNum = 0;
$body
return $varargNum;
"""
},
_.mkString(s"$varargNum += ", s";\n$varargNum += ", ";"))
val varargBuilds = ctx.splitExpressions(varargBuild, "varargBuildsConcatWs",
("InternalRow", ctx.INPUT_ROW) :: ("UTF8String []", array) :: ("int", idxInVararg) :: Nil,
"int",
{ body =>
s"""
$body
return $idxInVararg;
"""
},
_.mkString(s"$idxInVararg = ", s";\n$idxInVararg = ", ";"))
ev.copy(
s"""
$codes
int $varargNum = ${children.count(_.dataType == StringType) - 1};
int $idxInVararg = 0;
${varargCount.mkString("\n")}
$varargCounts
UTF8String[] $array = new UTF8String[$varargNum];
${varargBuild.mkString("\n")}
$varargBuilds
UTF8String ${ev.value} = UTF8String.concatWs(${evals.head.value}, $array);
boolean ${ev.isNull} = ${ev.value} == null;
""")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,19 @@ class StringExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper {
// scalastyle:on
}

test("SPARK-22549: ConcatWs should not generate codes beyond 64KB") {
val N = 5000
val sepExpr = Literal.create("#", StringType)
val strings1 = (1 to N).map(x => s"s$x")
val inputsExpr1 = strings1.map(Literal.create(_, StringType))
checkEvaluation(ConcatWs(sepExpr +: inputsExpr1), strings1.mkString("#"), EmptyRow)

val strings2 = (1 to N).map(x => Seq(s"s$x"))
val inputsExpr2 = strings2.map(Literal.create(_, ArrayType(StringType)))
checkEvaluation(
ConcatWs(sepExpr +: inputsExpr2), strings2.map(s => s(0)).mkString("#"), EmptyRow)
}

test("elt") {
def testElt(result: String, n: java.lang.Integer, args: String*): Unit = {
checkEvaluation(
Expand Down

0 comments on commit 41c6f36

Please sign in to comment.