diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/JavaTypeInference.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/JavaTypeInference.scala index 8b53d988cbc59..e9d9508e5adfe 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/JavaTypeInference.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/JavaTypeInference.scala @@ -117,11 +117,10 @@ object JavaTypeInference { val (valueDataType, nullable) = inferDataType(valueType) (MapType(keyDataType, valueDataType, nullable), true) - case _ => + case other => // TODO: we should only collect properties that have getter and setter. However, some tests // pass in scala case class as java bean class which doesn't have getter and setter. - val beanInfo = Introspector.getBeanInfo(typeToken.getRawType) - val properties = beanInfo.getPropertyDescriptors.filterNot(_.getName == "class") + val properties = getJavaBeanReadableProperties(other) val fields = properties.map { property => val returnType = typeToken.method(property.getReadMethod).getReturnType val (dataType, nullable) = inferDataType(returnType) @@ -131,10 +130,15 @@ object JavaTypeInference { } } - private def getJavaBeanProperties(beanClass: Class[_]): Array[PropertyDescriptor] = { + def getJavaBeanReadableProperties(beanClass: Class[_]): Array[PropertyDescriptor] = { val beanInfo = Introspector.getBeanInfo(beanClass) - beanInfo.getPropertyDescriptors - .filter(p => p.getReadMethod != null && p.getWriteMethod != null) + beanInfo.getPropertyDescriptors.filterNot(_.getName == "class") + .filter(_.getReadMethod != null) + } + + private def getJavaBeanReadableAndWritableProperties( + beanClass: Class[_]): Array[PropertyDescriptor] = { + getJavaBeanReadableProperties(beanClass).filter(_.getWriteMethod != null) } private def elementType(typeToken: TypeToken[_]): TypeToken[_] = { @@ -298,9 +302,7 @@ object JavaTypeInference { keyData :: valueData :: Nil) case other => - val properties = getJavaBeanProperties(other) - assert(properties.length > 0) - + val properties = getJavaBeanReadableAndWritableProperties(other) val setters = properties.map { p => val fieldName = p.getName val fieldType = typeToken.method(p.getReadMethod).getReturnType @@ -417,21 +419,16 @@ object JavaTypeInference { ) case other => - val properties = getJavaBeanProperties(other) - if (properties.length > 0) { - CreateNamedStruct(properties.flatMap { p => - val fieldName = p.getName - val fieldType = typeToken.method(p.getReadMethod).getReturnType - val fieldValue = Invoke( - inputObject, - p.getReadMethod.getName, - inferExternalType(fieldType.getRawType)) - expressions.Literal(fieldName) :: serializerFor(fieldValue, fieldType) :: Nil - }) - } else { - throw new UnsupportedOperationException( - s"Cannot infer type for class ${other.getName} because it is not bean-compliant") - } + val properties = getJavaBeanReadableAndWritableProperties(other) + CreateNamedStruct(properties.flatMap { p => + val fieldName = p.getName + val fieldType = typeToken.method(p.getReadMethod).getReturnType + val fieldValue = Invoke( + inputObject, + p.getReadMethod.getName, + inferExternalType(fieldType.getRawType)) + expressions.Literal(fieldName) :: serializerFor(fieldValue, fieldType) :: Nil + }) } } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala index dbe55090ea113..234ef2dffc6bc 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala @@ -1090,14 +1090,14 @@ object SQLContext { */ private[sql] def beansToRows( data: Iterator[_], - beanInfo: BeanInfo, + beanClass: Class[_], attrs: Seq[AttributeReference]): Iterator[InternalRow] = { val extractors = - beanInfo.getPropertyDescriptors.filterNot(_.getName == "class").map(_.getReadMethod) + JavaTypeInference.getJavaBeanReadableProperties(beanClass).map(_.getReadMethod) val methodsToConverts = extractors.zip(attrs).map { case (e, attr) => (e, CatalystTypeConverters.createToCatalystConverter(attr.dataType)) } - data.map{ element => + data.map { element => new GenericInternalRow( methodsToConverts.map { case (e, convert) => convert(e.invoke(element)) } ): InternalRow diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala b/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala index 72af55c1fa147..afc1827e7eece 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala @@ -17,7 +17,6 @@ package org.apache.spark.sql -import java.beans.Introspector import java.io.Closeable import java.util.concurrent.atomic.AtomicReference @@ -347,8 +346,7 @@ class SparkSession private( val className = beanClass.getName val rowRdd = rdd.mapPartitions { iter => // BeanInfo is not serializable so we must rediscover it remotely for each partition. - val localBeanInfo = Introspector.getBeanInfo(Utils.classForName(className)) - SQLContext.beansToRows(iter, localBeanInfo, attributeSeq) + SQLContext.beansToRows(iter, Utils.classForName(className), attributeSeq) } Dataset.ofRows(self, LogicalRDD(attributeSeq, rowRdd)(self)) } @@ -374,8 +372,7 @@ class SparkSession private( */ def createDataFrame(data: java.util.List[_], beanClass: Class[_]): DataFrame = { val attrSeq = getSchema(beanClass) - val beanInfo = Introspector.getBeanInfo(beanClass) - val rows = SQLContext.beansToRows(data.asScala.iterator, beanInfo, attrSeq) + val rows = SQLContext.beansToRows(data.asScala.iterator, beanClass, attrSeq) Dataset.ofRows(self, LocalRelation(attrSeq, rows.toSeq)) } diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/JavaDataFrameSuite.java b/sql/core/src/test/java/test/org/apache/spark/sql/JavaDataFrameSuite.java index c3b94a44c2e91..a8f814bfae530 100644 --- a/sql/core/src/test/java/test/org/apache/spark/sql/JavaDataFrameSuite.java +++ b/sql/core/src/test/java/test/org/apache/spark/sql/JavaDataFrameSuite.java @@ -397,4 +397,21 @@ public void testBloomFilter() { Assert.assertTrue(filter4.mightContain(i * 3)); } } + + public static class BeanWithoutGetter implements Serializable { + private String a; + + public void setA(String a) { + this.a = a; + } + } + + @Test + public void testBeanWithoutGetter() { + BeanWithoutGetter bean = new BeanWithoutGetter(); + List data = Arrays.asList(bean); + Dataset df = spark.createDataFrame(data, BeanWithoutGetter.class); + Assert.assertEquals(df.schema().length(), 0); + Assert.assertEquals(df.collectAsList().size(), 1); + } } diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java b/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java index 577672ca8e083..4581c6ebe9ef8 100644 --- a/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java +++ b/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java @@ -1276,4 +1276,15 @@ public void test() { spark.createDataset(data, Encoders.bean(NestedComplicatedJavaBean.class)); ds.collectAsList(); } + + public static class EmptyBean implements Serializable {} + + @Test + public void testEmptyBean() { + EmptyBean bean = new EmptyBean(); + List data = Arrays.asList(bean); + Dataset df = spark.createDataset(data, Encoders.bean(EmptyBean.class)); + Assert.assertEquals(df.schema().length(), 0); + Assert.assertEquals(df.collectAsList().size(), 1); + } }