From b6cb6218e539589f37ff8648dff068bef6e810e5 Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Tue, 23 Jan 2018 05:56:45 +0000 Subject: [PATCH 1/3] Extract parameter-less UDFs from aggregate. --- python/pyspark/sql/tests.py | 8 ++++++++ .../spark/sql/execution/python/ExtractPythonUDFs.scala | 4 ++-- 2 files changed, 10 insertions(+), 2 deletions(-) diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py index 4fee2ecde391b..dc8b821c12d04 100644 --- a/python/pyspark/sql/tests.py +++ b/python/pyspark/sql/tests.py @@ -1100,6 +1100,14 @@ def myudf(x): rows = [r[0] for r in df.selectExpr("udf(id)").take(2)] self.assertEqual(rows, [None, PythonOnlyPoint(1, 1)]) + def test_nonparam_udf_with_aggregate(self): + import pyspark.sql.functions as f + + df = self.spark.createDataFrame([(1,2), (1,2)]) + f_udf = f.udf(lambda: "const_str") + rows = df.distinct().withColumn("a", f_udf()).collect() + self.assertEqual(rows, [Row(_1=1, _2=2, a=u'const_str')]) + def test_infer_schema_with_udt(self): from pyspark.sql.tests import ExamplePoint, ExamplePointUDT row = Row(label=1.0, point=ExamplePoint(1.0, 2.0)) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ExtractPythonUDFs.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ExtractPythonUDFs.scala index 2f53fe788c7d0..67802f4e92e1e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ExtractPythonUDFs.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ExtractPythonUDFs.scala @@ -43,8 +43,8 @@ object ExtractPythonUDFFromAggregate extends Rule[LogicalPlan] { } private def hasPythonUdfOverAggregate(expr: Expression, agg: Aggregate): Boolean = { - expr.find { - e => e.isInstanceOf[PythonUDF] && e.find(belongAggregate(_, agg)).isDefined + expr.find { e => e.isInstanceOf[PythonUDF] && + (e.references.isEmpty || e.find(belongAggregate(_, agg)).isDefined) }.isDefined } From 5c3afbbdf762411023b06348b2bfe3dbc2ff4287 Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Tue, 23 Jan 2018 07:23:43 +0000 Subject: [PATCH 2/3] Fix python style. --- python/pyspark/sql/tests.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py index 27938c81e0013..a466ab87d882d 100644 --- a/python/pyspark/sql/tests.py +++ b/python/pyspark/sql/tests.py @@ -1109,7 +1109,7 @@ def myudf(x): def test_nonparam_udf_with_aggregate(self): import pyspark.sql.functions as f - df = self.spark.createDataFrame([(1,2), (1,2)]) + df = self.spark.createDataFrame([(1, 2), (1, 2)]) f_udf = f.udf(lambda: "const_str") rows = df.distinct().withColumn("a", f_udf()).collect() self.assertEqual(rows, [Row(_1=1, _2=2, a=u'const_str')]) From 74684a7d10009ef970d7d674d9c695b695c5da5c Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Tue, 23 Jan 2018 23:29:02 +0000 Subject: [PATCH 3/3] Fix doc. --- .../apache/spark/sql/execution/python/ExtractPythonUDFs.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ExtractPythonUDFs.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ExtractPythonUDFs.scala index db10fd9a21734..4ae4e164830be 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ExtractPythonUDFs.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ExtractPythonUDFs.scala @@ -30,7 +30,7 @@ import org.apache.spark.sql.execution.{FilterExec, ProjectExec, SparkPlan} /** * Extracts all the Python UDFs in logical aggregate, which depends on aggregate expression or - * grouping key, evaluate them after aggregate. + * grouping key, or doesn't depend on any above expressions, evaluate them after aggregate. */ object ExtractPythonUDFFromAggregate extends Rule[LogicalPlan] {