From 319c80447c4fc1baa3167c889d1d8c072ee5b31c Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Mon, 6 Nov 2017 09:04:52 +0000 Subject: [PATCH 1/3] ScalaReflection should produce correct field names for special characters. --- .../spark/sql/catalyst/ScalaReflection.scala | 9 +++++---- .../catalyst/expressions/objects/objects.scala | 11 +++++++---- .../sql/catalyst/ScalaReflectionSuite.scala | 18 +++++++++++++++++- 3 files changed, 29 insertions(+), 9 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala index 17e595f9c5d8d..f62553ddd3971 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala @@ -146,7 +146,7 @@ object ScalaReflection extends ScalaReflection { def addToPath(part: String, dataType: DataType, walkedTypePath: Seq[String]): Expression = { val newPath = path .map(p => UnresolvedExtractValue(p, expressions.Literal(part))) - .getOrElse(UnresolvedAttribute(part)) + .getOrElse(UnresolvedAttribute.quoted(part)) upCastToExpectedType(newPath, dataType, walkedTypePath) } @@ -675,7 +675,7 @@ object ScalaReflection extends ScalaReflection { val m = runtimeMirror(cls.getClassLoader) val classSymbol = m.staticClass(cls.getName) val t = classSymbol.selfType - constructParams(t).map(_.name.toString) + constructParams(t).map(_.name.decodedName.toString) } /** @@ -855,11 +855,12 @@ trait ScalaReflection { // if there are type variables to fill in, do the substitution (SomeClass[T] -> SomeClass[Int]) if (actualTypeArgs.nonEmpty) { params.map { p => - p.name.toString -> p.typeSignature.substituteTypes(formalTypeArgs, actualTypeArgs) + p.name.decodedName.toString -> + p.typeSignature.substituteTypes(formalTypeArgs, actualTypeArgs) } } else { params.map { p => - p.name.toString -> p.typeSignature + p.name.decodedName.toString -> p.typeSignature } } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala index 6ae3490a3f863..f2eee991c9865 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala @@ -28,6 +28,7 @@ import org.apache.spark.{SparkConf, SparkEnv} import org.apache.spark.serializer._ import org.apache.spark.sql.Row import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.ScalaReflection.universe.TermName import org.apache.spark.sql.catalyst.encoders.RowEncoder import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode} @@ -214,11 +215,13 @@ case class Invoke( override def eval(input: InternalRow): Any = throw new UnsupportedOperationException("Only code-generated evaluation is supported.") + private lazy val encodedFunctionName = TermName(functionName).encodedName.toString + @transient lazy val method = targetObject.dataType match { case ObjectType(cls) => - val m = cls.getMethods.find(_.getName == functionName) + val m = cls.getMethods.find(_.getName == encodedFunctionName) if (m.isEmpty) { - sys.error(s"Couldn't find $functionName on $cls") + sys.error(s"Couldn't find $encodedFunctionName on $cls") } else { m } @@ -247,7 +250,7 @@ case class Invoke( } val evaluate = if (returnPrimitive) { - getFuncResult(ev.value, s"${obj.value}.$functionName($argString)") + getFuncResult(ev.value, s"${obj.value}.$encodedFunctionName($argString)") } else { val funcResult = ctx.freshName("funcResult") // If the function can return null, we do an extra check to make sure our null bit is still @@ -265,7 +268,7 @@ case class Invoke( } s""" Object $funcResult = null; - ${getFuncResult(funcResult, s"${obj.value}.$functionName($argString)")} + ${getFuncResult(funcResult, s"${obj.value}.$encodedFunctionName($argString)")} $assignResult """ } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/ScalaReflectionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/ScalaReflectionSuite.scala index a5b9855e959d4..f9a5573288e1e 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/ScalaReflectionSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/ScalaReflectionSuite.scala @@ -20,7 +20,8 @@ package org.apache.spark.sql.catalyst import java.sql.{Date, Timestamp} import org.apache.spark.SparkFunSuite -import org.apache.spark.sql.catalyst.expressions.{BoundReference, Literal, SpecificInternalRow} +import org.apache.spark.sql.catalyst.analysis.UnresolvedAttribute +import org.apache.spark.sql.catalyst.expressions.{BoundReference, Literal, SpecificInternalRow, UpCast} import org.apache.spark.sql.catalyst.expressions.objects.NewInstance import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.UTF8String @@ -79,6 +80,8 @@ case class MultipleConstructorsData(a: Int, b: String, c: Double) { def this(b: String, a: Int) = this(a, b, c = 1.0) } +case class SpecialCharAsFieldData(`field.1`: String, `field 2`: String) + object TestingUDT { @SQLUserDefinedType(udt = classOf[NestedStructUDT]) class NestedStruct(val a: Integer, val b: Long, val c: Double) @@ -335,4 +338,17 @@ class ScalaReflectionSuite extends SparkFunSuite { assert(linkedHashMapDeserializer.dataType == ObjectType(classOf[LHMap[_, _]])) } + test("SPARK-22442: Generate correct field names for special characters") { + val serializer = serializerFor[SpecialCharAsFieldData](BoundReference( + 0, ObjectType(classOf[SpecialCharAsFieldData]), nullable = false)) + val deserializer = deserializerFor[SpecialCharAsFieldData] + assert(serializer.dataType(0).name == "field.1") + assert(serializer.dataType(1).name == "field 2") + + val argumentsFields = deserializer.asInstanceOf[NewInstance].arguments.flatMap { _.collect { + case UpCast(u: UnresolvedAttribute, _, _) => u.name + }} + assert(argumentsFields(0) == "`field.1`") + assert(argumentsFields(1) == "field 2") + } } From a671a83683fb3b5a7df0a5e213b7d01d7f7736e4 Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Mon, 6 Nov 2017 12:18:46 +0000 Subject: [PATCH 2/3] Add test. --- .../scala/org/apache/spark/sql/DatasetSuite.scala | 12 ++++++++++++ 1 file changed, 12 insertions(+) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala index 1537ce3313c09..c67165c7abca6 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala @@ -1398,6 +1398,16 @@ class DatasetSuite extends QueryTest with SharedSQLContext { val actual = kvDataset.toString assert(expected === actual) } + + test("SPARK-22442: Generate correct field names for special characters") { + withTempPath { dir => + val path = dir.getCanonicalPath + val data = """{"field.1": 1, "field 2": 2}""" + Seq(data).toDF().repartition(1).write.text(path) + val ds = spark.read.json(path).as[SpecialCharClass] + checkDataset(ds, SpecialCharClass("1", "2")) + } + } } case class SingleData(id: Int) @@ -1487,3 +1497,5 @@ case class CircularReferenceClassB(cls: CircularReferenceClassA) case class CircularReferenceClassC(ar: Array[CircularReferenceClassC]) case class CircularReferenceClassD(map: Map[String, CircularReferenceClassE]) case class CircularReferenceClassE(id: String, list: List[CircularReferenceClassD]) + +case class SpecialCharClass(`field.1`: String, `field 2`: String) From 10db6b4ba2ea099554743a2ebcfcb19c46ed264e Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Thu, 9 Nov 2017 03:23:36 +0000 Subject: [PATCH 3/3] Address comment. --- .../apache/spark/sql/catalyst/ScalaReflectionSuite.scala | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/ScalaReflectionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/ScalaReflectionSuite.scala index f9a5573288e1e..f77af5db3279b 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/ScalaReflectionSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/ScalaReflectionSuite.scala @@ -346,9 +346,9 @@ class ScalaReflectionSuite extends SparkFunSuite { assert(serializer.dataType(1).name == "field 2") val argumentsFields = deserializer.asInstanceOf[NewInstance].arguments.flatMap { _.collect { - case UpCast(u: UnresolvedAttribute, _, _) => u.name + case UpCast(u: UnresolvedAttribute, _, _) => u.nameParts }} - assert(argumentsFields(0) == "`field.1`") - assert(argumentsFields(1) == "field 2") + assert(argumentsFields(0) == Seq("field.1")) + assert(argumentsFields(1) == Seq("field 2")) } }