diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/types/DataTypes.java b/sql/catalyst/src/main/java/org/apache/spark/sql/types/DataTypes.java index e457542c647e7..5a5a6cd5c812e 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/types/DataTypes.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/types/DataTypes.java @@ -79,6 +79,11 @@ public class DataTypes { */ public static final DataType ShortType = ShortType$.MODULE$; + /** + * Gets the ShortType object. + */ + public static final DataType CharType = CharType$.MODULE$; + /** * Gets the NullType object. */ diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala index d6126c24fc50d..06227fbaf3137 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala @@ -170,6 +170,7 @@ trait ScalaReflection { case t if t <:< typeOf[java.lang.Double] => Schema(DoubleType, nullable = true) case t if t <:< typeOf[java.lang.Float] => Schema(FloatType, nullable = true) case t if t <:< typeOf[java.lang.Short] => Schema(ShortType, nullable = true) + case t if t <:< typeOf[java.lang.Character] => Schema(CharType, nullable = true) case t if t <:< typeOf[java.lang.Byte] => Schema(ByteType, nullable = true) case t if t <:< typeOf[java.lang.Boolean] => Schema(BooleanType, nullable = true) case t if t <:< definitions.IntTpe => Schema(IntegerType, nullable = false) @@ -177,6 +178,7 @@ trait ScalaReflection { case t if t <:< definitions.DoubleTpe => Schema(DoubleType, nullable = false) case t if t <:< definitions.FloatTpe => Schema(FloatType, nullable = false) case t if t <:< definitions.ShortTpe => Schema(ShortType, nullable = false) + case t if t <:< definitions.CharTpe => Schema(CharType, nullable = false) case t if t <:< definitions.ByteTpe => Schema(ByteType, nullable = false) case t if t <:< definitions.BooleanTpe => Schema(BooleanType, nullable = false) } @@ -189,6 +191,7 @@ trait ScalaReflection { case obj: StringType.JvmType => StringType case obj: ByteType.JvmType => ByteType case obj: ShortType.JvmType => ShortType + case obj: CharType.JvmType => CharType case obj: IntegerType.JvmType => IntegerType case obj: LongType.JvmType => LongType case obj: FloatType.JvmType => FloatType diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/dataTypes.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/dataTypes.scala index bf39603d13bd5..2dcca6e260ee1 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/dataTypes.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/dataTypes.scala @@ -610,6 +610,36 @@ class ShortType private() extends IntegralType { case object ShortType extends ShortType +/** + * :: DeveloperApi :: + * The data type representing `Char` values. Please use the singleton [[DataTypes.CharType]]. + * + * @group dataType + */ +@DeveloperApi +class CharType private() extends IntegralType { + // The companion object and this class is separated so the companion object also subclasses + // this type. Otherwise, the companion object would be of type "CharType$" in byte code. + // Defined with a private constructor so the companion object is the only possible instantiation. + private[sql] type JvmType = Char + @transient private[sql] lazy val tag = ScalaReflectionLock.synchronized { typeTag[JvmType] } + private[sql] val numeric = implicitly[Numeric[Char]] + private[sql] val integral = implicitly[Integral[Char]] + private[sql] val ordering = implicitly[Ordering[JvmType]] + + /** + * The default size of a value of the CharType is 2 bytes. + */ + override def defaultSize: Int = 2 + + override def simpleString = "char" + + private[spark] override def asNullable: CharType = this +} + +case object CharType extends CharType + + /** * :: DeveloperApi :: * The data type representing `Byte` values. Please use the singleton [[DataTypes.ByteType]]. diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/ScalaReflectionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/ScalaReflectionSuite.scala index eee00e3f7ea76..6a002191e6869 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/ScalaReflectionSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/ScalaReflectionSuite.scala @@ -26,6 +26,7 @@ import org.apache.spark.sql.catalyst.expressions.Row import org.apache.spark.sql.types._ case class PrimitiveData( + charField: Char, intField: Int, longField: Long, doubleField: Double, @@ -82,6 +83,7 @@ class ScalaReflectionSuite extends FunSuite { val schema = schemaFor[PrimitiveData] assert(schema === Schema( StructType(Seq( + StructField("charField", CharType, nullable = false), StructField("intField", IntegerType, nullable = false), StructField("longField", LongType, nullable = false), StructField("doubleField", DoubleType, nullable = false), @@ -157,6 +159,7 @@ class ScalaReflectionSuite extends FunSuite { StructField( "structField", StructType(Seq( + StructField("charField", CharType, nullable = false), StructField("intField", IntegerType, nullable = false), StructField("longField", LongType, nullable = false), StructField("doubleField", DoubleType, nullable = false), @@ -257,19 +260,19 @@ class ScalaReflectionSuite extends FunSuite { } test("convert PrimitiveData to catalyst") { - val data = PrimitiveData(1, 1, 1, 1, 1, 1, true) - val convertedData = Row(1, 1.toLong, 1.toDouble, 1.toFloat, 1.toShort, 1.toByte, true) + val data = PrimitiveData(1, 1, 1, 1, 1, 1, 1, true) + val convertedData = Row(1.toChar, 1, 1.toLong, 1.toDouble, 1.toFloat, 1.toShort, 1.toByte, true) val dataType = schemaFor[PrimitiveData].dataType assert(convertToCatalyst(data, dataType) === convertedData) } test("convert Option[Product] to catalyst") { - val primitiveData = PrimitiveData(1, 1, 1, 1, 1, 1, true) + val primitiveData = PrimitiveData(1, 1, 1, 1, 1, 1, 1, true) val data = OptionalData(Some(2), Some(2), Some(2), Some(2), Some(2), Some(2), Some(true), Some(primitiveData)) val dataType = schemaFor[OptionalData].dataType val convertedData = Row(2, 2.toLong, 2.toDouble, 2.toFloat, 2.toShort, 2.toByte, true, - Row(1, 1, 1, 1, 1, 1, true)) + Row(1, 1, 1, 1, 1, 1, 1, true)) assert(convertToCatalyst(data, dataType) === convertedData) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala index 9c49e84bf9680..20939fbd701ac 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala @@ -302,6 +302,12 @@ class SQLContext(@transient val sparkContext: SparkContext) } DataFrameHolder(self.createDataFrame(rows, StructType(StructField("_1", dataType) :: Nil))) } + + /** Creates a DataFrame from an RDD[Row]. */ + implicit def rowRddToDataFrameHolder(data: RDD[Row]): DataFrameHolder = { + val schema = data.first().schema + DataFrameHolder(self.createDataFrame(data, schema)) + } } /** @@ -1183,6 +1189,7 @@ class SQLContext(@transient val sparkContext: SparkContext) case c: Class[_] if c.isAnnotationPresent(classOf[SQLUserDefinedType]) => (c.getAnnotation(classOf[SQLUserDefinedType]).udt().newInstance(), true) case c: Class[_] if c == classOf[java.lang.String] => (StringType, true) + case c: Class[_] if c == java.lang.Character.TYPE => (CharType, false) case c: Class[_] if c == java.lang.Short.TYPE => (ShortType, false) case c: Class[_] if c == java.lang.Integer.TYPE => (IntegerType, false) case c: Class[_] if c == java.lang.Long.TYPE => (LongType, false) @@ -1191,6 +1198,7 @@ class SQLContext(@transient val sparkContext: SparkContext) case c: Class[_] if c == java.lang.Float.TYPE => (FloatType, false) case c: Class[_] if c == java.lang.Boolean.TYPE => (BooleanType, false) + case c: Class[_] if c == classOf[java.lang.Character] => (CharType, true) case c: Class[_] if c == classOf[java.lang.Short] => (ShortType, true) case c: Class[_] if c == classOf[java.lang.Integer] => (IntegerType, true) case c: Class[_] if c == classOf[java.lang.Long] => (LongType, true) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/ExistingRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/ExistingRDD.scala index 248dc1512b4d3..1f85a04f41936 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/ExistingRDD.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/ExistingRDD.scala @@ -38,7 +38,7 @@ object RDDConversions { } else { val bufferedIterator = iterator.buffered val mutableRow = new GenericMutableRow(bufferedIterator.head.productArity) - val schemaFields = schema.fields.toArray + val schemaFields = schema.fields bufferedIterator.map { r => var i = 0 while (i < mutableRow.length) { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameImplicitsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameImplicitsSuite.scala index 2d2367d6e7292..f786854fe526e 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameImplicitsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameImplicitsSuite.scala @@ -52,4 +52,11 @@ class DataFrameImplicitsSuite extends QueryTest { sc.parallelize(1 to 10).map(_.toString).toDF("stringCol"), (1 to 10).map(i => Row(i.toString))) } + + test("RDD[Row]") { + val rdd = (1 to 10).map(i => (i, i.toString)).toDF("intCol", "strCol").rdd + checkAnswer( + rdd.toDF("intCol", "strCol"), + (1 to 10).map(i => Row(i, i.toString))) + } }