Skip to content

Commit

Permalink
[SPARK-19666][SQL] Skip a property without getter in Java schema infe…
Browse files Browse the repository at this point in the history
…rence and allow empty bean in encoder creation

## What changes were proposed in this pull request?

This PR proposes to fix two.

**Skip a property without a getter in beans**

Currently, if we use a JavaBean without the getter as below:

```java
public static class BeanWithoutGetter implements Serializable {
  private String a;

  public void setA(String a) {
    this.a = a;
  }
}

BeanWithoutGetter bean = new BeanWithoutGetter();
List<BeanWithoutGetter> data = Arrays.asList(bean);
spark.createDataFrame(data, BeanWithoutGetter.class).show();
```

- Before

It throws an exception as below:

```
java.lang.NullPointerException
	at org.spark_project.guava.reflect.TypeToken.method(TypeToken.java:465)
	at org.apache.spark.sql.catalyst.JavaTypeInference$$anonfun$2.apply(JavaTypeInference.scala:126)
	at org.apache.spark.sql.catalyst.JavaTypeInference$$anonfun$2.apply(JavaTypeInference.scala:125)
```

- After

```
++
||
++
||
++
```

**Supports empty bean in encoder creation**

```java
public static class EmptyBean implements Serializable {}

EmptyBean bean = new EmptyBean();
List<EmptyBean> data = Arrays.asList(bean);
spark.createDataset(data, Encoders.bean(EmptyBean.class)).show();
```

- Before

throws an exception as below:

```
java.lang.UnsupportedOperationException: Cannot infer type for class EmptyBean because it is not bean-compliant
	at org.apache.spark.sql.catalyst.JavaTypeInference$.org$apache$spark$sql$catalyst$JavaTypeInference$$serializerFor(JavaTypeInference.scala:436)
	at org.apache.spark.sql.catalyst.JavaTypeInference$.serializerFor(JavaTypeInference.scala:341)
```

- After

```
++
||
++
||
++
```

## How was this patch tested?

Unit test in `JavaDataFrameSuite`.

Author: hyukjinkwon <gurwls223@gmail.com>

Closes #17013 from HyukjinKwon/SPARK-19666.
  • Loading branch information
