Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,14 @@

package org.apache.spark.sql.catalyst

import org.apache.spark.sql.catalyst.analysis.UnresolvedExtractValue
import org.apache.spark.sql.catalyst.expressions.{Expression, GetStructField, UpCast}
import org.apache.spark.sql.catalyst.expressions.objects.{AssertNotNull, Invoke, StaticInvoke}
import org.apache.spark.sql.catalyst.util.{DateTimeUtils, IntervalUtils}
import org.apache.spark.sql.catalyst.{expressions => exprs}
import org.apache.spark.sql.catalyst.analysis.{GetColumnByOrdinal, UnresolvedExtractValue}
import org.apache.spark.sql.catalyst.encoders.{AgnosticEncoder, AgnosticEncoders}
import org.apache.spark.sql.catalyst.encoders.AgnosticEncoders.{ArrayEncoder, BoxedLeafEncoder, DateEncoder, DayTimeIntervalEncoder, InstantEncoder, IterableEncoder, JavaBeanEncoder, JavaBigIntEncoder, JavaDecimalEncoder, JavaEnumEncoder, LocalDateEncoder, LocalDateTimeEncoder, MapEncoder, OptionEncoder, PrimitiveBooleanEncoder, PrimitiveByteEncoder, PrimitiveDoubleEncoder, PrimitiveFloatEncoder, PrimitiveIntEncoder, PrimitiveLongEncoder, PrimitiveShortEncoder, ProductEncoder, ScalaBigIntEncoder, ScalaDecimalEncoder, ScalaEnumEncoder, StringEncoder, TimestampEncoder, UDTEncoder, YearMonthIntervalEncoder}
import org.apache.spark.sql.catalyst.encoders.EncoderUtils.{externalDataTypeFor, isNativeEncoder}
import org.apache.spark.sql.catalyst.expressions.{Expression, GetStructField, IsNull, MapKeys, MapValues, UpCast}
import org.apache.spark.sql.catalyst.expressions.objects.{AssertNotNull, CreateExternalRow, InitializeJavaBean, Invoke, NewInstance, StaticInvoke, UnresolvedCatalystToExternalMap, UnresolvedMapObjects, WrapOption}
import org.apache.spark.sql.catalyst.util.{ArrayBasedMapData, DateTimeUtils, IntervalUtils}
import org.apache.spark.sql.types._

