Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SPARK-38235][SQL][TESTS] Add test util for testing grouped aggregate pandas UDF #35615

Closed
wants to merge 7 commits into from
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -31,15 +31,16 @@ import org.apache.spark.sql.catalyst.expressions.{Cast, Expression, ExprId, Pyth
import org.apache.spark.sql.catalyst.plans.SQLHelper
import org.apache.spark.sql.execution.python.UserDefinedPythonFunction
import org.apache.spark.sql.expressions.SparkUserDefinedFunction
import org.apache.spark.sql.types.{DataType, StringType}
import org.apache.spark.sql.types.{DataType, IntegerType, StringType}

/**
* This object targets to integrate various UDF test cases so that Scalar UDF, Python UDF and
* Scalar Pandas UDFs can be tested in SBT & Maven tests.
* This object targets to integrate various UDF test cases so that Scalar UDF, Python UDF,
* Scalar Pandas UDF and Grouped Aggregate Pandas UDF can be tested in SBT & Maven tests.
*
* The available UDFs are special. It defines an UDF wrapped by cast. So, the input column is
* casted into string, UDF returns strings as are, and then output column is casted back to
* the input column. In this way, UDF is virtually no-op.
* The available UDFs are special. For Scalar UDF, Python UDF and Scalar Pandas UDF,
* it defines an UDF wrapped by cast. So, the input column is casted into string,
* UDF returns strings as are, and then output column is casted back to the input column.
* In this way, UDF is virtually no-op.
itholic marked this conversation as resolved.
Show resolved Hide resolved
*
* Note that, due to this implementation limitation, complex types such as map, array and struct
* types do not work with this UDFs because they cannot be same after the cast roundtrip.
Expand Down Expand Up @@ -69,6 +70,28 @@ import org.apache.spark.sql.types.{DataType, StringType}
* df.select(expr("udf_name(id)")
* df.select(pandasTestUDF(df("id")))
* }}}
*
* For Grouped Aggregate Pandas UDF, it defines an UDF that calculates the count using pandas.
* The UDF returns the count of the given column. In this way, UDF is virtually not no-op.
*
* To register Grouped Aggregate Pandas UDF in SQL:
* {{{
* val groupedAggPandasTestUDF = TestGroupedAggPandasUDF(name = "udf_name")
* registerTestUDF(groupedAggPandasTestUDF, spark)
itholic marked this conversation as resolved.
Show resolved Hide resolved
* }}}
*
* To use it in Scala API and SQL:
* {{{
* sql("SELECT udf_name(1)")
* val df = Seq(
* (536361, "85123A", 2, 17850),
* (536362, "85123B", 4, 17850),
* (536363, "86123A", 6, 17851)
* ).toDF("InvoiceNo", "StockCode", "Quantity", "CustomerID")
*
* df.groupBy("CustomerID").agg(expr("udf_name(Quantity)"))
* df.groupBy("CustomerID").agg(groupedAggPandasTestUDF(df("Quantity")))
* }}}
*/
object IntegratedUDFTestUtils extends SQLHelper {
import scala.sys.process._
Expand Down Expand Up @@ -190,6 +213,28 @@ object IntegratedUDFTestUtils extends SQLHelper {
throw new RuntimeException(s"Python executable [$pythonExec] and/or pyspark are unavailable.")
}

private lazy val pandasGroupedAggFunc: Array[Byte] = if (shouldTestGroupedAggPandasUDFs) {
var binaryPandasFunc: Array[Byte] = null
withTempPath { path =>
Process(
Seq(
pythonExec,
"-c",
"from pyspark.sql.types import IntegerType; " +
"from pyspark.serializers import CloudPickleSerializer; " +
s"f = open('$path', 'wb');" +
"f.write(CloudPickleSerializer().dumps((" +
"lambda x: x.agg('count'), IntegerType())))"),
None,
"PYTHONPATH" -> s"$pysparkPythonPath:$pythonPath").!!
binaryPandasFunc = Files.readAllBytes(path.toPath)
}
assert(binaryPandasFunc != null)
binaryPandasFunc
} else {
throw new RuntimeException(s"Python executable [$pythonExec] and/or pyspark are unavailable.")
}

// Make sure this map stays mutable - this map gets updated later in Python runners.
private val workerEnv = new java.util.HashMap[String, String]()
workerEnv.put("PYTHONPATH", s"$pysparkPythonPath:$pythonPath")
Expand All @@ -209,6 +254,8 @@ object IntegratedUDFTestUtils extends SQLHelper {
lazy val shouldTestScalarPandasUDFs: Boolean =
isPythonAvailable && isPandasAvailable && isPyArrowAvailable

lazy val shouldTestGroupedAggPandasUDFs: Boolean = shouldTestScalarPandasUDFs

/**
* A base trait for various UDFs defined in this object.
*/
Expand Down Expand Up @@ -333,6 +380,46 @@ object IntegratedUDFTestUtils extends SQLHelper {
val prettyName: String = "Scalar Pandas UDF"
}

/**
* A Grouped Aggregate Pandas UDF that takes one column, executes the
* Python native function calculating the count of the column using pandas.
*
* Virtually equivalent to:
*
* {{{
* import pandas as pd
* from pyspark.sql.functions import pandas_udf
*
* df = spark.createDataFrame(
* [(1, 1.0), (1, 2.0), (2, 3.0), (2, 5.0), (2, 10.0)], ("id", "v"))
*
* @pandas_udf("double")
* def pandas_count(v: pd.Series) -> int:
* return v.count()
itholic marked this conversation as resolved.
Show resolved Hide resolved
*
* count_col = pandas_count(df['v'])
* }}}
*/
case class TestGroupedAggPandasUDF(name: String) extends TestUDF {
private[IntegratedUDFTestUtils] lazy val udf = new UserDefinedPythonFunction(
name = name,
func = PythonFunction(
command = pandasGroupedAggFunc,
envVars = workerEnv.clone().asInstanceOf[java.util.Map[String, String]],
pythonIncludes = List.empty[String].asJava,
pythonExec = pythonExec,
pythonVer = pythonVer,
broadcastVars = List.empty[Broadcast[PythonBroadcast]].asJava,
accumulator = null),
dataType = IntegerType,
pythonEvalType = PythonEvalType.SQL_GROUPED_AGG_PANDAS_UDF,
udfDeterministic = true)

def apply(exprs: Column*): Column = udf(exprs: _*)

val prettyName: String = "Grouped Aggregate Pandas UDF"
}

/**
* A Scala UDF that takes one column, casts into string, executes the
* Scala native function, and casts back to the type of input column.
Expand Down Expand Up @@ -387,6 +474,7 @@ object IntegratedUDFTestUtils extends SQLHelper {
def registerTestUDF(testUDF: TestUDF, session: SparkSession): Unit = testUDF match {
case udf: TestPythonUDF => session.udf.registerPython(udf.name, udf.udf)
case udf: TestScalarPandasUDF => session.udf.registerPython(udf.name, udf.udf)
case udf: TestGroupedAggPandasUDF => session.udf.registerPython(udf.name, udf.udf)
case udf: TestScalaUDF => session.udf.register(udf.name, udf.udf)
case other => throw new RuntimeException(s"Unknown UDF class [${other.getClass}]")
}
Expand Down