Skip to content

Commit

Permalink
[SPARK-15888] [SQL] fix Python UDF with aggregate
Browse files Browse the repository at this point in the history
## What changes were proposed in this pull request?

After we move the ExtractPythonUDF rule into physical plan, Python UDF can't work on top of aggregate anymore, because they can't be evaluated before aggregate, should be evaluated after aggregate. This PR add another rule to extract these kind of Python UDF from logical aggregate, create a Project on top of Aggregate.

## How was this patch tested?

Added regression tests. The plan of added test query looks like this:
```
== Parsed Logical Plan ==
'Project [<lambda>('k, 's) AS t#26]
+- Aggregate [<lambda>(key#5L)], [<lambda>(key#5L) AS k#17, sum(cast(<lambda>(value#6) as bigint)) AS s#22L]
   +- LogicalRDD [key#5L, value#6]

== Analyzed Logical Plan ==
t: int
Project [<lambda>(k#17, s#22L) AS t#26]
+- Aggregate [<lambda>(key#5L)], [<lambda>(key#5L) AS k#17, sum(cast(<lambda>(value#6) as bigint)) AS s#22L]
   +- LogicalRDD [key#5L, value#6]

== Optimized Logical Plan ==
Project [<lambda>(agg#29, agg#30L) AS t#26]
+- Aggregate [<lambda>(key#5L)], [<lambda>(key#5L) AS agg#29, sum(cast(<lambda>(value#6) as bigint)) AS agg#30L]
   +- LogicalRDD [key#5L, value#6]

== Physical Plan ==
*Project [pythonUDF0#37 AS t#26]
+- BatchEvalPython [<lambda>(agg#29, agg#30L)], [agg#29, agg#30L, pythonUDF0#37]
   +- *HashAggregate(key=[<lambda>(key#5L)#31], functions=[sum(cast(<lambda>(value#6) as bigint))], output=[agg#29,agg#30L])
      +- Exchange hashpartitioning(<lambda>(key#5L)#31, 200)
         +- *HashAggregate(key=[pythonUDF0#34 AS <lambda>(key#5L)#31], functions=[partial_sum(cast(pythonUDF1#35 as bigint))], output=[<lambda>(key#5L)#31,sum#33L])
            +- BatchEvalPython [<lambda>(key#5L), <lambda>(value#6)], [key#5L, value#6, pythonUDF0#34, pythonUDF1#35]
               +- Scan ExistingRDD[key#5L,value#6]
```

Author: Davies Liu <davies@databricks.com>

Closes apache#13682 from davies/fix_py_udf.
  • Loading branch information
Davies Liu authored and davies committed Jun 15, 2016
1 parent 279bd4a commit 5389013
Show file tree
Hide file tree
Showing 4 changed files with 77 additions and 11 deletions.
10 changes: 9 additions & 1 deletion python/pyspark/sql/tests.py
Expand Up @@ -339,13 +339,21 @@ def test_broadcast_in_udf(self):

def test_udf_with_aggregate_function(self):
df = self.spark.createDataFrame([(1, "1"), (2, "2"), (1, "2"), (1, "2")], ["key", "value"])
from pyspark.sql.functions import udf, col
from pyspark.sql.functions import udf, col, sum
from pyspark.sql.types import BooleanType

my_filter = udf(lambda a: a == 1, BooleanType())
sel = df.select(col("key")).distinct().filter(my_filter(col("key")))
self.assertEqual(sel.collect(), [Row(key=1)])

my_copy = udf(lambda x: x, IntegerType())
my_add = udf(lambda a, b: int(a + b), IntegerType())
my_strlen = udf(lambda x: len(x), IntegerType())
sel = df.groupBy(my_copy(col("key")).alias("k"))\
.agg(sum(my_strlen(col("value"))).alias("s"))\
.select(my_add(col("k"), col("s")).alias("t"))
self.assertEqual(sel.collect(), [Row(t=4), Row(t=3)])

def test_basic_functions(self):
rdd = self.sc.parallelize(['{"foo":"bar"}', '{"foo":"baz"}'])
df = self.spark.read.json(rdd)
Expand Down
Expand Up @@ -20,6 +20,7 @@ package org.apache.spark.sql.execution
import org.apache.spark.sql.ExperimentalMethods
import org.apache.spark.sql.catalyst.catalog.SessionCatalog
import org.apache.spark.sql.catalyst.optimizer.Optimizer
import org.apache.spark.sql.execution.python.ExtractPythonUDFFromAggregate
import org.apache.spark.sql.internal.SQLConf

