From f479fcbe4d5e550896e45f53df538ddfd8ae11f1 Mon Sep 17 00:00:00 2001 From: Jakob Odersky Date: Mon, 14 Mar 2016 17:35:33 -0700 Subject: [PATCH 1/2] Fix expression generation for optional inner classes. Also add regression test for Dataset's handling of classes defined in package objects. --- .../spark/sql/catalyst/ScalaReflection.scala | 20 +++++++++++++++++-- .../encoders/ExpressionEncoderSuite.scala | 2 ++ .../spark/sql/DatasetPrimitiveSuite.scala | 10 ++++++++++ 3 files changed, 30 insertions(+), 2 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala index bf07f4557a5b4..51314f453fd83 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala @@ -476,7 +476,7 @@ object ScalaReflection extends ScalaReflection { // For non-primitives, we can just extract the object from the Option and then recurse. case other => val className = getClassNameFromType(optType) - val classObj = Utils.classForName(className) + val classObj = getClassFromType(optType) val optionObjectType = ObjectType(classObj) val newPath = s"""- option value class: "$className"""" +: walkedTypePath @@ -626,6 +626,9 @@ object ScalaReflection extends ScalaReflection { constructParams(t).map(_.name.toString) } + /* + * Retrieves the runtime class corresponding to the provided type. + */ def getClassFromType(tpe: Type): Class[_] = mirror.runtimeClass(tpe.erasure.typeSymbol.asClass) } @@ -676,9 +679,12 @@ trait ScalaReflection { /** Returns a catalyst DataType and its nullability for the given Scala Type using reflection. */ def schemaFor(tpe: `Type`): Schema = ScalaReflectionLock.synchronized { val className = getClassNameFromType(tpe) + tpe match { + case t if Utils.classIsLoadable(className) && Utils.classForName(className).isAnnotationPresent(classOf[SQLUserDefinedType]) => + // Note: We check for classIsLoadable above since Utils.classForName uses Java reflection, // whereas className is from Scala reflection. This can make it hard to find classes // in some cases, such as when a class is enclosed in an object (in which case @@ -748,7 +754,16 @@ trait ScalaReflection { case _: UnsupportedOperationException => Schema(NullType, nullable = true) } - /** Returns the full class name for a type. */ + /** + * Returns the full class name for a type. The returned name is the canonical + * Scala name, where each component is separated by a period. It is NOT the + * Java-equivalent runtime name (no dollar signs). + * + * In simple cases, both the Scala and Java names are the same, however when Scala + * generates constructs that do not map to a Java equivalent, such as singleton objects + * or nested classes in package objects, it uses the dollar sign ($) to create + * synthetic classes, emulating behaviour in Java bytecode. + */ def getClassNameFromType(tpe: `Type`): String = { tpe.erasure.typeSymbol.asClass.fullName } @@ -792,4 +807,5 @@ trait ScalaReflection { } params.flatten } + } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoderSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoderSuite.scala index cca320fae9505..3024858b064d0 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoderSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoderSuite.scala @@ -152,6 +152,8 @@ class ExpressionEncoderSuite extends PlanTest with AnalysisTest { productTest(InnerClass(1)) encodeDecodeTest(Array(InnerClass(1)), "array of inner class") + encodeDecodeTest(Array(Option(InnerClass(1))), "array of optional inner class") + productTest(PrimitiveData(1, 1, 1, 1, 1, 1, true)) productTest( diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DatasetPrimitiveSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DatasetPrimitiveSuite.scala index 6e9840e4a7301..ff022b2dc45ee 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DatasetPrimitiveSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DatasetPrimitiveSuite.scala @@ -23,6 +23,10 @@ import org.apache.spark.sql.test.SharedSQLContext case class IntClass(value: Int) +package object packageobject { + case class PackageClass(value: Int) +} + class DatasetPrimitiveSuite extends QueryTest with SharedSQLContext { import testImplicits._ @@ -127,4 +131,10 @@ class DatasetPrimitiveSuite extends QueryTest with SharedSQLContext { checkDataset(Seq(Array("test")).toDS(), Array("test")) checkDataset(Seq(Array(Tuple1(1))).toDS(), Array(Tuple1(1))) } + + test("package objects") { + import packageobject._ + checkDataset(Seq(PackageClass(1)).toDS(), PackageClass(1)) + } + } From b925a97dd59b453a80567971444b2c2db5ce4a74 Mon Sep 17 00:00:00 2001 From: Jakob Odersky Date: Wed, 16 Mar 2016 14:26:02 -0700 Subject: [PATCH 2/2] Handle arrays in options --- .../apache/spark/sql/catalyst/ScalaReflection.scala | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala index 51314f453fd83..5e1672c779dba 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala @@ -476,11 +476,17 @@ object ScalaReflection extends ScalaReflection { // For non-primitives, we can just extract the object from the Option and then recurse. case other => val className = getClassNameFromType(optType) - val classObj = getClassFromType(optType) - val optionObjectType = ObjectType(classObj) val newPath = s"""- option value class: "$className"""" +: walkedTypePath + val optionObjectType: DataType = other match { + // Special handling is required for arrays, as getClassFromType() will fail + // since Scala Arrays map to native Java constructs. E.g. "Array[Int]" will map to + // the Java type "[I". + case arr if arr <:< localTypeOf[Array[_]] => arrayClassFor(t) + case cls => ObjectType(getClassFromType(cls)) + } val unwrapped = UnwrapOption(optionObjectType, inputObject) + expressions.If( IsNull(unwrapped), expressions.Literal.create(null, silentSchemaFor(optType).dataType),