Skip to content

Commit

Permalink
[SPARK-35329][SQL] Split generated switch code into pieces in ExpandExec
Browse files Browse the repository at this point in the history
### What changes were proposed in this pull request?

This PR intends to split generated switch code into smaller ones in `ExpandExec`. In the current master, even a simple query like the one below generates a large method whose size (`maxMethodCodeSize:7448`) is close to `8000` (`CodeGenerator.DEFAULT_JVM_HUGE_METHOD_LIMIT`);
```
scala> val df = Seq(("2016-03-27 19:39:34", 1, "a"), ("2016-03-27 19:39:56", 2, "a"), ("2016-03-27 19:39:27", 4, "b")).toDF("time", "value", "id")
scala> val rdf = df.select(window($"time", "10 seconds", "3 seconds", "0 second"), $"value").orderBy($"window.start".asc, $"value".desc).select("value")
scala> sql("SET spark.sql.adaptive.enabled=false")
scala> import org.apache.spark.sql.execution.debug._
scala> rdf.debugCodegen

Found 2 WholeStageCodegen subtrees.
== Subtree 1 / 2 (maxMethodCodeSize:7448; maxConstantPoolSize:189(0.29% used); numInnerClasses:0) ==
                                    ^^^^
*(1) Project [window#34.start AS _gen_alias_39#39, value#11]
+- *(1) Filter ((isnotnull(window#34) AND (cast(time#10 as timestamp) >= window#34.start)) AND (cast(time#10 as timestamp) < window#34.end))
   +- *(1) Expand [List(named_struct(start, precisetimestampcon...

/* 028 */   private void expand_doConsume_0(InternalRow localtablescan_row_0, UTF8String expand_expr_0_0, boolean expand_exprIsNull_0_0, int expand_expr_1_0) throws java.io.IOException {
/* 029 */     boolean expand_isNull_0 = true;
/* 030 */     InternalRow expand_value_0 =
/* 031 */     null;
/* 032 */     for (int expand_i_0 = 0; expand_i_0 < 4; expand_i_0 ++) {
/* 033 */       switch (expand_i_0) {
/* 034 */       case 0:
                  (too many code lines)
/* 517 */         break;
/* 518 */
/* 519 */       case 1:
                  (too many code lines)
/* 1002 */         break;
/* 1003 */
/* 1004 */       case 2:
                  (too many code lines)
/* 1487 */         break;
/* 1488 */
/* 1489 */       case 3:
                  (too many code lines)
/* 1972 */         break;
/* 1973 */       }
/* 1974 */       ((org.apache.spark.sql.execution.metric.SQLMetric) references[33] /* numOutputRows */).add(1);
/* 1975 */
/* 1976 */       do {
/* 1977 */         boolean filter_value_2 = !expand_isNull_0;
/* 1978 */         if (!filter_value_2) continue;
```
The fix in this PR can make the method smaller as follows;
```
Found 2 WholeStageCodegen subtrees.
== Subtree 1 / 2 (maxMethodCodeSize:1713; maxConstantPoolSize:210(0.32% used); numInnerClasses:0) ==
                                    ^^^^
*(1) Project [window#17.start AS _gen_alias_32#32, value#11]
+- *(1) Filter ((isnotnull(window#17) AND (cast(time#10 as timestamp) >= window#17.start)) AND (cast(time#10 as timestamp) < window#17.end))
   +- *(1) Expand [List(named_struct(start, precisetimestampcon...

/* 032 */   private void expand_doConsume_0(InternalRow localtablescan_row_0, UTF8String expand_expr_0_0, boolean expand_exprIsNull_0_0, int expand_expr_1_0) throws java.io.IOException {
/* 033 */     for (int expand_i_0 = 0; expand_i_0 < 4; expand_i_0 ++) {
/* 034 */       switch (expand_i_0) {
/* 035 */       case 0:
/* 036 */         expand_switchCaseCode_0(expand_exprIsNull_0_0, expand_expr_0_0);
/* 037 */         break;
/* 038 */
/* 039 */       case 1:
/* 040 */         expand_switchCaseCode_1(expand_exprIsNull_0_0, expand_expr_0_0);
/* 041 */         break;
/* 042 */
/* 043 */       case 2:
/* 044 */         expand_switchCaseCode_2(expand_exprIsNull_0_0, expand_expr_0_0);
/* 045 */         break;
/* 046 */
/* 047 */       case 3:
/* 048 */         expand_switchCaseCode_3(expand_exprIsNull_0_0, expand_expr_0_0);
/* 049 */         break;
/* 050 */       }
/* 051 */       ((org.apache.spark.sql.execution.metric.SQLMetric) references[33] /* numOutputRows */).add(1);
/* 052 */
/* 053 */       do {
/* 054 */         boolean filter_value_2 = !expand_resultIsNull_0;
/* 055 */         if (!filter_value_2) continue;
/* 056 */
...
```

