Skip to content

Commit

Permalink
recognize array types when infer data types from JavaBeans
Browse files Browse the repository at this point in the history
  • Loading branch information
mengxr committed Mar 23, 2015
1 parent c94d062 commit 4f2df5e
Show file tree
Hide file tree
Showing 2 changed files with 81 additions and 32 deletions.
80 changes: 49 additions & 31 deletions sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala
Original file line number Diff line number Diff line change
Expand Up @@ -1210,38 +1210,56 @@ class SQLContext(@transient val sparkContext: SparkContext)
* Returns a Catalyst Schema for the given java bean class.
*/
protected def getSchema(beanClass: Class[_]): Seq[AttributeReference] = {
val (dataType, _) = inferDataType(beanClass)
dataType.asInstanceOf[StructType].fields.map { f =>
AttributeReference(f.name, f.dataType, f.nullable)()
}
}

/**
* Infers the corresponding SQL data type of a Java class.
* @param clazz Java class
* @return (SQL data type, nullable)
*/
private def inferDataType(clazz: Class[_]): (DataType, Boolean) = {
// TODO: All of this could probably be moved to Catalyst as it is mostly not Spark specific.
val beanInfo = Introspector.getBeanInfo(beanClass)

// Note: The ordering of elements may differ from when the schema is inferred in Scala.
// This is because beanInfo.getPropertyDescriptors gives no guarantees about
// element ordering.
val fields = beanInfo.getPropertyDescriptors.filterNot(_.getName == "class")
fields.map { property =>
val (dataType, nullable) = property.getPropertyType match {
case c: Class[_] if c.isAnnotationPresent(classOf[SQLUserDefinedType]) =>
(c.getAnnotation(classOf[SQLUserDefinedType]).udt().newInstance(), true)
case c: Class[_] if c == classOf[java.lang.String] => (StringType, true)
case c: Class[_] if c == java.lang.Short.TYPE => (ShortType, false)
case c: Class[_] if c == java.lang.Integer.TYPE => (IntegerType, false)
case c: Class[_] if c == java.lang.Long.TYPE => (LongType, false)
case c: Class[_] if c == java.lang.Double.TYPE => (DoubleType, false)
case c: Class[_] if c == java.lang.Byte.TYPE => (ByteType, false)
case c: Class[_] if c == java.lang.Float.TYPE => (FloatType, false)
case c: Class[_] if c == java.lang.Boolean.TYPE => (BooleanType, false)

case c: Class[_] if c == classOf[java.lang.Short] => (ShortType, true)
case c: Class[_] if c == classOf[java.lang.Integer] => (IntegerType, true)
case c: Class[_] if c == classOf[java.lang.Long] => (LongType, true)
case c: Class[_] if c == classOf[java.lang.Double] => (DoubleType, true)
case c: Class[_] if c == classOf[java.lang.Byte] => (ByteType, true)
case c: Class[_] if c == classOf[java.lang.Float] => (FloatType, true)
case c: Class[_] if c == classOf[java.lang.Boolean] => (BooleanType, true)
case c: Class[_] if c == classOf[java.math.BigDecimal] => (DecimalType(), true)
case c: Class[_] if c == classOf[java.sql.Date] => (DateType, true)
case c: Class[_] if c == classOf[java.sql.Timestamp] => (TimestampType, true)
}
AttributeReference(property.getName, dataType, nullable)()
clazz match {
case c: Class[_] if c.isAnnotationPresent(classOf[SQLUserDefinedType]) =>
(c.getAnnotation(classOf[SQLUserDefinedType]).udt().newInstance(), true)

case c: Class[_] if c == classOf[java.lang.String] => (StringType, true)
case c: Class[_] if c == java.lang.Short.TYPE => (ShortType, false)
case c: Class[_] if c == java.lang.Integer.TYPE => (IntegerType, false)
case c: Class[_] if c == java.lang.Long.TYPE => (LongType, false)
case c: Class[_] if c == java.lang.Double.TYPE => (DoubleType, false)
case c: Class[_] if c == java.lang.Byte.TYPE => (ByteType, false)
case c: Class[_] if c == java.lang.Float.TYPE => (FloatType, false)
case c: Class[_] if c == java.lang.Boolean.TYPE => (BooleanType, false)

case c: Class[_] if c == classOf[java.lang.Short] => (ShortType, true)
case c: Class[_] if c == classOf[java.lang.Integer] => (IntegerType, true)
case c: Class[_] if c == classOf[java.lang.Long] => (LongType, true)
case c: Class[_] if c == classOf[java.lang.Double] => (DoubleType, true)
case c: Class[_] if c == classOf[java.lang.Byte] => (ByteType, true)
case c: Class[_] if c == classOf[java.lang.Float] => (FloatType, true)
case c: Class[_] if c == classOf[java.lang.Boolean] => (BooleanType, true)

case c: Class[_] if c == classOf[java.math.BigDecimal] => (DecimalType(), true)
case c: Class[_] if c == classOf[java.sql.Date] => (DateType, true)
case c: Class[_] if c == classOf[java.sql.Timestamp] => (TimestampType, true)

case c: Class[_] if c.isArray =>
val (dataType, nullable) = inferDataType(c.getComponentType)
(ArrayType(dataType, nullable), true)

case _ =>
val beanInfo = Introspector.getBeanInfo(clazz)
val properties = beanInfo.getPropertyDescriptors.filterNot(_.getName == "class")
val fields = properties.map { property =>
val (dataType, nullable) = inferDataType(property.getPropertyType)
new StructField(property.getName, dataType, nullable)
}
(new StructType(fields), true)
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -17,29 +17,37 @@

package test.org.apache.spark.sql;

import java.io.Serializable;
import java.util.Arrays;

import org.junit.After;
import org.junit.Assert;
import org.junit.Before;
import org.junit.Ignore;
import org.junit.Test;

import org.apache.spark.api.java.JavaRDD;
import org.apache.spark.api.java.JavaSparkContext;
import org.apache.spark.sql.*;
import org.apache.spark.sql.test.TestSQLContext;
import org.apache.spark.sql.test.TestSQLContext$;
import static org.apache.spark.sql.functions.*;


public class JavaDataFrameSuite {
private transient JavaSparkContext jsc;
private transient SQLContext context;

@Before
public void setUp() {
// Trigger static initializer of TestData
TestData$.MODULE$.testData();
jsc = new JavaSparkContext(TestSQLContext.sparkContext());
context = TestSQLContext$.MODULE$;
}

@After
public void tearDown() {
jsc = null;
context = null;
}

Expand Down Expand Up @@ -90,4 +98,27 @@ public void testShow() {
df.show();
df.show(1000);
}

public static class Bean implements Serializable {
private double a = 0.0;
private Integer[] b = new Integer[]{0, 1};

public double getA() {
return a;
}

public Integer[] getB() {
return b;
}
}

@Test
public void testCreateDataFrameFromJavaBeans() {
Bean bean = new Bean();
JavaRDD<Bean> rdd = jsc.parallelize(Arrays.asList(bean));
DataFrame df = context.createDataFrame(rdd, Bean.class);
Row first = df.select("a", "b").first();
Assert.assertEquals(bean.getA(), first.getDouble(0), 0.0);
Assert.assertArrayEquals(bean.getB(), first.<Integer[]>getAs(1));
}
}

0 comments on commit 4f2df5e

Please sign in to comment.