Skip to content

Commit

Permalink
[SPARK-32999][SQL] Use Utils.getSimpleName to avoid hitting Malformed…
Browse files Browse the repository at this point in the history
… class name in TreeNode

### What changes were proposed in this pull request?

Use `Utils.getSimpleName` to avoid hitting `Malformed class name` error in `TreeNode`.

### Why are the changes needed?

On older JDK versions (e.g. JDK8u), nested Scala classes may trigger `java.lang.Class.getSimpleName` to throw an `java.lang.InternalError: Malformed class name` error.

Similar to #29050, we should use  Spark's `Utils.getSimpleName` utility function in place of `Class.getSimpleName` to avoid hitting the issue.

### Does this PR introduce _any_ user-facing change?

Fixes a bug that throws an error when invoking `TreeNode.nodeName`, otherwise no changes.

### How was this patch tested?

Added new unit test case in `TreeNodeSuite`. Note that the test case assumes the test code can trigger the expected error, otherwise it'll skip the test safely, for compatibility with newer JDKs.

Manually tested on JDK8u and JDK11u and observed expected behavior:
- JDK8u: the test case triggers the "Malformed class name" issue and the fix works;
- JDK11u: the test case does not trigger the "Malformed class name" issue, and the test case is safely skipped.

Closes #29875 from rednaxelafx/spark-32999-getsimplename.

Authored-by: Kris Mok <kris.mok@databricks.com>
Signed-off-by: Dongjoon Hyun <dhyun@apple.com>
  • Loading branch information
rednaxelafx authored and dongjoon-hyun committed Sep 26, 2020
1 parent 934a91f commit 9a155d4
Show file tree
Hide file tree
Showing 2 changed files with 32 additions and 3 deletions.
Expand Up @@ -41,6 +41,7 @@ import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.types._
import org.apache.spark.sql.util.CaseInsensitiveStringMap
import org.apache.spark.storage.StorageLevel
import org.apache.spark.util.Utils

/** Used by [[TreeNode.getNodeNumbered]] when traversing the tree for a given number */
private class MutableInt(var i: Int)
Expand Down Expand Up @@ -521,11 +522,13 @@ abstract class TreeNode[BaseType <: TreeNode[BaseType]] extends Product {
mapChildren(_.clone(), forceCopy = true)
}

private def simpleClassName: String = Utils.getSimpleName(this.getClass)

/**
* Returns the name of this type of TreeNode. Defaults to the class name.
* Note that we remove the "Exec" suffix for physical operators here.
*/
def nodeName: String = getClass.getSimpleName.replaceAll("Exec$", "")
def nodeName: String = simpleClassName.replaceAll("Exec$", "")

/**
* The arguments that should be included in the arg string. Defaults to the `productIterator`.
Expand Down Expand Up @@ -747,7 +750,7 @@ abstract class TreeNode[BaseType <: TreeNode[BaseType]] extends Product {
protected def jsonFields: List[JField] = {
val fieldNames = getConstructorParameterNames(getClass)
val fieldValues = productIterator.toSeq ++ otherCopyArgs
assert(fieldNames.length == fieldValues.length, s"${getClass.getSimpleName} fields: " +
assert(fieldNames.length == fieldValues.length, s"$simpleClassName fields: " +
fieldNames.mkString(", ") + s", values: " + fieldValues.mkString(", "))

fieldNames.zip(fieldValues).map {
Expand Down Expand Up @@ -801,7 +804,7 @@ abstract class TreeNode[BaseType <: TreeNode[BaseType]] extends Product {
try {
val fieldNames = getConstructorParameterNames(p.getClass)
val fieldValues = p.productIterator.toSeq
assert(fieldNames.length == fieldValues.length, s"${getClass.getSimpleName} fields: " +
assert(fieldNames.length == fieldValues.length, s"$simpleClassName fields: " +
fieldNames.mkString(", ") + s", values: " + fieldValues.mkString(", "))
("product-class" -> JString(p.getClass.getName)) :: fieldNames.zip(fieldValues).map {
case (name, value) => name -> parseToJson(value)
Expand Down
Expand Up @@ -736,4 +736,30 @@ class TreeNodeSuite extends SparkFunSuite with SQLHelper {
assertDifferentInstance(leaf, leafCloned)
assert(leaf.child.eq(leafCloned.asInstanceOf[FakeLeafPlan].child))
}

object MalformedClassObject extends Serializable {
case class MalformedNameExpression(child: Expression) extends TaggingExpression
}

test("SPARK-32999: TreeNode.nodeName should not throw malformed class name error") {
val testTriggersExpectedError = try {
classOf[MalformedClassObject.MalformedNameExpression].getSimpleName
false
} catch {
case ex: java.lang.InternalError if ex.getMessage.contains("Malformed class name") =>
true
case ex: Throwable => throw ex
}
// This test case only applies on older JDK versions (e.g. JDK8u), and doesn't trigger the
// issue on newer JDK versions (e.g. JDK11u).
assume(testTriggersExpectedError, "the test case didn't trigger malformed class name error")

val expr = MalformedClassObject.MalformedNameExpression(Literal(1))
try {
expr.nodeName
} catch {
case ex: java.lang.InternalError if ex.getMessage.contains("Malformed class name") =>
fail("TreeNode.nodeName should not throw malformed class name error")
}
}
}

0 comments on commit 9a155d4

Please sign in to comment.