### Why are the changes needed?

For better generated code.

### Does this PR introduce _any_ user-facing change?

No.

### How was this patch tested?

GA passed.

Closes #32457 from maropu/splitSwitchCode.

Authored-by: Takeshi Yamamuro <yamamuro@apache.org>
Signed-off-by: Liang-Chi Hsieh <viirya@gmail.com>
  • Loading branch information
maropu authored and viirya committed May 14, 2021
1 parent 160b3be commit 8fa739f
Showing 1 changed file with 66 additions and 24 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -21,9 +21,9 @@ import org.apache.spark.rdd.RDD
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.expressions.codegen._
import org.apache.spark.sql.catalyst.expressions.codegen.Block._
import org.apache.spark.sql.catalyst.plans.physical.{Partitioning, UnknownPartitioning}
import org.apache.spark.sql.execution.metric.SQLMetrics
import org.apache.spark.sql.internal.SQLConf

/**
* Apply all of the GroupExpressions to every input row, hence we will get
Expand Down Expand Up @@ -152,40 +152,82 @@ case class ExpandExec(
// This column is the same across all output rows. Just generate code for it here.
BindReferences.bindReference(firstExpr, attributeSeq).genCode(ctx)
} else {
val isNull = ctx.freshName("isNull")
val value = ctx.freshName("value")
val code = code"""
|boolean $isNull = true;
|${CodeGenerator.javaType(firstExpr.dataType)} $value =
| ${CodeGenerator.defaultValue(firstExpr.dataType)};
""".stripMargin
val isNull = ctx.addMutableState(
CodeGenerator.JAVA_BOOLEAN,
"resultIsNull",
v => s"$v = true;")
val value = ctx.addMutableState(
CodeGenerator.javaType(firstExpr.dataType),
"resultValue",
v => s"$v = ${CodeGenerator.defaultValue(firstExpr.dataType)};")

ExprCode(
code,
JavaCode.isNullVariable(isNull),
JavaCode.variable(value, firstExpr.dataType))
}
}

// Part 2: switch/case statements
val cases = projections.zipWithIndex.map { case (exprs, row) =>
var updateCode = ""
for (col <- exprs.indices) {
val switchCaseExprs = projections.zipWithIndex.map { case (exprs, row) =>
val (exprCodesWithIndices, inputVarSets) = exprs.indices.flatMap { col =>
if (!sameOutput(col)) {
val ev = BindReferences.bindReference(exprs(col), attributeSeq).genCode(ctx)
updateCode +=
val boundExpr = BindReferences.bindReference(exprs(col), attributeSeq)
val exprCode = boundExpr.genCode(ctx)
val inputVars = CodeGenerator.getLocalInputVariableValues(ctx, boundExpr)._1
Some(((col, exprCode), inputVars))
} else {
None
}
}.unzip

val inputVars = inputVarSets.foldLeft(Set.empty[VariableValue])(_ ++ _)
(row, exprCodesWithIndices, inputVars.toSeq)
}

val updateCodes = switchCaseExprs.map { case (_, exprCodes, _) =>
exprCodes.map { case (col, ev) =>
s"""
|${ev.code}
|${outputColumns(col).isNull} = ${ev.isNull};
|${outputColumns(col).value} = ${ev.value};
""".stripMargin
}.mkString("\n")
}

val splitThreshold = SQLConf.get.methodSplitThreshold
val cases = if (switchCaseExprs.flatMap(_._2.map(_._2.code.length)).sum > splitThreshold) {
switchCaseExprs.zip(updateCodes).map { case ((row, _, inputVars), updateCode) =>
val paramLength = CodeGenerator.calculateParamLengthFromExprValues(inputVars)
val maybeSplitUpdateCode = if (CodeGenerator.isValidParamLength(paramLength)) {
val switchCaseFunc = ctx.freshName("switchCaseCode")
val argList = inputVars.map { v =>
s"${CodeGenerator.typeName(v.javaType)} ${v.variableName}"
}
ctx.addNewFunction(switchCaseFunc,
s"""
|${ev.code}
|${outputColumns(col).isNull} = ${ev.isNull};
|${outputColumns(col).value} = ${ev.value};
""".stripMargin
|private void $switchCaseFunc(${argList.mkString(", ")}) {
| $updateCode
|}
""".stripMargin)

s"$switchCaseFunc(${inputVars.map(_.variableName).mkString(", ")});"
} else {
updateCode
}
s"""
|case $row:
| $maybeSplitUpdateCode
| break;
""".stripMargin
}
} else {
switchCaseExprs.map(_._1).zip(updateCodes).map { case (row, updateCode) =>
s"""
|case $row:
| $updateCode
| break;
""".stripMargin
}

s"""
|case $row:
| ${updateCode.trim}
| break;
""".stripMargin
}

val numOutput = metricTerm(ctx, "numOutputRows")
Expand Down

0 comments on commit 8fa739f

Please sign in to comment.