Skip to content

Commit

Permalink
[SPARK-13094][SQL] Add encoders for seq/array of primitives
Browse files Browse the repository at this point in the history
Author: Michael Armbrust <michael@databricks.com>

Closes #11014 from marmbrus/seqEncoders.
  • Loading branch information
marmbrus committed Feb 2, 2016
1 parent 12a20c1 commit 29d9218
Show file tree
Hide file tree
Showing 3 changed files with 91 additions and 2 deletions.
63 changes: 62 additions & 1 deletion sql/core/src/main/scala/org/apache/spark/sql/SQLImplicits.scala
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,8 @@ abstract class SQLImplicits {
/** @since 1.6.0 */
implicit def newProductEncoder[T <: Product : TypeTag]: Encoder[T] = ExpressionEncoder()

// Primitives

/** @since 1.6.0 */
implicit def newIntEncoder: Encoder[Int] = ExpressionEncoder()

Expand All @@ -56,13 +58,72 @@ abstract class SQLImplicits {

/** @since 1.6.0 */
implicit def newShortEncoder: Encoder[Short] = ExpressionEncoder()
/** @since 1.6.0 */

/** @since 1.6.0 */
implicit def newBooleanEncoder: Encoder[Boolean] = ExpressionEncoder()

/** @since 1.6.0 */
implicit def newStringEncoder: Encoder[String] = ExpressionEncoder()

// Seqs

/** @since 1.6.1 */
implicit def newIntSeqEncoder: Encoder[Seq[Int]] = ExpressionEncoder()

/** @since 1.6.1 */
implicit def newLongSeqEncoder: Encoder[Seq[Long]] = ExpressionEncoder()

/** @since 1.6.1 */
implicit def newDoubleSeqEncoder: Encoder[Seq[Double]] = ExpressionEncoder()

/** @since 1.6.1 */
implicit def newFloatSeqEncoder: Encoder[Seq[Float]] = ExpressionEncoder()

/** @since 1.6.1 */
implicit def newByteSeqEncoder: Encoder[Seq[Byte]] = ExpressionEncoder()

/** @since 1.6.1 */
implicit def newShortSeqEncoder: Encoder[Seq[Short]] = ExpressionEncoder()

/** @since 1.6.1 */
implicit def newBooleanSeqEncoder: Encoder[Seq[Boolean]] = ExpressionEncoder()

/** @since 1.6.1 */
implicit def newStringSeqEncoder: Encoder[Seq[String]] = ExpressionEncoder()

/** @since 1.6.1 */
implicit def newProductSeqEncoder[A <: Product : TypeTag]: Encoder[Seq[A]] = ExpressionEncoder()

// Arrays

/** @since 1.6.1 */
implicit def newIntArrayEncoder: Encoder[Array[Int]] = ExpressionEncoder()

/** @since 1.6.1 */
implicit def newLongArrayEncoder: Encoder[Array[Long]] = ExpressionEncoder()

/** @since 1.6.1 */
implicit def newDoubleArrayEncoder: Encoder[Array[Double]] = ExpressionEncoder()

/** @since 1.6.1 */
implicit def newFloatArrayEncoder: Encoder[Array[Float]] = ExpressionEncoder()

/** @since 1.6.1 */
implicit def newByteArrayEncoder: Encoder[Array[Byte]] = ExpressionEncoder()

/** @since 1.6.1 */
implicit def newShortArrayEncoder: Encoder[Array[Short]] = ExpressionEncoder()

/** @since 1.6.1 */
implicit def newBooleanArrayEncoder: Encoder[Array[Boolean]] = ExpressionEncoder()

/** @since 1.6.1 */
implicit def newStringArrayEncoder: Encoder[Array[String]] = ExpressionEncoder()

/** @since 1.6.1 */
implicit def newProductArrayEncoder[A <: Product : TypeTag]: Encoder[Array[A]] =
ExpressionEncoder()

/**
* Creates a [[Dataset]] from an RDD.
* @since 1.6.0
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -105,4 +105,26 @@ class DatasetPrimitiveSuite extends QueryTest with SharedSQLContext {
agged,
"1", "abc", "3", "xyz", "5", "hello")
}

test("Arrays and Lists") {
checkAnswer(Seq(Seq(1)).toDS(), Seq(1))
checkAnswer(Seq(Seq(1.toLong)).toDS(), Seq(1.toLong))
checkAnswer(Seq(Seq(1.toDouble)).toDS(), Seq(1.toDouble))
checkAnswer(Seq(Seq(1.toFloat)).toDS(), Seq(1.toFloat))
checkAnswer(Seq(Seq(1.toByte)).toDS(), Seq(1.toByte))
checkAnswer(Seq(Seq(1.toShort)).toDS(), Seq(1.toShort))
checkAnswer(Seq(Seq(true)).toDS(), Seq(true))
checkAnswer(Seq(Seq("test")).toDS(), Seq("test"))
checkAnswer(Seq(Seq(Tuple1(1))).toDS(), Seq(Tuple1(1)))

checkAnswer(Seq(Array(1)).toDS(), Array(1))
checkAnswer(Seq(Array(1.toLong)).toDS(), Array(1.toLong))
checkAnswer(Seq(Array(1.toDouble)).toDS(), Array(1.toDouble))
checkAnswer(Seq(Array(1.toFloat)).toDS(), Array(1.toFloat))
checkAnswer(Seq(Array(1.toByte)).toDS(), Array(1.toByte))
checkAnswer(Seq(Array(1.toShort)).toDS(), Array(1.toShort))
checkAnswer(Seq(Array(true)).toDS(), Array(true))
checkAnswer(Seq(Array("test")).toDS(), Array("test"))
checkAnswer(Seq(Array(Tuple1(1))).toDS(), Array(Tuple1(1)))
}
}
8 changes: 7 additions & 1 deletion sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,13 @@ abstract class QueryTest extends PlanTest {
""".stripMargin, e)
}

if (decoded != expectedAnswer.toSet) {
// Handle the case where the return type is an array
val isArray = decoded.headOption.map(_.getClass.isArray).getOrElse(false)
def normalEquality = decoded == expectedAnswer.toSet
def expectedAsSeq = expectedAnswer.map(_.asInstanceOf[Array[_]].toSeq).toSet
def decodedAsSeq = decoded.map(_.asInstanceOf[Array[_]].toSeq)

if (!((isArray && expectedAsSeq == decodedAsSeq) || normalEquality)) {
val expected = expectedAnswer.toSet.toSeq.map((a: Any) => a.toString).sorted
val actual = decoded.toSet.toSeq.map((a: Any) => a.toString).sorted

Expand Down

0 comments on commit 29d9218

Please sign in to comment.