diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py index 1d5d69169604d..c631ad8a4618d 100644 --- a/python/pyspark/sql/tests.py +++ b/python/pyspark/sql/tests.py @@ -339,13 +339,21 @@ def test_broadcast_in_udf(self): def test_udf_with_aggregate_function(self): df = self.spark.createDataFrame([(1, "1"), (2, "2"), (1, "2"), (1, "2")], ["key", "value"]) - from pyspark.sql.functions import udf, col + from pyspark.sql.functions import udf, col, sum from pyspark.sql.types import BooleanType my_filter = udf(lambda a: a == 1, BooleanType()) sel = df.select(col("key")).distinct().filter(my_filter(col("key"))) self.assertEqual(sel.collect(), [Row(key=1)]) + my_copy = udf(lambda x: x, IntegerType()) + my_add = udf(lambda a, b: int(a + b), IntegerType()) + my_strlen = udf(lambda x: len(x), IntegerType()) + sel = df.groupBy(my_copy(col("key")).alias("k"))\ + .agg(sum(my_strlen(col("value"))).alias("s"))\ + .select(my_add(col("k"), col("s")).alias("t")) + self.assertEqual(sel.collect(), [Row(t=4), Row(t=3)]) + def test_basic_functions(self): rdd = self.sc.parallelize(['{"foo":"bar"}', '{"foo":"baz"}']) df = self.spark.read.json(rdd) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkOptimizer.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkOptimizer.scala index 08b2d7fcd4882..12a10cba20fe9 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkOptimizer.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkOptimizer.scala @@ -20,6 +20,7 @@ package org.apache.spark.sql.execution import org.apache.spark.sql.ExperimentalMethods import org.apache.spark.sql.catalyst.catalog.SessionCatalog import org.apache.spark.sql.catalyst.optimizer.Optimizer +import org.apache.spark.sql.execution.python.ExtractPythonUDFFromAggregate import org.apache.spark.sql.internal.SQLConf class SparkOptimizer( @@ -28,6 +29,7 @@ class SparkOptimizer( experimentalMethods: ExperimentalMethods) extends Optimizer(catalog, conf) { - override def batches: Seq[Batch] = super.batches :+ Batch( - "User Provided Optimizers", fixedPoint, experimentalMethods.extraOptimizations: _*) + override def batches: Seq[Batch] = super.batches :+ + Batch("Extract Python UDF from Aggregate", Once, ExtractPythonUDFFromAggregate) :+ + Batch("User Provided Optimizers", fixedPoint, experimentalMethods.extraOptimizations: _*) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/BatchEvalPythonExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/BatchEvalPythonExec.scala index 061d7c7f79de8..d9bf4d3ccf698 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/BatchEvalPythonExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/BatchEvalPythonExec.scala @@ -46,6 +46,8 @@ case class BatchEvalPythonExec(udfs: Seq[PythonUDF], output: Seq[Attribute], chi def children: Seq[SparkPlan] = child :: Nil + override def producedAttributes: AttributeSet = AttributeSet(output.drop(child.output.length)) + private def collectFunctions(udf: PythonUDF): (ChainedPythonFunctions, Seq[Expression]) = { udf.children match { case Seq(u: PythonUDF) => diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ExtractPythonUDFs.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ExtractPythonUDFs.scala index ab192360e1c1f..668470ee6a29a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ExtractPythonUDFs.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ExtractPythonUDFs.scala @@ -18,12 +18,68 @@ package org.apache.spark.sql.execution.python import scala.collection.mutable +import scala.collection.mutable.ArrayBuffer -import org.apache.spark.sql.catalyst.expressions.{AttributeReference, Expression} +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression +import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, LogicalPlan, Project} import org.apache.spark.sql.catalyst.rules.Rule import org.apache.spark.sql.execution import org.apache.spark.sql.execution.SparkPlan + +/** + * Extracts all the Python UDFs in logical aggregate, which depends on aggregate expression or + * grouping key, evaluate them after aggregate. + */ +private[spark] object ExtractPythonUDFFromAggregate extends Rule[LogicalPlan] { + + /** + * Returns whether the expression could only be evaluated within aggregate. + */ + private def belongAggregate(e: Expression, agg: Aggregate): Boolean = { + e.isInstanceOf[AggregateExpression] || + agg.groupingExpressions.exists(_.semanticEquals(e)) + } + + private def hasPythonUdfOverAggregate(expr: Expression, agg: Aggregate): Boolean = { + expr.find { + e => e.isInstanceOf[PythonUDF] && e.find(belongAggregate(_, agg)).isDefined + }.isDefined + } + + private def extract(agg: Aggregate): LogicalPlan = { + val projList = new ArrayBuffer[NamedExpression]() + val aggExpr = new ArrayBuffer[NamedExpression]() + agg.aggregateExpressions.foreach { expr => + if (hasPythonUdfOverAggregate(expr, agg)) { + // Python UDF can only be evaluated after aggregate + val newE = expr transformDown { + case e: Expression if belongAggregate(e, agg) => + val alias = e match { + case a: NamedExpression => a + case o => Alias(e, "agg")() + } + aggExpr += alias + alias.toAttribute + } + projList += newE.asInstanceOf[NamedExpression] + } else { + aggExpr += expr + projList += expr.toAttribute + } + } + // There is no Python UDF over aggregate expression + Project(projList, agg.copy(aggregateExpressions = aggExpr)) + } + + def apply(plan: LogicalPlan): LogicalPlan = plan transformUp { + case agg: Aggregate if agg.aggregateExpressions.exists(hasPythonUdfOverAggregate(_, agg)) => + extract(agg) + } +} + + /** * Extracts PythonUDFs from operators, rewriting the query plan so that the UDF can be evaluated * alone in a batch. @@ -59,10 +115,12 @@ private[spark] object ExtractPythonUDFs extends Rule[SparkPlan] { } /** - * Extract all the PythonUDFs from the current operator. + * Extract all the PythonUDFs from the current operator and evaluate them before the operator. */ - def extract(plan: SparkPlan): SparkPlan = { + private def extract(plan: SparkPlan): SparkPlan = { val udfs = plan.expressions.flatMap(collectEvaluatableUDF) + // ignore the PythonUDF that come from second/third aggregate, which is not used + .filter(udf => udf.references.subsetOf(plan.inputSet)) if (udfs.isEmpty) { // If there aren't any, we are done. plan @@ -89,11 +147,7 @@ private[spark] object ExtractPythonUDFs extends Rule[SparkPlan] { // Other cases are disallowed as they are ambiguous or would require a cartesian // product. udfs.filterNot(attributeMap.contains).foreach { udf => - if (udf.references.subsetOf(plan.inputSet)) { - sys.error(s"Invalid PythonUDF $udf, requires attributes from more than one child.") - } else { - sys.error(s"Unable to evaluate PythonUDF $udf. Missing input attributes.") - } + sys.error(s"Invalid PythonUDF $udf, requires attributes from more than one child.") } val rewritten = plan.transformExpressions {