From 1a3c5ff54c43d60e34e7591e7f175840b0e91513 Mon Sep 17 00:00:00 2001 From: Spiro Michaylov Date: Thu, 2 Jul 2015 04:34:51 -0700 Subject: [PATCH 1/2] Added several UDF unit tests for Spark SQL --- .../scala/org/apache/spark/sql/TestData.scala | 11 +++++ .../scala/org/apache/spark/sql/UDFSuite.scala | 43 +++++++++++++++++++ 2 files changed, 54 insertions(+) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/TestData.scala b/sql/core/src/test/scala/org/apache/spark/sql/TestData.scala index 207d7a352c7b3..d09b20d4beb5b 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/TestData.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/TestData.scala @@ -196,4 +196,15 @@ object TestData { :: ComplexData(Map("2" -> 2), TestData(2, "2"), Seq(2), false) :: Nil).toDF() complexData.registerTempTable("complexData") + + case class GroupData(g: String, v: Int) + val groupData = + TestSQLContext.sparkContext.parallelize( + GroupData("red", 1) :: + GroupData("red", 2) :: + GroupData("blue", 10) :: + GroupData("green", 100) :: + GroupData("green", 200) :: Nil).toDF() + groupData.registerTempTable("groupData") + } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/UDFSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/UDFSuite.scala index 703a34c47ec20..464d125663406 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/UDFSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/UDFSuite.scala @@ -21,6 +21,7 @@ package org.apache.spark.sql case class FunctionResult(f1: String, f2: String) class UDFSuite extends QueryTest { + import org.apache.spark.sql.TestData._ private lazy val ctx = org.apache.spark.sql.test.TestSQLContext import ctx.implicits._ @@ -82,6 +83,48 @@ class UDFSuite extends QueryTest { assert(ctx.sql("SELECT strLenScala('test', 1)").head().getInt(0) === 5) } + test("UDF in a WHERE") { + testData.sqlContext.udf.register("oneArgFilter", (n:Int) => { n > 80 }) + + val result = + testData.sqlContext.sql("SELECT * FROM testData WHERE oneArgFilter(key)") + assert(result.count() === 20) + } + + test("UDF in a HAVING") { + testData.sqlContext.udf.register("havingFilter", (n:Long) => { n > 5 }) + + val result = + testData.sqlContext.sql("SELECT g, SUM(v) as s FROM groupData GROUP BY g HAVING havingFilter(s)") + assert(result.count() === 2) + } + + test("UDF in a GROUP BY") { + testData.sqlContext.udf.register("groupFunction", (n:Int) => { n > 10 }) + + val result = + testData.sqlContext.sql("SELECT SUM(v) FROM groupData GROUP BY groupFunction(v)") + assert(result.count() === 2) + } + + test("UDFs everywhere") { + ctx.udf.register("groupFunction", (n:Int) => { n > 10 }) + ctx.udf.register("havingFilter", (n:Long) => { n > 2000 }) + ctx.udf.register("whereFilter", (n:Int) => { n < 150 }) + ctx.udf.register("timesHundred", (n:Long) => { n * 100 }) + + val result = + testData.sqlContext.sql( + """ + | SELECT timesHundred(SUM(v)) as v100 + | FROM groupData + | WHERE whereFilter(v) + | GROUP BY groupFunction(v) + | HAVING havingFilter(v100) + """.stripMargin) + assert(result.count() === 1) + } + test("struct UDF") { ctx.udf.register("returnStruct", (f1: String, f2: String) => FunctionResult(f1, f2)) From 6bbba9efbb4ffbf8cd4e8f980600272d66b4cadc Mon Sep 17 00:00:00 2001 From: Spiro Michaylov Date: Fri, 3 Jul 2015 10:56:55 -0700 Subject: [PATCH 2/2] Responded to review comments on UDF unit tests --- .../scala/org/apache/spark/sql/TestData.scala | 11 ---- .../scala/org/apache/spark/sql/UDFSuite.scala | 51 ++++++++++++++----- 2 files changed, 39 insertions(+), 23 deletions(-) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/TestData.scala b/sql/core/src/test/scala/org/apache/spark/sql/TestData.scala index d09b20d4beb5b..207d7a352c7b3 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/TestData.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/TestData.scala @@ -196,15 +196,4 @@ object TestData { :: ComplexData(Map("2" -> 2), TestData(2, "2"), Seq(2), false) :: Nil).toDF() complexData.registerTempTable("complexData") - - case class GroupData(g: String, v: Int) - val groupData = - TestSQLContext.sparkContext.parallelize( - GroupData("red", 1) :: - GroupData("red", 2) :: - GroupData("blue", 10) :: - GroupData("green", 100) :: - GroupData("green", 200) :: Nil).toDF() - groupData.registerTempTable("groupData") - } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/UDFSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/UDFSuite.scala index 464d125663406..8e5da3ac14da6 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/UDFSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/UDFSuite.scala @@ -21,7 +21,6 @@ package org.apache.spark.sql case class FunctionResult(f1: String, f2: String) class UDFSuite extends QueryTest { - import org.apache.spark.sql.TestData._ private lazy val ctx = org.apache.spark.sql.test.TestSQLContext import ctx.implicits._ @@ -84,37 +83,65 @@ class UDFSuite extends QueryTest { } test("UDF in a WHERE") { - testData.sqlContext.udf.register("oneArgFilter", (n:Int) => { n > 80 }) + ctx.udf.register("oneArgFilter", (n: Int) => { n > 80 }) + + val df = ctx.sparkContext.parallelize( + (1 to 100).map(i => TestData(i, i.toString))).toDF() + df.registerTempTable("integerData") val result = - testData.sqlContext.sql("SELECT * FROM testData WHERE oneArgFilter(key)") + ctx.sql("SELECT * FROM integerData WHERE oneArgFilter(key)") assert(result.count() === 20) } test("UDF in a HAVING") { - testData.sqlContext.udf.register("havingFilter", (n:Long) => { n > 5 }) + ctx.udf.register("havingFilter", (n: Long) => { n > 5 }) + + val df = Seq(("red", 1), ("red", 2), ("blue", 10), + ("green", 100), ("green", 200)).toDF("g", "v") + df.registerTempTable("groupData") val result = - testData.sqlContext.sql("SELECT g, SUM(v) as s FROM groupData GROUP BY g HAVING havingFilter(s)") + ctx.sql( + """ + | SELECT g, SUM(v) as s + | FROM groupData + | GROUP BY g + | HAVING havingFilter(s) + """.stripMargin) + assert(result.count() === 2) } test("UDF in a GROUP BY") { - testData.sqlContext.udf.register("groupFunction", (n:Int) => { n > 10 }) + ctx.udf.register("groupFunction", (n: Int) => { n > 10 }) + + val df = Seq(("red", 1), ("red", 2), ("blue", 10), + ("green", 100), ("green", 200)).toDF("g", "v") + df.registerTempTable("groupData") val result = - testData.sqlContext.sql("SELECT SUM(v) FROM groupData GROUP BY groupFunction(v)") + ctx.sql( + """ + | SELECT SUM(v) + | FROM groupData + | GROUP BY groupFunction(v) + """.stripMargin) assert(result.count() === 2) } test("UDFs everywhere") { - ctx.udf.register("groupFunction", (n:Int) => { n > 10 }) - ctx.udf.register("havingFilter", (n:Long) => { n > 2000 }) - ctx.udf.register("whereFilter", (n:Int) => { n < 150 }) - ctx.udf.register("timesHundred", (n:Long) => { n * 100 }) + ctx.udf.register("groupFunction", (n: Int) => { n > 10 }) + ctx.udf.register("havingFilter", (n: Long) => { n > 2000 }) + ctx.udf.register("whereFilter", (n: Int) => { n < 150 }) + ctx.udf.register("timesHundred", (n: Long) => { n * 100 }) + + val df = Seq(("red", 1), ("red", 2), ("blue", 10), + ("green", 100), ("green", 200)).toDF("g", "v") + df.registerTempTable("groupData") val result = - testData.sqlContext.sql( + ctx.sql( """ | SELECT timesHundred(SUM(v)) as v100 | FROM groupData