From 546742a70e88751a82bf81ee1141c6761dfd2e0e Mon Sep 17 00:00:00 2001 From: Yin Huai Date: Mon, 5 Sep 2016 12:04:35 -0700 Subject: [PATCH] Test DefinedByConstructorParams --- .../spark/sql/catalyst/ScalaReflection.scala | 22 ++++++++++++++++--- .../catalyst/encoders/ExpressionEncoder.scala | 5 +++-- 2 files changed, 22 insertions(+), 5 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala index 8722191f9d3ea..f5793f076f8cc 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala @@ -24,6 +24,15 @@ import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.UTF8String import org.apache.spark.util.Utils + +/** + * A helper trait to create [[org.apache.spark.sql.catalyst.encoders.ExpressionEncoder]]s + * for classes whose fields are entirely defined by constructor params but should not be + * case classes. + */ +trait DefinedByConstructorParams + + /** * A default version of ScalaReflection that uses the runtime universe. */ @@ -330,7 +339,7 @@ object ScalaReflection extends ScalaReflection { "toScalaMap", keyData :: valueData :: Nil) - case t if t <:< localTypeOf[Product] => + case t if definedByConstructorParams(t) => val params = getConstructorParameters(t) val cls = getClassFromType(tpe) @@ -485,7 +494,7 @@ object ScalaReflection extends ScalaReflection { extractorFor(unwrapped, optType, newPath)) } - case t if t <:< localTypeOf[Product] => + case t if definedByConstructorParams(t) => val params = getConstructorParameters(t) val nonNullOutput = CreateNamedStruct(params.flatMap { case (fieldName, fieldType) => val fieldValue = Invoke(inputObject, fieldName, dataTypeFor(fieldType)) @@ -683,7 +692,7 @@ trait ScalaReflection { val Schema(valueDataType, valueNullable) = schemaFor(valueType) Schema(MapType(schemaFor(keyType).dataType, valueDataType, valueContainsNull = valueNullable), nullable = true) - case t if t <:< localTypeOf[Product] => + case t if definedByConstructorParams(t) => val params = getConstructorParameters(t) Schema(StructType( params.map { case (fieldName, fieldType) => @@ -769,4 +778,11 @@ trait ScalaReflection { p.name.toString -> p.typeSignature.substituteTypes(formalTypeArgs, actualTypeArgs) } } + + /** + * Whether the fields of the given type is defined entirely by its constructor parameters. + */ + private[sql] def definedByConstructorParams(tpe: Type): Boolean = { + tpe <:< localTypeOf[Product] || tpe <:< localTypeOf[DefinedByConstructorParams] + } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala index 6c058463b9cf2..28713dd385f52 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala @@ -47,8 +47,9 @@ object ExpressionEncoder { def apply[T : TypeTag](): ExpressionEncoder[T] = { // We convert the not-serializable TypeTag into StructType and ClassTag. val mirror = typeTag[T].mirror - val cls = mirror.runtimeClass(typeTag[T].tpe) - val flat = !classOf[Product].isAssignableFrom(cls) + val tpe = typeTag[T].tpe + val cls = mirror.runtimeClass(tpe) + val flat = !ScalaReflection.definedByConstructorParams(tpe) val inputObject = BoundReference(0, ScalaReflection.dataTypeFor[T], nullable = true) val toRowExpression = ScalaReflection.extractorsFor[T](inputObject)