Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SPARK-22498][SQL] Fix 64KB JVM bytecode limit problem with concat #19728

Closed
wants to merge 10 commits into from
Closed
Show file tree
Hide file tree
Changes from 8 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -790,23 +790,7 @@ class CodegenContext {
returnType: String = "void",
makeSplitFunction: String => String = identity,
foldFunctions: Seq[String] => String = _.mkString("", ";\n", ";")): String = {
val blocks = new ArrayBuffer[String]()
val blockBuilder = new StringBuilder()
var length = 0
for (code <- expressions) {
// We can't know how many bytecode will be generated, so use the length of source code
// as metric. A method should not go beyond 8K, otherwise it will not be JITted, should
// also not be too small, or it will have many function calls (for wide table), see the
// results in BenchmarkWideTable.
if (length > 1024) {
blocks += blockBuilder.toString()
blockBuilder.clear()
length = 0
}
blockBuilder.append(code)
length += CodeFormatter.stripExtraNewLinesAndComments(code).length
}
blocks += blockBuilder.toString()
val blocks = splitCodes(expressions)

if (blocks.length == 1) {
// inline execution if only one block
Expand Down Expand Up @@ -841,6 +825,27 @@ class CodegenContext {
}
}

def splitCodes(expressions: Seq[String]): Seq[String] = {
val blocks = new ArrayBuffer[String]()
val blockBuilder = new StringBuilder()
var length = 0
for (code <- expressions) {
// We can't know how many bytecode will be generated, so use the length of source code
// as metric. A method should not go beyond 8K, otherwise it will not be JITted, should
// also not be too small, or it will have many function calls (for wide table), see the
// results in BenchmarkWideTable.
if (length > 1024) {
blocks += blockBuilder.toString()
blockBuilder.clear()
length = 0
}
blockBuilder.append(code)
length += CodeFormatter.stripExtraNewLinesAndComments(code).length
}
blocks += blockBuilder.toString()
blocks
}

/**
* Here we handle all the methods which have been added to the inner classes and
* not to the outer class.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -63,15 +63,28 @@ case class Concat(children: Seq[Expression]) extends Expression with ImplicitCas

override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
val evals = children.map(_.genCode(ctx))
val inputs = evals.map { eval =>
s"${eval.isNull} ? null : ${eval.value}"
}.mkString(", ")
ev.copy(evals.map(_.code).mkString("\n") + s"""
boolean ${ev.isNull} = false;
UTF8String ${ev.value} = UTF8String.concat($inputs);
if (${ev.value} == null) {
${ev.isNull} = true;
}
val numArgs = evals.length
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: we can inline it

val args = ctx.freshName("args")

val inputs = evals.zipWithIndex.map { case (eval, index) =>
s"""
${eval.code}
if (!${eval.isNull}) {
$args[$index] = ${eval.value};
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If eval.isNull is evaluated to null at runtime, eval.value is useless. We should assign null in that case.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

good catch

}
"""
}
val codes = if (ctx.INPUT_ROW != null && ctx.currentVars == null) {
ctx.splitExpressions(inputs, "valueConcat",
("InternalRow", ctx.INPUT_ROW) :: ("UTF8String[]", args) :: Nil)
} else {
inputs.mkString("\n")
}
ev.copy(s"""
UTF8String[] $args = new UTF8String[$numArgs];
$codes
UTF8String ${ev.value} = UTF8String.concat($args);
boolean ${ev.isNull} = ${ev.value} == null;
""")
}
}
Expand Down Expand Up @@ -125,19 +138,43 @@ case class ConcatWs(children: Seq[Expression])
if (children.forall(_.dataType == StringType)) {
// All children are strings. In that case we can construct a fixed size array.
val evals = children.map(_.genCode(ctx))

val inputs = evals.map { eval =>
s"${eval.isNull} ? (UTF8String) null : ${eval.value}"
}.mkString(", ")

ev.copy(evals.map(_.code).mkString("\n") + s"""
UTF8String ${ev.value} = UTF8String.concatWs($inputs);
val separator = evals.head
val strings = evals.tail
val argNums = strings.length
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: numArgs

val args = ctx.freshName("args")

val inputs = strings.zipWithIndex.map { case (eval, index) =>
if (eval.isNull != "true") {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

s"""
${eval.code}
if (!${eval.isNull}) {
$args[$index] = ${eval.value};
}
"""
} else {
""
}
}
val codes = s"${separator.code}\n" +
(if (ctx.INPUT_ROW != null && ctx.currentVars == null) {
ctx.splitExpressions(inputs, "valueConcatWs",
("InternalRow", ctx.INPUT_ROW) :: ("UTF8String[]", args) :: Nil)
} else {
inputs.mkString("\n")
})
ev.copy(s"""
UTF8String[] $args = new UTF8String[$argNums];
$codes
UTF8String ${ev.value} = UTF8String.concatWs(${separator.value}, $args);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit:

val code = if (ctx.INPUT_ROW != null && ctx.currentVars == null)
...
ev.copy(code = s"""
  UTF8String[] $args = new UTF8String[$argNums];
  ${separator.code}
  ...
""")

boolean ${ev.isNull} = ${ev.value} == null;
""")
} else {
val array = ctx.freshName("array")
ctx.addMutableState("UTF8String[]", array, "")
val varargNum = ctx.freshName("varargNum")
ctx.addMutableState("int", varargNum, "")
val idxInVararg = ctx.freshName("idxInVararg")
ctx.addMutableState("int", idxInVararg, "")

val evals = children.map(_.genCode(ctx))
val (varargCount, varargBuild) = children.tail.zip(evals.tail).map { case (child, eval) =>
Expand All @@ -163,13 +200,17 @@ case class ConcatWs(children: Seq[Expression])
}
}.unzip

ev.copy(evals.map(_.code).mkString("\n") +
val codes = ctx.splitExpressions(ctx.INPUT_ROW, evals.map(_.code))
val varargCounts = ctx.splitExpressions(ctx.INPUT_ROW, varargCount)
val varargBuilds = ctx.splitExpressions(ctx.INPUT_ROW, varargBuild)
ev.copy(
s"""
int $varargNum = ${children.count(_.dataType == StringType) - 1};
int $idxInVararg = 0;
${varargCount.mkString("\n")}
UTF8String[] $array = new UTF8String[$varargNum];
${varargBuild.mkString("\n")}
$codes
$varargNum = ${children.count(_.dataType == StringType) - 1};
$idxInVararg = 0;
$varargCounts
$array = new UTF8String[$varargNum];
$varargBuilds
UTF8String ${ev.value} = UTF8String.concatWs(${evals.head.value}, $array);
boolean ${ev.isNull} = ${ev.value} == null;
""")
Expand Down Expand Up @@ -224,22 +265,55 @@ case class Elt(children: Seq[Expression])
override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
val index = indexExpr.genCode(ctx)
val strings = stringExprs.map(_.genCode(ctx))
val indexVal = ctx.freshName("index")
val stringVal = ctx.freshName("stringVal")
val assignStringValue = strings.zipWithIndex.map { case (eval, index) =>
s"""
case ${index + 1}:
${ev.value} = ${eval.isNull} ? null : ${eval.value};
${eval.code}
$stringVal = ${eval.isNull} ? null : ${eval.value};
break;
"""
}.mkString("\n")
val indexVal = ctx.freshName("index")
val stringArray = ctx.freshName("strings");
}

ev.copy(index.code + "\n" + strings.map(_.code).mkString("\n") + s"""
final int $indexVal = ${index.value};
UTF8String ${ev.value} = null;
switch ($indexVal) {
$assignStringValue
val cases = ctx.splitCodes(assignStringValue)
val codes = if (cases.length == 1) {
s"""
UTF8String $stringVal = null;
switch ($indexVal) {
${cases.head}
}
"""
} else {
var fullFuncName = ""
cases.reverse.zipWithIndex.map { case (s, index) =>
val prevFunc = if (index == 0) {
"null"
} else {
s"$fullFuncName(${ctx.INPUT_ROW}, $indexVal)"
}
val funcName = ctx.freshName("eltFunc")
val funcBody = s"""
private UTF8String $funcName(InternalRow ${ctx.INPUT_ROW}, int $indexVal) {
UTF8String $stringVal = null;
switch ($indexVal) {
$s
default:
return $prevFunc;
}
return $stringVal;
}
"""
fullFuncName = ctx.addNewFunction(funcName, funcBody)
}
s"UTF8String $stringVal = $fullFuncName(${ctx.INPUT_ROW}, ${indexVal});"
}

ev.copy(index.code + "\n" +
s"""
final int $indexVal = ${index.value};
$codes
UTF8String ${ev.value} = $stringVal;
final boolean ${ev.isNull} = ${ev.value} == null;
""")
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,12 @@ class StringExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper {
// scalastyle:on
}

test("SPARK-22498: Concat should not generate codes beyond 64KB") {
val N = 5000
val strs = (1 to N).map(x => s"s$x")
checkEvaluation(Concat(strs.map(Literal.create(_, StringType))), strs.mkString, EmptyRow)
}

test("concat_ws") {
def testConcatWs(expected: String, sep: String, inputs: Any*): Unit = {
val inputExprs = inputs.map {
Expand Down Expand Up @@ -74,6 +80,19 @@ class StringExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper {
// scalastyle:on
}

test("SPARK-22498: ConcatWs should not generate codes beyond 64KB") {
val N = 5000
val sepExpr = Literal.create("#", StringType)
val strings1 = (1 to N).map(x => s"s$x")
val inputsExpr1 = strings1.map(Literal.create(_, StringType))
checkEvaluation(ConcatWs(sepExpr +: inputsExpr1), strings1.mkString("#"), EmptyRow)

val strings2 = (1 to N).map(x => Seq(s"s$x"))
val inputsExpr2 = strings2.map(Literal.create(_, ArrayType(StringType)))
checkEvaluation(
ConcatWs(sepExpr +: inputsExpr2), strings2.map(s => s(0)).mkString("#"), EmptyRow)
}

test("elt") {
def testElt(result: String, n: java.lang.Integer, args: String*): Unit = {
checkEvaluation(
Expand All @@ -97,6 +116,13 @@ class StringExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper {
assert(Elt(Seq(Literal(1), Literal(2))).checkInputDataTypes().isFailure)
}

test("SPARK-22498: Elt should not generate codes beyond 64KB") {
val N = 10000
val strings = (1 to N).map(x => s"s$x")
val args = Literal.create(N, IntegerType) +: strings.map(Literal.create(_, StringType))
checkEvaluation(Elt(args), s"s$N")
}

test("StringComparison") {
val row = create_row("abc", null)
val c1 = 'a.string.at(0)
Expand Down