diff --git a/sql/catalyst/pom.xml b/sql/catalyst/pom.xml index d5027ff6ad23f..5ecee28a1f0b8 100644 --- a/sql/catalyst/pom.xml +++ b/sql/catalyst/pom.xml @@ -131,13 +131,6 @@ - - org.scalatest - scalatest-maven-plugin - - -Xmx4g -Xss4096k -XX:MaxPermSize=${MaxPermGen} -XX:ReservedCodeCacheSize=512m - - org.antlr antlr4-maven-plugin diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ScalaUDF.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ScalaUDF.scala index 5c68f9ffc691c..228f4b756c8b4 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ScalaUDF.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ScalaUDF.scala @@ -988,7 +988,7 @@ case class ScalaUDF( val converterTerm = ctx.freshName("converter") val expressionIdx = ctx.references.size - 1 ctx.addMutableState(converterClassName, converterTerm, - s"$converterTerm = ($converterClassName)$typeConvertersClassName" + + s"this.$converterTerm = ($converterClassName)$typeConvertersClassName" + s".createToScalaConverter(((${expressionClassName})((($scalaUDFClassName)" + s"references[$expressionIdx]).getChildren().apply($index))).dataType());") converterTerm @@ -1005,7 +1005,7 @@ case class ScalaUDF( // Generate codes used to convert the returned value of user-defined functions to Catalyst type val catalystConverterTerm = ctx.freshName("catalystConverter") ctx.addMutableState(converterClassName, catalystConverterTerm, - s"$catalystConverterTerm = ($converterClassName)$typeConvertersClassName" + + s"this.$catalystConverterTerm = ($converterClassName)$typeConvertersClassName" + s".createToCatalystConverter($scalaUDF.dataType());") val resultTerm = ctx.freshName("result") @@ -1019,7 +1019,7 @@ case class ScalaUDF( val funcTerm = ctx.freshName("udf") ctx.addMutableState(funcClassName, funcTerm, - s"$funcTerm = ($funcClassName)$scalaUDF.userDefinedFunc();") + s"this.$funcTerm = ($funcClassName)$scalaUDF.userDefinedFunc();") // codegen for children expressions val evals = children.map(_.genCode(ctx)) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala index 4954cf8bc1177..f8da78b5f5e3e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala @@ -113,7 +113,7 @@ class CodegenContext { val idx = references.length references += obj val clsName = Option(className).getOrElse(obj.getClass.getName) - addMutableState(clsName, term, s"$term = ($clsName) references[$idx];") + addMutableState(clsName, term, s"this.$term = ($clsName) references[$idx];") term } @@ -202,6 +202,16 @@ class CodegenContext { partitionInitializationStatements.mkString("\n") } + /** + * Holding all the functions those will be added into generated class. + */ + val addedFunctions: mutable.Map[String, String] = + mutable.Map.empty[String, String] + + def addNewFunction(funcName: String, funcCode: String): Unit = { + addedFunctions += ((funcName, funcCode)) + } + /** * Holds expressions that are equivalent. Used to perform subexpression elimination * during codegen. @@ -223,118 +233,10 @@ class CodegenContext { // The collection of sub-expression result resetting methods that need to be called on each row. val subexprFunctions = mutable.ArrayBuffer.empty[String] - private val outerClassName = "OuterClass" - - /** - * Holds the class and instance names to be generated, where `OuterClass` is a placeholder - * standing for whichever class is generated as the outermost class and which will contain any - * nested sub-classes. All other classes and instance names in this list will represent private, - * nested sub-classes. - */ - private val classes: mutable.ListBuffer[(String, String)] = - mutable.ListBuffer[(String, String)](outerClassName -> null) - - // A map holding the current size in bytes of each class to be generated. - private val classSize: mutable.Map[String, Int] = - mutable.Map[String, Int](outerClassName -> 0) - - // Nested maps holding function names and their code belonging to each class. - private val classFunctions: mutable.Map[String, mutable.Map[String, String]] = - mutable.Map(outerClassName -> mutable.Map.empty[String, String]) - - // Returns the size of the most recently added class. - private def currClassSize(): Int = classSize(classes.head._1) - - // Returns the class name and instance name for the most recently added class. - private def currClass(): (String, String) = classes.head - - // Adds a new class. Requires the class' name, and its instance name. - private def addClass(className: String, classInstance: String): Unit = { - classes.prepend(className -> classInstance) - classSize += className -> 0 - classFunctions += className -> mutable.Map.empty[String, String] + def declareAddedFunctions(): String = { + addedFunctions.map { case (funcName, funcCode) => funcCode }.mkString("\n") } - /** - * Adds a function to the generated class. If the code for the `OuterClass` grows too large, the - * function will be inlined into a new private, nested class, and a instance-qualified name for - * the function will be returned. Otherwise, the function will be inined to the `OuterClass` the - * simple `funcName` will be returned. - * - * @param funcName the class-unqualified name of the function - * @param funcCode the body of the function - * @param inlineToOuterClass whether the given code must be inlined to the `OuterClass`. This - * can be necessary when a function is declared outside of the context - * it is eventually referenced and a returned qualified function name - * cannot otherwise be accessed. - * @return the name of the function, qualified by class if it will be inlined to a private, - * nested sub-class - */ - def addNewFunction( - funcName: String, - funcCode: String, - inlineToOuterClass: Boolean = false): String = { - // The number of named constants that can exist in the class is limited by the Constant Pool - // limit, 65,536. We cannot know how many constants will be inserted for a class, so we use a - // threshold of 1600k bytes to determine when a function should be inlined to a private, nested - // sub-class. - val (className, classInstance) = if (inlineToOuterClass) { - outerClassName -> "" - } else if (currClassSize > 1600000) { - val className = freshName("NestedClass") - val classInstance = freshName("nestedClassInstance") - - addClass(className, classInstance) - - className -> classInstance - } else { - currClass() - } - - classSize(className) += funcCode.length - classFunctions(className) += funcName -> funcCode - - if (className == outerClassName) { - funcName - } else { - - s"$classInstance.$funcName" - } - } - - /** - * Instantiates all nested, private sub-classes as objects to the `OuterClass` - */ - private[sql] def initNestedClasses(): String = { - // Nested, private sub-classes have no mutable state (though they do reference the outer class' - // mutable state), so we declare and initialize them inline to the OuterClass. - classes.filter(_._1 != outerClassName).map { - case (className, classInstance) => - s"private $className $classInstance = new $className();" - }.mkString("\n") - } - - /** - * Declares all function code that should be inlined to the `OuterClass`. - */ - private[sql] def declareAddedFunctions(): String = { - classFunctions(outerClassName).values.mkString("\n") - } - - /** - * Declares all nested, private sub-classes and the function code that should be inlined to them. - */ - private[sql] def declareNestedClasses(): String = { - classFunctions.filterKeys(_ != outerClassName).map { - case (className, functions) => - s""" - |private class $className { - | ${functions.values.mkString("\n")} - |} - """.stripMargin - } - }.mkString("\n") - final val JAVA_BOOLEAN = "boolean" final val JAVA_BYTE = "byte" final val JAVA_SHORT = "short" @@ -654,7 +556,8 @@ class CodegenContext { return 0; } """ - s"${addNewFunction(compareFunc, funcCode)}($c1, $c2)" + addNewFunction(compareFunc, funcCode) + s"this.$compareFunc($c1, $c2)" case schema: StructType => val comparisons = GenerateOrdering.genComparisons(this, schema) val compareFunc = freshName("compareStruct") @@ -670,7 +573,8 @@ class CodegenContext { return 0; } """ - s"${addNewFunction(compareFunc, funcCode)}($c1, $c2)" + addNewFunction(compareFunc, funcCode) + s"this.$compareFunc($c1, $c2)" case other if other.isInstanceOf[AtomicType] => s"$c1.compare($c2)" case udt: UserDefinedType[_] => genComp(udt.sqlType, c1, c2) case _ => @@ -785,6 +689,7 @@ class CodegenContext { |} """.stripMargin addNewFunction(name, code) + name } foldFunctions(functions.map(name => s"$name(${arguments.map(_._2).mkString(", ")})")) @@ -868,6 +773,8 @@ class CodegenContext { |} """.stripMargin + addNewFunction(fnName, fn) + // Add a state and a mapping of the common subexpressions that are associate with this // state. Adding this expression to subExprEliminationExprMap means it will call `fn` // when it is code generated. This decision should be a cost based one. @@ -885,7 +792,7 @@ class CodegenContext { addMutableState(javaType(expr.dataType), value, s"$value = ${defaultValue(expr.dataType)};") - subexprFunctions += s"${addNewFunction(fnName, fn)}($INPUT_ROW);" + subexprFunctions += s"$fnName($INPUT_ROW);" val state = SubExprEliminationState(isNull, value) e.foreach(subExprEliminationExprs.put(_, state)) } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateMutableProjection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateMutableProjection.scala index 635766835029b..4d732445544a8 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateMutableProjection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateMutableProjection.scala @@ -63,21 +63,21 @@ object GenerateMutableProjection extends CodeGenerator[Seq[Expression], MutableP if (e.nullable) { val isNull = s"isNull_$i" val value = s"value_$i" - ctx.addMutableState("boolean", isNull, s"$isNull = true;") + ctx.addMutableState("boolean", isNull, s"this.$isNull = true;") ctx.addMutableState(ctx.javaType(e.dataType), value, - s"$value = ${ctx.defaultValue(e.dataType)};") + s"this.$value = ${ctx.defaultValue(e.dataType)};") s""" ${ev.code} - $isNull = ${ev.isNull}; - $value = ${ev.value}; + this.$isNull = ${ev.isNull}; + this.$value = ${ev.value}; """ } else { val value = s"value_$i" ctx.addMutableState(ctx.javaType(e.dataType), value, - s"$value = ${ctx.defaultValue(e.dataType)};") + s"this.$value = ${ctx.defaultValue(e.dataType)};") s""" ${ev.code} - $value = ${ev.value}; + this.$value = ${ev.value}; """ } } @@ -87,7 +87,7 @@ object GenerateMutableProjection extends CodeGenerator[Seq[Expression], MutableP val updates = validExpr.zip(index).map { case (e, i) => - val ev = ExprCode("", s"isNull_$i", s"value_$i") + val ev = ExprCode("", s"this.isNull_$i", s"this.value_$i") ctx.updateColumn("mutableRow", e.dataType, i, ev, e.nullable) } @@ -135,9 +135,6 @@ object GenerateMutableProjection extends CodeGenerator[Seq[Expression], MutableP $allUpdates return mutableRow; } - - ${ctx.initNestedClasses()} - ${ctx.declareNestedClasses()} } """ diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateOrdering.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateOrdering.scala index a31943255b995..f7fc2d54a047b 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateOrdering.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateOrdering.scala @@ -179,9 +179,6 @@ object GenerateOrdering extends CodeGenerator[Seq[SortOrder], Ordering[InternalR $comparisons return 0; } - - ${ctx.initNestedClasses()} - ${ctx.declareNestedClasses()} }""" val code = CodeFormatter.stripOverlappingComments( diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GeneratePredicate.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GeneratePredicate.scala index b400783bb5e55..dcd1ed96a298e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GeneratePredicate.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GeneratePredicate.scala @@ -72,9 +72,6 @@ object GeneratePredicate extends CodeGenerator[Expression, Predicate] { ${eval.code} return !${eval.isNull} && ${eval.value}; } - - ${ctx.initNestedClasses()} - ${ctx.declareNestedClasses()} }""" val code = CodeFormatter.stripOverlappingComments( diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateSafeProjection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateSafeProjection.scala index f708aeff2b146..b1cb6edefb852 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateSafeProjection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateSafeProjection.scala @@ -49,7 +49,7 @@ object GenerateSafeProjection extends CodeGenerator[Seq[Expression], Projection] val output = ctx.freshName("safeRow") val values = ctx.freshName("values") // These expressions could be split into multiple functions - ctx.addMutableState("Object[]", values, s"$values = null;") + ctx.addMutableState("Object[]", values, s"this.$values = null;") val rowClass = classOf[GenericInternalRow].getName @@ -65,10 +65,10 @@ object GenerateSafeProjection extends CodeGenerator[Seq[Expression], Projection] val allFields = ctx.splitExpressions(tmp, fieldWriters) val code = s""" final InternalRow $tmp = $input; - $values = new Object[${schema.length}]; + this.$values = new Object[${schema.length}]; $allFields final InternalRow $output = new $rowClass($values); - $values = null; + this.$values = null; """ ExprCode(code, "false", output) @@ -184,9 +184,6 @@ object GenerateSafeProjection extends CodeGenerator[Seq[Expression], Projection] $allExpressions return mutableRow; } - - ${ctx.initNestedClasses()} - ${ctx.declareNestedClasses()} } """ diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala index febfe3124f2bd..b358102d914bd 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala @@ -82,7 +82,7 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro val rowWriterClass = classOf[UnsafeRowWriter].getName val rowWriter = ctx.freshName("rowWriter") ctx.addMutableState(rowWriterClass, rowWriter, - s"$rowWriter = new $rowWriterClass($bufferHolder, ${inputs.length});") + s"this.$rowWriter = new $rowWriterClass($bufferHolder, ${inputs.length});") val resetWriter = if (isTopLevel) { // For top level row writer, it always writes to the beginning of the global buffer holder, @@ -182,7 +182,7 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro val arrayWriterClass = classOf[UnsafeArrayWriter].getName val arrayWriter = ctx.freshName("arrayWriter") ctx.addMutableState(arrayWriterClass, arrayWriter, - s"$arrayWriter = new $arrayWriterClass();") + s"this.$arrayWriter = new $arrayWriterClass();") val numElements = ctx.freshName("numElements") val index = ctx.freshName("index") val element = ctx.freshName("element") @@ -321,7 +321,7 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro val holder = ctx.freshName("holder") val holderClass = classOf[BufferHolder].getName ctx.addMutableState(holderClass, holder, - s"$holder = new $holderClass($result, ${numVarLenFields * 32});") + s"this.$holder = new $holderClass($result, ${numVarLenFields * 32});") val resetBufferHolder = if (numVarLenFields == 0) { "" @@ -402,9 +402,6 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro ${eval.code.trim} return ${eval.value}; } - - ${ctx.initNestedClasses()} - ${ctx.declareNestedClasses()} } """ diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala index 98c4cbee38dee..b6675a84ece48 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala @@ -93,7 +93,7 @@ private [sql] object GenArrayData { if (!ctx.isPrimitiveType(elementType)) { val genericArrayClass = classOf[GenericArrayData].getName ctx.addMutableState("Object[]", arrayName, - s"$arrayName = new Object[${numElements}];") + s"this.$arrayName = new Object[${numElements}];") val assignments = elementsCode.zipWithIndex.map { case (eval, i) => val isNullAssignment = if (!isMapKey) { @@ -340,7 +340,7 @@ case class CreateNamedStruct(children: Seq[Expression]) extends CreateNamedStruc override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { val rowClass = classOf[GenericInternalRow].getName val values = ctx.freshName("values") - ctx.addMutableState("Object[]", values, s"$values = null;") + ctx.addMutableState("Object[]", values, s"this.$values = null;") ev.copy(code = s""" $values = new Object[${valExprs.size}];""" + @@ -357,7 +357,7 @@ case class CreateNamedStruct(children: Seq[Expression]) extends CreateNamedStruc }) + s""" final InternalRow ${ev.value} = new $rowClass($values); - $values = null; + this.$values = null; """, isNull = "false") } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/conditionalExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/conditionalExpressions.scala index ae8efb673f91c..ee365fe636614 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/conditionalExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/conditionalExpressions.scala @@ -131,8 +131,8 @@ case class If(predicate: Expression, trueValue: Expression, falseValue: Expressi | $globalValue = ${ev.value}; |} """.stripMargin - val fullFuncName = ctx.addNewFunction(funcName, funcBody) - (fullFuncName, globalIsNull, globalValue) + ctx.addNewFunction(funcName, funcBody) + (funcName, globalIsNull, globalValue) } override def toString: String = s"if ($predicate) $trueValue else $falseValue" diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/generators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/generators.scala index baa5ba68dcb30..e84796f2edad0 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/generators.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/generators.scala @@ -181,7 +181,7 @@ case class Stack(children: Seq[Expression]) extends Generator { override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { // Rows - we write these into an array. val rowData = ctx.freshName("rows") - ctx.addMutableState("InternalRow[]", rowData, s"$rowData = new InternalRow[$numRows];") + ctx.addMutableState("InternalRow[]", rowData, s"this.$rowData = new InternalRow[$numRows];") val values = children.tail val dataTypes = values.take(numFields).map(_.dataType) val code = ctx.splitExpressions(ctx.INPUT_ROW, Seq.tabulate(numRows) { row => @@ -190,7 +190,7 @@ case class Stack(children: Seq[Expression]) extends Generator { if (index < values.length) values(index) else Literal(null, dataTypes(col)) } val eval = CreateStruct(fields).genCode(ctx) - s"${eval.code}\n$rowData[$row] = ${eval.value};" + s"${eval.code}\nthis.$rowData[$row] = ${eval.value};" }) // Create the collection. @@ -198,7 +198,7 @@ case class Stack(children: Seq[Expression]) extends Generator { ctx.addMutableState( s"$wrapperClass", ev.value, - s"${ev.value} = $wrapperClass$$.MODULE$$.make($rowData);") + s"this.${ev.value} = $wrapperClass$$.MODULE$$.make(this.$rowData);") ev.copy(code = code, isNull = "false") } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala index 2bd752c82e6c1..1a202ecf745c9 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala @@ -981,7 +981,7 @@ case class InitializeJavaBean(beanInstance: Expression, setters: Map[String, Exp val code = s""" ${instanceGen.code} - ${javaBeanInstance} = ${instanceGen.value}; + this.${javaBeanInstance} = ${instanceGen.value}; if (!${instanceGen.isNull}) { $initializeCode } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/GeneratedProjectionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/GeneratedProjectionSuite.scala index d7ba57a697b08..b69b74b4240bd 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/GeneratedProjectionSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/GeneratedProjectionSuite.scala @@ -33,10 +33,10 @@ class GeneratedProjectionSuite extends SparkFunSuite { test("generated projections on wider table") { val N = 1000 - val wideRow1 = new GenericInternalRow((0 until N).toArray[Any]) + val wideRow1 = new GenericInternalRow((1 to N).toArray[Any]) val schema1 = StructType((1 to N).map(i => StructField("", IntegerType))) val wideRow2 = new GenericInternalRow( - (0 until N).map(i => UTF8String.fromString(i.toString)).toArray[Any]) + (1 to N).map(i => UTF8String.fromString(i.toString)).toArray[Any]) val schema2 = StructType((1 to N).map(i => StructField("", StringType))) val joined = new JoinedRow(wideRow1, wideRow2) val joinedSchema = StructType(schema1 ++ schema2) @@ -48,12 +48,12 @@ class GeneratedProjectionSuite extends SparkFunSuite { val unsafeProj = UnsafeProjection.create(nestedSchema) val unsafe: UnsafeRow = unsafeProj(nested) (0 until N).foreach { i => - val s = UTF8String.fromString(i.toString) - assert(i === unsafe.getInt(i + 2)) + val s = UTF8String.fromString((i + 1).toString) + assert(i + 1 === unsafe.getInt(i + 2)) assert(s === unsafe.getUTF8String(i + 2 + N)) - assert(i === unsafe.getStruct(0, N * 2).getInt(i)) + assert(i + 1 === unsafe.getStruct(0, N * 2).getInt(i)) assert(s === unsafe.getStruct(0, N * 2).getUTF8String(i + N)) - assert(i === unsafe.getStruct(1, N * 2).getInt(i)) + assert(i + 1 === unsafe.getStruct(1, N * 2).getInt(i)) assert(s === unsafe.getStruct(1, N * 2).getUTF8String(i + N)) } @@ -62,63 +62,13 @@ class GeneratedProjectionSuite extends SparkFunSuite { val result = safeProj(unsafe) // Can't compare GenericInternalRow with JoinedRow directly (0 until N).foreach { i => - val s = UTF8String.fromString(i.toString) - assert(i === result.getInt(i + 2)) + val r = i + 1 + val s = UTF8String.fromString((i + 1).toString) + assert(r === result.getInt(i + 2)) assert(s === result.getUTF8String(i + 2 + N)) - assert(i === result.getStruct(0, N * 2).getInt(i)) + assert(r === result.getStruct(0, N * 2).getInt(i)) assert(s === result.getStruct(0, N * 2).getUTF8String(i + N)) - assert(i === result.getStruct(1, N * 2).getInt(i)) - assert(s === result.getStruct(1, N * 2).getUTF8String(i + N)) - } - - // test generated MutableProjection - val exprs = nestedSchema.fields.zipWithIndex.map { case (f, i) => - BoundReference(i, f.dataType, true) - } - val mutableProj = GenerateMutableProjection.generate(exprs) - val row1 = mutableProj(result) - assert(result === row1) - val row2 = mutableProj(result) - assert(result === row2) - } - - test("generated projections on wider table requiring class-splitting") { - val N = 4000 - val wideRow1 = new GenericInternalRow((0 until N).toArray[Any]) - val schema1 = StructType((1 to N).map(i => StructField("", IntegerType))) - val wideRow2 = new GenericInternalRow( - (0 until N).map(i => UTF8String.fromString(i.toString)).toArray[Any]) - val schema2 = StructType((1 to N).map(i => StructField("", StringType))) - val joined = new JoinedRow(wideRow1, wideRow2) - val joinedSchema = StructType(schema1 ++ schema2) - val nested = new JoinedRow(InternalRow(joined, joined), joined) - val nestedSchema = StructType( - Seq(StructField("", joinedSchema), StructField("", joinedSchema)) ++ joinedSchema) - - // test generated UnsafeProjection - val unsafeProj = UnsafeProjection.create(nestedSchema) - val unsafe: UnsafeRow = unsafeProj(nested) - (0 until N).foreach { i => - val s = UTF8String.fromString(i.toString) - assert(i === unsafe.getInt(i + 2)) - assert(s === unsafe.getUTF8String(i + 2 + N)) - assert(i === unsafe.getStruct(0, N * 2).getInt(i)) - assert(s === unsafe.getStruct(0, N * 2).getUTF8String(i + N)) - assert(i === unsafe.getStruct(1, N * 2).getInt(i)) - assert(s === unsafe.getStruct(1, N * 2).getUTF8String(i + N)) - } - - // test generated SafeProjection - val safeProj = FromUnsafeProjection(nestedSchema) - val result = safeProj(unsafe) - // Can't compare GenericInternalRow with JoinedRow directly - (0 until N).foreach { i => - val s = UTF8String.fromString(i.toString) - assert(i === result.getInt(i + 2)) - assert(s === result.getUTF8String(i + 2 + N)) - assert(i === result.getStruct(0, N * 2).getInt(i)) - assert(s === result.getStruct(0, N * 2).getUTF8String(i + N)) - assert(i === result.getStruct(1, N * 2).getInt(i)) + assert(r === result.getStruct(1, N * 2).getInt(i)) assert(s === result.getStruct(1, N * 2).getUTF8String(i + N)) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/ColumnarBatchScan.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/ColumnarBatchScan.scala index 74a47da2deef2..e86116680a57a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/ColumnarBatchScan.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/ColumnarBatchScan.scala @@ -93,7 +93,7 @@ private[sql] trait ColumnarBatchScan extends CodegenSupport { } val nextBatch = ctx.freshName("nextBatch") - val nextBatchFuncName = ctx.addNewFunction(nextBatch, + ctx.addNewFunction(nextBatch, s""" |private void $nextBatch() throws java.io.IOException { | long getBatchStart = System.nanoTime(); @@ -121,7 +121,7 @@ private[sql] trait ColumnarBatchScan extends CodegenSupport { } s""" |if ($batch == null) { - | $nextBatchFuncName(); + | $nextBatch(); |} |while ($batch != null) { | int $numRows = $batch.numRows(); @@ -133,7 +133,7 @@ private[sql] trait ColumnarBatchScan extends CodegenSupport { | } | $idx = $numRows; | $batch = null; - | $nextBatchFuncName(); + | $nextBatch(); |} |$scanTimeMetric.add($scanTimeTotalNs / (1000 * 1000)); |$scanTimeTotalNs = 0; diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SortExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SortExec.scala index ff71fd4dc7bb7..f98ae82574d20 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SortExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SortExec.scala @@ -141,7 +141,7 @@ case class SortExec( ctx.addMutableState("scala.collection.Iterator", sortedIterator, "") val addToSorter = ctx.freshName("addToSorter") - val addToSorterFuncName = ctx.addNewFunction(addToSorter, + ctx.addNewFunction(addToSorter, s""" | private void $addToSorter() throws java.io.IOException { | ${child.asInstanceOf[CodegenSupport].produce(ctx, this)} @@ -160,7 +160,7 @@ case class SortExec( s""" | if ($needToSort) { | long $spillSizeBefore = $metrics.memoryBytesSpilled(); - | $addToSorterFuncName(); + | $addToSorter(); | $sortedIterator = $sorterVariable.sort(); | $sortTime.add($sorterVariable.getSortTimeNanos() / 1000000); | $peakMemory.add($sorterVariable.getPeakMemoryUsage()); diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegenExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegenExec.scala index c7e9d25bd2cc0..c1e1a631c677e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegenExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegenExec.scala @@ -357,9 +357,6 @@ case class WholeStageCodegenExec(child: SparkPlan) extends UnaryExecNode with Co protected void processNext() throws java.io.IOException { ${code.trim} } - - ${ctx.initNestedClasses()} - ${ctx.declareNestedClasses()} } """.trim diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala index bf7fa07765b9a..68c8e6ce62cbb 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala @@ -209,7 +209,7 @@ case class HashAggregateExec( } val doAgg = ctx.freshName("doAggregateWithoutKey") - val doAggFuncName = ctx.addNewFunction(doAgg, + ctx.addNewFunction(doAgg, s""" | private void $doAgg() throws java.io.IOException { | // initialize aggregation buffer @@ -226,7 +226,7 @@ case class HashAggregateExec( | while (!$initAgg) { | $initAgg = true; | long $beforeAgg = System.nanoTime(); - | $doAggFuncName(); + | $doAgg(); | $aggTime.add((System.nanoTime() - $beforeAgg) / 1000000); | | // output the result @@ -592,7 +592,7 @@ case class HashAggregateExec( } else "" } - val doAggFuncName = ctx.addNewFunction(doAgg, + ctx.addNewFunction(doAgg, s""" ${generateGenerateCode} private void $doAgg() throws java.io.IOException { @@ -672,7 +672,7 @@ case class HashAggregateExec( if (!$initAgg) { $initAgg = true; long $beforeAgg = System.nanoTime(); - $doAggFuncName(); + $doAgg(); $aggTime.add((System.nanoTime() - $beforeAgg) / 1000000); } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala index bb24489ade1b3..bd7a5c5d914c1 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala @@ -281,8 +281,10 @@ case class SampleExec( val samplerClass = classOf[PoissonSampler[UnsafeRow]].getName val initSampler = ctx.freshName("initSampler") ctx.copyResult = true + ctx.addMutableState(s"$samplerClass", sampler, + s"$initSampler();") - val initSamplerFuncName = ctx.addNewFunction(initSampler, + ctx.addNewFunction(initSampler, s""" | private void $initSampler() { | $sampler = new $samplerClass($upperBound - $lowerBound, false); @@ -297,8 +299,6 @@ case class SampleExec( | } """.stripMargin.trim) - ctx.addMutableState(s"$samplerClass", sampler, s"$initSamplerFuncName();") - val samplingCount = ctx.freshName("samplingCount") s""" | int $samplingCount = $sampler.sample(); @@ -394,7 +394,7 @@ case class RangeExec(range: org.apache.spark.sql.catalyst.plans.logical.Range) // The default size of a batch, which must be positive integer val batchSize = 1000 - val initRangeFuncName = ctx.addNewFunction("initRange", + ctx.addNewFunction("initRange", s""" | private void initRange(int idx) { | $BigInt index = $BigInt.valueOf(idx); @@ -451,7 +451,7 @@ case class RangeExec(range: org.apache.spark.sql.catalyst.plans.logical.Range) | // initialize Range | if (!$initTerm) { | $initTerm = true; - | $initRangeFuncName(partitionIndex); + | initRange(partitionIndex); | } | | while (true) { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/GenerateColumnAccessor.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/GenerateColumnAccessor.scala index a66e8e2b46e3d..14024d6c10558 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/GenerateColumnAccessor.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/GenerateColumnAccessor.scala @@ -128,7 +128,9 @@ object GenerateColumnAccessor extends CodeGenerator[Seq[DataType], ColumnarItera } else { val groupedAccessorsItr = initializeAccessors.grouped(numberOfStatementsThreshold) val groupedExtractorsItr = extractors.grouped(numberOfStatementsThreshold) - val accessorNames = groupedAccessorsItr.zipWithIndex.map { case (body, i) => + var groupedAccessorsLength = 0 + groupedAccessorsItr.zipWithIndex.foreach { case (body, i) => + groupedAccessorsLength += 1 val funcName = s"accessors$i" val funcCode = s""" |private void $funcName() { @@ -137,7 +139,7 @@ object GenerateColumnAccessor extends CodeGenerator[Seq[DataType], ColumnarItera """.stripMargin ctx.addNewFunction(funcName, funcCode) } - val extractorNames = groupedExtractorsItr.zipWithIndex.map { case (body, i) => + groupedExtractorsItr.zipWithIndex.foreach { case (body, i) => val funcName = s"extractors$i" val funcCode = s""" |private void $funcName() { @@ -146,8 +148,8 @@ object GenerateColumnAccessor extends CodeGenerator[Seq[DataType], ColumnarItera """.stripMargin ctx.addNewFunction(funcName, funcCode) } - (accessorNames.map { accessorName => s"$accessorName();" }.mkString("\n"), - extractorNames.map { extractorName => s"$extractorName();" }.mkString("\n")) + ((0 to groupedAccessorsLength - 1).map { i => s"accessors$i();" }.mkString("\n"), + (0 to groupedAccessorsLength - 1).map { i => s"extractors$i();" }.mkString("\n")) } val codeBody = s""" @@ -222,9 +224,6 @@ object GenerateColumnAccessor extends CodeGenerator[Seq[DataType], ColumnarItera unsafeRow.setTotalSize(bufferHolder.totalSize()); return unsafeRow; } - - ${ctx.initNestedClasses()} - ${ctx.declareNestedClasses()} }""" val code = CodeFormatter.stripOverlappingComments( diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoinExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoinExec.scala index 8445c26eeee58..26fb6103953fc 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoinExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoinExec.scala @@ -478,7 +478,7 @@ case class SortMergeJoinExec( | } | return false; // unreachable |} - """.stripMargin, inlineToOuterClass = true) + """.stripMargin) (leftRow, matches) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/limit.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/limit.scala index 73a0f8735ed45..757fe2185d302 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/limit.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/limit.scala @@ -75,7 +75,7 @@ trait BaseLimitExec extends UnaryExecNode with CodegenSupport { protected boolean stopEarly() { return $stopEarly; } - """, inlineToOuterClass = true) + """) val countTerm = ctx.freshName("count") ctx.addMutableState("int", countTerm, s"$countTerm = 0;") s"""