Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SPARK-6475][SQL] recognize array types when infer data types from JavaBeans #5146

Closed
wants to merge 2 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
80 changes: 49 additions & 31 deletions sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala
Original file line number Diff line number Diff line change
Expand Up @@ -1210,38 +1210,56 @@ class SQLContext(@transient val sparkContext: SparkContext)
* Returns a Catalyst Schema for the given java bean class.
*/
protected def getSchema(beanClass: Class[_]): Seq[AttributeReference] = {
val (dataType, _) = inferDataType(beanClass)
dataType.asInstanceOf[StructType].fields.map { f =>
AttributeReference(f.name, f.dataType, f.nullable)()
}
}

/**
* Infers the corresponding SQL data type of a Java class.
* @param clazz Java class
* @return (SQL data type, nullable)
*/
private def inferDataType(clazz: Class[_]): (DataType, Boolean) = {
// TODO: All of this could probably be moved to Catalyst as it is mostly not Spark specific.
val beanInfo = Introspector.getBeanInfo(beanClass)

// Note: The ordering of elements may differ from when the schema is inferred in Scala.
// This is because beanInfo.getPropertyDescriptors gives no guarantees about
// element ordering.
val fields = beanInfo.getPropertyDescriptors.filterNot(_.getName == "class")
fields.map { property =>
val (dataType, nullable) = property.getPropertyType match {
case c: Class[_] if c.isAnnotationPresent(classOf[SQLUserDefinedType]) =>
(c.getAnnotation(classOf[SQLUserDefinedType]).udt().newInstance(), true)
case c: Class[_] if c == classOf[java.lang.String] => (StringType, true)
case c: Class[_] if c == java.lang.Short.TYPE => (ShortType, false)
case c: Class[_] if c == java.lang.Integer.TYPE => (IntegerType, false)
case c: Class[_] if c == java.lang.Long.TYPE => (LongType, false)
case c: Class[_] if c == java.lang.Double.TYPE => (DoubleType, false)
case c: Class[_] if c == java.lang.Byte.TYPE => (ByteType, false)
case c: Class[_] if c == java.lang.Float.TYPE => (FloatType, false)
case c: Class[_] if c == java.lang.Boolean.TYPE => (BooleanType, false)

case c: Class[_] if c == classOf[java.lang.Short] => (ShortType, true)
case c: Class[_] if c == classOf[java.lang.Integer] => (IntegerType, true)
case c: Class[_] if c == classOf[java.lang.Long] => (LongType, true)
case c: Class[_] if c == classOf[java.lang.Double] => (DoubleType, true)
case c: Class[_] if c == classOf[java.lang.Byte] => (ByteType, true)
case c: Class[_] if c == classOf[java.lang.Float] => (FloatType, true)
case c: Class[_] if c == classOf[java.lang.Boolean] => (BooleanType, true)
case c: Class[_] if c == classOf[java.math.BigDecimal] => (DecimalType(), true)
case c: Class[_] if c == classOf[java.sql.Date] => (DateType, true)
case c: Class[_] if c == classOf[java.sql.Timestamp] => (TimestampType, true)
}
AttributeReference(property.getName, dataType, nullable)()
clazz match {
case c: Class[_] if c.isAnnotationPresent(classOf[SQLUserDefinedType]) =>
(c.getAnnotation(classOf[SQLUserDefinedType]).udt().newInstance(), true)

case c: Class[_] if c == classOf[java.lang.String] => (StringType, true)
case c: Class[_] if c == java.lang.Short.TYPE => (ShortType, false)
case c: Class[_] if c == java.lang.Integer.TYPE => (IntegerType, false)
case c: Class[_] if c == java.lang.Long.TYPE => (LongType, false)
case c: Class[_] if c == java.lang.Double.TYPE => (DoubleType, false)
case c: Class[_] if c == java.lang.Byte.TYPE => (ByteType, false)
case c: Class[_] if c == java.lang.Float.TYPE => (FloatType, false)
case c: Class[_] if c == java.lang.Boolean.TYPE => (BooleanType, false)

case c: Class[_] if c == classOf[java.lang.Short] => (ShortType, true)
case c: Class[_] if c == classOf[java.lang.Integer] => (IntegerType, true)
case c: Class[_] if c == classOf[java.lang.Long] => (LongType, true)
case c: Class[_] if c == classOf[java.lang.Double] => (DoubleType, true)
case c: Class[_] if c == classOf[java.lang.Byte] => (ByteType, true)
case c: Class[_] if c == classOf[java.lang.Float] => (FloatType, true)
case c: Class[_] if c == classOf[java.lang.Boolean] => (BooleanType, true)

case c: Class[_] if c == classOf[java.math.BigDecimal] => (DecimalType(), true)
case c: Class[_] if c == classOf[java.sql.Date] => (DateType, true)
case c: Class[_] if c == classOf[java.sql.Timestamp] => (TimestampType, true)

case c: Class[_] if c.isArray =>
val (dataType, nullable) = inferDataType(c.getComponentType)
(ArrayType(dataType, nullable), true)

case _ =>
val beanInfo = Introspector.getBeanInfo(clazz)
val properties = beanInfo.getPropertyDescriptors.filterNot(_.getName == "class")
val fields = properties.map { property =>
val (dataType, nullable) = inferDataType(property.getPropertyType)
new StructField(property.getName, dataType, nullable)
}
(new StructType(fields), true)
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -17,29 +17,39 @@

package test.org.apache.spark.sql;

import java.io.Serializable;
import java.util.Arrays;

import org.junit.After;
import org.junit.Assert;
import org.junit.Before;
import org.junit.Ignore;
import org.junit.Test;

import org.apache.spark.api.java.JavaRDD;
import org.apache.spark.api.java.JavaSparkContext;
import org.apache.spark.sql.*;
import org.apache.spark.sql.test.TestSQLContext;
import org.apache.spark.sql.test.TestSQLContext$;
import static org.apache.spark.sql.functions.*;
import org.apache.spark.sql.types.*;

import static org.apache.spark.sql.functions.*;

public class JavaDataFrameSuite {
private transient JavaSparkContext jsc;
private transient SQLContext context;

@Before
public void setUp() {
// Trigger static initializer of TestData
TestData$.MODULE$.testData();
jsc = new JavaSparkContext(TestSQLContext.sparkContext());
context = TestSQLContext$.MODULE$;
}

@After
public void tearDown() {
jsc = null;
context = null;
}

Expand Down Expand Up @@ -90,4 +100,33 @@ public void testShow() {
df.show();
df.show(1000);
}

public static class Bean implements Serializable {
private double a = 0.0;
private Integer[] b = new Integer[]{0, 1};

public double getA() {
return a;
}

public Integer[] getB() {
return b;
}
}

@Test
public void testCreateDataFrameFromJavaBeans() {
Bean bean = new Bean();
JavaRDD<Bean> rdd = jsc.parallelize(Arrays.asList(bean));
DataFrame df = context.createDataFrame(rdd, Bean.class);
StructType schema = df.schema();
Assert.assertEquals(new StructField("a", DoubleType$.MODULE$, false, Metadata.empty()),
schema.apply("a"));
Assert.assertEquals(
new StructField("b", new ArrayType(IntegerType$.MODULE$, true), true, Metadata.empty()),
schema.apply("b"));
Row first = df.select("a", "b").first();
Assert.assertEquals(bean.getA(), first.getDouble(0), 0.0);
Assert.assertArrayEquals(bean.getB(), first.<Integer[]>getAs(1));
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would be better to add assertions for schema field data types.

}
}