Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SPARK-19666][SQL] Skip a property without getter in Java schema inference and allow empty bean in encoder creation #17013

Closed
wants to merge 7 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
Expand Up @@ -117,11 +117,10 @@ object JavaTypeInference {
val (valueDataType, nullable) = inferDataType(valueType)
(MapType(keyDataType, valueDataType, nullable), true)

case _ =>
case other =>
// TODO: we should only collect properties that have getter and setter. However, some tests
// pass in scala case class as java bean class which doesn't have getter and setter.
val beanInfo = Introspector.getBeanInfo(typeToken.getRawType)
val properties = beanInfo.getPropertyDescriptors.filterNot(_.getName == "class")
val properties = getJavaBeanReadableProperties(other)
val fields = properties.map { property =>
val returnType = typeToken.method(property.getReadMethod).getReturnType
val (dataType, nullable) = inferDataType(returnType)
Expand All @@ -131,10 +130,15 @@ object JavaTypeInference {
}
}

private def getJavaBeanProperties(beanClass: Class[_]): Array[PropertyDescriptor] = {
def getJavaBeanReadableProperties(beanClass: Class[_]): Array[PropertyDescriptor] = {
val beanInfo = Introspector.getBeanInfo(beanClass)
beanInfo.getPropertyDescriptors
.filter(p => p.getReadMethod != null && p.getWriteMethod != null)
beanInfo.getPropertyDescriptors.filterNot(_.getName == "class")
.filter(_.getReadMethod != null)
}

private def getJavaBeanReadableAndWritableProperties(
beanClass: Class[_]): Array[PropertyDescriptor] = {
getJavaBeanReadableProperties(beanClass).filter(_.getWriteMethod != null)
}

private def elementType(typeToken: TypeToken[_]): TypeToken[_] = {
Expand Down Expand Up @@ -298,9 +302,7 @@ object JavaTypeInference {
keyData :: valueData :: Nil)

case other =>
val properties = getJavaBeanProperties(other)
assert(properties.length > 0)

val properties = getJavaBeanReadableAndWritableProperties(other)
val setters = properties.map { p =>
val fieldName = p.getName
val fieldType = typeToken.method(p.getReadMethod).getReturnType
Expand Down Expand Up @@ -417,21 +419,16 @@ object JavaTypeInference {
)

case other =>
val properties = getJavaBeanProperties(other)
if (properties.length > 0) {
CreateNamedStruct(properties.flatMap { p =>
val fieldName = p.getName
val fieldType = typeToken.method(p.getReadMethod).getReturnType
val fieldValue = Invoke(
inputObject,
p.getReadMethod.getName,
inferExternalType(fieldType.getRawType))
expressions.Literal(fieldName) :: serializerFor(fieldValue, fieldType) :: Nil
})
} else {
throw new UnsupportedOperationException(
s"Cannot infer type for class ${other.getName} because it is not bean-compliant")
}
val properties = getJavaBeanReadableAndWritableProperties(other)
CreateNamedStruct(properties.flatMap { p =>
val fieldName = p.getName
val fieldType = typeToken.method(p.getReadMethod).getReturnType
val fieldValue = Invoke(
inputObject,
p.getReadMethod.getName,
inferExternalType(fieldType.getRawType))
expressions.Literal(fieldName) :: serializerFor(fieldValue, fieldType) :: Nil
})
}
}
}
Expand Down
6 changes: 3 additions & 3 deletions sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala
Expand Up @@ -1090,14 +1090,14 @@ object SQLContext {
*/
private[sql] def beansToRows(
data: Iterator[_],
beanInfo: BeanInfo,
beanClass: Class[_],
attrs: Seq[AttributeReference]): Iterator[InternalRow] = {
val extractors =
beanInfo.getPropertyDescriptors.filterNot(_.getName == "class").map(_.getReadMethod)
JavaTypeInference.getJavaBeanReadableProperties(beanClass).map(_.getReadMethod)
val methodsToConverts = extractors.zip(attrs).map { case (e, attr) =>
(e, CatalystTypeConverters.createToCatalystConverter(attr.dataType))
}
data.map{ element =>
data.map { element =>
new GenericInternalRow(
methodsToConverts.map { case (e, convert) => convert(e.invoke(element)) }
): InternalRow
Expand Down
Expand Up @@ -17,7 +17,6 @@

package org.apache.spark.sql

import java.beans.Introspector
import java.io.Closeable
import java.util.concurrent.atomic.AtomicReference

Expand Down Expand Up @@ -347,8 +346,7 @@ class SparkSession private(
val className = beanClass.getName
val rowRdd = rdd.mapPartitions { iter =>
// BeanInfo is not serializable so we must rediscover it remotely for each partition.
val localBeanInfo = Introspector.getBeanInfo(Utils.classForName(className))
SQLContext.beansToRows(iter, localBeanInfo, attributeSeq)
SQLContext.beansToRows(iter, Utils.classForName(className), attributeSeq)
}
Dataset.ofRows(self, LogicalRDD(attributeSeq, rowRdd)(self))
}
Expand All @@ -374,8 +372,7 @@ class SparkSession private(
*/
def createDataFrame(data: java.util.List[_], beanClass: Class[_]): DataFrame = {
val attrSeq = getSchema(beanClass)
val beanInfo = Introspector.getBeanInfo(beanClass)
val rows = SQLContext.beansToRows(data.asScala.iterator, beanInfo, attrSeq)
val rows = SQLContext.beansToRows(data.asScala.iterator, beanClass, attrSeq)
Dataset.ofRows(self, LocalRelation(attrSeq, rows.toSeq))
}

Expand Down
Expand Up @@ -397,4 +397,21 @@ public void testBloomFilter() {
Assert.assertTrue(filter4.mightContain(i * 3));
}
}

public static class BeanWithoutGetter implements Serializable {
private String a;

public void setA(String a) {
this.a = a;
}
}

@Test
public void testBeanWithoutGetter() {
BeanWithoutGetter bean = new BeanWithoutGetter();
List<BeanWithoutGetter> data = Arrays.asList(bean);
Dataset<Row> df = spark.createDataFrame(data, BeanWithoutGetter.class);
Assert.assertEquals(df.schema().length(), 0);
Assert.assertEquals(df.collectAsList().size(), 1);
}
}
Expand Up @@ -1276,4 +1276,15 @@ public void test() {
spark.createDataset(data, Encoders.bean(NestedComplicatedJavaBean.class));
ds.collectAsList();
}

public static class EmptyBean implements Serializable {}

@Test
public void testEmptyBean() {
EmptyBean bean = new EmptyBean();
List<EmptyBean> data = Arrays.asList(bean);
Dataset<EmptyBean> df = spark.createDataset(data, Encoders.bean(EmptyBean.class));
Assert.assertEquals(df.schema().length(), 0);
Assert.assertEquals(df.collectAsList().size(), 1);
}
}