From 3c64750bdd4c2d0a5562f90aead37be81627cc9d Mon Sep 17 00:00:00 2001 From: Michael Armbrust Date: Mon, 5 May 2014 22:59:42 -0700 Subject: [PATCH] [SQL] SPARK-1732 - Support for null primitive values. I also removed a println that I bumped into. Author: Michael Armbrust Closes #658 from marmbrus/nullPrimitives and squashes the following commits: a3ec4f3 [Michael Armbrust] Remove println. 695606b [Michael Armbrust] Support for null primatives from using scala and java reflection. --- .../spark/sql/catalyst/ScalaReflection.scala | 14 ++++- .../spark/sql/api/java/JavaSQLContext.scala | 8 +++ .../org/apache/spark/sql/api/java/Row.scala | 2 +- .../spark/sql/execution/basicOperators.scala | 3 +- .../sql/ScalaReflectionRelationSuite.scala | 34 +++++++++++ .../spark/sql/api/java/JavaSQLSuite.scala | 61 +++++++++++++++++++ .../spark/sql/columnar/ColumnTypeSuite.scala | 5 +- 7 files changed, 122 insertions(+), 5 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala index 446d0e0bd7f54..792ef6cee6f5d 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala @@ -44,7 +44,8 @@ object ScalaReflection { case t if t <:< typeOf[Product] => val params = t.member("": TermName).asMethod.paramss StructType( - params.head.map(p => StructField(p.name.toString, schemaFor(p.typeSignature), true))) + params.head.map(p => + StructField(p.name.toString, schemaFor(p.typeSignature), nullable = true))) // Need to decide if we actually need a special type here. case t if t <:< typeOf[Array[Byte]] => BinaryType case t if t <:< typeOf[Array[_]] => @@ -58,6 +59,17 @@ object ScalaReflection { case t if t <:< typeOf[String] => StringType case t if t <:< typeOf[Timestamp] => TimestampType case t if t <:< typeOf[BigDecimal] => DecimalType + case t if t <:< typeOf[Option[_]] => + val TypeRef(_, _, Seq(optType)) = t + schemaFor(optType) + case t if t <:< typeOf[java.lang.Integer] => IntegerType + case t if t <:< typeOf[java.lang.Long] => LongType + case t if t <:< typeOf[java.lang.Double] => DoubleType + case t if t <:< typeOf[java.lang.Float] => FloatType + case t if t <:< typeOf[java.lang.Short] => ShortType + case t if t <:< typeOf[java.lang.Byte] => ByteType + case t if t <:< typeOf[java.lang.Boolean] => BooleanType + // TODO: The following datatypes could be marked as non-nullable. case t if t <:< definitions.IntTpe => IntegerType case t if t <:< definitions.LongTpe => LongType case t if t <:< definitions.DoubleTpe => DoubleType diff --git a/sql/core/src/main/scala/org/apache/spark/sql/api/java/JavaSQLContext.scala b/sql/core/src/main/scala/org/apache/spark/sql/api/java/JavaSQLContext.scala index a7347088794a8..57facbe10fc96 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/api/java/JavaSQLContext.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/api/java/JavaSQLContext.scala @@ -132,6 +132,14 @@ class JavaSQLContext(sparkContext: JavaSparkContext) { case c: Class[_] if c == java.lang.Byte.TYPE => ByteType case c: Class[_] if c == java.lang.Float.TYPE => FloatType case c: Class[_] if c == java.lang.Boolean.TYPE => BooleanType + + case c: Class[_] if c == classOf[java.lang.Short] => ShortType + case c: Class[_] if c == classOf[java.lang.Integer] => IntegerType + case c: Class[_] if c == classOf[java.lang.Long] => LongType + case c: Class[_] if c == classOf[java.lang.Double] => DoubleType + case c: Class[_] if c == classOf[java.lang.Byte] => ByteType + case c: Class[_] if c == classOf[java.lang.Float] => FloatType + case c: Class[_] if c == classOf[java.lang.Boolean] => BooleanType } // TODO: Nullability could be stricter. AttributeReference(property.getName, dataType, nullable = true)() diff --git a/sql/core/src/main/scala/org/apache/spark/sql/api/java/Row.scala b/sql/core/src/main/scala/org/apache/spark/sql/api/java/Row.scala index 362fe769581d7..9b0dd2176149b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/api/java/Row.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/api/java/Row.scala @@ -22,7 +22,7 @@ import org.apache.spark.sql.catalyst.expressions.{Row => ScalaRow} /** * A result row from a SparkSQL query. */ -class Row(row: ScalaRow) extends Serializable { +class Row(private[spark] val row: ScalaRow) extends Serializable { /** Returns the number of columns present in this Row. */ def length: Int = row.length diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala index d807187a5ffb8..8969794c69933 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala @@ -164,6 +164,7 @@ case class Sort( @DeveloperApi object ExistingRdd { def convertToCatalyst(a: Any): Any = a match { + case o: Option[_] => o.orNull case s: Seq[Any] => s.map(convertToCatalyst) case p: Product => new GenericRow(p.productIterator.map(convertToCatalyst).toArray) case other => other @@ -180,7 +181,7 @@ object ExistingRdd { bufferedIterator.map { r => var i = 0 while (i < mutableRow.length) { - mutableRow(i) = r.productElement(i) + mutableRow(i) = convertToCatalyst(r.productElement(i)) i += 1 } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/ScalaReflectionRelationSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/ScalaReflectionRelationSuite.scala index 1cbf973c34917..f2934da9a031d 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/ScalaReflectionRelationSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/ScalaReflectionRelationSuite.scala @@ -36,6 +36,24 @@ case class ReflectData( timestampField: Timestamp, seqInt: Seq[Int]) +case class NullReflectData( + intField: java.lang.Integer, + longField: java.lang.Long, + floatField: java.lang.Float, + doubleField: java.lang.Double, + shortField: java.lang.Short, + byteField: java.lang.Byte, + booleanField: java.lang.Boolean) + +case class OptionalReflectData( + intField: Option[Int], + longField: Option[Long], + floatField: Option[Float], + doubleField: Option[Double], + shortField: Option[Short], + byteField: Option[Byte], + booleanField: Option[Boolean]) + case class ReflectBinary(data: Array[Byte]) class ScalaReflectionRelationSuite extends FunSuite { @@ -48,6 +66,22 @@ class ScalaReflectionRelationSuite extends FunSuite { assert(sql("SELECT * FROM reflectData").collect().head === data.productIterator.toSeq) } + test("query case class RDD with nulls") { + val data = NullReflectData(null, null, null, null, null, null, null) + val rdd = sparkContext.parallelize(data :: Nil) + rdd.registerAsTable("reflectNullData") + + assert(sql("SELECT * FROM reflectNullData").collect().head === Seq.fill(7)(null)) + } + + test("query case class RDD with Nones") { + val data = OptionalReflectData(None, None, None, None, None, None, None) + val rdd = sparkContext.parallelize(data :: Nil) + rdd.registerAsTable("reflectOptionalData") + + assert(sql("SELECT * FROM reflectOptionalData").collect().head === Seq.fill(7)(null)) + } + // Equality is broken for Arrays, so we test that separately. test("query binary data") { val rdd = sparkContext.parallelize(ReflectBinary(Array[Byte](1)) :: Nil) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/api/java/JavaSQLSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/api/java/JavaSQLSuite.scala index def0e046a3831..9fff7222fe840 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/api/java/JavaSQLSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/api/java/JavaSQLSuite.scala @@ -35,6 +35,17 @@ class PersonBean extends Serializable { var age: Int = _ } +class AllTypesBean extends Serializable { + @BeanProperty var stringField: String = _ + @BeanProperty var intField: java.lang.Integer = _ + @BeanProperty var longField: java.lang.Long = _ + @BeanProperty var floatField: java.lang.Float = _ + @BeanProperty var doubleField: java.lang.Double = _ + @BeanProperty var shortField: java.lang.Short = _ + @BeanProperty var byteField: java.lang.Byte = _ + @BeanProperty var booleanField: java.lang.Boolean = _ +} + class JavaSQLSuite extends FunSuite { val javaCtx = new JavaSparkContext(TestSQLContext.sparkContext) val javaSqlCtx = new JavaSQLContext(javaCtx) @@ -50,4 +61,54 @@ class JavaSQLSuite extends FunSuite { schemaRDD.registerAsTable("people") javaSqlCtx.sql("SELECT * FROM people").collect() } + + test("all types in JavaBeans") { + val bean = new AllTypesBean + bean.setStringField("") + bean.setIntField(0) + bean.setLongField(0) + bean.setFloatField(0.0F) + bean.setDoubleField(0.0) + bean.setShortField(0.toShort) + bean.setByteField(0.toByte) + bean.setBooleanField(false) + + val rdd = javaCtx.parallelize(bean :: Nil) + val schemaRDD = javaSqlCtx.applySchema(rdd, classOf[AllTypesBean]) + schemaRDD.registerAsTable("allTypes") + + assert( + javaSqlCtx.sql( + """ + |SELECT stringField, intField, longField, floatField, doubleField, shortField, byteField, + | booleanField + |FROM allTypes + """.stripMargin).collect.head.row === + Seq("", 0, 0L, 0F, 0.0, 0.toShort, 0.toByte, false)) + } + + test("all types null in JavaBeans") { + val bean = new AllTypesBean + bean.setStringField(null) + bean.setIntField(null) + bean.setLongField(null) + bean.setFloatField(null) + bean.setDoubleField(null) + bean.setShortField(null) + bean.setByteField(null) + bean.setBooleanField(null) + + val rdd = javaCtx.parallelize(bean :: Nil) + val schemaRDD = javaSqlCtx.applySchema(rdd, classOf[AllTypesBean]) + schemaRDD.registerAsTable("allTypes") + + assert( + javaSqlCtx.sql( + """ + |SELECT stringField, intField, longField, floatField, doubleField, shortField, byteField, + | booleanField + |FROM allTypes + """.stripMargin).collect.head.row === + Seq.fill(8)(null)) + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/columnar/ColumnTypeSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/columnar/ColumnTypeSuite.scala index 325173cf95fdf..71be41056768f 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/columnar/ColumnTypeSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/columnar/ColumnTypeSuite.scala @@ -21,11 +21,12 @@ import java.nio.ByteBuffer import org.scalatest.FunSuite +import org.apache.spark.sql.Logging import org.apache.spark.sql.catalyst.types._ import org.apache.spark.sql.columnar.ColumnarTestUtils._ import org.apache.spark.sql.execution.SparkSqlSerializer -class ColumnTypeSuite extends FunSuite { +class ColumnTypeSuite extends FunSuite with Logging { val DEFAULT_BUFFER_SIZE = 512 test("defaultSize") { @@ -163,7 +164,7 @@ class ColumnTypeSuite extends FunSuite { buffer.rewind() seq.foreach { expected => - println("buffer = " + buffer + ", expected = " + expected) + logger.info("buffer = " + buffer + ", expected = " + expected) val extracted = columnType.extract(buffer) assert( expected === extracted,