Skip to content

Commit

Permalink
[SPARK-23177][SQL][PYSPARK] Extract zero-parameter UDFs from aggregate
Browse files Browse the repository at this point in the history
## What changes were proposed in this pull request?

We extract Python UDFs in logical aggregate which depends on aggregate expression or grouping key in ExtractPythonUDFFromAggregate rule. But Python UDFs which don't depend on above expressions should also be extracted to avoid the issue reported in the JIRA.

A small code snippet to reproduce that issue looks like:
```python
import pyspark.sql.functions as f

df = spark.createDataFrame([(1,2), (3,4)])
f_udf = f.udf(lambda: str("const_str"))
df2 = df.distinct().withColumn("a", f_udf())
df2.show()
```

Error exception is raised as:
```
: org.apache.spark.sql.catalyst.errors.package$TreeNodeException: Binding attribute, tree: pythonUDF0#50
        at org.apache.spark.sql.catalyst.errors.package$.attachTree(package.scala:56)
        at org.apache.spark.sql.catalyst.expressions.BindReferences$$anonfun$bindReference$1.applyOrElse(BoundAttribute.scala:91)
        at org.apache.spark.sql.catalyst.expressions.BindReferences$$anonfun$bindReference$1.applyOrElse(BoundAttribute.scala:90)
        at org.apache.spark.sql.catalyst.trees.TreeNode$$anonfun$2.apply(TreeNode.scala:267)
        at org.apache.spark.sql.catalyst.trees.TreeNode$$anonfun$2.apply(TreeNode.scala:267)
        at org.apache.spark.sql.catalyst.trees.CurrentOrigin$.withOrigin(TreeNode.scala:70)
        at org.apache.spark.sql.catalyst.trees.TreeNode.transformDown(TreeNode.scala:266)
        at org.apache.spark.sql.catalyst.trees.TreeNode$$anonfun$transformDown$1.apply(TreeNode.scala:272)
        at org.apache.spark.sql.catalyst.trees.TreeNode$$anonfun$transformDown$1.apply(TreeNode.scala:272)
        at org.apache.spark.sql.catalyst.trees.TreeNode$$anonfun$4.apply(TreeNode.scala:306)
        at org.apache.spark.sql.catalyst.trees.TreeNode.mapProductIterator(TreeNode.scala:187)
        at org.apache.spark.sql.catalyst.trees.TreeNode.mapChildren(TreeNode.scala:304)
        at org.apache.spark.sql.catalyst.trees.TreeNode.transformDown(TreeNode.scala:272)
        at org.apache.spark.sql.catalyst.trees.TreeNode.transform(TreeNode.scala:256)
        at org.apache.spark.sql.catalyst.expressions.BindReferences$.bindReference(BoundAttribute.scala:90)
        at org.apache.spark.sql.execution.aggregate.HashAggregateExec$$anonfun$38.apply(HashAggregateExec.scala:514)
        at org.apache.spark.sql.execution.aggregate.HashAggregateExec$$anonfun$38.apply(HashAggregateExec.scala:513)
```

This exception raises because `HashAggregateExec` tries to bind the aliased Python UDF expression (e.g., `pythonUDF0#50 AS a#44`) to grouping key.

## How was this patch tested?

Added test.

Author: Liang-Chi Hsieh <viirya@gmail.com>

Closes #20360 from viirya/SPARK-23177.
  • Loading branch information
viirya authored and HyukjinKwon committed Jan 24, 2018
1 parent 15adcc8 commit a3911cf
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 2 deletions.
8 changes: 8 additions & 0 deletions python/pyspark/sql/tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -1106,6 +1106,14 @@ def myudf(x):
rows = [r[0] for r in df.selectExpr("udf(id)").take(2)]
self.assertEqual(rows, [None, PythonOnlyPoint(1, 1)])

def test_nonparam_udf_with_aggregate(self):
import pyspark.sql.functions as f

df = self.spark.createDataFrame([(1, 2), (1, 2)])
f_udf = f.udf(lambda: "const_str")
rows = df.distinct().withColumn("a", f_udf()).collect()
self.assertEqual(rows, [Row(_1=1, _2=2, a=u'const_str')])

def test_infer_schema_with_udt(self):
from pyspark.sql.tests import ExamplePoint, ExamplePointUDT
row = Row(label=1.0, point=ExamplePoint(1.0, 2.0))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ import org.apache.spark.sql.execution.{FilterExec, ProjectExec, SparkPlan}

/**
* Extracts all the Python UDFs in logical aggregate, which depends on aggregate expression or
* grouping key, evaluate them after aggregate.
* grouping key, or doesn't depend on any above expressions, evaluate them after aggregate.
*/
object ExtractPythonUDFFromAggregate extends Rule[LogicalPlan] {

Expand All @@ -45,7 +45,8 @@ object ExtractPythonUDFFromAggregate extends Rule[LogicalPlan] {

private def hasPythonUdfOverAggregate(expr: Expression, agg: Aggregate): Boolean = {
expr.find {
e => PythonUDF.isScalarPythonUDF(e) && e.find(belongAggregate(_, agg)).isDefined
e => PythonUDF.isScalarPythonUDF(e) &&
(e.references.isEmpty || e.find(belongAggregate(_, agg)).isDefined)
}.isDefined
}

Expand Down

0 comments on commit a3911cf

Please sign in to comment.