diff --git a/sql/catalyst/pom.xml b/sql/catalyst/pom.xml
index d5027ff6ad23f..5ecee28a1f0b8 100644
--- a/sql/catalyst/pom.xml
+++ b/sql/catalyst/pom.xml
@@ -131,13 +131,6 @@
-
- org.scalatest
- scalatest-maven-plugin
-
- -Xmx4g -Xss4096k -XX:MaxPermSize=${MaxPermGen} -XX:ReservedCodeCacheSize=512m
-
-
org.antlr
antlr4-maven-plugin
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ScalaUDF.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ScalaUDF.scala
index 5c68f9ffc691c..228f4b756c8b4 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ScalaUDF.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ScalaUDF.scala
@@ -988,7 +988,7 @@ case class ScalaUDF(
val converterTerm = ctx.freshName("converter")
val expressionIdx = ctx.references.size - 1
ctx.addMutableState(converterClassName, converterTerm,
- s"$converterTerm = ($converterClassName)$typeConvertersClassName" +
+ s"this.$converterTerm = ($converterClassName)$typeConvertersClassName" +
s".createToScalaConverter(((${expressionClassName})((($scalaUDFClassName)" +
s"references[$expressionIdx]).getChildren().apply($index))).dataType());")
converterTerm
@@ -1005,7 +1005,7 @@ case class ScalaUDF(
// Generate codes used to convert the returned value of user-defined functions to Catalyst type
val catalystConverterTerm = ctx.freshName("catalystConverter")
ctx.addMutableState(converterClassName, catalystConverterTerm,
- s"$catalystConverterTerm = ($converterClassName)$typeConvertersClassName" +
+ s"this.$catalystConverterTerm = ($converterClassName)$typeConvertersClassName" +
s".createToCatalystConverter($scalaUDF.dataType());")
val resultTerm = ctx.freshName("result")
@@ -1019,7 +1019,7 @@ case class ScalaUDF(
val funcTerm = ctx.freshName("udf")
ctx.addMutableState(funcClassName, funcTerm,
- s"$funcTerm = ($funcClassName)$scalaUDF.userDefinedFunc();")
+ s"this.$funcTerm = ($funcClassName)$scalaUDF.userDefinedFunc();")
// codegen for children expressions
val evals = children.map(_.genCode(ctx))
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala
index 4954cf8bc1177..f8da78b5f5e3e 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala
@@ -113,7 +113,7 @@ class CodegenContext {
val idx = references.length
references += obj
val clsName = Option(className).getOrElse(obj.getClass.getName)
- addMutableState(clsName, term, s"$term = ($clsName) references[$idx];")
+ addMutableState(clsName, term, s"this.$term = ($clsName) references[$idx];")
term
}
@@ -202,6 +202,16 @@ class CodegenContext {
partitionInitializationStatements.mkString("\n")
}
+ /**
+ * Holding all the functions those will be added into generated class.
+ */
+ val addedFunctions: mutable.Map[String, String] =
+ mutable.Map.empty[String, String]
+
+ def addNewFunction(funcName: String, funcCode: String): Unit = {
+ addedFunctions += ((funcName, funcCode))
+ }
+
/**
* Holds expressions that are equivalent. Used to perform subexpression elimination
* during codegen.
@@ -223,118 +233,10 @@ class CodegenContext {
// The collection of sub-expression result resetting methods that need to be called on each row.
val subexprFunctions = mutable.ArrayBuffer.empty[String]
- private val outerClassName = "OuterClass"
-
- /**
- * Holds the class and instance names to be generated, where `OuterClass` is a placeholder
- * standing for whichever class is generated as the outermost class and which will contain any
- * nested sub-classes. All other classes and instance names in this list will represent private,
- * nested sub-classes.
- */
- private val classes: mutable.ListBuffer[(String, String)] =
- mutable.ListBuffer[(String, String)](outerClassName -> null)
-
- // A map holding the current size in bytes of each class to be generated.
- private val classSize: mutable.Map[String, Int] =
- mutable.Map[String, Int](outerClassName -> 0)
-
- // Nested maps holding function names and their code belonging to each class.
- private val classFunctions: mutable.Map[String, mutable.Map[String, String]] =
- mutable.Map(outerClassName -> mutable.Map.empty[String, String])
-
- // Returns the size of the most recently added class.
- private def currClassSize(): Int = classSize(classes.head._1)
-
- // Returns the class name and instance name for the most recently added class.
- private def currClass(): (String, String) = classes.head
-
- // Adds a new class. Requires the class' name, and its instance name.
- private def addClass(className: String, classInstance: String): Unit = {
- classes.prepend(className -> classInstance)
- classSize += className -> 0
- classFunctions += className -> mutable.Map.empty[String, String]
+ def declareAddedFunctions(): String = {
+ addedFunctions.map { case (funcName, funcCode) => funcCode }.mkString("\n")
}
- /**
- * Adds a function to the generated class. If the code for the `OuterClass` grows too large, the
- * function will be inlined into a new private, nested class, and a instance-qualified name for
- * the function will be returned. Otherwise, the function will be inined to the `OuterClass` the
- * simple `funcName` will be returned.
- *
- * @param funcName the class-unqualified name of the function
- * @param funcCode the body of the function
- * @param inlineToOuterClass whether the given code must be inlined to the `OuterClass`. This
- * can be necessary when a function is declared outside of the context
- * it is eventually referenced and a returned qualified function name
- * cannot otherwise be accessed.
- * @return the name of the function, qualified by class if it will be inlined to a private,
- * nested sub-class
- */
- def addNewFunction(
- funcName: String,
- funcCode: String,
- inlineToOuterClass: Boolean = false): String = {
- // The number of named constants that can exist in the class is limited by the Constant Pool
- // limit, 65,536. We cannot know how many constants will be inserted for a class, so we use a
- // threshold of 1600k bytes to determine when a function should be inlined to a private, nested
- // sub-class.
- val (className, classInstance) = if (inlineToOuterClass) {
- outerClassName -> ""
- } else if (currClassSize > 1600000) {
- val className = freshName("NestedClass")
- val classInstance = freshName("nestedClassInstance")
-
- addClass(className, classInstance)
-
- className -> classInstance
- } else {
- currClass()
- }
-
- classSize(className) += funcCode.length
- classFunctions(className) += funcName -> funcCode
-
- if (className == outerClassName) {
- funcName
- } else {
-
- s"$classInstance.$funcName"
- }
- }
-
- /**
- * Instantiates all nested, private sub-classes as objects to the `OuterClass`
- */
- private[sql] def initNestedClasses(): String = {
- // Nested, private sub-classes have no mutable state (though they do reference the outer class'
- // mutable state), so we declare and initialize them inline to the OuterClass.
- classes.filter(_._1 != outerClassName).map {
- case (className, classInstance) =>
- s"private $className $classInstance = new $className();"
- }.mkString("\n")
- }
-
- /**
- * Declares all function code that should be inlined to the `OuterClass`.
- */
- private[sql] def declareAddedFunctions(): String = {
- classFunctions(outerClassName).values.mkString("\n")
- }
-
- /**
- * Declares all nested, private sub-classes and the function code that should be inlined to them.
- */
- private[sql] def declareNestedClasses(): String = {
- classFunctions.filterKeys(_ != outerClassName).map {
- case (className, functions) =>
- s"""
- |private class $className {
- | ${functions.values.mkString("\n")}
- |}
- """.stripMargin
- }
- }.mkString("\n")
-
final val JAVA_BOOLEAN = "boolean"
final val JAVA_BYTE = "byte"
final val JAVA_SHORT = "short"
@@ -654,7 +556,8 @@ class CodegenContext {
return 0;
}
"""
- s"${addNewFunction(compareFunc, funcCode)}($c1, $c2)"
+ addNewFunction(compareFunc, funcCode)
+ s"this.$compareFunc($c1, $c2)"
case schema: StructType =>
val comparisons = GenerateOrdering.genComparisons(this, schema)
val compareFunc = freshName("compareStruct")
@@ -670,7 +573,8 @@ class CodegenContext {
return 0;
}
"""
- s"${addNewFunction(compareFunc, funcCode)}($c1, $c2)"
+ addNewFunction(compareFunc, funcCode)
+ s"this.$compareFunc($c1, $c2)"
case other if other.isInstanceOf[AtomicType] => s"$c1.compare($c2)"
case udt: UserDefinedType[_] => genComp(udt.sqlType, c1, c2)
case _ =>
@@ -785,6 +689,7 @@ class CodegenContext {
|}
""".stripMargin
addNewFunction(name, code)
+ name
}
foldFunctions(functions.map(name => s"$name(${arguments.map(_._2).mkString(", ")})"))
@@ -868,6 +773,8 @@ class CodegenContext {
|}
""".stripMargin
+ addNewFunction(fnName, fn)
+
// Add a state and a mapping of the common subexpressions that are associate with this
// state. Adding this expression to subExprEliminationExprMap means it will call `fn`
// when it is code generated. This decision should be a cost based one.
@@ -885,7 +792,7 @@ class CodegenContext {
addMutableState(javaType(expr.dataType), value,
s"$value = ${defaultValue(expr.dataType)};")
- subexprFunctions += s"${addNewFunction(fnName, fn)}($INPUT_ROW);"
+ subexprFunctions += s"$fnName($INPUT_ROW);"
val state = SubExprEliminationState(isNull, value)
e.foreach(subExprEliminationExprs.put(_, state))
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateMutableProjection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateMutableProjection.scala
index 635766835029b..4d732445544a8 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateMutableProjection.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateMutableProjection.scala
@@ -63,21 +63,21 @@ object GenerateMutableProjection extends CodeGenerator[Seq[Expression], MutableP
if (e.nullable) {
val isNull = s"isNull_$i"
val value = s"value_$i"
- ctx.addMutableState("boolean", isNull, s"$isNull = true;")
+ ctx.addMutableState("boolean", isNull, s"this.$isNull = true;")
ctx.addMutableState(ctx.javaType(e.dataType), value,
- s"$value = ${ctx.defaultValue(e.dataType)};")
+ s"this.$value = ${ctx.defaultValue(e.dataType)};")
s"""
${ev.code}
- $isNull = ${ev.isNull};
- $value = ${ev.value};
+ this.$isNull = ${ev.isNull};
+ this.$value = ${ev.value};
"""
} else {
val value = s"value_$i"
ctx.addMutableState(ctx.javaType(e.dataType), value,
- s"$value = ${ctx.defaultValue(e.dataType)};")
+ s"this.$value = ${ctx.defaultValue(e.dataType)};")
s"""
${ev.code}
- $value = ${ev.value};
+ this.$value = ${ev.value};
"""
}
}
@@ -87,7 +87,7 @@ object GenerateMutableProjection extends CodeGenerator[Seq[Expression], MutableP
val updates = validExpr.zip(index).map {
case (e, i) =>
- val ev = ExprCode("", s"isNull_$i", s"value_$i")
+ val ev = ExprCode("", s"this.isNull_$i", s"this.value_$i")
ctx.updateColumn("mutableRow", e.dataType, i, ev, e.nullable)
}
@@ -135,9 +135,6 @@ object GenerateMutableProjection extends CodeGenerator[Seq[Expression], MutableP
$allUpdates
return mutableRow;
}
-
- ${ctx.initNestedClasses()}
- ${ctx.declareNestedClasses()}
}
"""
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateOrdering.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateOrdering.scala
index a31943255b995..f7fc2d54a047b 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateOrdering.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateOrdering.scala
@@ -179,9 +179,6 @@ object GenerateOrdering extends CodeGenerator[Seq[SortOrder], Ordering[InternalR
$comparisons
return 0;
}
-
- ${ctx.initNestedClasses()}
- ${ctx.declareNestedClasses()}
}"""
val code = CodeFormatter.stripOverlappingComments(
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GeneratePredicate.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GeneratePredicate.scala
index b400783bb5e55..dcd1ed96a298e 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GeneratePredicate.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GeneratePredicate.scala
@@ -72,9 +72,6 @@ object GeneratePredicate extends CodeGenerator[Expression, Predicate] {
${eval.code}
return !${eval.isNull} && ${eval.value};
}
-
- ${ctx.initNestedClasses()}
- ${ctx.declareNestedClasses()}
}"""
val code = CodeFormatter.stripOverlappingComments(
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateSafeProjection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateSafeProjection.scala
index f708aeff2b146..b1cb6edefb852 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateSafeProjection.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateSafeProjection.scala
@@ -49,7 +49,7 @@ object GenerateSafeProjection extends CodeGenerator[Seq[Expression], Projection]
val output = ctx.freshName("safeRow")
val values = ctx.freshName("values")
// These expressions could be split into multiple functions
- ctx.addMutableState("Object[]", values, s"$values = null;")
+ ctx.addMutableState("Object[]", values, s"this.$values = null;")
val rowClass = classOf[GenericInternalRow].getName
@@ -65,10 +65,10 @@ object GenerateSafeProjection extends CodeGenerator[Seq[Expression], Projection]
val allFields = ctx.splitExpressions(tmp, fieldWriters)
val code = s"""
final InternalRow $tmp = $input;
- $values = new Object[${schema.length}];
+ this.$values = new Object[${schema.length}];
$allFields
final InternalRow $output = new $rowClass($values);
- $values = null;
+ this.$values = null;
"""
ExprCode(code, "false", output)
@@ -184,9 +184,6 @@ object GenerateSafeProjection extends CodeGenerator[Seq[Expression], Projection]
$allExpressions
return mutableRow;
}
-
- ${ctx.initNestedClasses()}
- ${ctx.declareNestedClasses()}
}
"""
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala
index febfe3124f2bd..b358102d914bd 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala
@@ -82,7 +82,7 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro
val rowWriterClass = classOf[UnsafeRowWriter].getName
val rowWriter = ctx.freshName("rowWriter")
ctx.addMutableState(rowWriterClass, rowWriter,
- s"$rowWriter = new $rowWriterClass($bufferHolder, ${inputs.length});")
+ s"this.$rowWriter = new $rowWriterClass($bufferHolder, ${inputs.length});")
val resetWriter = if (isTopLevel) {
// For top level row writer, it always writes to the beginning of the global buffer holder,
@@ -182,7 +182,7 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro
val arrayWriterClass = classOf[UnsafeArrayWriter].getName
val arrayWriter = ctx.freshName("arrayWriter")
ctx.addMutableState(arrayWriterClass, arrayWriter,
- s"$arrayWriter = new $arrayWriterClass();")
+ s"this.$arrayWriter = new $arrayWriterClass();")
val numElements = ctx.freshName("numElements")
val index = ctx.freshName("index")
val element = ctx.freshName("element")
@@ -321,7 +321,7 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro
val holder = ctx.freshName("holder")
val holderClass = classOf[BufferHolder].getName
ctx.addMutableState(holderClass, holder,
- s"$holder = new $holderClass($result, ${numVarLenFields * 32});")
+ s"this.$holder = new $holderClass($result, ${numVarLenFields * 32});")
val resetBufferHolder = if (numVarLenFields == 0) {
""
@@ -402,9 +402,6 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro
${eval.code.trim}
return ${eval.value};
}
-
- ${ctx.initNestedClasses()}
- ${ctx.declareNestedClasses()}
}
"""
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala
index 98c4cbee38dee..b6675a84ece48 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala
@@ -93,7 +93,7 @@ private [sql] object GenArrayData {
if (!ctx.isPrimitiveType(elementType)) {
val genericArrayClass = classOf[GenericArrayData].getName
ctx.addMutableState("Object[]", arrayName,
- s"$arrayName = new Object[${numElements}];")
+ s"this.$arrayName = new Object[${numElements}];")
val assignments = elementsCode.zipWithIndex.map { case (eval, i) =>
val isNullAssignment = if (!isMapKey) {
@@ -340,7 +340,7 @@ case class CreateNamedStruct(children: Seq[Expression]) extends CreateNamedStruc
override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
val rowClass = classOf[GenericInternalRow].getName
val values = ctx.freshName("values")
- ctx.addMutableState("Object[]", values, s"$values = null;")
+ ctx.addMutableState("Object[]", values, s"this.$values = null;")
ev.copy(code = s"""
$values = new Object[${valExprs.size}];""" +
@@ -357,7 +357,7 @@ case class CreateNamedStruct(children: Seq[Expression]) extends CreateNamedStruc
}) +
s"""
final InternalRow ${ev.value} = new $rowClass($values);
- $values = null;
+ this.$values = null;
""", isNull = "false")
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/conditionalExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/conditionalExpressions.scala
index ae8efb673f91c..ee365fe636614 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/conditionalExpressions.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/conditionalExpressions.scala
@@ -131,8 +131,8 @@ case class If(predicate: Expression, trueValue: Expression, falseValue: Expressi
| $globalValue = ${ev.value};
|}
""".stripMargin
- val fullFuncName = ctx.addNewFunction(funcName, funcBody)
- (fullFuncName, globalIsNull, globalValue)
+ ctx.addNewFunction(funcName, funcBody)
+ (funcName, globalIsNull, globalValue)
}
override def toString: String = s"if ($predicate) $trueValue else $falseValue"
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/generators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/generators.scala
index baa5ba68dcb30..e84796f2edad0 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/generators.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/generators.scala
@@ -181,7 +181,7 @@ case class Stack(children: Seq[Expression]) extends Generator {
override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
// Rows - we write these into an array.
val rowData = ctx.freshName("rows")
- ctx.addMutableState("InternalRow[]", rowData, s"$rowData = new InternalRow[$numRows];")
+ ctx.addMutableState("InternalRow[]", rowData, s"this.$rowData = new InternalRow[$numRows];")
val values = children.tail
val dataTypes = values.take(numFields).map(_.dataType)
val code = ctx.splitExpressions(ctx.INPUT_ROW, Seq.tabulate(numRows) { row =>
@@ -190,7 +190,7 @@ case class Stack(children: Seq[Expression]) extends Generator {
if (index < values.length) values(index) else Literal(null, dataTypes(col))
}
val eval = CreateStruct(fields).genCode(ctx)
- s"${eval.code}\n$rowData[$row] = ${eval.value};"
+ s"${eval.code}\nthis.$rowData[$row] = ${eval.value};"
})
// Create the collection.
@@ -198,7 +198,7 @@ case class Stack(children: Seq[Expression]) extends Generator {
ctx.addMutableState(
s"$wrapperClass",
ev.value,
- s"${ev.value} = $wrapperClass$$.MODULE$$.make($rowData);")
+ s"this.${ev.value} = $wrapperClass$$.MODULE$$.make(this.$rowData);")
ev.copy(code = code, isNull = "false")
}
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala
index 2bd752c82e6c1..1a202ecf745c9 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala
@@ -981,7 +981,7 @@ case class InitializeJavaBean(beanInstance: Expression, setters: Map[String, Exp
val code = s"""
${instanceGen.code}
- ${javaBeanInstance} = ${instanceGen.value};
+ this.${javaBeanInstance} = ${instanceGen.value};
if (!${instanceGen.isNull}) {
$initializeCode
}
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/GeneratedProjectionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/GeneratedProjectionSuite.scala
index d7ba57a697b08..b69b74b4240bd 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/GeneratedProjectionSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/GeneratedProjectionSuite.scala
@@ -33,10 +33,10 @@ class GeneratedProjectionSuite extends SparkFunSuite {
test("generated projections on wider table") {
val N = 1000
- val wideRow1 = new GenericInternalRow((0 until N).toArray[Any])
+ val wideRow1 = new GenericInternalRow((1 to N).toArray[Any])
val schema1 = StructType((1 to N).map(i => StructField("", IntegerType)))
val wideRow2 = new GenericInternalRow(
- (0 until N).map(i => UTF8String.fromString(i.toString)).toArray[Any])
+ (1 to N).map(i => UTF8String.fromString(i.toString)).toArray[Any])
val schema2 = StructType((1 to N).map(i => StructField("", StringType)))
val joined = new JoinedRow(wideRow1, wideRow2)
val joinedSchema = StructType(schema1 ++ schema2)
@@ -48,12 +48,12 @@ class GeneratedProjectionSuite extends SparkFunSuite {
val unsafeProj = UnsafeProjection.create(nestedSchema)
val unsafe: UnsafeRow = unsafeProj(nested)
(0 until N).foreach { i =>
- val s = UTF8String.fromString(i.toString)
- assert(i === unsafe.getInt(i + 2))
+ val s = UTF8String.fromString((i + 1).toString)
+ assert(i + 1 === unsafe.getInt(i + 2))
assert(s === unsafe.getUTF8String(i + 2 + N))
- assert(i === unsafe.getStruct(0, N * 2).getInt(i))
+ assert(i + 1 === unsafe.getStruct(0, N * 2).getInt(i))
assert(s === unsafe.getStruct(0, N * 2).getUTF8String(i + N))
- assert(i === unsafe.getStruct(1, N * 2).getInt(i))
+ assert(i + 1 === unsafe.getStruct(1, N * 2).getInt(i))
assert(s === unsafe.getStruct(1, N * 2).getUTF8String(i + N))
}
@@ -62,63 +62,13 @@ class GeneratedProjectionSuite extends SparkFunSuite {
val result = safeProj(unsafe)
// Can't compare GenericInternalRow with JoinedRow directly
(0 until N).foreach { i =>
- val s = UTF8String.fromString(i.toString)
- assert(i === result.getInt(i + 2))
+ val r = i + 1
+ val s = UTF8String.fromString((i + 1).toString)
+ assert(r === result.getInt(i + 2))
assert(s === result.getUTF8String(i + 2 + N))
- assert(i === result.getStruct(0, N * 2).getInt(i))
+ assert(r === result.getStruct(0, N * 2).getInt(i))
assert(s === result.getStruct(0, N * 2).getUTF8String(i + N))
- assert(i === result.getStruct(1, N * 2).getInt(i))
- assert(s === result.getStruct(1, N * 2).getUTF8String(i + N))
- }
-
- // test generated MutableProjection
- val exprs = nestedSchema.fields.zipWithIndex.map { case (f, i) =>
- BoundReference(i, f.dataType, true)
- }
- val mutableProj = GenerateMutableProjection.generate(exprs)
- val row1 = mutableProj(result)
- assert(result === row1)
- val row2 = mutableProj(result)
- assert(result === row2)
- }
-
- test("generated projections on wider table requiring class-splitting") {
- val N = 4000
- val wideRow1 = new GenericInternalRow((0 until N).toArray[Any])
- val schema1 = StructType((1 to N).map(i => StructField("", IntegerType)))
- val wideRow2 = new GenericInternalRow(
- (0 until N).map(i => UTF8String.fromString(i.toString)).toArray[Any])
- val schema2 = StructType((1 to N).map(i => StructField("", StringType)))
- val joined = new JoinedRow(wideRow1, wideRow2)
- val joinedSchema = StructType(schema1 ++ schema2)
- val nested = new JoinedRow(InternalRow(joined, joined), joined)
- val nestedSchema = StructType(
- Seq(StructField("", joinedSchema), StructField("", joinedSchema)) ++ joinedSchema)
-
- // test generated UnsafeProjection
- val unsafeProj = UnsafeProjection.create(nestedSchema)
- val unsafe: UnsafeRow = unsafeProj(nested)
- (0 until N).foreach { i =>
- val s = UTF8String.fromString(i.toString)
- assert(i === unsafe.getInt(i + 2))
- assert(s === unsafe.getUTF8String(i + 2 + N))
- assert(i === unsafe.getStruct(0, N * 2).getInt(i))
- assert(s === unsafe.getStruct(0, N * 2).getUTF8String(i + N))
- assert(i === unsafe.getStruct(1, N * 2).getInt(i))
- assert(s === unsafe.getStruct(1, N * 2).getUTF8String(i + N))
- }
-
- // test generated SafeProjection
- val safeProj = FromUnsafeProjection(nestedSchema)
- val result = safeProj(unsafe)
- // Can't compare GenericInternalRow with JoinedRow directly
- (0 until N).foreach { i =>
- val s = UTF8String.fromString(i.toString)
- assert(i === result.getInt(i + 2))
- assert(s === result.getUTF8String(i + 2 + N))
- assert(i === result.getStruct(0, N * 2).getInt(i))
- assert(s === result.getStruct(0, N * 2).getUTF8String(i + N))
- assert(i === result.getStruct(1, N * 2).getInt(i))
+ assert(r === result.getStruct(1, N * 2).getInt(i))
assert(s === result.getStruct(1, N * 2).getUTF8String(i + N))
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/ColumnarBatchScan.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/ColumnarBatchScan.scala
index 74a47da2deef2..e86116680a57a 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/ColumnarBatchScan.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/ColumnarBatchScan.scala
@@ -93,7 +93,7 @@ private[sql] trait ColumnarBatchScan extends CodegenSupport {
}
val nextBatch = ctx.freshName("nextBatch")
- val nextBatchFuncName = ctx.addNewFunction(nextBatch,
+ ctx.addNewFunction(nextBatch,
s"""
|private void $nextBatch() throws java.io.IOException {
| long getBatchStart = System.nanoTime();
@@ -121,7 +121,7 @@ private[sql] trait ColumnarBatchScan extends CodegenSupport {
}
s"""
|if ($batch == null) {
- | $nextBatchFuncName();
+ | $nextBatch();
|}
|while ($batch != null) {
| int $numRows = $batch.numRows();
@@ -133,7 +133,7 @@ private[sql] trait ColumnarBatchScan extends CodegenSupport {
| }
| $idx = $numRows;
| $batch = null;
- | $nextBatchFuncName();
+ | $nextBatch();
|}
|$scanTimeMetric.add($scanTimeTotalNs / (1000 * 1000));
|$scanTimeTotalNs = 0;
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SortExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SortExec.scala
index ff71fd4dc7bb7..f98ae82574d20 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SortExec.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SortExec.scala
@@ -141,7 +141,7 @@ case class SortExec(
ctx.addMutableState("scala.collection.Iterator", sortedIterator, "")
val addToSorter = ctx.freshName("addToSorter")
- val addToSorterFuncName = ctx.addNewFunction(addToSorter,
+ ctx.addNewFunction(addToSorter,
s"""
| private void $addToSorter() throws java.io.IOException {
| ${child.asInstanceOf[CodegenSupport].produce(ctx, this)}
@@ -160,7 +160,7 @@ case class SortExec(
s"""
| if ($needToSort) {
| long $spillSizeBefore = $metrics.memoryBytesSpilled();
- | $addToSorterFuncName();
+ | $addToSorter();
| $sortedIterator = $sorterVariable.sort();
| $sortTime.add($sorterVariable.getSortTimeNanos() / 1000000);
| $peakMemory.add($sorterVariable.getPeakMemoryUsage());
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegenExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegenExec.scala
index c7e9d25bd2cc0..c1e1a631c677e 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegenExec.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegenExec.scala
@@ -357,9 +357,6 @@ case class WholeStageCodegenExec(child: SparkPlan) extends UnaryExecNode with Co
protected void processNext() throws java.io.IOException {
${code.trim}
}
-
- ${ctx.initNestedClasses()}
- ${ctx.declareNestedClasses()}
}
""".trim
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala
index bf7fa07765b9a..68c8e6ce62cbb 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala
@@ -209,7 +209,7 @@ case class HashAggregateExec(
}
val doAgg = ctx.freshName("doAggregateWithoutKey")
- val doAggFuncName = ctx.addNewFunction(doAgg,
+ ctx.addNewFunction(doAgg,
s"""
| private void $doAgg() throws java.io.IOException {
| // initialize aggregation buffer
@@ -226,7 +226,7 @@ case class HashAggregateExec(
| while (!$initAgg) {
| $initAgg = true;
| long $beforeAgg = System.nanoTime();
- | $doAggFuncName();
+ | $doAgg();
| $aggTime.add((System.nanoTime() - $beforeAgg) / 1000000);
|
| // output the result
@@ -592,7 +592,7 @@ case class HashAggregateExec(
} else ""
}
- val doAggFuncName = ctx.addNewFunction(doAgg,
+ ctx.addNewFunction(doAgg,
s"""
${generateGenerateCode}
private void $doAgg() throws java.io.IOException {
@@ -672,7 +672,7 @@ case class HashAggregateExec(
if (!$initAgg) {
$initAgg = true;
long $beforeAgg = System.nanoTime();
- $doAggFuncName();
+ $doAgg();
$aggTime.add((System.nanoTime() - $beforeAgg) / 1000000);
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala
index bb24489ade1b3..bd7a5c5d914c1 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala
@@ -281,8 +281,10 @@ case class SampleExec(
val samplerClass = classOf[PoissonSampler[UnsafeRow]].getName
val initSampler = ctx.freshName("initSampler")
ctx.copyResult = true
+ ctx.addMutableState(s"$samplerClass", sampler,
+ s"$initSampler();")
- val initSamplerFuncName = ctx.addNewFunction(initSampler,
+ ctx.addNewFunction(initSampler,
s"""
| private void $initSampler() {
| $sampler = new $samplerClass($upperBound - $lowerBound, false);
@@ -297,8 +299,6 @@ case class SampleExec(
| }
""".stripMargin.trim)
- ctx.addMutableState(s"$samplerClass", sampler, s"$initSamplerFuncName();")
-
val samplingCount = ctx.freshName("samplingCount")
s"""
| int $samplingCount = $sampler.sample();
@@ -394,7 +394,7 @@ case class RangeExec(range: org.apache.spark.sql.catalyst.plans.logical.Range)
// The default size of a batch, which must be positive integer
val batchSize = 1000
- val initRangeFuncName = ctx.addNewFunction("initRange",
+ ctx.addNewFunction("initRange",
s"""
| private void initRange(int idx) {
| $BigInt index = $BigInt.valueOf(idx);
@@ -451,7 +451,7 @@ case class RangeExec(range: org.apache.spark.sql.catalyst.plans.logical.Range)
| // initialize Range
| if (!$initTerm) {
| $initTerm = true;
- | $initRangeFuncName(partitionIndex);
+ | initRange(partitionIndex);
| }
|
| while (true) {
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/GenerateColumnAccessor.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/GenerateColumnAccessor.scala
index a66e8e2b46e3d..14024d6c10558 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/GenerateColumnAccessor.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/GenerateColumnAccessor.scala
@@ -128,7 +128,9 @@ object GenerateColumnAccessor extends CodeGenerator[Seq[DataType], ColumnarItera
} else {
val groupedAccessorsItr = initializeAccessors.grouped(numberOfStatementsThreshold)
val groupedExtractorsItr = extractors.grouped(numberOfStatementsThreshold)
- val accessorNames = groupedAccessorsItr.zipWithIndex.map { case (body, i) =>
+ var groupedAccessorsLength = 0
+ groupedAccessorsItr.zipWithIndex.foreach { case (body, i) =>
+ groupedAccessorsLength += 1
val funcName = s"accessors$i"
val funcCode = s"""
|private void $funcName() {
@@ -137,7 +139,7 @@ object GenerateColumnAccessor extends CodeGenerator[Seq[DataType], ColumnarItera
""".stripMargin
ctx.addNewFunction(funcName, funcCode)
}
- val extractorNames = groupedExtractorsItr.zipWithIndex.map { case (body, i) =>
+ groupedExtractorsItr.zipWithIndex.foreach { case (body, i) =>
val funcName = s"extractors$i"
val funcCode = s"""
|private void $funcName() {
@@ -146,8 +148,8 @@ object GenerateColumnAccessor extends CodeGenerator[Seq[DataType], ColumnarItera
""".stripMargin
ctx.addNewFunction(funcName, funcCode)
}
- (accessorNames.map { accessorName => s"$accessorName();" }.mkString("\n"),
- extractorNames.map { extractorName => s"$extractorName();" }.mkString("\n"))
+ ((0 to groupedAccessorsLength - 1).map { i => s"accessors$i();" }.mkString("\n"),
+ (0 to groupedAccessorsLength - 1).map { i => s"extractors$i();" }.mkString("\n"))
}
val codeBody = s"""
@@ -222,9 +224,6 @@ object GenerateColumnAccessor extends CodeGenerator[Seq[DataType], ColumnarItera
unsafeRow.setTotalSize(bufferHolder.totalSize());
return unsafeRow;
}
-
- ${ctx.initNestedClasses()}
- ${ctx.declareNestedClasses()}
}"""
val code = CodeFormatter.stripOverlappingComments(
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoinExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoinExec.scala
index 8445c26eeee58..26fb6103953fc 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoinExec.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoinExec.scala
@@ -478,7 +478,7 @@ case class SortMergeJoinExec(
| }
| return false; // unreachable
|}
- """.stripMargin, inlineToOuterClass = true)
+ """.stripMargin)
(leftRow, matches)
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/limit.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/limit.scala
index 73a0f8735ed45..757fe2185d302 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/limit.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/limit.scala
@@ -75,7 +75,7 @@ trait BaseLimitExec extends UnaryExecNode with CodegenSupport {
protected boolean stopEarly() {
return $stopEarly;
}
- """, inlineToOuterClass = true)
+ """)
val countTerm = ctx.freshName("count")
ctx.addMutableState("int", countTerm, s"$countTerm = 0;")
s"""