object DeserializerBuildHelper {
Expand Down Expand Up @@ -193,4 +197,246 @@ object DeserializerBuildHelper {
UpCast(expr, DecimalType, walkedTypePath.getPaths)
case _ => UpCast(expr, expected, walkedTypePath.getPaths)
}

/**
* Returns an expression for deserializing the Spark SQL representation of an object into its
* external form. The mapping between the internal and external representations is
* described by encoder `enc`. The Spark SQL representation is located at ordinal 0 of
* a row, i.e., `GetColumnByOrdinal(0, _)`. Nested classes will have their fields accessed using
* `UnresolvedExtractValue`.
*
* The returned expression is used by `ExpressionEncoder`. The encoder will resolve and bind this
* deserializer expression when using it.
*
* @param enc encoder that describes the mapping between the Spark SQL representation and the
* external representation.
*/
def createDeserializer[T](enc: AgnosticEncoder[T]): Expression = {
val walkedTypePath = WalkedTypePath().recordRoot(enc.clsTag.runtimeClass.getName)
// Assumes we are deserializing the first column of a row.
val input = GetColumnByOrdinal(0, enc.dataType)
enc match {
case AgnosticEncoders.RowEncoder(fields) =>
val children = fields.zipWithIndex.map { case (f, i) =>
createDeserializer(f.enc, GetStructField(input, i), walkedTypePath)
}
CreateExternalRow(children, enc.schema)
case _ =>
val deserializer = createDeserializer(
enc,
upCastToExpectedType(input, enc.dataType, walkedTypePath),
walkedTypePath)
expressionWithNullSafety(deserializer, enc.nullable, walkedTypePath)
}
}

/**
* Returns an expression for deserializing the value of an input expression into its external
* representation. The mapping between the internal and external representations is
* described by encoder `enc`.
*
* @param enc encoder that describes the mapping between the Spark SQL representation and the
* external representation.
* @param path The expression which can be used to extract serialized value.
* @param walkedTypePath The paths from top to bottom to access current field when deserializing.
*/
private def createDeserializer(
enc: AgnosticEncoder[_],
path: Expression,
walkedTypePath: WalkedTypePath): Expression = enc match {
case _ if isNativeEncoder(enc) =>
path
case _: BoxedLeafEncoder[_, _] =>
createDeserializerForTypesSupportValueOf(path, enc.clsTag.runtimeClass)
case JavaEnumEncoder(tag) =>
val toString = createDeserializerForString(path, returnNullable = false)
createDeserializerForTypesSupportValueOf(toString, tag.runtimeClass)
case ScalaEnumEncoder(parent, tag) =>
StaticInvoke(
parent,
ObjectType(tag.runtimeClass),
"withName",
createDeserializerForString(path, returnNullable = false) :: Nil,
returnNullable = false)
case StringEncoder =>
createDeserializerForString(path, returnNullable = false)
case _: ScalaDecimalEncoder =>
createDeserializerForScalaBigDecimal(path, returnNullable = false)
case _: JavaDecimalEncoder =>
createDeserializerForJavaBigDecimal(path, returnNullable = false)
case ScalaBigIntEncoder =>
createDeserializerForScalaBigInt(path)
case JavaBigIntEncoder =>
createDeserializerForJavaBigInteger(path, returnNullable = false)
case DayTimeIntervalEncoder =>
createDeserializerForDuration(path)
case YearMonthIntervalEncoder =>
createDeserializerForPeriod(path)
case _: DateEncoder =>
createDeserializerForSqlDate(path)
case _: LocalDateEncoder =>
createDeserializerForLocalDate(path)
case _: TimestampEncoder =>
createDeserializerForSqlTimestamp(path)
case _: InstantEncoder =>
createDeserializerForInstant(path)
case LocalDateTimeEncoder =>
createDeserializerForLocalDateTime(path)
case UDTEncoder(udt, udtClass) =>
val obj = NewInstance(udtClass, Nil, ObjectType(udtClass))
Invoke(obj, "deserialize", ObjectType(udt.userClass), path :: Nil)
case OptionEncoder(valueEnc) =>
val newTypePath = walkedTypePath.recordOption(valueEnc.clsTag.runtimeClass.getName)
val deserializer = createDeserializer(valueEnc, path, newTypePath)
WrapOption(deserializer, externalDataTypeFor(valueEnc))

case ArrayEncoder(elementEnc: AgnosticEncoder[_], containsNull) =>
Invoke(
deserializeArray(
path,
elementEnc,
containsNull,
None,
walkedTypePath),
toArrayMethodName(elementEnc),
ObjectType(enc.clsTag.runtimeClass),
returnNullable = false)

case IterableEncoder(clsTag, elementEnc, containsNull, _) =>
deserializeArray(
path,
elementEnc,
containsNull,
Option(clsTag.runtimeClass),
walkedTypePath)

case MapEncoder(tag, keyEncoder, valueEncoder, _)
if classOf[java.util.Map[_, _]].isAssignableFrom(tag.runtimeClass) =>
// TODO (hvanhovell) this is can be improved.
val newTypePath = walkedTypePath.recordMap(
keyEncoder.clsTag.runtimeClass.getName,
valueEncoder.clsTag.runtimeClass.getName)

val keyData =
Invoke(
UnresolvedMapObjects(
p => createDeserializer(keyEncoder, p, newTypePath),
MapKeys(path)),
"array",
ObjectType(classOf[Array[Any]]))

val valueData =
Invoke(
UnresolvedMapObjects(
p => createDeserializer(valueEncoder, p, newTypePath),
MapValues(path)),
"array",
ObjectType(classOf[Array[Any]]))

StaticInvoke(
ArrayBasedMapData.getClass,
ObjectType(classOf[java.util.Map[_, _]]),
"toJavaMap",
keyData :: valueData :: Nil,
returnNullable = false)

case MapEncoder(tag, keyEncoder, valueEncoder, _) =>
val newTypePath = walkedTypePath.recordMap(
keyEncoder.clsTag.runtimeClass.getName,
valueEncoder.clsTag.runtimeClass.getName)
UnresolvedCatalystToExternalMap(
path,
createDeserializer(keyEncoder, _, newTypePath),
createDeserializer(valueEncoder, _, newTypePath),
tag.runtimeClass)

case ProductEncoder(tag, fields) =>
val cls = tag.runtimeClass
val dt = ObjectType(cls)
val isTuple = cls.getName.startsWith("scala.Tuple")
val arguments = fields.zipWithIndex.map {
case (field, i) =>
val newTypePath = walkedTypePath.recordField(
field.enc.clsTag.runtimeClass.getName,
field.name)
// For tuples, we grab the inner fields by ordinal instead of name.
val getter = if (isTuple) {
addToPathOrdinal(path, i, field.enc.dataType, newTypePath)
} else {
addToPath(path, field.name, field.enc.dataType, newTypePath)
}
expressionWithNullSafety(
createDeserializer(field.enc, getter, newTypePath),
field.enc.nullable,
newTypePath)
}
exprs.If(
IsNull(path),
exprs.Literal.create(null, dt),
NewInstance(cls, arguments, dt, propagateNull = false))

case AgnosticEncoders.RowEncoder(fields) =>
val convertedFields = fields.zipWithIndex.map { case (f, i) =>
val newTypePath = walkedTypePath.recordField(
f.enc.clsTag.runtimeClass.getName,
f.name)
exprs.If(
Invoke(path, "isNullAt", BooleanType, exprs.Literal(i) :: Nil),
exprs.Literal.create(null, externalDataTypeFor(f.enc)),
createDeserializer(f.enc, GetStructField(path, i), newTypePath))
}
exprs.If(IsNull(path),
exprs.Literal.create(null, externalDataTypeFor(enc)),
CreateExternalRow(convertedFields, enc.schema))

case JavaBeanEncoder(tag, fields) =>
val setters = fields.map { f =>
val newTypePath = walkedTypePath.recordField(
f.enc.clsTag.runtimeClass.getName,
f.name)
val setter = expressionWithNullSafety(
createDeserializer(
f.enc,
addToPath(path, f.name, f.enc.dataType, newTypePath),
newTypePath),
nullable = f.nullable,
newTypePath)
f.writeMethod.get -> setter
}

val cls = tag.runtimeClass
val newInstance = NewInstance(cls, Nil, ObjectType(cls), propagateNull = false)
val result = InitializeJavaBean(newInstance, setters.toMap)
exprs.If(IsNull(path), exprs.Literal.create(null, ObjectType(cls)), result)
}

private def deserializeArray(
path: Expression,
elementEnc: AgnosticEncoder[_],
containsNull: Boolean,
cls: Option[Class[_]],
walkedTypePath: WalkedTypePath): Expression = {
val newTypePath = walkedTypePath.recordArray(elementEnc.clsTag.runtimeClass.getName)
val mapFunction: Expression => Expression = element => {
// upcast the array element to the data type the encoder expects.
deserializerForWithNullSafetyAndUpcast(
element,
elementEnc.dataType,
nullable = containsNull,
newTypePath,
createDeserializer(elementEnc, _, newTypePath))
}
UnresolvedMapObjects(mapFunction, path, cls)
}

private def toArrayMethodName(enc: AgnosticEncoder[_]): String = enc match {
case PrimitiveBooleanEncoder => "toBooleanArray"
case PrimitiveByteEncoder => "toByteArray"
case PrimitiveShortEncoder => "toShortArray"
case PrimitiveIntEncoder => "toIntArray"
case PrimitiveLongEncoder => "toLongArray"
case PrimitiveFloatEncoder => "toFloatArray"
case PrimitiveDoubleEncoder => "toDoubleArray"
case _ => "array"
}
}
Loading