diff --git a/sql/core/src/test/scala/org/apache/spark/sql/UDFSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/UDFSuite.scala index 703a34c47ec20..8e5da3ac14da6 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/UDFSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/UDFSuite.scala @@ -82,6 +82,76 @@ class UDFSuite extends QueryTest { assert(ctx.sql("SELECT strLenScala('test', 1)").head().getInt(0) === 5) } + test("UDF in a WHERE") { + ctx.udf.register("oneArgFilter", (n: Int) => { n > 80 }) + + val df = ctx.sparkContext.parallelize( + (1 to 100).map(i => TestData(i, i.toString))).toDF() + df.registerTempTable("integerData") + + val result = + ctx.sql("SELECT * FROM integerData WHERE oneArgFilter(key)") + assert(result.count() === 20) + } + + test("UDF in a HAVING") { + ctx.udf.register("havingFilter", (n: Long) => { n > 5 }) + + val df = Seq(("red", 1), ("red", 2), ("blue", 10), + ("green", 100), ("green", 200)).toDF("g", "v") + df.registerTempTable("groupData") + + val result = + ctx.sql( + """ + | SELECT g, SUM(v) as s + | FROM groupData + | GROUP BY g + | HAVING havingFilter(s) + """.stripMargin) + + assert(result.count() === 2) + } + + test("UDF in a GROUP BY") { + ctx.udf.register("groupFunction", (n: Int) => { n > 10 }) + + val df = Seq(("red", 1), ("red", 2), ("blue", 10), + ("green", 100), ("green", 200)).toDF("g", "v") + df.registerTempTable("groupData") + + val result = + ctx.sql( + """ + | SELECT SUM(v) + | FROM groupData + | GROUP BY groupFunction(v) + """.stripMargin) + assert(result.count() === 2) + } + + test("UDFs everywhere") { + ctx.udf.register("groupFunction", (n: Int) => { n > 10 }) + ctx.udf.register("havingFilter", (n: Long) => { n > 2000 }) + ctx.udf.register("whereFilter", (n: Int) => { n < 150 }) + ctx.udf.register("timesHundred", (n: Long) => { n * 100 }) + + val df = Seq(("red", 1), ("red", 2), ("blue", 10), + ("green", 100), ("green", 200)).toDF("g", "v") + df.registerTempTable("groupData") + + val result = + ctx.sql( + """ + | SELECT timesHundred(SUM(v)) as v100 + | FROM groupData + | WHERE whereFilter(v) + | GROUP BY groupFunction(v) + | HAVING havingFilter(v100) + """.stripMargin) + assert(result.count() === 1) + } + test("struct UDF") { ctx.udf.register("returnStruct", (f1: String, f2: String) => FunctionResult(f1, f2))