Skip to content

Commit

Permalink
[SPARK-27893][SQL][PYTHON][FOLLOW-UP] Allow Scalar Pandas and Python …
Browse files Browse the repository at this point in the history
…UDFs can be tested with Scala test base

## What changes were proposed in this pull request?

After this PR, we can test Pandas and Python UDF as below **in Scala side**:

```scala
import IntegratedUDFTestUtils._
val pandasTestUDF = TestScalarPandasUDF("udf")
spark.range(10).select(pandasTestUDF($"id")).show()
```

## How was this patch tested?

Manually tested.

Closes #24945 from HyukjinKwon/SPARK-27893-followup.

Authored-by: HyukjinKwon <gurwls223@apache.org>
Signed-off-by: HyukjinKwon <gurwls223@apache.org>
  • Loading branch information
HyukjinKwon committed Jun 25, 2019
1 parent 1d36b89 commit ac61f7d
Show file tree
Hide file tree
Showing 2 changed files with 35 additions and 25 deletions.
Expand Up @@ -40,32 +40,37 @@ import org.apache.spark.sql.types.StringType
*
* To register Scala UDF in SQL:
* {{{
* registerTestUDF(TestScalaUDF(name = "udf_name"), spark)
* val scalaTestUDF = TestScalaUDF(name = "udf_name")
* registerTestUDF(scalaTestUDF, spark)
* }}}
*
* To register Python UDF in SQL:
* {{{
* registerTestUDF(TestPythonUDF(name = "udf_name"), spark)
* val pythonTestUDF = TestPythonUDF(name = "udf_name")
* registerTestUDF(pythonTestUDF, spark)
* }}}
*
* To register Scalar Pandas UDF in SQL:
* {{{
* registerTestUDF(TestScalarPandasUDF(name = "udf_name"), spark)
* val pandasTestUDF = TestScalarPandasUDF(name = "udf_name")
* registerTestUDF(pandasTestUDF, spark)
* }}}
*
* To use it in Scala API and SQL:
* {{{
* sql("SELECT udf_name(1)")
* spark.select(expr("udf_name(1)")
* spark.range(10).select(expr("udf_name(id)")
* spark.range(10).select(pandasTestUDF($"id"))
* }}}
*/
object IntegratedUDFTestUtils extends SQLHelper {
import scala.sys.process._

private lazy val pythonPath = sys.env.getOrElse("PYTHONPATH", "")
private lazy val sparkHome = if (sys.props.contains(Tests.IS_TESTING.key)) {
assert(sys.props.contains("spark.test.home"), "spark.test.home is not set.")
sys.props("spark.test.home")
assert(sys.props.contains("spark.test.home") ||
sys.env.contains("SPARK_HOME"), "spark.test.home or SPARK_HOME is not set.")
sys.props.getOrElse("spark.test.home", sys.env("SPARK_HOME"))
} else {
assert(sys.env.contains("SPARK_HOME"), "SPARK_HOME is not set.")
sys.env("SPARK_HOME")
Expand Down Expand Up @@ -186,14 +191,18 @@ object IntegratedUDFTestUtils extends SQLHelper {
/**
* A base trait for various UDFs defined in this object.
*/
sealed trait TestUDF
sealed trait TestUDF {
def apply(exprs: Column*): Column

val prettyName: String
}

/**
* A Python UDF that takes one column and returns a string column.
* Equivalent to `udf(lambda x: str(x), "string")`
*/
case class TestPythonUDF(name: String) extends TestUDF {
lazy val udf = UserDefinedPythonFunction(
private[IntegratedUDFTestUtils] lazy val udf = UserDefinedPythonFunction(
name = name,
func = PythonFunction(
command = pythonFunc,
Expand All @@ -206,14 +215,18 @@ object IntegratedUDFTestUtils extends SQLHelper {
dataType = StringType,
pythonEvalType = PythonEvalType.SQL_BATCHED_UDF,
udfDeterministic = true)

def apply(exprs: Column*): Column = udf(exprs: _*)

val prettyName: String = "Regular Python UDF"
}

/**
* A Scalar Pandas UDF that takes one column and returns a string column.
* Equivalent to `pandas_udf(lambda x: x.apply(str), "string", PandasUDFType.SCALAR)`.
*/
case class TestScalarPandasUDF(name: String) extends TestUDF {
lazy val udf = UserDefinedPythonFunction(
private[IntegratedUDFTestUtils] lazy val udf = UserDefinedPythonFunction(
name = name,
func = PythonFunction(
command = pandasFunc,
Expand All @@ -226,17 +239,25 @@ object IntegratedUDFTestUtils extends SQLHelper {
dataType = StringType,
pythonEvalType = PythonEvalType.SQL_SCALAR_PANDAS_UDF,
udfDeterministic = true)

def apply(exprs: Column*): Column = udf(exprs: _*)

val prettyName: String = "Scalar Pandas UDF"
}

/**
* A Scala UDF that takes one column and returns a string column.
* Equivalent to `udf((input: Any) => input.toString)`.
*/
case class TestScalaUDF(name: String) extends TestUDF {
lazy val udf = SparkUserDefinedFunction(
private[IntegratedUDFTestUtils] lazy val udf = SparkUserDefinedFunction(
(input: Any) => input.toString,
StringType,
inputSchemas = Seq.fill(1)(None))

def apply(exprs: Column*): Column = udf(exprs: _*)

val prettyName: String = "Scala UDF"
}

/**
Expand Down
Expand Up @@ -383,24 +383,13 @@ class SQLQueryTestSuite extends QueryTest with SharedSQLContext {
val testCaseName = absPath.stripPrefix(inputFilePath).stripPrefix(File.separator)

if (file.getAbsolutePath.startsWith(s"$inputFilePath${File.separator}udf")) {
Seq(
Seq(TestScalaUDF("udf"), TestPythonUDF("udf"), TestScalarPandasUDF("udf")).map { udf =>
UDFTestCase(
s"$testCaseName - Scala UDF",
s"$testCaseName - ${udf.prettyName}",
absPath,
resultFile,
TestScalaUDF(name = "udf")),

UDFTestCase(
s"$testCaseName - Python UDF",
absPath,
resultFile,
TestPythonUDF(name = "udf")),

UDFTestCase(
s"$testCaseName - Scalar Pandas UDF",
absPath,
resultFile,
TestScalarPandasUDF(name = "udf")))
udf)
}
} else if (file.getAbsolutePath.startsWith(s"$inputFilePath${File.separator}pgSQL")) {
PgSQLTestCase(testCaseName, absPath, resultFile) :: Nil
} else {
Expand Down

0 comments on commit ac61f7d

Please sign in to comment.