Skip to content

Commit

Permalink
[SPARK-21255][SQL] simplify encoder for java enum
Browse files Browse the repository at this point in the history
## What changes were proposed in this pull request?

This is a follow-up for #18488, to simplify the code.

The major change is, we should map java enum to string type, instead of a struct type with a single string field.

## How was this patch tested?

existing tests

Author: Wenchen Fan <wenchen@databricks.com>

Closes #19066 from cloud-fan/fix.
  • Loading branch information
cloud-fan authored and gatorsmile committed Aug 29, 2017
1 parent 8fcbda9 commit 6327ea5
Show file tree
Hide file tree
Showing 4 changed files with 25 additions and 63 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,6 @@ import org.apache.spark.sql.catalyst.expressions.objects._
import org.apache.spark.sql.catalyst.util.{ArrayBasedMapData, DateTimeUtils, GenericArrayData}
import org.apache.spark.sql.types._
import org.apache.spark.unsafe.types.UTF8String
import org.apache.spark.util.Utils

/**
* Type-inference utilities for POJOs and Java collections.
Expand Down Expand Up @@ -120,8 +119,7 @@ object JavaTypeInference {
(MapType(keyDataType, valueDataType, nullable), true)

case other if other.isEnum =>
(StructType(Seq(StructField(typeToken.getRawType.getSimpleName,
StringType, nullable = false))), true)
(StringType, true)

case other =>
if (seenTypeSet.contains(other)) {
Expand Down Expand Up @@ -310,9 +308,12 @@ object JavaTypeInference {
returnNullable = false)

case other if other.isEnum =>
StaticInvoke(JavaTypeInference.getClass, ObjectType(other), "deserializeEnumName",
expressions.Literal.create(other.getEnumConstants.apply(0), ObjectType(other))
:: getPath :: Nil)
StaticInvoke(
other,
ObjectType(other),
"valueOf",
Invoke(getPath, "toString", ObjectType(classOf[String]), returnNullable = false) :: Nil,
returnNullable = false)

case other =>
val properties = getJavaBeanReadableAndWritableProperties(other)
Expand Down Expand Up @@ -356,30 +357,6 @@ object JavaTypeInference {
}
}

/** Returns a mapping from enum value to int for given enum type */
def enumSerializer[T <: Enum[T]](enum: Class[T]): T => UTF8String = {
assert(enum.isEnum)
inputObject: T =>
UTF8String.fromString(inputObject.name())
}

/** Returns value index for given enum type and value */
def serializeEnumName[T <: Enum[T]](enum: UTF8String, inputObject: T): UTF8String = {
enumSerializer(Utils.classForName(enum.toString).asInstanceOf[Class[T]])(inputObject)
}

/** Returns a mapping from int to enum value for given enum type */
def enumDeserializer[T <: Enum[T]](enum: Class[T]): InternalRow => T = {
assert(enum.isEnum)
value: InternalRow =>
Enum.valueOf(enum, value.getUTF8String(0).toString)
}

/** Returns enum value for given enum type and value index */
def deserializeEnumName[T <: Enum[T]](typeDummy: T, inputObject: InternalRow): T = {
enumDeserializer(typeDummy.getClass.asInstanceOf[Class[T]])(inputObject)
}

