Skip to content

Commit

Permalink
Modify RowEncoder and MapObjects to preserve array/map nullability.
Browse files Browse the repository at this point in the history
  • Loading branch information
ueshin committed Jun 23, 2016
1 parent f64f570 commit 093a9fa
Show file tree
Hide file tree
Showing 2 changed files with 48 additions and 13 deletions.
Expand Up @@ -119,15 +119,32 @@ object RowEncoder {
"fromString",
inputObject :: Nil)

case t @ ArrayType(et, _) => et match {
case t @ ArrayType(et, containsNull) => et match {
case BooleanType | ByteType | ShortType | IntegerType | LongType | FloatType | DoubleType =>
// TODO: validate input type for primitive array.
NewInstance(
val nonNullOutput = NewInstance(
classOf[GenericArrayData],
inputObject :: Nil,
dataType = t)
dataType = t,
propagateNull = false)

if (inputObject.nullable) {
If(IsNull(inputObject),
Literal.create(null, inputType),
nonNullOutput)
} else {
nonNullOutput
}

case _ => MapObjects(
element => serializerFor(ValidateExternalType(element, et), et),
{ element =>
val value = serializerFor(ValidateExternalType(element, et), et)
if (!containsNull) {
AssertNotNull(value, Seq.empty)
} else {
value
}
},
inputObject,
ObjectType(classOf[Object]))
}
Expand All @@ -147,10 +164,19 @@ object RowEncoder {
ObjectType(classOf[scala.collection.Seq[_]]))
val convertedValues = serializerFor(values, ArrayType(vt, valueNullable))

NewInstance(
val nonNullOutput = NewInstance(
classOf[ArrayBasedMapData],
convertedKeys :: convertedValues :: Nil,
dataType = t)
dataType = t,
propagateNull = false)

if (inputObject.nullable) {
If(IsNull(inputObject),
Literal.create(null, inputType),
nonNullOutput)
} else {
nonNullOutput
}

case StructType(fields) =>
val nonNullOutput = CreateNamedStruct(fields.zipWithIndex.flatMap { case (field, index) =>
Expand Down
Expand Up @@ -376,14 +376,15 @@ case class MapObjects private(
lambdaFunction: Expression,
inputData: Expression) extends Expression with NonSQLExpression {

override def nullable: Boolean = true
override def nullable: Boolean = inputData.nullable

override def children: Seq[Expression] = lambdaFunction :: inputData :: Nil

override def eval(input: InternalRow): Any =
throw new UnsupportedOperationException("Only code-generated evaluation is supported")

override def dataType: DataType = ArrayType(lambdaFunction.dataType)
override def dataType: DataType =
ArrayType(lambdaFunction.dataType, containsNull = lambdaFunction.nullable)

override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
val elementJavaType = ctx.javaType(loopVar.dataType)
Expand Down Expand Up @@ -450,6 +451,18 @@ case class MapObjects private(
case _ => s"${loopVar.isNull} = ${loopVar.value} == null;"
}

val setValue = if (lambdaFunction.nullable) {
s"""
if (${genFunction.isNull}) {
$convertedArray[$loopIndex] = null;
} else {
$convertedArray[$loopIndex] = ${genFunction.value};
}
"""
} else {
s"$convertedArray[$loopIndex] = ${genFunction.value};"
}

val code = s"""
${genInputData.code}
${ctx.javaType(dataType)} ${ev.value} = ${ctx.defaultValue(dataType)};
Expand All @@ -466,11 +479,7 @@ case class MapObjects private(
$loopNullCheck

${genFunction.code}
if (${genFunction.isNull}) {
$convertedArray[$loopIndex] = null;
} else {
$convertedArray[$loopIndex] = ${genFunction.value};
}
$setValue

$loopIndex += 1;
}
Expand Down

0 comments on commit 093a9fa

Please sign in to comment.