From 1fb1e1ad0d37578cb5a5bf25b22b886994fadc43 Mon Sep 17 00:00:00 2001 From: ahshahid Date: Mon, 2 Oct 2017 16:26:17 -0700 Subject: [PATCH 1/3] Update SQLContext.scala --- .../org/apache/spark/sql/SQLContext.scala | 38 +++++++++++++++++++ 1 file changed, 38 insertions(+) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala index af6018472cb03..77017d7962f4e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala @@ -1094,6 +1094,7 @@ object SQLContext { * bean info & schema. This is not related to the singleton, but is a static * method for internal use. */ + /* private[sql] def beansToRows( data: Iterator[_], beanClass: Class[_], @@ -1109,6 +1110,43 @@ object SQLContext { ): InternalRow } } + */ + /** + * Converts an iterator of Java Beans to InternalRow using the provided + * bean info & schema. This is not related to the singleton, but is a static + * method for internal use. + */ + private[sql] def beansToRows( + data: Iterator[_], + beanClass: Class[_], + attrs: Seq[AttributeReference]): Iterator[InternalRow] = { + val converters = getExtractors(beanClass, attrs) + data.map{ element => + new GenericInternalRow( + converters.map { case (e, convert) => convert(e.invoke(element)) } + ): InternalRow + } + } + + def getExtractors( beanClass: Class[_], + attrs: Seq[AttributeReference]): Array[(Method, Any => Any)] = { + val methodsToConverts = JavaTypeInference.getJavaBeanReadableProperties(beanClass). + map(_.getReadMethod).zip(attrs) + methodsToConverts.map { case (e, attr) => + attr.dataType match { + case strct: StructType => { + val extractors = getExtractors(e.getReturnType, + strct.map(sf => AttributeReference(sf.name, sf.dataType, sf.nullable)())) + (e, (x: Any) => { + val arr = Array.tabulate[Any](strct.length)(i => + extractors(i)._2(extractors(i)._1.invoke(x))) + new GenericInternalRow(arr) + }) + } + case _ => (e, CatalystTypeConverters.createToCatalystConverter(attr.dataType)) + } + } + } /** * Extract `spark.sql.*` properties from the conf and return them as a [[Properties]]. From 3074347360e8c18e947e52aea89491bb9e844d86 Mon Sep 17 00:00:00 2001 From: Asif Shahid Date: Tue, 3 Oct 2017 12:23:57 -0700 Subject: [PATCH 2/3] Fix for bug SPARK-22192. Recursively convert the nested POJOs into GenericInternalRow to avoid ScalaMatchErrorException --- .../org/apache/spark/sql/SQLContext.scala | 8 +-- .../apache/spark/sql/SQLContextSuite.scala | 63 +++++++++++++++++++ 2 files changed, 67 insertions(+), 4 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala index 77017d7962f4e..0fa3daed85662 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala @@ -17,11 +17,11 @@ package org.apache.spark.sql +import java.lang.reflect.Method import java.util.Properties import scala.collection.immutable import scala.reflect.runtime.universe.TypeTag - import org.apache.spark.{SparkConf, SparkContext} import org.apache.spark.annotation.{DeveloperApi, Experimental, InterfaceStability} import org.apache.spark.api.java.{JavaRDD, JavaSparkContext} @@ -31,7 +31,7 @@ import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst._ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.execution.command.ShowTablesCommand -import org.apache.spark.sql.internal.{SessionState, SharedState, SQLConf} +import org.apache.spark.sql.internal.{SQLConf, SessionState, SharedState} import org.apache.spark.sql.sources.BaseRelation import org.apache.spark.sql.streaming.{DataStreamReader, StreamingQueryManager} import org.apache.spark.sql.types._ @@ -1134,7 +1134,7 @@ object SQLContext { map(_.getReadMethod).zip(attrs) methodsToConverts.map { case (e, attr) => attr.dataType match { - case strct: StructType => { + case strct: StructType => val extractors = getExtractors(e.getReturnType, strct.map(sf => AttributeReference(sf.name, sf.dataType, sf.nullable)())) (e, (x: Any) => { @@ -1142,7 +1142,7 @@ object SQLContext { extractors(i)._2(extractors(i)._1.invoke(x))) new GenericInternalRow(arr) }) - } + case _ => (e, CatalystTypeConverters.createToCatalystConverter(attr.dataType)) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLContextSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLContextSuite.scala index a1799829932b8..ed92aa5c47e41 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SQLContextSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLContextSuite.scala @@ -17,6 +17,8 @@ package org.apache.spark.sql +import java.sql.{Date, Timestamp} + import org.apache.spark.{SharedSparkContext, SparkFunSuite} import org.apache.spark.sql.catalyst.TableIdentifier import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan @@ -48,6 +50,23 @@ class SQLContextSuite extends SparkFunSuite with SharedSparkContext { "SQLContext.getOrCreate after explicitly setActive() did not return the active context") } + test("Bug SPARK-22192 Nested POJO object not handled when creating DataFrame from RDD") { + val sqlContext = SQLContext.getOrCreate(sc) + val personsCollection = for (k <- 1 until 100) yield { + new Person(k, "name_" + k, k.toLong, k.toShort, + k.toByte, k.toDouble *86.7543d, k.toFloat *7.31f, + true, Array.fill[Byte](k)(k.toByte), + new java.sql.Date(7836*k*1000), new Timestamp(7896*k*1000), + new Address("12320 sw horizon", 97007)) + } + + // create a pair RDD from the collection + val personsRDD = sc.parallelize(personsCollection) + val df = sqlContext.createDataFrame(personsRDD, classOf[Person]) + df.printSchema() + df.collect() + } + test("Sessions of SQLContext") { val sqlContext = SQLContext.getOrCreate(sc) val session1 = sqlContext.newSession() @@ -143,3 +162,47 @@ class SQLContextSuite extends SparkFunSuite with SharedSparkContext { } } + + +class Person(var id: Int, var name: String, var longField: Long, var shortField: Short, + var byteField: Byte, var doubleField: Double, var floatField: Float, + var booleanField: Boolean, var binaryField: Array[Byte], + var datee: Date, var timeeStamp: Timestamp, + var address: Address ) extends java.io.Serializable{ + def this() = this(0, null, 0, 0, 0, 0d, 0f, false, null, null, null, null) + def getName: String = name + def getId: Int = id + def getLongField: Long = longField + def getShortField: Short = shortField + def getByteField: Byte = byteField + def getDoubleField: Double = doubleField + def getFloatField: Float = floatField + def getBooleanField: Boolean = booleanField + def getBinaryField: Array[Byte] = binaryField + def getDatee: Date = datee + def getTimeeStamp: Timestamp = timeeStamp + def getAddress: Address = address + + def setName(name: String): Unit = {this.name = name} + def setId(id: Int): Unit = {this.id = id} + def setLongField(longField: Long): Unit = {this.longField = longField} + def setShortField(shortField: Short): Unit = {this.shortField = shortField} + def setByteField(byteField: Byte): Unit = {this.byteField = byteField} + def setDoubleField(doubleField: Double): Unit = {this.doubleField = doubleField} + def setFloatField(floatField: Float): Unit = {this.floatField = floatField} + def setBooleanField(booleanField: Boolean): Unit = {this.booleanField = booleanField} + def setBinaryField(binaryField: Array[Byte]): Unit = {this.binaryField = binaryField} + def setDatee(datee: Date): Unit = {this.datee = datee} + def setTimeeStamp(ts: Timestamp): Unit = {this.timeeStamp = ts} + def setAddress(address: Address): Unit = {this.address = address} +} + + + +class Address(var street: String, var zip: Int) extends java.io.Serializable { + def this() = this(null, -1) + def getStreet: String = this.street + def getZip: Int = this.zip + def setStreet(street: String): Unit = {this.street = street} + def setZip(zip: Int): Unit = {this.zip = zip} +} From fb207df9f1939c1dbba577ce3b21025dda938be5 Mon Sep 17 00:00:00 2001 From: Asif Shahid Date: Tue, 3 Oct 2017 14:45:31 -0700 Subject: [PATCH 3/3] added data validation in the test --- .../apache/spark/sql/SQLContextSuite.scala | 29 +++++++++++++++++-- 1 file changed, 26 insertions(+), 3 deletions(-) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLContextSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLContextSuite.scala index ed92aa5c47e41..a75ec92ae8b8f 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SQLContextSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLContextSuite.scala @@ -26,6 +26,7 @@ import org.apache.spark.sql.catalyst.rules.Rule import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types.{BooleanType, StringType, StructField, StructType} + @deprecated("This suite is deprecated to silent compiler deprecation warnings", "2.0.0") class SQLContextSuite extends SparkFunSuite with SharedSparkContext { @@ -57,14 +58,36 @@ class SQLContextSuite extends SparkFunSuite with SharedSparkContext { k.toByte, k.toDouble *86.7543d, k.toFloat *7.31f, true, Array.fill[Byte](k)(k.toByte), new java.sql.Date(7836*k*1000), new Timestamp(7896*k*1000), - new Address("12320 sw horizon", 97007)) + new Address("12320 sw horizon," + k, 97007*k)) } // create a pair RDD from the collection val personsRDD = sc.parallelize(personsCollection) val df = sqlContext.createDataFrame(personsRDD, classOf[Person]) - df.printSchema() - df.collect() + val rows = df.collect() + val keys = scala.collection.mutable.Set[Int]() + for(i <- 1 until 100) keys.add(i) + for(row <- rows) { + assert(keys.remove(row.getAs[Int]("id"))) + val k = row.getAs[Int]("id") + assert("name_" + k == row.getAs[String]("name"), "String field match not as expected") + assert(k.toLong == row.getAs[Long]("longField"), "Long field match not as expected") + assert(k.toShort == row.getAs[Short]("shortField"), "Short field match not as expected") + assert(k.toByte == row.getAs[Byte]("byteField"), "Byte field match not as expected") + assert(k*86.7543d == row.getAs[Double]("doubleField"), "Double field match not as expected") + assert(k*7.31f == row.getAs[Float]("floatField"), "Float field match not as expected") + assert(true == row.getAs[Boolean]("booleanField"), "Boolean field match not as expected") + assertResult(Array.fill[Byte](k)(k.toByte).seq) {row.getAs[Array[Byte]]("binaryField").toSeq} + assert(new java.sql.Date(7836*k*1000).toString == row.getAs[Date]("datee").toString, + "Date field match not as expected") + assert(new Timestamp(7896*k*1000).toString == row.getAs[Timestamp]("timeeStamp").toString, + "TimeStamp field match not as expected") + val addressStruct = row.getAs[Row]("address") + assert("12320 sw horizon," + k == addressStruct.getAs[String]("street"), + "struct field match not as expected") + assert(97007*k == addressStruct.getAs[Int]("zip"), "struct field match not as expected") + } + assert(keys.isEmpty) } test("Sessions of SQLContext") {