private def serializerFor(inputObject: Expression, typeToken: TypeToken[_]): Expression = {

def toCatalystArray(input: Expression, elementType: TypeToken[_]): Expression = {
Expand Down Expand Up @@ -465,9 +442,12 @@ object JavaTypeInference {
)

case other if other.isEnum =>
CreateNamedStruct(expressions.Literal("enum") ::
StaticInvoke(JavaTypeInference.getClass, StringType, "serializeEnumName",
expressions.Literal.create(other.getName, StringType) :: inputObject :: Nil) :: Nil)
StaticInvoke(
classOf[UTF8String],
StringType,
"fromString",
Invoke(inputObject, "name", ObjectType(classOf[String]), returnNullable = false) :: Nil,
returnNullable = false)

case other =>
val properties = getJavaBeanReadableAndWritableProperties(other)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ import org.apache.spark.sql.catalyst.expressions.codegen.{GenerateSafeProjection
import org.apache.spark.sql.catalyst.expressions.objects.{AssertNotNull, Invoke, NewInstance}
import org.apache.spark.sql.catalyst.optimizer.SimplifyCasts
import org.apache.spark.sql.catalyst.plans.logical.{CatalystSerde, DeserializeToObject, LocalRelation}
import org.apache.spark.sql.types.{BooleanType, DataType, ObjectType, StringType, StructField, StructType}
import org.apache.spark.sql.types.{BooleanType, ObjectType, StructField, StructType}
import org.apache.spark.util.Utils

/**
Expand Down Expand Up @@ -81,19 +81,9 @@ object ExpressionEncoder {
ClassTag[T](cls))
}

def javaEnumSchema[T](beanClass: Class[T]): DataType = {
StructType(Seq(StructField("enum",
StructType(Seq(StructField(beanClass.getSimpleName, StringType, nullable = false))),
nullable = false)))
}

// TODO: improve error message for java bean encoder.
def javaBean[T](beanClass: Class[T]): ExpressionEncoder[T] = {
val schema = if (beanClass.isEnum) {
javaEnumSchema(beanClass)
} else {
JavaTypeInference.inferDataType(beanClass)._1
}
val schema = JavaTypeInference.inferDataType(beanClass)._1
assert(schema.isInstanceOf[StructType])

val serializer = JavaTypeInference.serializerFor(beanClass)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -154,13 +154,13 @@ case class StaticInvoke(
val evaluate = if (returnNullable) {
if (ctx.defaultValue(dataType) == "null") {
s"""
${ev.value} = (($javaType) ($callFunc));
${ev.value} = $callFunc;
${ev.isNull} = ${ev.value} == null;
"""
} else {
val boxedResult = ctx.freshName("boxedResult")
s"""
${ctx.boxedType(dataType)} $boxedResult = (($javaType) ($callFunc));
${ctx.boxedType(dataType)} $boxedResult = $callFunc;
${ev.isNull} = $boxedResult == null;
if (!${ev.isNull}) {
${ev.value} = $boxedResult;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1283,13 +1283,13 @@ public void test() {
ds.collectAsList();
}

public enum EnumBean {
public enum MyEnum {
A("www.elgoog.com"),
B("www.google.com");

private String url;

EnumBean(String url) {
MyEnum(String url) {
this.url = url;
}

Expand All @@ -1302,16 +1302,8 @@ public void setUrl(String url) {
}
}

@Test
public void testEnum() {
List<EnumBean> data = Arrays.asList(EnumBean.B);
Encoder<EnumBean> encoder = Encoders.bean(EnumBean.class);
Dataset<EnumBean> ds = spark.createDataset(data, encoder);
Assert.assertEquals(ds.collectAsList(), data);
}

public static class BeanWithEnum {
EnumBean enumField;
MyEnum enumField;
String regularField;

public String getRegularField() {
Expand All @@ -1322,15 +1314,15 @@ public void setRegularField(String regularField) {
this.regularField = regularField;
}

public EnumBean getEnumField() {
public MyEnum getEnumField() {
return enumField;
}

public void setEnumField(EnumBean field) {
public void setEnumField(MyEnum field) {
this.enumField = field;
}

public BeanWithEnum(EnumBean enumField, String regularField) {
public BeanWithEnum(MyEnum enumField, String regularField) {
this.enumField = enumField;
this.regularField = regularField;
}
Expand All @@ -1353,8 +1345,8 @@ public boolean equals(Object other) {

@Test
public void testBeanWithEnum() {
List<BeanWithEnum> data = Arrays.asList(new BeanWithEnum(EnumBean.A, "mira avenue"),
new BeanWithEnum(EnumBean.B, "flower boulevard"));
List<BeanWithEnum> data = Arrays.asList(new BeanWithEnum(MyEnum.A, "mira avenue"),
new BeanWithEnum(MyEnum.B, "flower boulevard"));
Encoder<BeanWithEnum> encoder = Encoders.bean(BeanWithEnum.class);
Dataset<BeanWithEnum> ds = spark.createDataset(data, encoder);
Assert.assertEquals(ds.collectAsList(), data);
Expand Down

0 comments on commit 6327ea5

Please sign in to comment.