class SparkOptimizer(
Expand All @@ -28,6 +29,7 @@ class SparkOptimizer(
experimentalMethods: ExperimentalMethods)
extends Optimizer(catalog, conf) {

override def batches: Seq[Batch] = super.batches :+ Batch(
"User Provided Optimizers", fixedPoint, experimentalMethods.extraOptimizations: _*)
override def batches: Seq[Batch] = super.batches :+
Batch("Extract Python UDF from Aggregate", Once, ExtractPythonUDFFromAggregate) :+
Batch("User Provided Optimizers", fixedPoint, experimentalMethods.extraOptimizations: _*)
}
Expand Up @@ -46,6 +46,8 @@ case class BatchEvalPythonExec(udfs: Seq[PythonUDF], output: Seq[Attribute], chi

def children: Seq[SparkPlan] = child :: Nil

override def producedAttributes: AttributeSet = AttributeSet(output.drop(child.output.length))

private def collectFunctions(udf: PythonUDF): (ChainedPythonFunctions, Seq[Expression]) = {
udf.children match {
case Seq(u: PythonUDF) =>
Expand Down
Expand Up @@ -18,12 +18,68 @@
package org.apache.spark.sql.execution.python

import scala.collection.mutable
import scala.collection.mutable.ArrayBuffer

import org.apache.spark.sql.catalyst.expressions.{AttributeReference, Expression}
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression
import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, LogicalPlan, Project}
import org.apache.spark.sql.catalyst.rules.Rule
import org.apache.spark.sql.execution
import org.apache.spark.sql.execution.SparkPlan


/**
* Extracts all the Python UDFs in logical aggregate, which depends on aggregate expression or
* grouping key, evaluate them after aggregate.
*/
private[spark] object ExtractPythonUDFFromAggregate extends Rule[LogicalPlan] {

/**
* Returns whether the expression could only be evaluated within aggregate.
*/
private def belongAggregate(e: Expression, agg: Aggregate): Boolean = {
e.isInstanceOf[AggregateExpression] ||
agg.groupingExpressions.exists(_.semanticEquals(e))
}

private def hasPythonUdfOverAggregate(expr: Expression, agg: Aggregate): Boolean = {
expr.find {
e => e.isInstanceOf[PythonUDF] && e.find(belongAggregate(_, agg)).isDefined
}.isDefined
}

private def extract(agg: Aggregate): LogicalPlan = {
val projList = new ArrayBuffer[NamedExpression]()
val aggExpr = new ArrayBuffer[NamedExpression]()
agg.aggregateExpressions.foreach { expr =>
if (hasPythonUdfOverAggregate(expr, agg)) {
// Python UDF can only be evaluated after aggregate
val newE = expr transformDown {
case e: Expression if belongAggregate(e, agg) =>
val alias = e match {
case a: NamedExpression => a
case o => Alias(e, "agg")()
}
aggExpr += alias
alias.toAttribute
}
projList += newE.asInstanceOf[NamedExpression]
} else {
aggExpr += expr
projList += expr.toAttribute
}
}
// There is no Python UDF over aggregate expression
Project(projList, agg.copy(aggregateExpressions = aggExpr))
}

def apply(plan: LogicalPlan): LogicalPlan = plan transformUp {
case agg: Aggregate if agg.aggregateExpressions.exists(hasPythonUdfOverAggregate(_, agg)) =>
extract(agg)
}
}


/**
* Extracts PythonUDFs from operators, rewriting the query plan so that the UDF can be evaluated
* alone in a batch.
Expand Down Expand Up @@ -59,10 +115,12 @@ private[spark] object ExtractPythonUDFs extends Rule[SparkPlan] {
}

/**
* Extract all the PythonUDFs from the current operator.
* Extract all the PythonUDFs from the current operator and evaluate them before the operator.
*/
def extract(plan: SparkPlan): SparkPlan = {
private def extract(plan: SparkPlan): SparkPlan = {
val udfs = plan.expressions.flatMap(collectEvaluatableUDF)
// ignore the PythonUDF that come from second/third aggregate, which is not used
.filter(udf => udf.references.subsetOf(plan.inputSet))
if (udfs.isEmpty) {
// If there aren't any, we are done.
plan
Expand All @@ -89,11 +147,7 @@ private[spark] object ExtractPythonUDFs extends Rule[SparkPlan] {
// Other cases are disallowed as they are ambiguous or would require a cartesian
// product.
udfs.filterNot(attributeMap.contains).foreach { udf =>
if (udf.references.subsetOf(plan.inputSet)) {
sys.error(s"Invalid PythonUDF $udf, requires attributes from more than one child.")
} else {
sys.error(s"Unable to evaluate PythonUDF $udf. Missing input attributes.")
}
sys.error(s"Invalid PythonUDF $udf, requires attributes from more than one child.")
}

val rewritten = plan.transformExpressions {
Expand Down

0 comments on commit 5389013

Please sign in to comment.