From 3c58b0d31578aaa79a619d014ebc6882effb4a52 Mon Sep 17 00:00:00 2001 From: Daoyuan Wang Date: Sun, 19 Oct 2014 19:58:33 -0700 Subject: [PATCH 1/3] add 3 types for java SQL context --- .../org/apache/spark/sql/api/java/JavaSQLContext.scala | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/api/java/JavaSQLContext.scala b/sql/core/src/main/scala/org/apache/spark/sql/api/java/JavaSQLContext.scala index f8171c3be3207..d53199114b70a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/api/java/JavaSQLContext.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/api/java/JavaSQLContext.scala @@ -226,6 +226,12 @@ class JavaSQLContext(val sqlContext: SQLContext) extends UDFRegistration { (org.apache.spark.sql.FloatType, true) case c: Class[_] if c == classOf[java.lang.Boolean] => (org.apache.spark.sql.BooleanType, true) + case c: Class[_] if c == classOf[java.math.BigDecimal] => + (org.apache.spark.sql.DecimalType, true) + case c: Class[_] if c == classOf[java.sql.Date] => + (org.apache.spark.sql.DateType, true) + case c: Class[_] if c == classOf[java.sql.Timestamp] => + (org.apache.spark.sql.TimestampType, true) } AttributeReference(property.getName, dataType, nullable)() } From bb0508f1382186c20ddb80b6032f3fce5c6cf6aa Mon Sep 17 00:00:00 2001 From: Daoyuan Wang Date: Mon, 20 Oct 2014 22:32:31 -0700 Subject: [PATCH 2/3] add test cases --- .../spark/sql/api/java/JavaSQLSuite.scala | 18 ++++++++++++++---- 1 file changed, 14 insertions(+), 4 deletions(-) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/api/java/JavaSQLSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/api/java/JavaSQLSuite.scala index 203ff847e94cc..15cee0c2bc353 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/api/java/JavaSQLSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/api/java/JavaSQLSuite.scala @@ -45,6 +45,9 @@ class AllTypesBean extends Serializable { @BeanProperty var shortField: java.lang.Short = _ @BeanProperty var byteField: java.lang.Byte = _ @BeanProperty var booleanField: java.lang.Boolean = _ + @BeanProperty var dateField: java.sql.Date = _ + @BeanProperty var timestampField: java.sql.Timestamp = _ + @BeanProperty var bigDecimalField: java.math.BigDecimal = _ } class JavaSQLSuite extends FunSuite { @@ -73,6 +76,9 @@ class JavaSQLSuite extends FunSuite { bean.setShortField(0.toShort) bean.setByteField(0.toByte) bean.setBooleanField(false) + bean.setDateField(java.sql.Date.valueOf("2014-10-10")) + bean.setTimestampField(java.sql.Timestamp.valueOf("2014-10-10 00:00:00.0")) + bean.setBigDecimalField(new java.math.BigDecimal(0)) val rdd = javaCtx.parallelize(bean :: Nil) val schemaRDD = javaSqlCtx.applySchema(rdd, classOf[AllTypesBean]) @@ -82,10 +88,11 @@ class JavaSQLSuite extends FunSuite { javaSqlCtx.sql( """ |SELECT stringField, intField, longField, floatField, doubleField, shortField, byteField, - | booleanField + | booleanField, dateField, timestampField, bigDecimalField |FROM allTypes """.stripMargin).collect.head.row === - Seq("", 0, 0L, 0F, 0.0, 0.toShort, 0.toByte, false)) + Seq("", 0, 0L, 0F, 0.0, 0.toShort, 0.toByte, false, java.sql.Date.valueOf("2014-10-10"), + java.sql.Timestamp.valueOf("2014-10-10 00:00:00.0"), new java.math.BigDecimal(0))) } test("all types null in JavaBeans") { @@ -98,6 +105,9 @@ class JavaSQLSuite extends FunSuite { bean.setShortField(null) bean.setByteField(null) bean.setBooleanField(null) + bean.setDateField(null) + bean.setTimestampField(null) + bean.setBigDecimalField(null) val rdd = javaCtx.parallelize(bean :: Nil) val schemaRDD = javaSqlCtx.applySchema(rdd, classOf[AllTypesBean]) @@ -107,10 +117,10 @@ class JavaSQLSuite extends FunSuite { javaSqlCtx.sql( """ |SELECT stringField, intField, longField, floatField, doubleField, shortField, byteField, - | booleanField + | booleanField, dateField, timestampField, bigDecimalField |FROM allTypes """.stripMargin).collect.head.row === - Seq.fill(8)(null)) + Seq.fill(11)(null)) } test("loads JSON datasets") { From 4c4292ccbd23cac1cb511ca9581bb9ac17ac037f Mon Sep 17 00:00:00 2001 From: Daoyuan Wang Date: Wed, 29 Oct 2014 03:16:17 -0700 Subject: [PATCH 3/3] change underlying type of JavaSchemaRDD as scala --- .../spark/sql/api/java/JavaSQLContext.scala | 5 +++- .../sql/types/util/DataTypeConversions.scala | 12 +++++++++ .../spark/sql/api/java/JavaSQLSuite.scala | 25 ++++++++++++++++++- 3 files changed, 40 insertions(+), 2 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/api/java/JavaSQLContext.scala b/sql/core/src/main/scala/org/apache/spark/sql/api/java/JavaSQLContext.scala index d53199114b70a..082ae03eef03f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/api/java/JavaSQLContext.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/api/java/JavaSQLContext.scala @@ -24,6 +24,7 @@ import org.apache.hadoop.conf.Configuration import org.apache.spark.annotation.{DeveloperApi, Experimental} import org.apache.spark.api.java.{JavaRDD, JavaSparkContext} import org.apache.spark.sql.json.JsonRDD +import org.apache.spark.sql.types.util.DataTypeConversions import org.apache.spark.sql.{SQLContext, StructType => SStructType} import org.apache.spark.sql.catalyst.expressions.{AttributeReference, GenericRow, Row => ScalaRow} import org.apache.spark.sql.parquet.ParquetRelation @@ -97,7 +98,9 @@ class JavaSQLContext(val sqlContext: SQLContext) extends UDFRegistration { localBeanInfo.getPropertyDescriptors.filterNot(_.getName == "class").map(_.getReadMethod) iter.map { row => - new GenericRow(extractors.map(e => e.invoke(row)).toArray[Any]): ScalaRow + new GenericRow( + extractors.map(e => DataTypeConversions.convertJavaToCatalyst(e.invoke(row))).toArray[Any] + ): ScalaRow } } new JavaSchemaRDD(sqlContext, LogicalRDD(schema, rowRdd)(sqlContext)) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/types/util/DataTypeConversions.scala b/sql/core/src/main/scala/org/apache/spark/sql/types/util/DataTypeConversions.scala index e44cb08309523..609f7db562a31 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/types/util/DataTypeConversions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/types/util/DataTypeConversions.scala @@ -110,4 +110,16 @@ protected[sql] object DataTypeConversions { case structType: org.apache.spark.sql.api.java.StructType => StructType(structType.getFields.map(asScalaStructField)) } + + /** Converts Java objects to catalyst rows / types */ + def convertJavaToCatalyst(a: Any): Any = a match { + case d: java.math.BigDecimal => BigDecimal(d) + case other => other + } + + /** Converts Java objects to catalyst rows / types */ + def convertCatalystToJava(a: Any): Any = a match { + case d: scala.math.BigDecimal => d.underlying() + case other => other + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/api/java/JavaSQLSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/api/java/JavaSQLSuite.scala index 15cee0c2bc353..d83f3e23a9468 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/api/java/JavaSQLSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/api/java/JavaSQLSuite.scala @@ -92,7 +92,30 @@ class JavaSQLSuite extends FunSuite { |FROM allTypes """.stripMargin).collect.head.row === Seq("", 0, 0L, 0F, 0.0, 0.toShort, 0.toByte, false, java.sql.Date.valueOf("2014-10-10"), - java.sql.Timestamp.valueOf("2014-10-10 00:00:00.0"), new java.math.BigDecimal(0))) + java.sql.Timestamp.valueOf("2014-10-10 00:00:00.0"), scala.math.BigDecimal(0))) + } + + test("decimal types in JavaBeans") { + val bean = new AllTypesBean + bean.setStringField("") + bean.setIntField(0) + bean.setLongField(0) + bean.setFloatField(0.0F) + bean.setDoubleField(0.0) + bean.setShortField(0.toShort) + bean.setByteField(0.toByte) + bean.setBooleanField(false) + bean.setDateField(java.sql.Date.valueOf("2014-10-10")) + bean.setTimestampField(java.sql.Timestamp.valueOf("2014-10-10 00:00:00.0")) + bean.setBigDecimalField(new java.math.BigDecimal(0)) + + val rdd = javaCtx.parallelize(bean :: Nil) + val schemaRDD = javaSqlCtx.applySchema(rdd, classOf[AllTypesBean]) + schemaRDD.registerTempTable("decimalTypes") + + assert(javaSqlCtx.sql( + "select bigDecimalField + bigDecimalField from decimalTypes" + ).collect.head.row === Seq(scala.math.BigDecimal(0))) } test("all types null in JavaBeans") {