HyukjinKwon authored and cloud-fan committed Feb 22, 2017
1 parent 1f86e79 commit 37112fc
Show file tree
Hide file tree
Showing 5 changed files with 54 additions and 32 deletions.
Expand Up @@ -117,11 +117,10 @@ object JavaTypeInference {
val (valueDataType, nullable) = inferDataType(valueType)
(MapType(keyDataType, valueDataType, nullable), true)

case _ =>
case other =>
// TODO: we should only collect properties that have getter and setter. However, some tests
// pass in scala case class as java bean class which doesn't have getter and setter.
val beanInfo = Introspector.getBeanInfo(typeToken.getRawType)
val properties = beanInfo.getPropertyDescriptors.filterNot(_.getName == "class")
val properties = getJavaBeanReadableProperties(other)
val fields = properties.map { property =>
val returnType = typeToken.method(property.getReadMethod).getReturnType
val (dataType, nullable) = inferDataType(returnType)
Expand All @@ -131,10 +130,15 @@ object JavaTypeInference {
}
}

private def getJavaBeanProperties(beanClass: Class[_]): Array[PropertyDescriptor] = {
def getJavaBeanReadableProperties(beanClass: Class[_]): Array[PropertyDescriptor] = {
val beanInfo = Introspector.getBeanInfo(beanClass)
beanInfo.getPropertyDescriptors
.filter(p => p.getReadMethod != null && p.getWriteMethod != null)
beanInfo.getPropertyDescriptors.filterNot(_.getName == "class")
.filter(_.getReadMethod != null)
}

private def getJavaBeanReadableAndWritableProperties(
beanClass: Class[_]): Array[PropertyDescriptor] = {
getJavaBeanReadableProperties(beanClass).filter(_.getWriteMethod != null)
}

private def elementType(typeToken: TypeToken[_]): TypeToken[_] = {
Expand Down Expand Up @@ -298,9 +302,7 @@ object JavaTypeInference {
keyData :: valueData :: Nil)

case other =>
val properties = getJavaBeanProperties(other)
assert(properties.length > 0)

val properties = getJavaBeanReadableAndWritableProperties(other)
val setters = properties.map { p =>
val fieldName = p.getName
val fieldType = typeToken.method(p.getReadMethod).getReturnType
Expand Down Expand Up @@ -417,21 +419,16 @@ object JavaTypeInference {
)

case other =>
val properties = getJavaBeanProperties(other)
if (properties.length > 0) {
CreateNamedStruct(properties.flatMap { p =>
val fieldName = p.getName
val fieldType = typeToken.method(p.getReadMethod).getReturnType
val fieldValue = Invoke(
inputObject,
p.getReadMethod.getName,
inferExternalType(fieldType.getRawType))
expressions.Literal(fieldName) :: serializerFor(fieldValue, fieldType) :: Nil
})
} else {
throw new UnsupportedOperationException(
s"Cannot infer type for class ${other.getName} because it is not bean-compliant")
}
val properties = getJavaBeanReadableAndWritableProperties(other)
CreateNamedStruct(properties.flatMap { p =>
val fieldName = p.getName
val fieldType = typeToken.method(p.getReadMethod).getReturnType
val fieldValue = Invoke(
inputObject,
p.getReadMethod.getName,
inferExternalType(fieldType.getRawType))
expressions.Literal(fieldName) :: serializerFor(fieldValue, fieldType) :: Nil
})
}
}
}
Expand Down
6 changes: 3 additions & 3 deletions sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala
Expand Up @@ -1090,14 +1090,14 @@ object SQLContext {
*/
private[sql] def beansToRows(
data: Iterator[_],
beanInfo: BeanInfo,
beanClass: Class[_],
attrs: Seq[AttributeReference]): Iterator[InternalRow] = {
val extractors =
beanInfo.getPropertyDescriptors.filterNot(_.getName == "class").map(_.getReadMethod)
JavaTypeInference.getJavaBeanReadableProperties(beanClass).map(_.getReadMethod)
val methodsToConverts = extractors.zip(attrs).map { case (e, attr) =>
(e, CatalystTypeConverters.createToCatalystConverter(attr.dataType))
}
data.map{ element =>
data.map { element =>
new GenericInternalRow(
methodsToConverts.map { case (e, convert) => convert(e.invoke(element)) }
): InternalRow
Expand Down
Expand Up @@ -17,7 +17,6 @@

package org.apache.spark.sql

import java.beans.Introspector
import java.io.Closeable
import java.util.concurrent.atomic.AtomicReference

Expand Down Expand Up @@ -347,8 +346,7 @@ class SparkSession private(
val className = beanClass.getName
val rowRdd = rdd.mapPartitions { iter =>
// BeanInfo is not serializable so we must rediscover it remotely for each partition.
val localBeanInfo = Introspector.getBeanInfo(Utils.classForName(className))
SQLContext.beansToRows(iter, localBeanInfo, attributeSeq)
SQLContext.beansToRows(iter, Utils.classForName(className), attributeSeq)
}
Dataset.ofRows(self, LogicalRDD(attributeSeq, rowRdd)(self))
}
Expand All @@ -374,8 +372,7 @@ class SparkSession private(
*/
def createDataFrame(data: java.util.List[_], beanClass: Class[_]): DataFrame = {
val attrSeq = getSchema(beanClass)
val beanInfo = Introspector.getBeanInfo(beanClass)
val rows = SQLContext.beansToRows(data.asScala.iterator, beanInfo, attrSeq)
val rows = SQLContext.beansToRows(data.asScala.iterator, beanClass, attrSeq)
Dataset.ofRows(self, LocalRelation(attrSeq, rows.toSeq))
}

Expand Down
Expand Up @@ -397,4 +397,21 @@ public void testBloomFilter() {
Assert.assertTrue(filter4.mightContain(i * 3));
}
}

public static class BeanWithoutGetter implements Serializable {
private String a;

public void setA(String a) {
this.a = a;
}
}

@Test
public void testBeanWithoutGetter() {
BeanWithoutGetter bean = new BeanWithoutGetter();
List<BeanWithoutGetter> data = Arrays.asList(bean);
Dataset<Row> df = spark.createDataFrame(data, BeanWithoutGetter.class);
Assert.assertEquals(df.schema().length(), 0);
Assert.assertEquals(df.collectAsList().size(), 1);
}
}
Expand Up @@ -1276,4 +1276,15 @@ public void test() {
spark.createDataset(data, Encoders.bean(NestedComplicatedJavaBean.class));
ds.collectAsList();
}

public static class EmptyBean implements Serializable {}

@Test
public void testEmptyBean() {
EmptyBean bean = new EmptyBean();
List<EmptyBean> data = Arrays.asList(bean);
Dataset<EmptyBean> df = spark.createDataset(data, Encoders.bean(EmptyBean.class));
Assert.assertEquals(df.schema().length(), 0);
Assert.assertEquals(df.collectAsList().size(), 1);
}
}

0 comments on commit 37112fc

Please sign in to comment.