diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala index 8003012f30ca5..1ab7bbdcff697 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala @@ -41,6 +41,7 @@ import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types._ import org.apache.spark.sql.util.CaseInsensitiveStringMap import org.apache.spark.storage.StorageLevel +import org.apache.spark.util.Utils /** Used by [[TreeNode.getNodeNumbered]] when traversing the tree for a given number */ private class MutableInt(var i: Int) @@ -521,11 +522,13 @@ abstract class TreeNode[BaseType <: TreeNode[BaseType]] extends Product { mapChildren(_.clone(), forceCopy = true) } + private def simpleClassName: String = Utils.getSimpleName(this.getClass) + /** * Returns the name of this type of TreeNode. Defaults to the class name. * Note that we remove the "Exec" suffix for physical operators here. */ - def nodeName: String = getClass.getSimpleName.replaceAll("Exec$", "") + def nodeName: String = simpleClassName.replaceAll("Exec$", "") /** * The arguments that should be included in the arg string. Defaults to the `productIterator`. @@ -747,7 +750,7 @@ abstract class TreeNode[BaseType <: TreeNode[BaseType]] extends Product { protected def jsonFields: List[JField] = { val fieldNames = getConstructorParameterNames(getClass) val fieldValues = productIterator.toSeq ++ otherCopyArgs - assert(fieldNames.length == fieldValues.length, s"${getClass.getSimpleName} fields: " + + assert(fieldNames.length == fieldValues.length, s"$simpleClassName fields: " + fieldNames.mkString(", ") + s", values: " + fieldValues.mkString(", ")) fieldNames.zip(fieldValues).map { @@ -801,7 +804,7 @@ abstract class TreeNode[BaseType <: TreeNode[BaseType]] extends Product { try { val fieldNames = getConstructorParameterNames(p.getClass) val fieldValues = p.productIterator.toSeq - assert(fieldNames.length == fieldValues.length, s"${getClass.getSimpleName} fields: " + + assert(fieldNames.length == fieldValues.length, s"$simpleClassName fields: " + fieldNames.mkString(", ") + s", values: " + fieldValues.mkString(", ")) ("product-class" -> JString(p.getClass.getName)) :: fieldNames.zip(fieldValues).map { case (name, value) => name -> parseToJson(value) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/trees/TreeNodeSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/trees/TreeNodeSuite.scala index ff51bc0071c80..4ad8475a0113c 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/trees/TreeNodeSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/trees/TreeNodeSuite.scala @@ -736,4 +736,30 @@ class TreeNodeSuite extends SparkFunSuite with SQLHelper { assertDifferentInstance(leaf, leafCloned) assert(leaf.child.eq(leafCloned.asInstanceOf[FakeLeafPlan].child)) } + + object MalformedClassObject extends Serializable { + case class MalformedNameExpression(child: Expression) extends TaggingExpression + } + + test("SPARK-32999: TreeNode.nodeName should not throw malformed class name error") { + val testTriggersExpectedError = try { + classOf[MalformedClassObject.MalformedNameExpression].getSimpleName + false + } catch { + case ex: java.lang.InternalError if ex.getMessage.contains("Malformed class name") => + true + case ex: Throwable => throw ex + } + // This test case only applies on older JDK versions (e.g. JDK8u), and doesn't trigger the + // issue on newer JDK versions (e.g. JDK11u). + assume(testTriggersExpectedError, "the test case didn't trigger malformed class name error") + + val expr = MalformedClassObject.MalformedNameExpression(Literal(1)) + try { + expr.nodeName + } catch { + case ex: java.lang.InternalError if ex.getMessage.contains("Malformed class name") => + fail("TreeNode.nodeName should not throw malformed class name error") + } + } }