From 8fa739fb9d3bd49efa7df7525df18b7111b0131e Mon Sep 17 00:00:00 2001 From: Takeshi Yamamuro Date: Thu, 13 May 2021 17:53:46 -0700 Subject: [PATCH] [SPARK-35329][SQL] Split generated switch code into pieces in ExpandExec ### What changes were proposed in this pull request? This PR intends to split generated switch code into smaller ones in `ExpandExec`. In the current master, even a simple query like the one below generates a large method whose size (`maxMethodCodeSize:7448`) is close to `8000` (`CodeGenerator.DEFAULT_JVM_HUGE_METHOD_LIMIT`); ``` scala> val df = Seq(("2016-03-27 19:39:34", 1, "a"), ("2016-03-27 19:39:56", 2, "a"), ("2016-03-27 19:39:27", 4, "b")).toDF("time", "value", "id") scala> val rdf = df.select(window($"time", "10 seconds", "3 seconds", "0 second"), $"value").orderBy($"window.start".asc, $"value".desc).select("value") scala> sql("SET spark.sql.adaptive.enabled=false") scala> import org.apache.spark.sql.execution.debug._ scala> rdf.debugCodegen Found 2 WholeStageCodegen subtrees. == Subtree 1 / 2 (maxMethodCodeSize:7448; maxConstantPoolSize:189(0.29% used); numInnerClasses:0) == ^^^^ *(1) Project [window#34.start AS _gen_alias_39#39, value#11] +- *(1) Filter ((isnotnull(window#34) AND (cast(time#10 as timestamp) >= window#34.start)) AND (cast(time#10 as timestamp) < window#34.end)) +- *(1) Expand [List(named_struct(start, precisetimestampcon... /* 028 */ private void expand_doConsume_0(InternalRow localtablescan_row_0, UTF8String expand_expr_0_0, boolean expand_exprIsNull_0_0, int expand_expr_1_0) throws java.io.IOException { /* 029 */ boolean expand_isNull_0 = true; /* 030 */ InternalRow expand_value_0 = /* 031 */ null; /* 032 */ for (int expand_i_0 = 0; expand_i_0 < 4; expand_i_0 ++) { /* 033 */ switch (expand_i_0) { /* 034 */ case 0: (too many code lines) /* 517 */ break; /* 518 */ /* 519 */ case 1: (too many code lines) /* 1002 */ break; /* 1003 */ /* 1004 */ case 2: (too many code lines) /* 1487 */ break; /* 1488 */ /* 1489 */ case 3: (too many code lines) /* 1972 */ break; /* 1973 */ } /* 1974 */ ((org.apache.spark.sql.execution.metric.SQLMetric) references[33] /* numOutputRows */).add(1); /* 1975 */ /* 1976 */ do { /* 1977 */ boolean filter_value_2 = !expand_isNull_0; /* 1978 */ if (!filter_value_2) continue; ``` The fix in this PR can make the method smaller as follows; ``` Found 2 WholeStageCodegen subtrees. == Subtree 1 / 2 (maxMethodCodeSize:1713; maxConstantPoolSize:210(0.32% used); numInnerClasses:0) == ^^^^ *(1) Project [window#17.start AS _gen_alias_32#32, value#11] +- *(1) Filter ((isnotnull(window#17) AND (cast(time#10 as timestamp) >= window#17.start)) AND (cast(time#10 as timestamp) < window#17.end)) +- *(1) Expand [List(named_struct(start, precisetimestampcon... /* 032 */ private void expand_doConsume_0(InternalRow localtablescan_row_0, UTF8String expand_expr_0_0, boolean expand_exprIsNull_0_0, int expand_expr_1_0) throws java.io.IOException { /* 033 */ for (int expand_i_0 = 0; expand_i_0 < 4; expand_i_0 ++) { /* 034 */ switch (expand_i_0) { /* 035 */ case 0: /* 036 */ expand_switchCaseCode_0(expand_exprIsNull_0_0, expand_expr_0_0); /* 037 */ break; /* 038 */ /* 039 */ case 1: /* 040 */ expand_switchCaseCode_1(expand_exprIsNull_0_0, expand_expr_0_0); /* 041 */ break; /* 042 */ /* 043 */ case 2: /* 044 */ expand_switchCaseCode_2(expand_exprIsNull_0_0, expand_expr_0_0); /* 045 */ break; /* 046 */ /* 047 */ case 3: /* 048 */ expand_switchCaseCode_3(expand_exprIsNull_0_0, expand_expr_0_0); /* 049 */ break; /* 050 */ } /* 051 */ ((org.apache.spark.sql.execution.metric.SQLMetric) references[33] /* numOutputRows */).add(1); /* 052 */ /* 053 */ do { /* 054 */ boolean filter_value_2 = !expand_resultIsNull_0; /* 055 */ if (!filter_value_2) continue; /* 056 */ ... ``` ### Why are the changes needed? For better generated code. ### Does this PR introduce _any_ user-facing change? No. ### How was this patch tested? GA passed. Closes #32457 from maropu/splitSwitchCode. Authored-by: Takeshi Yamamuro Signed-off-by: Liang-Chi Hsieh --- .../spark/sql/execution/ExpandExec.scala | 90 ++++++++++++++----- 1 file changed, 66 insertions(+), 24 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/ExpandExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/ExpandExec.scala index 3fd653130e57c..c087fdf5f962b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/ExpandExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/ExpandExec.scala @@ -21,9 +21,9 @@ import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.codegen._ -import org.apache.spark.sql.catalyst.expressions.codegen.Block._ import org.apache.spark.sql.catalyst.plans.physical.{Partitioning, UnknownPartitioning} import org.apache.spark.sql.execution.metric.SQLMetrics +import org.apache.spark.sql.internal.SQLConf /** * Apply all of the GroupExpressions to every input row, hence we will get @@ -152,40 +152,82 @@ case class ExpandExec( // This column is the same across all output rows. Just generate code for it here. BindReferences.bindReference(firstExpr, attributeSeq).genCode(ctx) } else { - val isNull = ctx.freshName("isNull") - val value = ctx.freshName("value") - val code = code""" - |boolean $isNull = true; - |${CodeGenerator.javaType(firstExpr.dataType)} $value = - | ${CodeGenerator.defaultValue(firstExpr.dataType)}; - """.stripMargin + val isNull = ctx.addMutableState( + CodeGenerator.JAVA_BOOLEAN, + "resultIsNull", + v => s"$v = true;") + val value = ctx.addMutableState( + CodeGenerator.javaType(firstExpr.dataType), + "resultValue", + v => s"$v = ${CodeGenerator.defaultValue(firstExpr.dataType)};") + ExprCode( - code, JavaCode.isNullVariable(isNull), JavaCode.variable(value, firstExpr.dataType)) } } // Part 2: switch/case statements - val cases = projections.zipWithIndex.map { case (exprs, row) => - var updateCode = "" - for (col <- exprs.indices) { + val switchCaseExprs = projections.zipWithIndex.map { case (exprs, row) => + val (exprCodesWithIndices, inputVarSets) = exprs.indices.flatMap { col => if (!sameOutput(col)) { - val ev = BindReferences.bindReference(exprs(col), attributeSeq).genCode(ctx) - updateCode += + val boundExpr = BindReferences.bindReference(exprs(col), attributeSeq) + val exprCode = boundExpr.genCode(ctx) + val inputVars = CodeGenerator.getLocalInputVariableValues(ctx, boundExpr)._1 + Some(((col, exprCode), inputVars)) + } else { + None + } + }.unzip + + val inputVars = inputVarSets.foldLeft(Set.empty[VariableValue])(_ ++ _) + (row, exprCodesWithIndices, inputVars.toSeq) + } + + val updateCodes = switchCaseExprs.map { case (_, exprCodes, _) => + exprCodes.map { case (col, ev) => + s""" + |${ev.code} + |${outputColumns(col).isNull} = ${ev.isNull}; + |${outputColumns(col).value} = ${ev.value}; + """.stripMargin + }.mkString("\n") + } + + val splitThreshold = SQLConf.get.methodSplitThreshold + val cases = if (switchCaseExprs.flatMap(_._2.map(_._2.code.length)).sum > splitThreshold) { + switchCaseExprs.zip(updateCodes).map { case ((row, _, inputVars), updateCode) => + val paramLength = CodeGenerator.calculateParamLengthFromExprValues(inputVars) + val maybeSplitUpdateCode = if (CodeGenerator.isValidParamLength(paramLength)) { + val switchCaseFunc = ctx.freshName("switchCaseCode") + val argList = inputVars.map { v => + s"${CodeGenerator.typeName(v.javaType)} ${v.variableName}" + } + ctx.addNewFunction(switchCaseFunc, s""" - |${ev.code} - |${outputColumns(col).isNull} = ${ev.isNull}; - |${outputColumns(col).value} = ${ev.value}; - """.stripMargin + |private void $switchCaseFunc(${argList.mkString(", ")}) { + | $updateCode + |} + """.stripMargin) + + s"$switchCaseFunc(${inputVars.map(_.variableName).mkString(", ")});" + } else { + updateCode } + s""" + |case $row: + | $maybeSplitUpdateCode + | break; + """.stripMargin + } + } else { + switchCaseExprs.map(_._1).zip(updateCodes).map { case (row, updateCode) => + s""" + |case $row: + | $updateCode + | break; + """.stripMargin } - - s""" - |case $row: - | ${updateCode.trim} - | break; - """.stripMargin } val numOutput = metricTerm(ctx, "numOutputRows")