From 1cb30baa791a40b6827eb6aaaa9e3cc8c9ca8fdd Mon Sep 17 00:00:00 2001 From: Shixiong Zhu Date: Fri, 6 May 2016 14:21:19 -0700 Subject: [PATCH] Support using SQLUserDefinedType for case classes --- .../spark/sql/catalyst/ScalaReflection.scala | 70 +++++++++---------- .../encoders/ExpressionEncoderSuite.scala | 34 ++++++++- 2 files changed, 67 insertions(+), 37 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala index d158a64a85bc0..16309e15867c9 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala @@ -342,6 +342,23 @@ object ScalaReflection extends ScalaReflection { "toScalaMap", keyData :: valueData :: Nil) + case t if t.typeSymbol.annotations.exists(_.tpe =:= typeOf[SQLUserDefinedType]) => + val udt = getClassFromType(t).getAnnotation(classOf[SQLUserDefinedType]).udt().newInstance() + val obj = NewInstance( + udt.userClass.getAnnotation(classOf[SQLUserDefinedType]).udt(), + Nil, + dataType = ObjectType(udt.userClass.getAnnotation(classOf[SQLUserDefinedType]).udt())) + Invoke(obj, "deserialize", ObjectType(udt.userClass), getPath :: Nil) + + case t if UDTRegistration.exists(getClassNameFromType(t)) => + val udt = UDTRegistration.getUDTFor(getClassNameFromType(t)).get.newInstance() + .asInstanceOf[UserDefinedType[_]] + val obj = NewInstance( + udt.getClass, + Nil, + dataType = ObjectType(udt.getClass)) + Invoke(obj, "deserialize", ObjectType(udt.userClass), getPath :: Nil) + case t if definedByConstructorParams(t) => val params = getConstructorParameters(t) @@ -382,23 +399,6 @@ object ScalaReflection extends ScalaReflection { } else { newInstance } - - case t if t.typeSymbol.annotations.exists(_.tpe =:= typeOf[SQLUserDefinedType]) => - val udt = getClassFromType(t).getAnnotation(classOf[SQLUserDefinedType]).udt().newInstance() - val obj = NewInstance( - udt.userClass.getAnnotation(classOf[SQLUserDefinedType]).udt(), - Nil, - dataType = ObjectType(udt.userClass.getAnnotation(classOf[SQLUserDefinedType]).udt())) - Invoke(obj, "deserialize", ObjectType(udt.userClass), getPath :: Nil) - - case t if UDTRegistration.exists(getClassNameFromType(t)) => - val udt = UDTRegistration.getUDTFor(getClassNameFromType(t)).get.newInstance() - .asInstanceOf[UserDefinedType[_]] - val obj = NewInstance( - udt.getClass, - Nil, - dataType = ObjectType(udt.getClass)) - Invoke(obj, "deserialize", ObjectType(udt.userClass), getPath :: Nil) } } @@ -516,17 +516,6 @@ object ScalaReflection extends ScalaReflection { val TypeRef(_, _, Seq(elementType)) = t toCatalystArray(inputObject, elementType) - case t if definedByConstructorParams(t) => - val params = getConstructorParameters(t) - val nonNullOutput = CreateNamedStruct(params.flatMap { case (fieldName, fieldType) => - val fieldValue = Invoke(inputObject, fieldName, dataTypeFor(fieldType)) - val clsName = getClassNameFromType(fieldType) - val newPath = s"""- field (class: "$clsName", name: "$fieldName")""" +: walkedTypePath - expressions.Literal(fieldName) :: serializerFor(fieldValue, fieldType, newPath) :: Nil - }) - val nullOutput = expressions.Literal.create(null, nonNullOutput.dataType) - expressions.If(IsNull(inputObject), nullOutput, nonNullOutput) - case t if t <:< localTypeOf[Array[_]] => val TypeRef(_, _, Seq(elementType)) = t toCatalystArray(inputObject, elementType) @@ -625,6 +614,17 @@ object ScalaReflection extends ScalaReflection { dataType = ObjectType(udt.getClass)) Invoke(obj, "serialize", udt.sqlType, inputObject :: Nil) + case t if definedByConstructorParams(t) => + val params = getConstructorParameters(t) + val nonNullOutput = CreateNamedStruct(params.flatMap { case (fieldName, fieldType) => + val fieldValue = Invoke(inputObject, fieldName, dataTypeFor(fieldType)) + val clsName = getClassNameFromType(fieldType) + val newPath = s"""- field (class: "$clsName", name: "$fieldName")""" +: walkedTypePath + expressions.Literal(fieldName) :: serializerFor(fieldValue, fieldType, newPath) :: Nil + }) + val nullOutput = expressions.Literal.create(null, nonNullOutput.dataType) + expressions.If(IsNull(inputObject), nullOutput, nonNullOutput) + case other => throw new UnsupportedOperationException( s"No Encoder found for $tpe\n" + walkedTypePath.mkString("\n")) @@ -714,13 +714,6 @@ object ScalaReflection extends ScalaReflection { val Schema(valueDataType, valueNullable) = schemaFor(valueType) Schema(MapType(schemaFor(keyType).dataType, valueDataType, valueContainsNull = valueNullable), nullable = true) - case t if definedByConstructorParams(t) => - val params = getConstructorParameters(t) - Schema(StructType( - params.map { case (fieldName, fieldType) => - val Schema(dataType, nullable) = schemaFor(fieldType) - StructField(fieldName, dataType, nullable) - }), nullable = true) case t if t <:< localTypeOf[String] => Schema(StringType, nullable = true) case t if t <:< localTypeOf[java.sql.Timestamp] => Schema(TimestampType, nullable = true) case t if t <:< localTypeOf[java.sql.Date] => Schema(DateType, nullable = true) @@ -742,6 +735,13 @@ object ScalaReflection extends ScalaReflection { case t if t <:< definitions.ShortTpe => Schema(ShortType, nullable = false) case t if t <:< definitions.ByteTpe => Schema(ByteType, nullable = false) case t if t <:< definitions.BooleanTpe => Schema(BooleanType, nullable = false) + case t if definedByConstructorParams(t) => + val params = getConstructorParameters(t) + Schema(StructType( + params.map { case (fieldName, fieldType) => + val Schema(dataType, nullable) = schemaFor(fieldType) + StructField(fieldName, dataType, nullable) + }), nullable = true) case other => throw new UnsupportedOperationException(s"Schema for type $other is not supported") } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoderSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoderSuite.scala index c3b20e2cc00a2..af3fa6e803cd4 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoderSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoderSuite.scala @@ -30,7 +30,8 @@ import org.apache.spark.sql.catalyst.expressions.{Alias, AttributeReference} import org.apache.spark.sql.catalyst.plans.PlanTest import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, Project} import org.apache.spark.sql.catalyst.util.ArrayData -import org.apache.spark.sql.types.{ArrayType, Decimal, ObjectType, StructType} +import org.apache.spark.sql.types._ +import org.apache.spark.unsafe.types.UTF8String case class RepeatedStruct(s: Seq[PrimitiveData]) @@ -85,6 +86,25 @@ class JavaSerializable(val value: Int) extends Serializable { } } +/** For testing UDT for a case class */ +@SQLUserDefinedType(udt = classOf[UDTForCaseClass]) +case class UDTCaseClass(uri: java.net.URI) + +class UDTForCaseClass extends UserDefinedType[UDTCaseClass] { + + override def sqlType: DataType = StringType + + override def serialize(obj: UDTCaseClass): UTF8String = { + UTF8String.fromString(obj.uri.toString) + } + + override def userClass: Class[UDTCaseClass] = classOf[UDTCaseClass] + + override def deserialize(datum: Any): UDTCaseClass = datum match { + case uri: UTF8String => UDTCaseClass(new java.net.URI(uri.toString)) + } +} + class ExpressionEncoderSuite extends PlanTest with AnalysisTest { OuterScopes.addOuterScope(this) @@ -108,7 +128,7 @@ class ExpressionEncoderSuite extends PlanTest with AnalysisTest { encodeDecodeTest(new java.lang.Double(-3.7), "boxed double") encodeDecodeTest(BigDecimal("32131413.211321313"), "scala decimal") - // encodeDecodeTest(new java.math.BigDecimal("231341.23123"), "java decimal") + encodeDecodeTest(new java.math.BigDecimal("231341.23123"), "java decimal") encodeDecodeTest(Decimal("32131413.211321313"), "catalyst decimal") @@ -144,6 +164,12 @@ class ExpressionEncoderSuite extends PlanTest with AnalysisTest { encodeDecodeTest(Tuple1[Seq[Int]](null), "null seq in tuple") encodeDecodeTest(Tuple1[Map[String, String]](null), "null map in tuple") + encodeDecodeTest(List(1, 2), "list of int") + encodeDecodeTest(List("a", null), "list with String and null") + + encodeDecodeTest( + UDTCaseClass(new java.net.URI("http://spark.apache.org/")), "udt with case class") + // Kryo encoders encodeDecodeTest("hello", "kryo string")(encoderFor(Encoders.kryo[String])) encodeDecodeTest(new KryoSerializable(15), "kryo object")( @@ -336,6 +362,10 @@ class ExpressionEncoderSuite extends PlanTest with AnalysisTest { Arrays.deepEquals(b1.asInstanceOf[Array[AnyRef]], b2.asInstanceOf[Array[AnyRef]]) case (b1: Array[_], b2: Array[_]) => Arrays.equals(b1.asInstanceOf[Array[AnyRef]], b2.asInstanceOf[Array[AnyRef]]) + case (b1: java.math.BigDecimal, b2: java.math.BigDecimal) => + // "java.math.BigDecimal.equals" requires both value and scala must be same (thus 2.0 is + // not equal to 2.00 when using "equals"). Hence using "compareTo" instead. + b1.compareTo(b2) == 0 case _ => input == convertedBack }