Skip to content

Commit

Permalink
[SPARK-15471][SQL] ScalaReflection cleanup
Browse files Browse the repository at this point in the history
## What changes were proposed in this pull request?

1. simplify the logic of deserializing option type.
2. simplify the logic of serializing array type, and remove silentSchemaFor
3. remove some unnecessary code.

## How was this patch tested?

existing tests

Author: Wenchen Fan <wenchen@databricks.com>

Closes #13250 from cloud-fan/encoder.

(cherry picked from commit 07c36a2)
Signed-off-by: Michael Armbrust <michael@databricks.com>
  • Loading branch information
cloud-fan authored and marmbrus committed May 23, 2016
1 parent 6eb8ec6 commit 655d882
Show file tree
Hide file tree
Showing 2 changed files with 21 additions and 88 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.expressions.objects._
import org.apache.spark.sql.catalyst.util.{ArrayBasedMapData, DateTimeUtils, GenericArrayData}
import org.apache.spark.sql.types._
import org.apache.spark.unsafe.types.UTF8String
import org.apache.spark.unsafe.types.{CalendarInterval, UTF8String}


/**
Expand Down Expand Up @@ -72,6 +72,7 @@ object ScalaReflection extends ScalaReflection {
case t if t <:< definitions.ByteTpe => ByteType
case t if t <:< definitions.BooleanTpe => BooleanType
case t if t <:< localTypeOf[Array[Byte]] => BinaryType
case t if t <:< localTypeOf[CalendarInterval] => CalendarIntervalType
case t if t <:< localTypeOf[Decimal] => DecimalType.SYSTEM_DEFAULT
case _ =>
val className = getClassNameFromType(tpe)
Expand Down Expand Up @@ -189,7 +190,6 @@ object ScalaReflection extends ScalaReflection {
case _ => UpCast(expr, expected, walkedTypePath)
}

val className = getClassNameFromType(tpe)
tpe match {
case t if !dataTypeFor(t).isInstanceOf[ObjectType] => getPath

Expand Down Expand Up @@ -239,16 +239,14 @@ object ScalaReflection extends ScalaReflection {
DateTimeUtils.getClass,
ObjectType(classOf[java.sql.Date]),
"toJavaDate",
getPath :: Nil,
propagateNull = true)
getPath :: Nil)

case t if t <:< localTypeOf[java.sql.Timestamp] =>
StaticInvoke(
DateTimeUtils.getClass,
ObjectType(classOf[java.sql.Timestamp]),
"toJavaTimestamp",
getPath :: Nil,
propagateNull = true)
getPath :: Nil)

case t if t <:< localTypeOf[java.lang.String] =>
Invoke(getPath, "toString", ObjectType(classOf[String]))
Expand Down Expand Up @@ -437,17 +435,17 @@ object ScalaReflection extends ScalaReflection {
walkedTypePath: Seq[String]): Expression = ScalaReflectionLock.synchronized {

def toCatalystArray(input: Expression, elementType: `Type`): Expression = {
val externalDataType = dataTypeFor(elementType)
val Schema(catalystType, nullable) = silentSchemaFor(elementType)
if (isNativeType(externalDataType)) {
NewInstance(
classOf[GenericArrayData],
input :: Nil,
dataType = ArrayType(catalystType, nullable))
} else {
val clsName = getClassNameFromType(elementType)
val newPath = s"""- array element class: "$clsName"""" +: walkedTypePath
MapObjects(serializerFor(_, elementType, newPath), input, externalDataType)
dataTypeFor(elementType) match {
case dt: ObjectType =>
val clsName = getClassNameFromType(elementType)
val newPath = s"""- array element class: "$clsName"""" +: walkedTypePath
MapObjects(serializerFor(_, elementType, newPath), input, dt)

case dt =>
NewInstance(
classOf[GenericArrayData],
input :: Nil,
dataType = ArrayType(dt, schemaFor(elementType).nullable))
}
}

Expand All @@ -457,63 +455,10 @@ object ScalaReflection extends ScalaReflection {
tpe match {
case t if t <:< localTypeOf[Option[_]] =>
val TypeRef(_, _, Seq(optType)) = t
optType match {
// For primitive types we must manually unbox the value of the object.
case t if t <:< definitions.IntTpe =>
Invoke(
UnwrapOption(ObjectType(classOf[java.lang.Integer]), inputObject),
"intValue",
IntegerType)
case t if t <:< definitions.LongTpe =>
Invoke(
UnwrapOption(ObjectType(classOf[java.lang.Long]), inputObject),
"longValue",
LongType)
case t if t <:< definitions.DoubleTpe =>
Invoke(
UnwrapOption(ObjectType(classOf[java.lang.Double]), inputObject),
"doubleValue",
DoubleType)
case t if t <:< definitions.FloatTpe =>
Invoke(
UnwrapOption(ObjectType(classOf[java.lang.Float]), inputObject),
"floatValue",
FloatType)
case t if t <:< definitions.ShortTpe =>
Invoke(
UnwrapOption(ObjectType(classOf[java.lang.Short]), inputObject),
"shortValue",
ShortType)
case t if t <:< definitions.ByteTpe =>
Invoke(
UnwrapOption(ObjectType(classOf[java.lang.Byte]), inputObject),
"byteValue",
ByteType)
case t if t <:< definitions.BooleanTpe =>
Invoke(
UnwrapOption(ObjectType(classOf[java.lang.Boolean]), inputObject),
"booleanValue",
BooleanType)

// For non-primitives, we can just extract the object from the Option and then recurse.
case other =>
val className = getClassNameFromType(optType)
val newPath = s"""- option value class: "$className"""" +: walkedTypePath

val optionObjectType: DataType = other match {
// Special handling is required for arrays, as getClassFromType(<Array>) will fail
// since Scala Arrays map to native Java constructs. E.g. "Array[Int]" will map to
// the Java type "[I".
case arr if arr <:< localTypeOf[Array[_]] => arrayClassFor(t)
case cls => ObjectType(getClassFromType(cls))
}
val unwrapped = UnwrapOption(optionObjectType, inputObject)

expressions.If(
IsNull(unwrapped),
expressions.Literal.create(null, silentSchemaFor(optType).dataType),
serializerFor(unwrapped, optType, newPath))
}
val className = getClassNameFromType(optType)
val newPath = s"""- option value class: "$className"""" +: walkedTypePath
val unwrapped = UnwrapOption(dataTypeFor(optType), inputObject)
serializerFor(unwrapped, optType, newPath)

// Since List[_] also belongs to localTypeOf[Product], we put this case before
// "case t if definedByConstructorParams(t)" to make sure it will match to the
Expand Down Expand Up @@ -704,18 +649,6 @@ object ScalaReflection extends ScalaReflection {
/** Returns a catalyst DataType and its nullability for the given Scala Type using reflection. */
def schemaFor[T: TypeTag]: Schema = schemaFor(localTypeOf[T])

/**
* Returns a catalyst DataType and its nullability for the given Scala Type using reflection.
*
* Unlike `schemaFor`, this method won't throw exception for un-supported type, it will return
* `NullType` silently instead.
*/
def silentSchemaFor(tpe: `Type`): Schema = try {
schemaFor(tpe)
} catch {
case _: UnsupportedOperationException => Schema(NullType, nullable = true)
}

/** Returns a catalyst DataType and its nullability for the given Scala Type using reflection. */
def schemaFor(tpe: `Type`): Schema = ScalaReflectionLock.synchronized {
tpe match {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -289,8 +289,8 @@ case class UnwrapOption(
${inputObject.code}

final boolean ${ev.isNull} = ${inputObject.isNull} || ${inputObject.value}.isEmpty();
$javaType ${ev.value} =
${ev.isNull} ? ${ctx.defaultValue(javaType)} : ($javaType) ${inputObject.value}.get();
$javaType ${ev.value} = ${ev.isNull} ?
${ctx.defaultValue(javaType)} : (${ctx.boxedType(javaType)}) ${inputObject.value}.get();
"""
ev.copy(code = code)
}
Expand Down

0 comments on commit 655d882

Please sign in to comment.