Skip to content

Commit

Permalink
[SPARK-18016][SQL][FOLLOW-UP] Code Generation: Constant Pool Limit - …
Browse files Browse the repository at this point in the history
…reduce entries for mutable state

## What changes were proposed in this pull request?

This PR addresses additional review comments in #19811

## How was this patch tested?

Existing test suites

Author: Kazuaki Ishizaki <ishizaki@jp.ibm.com>

Closes #20036 from kiszk/SPARK-18066-followup.
  • Loading branch information
kiszk authored and cloud-fan committed Dec 28, 2017
1 parent 753793b commit 5683984
Show file tree
Hide file tree
Showing 8 changed files with 51 additions and 48 deletions.
Expand Up @@ -190,7 +190,7 @@ class CodegenContext {

/**
* Returns the reference of next available slot in current compacted array. The size of each
* compacted array is controlled by the config `CodeGenerator.MUTABLESTATEARRAY_SIZE_LIMIT`.
* compacted array is controlled by the constant `CodeGenerator.MUTABLESTATEARRAY_SIZE_LIMIT`.
* Once reaching the threshold, new compacted array is created.
*/
def getNextSlot(): String = {
Expand Down Expand Up @@ -352,7 +352,7 @@ class CodegenContext {
def initMutableStates(): String = {
// It's possible that we add same mutable state twice, e.g. the `mergeExpressions` in
// `TypedAggregateExpression`, we should call `distinct` here to remove the duplicated ones.
val initCodes = mutableStateInitCode.distinct
val initCodes = mutableStateInitCode.distinct.map(_ + "\n")

// The generated initialization code may exceed 64kb function size limit in JVM if there are too
// many mutable states, so split it into multiple functions.
Expand Down
Expand Up @@ -118,9 +118,8 @@ case class Like(left: Expression, right: Expression) extends StringRegexExpressi
if (rVal != null) {
val regexStr =
StringEscapeUtils.escapeJava(escape(rVal.asInstanceOf[UTF8String].toString()))
// inline mutable state since not many Like operations in a task
val pattern = ctx.addMutableState(patternClass, "patternLike",
v => s"""$v = ${patternClass}.compile("$regexStr");""", forceInline = true)
v => s"""$v = $patternClass.compile("$regexStr");""")

// We don't use nullSafeCodeGen here because we don't want to re-evaluate right again.
val eval = left.genCode(ctx)
Expand All @@ -143,9 +142,9 @@ case class Like(left: Expression, right: Expression) extends StringRegexExpressi
val rightStr = ctx.freshName("rightStr")
nullSafeCodeGen(ctx, ev, (eval1, eval2) => {
s"""
String $rightStr = ${eval2}.toString();
${patternClass} $pattern = ${patternClass}.compile($escapeFunc($rightStr));
${ev.value} = $pattern.matcher(${eval1}.toString()).matches();
String $rightStr = $eval2.toString();
$patternClass $pattern = $patternClass.compile($escapeFunc($rightStr));
${ev.value} = $pattern.matcher($eval1.toString()).matches();
"""
})
}
Expand Down Expand Up @@ -194,9 +193,8 @@ case class RLike(left: Expression, right: Expression) extends StringRegexExpress
if (rVal != null) {
val regexStr =
StringEscapeUtils.escapeJava(rVal.asInstanceOf[UTF8String].toString())
// inline mutable state since not many RLike operations in a task
val pattern = ctx.addMutableState(patternClass, "patternRLike",
v => s"""$v = ${patternClass}.compile("$regexStr");""", forceInline = true)
v => s"""$v = $patternClass.compile("$regexStr");""")

// We don't use nullSafeCodeGen here because we don't want to re-evaluate right again.
val eval = left.genCode(ctx)
Expand All @@ -219,9 +217,9 @@ case class RLike(left: Expression, right: Expression) extends StringRegexExpress
val pattern = ctx.freshName("pattern")
nullSafeCodeGen(ctx, ev, (eval1, eval2) => {
s"""
String $rightStr = ${eval2}.toString();
${patternClass} $pattern = ${patternClass}.compile($rightStr);
${ev.value} = $pattern.matcher(${eval1}.toString()).find(0);
String $rightStr = $eval2.toString();
$patternClass $pattern = $patternClass.compile($rightStr);
${ev.value} = $pattern.matcher($eval1.toString()).find(0);
"""
})
}
Expand Down Expand Up @@ -338,25 +336,25 @@ case class RegExpReplace(subject: Expression, regexp: Expression, rep: Expressio

nullSafeCodeGen(ctx, ev, (subject, regexp, rep) => {
s"""
if (!$regexp.equals(${termLastRegex})) {
if (!$regexp.equals($termLastRegex)) {
// regex value changed
${termLastRegex} = $regexp.clone();
${termPattern} = ${classNamePattern}.compile(${termLastRegex}.toString());
$termLastRegex = $regexp.clone();
$termPattern = $classNamePattern.compile($termLastRegex.toString());
}
if (!$rep.equals(${termLastReplacementInUTF8})) {
if (!$rep.equals($termLastReplacementInUTF8)) {
// replacement string changed
${termLastReplacementInUTF8} = $rep.clone();
${termLastReplacement} = ${termLastReplacementInUTF8}.toString();
$termLastReplacementInUTF8 = $rep.clone();
$termLastReplacement = $termLastReplacementInUTF8.toString();
}
$classNameStringBuffer ${termResult} = new $classNameStringBuffer();
java.util.regex.Matcher ${matcher} = ${termPattern}.matcher($subject.toString());
$classNameStringBuffer $termResult = new $classNameStringBuffer();
java.util.regex.Matcher $matcher = $termPattern.matcher($subject.toString());

while (${matcher}.find()) {
${matcher}.appendReplacement(${termResult}, ${termLastReplacement});
while ($matcher.find()) {
$matcher.appendReplacement($termResult, $termLastReplacement);
}
${matcher}.appendTail(${termResult});
${ev.value} = UTF8String.fromString(${termResult}.toString());
${termResult} = null;
$matcher.appendTail($termResult);
${ev.value} = UTF8String.fromString($termResult.toString());
$termResult = null;
$setEvNotNull
"""
})
Expand Down Expand Up @@ -425,19 +423,19 @@ case class RegExpExtract(subject: Expression, regexp: Expression, idx: Expressio

nullSafeCodeGen(ctx, ev, (subject, regexp, idx) => {
s"""
if (!$regexp.equals(${termLastRegex})) {
if (!$regexp.equals($termLastRegex)) {
// regex value changed
${termLastRegex} = $regexp.clone();
${termPattern} = ${classNamePattern}.compile(${termLastRegex}.toString());
$termLastRegex = $regexp.clone();
$termPattern = $classNamePattern.compile($termLastRegex.toString());
}
java.util.regex.Matcher ${matcher} =
${termPattern}.matcher($subject.toString());
if (${matcher}.find()) {
java.util.regex.MatchResult ${matchResult} = ${matcher}.toMatchResult();
if (${matchResult}.group($idx) == null) {
java.util.regex.Matcher $matcher =
$termPattern.matcher($subject.toString());
if ($matcher.find()) {
java.util.regex.MatchResult $matchResult = $matcher.toMatchResult();
if ($matchResult.group($idx) == null) {
${ev.value} = UTF8String.EMPTY_UTF8;
} else {
${ev.value} = UTF8String.fromString(${matchResult}.group($idx));
${ev.value} = UTF8String.fromString($matchResult.group($idx));
}
$setEvNotNull
} else {
Expand Down
Expand Up @@ -138,7 +138,7 @@ case class SortExec(
// Initialize the class member variables. This includes the instance of the Sorter and
// the iterator to return sorted rows.
val thisPlan = ctx.addReferenceObj("plan", this)
// inline mutable state since not many Sort operations in a task
// Inline mutable state since not many Sort operations in a task
sorterVariable = ctx.addMutableState(classOf[UnsafeExternalRowSorter].getName, "sorter",
v => s"$v = $thisPlan.createSorter();", forceInline = true)
val metrics = ctx.addMutableState(classOf[TaskMetrics].getName, "metrics",
Expand Down
Expand Up @@ -283,7 +283,7 @@ case class InputAdapter(child: SparkPlan) extends UnaryExecNode with CodegenSupp

override def doProduce(ctx: CodegenContext): String = {
// Right now, InputAdapter is only used when there is one input RDD.
// inline mutable state since an inputAdaptor in a task
// Inline mutable state since an InputAdapter is used once in a task for WholeStageCodegen
val input = ctx.addMutableState("scala.collection.Iterator", "input", v => s"$v = inputs[0];",
forceInline = true)
val row = ctx.freshName("row")
Expand Down
Expand Up @@ -587,31 +587,35 @@ case class HashAggregateExec(
fastHashMapClassName, groupingKeySchema, bufferSchema).generate()
ctx.addInnerClass(generatedMap)

// Inline mutable state since not many aggregation operations in a task
fastHashMapTerm = ctx.addMutableState(fastHashMapClassName, "vectorizedHastHashMap",
v => s"$v = new $fastHashMapClassName();")
ctx.addMutableState(s"java.util.Iterator<InternalRow>", "vectorizedFastHashMapIter")
v => s"$v = new $fastHashMapClassName();", forceInline = true)
ctx.addMutableState(s"java.util.Iterator<InternalRow>", "vectorizedFastHashMapIter",
forceInline = true)
} else {
val generatedMap = new RowBasedHashMapGenerator(ctx, aggregateExpressions,
fastHashMapClassName, groupingKeySchema, bufferSchema).generate()
ctx.addInnerClass(generatedMap)

// Inline mutable state since not many aggregation operations in a task
fastHashMapTerm = ctx.addMutableState(fastHashMapClassName, "fastHashMap",
v => s"$v = new $fastHashMapClassName(" +
s"$thisPlan.getTaskMemoryManager(), $thisPlan.getEmptyAggregationBuffer());")
s"$thisPlan.getTaskMemoryManager(), $thisPlan.getEmptyAggregationBuffer());",
forceInline = true)
ctx.addMutableState(
"org.apache.spark.unsafe.KVIterator<UnsafeRow, UnsafeRow>",
"fastHashMapIter")
"fastHashMapIter", forceInline = true)
}
}

// Create a name for the iterator from the regular hash map.
// inline mutable state since not many aggregation operations in a task
// Inline mutable state since not many aggregation operations in a task
val iterTerm = ctx.addMutableState(classOf[KVIterator[UnsafeRow, UnsafeRow]].getName,
"mapIter", forceInline = true)
// create hashMap
val hashMapClassName = classOf[UnsafeFixedWidthAggregationMap].getName
hashMapTerm = ctx.addMutableState(hashMapClassName, "hashMap",
v => s"$v = $thisPlan.createHashMap();")
v => s"$v = $thisPlan.createHashMap();", forceInline = true)
sorterTerm = ctx.addMutableState(classOf[UnsafeKVExternalSorter].getName, "sorter",
forceInline = true)

Expand Down
Expand Up @@ -284,7 +284,7 @@ case class SampleExec(
val samplerClass = classOf[PoissonSampler[UnsafeRow]].getName
val initSampler = ctx.freshName("initSampler")

// inline mutable state since not many Sample operations in a task
// Inline mutable state since not many Sample operations in a task
val sampler = ctx.addMutableState(s"$samplerClass<UnsafeRow>", "sampleReplace",
v => {
val initSamplerFuncName = ctx.addNewFunction(initSampler,
Expand Down Expand Up @@ -371,7 +371,7 @@ case class RangeExec(range: org.apache.spark.sql.catalyst.plans.logical.Range)
val ev = ExprCode("", "false", value)
val BigInt = classOf[java.math.BigInteger].getName

// inline mutable state since not many Range operations in a task
// Inline mutable state since not many Range operations in a task
val taskContext = ctx.addMutableState("TaskContext", "taskContext",
v => s"$v = TaskContext.get();", forceInline = true)
val inputMetrics = ctx.addMutableState("InputMetrics", "inputMetrics",
Expand Down
Expand Up @@ -139,7 +139,7 @@ case class BroadcastHashJoinExec(
// At the end of the task, we update the avg hash probe.
val avgHashProbe = metricTerm(ctx, "avgHashProbe")

// inline mutable state since not many join operations in a task
// Inline mutable state since not many join operations in a task
val relationTerm = ctx.addMutableState(clsName, "relation",
v => s"""
| $v = (($clsName) $broadcast.value()).asReadOnlyCopy();
Expand Down
Expand Up @@ -422,7 +422,7 @@ case class SortMergeJoinExec(
*/
private def genScanner(ctx: CodegenContext): (String, String) = {
// Create class member for next row from both sides.
// inline mutable state since not many join operations in a task
// Inline mutable state since not many join operations in a task
val leftRow = ctx.addMutableState("InternalRow", "leftRow", forceInline = true)
val rightRow = ctx.addMutableState("InternalRow", "rightRow", forceInline = true)

Expand All @@ -440,8 +440,9 @@ case class SortMergeJoinExec(
val spillThreshold = getSpillThreshold
val inMemoryThreshold = getInMemoryThreshold

// Inline mutable state since not many join operations in a task
val matches = ctx.addMutableState(clsName, "matches",
v => s"$v = new $clsName($inMemoryThreshold, $spillThreshold);")
v => s"$v = new $clsName($inMemoryThreshold, $spillThreshold);", forceInline = true)
// Copy the left keys as class members so they could be used in next function call.
val matchedKeyVars = copyKeys(ctx, leftKeyVars)

Expand Down Expand Up @@ -576,7 +577,7 @@ case class SortMergeJoinExec(
override def needCopyResult: Boolean = true

override def doProduce(ctx: CodegenContext): String = {
// inline mutable state since not many join operations in a task
// Inline mutable state since not many join operations in a task
val leftInput = ctx.addMutableState("scala.collection.Iterator", "leftInput",
v => s"$v = inputs[0];", forceInline = true)
val rightInput = ctx.addMutableState("scala.collection.Iterator", "rightInput",
Expand Down

0 comments on commit 5683984

Please sign in to comment.