From 692b54ff0b38f26fb8eb59b1d958df45dbb59aa3 Mon Sep 17 00:00:00 2001 From: Li Jin Date: Mon, 4 Dec 2017 18:19:26 -0500 Subject: [PATCH 01/35] Initial commit: wip --- .../spark/api/python/PythonRunner.scala | 3 + python/pyspark/rdd.py | 1 + python/pyspark/sql/functions.py | 5 +- python/pyspark/sql/group.py | 13 +- python/pyspark/sql/tests.py | 83 +++++++++++ python/pyspark/sql/udf.py | 6 +- python/pyspark/worker.py | 22 ++- .../sql/catalyst/optimizer/Optimizer.scala | 2 + .../logical/pythonLogicalOperators.scala | 10 ++ .../spark/sql/RelationalGroupedDataset.scala | 31 ++++ .../spark/sql/execution/SparkStrategies.scala | 2 + .../python/AggregateInPandasExec.scala | 135 ++++++++++++++++++ .../execution/python/ExtractPythonUDFs.scala | 1 + 13 files changed, 306 insertions(+), 8 deletions(-) create mode 100644 sql/core/src/main/scala/org/apache/spark/sql/execution/python/AggregateInPandasExec.scala diff --git a/core/src/main/scala/org/apache/spark/api/python/PythonRunner.scala b/core/src/main/scala/org/apache/spark/api/python/PythonRunner.scala index 1ec0e717fac29..ffc0777c0aa46 100644 --- a/core/src/main/scala/org/apache/spark/api/python/PythonRunner.scala +++ b/core/src/main/scala/org/apache/spark/api/python/PythonRunner.scala @@ -39,13 +39,16 @@ private[spark] object PythonEvalType { val SQL_PANDAS_SCALAR_UDF = 200 val SQL_PANDAS_GROUP_MAP_UDF = 201 + val SQL_PANDAS_GROUP_AGG_UDF = 202 def toString(pythonEvalType: Int): String = pythonEvalType match { case NON_UDF => "NON_UDF" case SQL_BATCHED_UDF => "SQL_BATCHED_UDF" case SQL_PANDAS_SCALAR_UDF => "SQL_PANDAS_SCALAR_UDF" case SQL_PANDAS_GROUP_MAP_UDF => "SQL_PANDAS_GROUP_MAP_UDF" + case SQL_PANDAS_GROUP_AGG_UDF => "SQL_PANDAS_GROUP_AGG_UDF" } + } /** diff --git a/python/pyspark/rdd.py b/python/pyspark/rdd.py index 340bc3a6b7470..bfeb034e8e242 100644 --- a/python/pyspark/rdd.py +++ b/python/pyspark/rdd.py @@ -70,6 +70,7 @@ class PythonEvalType(object): SQL_PANDAS_SCALAR_UDF = 200 SQL_PANDAS_GROUP_MAP_UDF = 201 + SQL_PANDAS_GROUP_AGG_UDF = 202 def portable_hash(x): diff --git a/python/pyspark/sql/functions.py b/python/pyspark/sql/functions.py index 961b3267b44cf..77aaf3d19127f 100644 --- a/python/pyspark/sql/functions.py +++ b/python/pyspark/sql/functions.py @@ -2089,6 +2089,8 @@ class PandasUDFType(object): GROUP_MAP = PythonEvalType.SQL_PANDAS_GROUP_MAP_UDF + GROUP_AGG = PythonEvalType.SQL_PANDAS_GROUP_AGG_UDF + @since(1.3) def udf(f=None, returnType=StringType()): @@ -2267,7 +2269,8 @@ def pandas_udf(f=None, returnType=None, functionType=None): raise ValueError("Invalid returnType: returnType can not be None") if eval_type not in [PythonEvalType.SQL_PANDAS_SCALAR_UDF, - PythonEvalType.SQL_PANDAS_GROUP_MAP_UDF]: + PythonEvalType.SQL_PANDAS_GROUP_MAP_UDF, + PythonEvalType.SQL_PANDAS_GROUP_AGG_UDF]: raise ValueError("Invalid functionType: " "functionType must be one the values from PandasUDFType") diff --git a/python/pyspark/sql/group.py b/python/pyspark/sql/group.py index 22061b83eb78c..8a25f14ac2993 100644 --- a/python/pyspark/sql/group.py +++ b/python/pyspark/sql/group.py @@ -19,7 +19,7 @@ from pyspark.rdd import ignore_unicode_prefix, PythonEvalType from pyspark.sql.column import Column, _to_seq, _to_java_column, _create_column_from_literal from pyspark.sql.dataframe import DataFrame -from pyspark.sql.udf import UserDefinedFunction +from pyspark.sql.udf import UserDefinedFunction, UDFColumn from pyspark.sql.types import * __all__ = ["GroupedData"] @@ -89,8 +89,15 @@ def agg(self, *exprs): else: # Columns assert all(isinstance(c, Column) for c in exprs), "all exprs should be Column" - jdf = self._jgd.agg(exprs[0]._jc, - _to_seq(self.sql_ctx._sc, [c._jc for c in exprs[1:]])) + if isinstance(exprs[0], UDFColumn): + assert all(isinstance(c, UDFColumn) for c in exprs) + jdf = self._jgd.aggInPandas( + _to_seq(self.sql_ctx._sc, [c._jc for c in exprs])) + else: + jdf = self._jgd.agg(exprs[0]._jc, + _to_seq(self.sql_ctx._sc, [c._jc for c in exprs[1:]])) + + return DataFrame(jdf, self.sql_ctx) @dfapi diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py index f84aa3d68b808..6c1a0eccb7c26 100644 --- a/python/pyspark/sql/tests.py +++ b/python/pyspark/sql/tests.py @@ -4352,6 +4352,89 @@ def test_unsupported_types(self): with self.assertRaisesRegexp(Exception, 'Unsupported data type'): df.groupby('id').apply(f).collect() +@unittest.skipIf(not _have_pandas or not _have_arrow, "Pandas or Arrow not installed") +class GroupbyAggTests(ReusedSQLTestCase): + def assertFramesEqual(self, expected, result): + msg = ("DataFrames are not equal: " + + ("\n\nExpected:\n%s\n%s" % (expected, expected.dtypes)) + + ("\n\nResult:\n%s\n%s" % (result, result.dtypes))) + self.assertTrue(expected.equals(result), msg=msg) + + @property + def data(self): + from pyspark.sql.functions import array, explode, col, lit + return self.spark.range(10).toDF('id') \ + .withColumn("vs", array([lit(i * 1.0) + col('id') for i in range(20, 30)])) \ + .withColumn("v", explode(col('vs'))) \ + .drop('vs') \ + .withColumn('w', lit(1.0)) + + def test_basic(self): + import numpy as np + from pyspark.sql.functions import pandas_udf, PandasUDFType, col, lit, sum, mean + + df = self.data + + @pandas_udf('double', PandasUDFType.GROUP_AGG) + def mean_udf(v, w): + return np.average(v, weights=w) + + result1 = df.groupby('id').agg(mean_udf(df.v, lit(1.0))).sort('id').toPandas() + expected1 = df.groupby('id').agg(mean(df.v).alias('mean_udf')).sort('id').toPandas() + self.assertFramesEqual(expected1, result1) + + result2 = df.groupby((col('id') + 1).alias('id')).agg(mean_udf(df.v, lit(1.0))).sort('id').toPandas() + expected2 = df.groupby((col('id') + 1).alias('id')).agg(mean(df.v).alias('mean_udf')).sort('id').toPandas() + self.assertFramesEqual(expected2, result2) + + result3 = df.groupby('id').agg(mean_udf(df.v, df.w)).sort('id').toPandas() + expected3 = df.groupby('id').agg(mean(df.v).alias('mean_udf')).sort('id').toPandas() + self.assertFramesEqual(expected3, result3) + + result4 = df.groupby((col('id') + 1).alias('id')).agg(mean_udf(df.v, df.w)).sort('id').toPandas() + expected4 = df.groupby((col('id') + 1).alias('id')).agg(mean(df.v).alias('mean_udf')).sort('id').toPandas() + self.assertFramesEqual(expected4, result4) + + def test_multiple(self): + import numpy as np + from pyspark.sql.functions import pandas_udf, PandasUDFType, col, lit, sum, mean + df = self.data + + @pandas_udf('double', PandasUDFType.GROUP_AGG) + def mean_udf(v): + return v.mean() + + @pandas_udf('double', PandasUDFType.GROUP_AGG) + def sum_udf(v): + return v.sum() + + @pandas_udf('double', PandasUDFType.GROUP_AGG) + def weighted_mean_udf(v, w): + return np.average(v, weights=w) + + # TODO: Fix alias + #result1 = df.groupBy('id') \ + # .agg(mean_udf(df.v).alias('mean'), \ + # sum_udf(df.v).alias('sum'), \ + # weighted_mean_udf(df.v, df.w).alias('wm')) \ + # .sort('id') \ + # .toPandas() + + result1 = df.groupBy('id') \ + .agg(mean_udf(df.v), \ + sum_udf(df.v), \ + weighted_mean_udf(df.v, df.w)) \ + .sort('id') \ + .toPandas() + + expected1 = df.groupBy('id') \ + .agg(mean(df.v).alias('mean_udf'), \ + sum(df.v).alias('sum_udf'), \ + mean(df.v).alias('weighted_mean_udf')) \ + .sort('id') \ + .toPandas() + self.assertFramesEqual(expected1, result1) + if __name__ == "__main__": from pyspark.sql.tests import * diff --git a/python/pyspark/sql/udf.py b/python/pyspark/sql/udf.py index c77f19f89a442..471326dda0d06 100644 --- a/python/pyspark/sql/udf.py +++ b/python/pyspark/sql/udf.py @@ -63,6 +63,10 @@ def _create_udf(f, returnType, evalType): return udf_obj._wrapped() +class UDFColumn(Column): + pass + + class UserDefinedFunction(object): """ User defined function in Python @@ -141,7 +145,7 @@ def _create_judf(self): def __call__(self, *cols): judf = self._judf sc = SparkContext._active_spark_context - return Column(judf.apply(_to_seq(sc, cols, _to_java_column))) + return UDFColumn(judf.apply(_to_seq(sc, cols, _to_java_column))) # This function is for improving the online help system in the interactive interpreter. # For example, the built-in help / pydoc.help. It wraps the UDF with the docstring and diff --git a/python/pyspark/worker.py b/python/pyspark/worker.py index e6737ae1c1285..173d8fb2856fa 100644 --- a/python/pyspark/worker.py +++ b/python/pyspark/worker.py @@ -110,6 +110,17 @@ def wrapped(*series): return wrapped +def wrap_pandas_group_agg_udf(f, return_type): + arrow_return_type = to_arrow_type(return_type) + + def wrapped(*series): + import pandas as pd + result = f(*series) + return pd.Series(result) + + return lambda *a: (wrapped(*a), arrow_return_type) + + def read_single_udf(pickleSer, infile, eval_type): num_arg = read_int(infile) arg_offsets = [read_int(infile) for i in range(num_arg)] @@ -126,8 +137,12 @@ def read_single_udf(pickleSer, infile, eval_type): return arg_offsets, wrap_pandas_scalar_udf(row_func, return_type) elif eval_type == PythonEvalType.SQL_PANDAS_GROUP_MAP_UDF: return arg_offsets, wrap_pandas_group_map_udf(row_func, return_type) - else: + elif eval_type == PythonEvalType.SQL_PANDAS_GROUP_AGG_UDF: + return arg_offsets, wrap_pandas_group_agg_udf(row_func, return_type) + elif eval_type == PythonEvalType.SQL_BATCHED_UDF: return arg_offsets, wrap_udf(row_func, return_type) + else: + raise ValueError("Unknown eval type: {}".format(eval_type)) def read_udfs(pickleSer, infile, eval_type): @@ -148,8 +163,9 @@ def read_udfs(pickleSer, infile, eval_type): func = lambda _, it: map(mapper, it) - if eval_type == PythonEvalType.SQL_PANDAS_SCALAR_UDF \ - or eval_type == PythonEvalType.SQL_PANDAS_GROUP_MAP_UDF: + if eval_type in (PythonEvalType.SQL_PANDAS_SCALAR_UDF, + PythonEvalType.SQL_PANDAS_GROUP_MAP_UDF, + PythonEvalType.SQL_PANDAS_GROUP_AGG_UDF): timezone = utf8_deserializer.loads(infile) ser = ArrowStreamPandasSerializer(timezone) else: diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala index c794ba8619322..f26608fae8a5b 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala @@ -452,6 +452,8 @@ object ColumnPruning extends Rule[LogicalPlan] { // Prunes the unused columns from child of Aggregate/Expand/Generate case a @ Aggregate(_, _, child) if (child.outputSet -- a.references).nonEmpty => a.copy(child = prunedChild(child, a.references)) + case a @ AggregateInPandas(_, _, _, child) if (child.outputSet -- a.references).nonEmpty => + a.copy(child = prunedChild(child, a.references)) case f @ FlatMapGroupsInPandas(_, _, _, child) if (child.outputSet -- f.references).nonEmpty => f.copy(child = prunedChild(child, f.references)) case e @ Expand(_, _, child) if (child.outputSet -- e.references).nonEmpty => diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/pythonLogicalOperators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/pythonLogicalOperators.scala index 254687ec00880..059e304ac4b55 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/pythonLogicalOperators.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/pythonLogicalOperators.scala @@ -38,3 +38,13 @@ case class FlatMapGroupsInPandas( */ override val producedAttributes = AttributeSet(output) } + +case class AggregateInPandas( + groupingAttributes: Seq[Attribute], + functionExprs: Seq[Expression], + output: Seq[Attribute], + child: LogicalPlan +) extends UnaryNode { + override val references: AttributeSet = child.outputSet + override val producedAttributes = AttributeSet(groupingAttributes ++ output) +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/RelationalGroupedDataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/RelationalGroupedDataset.scala index a009c00b0abc5..afd96b62e3a36 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/RelationalGroupedDataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/RelationalGroupedDataset.scala @@ -437,6 +437,37 @@ class RelationalGroupedDataset protected[sql]( df.logicalPlan)) } + + private[sql] def aggInPandas(columns: Seq[Column]): DataFrame = { + val exprs = columns.map(column => column.expr.asInstanceOf[PythonUDF]) + + val groupingNamedExpressions = groupingExprs.map { + case ne: NamedExpression => ne + case other => Alias(other, other.toString)() + } + + val groupingAttributes = groupingNamedExpressions.map(_.toAttribute) + + val child = df.logicalPlan + + val childrenExpressions = exprs.flatMap(expr => + expr.children.map { + case ne: NamedExpression => ne + case other => Alias(other, other.toString)() + }) + + val project = Project(groupingNamedExpressions ++ childrenExpressions, child) + + val udfOutputs = exprs.flatMap(expr => + Seq(AttributeReference(expr.name, expr.dataType)()) + ) + + val output: Seq[Attribute] = groupingAttributes ++ udfOutputs + + val plan = AggregateInPandas(groupingAttributes, exprs, output, project) + + Dataset.ofRows(df.sparkSession, plan) + } /** * Applies a grouped vectorized python user-defined function to each group of data. * The user-defined function defines a transformation: `pandas.DataFrame` -> `pandas.DataFrame`. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala index 910294853c318..c0b342eb3de12 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala @@ -452,6 +452,8 @@ abstract class SparkStrategies extends QueryPlanner[SparkPlan] { case logical.FlatMapGroupsInR(f, p, b, is, os, key, value, grouping, data, objAttr, child) => execution.FlatMapGroupsInRExec(f, p, b, is, os, key, value, grouping, data, objAttr, planLater(child)) :: Nil + case logical.AggregateInPandas(grouping, func, output, child) => + execution.python.AggregateInPandasExec(grouping, func, output, planLater(child)) :: Nil case logical.FlatMapGroupsInPandas(grouping, func, output, child) => execution.python.FlatMapGroupsInPandasExec(grouping, func, output, planLater(child)) :: Nil case logical.MapElements(f, _, _, objAttr, child) => diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/AggregateInPandasExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/AggregateInPandasExec.scala new file mode 100644 index 0000000000000..f8913bb9b616d --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/AggregateInPandasExec.scala @@ -0,0 +1,135 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.python + +import java.io.File + +import scala.collection.mutable.ArrayBuffer + +import org.apache.spark.{SparkEnv, TaskContext} +import org.apache.spark.api.python.{ChainedPythonFunctions, PythonEvalType} +import org.apache.spark.rdd.RDD +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions.{Ascending, Attribute, AttributeSet, Expression, JoinedRow, SortOrder, UnsafeProjection, UnsafeRow} +import org.apache.spark.sql.catalyst.plans.physical.{AllTuples, ClusteredDistribution, Distribution, Partitioning} +import org.apache.spark.sql.execution.{GroupedIterator, SparkPlan, UnaryExecNode} +import org.apache.spark.sql.types.StructType +import org.apache.spark.util.Utils + +case class AggregateInPandasExec( + groupingAttributes: Seq[Attribute], + func: Seq[Expression], + output: Seq[Attribute], + child: SparkPlan) + extends UnaryExecNode { + private val udfs = func.map(expr => expr.asInstanceOf[PythonUDF]) + + override def outputPartitioning: Partitioning = child.outputPartitioning + + override def producedAttributes: AttributeSet = AttributeSet(output) + + override def requiredChildDistribution: Seq[Distribution] = { + if (groupingAttributes.isEmpty) { + AllTuples :: Nil + } else { + ClusteredDistribution(groupingAttributes) :: Nil + } + } + + private def collectFunctions(udf: PythonUDF): (ChainedPythonFunctions, Seq[Expression]) = { + udf.children match { + case Seq(u: PythonUDF) => + val (chained, children) = collectFunctions(u) + (ChainedPythonFunctions(chained.funcs ++ Seq(udf.func)), children) + case children => + // There should not be any other UDFs, or the children can't be evaluated directly. + assert(children.forall(_.find(_.isInstanceOf[PythonUDF]).isEmpty)) + (ChainedPythonFunctions(Seq(udf.func)), udf.children) + } + } + + override def requiredChildOrdering: Seq[Seq[SortOrder]] = + Seq(groupingAttributes.map(SortOrder(_, Ascending))) + + override protected def doExecute(): RDD[InternalRow] = { + val inputRDD = child.execute() + + val bufferSize = inputRDD.conf.getInt("spark.buffer.size", 65536) + val reuseWorker = inputRDD.conf.getBoolean("spark.python.worker.reuse", defaultValue = true) + // val argOffsets = Array((0 until (child.output.length - groupingAttributes.length)).toArray) + val schema = StructType(child.schema.drop(groupingAttributes.length)) + val sessionLocalTimeZone = conf.sessionLocalTimeZone + val pandasRespectSessionTimeZone = conf.pandasRespectSessionTimeZone + + val (pyFuncs, inputs) = udfs.map(collectFunctions).unzip + + val allInputs = new ArrayBuffer[Expression] + + val argOffsets = inputs.map { input => + input.map { e => + allInputs += e + allInputs.length - 1 + }.toArray + }.toArray + + inputRDD.mapPartitionsInternal { iter => + val grouped = if (groupingAttributes.isEmpty) { + Iterator((null, iter)) + } else { + val groupedIter = GroupedIterator(iter, groupingAttributes, child.output) + + val dropGrouping = + UnsafeProjection.create(child.output.drop(groupingAttributes.length), child.output) + groupedIter.map { + case (k, groupedRowIter) => (k, groupedRowIter.map(dropGrouping)) + } + } + + val context = TaskContext.get() + + // The queue used to buffer input rows so we can drain it to + // combine input with output from Python. + val queue = HybridRowQueue(context.taskMemoryManager(), + new File(Utils.getLocalDir(SparkEnv.get.conf)), groupingAttributes.length) + context.addTaskCompletionListener { _ => + queue.close() + } + + // Add rows to queue to join later with the result. + val projectedRowIter = grouped.map { case (groupingKey, rows) => + queue.add(groupingKey.asInstanceOf[UnsafeRow]) + rows + } + + val columnarBatchIter = new ArrowPythonRunner( + pyFuncs, bufferSize, reuseWorker, + PythonEvalType.SQL_PANDAS_GROUP_AGG_UDF, argOffsets, schema, + sessionLocalTimeZone, pandasRespectSessionTimeZone) + .compute(projectedRowIter, context.partitionId(), context) + + val joined = new JoinedRow + val resultProj = UnsafeProjection.create(output, output) + + columnarBatchIter.map(_.rowIterator.next()).map{ outputRow => + val leftRow = queue.remove() + val joinedRow = joined(leftRow, outputRow) + resultProj(joinedRow) + } + } + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ExtractPythonUDFs.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ExtractPythonUDFs.scala index 2f53fe788c7d0..171619d247bf3 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ExtractPythonUDFs.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ExtractPythonUDFs.scala @@ -113,6 +113,7 @@ object ExtractPythonUDFs extends Rule[SparkPlan] with PredicateHelper { def apply(plan: SparkPlan): SparkPlan = plan transformUp { // FlatMapGroupsInPandas can be evaluated directly in python worker // Therefore we don't need to extract the UDFs + case plan: AggregateInPandasExec => plan case plan: FlatMapGroupsInPandasExec => plan case plan: SparkPlan => extract(plan) } From 11321f68e83d5f3f6571954f74d4dab5a2b74f71 Mon Sep 17 00:00:00 2001 From: Li Jin Date: Tue, 5 Dec 2017 13:09:22 -0500 Subject: [PATCH 02/35] wip --- python/pyspark/sql/tests.py | 128 ++++++++++-------- .../expressions/aggregate/interfaces.scala | 1 + .../sql/catalyst/rules/RuleExecutor.scala | 2 +- 3 files changed, 73 insertions(+), 58 deletions(-) diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py index 6c1a0eccb7c26..dae09cfcfe6e3 100644 --- a/python/pyspark/sql/tests.py +++ b/python/pyspark/sql/tests.py @@ -4369,71 +4369,85 @@ def data(self): .drop('vs') \ .withColumn('w', lit(1.0)) - def test_basic(self): + # def test_basic(self): + # import numpy as np + # from pyspark.sql.functions import pandas_udf, PandasUDFType, col, lit, sum, mean + # + # df = self.data + # + # @pandas_udf('double', PandasUDFType.GROUP_AGG) + # def mean_udf(v, w): + # return np.average(v, weights=w) + # + # result1 = df.groupby('id').agg(mean_udf(df.v, lit(1.0))).sort('id').toPandas() + # expected1 = df.groupby('id').agg(mean(df.v).alias('mean_udf')).sort('id').toPandas() + # self.assertFramesEqual(expected1, result1) + # + # result2 = df.groupby((col('id') + 1).alias('id')).agg(mean_udf(df.v, lit(1.0))).sort('id').toPandas() + # expected2 = df.groupby((col('id') + 1).alias('id')).agg(mean(df.v).alias('mean_udf')).sort('id').toPandas() + # self.assertFramesEqual(expected2, result2) + # + # result3 = df.groupby('id').agg(mean_udf(df.v, df.w)).sort('id').toPandas() + # expected3 = df.groupby('id').agg(mean(df.v).alias('mean_udf')).sort('id').toPandas() + # self.assertFramesEqual(expected3, result3) + # + # result4 = df.groupby((col('id') + 1).alias('id')).agg(mean_udf(df.v, df.w)).sort('id').toPandas() + # expected4 = df.groupby((col('id') + 1).alias('id')).agg(mean(df.v).alias('mean_udf')).sort('id').toPandas() + # self.assertFramesEqual(expected4, result4) + + def test_alias(self): import numpy as np - from pyspark.sql.functions import pandas_udf, PandasUDFType, col, lit, sum, mean + from pyspark.sql.functions import pandas_udf, PandasUDFType, mean - df = self.data - - @pandas_udf('double', PandasUDFType.GROUP_AGG) - def mean_udf(v, w): - return np.average(v, weights=w) - - result1 = df.groupby('id').agg(mean_udf(df.v, lit(1.0))).sort('id').toPandas() - expected1 = df.groupby('id').agg(mean(df.v).alias('mean_udf')).sort('id').toPandas() - self.assertFramesEqual(expected1, result1) - - result2 = df.groupby((col('id') + 1).alias('id')).agg(mean_udf(df.v, lit(1.0))).sort('id').toPandas() - expected2 = df.groupby((col('id') + 1).alias('id')).agg(mean(df.v).alias('mean_udf')).sort('id').toPandas() - self.assertFramesEqual(expected2, result2) - - result3 = df.groupby('id').agg(mean_udf(df.v, df.w)).sort('id').toPandas() - expected3 = df.groupby('id').agg(mean(df.v).alias('mean_udf')).sort('id').toPandas() - self.assertFramesEqual(expected3, result3) - - result4 = df.groupby((col('id') + 1).alias('id')).agg(mean_udf(df.v, df.w)).sort('id').toPandas() - expected4 = df.groupby((col('id') + 1).alias('id')).agg(mean(df.v).alias('mean_udf')).sort('id').toPandas() - self.assertFramesEqual(expected4, result4) + #df = self.data - def test_multiple(self): - import numpy as np - from pyspark.sql.functions import pandas_udf, PandasUDFType, col, lit, sum, mean - df = self.data + df = self.spark.range(10) @pandas_udf('double', PandasUDFType.GROUP_AGG) def mean_udf(v): return v.mean() - @pandas_udf('double', PandasUDFType.GROUP_AGG) - def sum_udf(v): - return v.sum() - - @pandas_udf('double', PandasUDFType.GROUP_AGG) - def weighted_mean_udf(v, w): - return np.average(v, weights=w) - - # TODO: Fix alias - #result1 = df.groupBy('id') \ - # .agg(mean_udf(df.v).alias('mean'), \ - # sum_udf(df.v).alias('sum'), \ - # weighted_mean_udf(df.v, df.w).alias('wm')) \ - # .sort('id') \ - # .toPandas() - - result1 = df.groupBy('id') \ - .agg(mean_udf(df.v), \ - sum_udf(df.v), \ - weighted_mean_udf(df.v, df.w)) \ - .sort('id') \ - .toPandas() - - expected1 = df.groupBy('id') \ - .agg(mean(df.v).alias('mean_udf'), \ - sum(df.v).alias('sum_udf'), \ - mean(df.v).alias('weighted_mean_udf')) \ - .sort('id') \ - .toPandas() - self.assertFramesEqual(expected1, result1) + result1 = df.groupby('id').agg(mean_udf(df.id).alias('mean')) + + result1.show() + + # result2 = df.groupby('id').agg(mean_udf(df.v).name('mean')) + + # result2.show() + + print(result1._jdf.queryExecution().analyzed().toString()) + + # def test_multiple(self): + # import numpy as np + # from pyspark.sql.functions import pandas_udf, PandasUDFType, col, lit, sum, mean + # df = self.data + # + # @pandas_udf('double', PandasUDFType.GROUP_AGG) + # def mean_udf(v): + # return v.mean() + # + # @pandas_udf('double', PandasUDFType.GROUP_AGG) + # def sum_udf(v): + # return v.sum() + # + # @pandas_udf('double', PandasUDFType.GROUP_AGG) + # def weighted_mean_udf(v, w): + # return np.average(v, weights=w) + # + # result1 = df.groupBy('id') \ + # .agg(mean_udf(df.v), \ + # sum_udf(df.v), \ + # weighted_mean_udf(df.v, df.w)) \ + # .sort('id') \ + # .toPandas() + # + # expected1 = df.groupBy('id') \ + # .agg(mean(df.v).alias('mean_udf'), \ + # sum(df.v).alias('sum_udf'), \ + # mean(df.v).alias('weighted_mean_udf')) \ + # .sort('id') \ + # .toPandas() + # self.assertFramesEqual(expected1, result1) if __name__ == "__main__": diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/interfaces.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/interfaces.scala index 19abce01a26cf..9db7f94682cba 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/interfaces.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/interfaces.scala @@ -80,6 +80,7 @@ object AggregateExpression { } } + /** * A container for an [[AggregateFunction]] with its [[AggregateMode]] and a field * (`isDistinct`) indicating if DISTINCT keyword is specified for this function. diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/rules/RuleExecutor.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/rules/RuleExecutor.scala index 7e4b784033bfc..2d0aa218617fb 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/rules/RuleExecutor.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/rules/RuleExecutor.scala @@ -94,7 +94,7 @@ abstract class RuleExecutor[TreeType <: TreeNode[_]] extends Logging { RuleExecutor.timeMap.addAndGet(rule.ruleName, runTime) if (!result.fastEquals(plan)) { - logTrace( + print( s""" |=== Applying Rule ${rule.ruleName} === |${sideBySide(plan.treeString, result.treeString).mkString("\n")} From 8cf0ccd50a2b2fa10c7977fa2ff5ad3bde9c44b9 Mon Sep 17 00:00:00 2001 From: Li Jin Date: Thu, 7 Dec 2017 19:35:05 -0500 Subject: [PATCH 03/35] Tests pass --- python/pyspark/sql/group.py | 11 +- python/pyspark/sql/tests.py | 159 ++++++++++-------- python/pyspark/sql/udf.py | 6 +- .../sql/catalyst/analysis/CheckAnalysis.scala | 1 + .../sql/catalyst/expressions}/PythonUDF.scala | 8 +- .../sql/catalyst/optimizer/Optimizer.scala | 2 - .../sql/catalyst/planning/patterns.scala | 7 +- .../logical/pythonLogicalOperators.scala | 9 - .../sql/catalyst/rules/RuleExecutor.scala | 2 +- .../spark/sql/RelationalGroupedDataset.scala | 6 +- .../spark/sql/execution/SparkStrategies.scala | 72 +++++--- .../sql/execution/aggregate/AggUtils.scala | 4 +- .../python/AggregateInPandasExec.scala | 30 ++-- .../sql/execution/python/EvalPythonExec.scala | 1 + .../execution/python/ExtractPythonUDFs.scala | 18 ++ .../python/UserDefinedPythonFunction.scala | 2 +- 16 files changed, 197 insertions(+), 141 deletions(-) rename sql/{core/src/main/scala/org/apache/spark/sql/execution/python => catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions}/PythonUDF.scala (88%) diff --git a/python/pyspark/sql/group.py b/python/pyspark/sql/group.py index 8a25f14ac2993..fe73d6ac7bcb6 100644 --- a/python/pyspark/sql/group.py +++ b/python/pyspark/sql/group.py @@ -19,7 +19,7 @@ from pyspark.rdd import ignore_unicode_prefix, PythonEvalType from pyspark.sql.column import Column, _to_seq, _to_java_column, _create_column_from_literal from pyspark.sql.dataframe import DataFrame -from pyspark.sql.udf import UserDefinedFunction, UDFColumn +from pyspark.sql.udf import UserDefinedFunction from pyspark.sql.types import * __all__ = ["GroupedData"] @@ -89,13 +89,8 @@ def agg(self, *exprs): else: # Columns assert all(isinstance(c, Column) for c in exprs), "all exprs should be Column" - if isinstance(exprs[0], UDFColumn): - assert all(isinstance(c, UDFColumn) for c in exprs) - jdf = self._jgd.aggInPandas( - _to_seq(self.sql_ctx._sc, [c._jc for c in exprs])) - else: - jdf = self._jgd.agg(exprs[0]._jc, - _to_seq(self.sql_ctx._sc, [c._jc for c in exprs[1:]])) + jdf = self._jgd.agg(exprs[0]._jc, + _to_seq(self.sql_ctx._sc, [c._jc for c in exprs[1:]])) return DataFrame(jdf, self.sql_ctx) diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py index dae09cfcfe6e3..5f6552686fcbe 100644 --- a/python/pyspark/sql/tests.py +++ b/python/pyspark/sql/tests.py @@ -4369,85 +4369,106 @@ def data(self): .drop('vs') \ .withColumn('w', lit(1.0)) - # def test_basic(self): - # import numpy as np - # from pyspark.sql.functions import pandas_udf, PandasUDFType, col, lit, sum, mean - # - # df = self.data - # - # @pandas_udf('double', PandasUDFType.GROUP_AGG) - # def mean_udf(v, w): - # return np.average(v, weights=w) - # - # result1 = df.groupby('id').agg(mean_udf(df.v, lit(1.0))).sort('id').toPandas() - # expected1 = df.groupby('id').agg(mean(df.v).alias('mean_udf')).sort('id').toPandas() - # self.assertFramesEqual(expected1, result1) - # - # result2 = df.groupby((col('id') + 1).alias('id')).agg(mean_udf(df.v, lit(1.0))).sort('id').toPandas() - # expected2 = df.groupby((col('id') + 1).alias('id')).agg(mean(df.v).alias('mean_udf')).sort('id').toPandas() - # self.assertFramesEqual(expected2, result2) - # - # result3 = df.groupby('id').agg(mean_udf(df.v, df.w)).sort('id').toPandas() - # expected3 = df.groupby('id').agg(mean(df.v).alias('mean_udf')).sort('id').toPandas() - # self.assertFramesEqual(expected3, result3) - # - # result4 = df.groupby((col('id') + 1).alias('id')).agg(mean_udf(df.v, df.w)).sort('id').toPandas() - # expected4 = df.groupby((col('id') + 1).alias('id')).agg(mean(df.v).alias('mean_udf')).sort('id').toPandas() - # self.assertFramesEqual(expected4, result4) + def test_basic(self): + import numpy as np + from pyspark.sql.functions import pandas_udf, PandasUDFType, col, lit, sum, mean + + df = self.data + + @pandas_udf('double', PandasUDFType.GROUP_AGG) + def mean_udf(v, w): + return np.average(v, weights=w) + + result1 = df.groupby('id').agg(mean_udf(df.v, lit(1.0))).sort('id').toPandas() + expected1 = df.groupby('id').agg(mean(df.v).alias('mean_udf')).sort('id').toPandas() + self.assertFramesEqual(expected1, result1) + + result2 = df.groupby((col('id') + 1).alias('id')).agg(mean_udf(df.v, lit(1.0))).sort('id').toPandas() + expected2 = df.groupby((col('id') + 1).alias('id')).agg(mean(df.v).alias('mean_udf')).sort('id').toPandas() + self.assertFramesEqual(expected2, result2) + + result3 = df.groupby('id').agg(mean_udf(df.v, df.w)).sort('id') + result3 = df.groupby('id').agg(mean_udf(df.v, df.w)).sort('id').toPandas() + expected3 = df.groupby('id').agg(mean(df.v).alias('mean_udf')).sort('id').toPandas() + self.assertFramesEqual(expected3, result3) + + result4 = df.groupby((col('id') + 1).alias('id')).agg(mean_udf(df.v, df.w)).sort('id').toPandas() + expected4 = df.groupby((col('id') + 1).alias('id')).agg(mean(df.v).alias('mean_udf')).sort('id').toPandas() + self.assertFramesEqual(expected4, result4) def test_alias(self): - import numpy as np from pyspark.sql.functions import pandas_udf, PandasUDFType, mean - #df = self.data + df = self.data - df = self.spark.range(10) + @pandas_udf('double', PandasUDFType.GROUP_AGG) + def mean_udf(v): + return v.mean() + + result1 = df.groupby('id').agg(mean_udf(df.v).alias('mean_alias')).toPandas() + expected1 = df.groupby('id').agg(mean(df.v).alias('mean_alias')).toPandas() + + # result1._jdf.queryExecution().analyzed() + # print() + # print("*************** Analyzed *****************") + # print(result1._jdf.queryExecution().analyzed().toString()) + # print("******************************************") + # + # print(result1._jdf.queryExecution().analyzed().aggregateExpressions().apply(1).getClass()) + # + # result1._jdf.queryExecution().optimizedPlan() + # print() + # print("*************** Optimized *****************") + # print(result1._jdf.queryExecution().optimizedPlan().toString()) + # print("******************************************") + # + # result1._jdf.queryExecution().sparkPlan() + # print() + # print("*************** Spark Plan *****************") + # print(result1._jdf.queryExecution().sparkPlan().toString()) + # print("********************************************") + # + # + # result1._jdf.queryExecution().executedPlan() + # print() + # print("*************** Executed Plan *****************") + # print(result1._jdf.queryExecution().executedPlan().toString()) + # print("********************************************") + + self.assertFramesEqual(expected1, result1) + + def test_multiple(self): + import numpy as np + from pyspark.sql.functions import pandas_udf, PandasUDFType, col, lit, sum, mean + df = self.data @pandas_udf('double', PandasUDFType.GROUP_AGG) def mean_udf(v): return v.mean() - result1 = df.groupby('id').agg(mean_udf(df.id).alias('mean')) - - result1.show() - - # result2 = df.groupby('id').agg(mean_udf(df.v).name('mean')) - - # result2.show() - - print(result1._jdf.queryExecution().analyzed().toString()) - - # def test_multiple(self): - # import numpy as np - # from pyspark.sql.functions import pandas_udf, PandasUDFType, col, lit, sum, mean - # df = self.data - # - # @pandas_udf('double', PandasUDFType.GROUP_AGG) - # def mean_udf(v): - # return v.mean() - # - # @pandas_udf('double', PandasUDFType.GROUP_AGG) - # def sum_udf(v): - # return v.sum() - # - # @pandas_udf('double', PandasUDFType.GROUP_AGG) - # def weighted_mean_udf(v, w): - # return np.average(v, weights=w) - # - # result1 = df.groupBy('id') \ - # .agg(mean_udf(df.v), \ - # sum_udf(df.v), \ - # weighted_mean_udf(df.v, df.w)) \ - # .sort('id') \ - # .toPandas() - # - # expected1 = df.groupBy('id') \ - # .agg(mean(df.v).alias('mean_udf'), \ - # sum(df.v).alias('sum_udf'), \ - # mean(df.v).alias('weighted_mean_udf')) \ - # .sort('id') \ - # .toPandas() - # self.assertFramesEqual(expected1, result1) + @pandas_udf('double', PandasUDFType.GROUP_AGG) + def sum_udf(v): + return v.sum() + + @pandas_udf('double', PandasUDFType.GROUP_AGG) + def weighted_mean_udf(v, w): + return np.average(v, weights=w) + + result1 = df.groupBy('id') \ + .agg(mean_udf(df.v), \ + sum_udf(df.v), \ + weighted_mean_udf(df.v, df.w)) \ + .sort('id') \ + .toPandas() + + expected1 = df.groupBy('id') \ + .agg(mean(df.v).alias('mean_udf'), \ + sum(df.v).alias('sum_udf'), \ + mean(df.v).alias('weighted_mean_udf')) \ + .sort('id') \ + .toPandas() + + self.assertFramesEqual(expected1, result1) if __name__ == "__main__": diff --git a/python/pyspark/sql/udf.py b/python/pyspark/sql/udf.py index 471326dda0d06..c77f19f89a442 100644 --- a/python/pyspark/sql/udf.py +++ b/python/pyspark/sql/udf.py @@ -63,10 +63,6 @@ def _create_udf(f, returnType, evalType): return udf_obj._wrapped() -class UDFColumn(Column): - pass - - class UserDefinedFunction(object): """ User defined function in Python @@ -145,7 +141,7 @@ def _create_judf(self): def __call__(self, *cols): judf = self._judf sc = SparkContext._active_spark_context - return UDFColumn(judf.apply(_to_seq(sc, cols, _to_java_column))) + return Column(judf.apply(_to_seq(sc, cols, _to_java_column))) # This function is for improving the online help system in the interactive interpreter. # For example, the built-in help / pydoc.help. It wraps the UDF with the docstring and diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala index bbcec5627bd49..a1662112e1639 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala @@ -171,6 +171,7 @@ trait CheckAnalysis extends PredicateHelper { s"appear in the arguments of an aggregate function.") } } + case _: PythonUDF => // OK case e: Attribute if groupingExprs.isEmpty => // Collect all [[AggregateExpressions]]s. val aggExprs = aggregateExprs.filter(_.collect { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/PythonUDF.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/PythonUDF.scala similarity index 88% rename from sql/core/src/main/scala/org/apache/spark/sql/execution/python/PythonUDF.scala rename to sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/PythonUDF.scala index d3f743d9eb61e..6dce19ca54c81 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/PythonUDF.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/PythonUDF.scala @@ -15,10 +15,10 @@ * limitations under the License. */ -package org.apache.spark.sql.execution.python +package org.apache.spark.sql.catalyst.expressions import org.apache.spark.api.python.PythonFunction -import org.apache.spark.sql.catalyst.expressions.{Expression, NonSQLExpression, Unevaluable, UserDefinedExpression} +import org.apache.spark.sql.catalyst.analysis.UnresolvedAttribute import org.apache.spark.sql.types.DataType /** @@ -35,7 +35,9 @@ case class PythonUDF( override lazy val deterministic: Boolean = udfDeterministic && children.forall(_.deterministic) + override def nullable: Boolean = true + override def toString: String = s"$name(${children.mkString(", ")})" - override def nullable: Boolean = true + lazy val resultAttribute: Attribute = AttributeReference(name, dataType)() } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala index f26608fae8a5b..c794ba8619322 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala @@ -452,8 +452,6 @@ object ColumnPruning extends Rule[LogicalPlan] { // Prunes the unused columns from child of Aggregate/Expand/Generate case a @ Aggregate(_, _, child) if (child.outputSet -- a.references).nonEmpty => a.copy(child = prunedChild(child, a.references)) - case a @ AggregateInPandas(_, _, _, child) if (child.outputSet -- a.references).nonEmpty => - a.copy(child = prunedChild(child, a.references)) case f @ FlatMapGroupsInPandas(_, _, _, child) if (child.outputSet -- f.references).nonEmpty => f.copy(child = prunedChild(child, f.references)) case e @ Expand(_, _, child) if (child.outputSet -- e.references).nonEmpty => diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala index cc391aae55787..0a2268f4a188f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala @@ -17,6 +17,7 @@ package org.apache.spark.sql.catalyst.planning +import org.apache.spark.api.python.PythonEvalType import org.apache.spark.internal.Logging import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression @@ -199,7 +200,7 @@ object ExtractFiltersAndInnerJoins extends PredicateHelper { object PhysicalAggregation { // groupingExpressions, aggregateExpressions, resultExpressions, child type ReturnType = - (Seq[NamedExpression], Seq[AggregateExpression], Seq[NamedExpression], LogicalPlan) + (Seq[NamedExpression], Seq[Expression], Seq[NamedExpression], LogicalPlan) def unapply(a: Any): Option[ReturnType] = a match { case logical.Aggregate(groupingExpressions, resultExpressions, child) => @@ -213,7 +214,9 @@ object PhysicalAggregation { expr.collect { // addExpr() always returns false for non-deterministic expressions and do not add them. case agg: AggregateExpression - if (!equivalentAggregateExpressions.addExpr(agg)) => agg + if !equivalentAggregateExpressions.addExpr(agg) => agg + case agg @ PythonUDF(_, _, _, _, PythonEvalType.SQL_PANDAS_GROUP_AGG_UDF) + if !equivalentAggregateExpressions.addExpr(agg) => agg } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/pythonLogicalOperators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/pythonLogicalOperators.scala index 059e304ac4b55..381d8f5aaeb47 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/pythonLogicalOperators.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/pythonLogicalOperators.scala @@ -39,12 +39,3 @@ case class FlatMapGroupsInPandas( override val producedAttributes = AttributeSet(output) } -case class AggregateInPandas( - groupingAttributes: Seq[Attribute], - functionExprs: Seq[Expression], - output: Seq[Attribute], - child: LogicalPlan -) extends UnaryNode { - override val references: AttributeSet = child.outputSet - override val producedAttributes = AttributeSet(groupingAttributes ++ output) -} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/rules/RuleExecutor.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/rules/RuleExecutor.scala index 2d0aa218617fb..7e4b784033bfc 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/rules/RuleExecutor.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/rules/RuleExecutor.scala @@ -94,7 +94,7 @@ abstract class RuleExecutor[TreeType <: TreeNode[_]] extends Logging { RuleExecutor.timeMap.addAndGet(rule.ruleName, runTime) if (!result.fastEquals(plan)) { - print( + logTrace( s""" |=== Applying Rule ${rule.ruleName} === |${sideBySide(plan.treeString, result.treeString).mkString("\n")} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/RelationalGroupedDataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/RelationalGroupedDataset.scala index afd96b62e3a36..ea900c95a815a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/RelationalGroupedDataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/RelationalGroupedDataset.scala @@ -31,7 +31,6 @@ import org.apache.spark.sql.catalyst.expressions.aggregate._ import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.util.toPrettySQL import org.apache.spark.sql.execution.aggregate.TypedAggregateExpression -import org.apache.spark.sql.execution.python.PythonUDF import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types.{NumericType, StructType} @@ -86,6 +85,8 @@ class RelationalGroupedDataset protected[sql]( case expr: NamedExpression => expr case a: AggregateExpression if a.aggregateFunction.isInstanceOf[TypedAggregateExpression] => UnresolvedAlias(a, Some(Column.generateAlias)) + case udf: PythonUDF => + UnresolvedAlias(udf, Some(_ => udf.name)) case expr: Expression => Alias(expr, toPrettySQL(expr))() } @@ -438,6 +439,7 @@ class RelationalGroupedDataset protected[sql]( } + /* private[sql] def aggInPandas(columns: Seq[Column]): DataFrame = { val exprs = columns.map(column => column.expr.asInstanceOf[PythonUDF]) @@ -468,6 +470,8 @@ class RelationalGroupedDataset protected[sql]( Dataset.ofRows(df.sparkSession, plan) } + */ + /** * Applies a grouped vectorized python user-defined function to each group of data. * The user-defined function defines a transformation: `pandas.DataFrame` -> `pandas.DataFrame`. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala index c0b342eb3de12..05fe33d28f9ad 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala @@ -22,6 +22,7 @@ import org.apache.spark.sql.{execution, AnalysisException, Strategy} import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.encoders.RowEncoder import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression import org.apache.spark.sql.catalyst.planning._ import org.apache.spark.sql.catalyst.plans._ import org.apache.spark.sql.catalyst.plans.logical._ @@ -290,7 +291,7 @@ abstract class SparkStrategies extends QueryPlanner[SparkPlan] { aggregate.AggUtils.planStreamingAggregation( namedGroupingExpressions, - aggregateExpressions, + aggregateExpressions.map(expr => expr.asInstanceOf[AggregateExpression]), rewrittenResultExpressions, planLater(child)) @@ -334,34 +335,51 @@ abstract class SparkStrategies extends QueryPlanner[SparkPlan] { object Aggregation extends Strategy { def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match { case PhysicalAggregation( - groupingExpressions, aggregateExpressions, resultExpressions, child) => - - val (functionsWithDistinct, functionsWithoutDistinct) = - aggregateExpressions.partition(_.isDistinct) - if (functionsWithDistinct.map(_.aggregateFunction.children).distinct.length > 1) { - // This is a sanity check. We should not reach here when we have multiple distinct - // column sets. Our MultipleDistinctRewriter should take care this case. - sys.error("You hit a query analyzer bug. Please report your query to " + - "Spark user mailing list.") - } + groupingExpressions, aggExpressions, resultExpressions, child) => + + if (aggExpressions.forall(expr => expr.isInstanceOf[AggregateExpression])) { + + val aggregateExpressions = aggExpressions.map(expr => + expr.asInstanceOf[AggregateExpression]) - val aggregateOperator = - if (functionsWithDistinct.isEmpty) { - aggregate.AggUtils.planAggregateWithoutDistinct( - groupingExpressions, - aggregateExpressions, - resultExpressions, - planLater(child)) - } else { - aggregate.AggUtils.planAggregateWithOneDistinct( - groupingExpressions, - functionsWithDistinct, - functionsWithoutDistinct, - resultExpressions, - planLater(child)) + val (functionsWithDistinct, functionsWithoutDistinct) = + aggregateExpressions.partition(_.isDistinct) + if (functionsWithDistinct.map(_.aggregateFunction.children).distinct.length > 1) { + // This is a sanity check. We should not reach here when we have multiple distinct + // column sets. Our MultipleDistinctRewriter should take care this case. + sys.error("You hit a query analyzer bug. Please report your query to " + + "Spark user mailing list.") } - aggregateOperator + val aggregateOperator = + if (functionsWithDistinct.isEmpty) { + aggregate.AggUtils.planAggregateWithoutDistinct( + groupingExpressions, + aggregateExpressions, + resultExpressions, + planLater(child)) + } else { + aggregate.AggUtils.planAggregateWithOneDistinct( + groupingExpressions, + functionsWithDistinct, + functionsWithoutDistinct, + resultExpressions, + planLater(child)) + } + + aggregateOperator + } else if (aggExpressions.forall(expr => expr.isInstanceOf[PythonUDF])) { + val udfExpressions = aggExpressions.map(expr => expr.asInstanceOf[PythonUDF]) + + Seq(execution.python.AggregateInPandasExec( + groupingExpressions, + udfExpressions, + resultExpressions, + planLater(child))) + } else { + throw new IllegalArgumentException( + "Cannot use mixture of aggregation function and pandas group aggregation UDF") + } case _ => Nil } @@ -452,8 +470,6 @@ abstract class SparkStrategies extends QueryPlanner[SparkPlan] { case logical.FlatMapGroupsInR(f, p, b, is, os, key, value, grouping, data, objAttr, child) => execution.FlatMapGroupsInRExec(f, p, b, is, os, key, value, grouping, data, objAttr, planLater(child)) :: Nil - case logical.AggregateInPandas(grouping, func, output, child) => - execution.python.AggregateInPandasExec(grouping, func, output, planLater(child)) :: Nil case logical.FlatMapGroupsInPandas(grouping, func, output, child) => execution.python.FlatMapGroupsInPandasExec(grouping, func, output, planLater(child)) :: Nil case logical.MapElements(f, _, _, objAttr, child) => diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/AggUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/AggUtils.scala index ebbdf1aaa024d..19e53c3cf4fc3 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/AggUtils.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/AggUtils.scala @@ -74,13 +74,15 @@ object AggUtils { def planAggregateWithoutDistinct( groupingExpressions: Seq[NamedExpression], - aggregateExpressions: Seq[AggregateExpression], + aggExpressions: Seq[Expression], resultExpressions: Seq[NamedExpression], child: SparkPlan): Seq[SparkPlan] = { // Check if we can use HashAggregate. // 1. Create an Aggregate Operator for partial aggregations. + val aggregateExpressions = aggExpressions.map(expr => expr.asInstanceOf[AggregateExpression]) + val groupingAttributes = groupingExpressions.map(_.toAttribute) val partialAggregateExpressions = aggregateExpressions.map(_.copy(mode = Partial)) val partialAggregateAttributes = diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/AggregateInPandasExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/AggregateInPandasExec.scala index f8913bb9b616d..7206b8827e6b5 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/AggregateInPandasExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/AggregateInPandasExec.scala @@ -25,19 +25,20 @@ import org.apache.spark.{SparkEnv, TaskContext} import org.apache.spark.api.python.{ChainedPythonFunctions, PythonEvalType} import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.expressions.{Ascending, Attribute, AttributeSet, Expression, JoinedRow, SortOrder, UnsafeProjection, UnsafeRow} +import org.apache.spark.sql.catalyst.expressions.{Ascending, Attribute, AttributeSet, Expression, JoinedRow, NamedExpression, PythonUDF, SortOrder, UnsafeProjection, UnsafeRow} import org.apache.spark.sql.catalyst.plans.physical.{AllTuples, ClusteredDistribution, Distribution, Partitioning} import org.apache.spark.sql.execution.{GroupedIterator, SparkPlan, UnaryExecNode} -import org.apache.spark.sql.types.StructType +import org.apache.spark.sql.types.{DataType, StructField, StructType} import org.apache.spark.util.Utils case class AggregateInPandasExec( - groupingAttributes: Seq[Attribute], - func: Seq[Expression], - output: Seq[Attribute], + groupingAttributes: Seq[Expression], + func: Seq[PythonUDF], + resultExpressions: Seq[NamedExpression], child: SparkPlan) extends UnaryExecNode { - private val udfs = func.map(expr => expr.asInstanceOf[PythonUDF]) + + override def output: Seq[Attribute] = resultExpressions.map(_.toAttribute) override def outputPartitioning: Partitioning = child.outputPartitioning @@ -71,22 +72,28 @@ case class AggregateInPandasExec( val bufferSize = inputRDD.conf.getInt("spark.buffer.size", 65536) val reuseWorker = inputRDD.conf.getBoolean("spark.python.worker.reuse", defaultValue = true) - // val argOffsets = Array((0 until (child.output.length - groupingAttributes.length)).toArray) - val schema = StructType(child.schema.drop(groupingAttributes.length)) val sessionLocalTimeZone = conf.sessionLocalTimeZone val pandasRespectSessionTimeZone = conf.pandasRespectSessionTimeZone - val (pyFuncs, inputs) = udfs.map(collectFunctions).unzip + val (pyFuncs, inputs) = func.map(collectFunctions).unzip val allInputs = new ArrayBuffer[Expression] + val dataTypes = new ArrayBuffer[DataType] + + allInputs.appendAll(groupingAttributes) val argOffsets = inputs.map { input => input.map { e => allInputs += e - allInputs.length - 1 + dataTypes += e.dataType + allInputs.length - 1 - groupingAttributes.length }.toArray }.toArray + val schema = StructType(dataTypes.zipWithIndex.map { case (dt, i) => + StructField(s"_$i", dt) + }) + inputRDD.mapPartitionsInternal { iter => val grouped = if (groupingAttributes.isEmpty) { Iterator((null, iter)) @@ -94,7 +101,8 @@ case class AggregateInPandasExec( val groupedIter = GroupedIterator(iter, groupingAttributes, child.output) val dropGrouping = - UnsafeProjection.create(child.output.drop(groupingAttributes.length), child.output) + UnsafeProjection.create(allInputs.drop(groupingAttributes.length), child.output) + groupedIter.map { case (k, groupedRowIter) => (k, groupedRowIter.map(dropGrouping)) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/EvalPythonExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/EvalPythonExec.scala index 860dc78c1dd1b..9bd3c0bb15de2 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/EvalPythonExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/EvalPythonExec.scala @@ -117,6 +117,7 @@ abstract class EvalPythonExec(udfs: Seq[PythonUDF], output: Seq[Attribute], chil } }.toArray }.toArray + val projection = newMutableProjection(allInputs, child.output) val schema = StructType(dataTypes.zipWithIndex.map { case (dt, i) => StructField(s"_$i", dt) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ExtractPythonUDFs.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ExtractPythonUDFs.scala index 171619d247bf3..e6016bf234fd0 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ExtractPythonUDFs.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ExtractPythonUDFs.scala @@ -48,9 +48,26 @@ object ExtractPythonUDFFromAggregate extends Rule[LogicalPlan] { }.isDefined } + private def isPandasGroupAggUdf(expr: Expression): Boolean = expr match { + case _ @ PythonUDF(_, _, _, _, PythonEvalType.SQL_PANDAS_GROUP_AGG_UDF ) => true + case Alias(expr, _) => isPandasGroupAggUdf(expr) + case _ => false + } + + private def hasPandasGroupAggUdf(agg: Aggregate): Boolean = { + val actualAggExpr = agg.aggregateExpressions.drop(agg.groupingExpressions.length) + actualAggExpr.exists(isPandasGroupAggUdf) + } + + private def extract(agg: Aggregate): LogicalPlan = { val projList = new ArrayBuffer[NamedExpression]() val aggExpr = new ArrayBuffer[NamedExpression]() + + if (hasPandasGroupAggUdf(agg)) { + Aggregate(agg.groupingExpressions, agg.aggregateExpressions, agg.child) + } else { + agg.aggregateExpressions.foreach { expr => if (hasPythonUdfOverAggregate(expr, agg)) { // Python UDF can only be evaluated after aggregate @@ -71,6 +88,7 @@ object ExtractPythonUDFFromAggregate extends Rule[LogicalPlan] { } // There is no Python UDF over aggregate expression Project(projList, agg.copy(aggregateExpressions = aggExpr)) + } } def apply(plan: LogicalPlan): LogicalPlan = plan transformUp { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/UserDefinedPythonFunction.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/UserDefinedPythonFunction.scala index 50dca32cb7861..f4c2d02ee9420 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/UserDefinedPythonFunction.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/UserDefinedPythonFunction.scala @@ -19,7 +19,7 @@ package org.apache.spark.sql.execution.python import org.apache.spark.api.python.PythonFunction import org.apache.spark.sql.Column -import org.apache.spark.sql.catalyst.expressions.Expression +import org.apache.spark.sql.catalyst.expressions.{Expression, PythonUDF} import org.apache.spark.sql.types.DataType /** From c1f6cf97a15419351a185fade753ce276f834dc8 Mon Sep 17 00:00:00 2001 From: Li Jin Date: Thu, 7 Dec 2017 20:04:43 -0500 Subject: [PATCH 04/35] Clean up --- python/pyspark/sql/group.py | 2 -- .../sql/catalyst/expressions/PythonUDF.scala | 4 ++-- .../expressions/aggregate/interfaces.scala | 1 - .../logical/pythonLogicalOperators.scala | 1 - .../sql/execution/aggregate/AggUtils.scala | 4 +--- .../python/AggregateInPandasExec.scala | 24 +++++++++---------- 6 files changed, 15 insertions(+), 21 deletions(-) diff --git a/python/pyspark/sql/group.py b/python/pyspark/sql/group.py index fe73d6ac7bcb6..22061b83eb78c 100644 --- a/python/pyspark/sql/group.py +++ b/python/pyspark/sql/group.py @@ -91,8 +91,6 @@ def agg(self, *exprs): assert all(isinstance(c, Column) for c in exprs), "all exprs should be Column" jdf = self._jgd.agg(exprs[0]._jc, _to_seq(self.sql_ctx._sc, [c._jc for c in exprs[1:]])) - - return DataFrame(jdf, self.sql_ctx) @dfapi diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/PythonUDF.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/PythonUDF.scala index 6dce19ca54c81..30a9ed1732f77 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/PythonUDF.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/PythonUDF.scala @@ -35,9 +35,9 @@ case class PythonUDF( override lazy val deterministic: Boolean = udfDeterministic && children.forall(_.deterministic) - override def nullable: Boolean = true - override def toString: String = s"$name(${children.mkString(", ")})" lazy val resultAttribute: Attribute = AttributeReference(name, dataType)() + + override def nullable: Boolean = true } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/interfaces.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/interfaces.scala index 9db7f94682cba..19abce01a26cf 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/interfaces.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/interfaces.scala @@ -80,7 +80,6 @@ object AggregateExpression { } } - /** * A container for an [[AggregateFunction]] with its [[AggregateMode]] and a field * (`isDistinct`) indicating if DISTINCT keyword is specified for this function. diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/pythonLogicalOperators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/pythonLogicalOperators.scala index 381d8f5aaeb47..254687ec00880 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/pythonLogicalOperators.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/pythonLogicalOperators.scala @@ -38,4 +38,3 @@ case class FlatMapGroupsInPandas( */ override val producedAttributes = AttributeSet(output) } - diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/AggUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/AggUtils.scala index 19e53c3cf4fc3..ebbdf1aaa024d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/AggUtils.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/AggUtils.scala @@ -74,15 +74,13 @@ object AggUtils { def planAggregateWithoutDistinct( groupingExpressions: Seq[NamedExpression], - aggExpressions: Seq[Expression], + aggregateExpressions: Seq[AggregateExpression], resultExpressions: Seq[NamedExpression], child: SparkPlan): Seq[SparkPlan] = { // Check if we can use HashAggregate. // 1. Create an Aggregate Operator for partial aggregations. - val aggregateExpressions = aggExpressions.map(expr => expr.asInstanceOf[AggregateExpression]) - val groupingAttributes = groupingExpressions.map(_.toAttribute) val partialAggregateExpressions = aggregateExpressions.map(_.copy(mode = Partial)) val partialAggregateAttributes = diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/AggregateInPandasExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/AggregateInPandasExec.scala index 7206b8827e6b5..276dc25177960 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/AggregateInPandasExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/AggregateInPandasExec.scala @@ -32,8 +32,8 @@ import org.apache.spark.sql.types.{DataType, StructField, StructType} import org.apache.spark.util.Utils case class AggregateInPandasExec( - groupingAttributes: Seq[Expression], - func: Seq[PythonUDF], + groupingExpressions: Seq[Expression], + udfExpressions: Seq[PythonUDF], resultExpressions: Seq[NamedExpression], child: SparkPlan) extends UnaryExecNode { @@ -45,10 +45,10 @@ case class AggregateInPandasExec( override def producedAttributes: AttributeSet = AttributeSet(output) override def requiredChildDistribution: Seq[Distribution] = { - if (groupingAttributes.isEmpty) { + if (groupingExpressions.isEmpty) { AllTuples :: Nil } else { - ClusteredDistribution(groupingAttributes) :: Nil + ClusteredDistribution(groupingExpressions) :: Nil } } @@ -65,7 +65,7 @@ case class AggregateInPandasExec( } override def requiredChildOrdering: Seq[Seq[SortOrder]] = - Seq(groupingAttributes.map(SortOrder(_, Ascending))) + Seq(groupingExpressions.map(SortOrder(_, Ascending))) override protected def doExecute(): RDD[InternalRow] = { val inputRDD = child.execute() @@ -75,18 +75,18 @@ case class AggregateInPandasExec( val sessionLocalTimeZone = conf.sessionLocalTimeZone val pandasRespectSessionTimeZone = conf.pandasRespectSessionTimeZone - val (pyFuncs, inputs) = func.map(collectFunctions).unzip + val (pyFuncs, inputs) = udfExpressions.map(collectFunctions).unzip val allInputs = new ArrayBuffer[Expression] val dataTypes = new ArrayBuffer[DataType] - allInputs.appendAll(groupingAttributes) + allInputs.appendAll(groupingExpressions) val argOffsets = inputs.map { input => input.map { e => allInputs += e dataTypes += e.dataType - allInputs.length - 1 - groupingAttributes.length + allInputs.length - 1 - groupingExpressions.length }.toArray }.toArray @@ -95,13 +95,13 @@ case class AggregateInPandasExec( }) inputRDD.mapPartitionsInternal { iter => - val grouped = if (groupingAttributes.isEmpty) { + val grouped = if (groupingExpressions.isEmpty) { Iterator((null, iter)) } else { - val groupedIter = GroupedIterator(iter, groupingAttributes, child.output) + val groupedIter = GroupedIterator(iter, groupingExpressions, child.output) val dropGrouping = - UnsafeProjection.create(allInputs.drop(groupingAttributes.length), child.output) + UnsafeProjection.create(allInputs.drop(groupingExpressions.length), child.output) groupedIter.map { case (k, groupedRowIter) => (k, groupedRowIter.map(dropGrouping)) @@ -113,7 +113,7 @@ case class AggregateInPandasExec( // The queue used to buffer input rows so we can drain it to // combine input with output from Python. val queue = HybridRowQueue(context.taskMemoryManager(), - new File(Utils.getLocalDir(SparkEnv.get.conf)), groupingAttributes.length) + new File(Utils.getLocalDir(SparkEnv.get.conf)), groupingExpressions.length) context.addTaskCompletionListener { _ => queue.close() } From 3791b28178001b5ccf8074d16d6d8824f6252adf Mon Sep 17 00:00:00 2001 From: Li Jin Date: Thu, 7 Dec 2017 20:06:55 -0500 Subject: [PATCH 05/35] More clean up --- .../spark/sql/RelationalGroupedDataset.scala | 34 ------------------- 1 file changed, 34 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/RelationalGroupedDataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/RelationalGroupedDataset.scala index ea900c95a815a..4655e86391b83 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/RelationalGroupedDataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/RelationalGroupedDataset.scala @@ -438,40 +438,6 @@ class RelationalGroupedDataset protected[sql]( df.logicalPlan)) } - - /* - private[sql] def aggInPandas(columns: Seq[Column]): DataFrame = { - val exprs = columns.map(column => column.expr.asInstanceOf[PythonUDF]) - - val groupingNamedExpressions = groupingExprs.map { - case ne: NamedExpression => ne - case other => Alias(other, other.toString)() - } - - val groupingAttributes = groupingNamedExpressions.map(_.toAttribute) - - val child = df.logicalPlan - - val childrenExpressions = exprs.flatMap(expr => - expr.children.map { - case ne: NamedExpression => ne - case other => Alias(other, other.toString)() - }) - - val project = Project(groupingNamedExpressions ++ childrenExpressions, child) - - val udfOutputs = exprs.flatMap(expr => - Seq(AttributeReference(expr.name, expr.dataType)()) - ) - - val output: Seq[Attribute] = groupingAttributes ++ udfOutputs - - val plan = AggregateInPandas(groupingAttributes, exprs, output, project) - - Dataset.ofRows(df.sparkSession, plan) - } - */ - /** * Applies a grouped vectorized python user-defined function to each group of data. * The user-defined function defines a transformation: `pandas.DataFrame` -> `pandas.DataFrame`. From df9f6b33f55c2f00ab028c25aae9263f0df45155 Mon Sep 17 00:00:00 2001 From: Li Jin Date: Thu, 7 Dec 2017 20:09:29 -0500 Subject: [PATCH 06/35] more clean up --- .../org/apache/spark/sql/catalyst/expressions/PythonUDF.scala | 1 - 1 file changed, 1 deletion(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/PythonUDF.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/PythonUDF.scala index 30a9ed1732f77..f0585c6553413 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/PythonUDF.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/PythonUDF.scala @@ -18,7 +18,6 @@ package org.apache.spark.sql.catalyst.expressions import org.apache.spark.api.python.PythonFunction -import org.apache.spark.sql.catalyst.analysis.UnresolvedAttribute import org.apache.spark.sql.types.DataType /** From 58c21c1d892632605be8b88143ad738ddc94d376 Mon Sep 17 00:00:00 2001 From: Li Jin Date: Tue, 19 Dec 2017 14:36:44 -0500 Subject: [PATCH 07/35] Clean up code; Address PR comments --- python/pyspark/sql/tests.py | 106 +++++++----------- python/pyspark/sql/udf.py | 6 +- .../sql/catalyst/expressions/PythonUDF.scala | 4 +- .../python/AggregateInPandasExec.scala | 2 +- .../execution/python/ExtractPythonUDFs.scala | 43 ++++--- 5 files changed, 72 insertions(+), 89 deletions(-) diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py index 5f6552686fcbe..c2cb55550435b 100644 --- a/python/pyspark/sql/tests.py +++ b/python/pyspark/sql/tests.py @@ -197,6 +197,12 @@ def tearDownClass(cls): ReusedPySparkTestCase.tearDownClass() cls.spark.stop() + def assertPandasEqual(self, expected, result): + msg = ("DataFrames are not equal: " + + "\n\nExpected:\n%s\n%s" % (expected, expected.dtypes) + + "\n\nResult:\n%s\n%s" % (result, result.dtypes)) + self.assertTrue(expected.equals(result), msg=msg) + class DataTypeTests(unittest.TestCase): # regression test for SPARK-6055 @@ -3354,12 +3360,6 @@ def tearDownClass(cls): time.tzset() ReusedSQLTestCase.tearDownClass() - def assertFramesEqual(self, df_with_arrow, df_without): - msg = ("DataFrame from Arrow is not equal" + - ("\n\nWith Arrow:\n%s\n%s" % (df_with_arrow, df_with_arrow.dtypes)) + - ("\n\nWithout:\n%s\n%s" % (df_without, df_without.dtypes))) - self.assertTrue(df_without.equals(df_with_arrow), msg=msg) - def create_pandas_data_frame(self): import pandas as pd import numpy as np @@ -3397,7 +3397,7 @@ def _toPandas_arrow_toggle(self, df): def test_toPandas_arrow_toggle(self): df = self.spark.createDataFrame(self.data, schema=self.schema) pdf, pdf_arrow = self._toPandas_arrow_toggle(df) - self.assertFramesEqual(pdf_arrow, pdf) + self.assertPandasEqual(pdf_arrow, pdf) def test_toPandas_respect_session_timezone(self): df = self.spark.createDataFrame(self.data, schema=self.schema) @@ -3408,11 +3408,11 @@ def test_toPandas_respect_session_timezone(self): self.spark.conf.set("spark.sql.execution.pandas.respectSessionTimeZone", "false") try: pdf_la, pdf_arrow_la = self._toPandas_arrow_toggle(df) - self.assertFramesEqual(pdf_arrow_la, pdf_la) + self.assertPandasEqual(pdf_arrow_la, pdf_la) finally: self.spark.conf.set("spark.sql.execution.pandas.respectSessionTimeZone", "true") pdf_ny, pdf_arrow_ny = self._toPandas_arrow_toggle(df) - self.assertFramesEqual(pdf_arrow_ny, pdf_ny) + self.assertPandasEqual(pdf_arrow_ny, pdf_ny) self.assertFalse(pdf_ny.equals(pdf_la)) @@ -3422,7 +3422,7 @@ def test_toPandas_respect_session_timezone(self): if isinstance(field.dataType, TimestampType): pdf_la_corrected[field.name] = _check_series_convert_timestamps_local_tz( pdf_la_corrected[field.name], timezone) - self.assertFramesEqual(pdf_ny, pdf_la_corrected) + self.assertPandasEqual(pdf_ny, pdf_la_corrected) finally: self.spark.conf.set("spark.sql.session.timeZone", orig_tz) @@ -3430,7 +3430,7 @@ def test_pandas_round_trip(self): pdf = self.create_pandas_data_frame() df = self.spark.createDataFrame(self.data, schema=self.schema) pdf_arrow = df.toPandas() - self.assertFramesEqual(pdf_arrow, pdf) + self.assertPandasEqual(pdf_arrow, pdf) def test_filtered_frame(self): df = self.spark.range(3).toDF("i") @@ -3488,7 +3488,7 @@ def test_createDataFrame_with_schema(self): df = self.spark.createDataFrame(pdf, schema=self.schema) self.assertEquals(self.schema, df.schema) pdf_arrow = df.toPandas() - self.assertFramesEqual(pdf_arrow, pdf) + self.assertPandasEqual(pdf_arrow, pdf) def test_createDataFrame_with_incorrect_schema(self): pdf = self.create_pandas_data_frame() @@ -4181,12 +4181,6 @@ def test_register_vectorized_udf_basic(self): @unittest.skipIf(not _have_pandas or not _have_arrow, "Pandas or Arrow not installed") class GroupbyApplyTests(ReusedSQLTestCase): - def assertFramesEqual(self, expected, result): - msg = ("DataFrames are not equal: " + - ("\n\nExpected:\n%s\n%s" % (expected, expected.dtypes)) + - ("\n\nResult:\n%s\n%s" % (result, result.dtypes))) - self.assertTrue(expected.equals(result), msg=msg) - @property def data(self): from pyspark.sql.functions import array, explode, col, lit @@ -4210,7 +4204,7 @@ def test_simple(self): result = df.groupby('id').apply(foo_udf).sort('id').toPandas() expected = df.toPandas().groupby('id').apply(foo_udf.func).reset_index(drop=True) - self.assertFramesEqual(expected, result) + self.assertPandasEqual(expected, result) def test_register_group_map_udf(self): from pyspark.sql.functions import pandas_udf, PandasUDFType @@ -4234,7 +4228,7 @@ def foo(pdf): result = df.groupby('id').apply(foo).sort('id').toPandas() expected = df.toPandas().groupby('id').apply(foo.func).reset_index(drop=True) - self.assertFramesEqual(expected, result) + self.assertPandasEqual(expected, result) def test_coerce(self): from pyspark.sql.functions import pandas_udf, PandasUDFType @@ -4249,7 +4243,7 @@ def test_coerce(self): result = df.groupby('id').apply(foo).sort('id').toPandas() expected = df.toPandas().groupby('id').apply(foo.func).reset_index(drop=True) expected = expected.assign(v=expected.v.astype('float64')) - self.assertFramesEqual(expected, result) + self.assertPandasEqual(expected, result) def test_complex_groupby(self): from pyspark.sql.functions import pandas_udf, col, PandasUDFType @@ -4268,7 +4262,7 @@ def normalize(pdf): expected = pdf.groupby(pdf['id'] % 2 == 0).apply(normalize.func) expected = expected.sort_values(['id', 'v']).reset_index(drop=True) expected = expected.assign(norm=expected.norm.astype('float64')) - self.assertFramesEqual(expected, result) + self.assertPandasEqual(expected, result) def test_empty_groupby(self): from pyspark.sql.functions import pandas_udf, col, PandasUDFType @@ -4287,7 +4281,7 @@ def normalize(pdf): expected = normalize.func(pdf) expected = expected.sort_values(['id', 'v']).reset_index(drop=True) expected = expected.assign(norm=expected.norm.astype('float64')) - self.assertFramesEqual(expected, result) + self.assertPandasEqual(expected, result) def test_datatype_string(self): from pyspark.sql.functions import pandas_udf, PandasUDFType @@ -4301,7 +4295,7 @@ def test_datatype_string(self): result = df.groupby('id').apply(foo_udf).sort('id').toPandas() expected = df.toPandas().groupby('id').apply(foo_udf.func).reset_index(drop=True) - self.assertFramesEqual(expected, result) + self.assertPandasEqual(expected, result) def test_wrong_return_type(self): from pyspark.sql.functions import pandas_udf, PandasUDFType @@ -4354,11 +4348,6 @@ def test_unsupported_types(self): @unittest.skipIf(not _have_pandas or not _have_arrow, "Pandas or Arrow not installed") class GroupbyAggTests(ReusedSQLTestCase): - def assertFramesEqual(self, expected, result): - msg = ("DataFrames are not equal: " + - ("\n\nExpected:\n%s\n%s" % (expected, expected.dtypes)) + - ("\n\nResult:\n%s\n%s" % (result, result.dtypes))) - self.assertTrue(expected.equals(result), msg=msg) @property def data(self): @@ -4381,20 +4370,38 @@ def mean_udf(v, w): result1 = df.groupby('id').agg(mean_udf(df.v, lit(1.0))).sort('id').toPandas() expected1 = df.groupby('id').agg(mean(df.v).alias('mean_udf')).sort('id').toPandas() - self.assertFramesEqual(expected1, result1) + self.assertPandasEqual(expected1, result1) result2 = df.groupby((col('id') + 1).alias('id')).agg(mean_udf(df.v, lit(1.0))).sort('id').toPandas() expected2 = df.groupby((col('id') + 1).alias('id')).agg(mean(df.v).alias('mean_udf')).sort('id').toPandas() - self.assertFramesEqual(expected2, result2) + self.assertPandasEqual(expected2, result2) - result3 = df.groupby('id').agg(mean_udf(df.v, df.w)).sort('id') result3 = df.groupby('id').agg(mean_udf(df.v, df.w)).sort('id').toPandas() expected3 = df.groupby('id').agg(mean(df.v).alias('mean_udf')).sort('id').toPandas() - self.assertFramesEqual(expected3, result3) + self.assertPandasEqual(expected3, result3) result4 = df.groupby((col('id') + 1).alias('id')).agg(mean_udf(df.v, df.w)).sort('id').toPandas() expected4 = df.groupby((col('id') + 1).alias('id')).agg(mean(df.v).alias('mean_udf')).sort('id').toPandas() - self.assertFramesEqual(expected4, result4) + self.assertPandasEqual(expected4, result4) + + def test_array(self): + from pyspark.sql.types import ArrayType, DoubleType + from pyspark.sql.functions import pandas_udf, PandasUDFType + + with QuietTest(self.sc): + with self.assertRaisesRegexp(NotImplementedError, 'not supported'): + @pandas_udf(ArrayType(DoubleType()), PandasUDFType.GROUP_AGG) + def mean_and_std_udf(v): + return [v.mean(), v.std()] + + def test_struct(self): + from pyspark.sql.functions import pandas_udf, PandasUDFType + + with QuietTest(self.sc): + with self.assertRaisesRegexp(NotImplementedError, 'not supported'): + @pandas_udf('mean double, std double', PandasUDFType.GROUP_AGG) + def mean_and_std_udf(v): + return (v.mean(), v.std()) def test_alias(self): from pyspark.sql.functions import pandas_udf, PandasUDFType, mean @@ -4408,34 +4415,7 @@ def mean_udf(v): result1 = df.groupby('id').agg(mean_udf(df.v).alias('mean_alias')).toPandas() expected1 = df.groupby('id').agg(mean(df.v).alias('mean_alias')).toPandas() - # result1._jdf.queryExecution().analyzed() - # print() - # print("*************** Analyzed *****************") - # print(result1._jdf.queryExecution().analyzed().toString()) - # print("******************************************") - # - # print(result1._jdf.queryExecution().analyzed().aggregateExpressions().apply(1).getClass()) - # - # result1._jdf.queryExecution().optimizedPlan() - # print() - # print("*************** Optimized *****************") - # print(result1._jdf.queryExecution().optimizedPlan().toString()) - # print("******************************************") - # - # result1._jdf.queryExecution().sparkPlan() - # print() - # print("*************** Spark Plan *****************") - # print(result1._jdf.queryExecution().sparkPlan().toString()) - # print("********************************************") - # - # - # result1._jdf.queryExecution().executedPlan() - # print() - # print("*************** Executed Plan *****************") - # print(result1._jdf.queryExecution().executedPlan().toString()) - # print("********************************************") - - self.assertFramesEqual(expected1, result1) + self.assertPandasEqual(expected1, result1) def test_multiple(self): import numpy as np @@ -4468,7 +4448,7 @@ def weighted_mean_udf(v, w): .sort('id') \ .toPandas() - self.assertFramesEqual(expected1, result1) + self.assertPandasEqual(expected1, result1) if __name__ == "__main__": diff --git a/python/pyspark/sql/udf.py b/python/pyspark/sql/udf.py index c77f19f89a442..e9f795810af71 100644 --- a/python/pyspark/sql/udf.py +++ b/python/pyspark/sql/udf.py @@ -22,7 +22,7 @@ from pyspark import SparkContext, since from pyspark.rdd import _prepare_for_python_RDD, PythonEvalType, ignore_unicode_prefix from pyspark.sql.column import Column, _to_java_column, _to_seq -from pyspark.sql.types import StringType, DataType, StructType, _parse_datatype_string +from pyspark.sql.types import StringType, DataType, ArrayType, StructType, _parse_datatype_string __all__ = ["UDFRegistration"] @@ -113,6 +113,10 @@ def returnType(self): and not isinstance(self._returnType_placeholder, StructType): raise ValueError("Invalid returnType: returnType must be a StructType for " "pandas_udf with function type GROUP_MAP") + elif self.evalType == PythonEvalType.SQL_PANDAS_GROUP_AGG_UDF \ + and isinstance(self._returnType_placeholder, (StructType, ArrayType)): + raise NotImplementedError( + "StructType and ArrayType are not supported with PandasUDFType.GROUP_AGG") return self._returnType_placeholder diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/PythonUDF.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/PythonUDF.scala index f0585c6553413..e112582a8fdd0 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/PythonUDF.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/PythonUDF.scala @@ -34,9 +34,9 @@ case class PythonUDF( override lazy val deterministic: Boolean = udfDeterministic && children.forall(_.deterministic) - override def toString: String = s"$name(${children.mkString(", ")})" - lazy val resultAttribute: Attribute = AttributeReference(name, dataType)() + override def toString: String = s"$name(${children.mkString(", ")})" + override def nullable: Boolean = true } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/AggregateInPandasExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/AggregateInPandasExec.scala index 276dc25177960..c73eee2038b2f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/AggregateInPandasExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/AggregateInPandasExec.scala @@ -133,7 +133,7 @@ case class AggregateInPandasExec( val joined = new JoinedRow val resultProj = UnsafeProjection.create(output, output) - columnarBatchIter.map(_.rowIterator.next()).map{ outputRow => + columnarBatchIter.map(_.rowIterator.next()).map { outputRow => val leftRow = queue.remove() val joinedRow = joined(leftRow, outputRow) resultProj(joinedRow) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ExtractPythonUDFs.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ExtractPythonUDFs.scala index e6016bf234fd0..1941af61f3e7e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ExtractPythonUDFs.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ExtractPythonUDFs.scala @@ -49,8 +49,8 @@ object ExtractPythonUDFFromAggregate extends Rule[LogicalPlan] { } private def isPandasGroupAggUdf(expr: Expression): Boolean = expr match { - case _ @ PythonUDF(_, _, _, _, PythonEvalType.SQL_PANDAS_GROUP_AGG_UDF ) => true - case Alias(expr, _) => isPandasGroupAggUdf(expr) + case PythonUDF(_, _, _, _, PythonEvalType.SQL_PANDAS_GROUP_AGG_UDF) => true + case Alias(child, _) => isPandasGroupAggUdf(child) case _ => false } @@ -67,27 +67,26 @@ object ExtractPythonUDFFromAggregate extends Rule[LogicalPlan] { if (hasPandasGroupAggUdf(agg)) { Aggregate(agg.groupingExpressions, agg.aggregateExpressions, agg.child) } else { - - agg.aggregateExpressions.foreach { expr => - if (hasPythonUdfOverAggregate(expr, agg)) { - // Python UDF can only be evaluated after aggregate - val newE = expr transformDown { - case e: Expression if belongAggregate(e, agg) => - val alias = e match { - case a: NamedExpression => a - case o => Alias(e, "agg")() - } - aggExpr += alias - alias.toAttribute + agg.aggregateExpressions.foreach { expr => + if (hasPythonUdfOverAggregate(expr, agg)) { + // Python UDF can only be evaluated after aggregate + val newE = expr transformDown { + case e: Expression if belongAggregate(e, agg) => + val alias = e match { + case a: NamedExpression => a + case o => Alias(e, "agg")() + } + aggExpr += alias + alias.toAttribute + } + projList += newE.asInstanceOf[NamedExpression] + } else { + aggExpr += expr + projList += expr.toAttribute } - projList += newE.asInstanceOf[NamedExpression] - } else { - aggExpr += expr - projList += expr.toAttribute } - } - // There is no Python UDF over aggregate expression - Project(projList, agg.copy(aggregateExpressions = aggExpr)) + // There is no Python UDF over aggregate expression + Project(projList, agg.copy(aggregateExpressions = aggExpr)) } } @@ -129,7 +128,7 @@ object ExtractPythonUDFs extends Rule[SparkPlan] with PredicateHelper { } def apply(plan: SparkPlan): SparkPlan = plan transformUp { - // FlatMapGroupsInPandas can be evaluated directly in python worker + // AggregateInPandasExec and FlatMapGroupsInPandas can be evaluated directly in python worker // Therefore we don't need to extract the UDFs case plan: AggregateInPandasExec => plan case plan: FlatMapGroupsInPandasExec => plan From d79464cabf9d25d81c1450f0777c1edcf1ff51a1 Mon Sep 17 00:00:00 2001 From: Li Jin Date: Tue, 19 Dec 2017 14:40:30 -0500 Subject: [PATCH 08/35] Fix python style --- python/pyspark/sql/tests.py | 21 +++++++++++++-------- 1 file changed, 13 insertions(+), 8 deletions(-) diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py index c2cb55550435b..5669af1f95b00 100644 --- a/python/pyspark/sql/tests.py +++ b/python/pyspark/sql/tests.py @@ -4346,6 +4346,7 @@ def test_unsupported_types(self): with self.assertRaisesRegexp(Exception, 'Unsupported data type'): df.groupby('id').apply(f).collect() + @unittest.skipIf(not _have_pandas or not _have_arrow, "Pandas or Arrow not installed") class GroupbyAggTests(ReusedSQLTestCase): @@ -4372,16 +4373,20 @@ def mean_udf(v, w): expected1 = df.groupby('id').agg(mean(df.v).alias('mean_udf')).sort('id').toPandas() self.assertPandasEqual(expected1, result1) - result2 = df.groupby((col('id') + 1).alias('id')).agg(mean_udf(df.v, lit(1.0))).sort('id').toPandas() - expected2 = df.groupby((col('id') + 1).alias('id')).agg(mean(df.v).alias('mean_udf')).sort('id').toPandas() + result2 = df.groupby((col('id') + 1).alias('id')).agg(mean_udf(df.v, lit(1.0)))\ + .sort('id').toPandas() + expected2 = df.groupby((col('id') + 1).alias('id')).agg(mean(df.v).alias('mean_udf'))\ + .sort('id').toPandas() self.assertPandasEqual(expected2, result2) result3 = df.groupby('id').agg(mean_udf(df.v, df.w)).sort('id').toPandas() expected3 = df.groupby('id').agg(mean(df.v).alias('mean_udf')).sort('id').toPandas() self.assertPandasEqual(expected3, result3) - result4 = df.groupby((col('id') + 1).alias('id')).agg(mean_udf(df.v, df.w)).sort('id').toPandas() - expected4 = df.groupby((col('id') + 1).alias('id')).agg(mean(df.v).alias('mean_udf')).sort('id').toPandas() + result4 = df.groupby((col('id') + 1).alias('id')).agg(mean_udf(df.v, df.w))\ + .sort('id').toPandas() + expected4 = df.groupby((col('id') + 1).alias('id')).agg(mean(df.v).alias('mean_udf'))\ + .sort('id').toPandas() self.assertPandasEqual(expected4, result4) def test_array(self): @@ -4435,15 +4440,15 @@ def weighted_mean_udf(v, w): return np.average(v, weights=w) result1 = df.groupBy('id') \ - .agg(mean_udf(df.v), \ - sum_udf(df.v), \ + .agg(mean_udf(df.v), + sum_udf(df.v), weighted_mean_udf(df.v, df.w)) \ .sort('id') \ .toPandas() expected1 = df.groupBy('id') \ - .agg(mean(df.v).alias('mean_udf'), \ - sum(df.v).alias('sum_udf'), \ + .agg(mean(df.v).alias('mean_udf'), + sum(df.v).alias('sum_udf'), mean(df.v).alias('weighted_mean_udf')) \ .sort('id') \ .toPandas() From 51f478259f9421666eb66385a7f97df3f4810fbf Mon Sep 17 00:00:00 2001 From: Li Jin Date: Tue, 19 Dec 2017 17:16:22 -0500 Subject: [PATCH 09/35] Add docs and more tests --- python/pyspark/sql/functions.py | 31 +++++++ python/pyspark/sql/group.py | 16 +++- python/pyspark/sql/tests.py | 92 +++++++++++++++---- .../spark/sql/RelationalGroupedDataset.scala | 2 - 4 files changed, 121 insertions(+), 20 deletions(-) diff --git a/python/pyspark/sql/functions.py b/python/pyspark/sql/functions.py index 77aaf3d19127f..b800e7e264daa 100644 --- a/python/pyspark/sql/functions.py +++ b/python/pyspark/sql/functions.py @@ -2223,6 +2223,37 @@ def pandas_udf(f=None, returnType=None, functionType=None): .. seealso:: :meth:`pyspark.sql.GroupedData.apply` + 3. GROUP_AGG + + A group aggregate UDF defines a transformation: One or more `pandas.Series` -> A scalar + The returnType should be a primitive data type, e.g, `DoubleType()`. + The returned scalar can be either a python primitive type, e.g., `int` or `float` + or a numpy data type, e.g., `numpy.int64` or `numpy.float64`. + + StructType and ArrayType are currently not supported. + + Group aggregate UDFs are used with :meth:`pyspark.sql.GroupedData.agg` + + >>> from pyspark.sql.functions import pandas_udf, PandasUDFType + >>> df = spark.createDataFrame( + ... [(1, 1.0), (1, 2.0), (2, 3.0), (2, 5.0), (2, 10.0)], + ... ("id", "v")) + >>> @pandas_udf("double", PandasUDFType.GROUP_AGG) + ... def mean_udf(v): + ... return v.mean() + >>> df.groupby("id").agg(mean_udf(df['v'])).show() # doctest: +SKIP + +---+-----------+ + | id|mean_udf(v)| + +---+-----------+ + | 1| 1.5| + | 2| 6.0| + +---+-----------+ + + .. note:: There is no partial aggregation with group aggregate UDFs, i.e., + a full shuffle is required. + + .. seealso:: :meth:`pyspark.sql.GroupedData.agg` + .. note:: The user-defined functions are considered deterministic by default. Due to optimization, duplicate invocations may be eliminated or the function may even be invoked more times than it is present in the query. If your function is not deterministic, call diff --git a/python/pyspark/sql/group.py b/python/pyspark/sql/group.py index 22061b83eb78c..e26ab58748e9b 100644 --- a/python/pyspark/sql/group.py +++ b/python/pyspark/sql/group.py @@ -65,7 +65,14 @@ def __init__(self, jgd, df): def agg(self, *exprs): """Compute aggregates and returns the result as a :class:`DataFrame`. - The available aggregate functions are `avg`, `max`, `min`, `sum`, `count`. + The available aggregate functions can be: + + 1. built-in aggregation functions, such as `avg`, `max`, `min`, `sum`, `count` + + 2. group aggregate pandas UDFs + .. note:: There is no partial aggregation with group aggregate UDFs, i.e., + a full shuffle is required. + .. seealso:: :meth:`pyspark.sql.functions.pandas_udf` If ``exprs`` is a single :class:`dict` mapping from string to string, then the key is the column to perform aggregation on, and the value is the aggregate function. @@ -82,6 +89,13 @@ def agg(self, *exprs): >>> from pyspark.sql import functions as F >>> sorted(gdf.agg(F.min(df.age)).collect()) [Row(name=u'Alice', min(age)=2), Row(name=u'Bob', min(age)=5)] + + >>> from pyspark.sql.functions import pandas_udf, PandasUDFType + >>> @pandas_udf('double', PandasUDFType.GROUP_AGG) + ... def min_udf(v): + ... return v.min() + >>> sorted(gdf.agg(min_udf(df.age))).collect()) + [Row(name=u'Alice', min_udf(age)=2), Row(name=u'Bob', min_udf(age)=5)] """ assert exprs, "exprs should not be empty" if len(exprs) == 1 and isinstance(exprs[0], dict): diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py index 5669af1f95b00..5fd397289e867 100644 --- a/python/pyspark/sql/tests.py +++ b/python/pyspark/sql/tests.py @@ -4370,22 +4370,22 @@ def mean_udf(v, w): return np.average(v, weights=w) result1 = df.groupby('id').agg(mean_udf(df.v, lit(1.0))).sort('id').toPandas() - expected1 = df.groupby('id').agg(mean(df.v).alias('mean_udf')).sort('id').toPandas() + expected1 = df.groupby('id').agg(mean(df.v).alias('mean_udf(v, 1.0)')).sort('id').toPandas() self.assertPandasEqual(expected1, result1) result2 = df.groupby((col('id') + 1).alias('id')).agg(mean_udf(df.v, lit(1.0)))\ .sort('id').toPandas() - expected2 = df.groupby((col('id') + 1).alias('id')).agg(mean(df.v).alias('mean_udf'))\ - .sort('id').toPandas() + expected2 = df.groupby((col('id') + 1).alias('id'))\ + .agg(mean(df.v).alias('mean_udf(v, 1.0)')).sort('id').toPandas() self.assertPandasEqual(expected2, result2) result3 = df.groupby('id').agg(mean_udf(df.v, df.w)).sort('id').toPandas() - expected3 = df.groupby('id').agg(mean(df.v).alias('mean_udf')).sort('id').toPandas() + expected3 = df.groupby('id').agg(mean(df.v).alias('mean_udf(v, w)')).sort('id').toPandas() self.assertPandasEqual(expected3, result3) result4 = df.groupby((col('id') + 1).alias('id')).agg(mean_udf(df.v, df.w))\ .sort('id').toPandas() - expected4 = df.groupby((col('id') + 1).alias('id')).agg(mean(df.v).alias('mean_udf'))\ + expected4 = df.groupby((col('id') + 1).alias('id')).agg(mean(df.v).alias('mean_udf(v, w)'))\ .sort('id').toPandas() self.assertPandasEqual(expected4, result4) @@ -4439,21 +4439,79 @@ def sum_udf(v): def weighted_mean_udf(v, w): return np.average(v, weights=w) - result1 = df.groupBy('id') \ - .agg(mean_udf(df.v), - sum_udf(df.v), - weighted_mean_udf(df.v, df.w)) \ - .sort('id') \ - .toPandas() + result1 = (df.groupBy('id') + .agg(mean_udf(df.v), + sum_udf(df.v), + weighted_mean_udf(df.v, df.w)) + .sort('id') + .toPandas()) + + expected1 = (df.groupBy('id') + .agg(mean(df.v).alias('mean_udf(v)'), + sum(df.v).alias('sum_udf(v)'), + mean(df.v).alias('weighted_mean_udf(v, w)')) + .sort('id') + .toPandas()) + + self.assertPandasEqual(expected1, result1) + + def test_complex_expressions(self): + from pyspark.sql.functions import col, mean, udf, pandas_udf, PandasUDFType + + df = self.data + + @udf('double') + def plus_one(v): + return v + 1 - expected1 = df.groupBy('id') \ - .agg(mean(df.v).alias('mean_udf'), - sum(df.v).alias('sum_udf'), - mean(df.v).alias('weighted_mean_udf')) \ - .sort('id') \ - .toPandas() + @pandas_udf('double', PandasUDFType.GROUP_AGG) + def mean_udf(v): + return v.mean() + + result1 = (df.withColumn('v1', plus_one(df.v)) + .withColumn('v2', df.v + 2) + .groupby('id') + .agg(mean_udf(col('v')), + mean_udf(col('v1')), + mean_udf(col('v2'))) + .sort('id') + .toPandas()) + + expected1 = (df.withColumn('v1', df.v + 1) + .withColumn('v2', df.v + 2) + .groupby('id') + .agg(mean(col('v')).alias('mean_udf(v)'), + mean(col('v1')).alias('mean_udf(v1)'), + mean(col('v2')).alias('mean_udf(v2)')) + .sort('id') + .toPandas()) + + result2 = (df.groupby('id') + .agg(mean_udf(col('v')), + mean_udf(plus_one(df.v)).alias('mean_udf(v1)'), + mean_udf(df.v + 2).alias('mean_udf(v2)')) + .sort('id') + .toPandas()) + + expected2 = expected1 + + result3 = (df.groupby('id') + .agg(mean_udf(df.v).alias('v')) + .groupby('id') + .agg(mean_udf(col('v')).alias('mean_v')) + .sort('id') + .toPandas()) + + expected3 = (df.groupby('id') + .agg(mean(df.v).alias('v')) + .groupby('id') + .agg(mean(col('v')).alias('mean_v')) + .sort('id') + .toPandas()) self.assertPandasEqual(expected1, result1) + self.assertPandasEqual(expected2, result2) + self.assertPandasEqual(expected3, result3) if __name__ == "__main__": diff --git a/sql/core/src/main/scala/org/apache/spark/sql/RelationalGroupedDataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/RelationalGroupedDataset.scala index 4655e86391b83..d320c1c359411 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/RelationalGroupedDataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/RelationalGroupedDataset.scala @@ -85,8 +85,6 @@ class RelationalGroupedDataset protected[sql]( case expr: NamedExpression => expr case a: AggregateExpression if a.aggregateFunction.isInstanceOf[TypedAggregateExpression] => UnresolvedAlias(a, Some(Column.generateAlias)) - case udf: PythonUDF => - UnresolvedAlias(udf, Some(_ => udf.name)) case expr: Expression => Alias(expr, toPrettySQL(expr))() } From 505acdb1108f10c163178b42d725ae3aae0f1b12 Mon Sep 17 00:00:00 2001 From: Li Jin Date: Tue, 19 Dec 2017 17:40:59 -0500 Subject: [PATCH 10/35] Style fix --- python/pyspark/sql/functions.py | 2 +- python/pyspark/sql/group.py | 4 +++- .../spark/sql/execution/python/AggregateInPandasExec.scala | 6 +++--- 3 files changed, 7 insertions(+), 5 deletions(-) diff --git a/python/pyspark/sql/functions.py b/python/pyspark/sql/functions.py index b800e7e264daa..b3d221495d621 100644 --- a/python/pyspark/sql/functions.py +++ b/python/pyspark/sql/functions.py @@ -2250,7 +2250,7 @@ def pandas_udf(f=None, returnType=None, functionType=None): +---+-----------+ .. note:: There is no partial aggregation with group aggregate UDFs, i.e., - a full shuffle is required. + a full shuffle is required. .. seealso:: :meth:`pyspark.sql.GroupedData.agg` diff --git a/python/pyspark/sql/group.py b/python/pyspark/sql/group.py index e26ab58748e9b..3df3cbb3f8125 100644 --- a/python/pyspark/sql/group.py +++ b/python/pyspark/sql/group.py @@ -70,8 +70,10 @@ def agg(self, *exprs): 1. built-in aggregation functions, such as `avg`, `max`, `min`, `sum`, `count` 2. group aggregate pandas UDFs + .. note:: There is no partial aggregation with group aggregate UDFs, i.e., - a full shuffle is required. + a full shuffle is required. + .. seealso:: :meth:`pyspark.sql.functions.pandas_udf` If ``exprs`` is a single :class:`dict` mapping from string to string, then the key diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/AggregateInPandasExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/AggregateInPandasExec.scala index c73eee2038b2f..e2e645ff5dd34 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/AggregateInPandasExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/AggregateInPandasExec.scala @@ -84,9 +84,9 @@ case class AggregateInPandasExec( val argOffsets = inputs.map { input => input.map { e => - allInputs += e - dataTypes += e.dataType - allInputs.length - 1 - groupingExpressions.length + allInputs += e + dataTypes += e.dataType + allInputs.length - 1 - groupingExpressions.length }.toArray }.toArray From 4856e821b973b8bff6fe0a88ab816f86e57fe609 Mon Sep 17 00:00:00 2001 From: Li Jin Date: Tue, 19 Dec 2017 17:47:09 -0500 Subject: [PATCH 11/35] Remove whitespace --- .../org/apache/spark/sql/execution/python/EvalPythonExec.scala | 1 - 1 file changed, 1 deletion(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/EvalPythonExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/EvalPythonExec.scala index 9bd3c0bb15de2..860dc78c1dd1b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/EvalPythonExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/EvalPythonExec.scala @@ -117,7 +117,6 @@ abstract class EvalPythonExec(udfs: Seq[PythonUDF], output: Seq[Attribute], chil } }.toArray }.toArray - val projection = newMutableProjection(allInputs, child.output) val schema = StructType(dataTypes.zipWithIndex.map { case (dt, i) => StructField(s"_$i", dt) From 3287e6e8a7cb9371fa37fccc9fa1a70e6637bae2 Mon Sep 17 00:00:00 2001 From: Li Jin Date: Tue, 19 Dec 2017 23:35:57 -0500 Subject: [PATCH 12/35] Fix doctest --- python/pyspark/sql/group.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/python/pyspark/sql/group.py b/python/pyspark/sql/group.py index 3df3cbb3f8125..857914b991c97 100644 --- a/python/pyspark/sql/group.py +++ b/python/pyspark/sql/group.py @@ -93,10 +93,10 @@ def agg(self, *exprs): [Row(name=u'Alice', min(age)=2), Row(name=u'Bob', min(age)=5)] >>> from pyspark.sql.functions import pandas_udf, PandasUDFType - >>> @pandas_udf('double', PandasUDFType.GROUP_AGG) + >>> @pandas_udf('int', PandasUDFType.GROUP_AGG) ... def min_udf(v): ... return v.min() - >>> sorted(gdf.agg(min_udf(df.age))).collect()) + >>> sorted(gdf.agg(min_udf(df.age)).collect()) [Row(name=u'Alice', min_udf(age)=2), Row(name=u'Bob', min_udf(age)=5)] """ assert exprs, "exprs should not be empty" From 7dcdd3a542bdd987f42ce8a41ed2f57c0d89e18d Mon Sep 17 00:00:00 2001 From: Li Jin Date: Wed, 27 Dec 2017 15:52:30 -0500 Subject: [PATCH 13/35] Tests pass --- python/pyspark/sql/tests.py | 302 ++++++++++++++---- .../sql/catalyst/expressions/PythonUDF.scala | 9 +- .../sql/catalyst/planning/patterns.scala | 5 +- .../spark/sql/execution/SparkOptimizer.scala | 4 +- .../python/AggregateInPandasExec.scala | 21 +- .../execution/python/ExtractPythonUDFs.scala | 115 ++++--- 6 files changed, 336 insertions(+), 120 deletions(-) diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py index 5fd397289e867..c907c5ef8377b 100644 --- a/python/pyspark/sql/tests.py +++ b/python/pyspark/sql/tests.py @@ -203,6 +203,25 @@ def assertPandasEqual(self, expected, result): "\n\nResult:\n%s\n%s" % (result, result.dtypes)) self.assertTrue(expected.equals(result), msg=msg) + def printPlans(self, df): + df._jdf.queryExecution().optimizedPlan() + print() + print("****************** Optimized ********************") + print(df._jdf.queryExecution().optimizedPlan().toString()) + print("*************************************************") + + df._jdf.queryExecution().sparkPlan() + print() + print("****************** Spark Plan *******************") + print(df._jdf.queryExecution().sparkPlan().toString()) + print("*************************************************") + + df._jdf.queryExecution().executedPlan() + print() + print("**************** Executed Plan ******************") + print(df._jdf.queryExecution().executedPlan().toString()) + print("*************************************************") + class DataTypeTests(unittest.TestCase): # regression test for SPARK-6055 @@ -543,6 +562,7 @@ def test_udf_with_aggregate_function(self): sel = df.groupBy(my_copy(col("key")).alias("k"))\ .agg(sum(my_strlen(col("value"))).alias("s"))\ .select(my_add(col("k"), col("s")).alias("t")) + self.printPlans(sel) self.assertEqual(sel.collect(), [Row(t=4), Row(t=3)]) def test_udf_in_generate(self): @@ -4359,35 +4379,84 @@ def data(self): .drop('vs') \ .withColumn('w', lit(1.0)) - def test_basic(self): - import numpy as np - from pyspark.sql.functions import pandas_udf, PandasUDFType, col, lit, sum, mean + @property + def plus_one(self): + from pyspark.sql.functions import udf - df = self.data + @udf('double') + def plus_one(v): + assert isinstance(v, float) + return v + 1 + return plus_one + + @property + def plus_two(self): + import pandas as pd + from pyspark.sql.functions import pandas_udf, PandasUDFType + + @pandas_udf('double', PandasUDFType.SCALAR) + def plus_two(v): + assert isinstance(v, pd.Series) + return v + 2 + return plus_two + + @property + def mean_udf(self): + from pyspark.sql.functions import pandas_udf, PandasUDFType @pandas_udf('double', PandasUDFType.GROUP_AGG) - def mean_udf(v, w): - return np.average(v, weights=w) + def mean_udf(v): + return v.mean() + return mean_udf - result1 = df.groupby('id').agg(mean_udf(df.v, lit(1.0))).sort('id').toPandas() - expected1 = df.groupby('id').agg(mean(df.v).alias('mean_udf(v, 1.0)')).sort('id').toPandas() - self.assertPandasEqual(expected1, result1) + @property + def sum_udf(self): + from pyspark.sql.functions import pandas_udf, PandasUDFType - result2 = df.groupby((col('id') + 1).alias('id')).agg(mean_udf(df.v, lit(1.0)))\ - .sort('id').toPandas() - expected2 = df.groupby((col('id') + 1).alias('id'))\ - .agg(mean(df.v).alias('mean_udf(v, 1.0)')).sort('id').toPandas() - self.assertPandasEqual(expected2, result2) + @pandas_udf('double', PandasUDFType.GROUP_AGG) + def sum_udf(v): + return v.sum() + return sum_udf - result3 = df.groupby('id').agg(mean_udf(df.v, df.w)).sort('id').toPandas() - expected3 = df.groupby('id').agg(mean(df.v).alias('mean_udf(v, w)')).sort('id').toPandas() - self.assertPandasEqual(expected3, result3) + @property + def weighted_mean_udf(self): + import numpy as np + from pyspark.sql.functions import pandas_udf, PandasUDFType - result4 = df.groupby((col('id') + 1).alias('id')).agg(mean_udf(df.v, df.w))\ - .sort('id').toPandas() - expected4 = df.groupby((col('id') + 1).alias('id')).agg(mean(df.v).alias('mean_udf(v, w)'))\ - .sort('id').toPandas() - self.assertPandasEqual(expected4, result4) + @pandas_udf('double', PandasUDFType.GROUP_AGG) + def weighted_mean_udf(v, w): + return np.average(v, weights=w) + return weighted_mean_udf + + def test_basic(self): + from pyspark.sql.functions import col, lit, sum, mean + + self.spark.conf.set("spark.sql.codegen.wholeStage", False) + + df = self.data + weighted_mean_udf = self.weighted_mean_udf + + result1 = df.groupby('id').agg(weighted_mean_udf(df.v, lit(1.0))).sort('id') + expected1 = df.groupby('id').agg(mean(df.v).alias('weighted_mean_udf(v, 1.0)')).sort('id') + self.assertPandasEqual(expected1.toPandas(), result1.toPandas()) + + result2 = df.groupby((col('id') + 1)).agg(weighted_mean_udf(df.v, lit(1.0)))\ + .sort(df.id + 1) + expected2 = df.groupby((col('id') + 1))\ + .agg(mean(df.v).alias('weighted_mean_udf(v, 1.0)')).sort(df.id + 1) + self.assertPandasEqual(expected2.toPandas(), result2.toPandas()) + + result3 = df.groupby('id').agg(weighted_mean_udf(df.v, df.w)).sort('id') + expected3 = df.groupby('id').agg(mean(df.v).alias('weighted_mean_udf(v, w)')).sort('id') + self.assertPandasEqual(expected3.toPandas(), result3.toPandas()) + + result4 = df.groupby((col('id') + 1).alias('id'))\ + .agg(weighted_mean_udf(df.v, df.w))\ + .sort('id') + expected4 = df.groupby((col('id') + 1).alias('id'))\ + .agg(mean(df.v).alias('weighted_mean_udf(v, w)'))\ + .sort('id') + self.assertPandasEqual(expected4.toPandas(), result4.toPandas()) def test_array(self): from pyspark.sql.types import ArrayType, DoubleType @@ -4409,35 +4478,93 @@ def mean_and_std_udf(v): return (v.mean(), v.std()) def test_alias(self): - from pyspark.sql.functions import pandas_udf, PandasUDFType, mean + from pyspark.sql.functions import mean df = self.data + mean_udf = self.mean_udf - @pandas_udf('double', PandasUDFType.GROUP_AGG) - def mean_udf(v): - return v.mean() + result1 = df.groupby('id').agg(mean_udf(df.v).alias('mean_alias')) + expected1 = df.groupby('id').agg(mean(df.v).alias('mean_alias')) - result1 = df.groupby('id').agg(mean_udf(df.v).alias('mean_alias')).toPandas() - expected1 = df.groupby('id').agg(mean(df.v).alias('mean_alias')).toPandas() + self.assertPandasEqual(expected1.toPandas(), result1.toPandas()) - self.assertPandasEqual(expected1, result1) + def test_mixed_sql(self): + from pyspark.sql.functions import sum, mean - def test_multiple(self): - import numpy as np - from pyspark.sql.functions import pandas_udf, PandasUDFType, col, lit, sum, mean df = self.data + sum_udf = self.sum_udf - @pandas_udf('double', PandasUDFType.GROUP_AGG) - def mean_udf(v): - return v.mean() + result1 = (df.groupby('id') + .agg(sum_udf(df.v) + 1) + .sort('id')) - @pandas_udf('double', PandasUDFType.GROUP_AGG) - def sum_udf(v): - return v.sum() + expected1 = (df.groupby('id') + .agg((sum(df.v) + 1).alias('(sum_udf(v) + 1)')) + .sort('id')) - @pandas_udf('double', PandasUDFType.GROUP_AGG) - def weighted_mean_udf(v, w): - return np.average(v, weights=w) + result2 = (df.groupby('id') + .agg(sum_udf(df.v + 1)) + .sort('id')) + + expected2 = (df.groupby('id') + .agg(sum(df.v + 1).alias('sum_udf((v + 1))')) + .sort('id')) + + self.assertPandasEqual(expected1.toPandas(), result1.toPandas()) + self.assertPandasEqual(expected2.toPandas(), result2.toPandas()) + + def test_mixed_udf(self): + from pyspark.sql.functions import sum, mean + + df = self.data + plus_one = self.plus_one + plus_two = self.plus_two + sum_udf = self.sum_udf + + result1 = (df.groupby('id') + .agg(plus_one(sum_udf(df.v))) + .sort('id')) + + expected1 = (df.groupby('id') + .agg(plus_one(sum(df.v)).alias("plus_one(sum_udf(v))")) + .sort('id')) + + result2 = (df.groupby('id') + .agg(sum_udf(plus_one(df.v))) + .sort('id')) + + expected2 = (df.groupby('id') + .agg(sum(df.v + 1).alias("sum_udf(plus_one(v))")) + .sort('id')) + + result3 = (df.groupby('id') + .agg(sum_udf(plus_two(df.v))) + .sort('id')) + + expected3 = (df.groupby('id') + .agg(sum(df.v + 2).alias("sum_udf(plus_two(v))")) + .sort('id')) + + result4 = (df.groupby('id') + .agg(plus_two(sum_udf(df.v))) + .sort('id')) + + expected4 = (df.groupby('id') + .agg(plus_two(sum(df.v)).alias("plus_two(sum_udf(v))")) + .sort('id')) + + self.assertPandasEqual(expected1.toPandas(), result1.toPandas()) + self.assertPandasEqual(expected2.toPandas(), result2.toPandas()) + self.assertPandasEqual(expected3.toPandas(), result3.toPandas()) + self.assertPandasEqual(expected4.toPandas(), result4.toPandas()) + + def test_multiple(self): + from pyspark.sql.functions import col, lit, sum, mean + + df = self.data + mean_udf = self.mean_udf + sum_udf = self.sum_udf + weighted_mean_udf = self.weighted_mean_udf result1 = (df.groupBy('id') .agg(mean_udf(df.v), @@ -4453,59 +4580,84 @@ def weighted_mean_udf(v, w): .sort('id') .toPandas()) - self.assertPandasEqual(expected1, result1) + result2 = (df.groupBy('id', 'v') + .agg(mean_udf(df.v), + sum_udf(df.id)) + .sort('id', 'v') + .toPandas()) - def test_complex_expressions(self): - from pyspark.sql.functions import col, mean, udf, pandas_udf, PandasUDFType + expected2 = (df.groupBy('id', 'v') + .agg(mean_udf(df.v).alias('mean_udf(v)'), + sum_udf(df.id).alias('sum_udf(id)')) + .sort('id', 'v') + .toPandas()) - df = self.data + self.assertPandasEqual(expected1, result1) + self.assertPandasEqual(expected2, result2) - @udf('double') - def plus_one(v): - return v + 1 + def test_complex(self): + from pyspark.sql.functions import col, sum - @pandas_udf('double', PandasUDFType.GROUP_AGG) - def mean_udf(v): - return v.mean() + df = self.data + plus_one = self.plus_one + plus_two = self.plus_two + sum_udf = self.sum_udf result1 = (df.withColumn('v1', plus_one(df.v)) .withColumn('v2', df.v + 2) .groupby('id') - .agg(mean_udf(col('v')), - mean_udf(col('v1')), - mean_udf(col('v2'))) + .agg(sum_udf(col('v')), + sum_udf(col('v1') + 3), + sum_udf(col('v2')) + 5, + plus_one(sum_udf(col('v1'))), + sum_udf(plus_one(col('v2')))) .sort('id') .toPandas()) expected1 = (df.withColumn('v1', df.v + 1) .withColumn('v2', df.v + 2) .groupby('id') - .agg(mean(col('v')).alias('mean_udf(v)'), - mean(col('v1')).alias('mean_udf(v1)'), - mean(col('v2')).alias('mean_udf(v2)')) + .agg(sum(col('v')).alias('sum_udf(v)'), + sum(col('v1') + 3).alias('sum_udf((v1 + 3))'), + (sum(col('v2')) + 5).alias('(sum_udf(v2) + 5)'), + plus_one(sum(col('v1'))).alias('plus_one(sum_udf(v1))'), + sum(col('v2') + 1).alias('sum_udf(plus_one(v2))')) .sort('id') .toPandas()) - result2 = (df.groupby('id') - .agg(mean_udf(col('v')), - mean_udf(plus_one(df.v)).alias('mean_udf(v1)'), - mean_udf(df.v + 2).alias('mean_udf(v2)')) + result2 = (df.withColumn('v1', plus_one(df.v)) + .withColumn('v2', df.v + 2) + .groupby('id') + .agg(sum_udf(col('v')), + sum_udf(col('v1') + 3), + sum_udf(col('v2')) + 5, + plus_two(sum_udf(col('v1'))), + sum_udf(plus_two(col('v2')))) .sort('id') .toPandas()) - expected2 = expected1 + expected2 = (df.withColumn('v1', df.v + 1) + .withColumn('v2', df.v + 2) + .groupby('id') + .agg(sum(col('v')).alias('sum_udf(v)'), + sum(col('v1') + 3).alias('sum_udf((v1 + 3))'), + (sum(col('v2')) + 5).alias('(sum_udf(v2) + 5)'), + plus_two(sum(col('v1'))).alias('plus_two(sum_udf(v1))'), + sum(col('v2') + 2).alias('sum_udf(plus_two(v2))')) + .sort('id') + .toPandas()) result3 = (df.groupby('id') - .agg(mean_udf(df.v).alias('v')) + .agg(sum_udf(df.v).alias('v')) .groupby('id') - .agg(mean_udf(col('v')).alias('mean_v')) + .agg(sum_udf(col('v')).alias('sum_v')) .sort('id') .toPandas()) expected3 = (df.groupby('id') - .agg(mean(df.v).alias('v')) + .agg(sum(df.v).alias('v')) .groupby('id') - .agg(mean(col('v')).alias('mean_v')) + .agg(sum(col('v')).alias('sum_v')) .sort('id') .toPandas()) @@ -4513,6 +4665,24 @@ def mean_udf(v): self.assertPandasEqual(expected2, result2) self.assertPandasEqual(expected3, result3) + def test_retain_group_columns(self): + from pyspark.sql.functions import sum, lit, col + orig_value = self.spark.conf.get("spark.sql.retainGroupColumns", None) + self.spark.conf.set("spark.sql.retainGroupColumns", False) + try: + df = self.data + sum_udf = self.sum_udf + + result1 = df.groupby(df.id).agg(sum_udf(df.v)) + expected1 = df.groupby(df.id).agg(sum(df.v).alias('sum_udf(v)')) + self.assertPandasEqual(expected1.toPandas(), result1.toPandas()) + + finally: + if orig_value is None: + self.spark.conf.unset("spark.sql.retainGroupColumns") + else: + self.spark.conf.set("spark.sql.retainGroupColumns", orig_value) + if __name__ == "__main__": from pyspark.sql.tests import * diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/PythonUDF.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/PythonUDF.scala index e112582a8fdd0..e311659b18e53 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/PythonUDF.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/PythonUDF.scala @@ -18,6 +18,7 @@ package org.apache.spark.sql.catalyst.expressions import org.apache.spark.api.python.PythonFunction +import org.apache.spark.sql.catalyst.util.toPrettySQL import org.apache.spark.sql.types.DataType /** @@ -29,14 +30,16 @@ case class PythonUDF( dataType: DataType, children: Seq[Expression], evalType: Int, - udfDeterministic: Boolean) + udfDeterministic: Boolean, + resultId: ExprId = NamedExpression.newExprId) extends Expression with Unevaluable with NonSQLExpression with UserDefinedExpression { override lazy val deterministic: Boolean = udfDeterministic && children.forall(_.deterministic) - lazy val resultAttribute: Attribute = AttributeReference(name, dataType)() - override def toString: String = s"$name(${children.mkString(", ")})" + lazy val resultAttribute: Attribute = AttributeReference(toPrettySQL(this), dataType, nullable)( + exprId = resultId) + override def nullable: Boolean = true } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala index 0a2268f4a188f..a7e32dc8c4fdb 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala @@ -215,7 +215,7 @@ object PhysicalAggregation { // addExpr() always returns false for non-deterministic expressions and do not add them. case agg: AggregateExpression if !equivalentAggregateExpressions.addExpr(agg) => agg - case agg @ PythonUDF(_, _, _, _, PythonEvalType.SQL_PANDAS_GROUP_AGG_UDF) + case agg @ PythonUDF(_, _, _, _, PythonEvalType.SQL_PANDAS_GROUP_AGG_UDF, _) if !equivalentAggregateExpressions.addExpr(agg) => agg } } @@ -244,6 +244,9 @@ object PhysicalAggregation { // so replace each aggregate expression by its corresponding attribute in the set: equivalentAggregateExpressions.getEquivalentExprs(ae).headOption .getOrElse(ae).asInstanceOf[AggregateExpression].resultAttribute + case ue: PythonUDF => + equivalentAggregateExpressions.getEquivalentExprs(ue).headOption + .getOrElse(ue).asInstanceOf[PythonUDF].resultAttribute case expression => // Since we're using `namedGroupingAttributes` to extract the grouping key // columns, we need to replace grouping key expressions with their corresponding diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkOptimizer.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkOptimizer.scala index 1c8e4050978dc..f665852e0795f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkOptimizer.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkOptimizer.scala @@ -22,7 +22,7 @@ import org.apache.spark.sql.catalyst.catalog.SessionCatalog import org.apache.spark.sql.catalyst.optimizer.Optimizer import org.apache.spark.sql.execution.datasources.PruneFileSourcePartitions import org.apache.spark.sql.execution.datasources.v2.PushDownOperatorsToDataSource -import org.apache.spark.sql.execution.python.ExtractPythonUDFFromAggregate +import org.apache.spark.sql.execution.python.{ExtractGroupAggPandasUDFFromAggregate, ExtractPythonUDFFromAggregate} class SparkOptimizer( catalog: SessionCatalog, @@ -32,6 +32,8 @@ class SparkOptimizer( override def batches: Seq[Batch] = (preOptimizationBatches ++ super.batches :+ Batch("Optimize Metadata Only Query", Once, OptimizeMetadataOnlyQuery(catalog)) :+ Batch("Extract Python UDF from Aggregate", Once, ExtractPythonUDFFromAggregate) :+ + Batch("Extract group aggregate Pandas UDF from Aggregate", + Once, ExtractGroupAggPandasUDFFromAggregate) :+ Batch("Prune File Source Table Partitions", Once, PruneFileSourcePartitions) :+ Batch("Push down operators to data source scan", Once, PushDownOperatorsToDataSource)) ++ postHocOptimizationBatches :+ diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/AggregateInPandasExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/AggregateInPandasExec.scala index e2e645ff5dd34..73c78324bc54f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/AggregateInPandasExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/AggregateInPandasExec.scala @@ -25,20 +25,20 @@ import org.apache.spark.{SparkEnv, TaskContext} import org.apache.spark.api.python.{ChainedPythonFunctions, PythonEvalType} import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.expressions.{Ascending, Attribute, AttributeSet, Expression, JoinedRow, NamedExpression, PythonUDF, SortOrder, UnsafeProjection, UnsafeRow} +import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.physical.{AllTuples, ClusteredDistribution, Distribution, Partitioning} import org.apache.spark.sql.execution.{GroupedIterator, SparkPlan, UnaryExecNode} import org.apache.spark.sql.types.{DataType, StructField, StructType} import org.apache.spark.util.Utils case class AggregateInPandasExec( - groupingExpressions: Seq[Expression], + groupingExpressions: Seq[NamedExpression], udfExpressions: Seq[PythonUDF], resultExpressions: Seq[NamedExpression], child: SparkPlan) extends UnaryExecNode { - override def output: Seq[Attribute] = resultExpressions.map(_.toAttribute) + override val output: Seq[Attribute] = resultExpressions.map(_.toAttribute) override def outputPartitioning: Partitioning = child.outputPartitioning @@ -80,13 +80,11 @@ case class AggregateInPandasExec( val allInputs = new ArrayBuffer[Expression] val dataTypes = new ArrayBuffer[DataType] - allInputs.appendAll(groupingExpressions) - val argOffsets = inputs.map { input => input.map { e => allInputs += e dataTypes += e.dataType - allInputs.length - 1 - groupingExpressions.length + allInputs.length - 1 }.toArray }.toArray @@ -94,17 +92,16 @@ case class AggregateInPandasExec( StructField(s"_$i", dt) }) + val input = groupingExpressions.map(_.toAttribute) ++ udfExpressions.map(_.resultAttribute) + inputRDD.mapPartitionsInternal { iter => val grouped = if (groupingExpressions.isEmpty) { Iterator((null, iter)) } else { val groupedIter = GroupedIterator(iter, groupingExpressions, child.output) - - val dropGrouping = - UnsafeProjection.create(allInputs.drop(groupingExpressions.length), child.output) - + val proj = UnsafeProjection.create(allInputs, child.output) groupedIter.map { - case (k, groupedRowIter) => (k, groupedRowIter.map(dropGrouping)) + case (k, groupedRowIter) => (k, groupedRowIter.map(proj)) } } @@ -131,7 +128,7 @@ case class AggregateInPandasExec( .compute(projectedRowIter, context.partitionId(), context) val joined = new JoinedRow - val resultProj = UnsafeProjection.create(output, output) + val resultProj = UnsafeProjection.create(resultExpressions, input) columnarBatchIter.map(_.rowIterator.next()).map { outputRow => val leftRow = queue.remove() diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ExtractPythonUDFs.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ExtractPythonUDFs.scala index 1941af61f3e7e..0d3713c8d731e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ExtractPythonUDFs.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ExtractPythonUDFs.scala @@ -34,6 +34,12 @@ import org.apache.spark.sql.execution.{FilterExec, ProjectExec, SparkPlan} */ object ExtractPythonUDFFromAggregate extends Rule[LogicalPlan] { + private def isPythonUDF(e: Expression): Boolean = { + e.isInstanceOf[PythonUDF] && + Set(PythonEvalType.SQL_BATCHED_UDF, PythonEvalType.SQL_PANDAS_SCALAR_UDF + ).contains(e.asInstanceOf[PythonUDF].evalType) + } + /** * Returns whether the expression could only be evaluated within aggregate. */ @@ -44,50 +50,34 @@ object ExtractPythonUDFFromAggregate extends Rule[LogicalPlan] { private def hasPythonUdfOverAggregate(expr: Expression, agg: Aggregate): Boolean = { expr.find { - e => e.isInstanceOf[PythonUDF] && e.find(belongAggregate(_, agg)).isDefined + e => isPythonUDF(e) && e.find(belongAggregate(_, agg)).isDefined }.isDefined } - private def isPandasGroupAggUdf(expr: Expression): Boolean = expr match { - case PythonUDF(_, _, _, _, PythonEvalType.SQL_PANDAS_GROUP_AGG_UDF) => true - case Alias(child, _) => isPandasGroupAggUdf(child) - case _ => false - } - - private def hasPandasGroupAggUdf(agg: Aggregate): Boolean = { - val actualAggExpr = agg.aggregateExpressions.drop(agg.groupingExpressions.length) - actualAggExpr.exists(isPandasGroupAggUdf) - } - - private def extract(agg: Aggregate): LogicalPlan = { val projList = new ArrayBuffer[NamedExpression]() val aggExpr = new ArrayBuffer[NamedExpression]() - if (hasPandasGroupAggUdf(agg)) { - Aggregate(agg.groupingExpressions, agg.aggregateExpressions, agg.child) - } else { - agg.aggregateExpressions.foreach { expr => - if (hasPythonUdfOverAggregate(expr, agg)) { - // Python UDF can only be evaluated after aggregate - val newE = expr transformDown { - case e: Expression if belongAggregate(e, agg) => - val alias = e match { - case a: NamedExpression => a - case o => Alias(e, "agg")() - } - aggExpr += alias - alias.toAttribute - } - projList += newE.asInstanceOf[NamedExpression] - } else { - aggExpr += expr - projList += expr.toAttribute + agg.aggregateExpressions.foreach { expr => + if (hasPythonUdfOverAggregate(expr, agg)) { + // Python UDF can only be evaluated after aggregate + val newE = expr transformDown { + case e: Expression if belongAggregate(e, agg) => + val alias = e match { + case a: NamedExpression => a + case o => Alias(e, "agg")() + } + aggExpr += alias + alias.toAttribute } + projList += newE.asInstanceOf[NamedExpression] + } else { + aggExpr += expr + projList += expr.toAttribute } - // There is no Python UDF over aggregate expression - Project(projList, agg.copy(aggregateExpressions = aggExpr)) } + // There is no Python UDF over aggregate expression + Project(projList, agg.copy(aggregateExpressions = aggExpr)) } def apply(plan: LogicalPlan): LogicalPlan = plan transformUp { @@ -109,8 +99,14 @@ object ExtractPythonUDFFromAggregate extends Rule[LogicalPlan] { */ object ExtractPythonUDFs extends Rule[SparkPlan] with PredicateHelper { + private def isPythonUDF(e: Expression): Boolean = { + e.isInstanceOf[PythonUDF] && + Set(PythonEvalType.SQL_BATCHED_UDF, PythonEvalType.SQL_PANDAS_SCALAR_UDF + ).contains(e.asInstanceOf[PythonUDF].evalType) + } + private def hasPythonUDF(e: Expression): Boolean = { - e.find(_.isInstanceOf[PythonUDF]).isDefined + e.find(isPythonUDF).isDefined } private def canEvaluateInPython(e: PythonUDF): Boolean = { @@ -123,14 +119,13 @@ object ExtractPythonUDFs extends Rule[SparkPlan] with PredicateHelper { } private def collectEvaluatableUDF(expr: Expression): Seq[PythonUDF] = expr match { - case udf: PythonUDF if canEvaluateInPython(udf) => Seq(udf) + case udf: PythonUDF if isPythonUDF(udf) && canEvaluateInPython(udf) => Seq(udf) case e => e.children.flatMap(collectEvaluatableUDF) } def apply(plan: SparkPlan): SparkPlan = plan transformUp { // AggregateInPandasExec and FlatMapGroupsInPandas can be evaluated directly in python worker // Therefore we don't need to extract the UDFs - case plan: AggregateInPandasExec => plan case plan: FlatMapGroupsInPandasExec => plan case plan: SparkPlan => extract(plan) } @@ -233,3 +228,49 @@ object ExtractPythonUDFs extends Rule[SparkPlan] with PredicateHelper { } } } + + +/** + * Extract all the group aggregate Pandas UDFs in logical aggregation, evaluate the UDFs first + * and then the expressions that depend on the result of the UDFs. + */ +object ExtractGroupAggPandasUDFFromAggregate extends Rule[LogicalPlan] { + + private def isPandasGroupAggUdf(expr: Expression): Boolean = { + expr.isInstanceOf[PythonUDF] && + expr.asInstanceOf[PythonUDF].evalType == PythonEvalType.SQL_PANDAS_GROUP_AGG_UDF + } + + private def hasPandasGroupAggUdf(expr: Expression): Boolean = { + expr.find(isPandasGroupAggUdf).isDefined + } + + private def extract(agg: Aggregate): LogicalPlan = { + val projList = new ArrayBuffer[NamedExpression]() + val aggExpr = new ArrayBuffer[NamedExpression]() + + agg.aggregateExpressions.foreach { expr => + if (hasPandasGroupAggUdf(expr)) { + val newE = expr transformDown { + case e: PythonUDF if isPandasGroupAggUdf(e) => + // Wrap the UDF with alias to make it a NamedExpression + // The alias is intermediate, its attribute name doesn't affect the final result + val alias = Alias(e, "agg")(exprId = e.resultId) + aggExpr += alias + alias.toAttribute + } + projList += newE.asInstanceOf[NamedExpression] + } else { + aggExpr += expr + projList += expr.toAttribute + } + } + + Project(projList, agg.copy(aggregateExpressions = aggExpr)) + } + + override def apply(plan: LogicalPlan): LogicalPlan = plan transformUp { + case agg: Aggregate if agg.aggregateExpressions.exists(hasPandasGroupAggUdf) => + extract(agg) + } +} From 1c834f7b34b912e84ff1826db16284ea3f5cafa0 Mon Sep 17 00:00:00 2001 From: Li Jin Date: Wed, 27 Dec 2017 16:54:10 -0500 Subject: [PATCH 14/35] Fix merge error --- .../scala/org/apache/spark/sql/catalyst/planning/patterns.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala index a7e32dc8c4fdb..1f5d2eb123792 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala @@ -215,7 +215,7 @@ object PhysicalAggregation { // addExpr() always returns false for non-deterministic expressions and do not add them. case agg: AggregateExpression if !equivalentAggregateExpressions.addExpr(agg) => agg - case agg @ PythonUDF(_, _, _, _, PythonEvalType.SQL_PANDAS_GROUP_AGG_UDF, _) + case agg @ PythonUDF(_, _, _, _, PythonEvalType.SQL_PANDAS_GROUP_AGG_UDF, _, _) if !equivalentAggregateExpressions.addExpr(agg) => agg } } From 066783e38e6418e5546f8f681ab564dc3696930f Mon Sep 17 00:00:00 2001 From: Li Jin Date: Thu, 28 Dec 2017 11:00:36 -0500 Subject: [PATCH 15/35] Fix test --- python/pyspark/sql/tests.py | 22 ------------------- .../sql/catalyst/planning/patterns.scala | 7 +++--- 2 files changed, 4 insertions(+), 25 deletions(-) diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py index c907c5ef8377b..673621073b943 100644 --- a/python/pyspark/sql/tests.py +++ b/python/pyspark/sql/tests.py @@ -203,25 +203,6 @@ def assertPandasEqual(self, expected, result): "\n\nResult:\n%s\n%s" % (result, result.dtypes)) self.assertTrue(expected.equals(result), msg=msg) - def printPlans(self, df): - df._jdf.queryExecution().optimizedPlan() - print() - print("****************** Optimized ********************") - print(df._jdf.queryExecution().optimizedPlan().toString()) - print("*************************************************") - - df._jdf.queryExecution().sparkPlan() - print() - print("****************** Spark Plan *******************") - print(df._jdf.queryExecution().sparkPlan().toString()) - print("*************************************************") - - df._jdf.queryExecution().executedPlan() - print() - print("**************** Executed Plan ******************") - print(df._jdf.queryExecution().executedPlan().toString()) - print("*************************************************") - class DataTypeTests(unittest.TestCase): # regression test for SPARK-6055 @@ -562,7 +543,6 @@ def test_udf_with_aggregate_function(self): sel = df.groupBy(my_copy(col("key")).alias("k"))\ .agg(sum(my_strlen(col("value"))).alias("s"))\ .select(my_add(col("k"), col("s")).alias("t")) - self.printPlans(sel) self.assertEqual(sel.collect(), [Row(t=4), Row(t=3)]) def test_udf_in_generate(self): @@ -4431,8 +4411,6 @@ def weighted_mean_udf(v, w): def test_basic(self): from pyspark.sql.functions import col, lit, sum, mean - self.spark.conf.set("spark.sql.codegen.wholeStage", False) - df = self.data weighted_mean_udf = self.weighted_mean_udf diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala index 1f5d2eb123792..520080a467571 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala @@ -215,8 +215,8 @@ object PhysicalAggregation { // addExpr() always returns false for non-deterministic expressions and do not add them. case agg: AggregateExpression if !equivalentAggregateExpressions.addExpr(agg) => agg - case agg @ PythonUDF(_, _, _, _, PythonEvalType.SQL_PANDAS_GROUP_AGG_UDF, _, _) - if !equivalentAggregateExpressions.addExpr(agg) => agg + case udf: PythonUDF if udf.evalType == PythonEvalType.SQL_PANDAS_GROUP_AGG_UDF && + !equivalentAggregateExpressions.addExpr(udf) => udf } } @@ -244,7 +244,8 @@ object PhysicalAggregation { // so replace each aggregate expression by its corresponding attribute in the set: equivalentAggregateExpressions.getEquivalentExprs(ae).headOption .getOrElse(ae).asInstanceOf[AggregateExpression].resultAttribute - case ue: PythonUDF => + // Similar to AggregateExpression + case ue: PythonUDF if ue.evalType == PythonEvalType.SQL_PANDAS_GROUP_AGG_UDF => equivalentAggregateExpressions.getEquivalentExprs(ue).headOption .getOrElse(ue).asInstanceOf[PythonUDF].resultAttribute case expression => From cd164857b2b78ea5004a1ebb091786f43d6012c7 Mon Sep 17 00:00:00 2001 From: Li Jin Date: Fri, 29 Dec 2017 15:20:43 -0500 Subject: [PATCH 16/35] Add complex_grouping test --- .../spark/api/python/PythonRunner.scala | 1 - python/pyspark/sql/tests.py | 50 +++++++++++++++++-- .../sql/catalyst/expressions/PythonUDF.scala | 20 +++++++- .../spark/sql/execution/SparkOptimizer.scala | 4 +- .../python/AggregateInPandasExec.scala | 13 ++--- .../execution/python/ExtractPythonUDFs.scala | 35 +++---------- 6 files changed, 81 insertions(+), 42 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/api/python/PythonRunner.scala b/core/src/main/scala/org/apache/spark/api/python/PythonRunner.scala index ffc0777c0aa46..29148a7ee558b 100644 --- a/core/src/main/scala/org/apache/spark/api/python/PythonRunner.scala +++ b/core/src/main/scala/org/apache/spark/api/python/PythonRunner.scala @@ -48,7 +48,6 @@ private[spark] object PythonEvalType { case SQL_PANDAS_GROUP_MAP_UDF => "SQL_PANDAS_GROUP_MAP_UDF" case SQL_PANDAS_GROUP_AGG_UDF => "SQL_PANDAS_GROUP_AGG_UDF" } - } /** diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py index 673621073b943..15eff4481e688 100644 --- a/python/pyspark/sql/tests.py +++ b/python/pyspark/sql/tests.py @@ -577,7 +577,6 @@ def test_udf_with_order_by_and_limit(self): my_copy = udf(lambda x: x, IntegerType()) df = self.spark.range(10).orderBy("id") res = df.select(df.id, my_copy(df.id).alias("copy")).limit(1) - res.explain(True) self.assertEqual(res.collect(), [Row(id=0, copy=0)]) def test_udf_registration_returns_udf(self): @@ -4365,7 +4364,7 @@ def plus_one(self): @udf('double') def plus_one(v): - assert isinstance(v, float) + assert isinstance(v, (int, float)) return v + 1 return plus_one @@ -4531,10 +4530,18 @@ def test_mixed_udf(self): .agg(plus_two(sum(df.v)).alias("plus_two(sum_udf(v))")) .sort('id')) + result5 = (df.groupby(plus_one(df.id)) + .agg(plus_one(sum_udf(plus_one(df.v)))) + .sort('plus_one(id)')) + expected5 = (df.groupby(plus_one(df.id)) + .agg(plus_one(sum(plus_one(df.v))).alias('plus_one(sum_udf(plus_one(v)))')) + .sort('plus_one(id)')) + self.assertPandasEqual(expected1.toPandas(), result1.toPandas()) self.assertPandasEqual(expected2.toPandas(), result2.toPandas()) self.assertPandasEqual(expected3.toPandas(), result3.toPandas()) self.assertPandasEqual(expected4.toPandas(), result4.toPandas()) + self.assertPandasEqual(expected5.toPandas(), result5.toPandas()) def test_multiple(self): from pyspark.sql.functions import col, lit, sum, mean @@ -4573,7 +4580,40 @@ def test_multiple(self): self.assertPandasEqual(expected1, result1) self.assertPandasEqual(expected2, result2) - def test_complex(self): + def test_complex_grouping(self): + from pyspark.sql.functions import lit, sum + + df = self.data + sum_udf = self.sum_udf + plus_one = self.plus_one + plus_two = self.plus_two + + result1 = df.groupby(df.id + 1).agg(sum_udf(df.v)) + expected1 = df.groupby(df.id + 1).agg(sum(df.v).alias('sum_udf(v)')) + + result2 = df.groupby().agg(sum_udf(df.v)) + expected2 = df.groupby().agg(sum(df.v).alias('sum_udf(v)')) + + result3 = df.groupby(df.id, df.v % 2).agg(sum_udf(df.v)) + expected3 = df.groupby(df.id, df.v % 2).agg(sum(df.v).alias('sum_udf(v)')) + + result4 = df.groupby(plus_one(df.id)).agg(sum_udf(df.v)) + expected4 = df.groupby(plus_one(df.id)).agg(sum(df.v).alias('sum_udf(v)')) + + result5 = df.groupby(plus_two(df.id)).agg(sum_udf(df.v)) + expected5 = df.groupby(plus_two(df.id)).agg(sum(df.v).alias('sum_udf(v)')) + + result6 = df.groupby(df.id, plus_one(df.id)).agg(sum_udf(df.v)) + expected6 = df.groupby(df.id, plus_one(df.id)).agg(sum(df.v).alias('sum_udf(v)')) + + self.assertPandasEqual(expected1.toPandas(), result1.toPandas()) + self.assertPandasEqual(expected2.toPandas(), result2.toPandas()) + self.assertPandasEqual(expected3.toPandas(), result3.toPandas()) + self.assertPandasEqual(expected4.toPandas(), result4.toPandas()) + self.assertPandasEqual(expected5.toPandas(), result5.toPandas()) + self.assertPandasEqual(expected6.toPandas(), result6.toPandas()) + + def test_complex_expression(self): from pyspark.sql.functions import col, sum df = self.data @@ -4599,7 +4639,7 @@ def test_complex(self): sum(col('v1') + 3).alias('sum_udf((v1 + 3))'), (sum(col('v2')) + 5).alias('(sum_udf(v2) + 5)'), plus_one(sum(col('v1'))).alias('plus_one(sum_udf(v1))'), - sum(col('v2') + 1).alias('sum_udf(plus_one(v2))')) + sum(plus_one(col('v2'))).alias('sum_udf(plus_one(v2))')) .sort('id') .toPandas()) @@ -4621,7 +4661,7 @@ def test_complex(self): sum(col('v1') + 3).alias('sum_udf((v1 + 3))'), (sum(col('v2')) + 5).alias('(sum_udf(v2) + 5)'), plus_two(sum(col('v1'))).alias('plus_two(sum_udf(v1))'), - sum(col('v2') + 2).alias('sum_udf(plus_two(v2))')) + sum(plus_two(col('v2'))).alias('sum_udf(plus_two(v2))')) .sort('id') .toPandas()) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/PythonUDF.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/PythonUDF.scala index e311659b18e53..274f2cd1cd791 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/PythonUDF.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/PythonUDF.scala @@ -17,10 +17,28 @@ package org.apache.spark.sql.catalyst.expressions -import org.apache.spark.api.python.PythonFunction +import org.apache.spark.api.python.{PythonEvalType, PythonFunction} import org.apache.spark.sql.catalyst.util.toPrettySQL import org.apache.spark.sql.types.DataType +/** + * Helper functions for PythonUDF + */ +object PythonUDF { + def isScalarPythonUDF(e: Expression): Boolean = { + e.isInstanceOf[PythonUDF] && + Set( + PythonEvalType.SQL_BATCHED_UDF, + PythonEvalType.SQL_PANDAS_SCALAR_UDF + ).contains(e.asInstanceOf[PythonUDF].evalType) + } + + def isGroupAggPandasUDF(e: Expression): Boolean = { + e.isInstanceOf[PythonUDF] && + e.asInstanceOf[PythonUDF].evalType == PythonEvalType.SQL_PANDAS_GROUP_AGG_UDF + } +} + /** * A serialized version of a Python lambda function. */ diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkOptimizer.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkOptimizer.scala index f665852e0795f..0c78df705134c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkOptimizer.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkOptimizer.scala @@ -32,8 +32,8 @@ class SparkOptimizer( override def batches: Seq[Batch] = (preOptimizationBatches ++ super.batches :+ Batch("Optimize Metadata Only Query", Once, OptimizeMetadataOnlyQuery(catalog)) :+ Batch("Extract Python UDF from Aggregate", Once, ExtractPythonUDFFromAggregate) :+ - Batch("Extract group aggregate Pandas UDF from Aggregate", - Once, ExtractGroupAggPandasUDFFromAggregate) :+ + // Batch("Extract group aggregate Pandas UDF from Aggregate", + // Once, ExtractGroupAggPandasUDFFromAggregate) :+ Batch("Prune File Source Table Partitions", Once, PruneFileSourcePartitions) :+ Batch("Push down operators to data source scan", Once, PushDownOperatorsToDataSource)) ++ postHocOptimizationBatches :+ diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/AggregateInPandasExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/AggregateInPandasExec.scala index 73c78324bc54f..63b04b271991f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/AggregateInPandasExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/AggregateInPandasExec.scala @@ -95,14 +95,15 @@ case class AggregateInPandasExec( val input = groupingExpressions.map(_.toAttribute) ++ udfExpressions.map(_.resultAttribute) inputRDD.mapPartitionsInternal { iter => + val proj = UnsafeProjection.create(allInputs, child.output) + val grouped = if (groupingExpressions.isEmpty) { - Iterator((null, iter)) + // Use an empty unsafe row as a place holder for the grouping key + Iterator((new UnsafeRow(), iter)) } else { - val groupedIter = GroupedIterator(iter, groupingExpressions, child.output) - val proj = UnsafeProjection.create(allInputs, child.output) - groupedIter.map { - case (k, groupedRowIter) => (k, groupedRowIter.map(proj)) - } + GroupedIterator(iter, groupingExpressions, child.output) + }.map { case (key, rows) => + (key, rows.map(proj)) } val context = TaskContext.get() diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ExtractPythonUDFs.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ExtractPythonUDFs.scala index 0d3713c8d731e..fb43bccb03ec5 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ExtractPythonUDFs.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ExtractPythonUDFs.scala @@ -27,30 +27,24 @@ import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, LogicalPlan, Proj import org.apache.spark.sql.catalyst.rules.Rule import org.apache.spark.sql.execution.{FilterExec, ProjectExec, SparkPlan} - /** * Extracts all the Python UDFs in logical aggregate, which depends on aggregate expression or * grouping key, evaluate them after aggregate. */ object ExtractPythonUDFFromAggregate extends Rule[LogicalPlan] { - private def isPythonUDF(e: Expression): Boolean = { - e.isInstanceOf[PythonUDF] && - Set(PythonEvalType.SQL_BATCHED_UDF, PythonEvalType.SQL_PANDAS_SCALAR_UDF - ).contains(e.asInstanceOf[PythonUDF].evalType) - } - /** * Returns whether the expression could only be evaluated within aggregate. */ private def belongAggregate(e: Expression, agg: Aggregate): Boolean = { e.isInstanceOf[AggregateExpression] || + PythonUDF.isGroupAggPandasUDF(e) || agg.groupingExpressions.exists(_.semanticEquals(e)) } private def hasPythonUdfOverAggregate(expr: Expression, agg: Aggregate): Boolean = { expr.find { - e => isPythonUDF(e) && e.find(belongAggregate(_, agg)).isDefined + e => PythonUDF.isScalarPythonUDF(e) && e.find(belongAggregate(_, agg)).isDefined }.isDefined } @@ -99,14 +93,8 @@ object ExtractPythonUDFFromAggregate extends Rule[LogicalPlan] { */ object ExtractPythonUDFs extends Rule[SparkPlan] with PredicateHelper { - private def isPythonUDF(e: Expression): Boolean = { - e.isInstanceOf[PythonUDF] && - Set(PythonEvalType.SQL_BATCHED_UDF, PythonEvalType.SQL_PANDAS_SCALAR_UDF - ).contains(e.asInstanceOf[PythonUDF].evalType) - } - private def hasPythonUDF(e: Expression): Boolean = { - e.find(isPythonUDF).isDefined + e.find(PythonUDF.isScalarPythonUDF).isDefined } private def canEvaluateInPython(e: PythonUDF): Boolean = { @@ -119,7 +107,7 @@ object ExtractPythonUDFs extends Rule[SparkPlan] with PredicateHelper { } private def collectEvaluatableUDF(expr: Expression): Seq[PythonUDF] = expr match { - case udf: PythonUDF if isPythonUDF(udf) && canEvaluateInPython(udf) => Seq(udf) + case udf: PythonUDF if PythonUDF.isScalarPythonUDF(udf) && canEvaluateInPython(udf) => Seq(udf) case e => e.children.flatMap(collectEvaluatableUDF) } @@ -162,10 +150,8 @@ object ExtractPythonUDFs extends Rule[SparkPlan] with PredicateHelper { udf.references.subsetOf(child.outputSet) } if (validUdfs.nonEmpty) { - require(validUdfs.forall(udf => - udf.evalType == PythonEvalType.SQL_BATCHED_UDF || - udf.evalType == PythonEvalType.SQL_PANDAS_SCALAR_UDF - ), "Can only extract scalar vectorized udf or sql batch udf") + require(validUdfs.forall(PythonUDF.isScalarPythonUDF), + "Can only extract scalar vectorized udf or sql batch udf") val resultAttrs = udfs.zipWithIndex.map { case (u, i) => AttributeReference(s"pythonUDF$i", u.dataType)() @@ -236,13 +222,8 @@ object ExtractPythonUDFs extends Rule[SparkPlan] with PredicateHelper { */ object ExtractGroupAggPandasUDFFromAggregate extends Rule[LogicalPlan] { - private def isPandasGroupAggUdf(expr: Expression): Boolean = { - expr.isInstanceOf[PythonUDF] && - expr.asInstanceOf[PythonUDF].evalType == PythonEvalType.SQL_PANDAS_GROUP_AGG_UDF - } - private def hasPandasGroupAggUdf(expr: Expression): Boolean = { - expr.find(isPandasGroupAggUdf).isDefined + expr.find(PythonUDF.isGroupAggPandasUDF).isDefined } private def extract(agg: Aggregate): LogicalPlan = { @@ -252,7 +233,7 @@ object ExtractGroupAggPandasUDFFromAggregate extends Rule[LogicalPlan] { agg.aggregateExpressions.foreach { expr => if (hasPandasGroupAggUdf(expr)) { val newE = expr transformDown { - case e: PythonUDF if isPandasGroupAggUdf(e) => + case e: PythonUDF if PythonUDF.isGroupAggPandasUDF(e) => // Wrap the UDF with alias to make it a NamedExpression // The alias is intermediate, its attribute name doesn't affect the final result val alias = Alias(e, "agg")(exprId = e.resultId) From 959f3eb21ebf2ebd1a41b131f50d07a792aa9258 Mon Sep 17 00:00:00 2001 From: Li Jin Date: Fri, 29 Dec 2017 17:13:15 -0500 Subject: [PATCH 17/35] Address PR comments --- python/pyspark/sql/tests.py | 24 +++++++++++++++++++ .../sql/catalyst/analysis/CheckAnalysis.scala | 16 +++++++++---- .../spark/sql/execution/SparkStrategies.scala | 7 +++++- .../python/AggregateInPandasExec.scala | 11 +++++---- 4 files changed, 49 insertions(+), 9 deletions(-) diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py index 15eff4481e688..fa82e1d61a09c 100644 --- a/python/pyspark/sql/tests.py +++ b/python/pyspark/sql/tests.py @@ -4701,6 +4701,30 @@ def test_retain_group_columns(self): else: self.spark.conf.set("spark.sql.retainGroupColumns", orig_value) + def test_invalid_args(self): + from pyspark.sql.functions import mean + + df = self.data + plus_one = self.plus_one + mean_udf = self.mean_udf + + with QuietTest(self.sc): + with self.assertRaisesRegexp( + AnalysisException, + 'nor.*aggregate function'): + df.groupby(df.id).agg(plus_one(df.v)).collect() + + with QuietTest(self.sc): + with self.assertRaisesRegexp( + AnalysisException, + 'aggregate function.*argument.*aggregate function'): + df.groupby(df.id).agg(mean_udf(mean_udf(df.v))).collect() + + with QuietTest(self.sc): + with self.assertRaisesRegexp( + Exception, + 'mixture.*aggregate function.*group aggregate pandas UDF'): + df.groupby(df.id).agg(mean_udf(df.v), mean(df.v)).collect() if __name__ == "__main__": from pyspark.sql.tests import * diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala index a1662112e1639..02f71f19a3390 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala @@ -153,11 +153,20 @@ trait CheckAnalysis extends PredicateHelper { s"of type ${condition.dataType.simpleString} is not a boolean.") case Aggregate(groupingExprs, aggregateExprs, child) => + def isAggregateExpression(expr: Expression) = { + expr.isInstanceOf[AggregateExpression] || + PythonUDF.isGroupAggPandasUDF(expr) + } + def checkValidAggregateExpression(expr: Expression): Unit = expr match { - case aggExpr: AggregateExpression => - aggExpr.aggregateFunction.children.foreach { child => + case expr: Expression if isAggregateExpression(expr) => + val aggFunction = expr match { + case agg: AggregateExpression => agg.aggregateFunction + case udf: PythonUDF => udf + } + aggFunction.children.foreach { child => child.foreach { - case agg: AggregateExpression => + case expr: Expression if isAggregateExpression(expr) => failAnalysis( s"It is not allowed to use an aggregate function in the argument of " + s"another aggregate function. Please use the inner aggregate function " + @@ -171,7 +180,6 @@ trait CheckAnalysis extends PredicateHelper { s"appear in the arguments of an aggregate function.") } } - case _: PythonUDF => // OK case e: Attribute if groupingExprs.isEmpty => // Collect all [[AggregateExpressions]]s. val aggExprs = aggregateExprs.filter(_.collect { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala index 05fe33d28f9ad..563ce04eb4fc1 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala @@ -289,6 +289,11 @@ abstract class SparkStrategies extends QueryPlanner[SparkPlan] { case PhysicalAggregation( namedGroupingExpressions, aggregateExpressions, rewrittenResultExpressions, child) => + require( + !aggregateExpressions.exists(PythonUDF.isGroupAggPandasUDF), + "Streaming aggregation doesn't support group aggregate pandas UDF" + ) + aggregate.AggUtils.planStreamingAggregation( namedGroupingExpressions, aggregateExpressions.map(expr => expr.asInstanceOf[AggregateExpression]), @@ -378,7 +383,7 @@ abstract class SparkStrategies extends QueryPlanner[SparkPlan] { planLater(child))) } else { throw new IllegalArgumentException( - "Cannot use mixture of aggregation function and pandas group aggregation UDF") + "Cannot use a mixture of aggregate function and group aggregate pandas UDF") } case _ => Nil diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/AggregateInPandasExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/AggregateInPandasExec.scala index 63b04b271991f..8fa5888995c8d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/AggregateInPandasExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/AggregateInPandasExec.scala @@ -79,12 +79,15 @@ case class AggregateInPandasExec( val allInputs = new ArrayBuffer[Expression] val dataTypes = new ArrayBuffer[DataType] - val argOffsets = inputs.map { input => input.map { e => - allInputs += e - dataTypes += e.dataType - allInputs.length - 1 + if (allInputs.exists(_.semanticEquals(e))) { + allInputs.indexWhere(_.semanticEquals(e)) + } else { + allInputs += e + dataTypes += e.dataType + allInputs.length - 1 + } }.toArray }.toArray From 3cda9b870c81169348c6f9a899e12b8f39cdfd0e Mon Sep 17 00:00:00 2001 From: Li Jin Date: Fri, 29 Dec 2017 17:13:54 -0500 Subject: [PATCH 18/35] Remove ExtractGroupAggPandasUDFFromAggregate --- .../spark/sql/execution/SparkOptimizer.scala | 4 +- .../execution/python/ExtractPythonUDFs.scala | 40 ------------------- 2 files changed, 1 insertion(+), 43 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkOptimizer.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkOptimizer.scala index 0c78df705134c..1c8e4050978dc 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkOptimizer.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkOptimizer.scala @@ -22,7 +22,7 @@ import org.apache.spark.sql.catalyst.catalog.SessionCatalog import org.apache.spark.sql.catalyst.optimizer.Optimizer import org.apache.spark.sql.execution.datasources.PruneFileSourcePartitions import org.apache.spark.sql.execution.datasources.v2.PushDownOperatorsToDataSource -import org.apache.spark.sql.execution.python.{ExtractGroupAggPandasUDFFromAggregate, ExtractPythonUDFFromAggregate} +import org.apache.spark.sql.execution.python.ExtractPythonUDFFromAggregate class SparkOptimizer( catalog: SessionCatalog, @@ -32,8 +32,6 @@ class SparkOptimizer( override def batches: Seq[Batch] = (preOptimizationBatches ++ super.batches :+ Batch("Optimize Metadata Only Query", Once, OptimizeMetadataOnlyQuery(catalog)) :+ Batch("Extract Python UDF from Aggregate", Once, ExtractPythonUDFFromAggregate) :+ - // Batch("Extract group aggregate Pandas UDF from Aggregate", - // Once, ExtractGroupAggPandasUDFFromAggregate) :+ Batch("Prune File Source Table Partitions", Once, PruneFileSourcePartitions) :+ Batch("Push down operators to data source scan", Once, PushDownOperatorsToDataSource)) ++ postHocOptimizationBatches :+ diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ExtractPythonUDFs.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ExtractPythonUDFs.scala index fb43bccb03ec5..23ba1c2857990 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ExtractPythonUDFs.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ExtractPythonUDFs.scala @@ -215,43 +215,3 @@ object ExtractPythonUDFs extends Rule[SparkPlan] with PredicateHelper { } } - -/** - * Extract all the group aggregate Pandas UDFs in logical aggregation, evaluate the UDFs first - * and then the expressions that depend on the result of the UDFs. - */ -object ExtractGroupAggPandasUDFFromAggregate extends Rule[LogicalPlan] { - - private def hasPandasGroupAggUdf(expr: Expression): Boolean = { - expr.find(PythonUDF.isGroupAggPandasUDF).isDefined - } - - private def extract(agg: Aggregate): LogicalPlan = { - val projList = new ArrayBuffer[NamedExpression]() - val aggExpr = new ArrayBuffer[NamedExpression]() - - agg.aggregateExpressions.foreach { expr => - if (hasPandasGroupAggUdf(expr)) { - val newE = expr transformDown { - case e: PythonUDF if PythonUDF.isGroupAggPandasUDF(e) => - // Wrap the UDF with alias to make it a NamedExpression - // The alias is intermediate, its attribute name doesn't affect the final result - val alias = Alias(e, "agg")(exprId = e.resultId) - aggExpr += alias - alias.toAttribute - } - projList += newE.asInstanceOf[NamedExpression] - } else { - aggExpr += expr - projList += expr.toAttribute - } - } - - Project(projList, agg.copy(aggregateExpressions = aggExpr)) - } - - override def apply(plan: LogicalPlan): LogicalPlan = plan transformUp { - case agg: Aggregate if agg.aggregateExpressions.exists(hasPandasGroupAggUdf) => - extract(agg) - } -} From 1696bdb7a9c094d5931d836d65591bfc6bf3b154 Mon Sep 17 00:00:00 2001 From: Li Jin Date: Fri, 29 Dec 2017 17:33:59 -0500 Subject: [PATCH 19/35] Fix style --- .../org/apache/spark/sql/catalyst/planning/patterns.scala | 7 ++++--- .../spark/sql/execution/python/ExtractPythonUDFs.scala | 6 +++--- 2 files changed, 7 insertions(+), 6 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala index 520080a467571..132241061d510 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala @@ -215,8 +215,9 @@ object PhysicalAggregation { // addExpr() always returns false for non-deterministic expressions and do not add them. case agg: AggregateExpression if !equivalentAggregateExpressions.addExpr(agg) => agg - case udf: PythonUDF if udf.evalType == PythonEvalType.SQL_PANDAS_GROUP_AGG_UDF && - !equivalentAggregateExpressions.addExpr(udf) => udf + case udf: PythonUDF + if PythonUDF.isGroupAggPandasUDF(udf) && + !equivalentAggregateExpressions.addExpr(udf) => udf } } @@ -245,7 +246,7 @@ object PhysicalAggregation { equivalentAggregateExpressions.getEquivalentExprs(ae).headOption .getOrElse(ae).asInstanceOf[AggregateExpression].resultAttribute // Similar to AggregateExpression - case ue: PythonUDF if ue.evalType == PythonEvalType.SQL_PANDAS_GROUP_AGG_UDF => + case ue: PythonUDF if PythonUDF.isGroupAggPandasUDF(ue) => equivalentAggregateExpressions.getEquivalentExprs(ue).headOption .getOrElse(ue).asInstanceOf[PythonUDF].resultAttribute case expression => diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ExtractPythonUDFs.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ExtractPythonUDFs.scala index 23ba1c2857990..8bfbfaba5cf43 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ExtractPythonUDFs.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ExtractPythonUDFs.scala @@ -150,8 +150,9 @@ object ExtractPythonUDFs extends Rule[SparkPlan] with PredicateHelper { udf.references.subsetOf(child.outputSet) } if (validUdfs.nonEmpty) { - require(validUdfs.forall(PythonUDF.isScalarPythonUDF), - "Can only extract scalar vectorized udf or sql batch udf") + require( + validUdfs.forall(PythonUDF.isScalarPythonUDF), + "Can only extract scalar vectorized udf or sql batch udf") val resultAttrs = udfs.zipWithIndex.map { case (u, i) => AttributeReference(s"pythonUDF$i", u.dataType)() @@ -214,4 +215,3 @@ object ExtractPythonUDFs extends Rule[SparkPlan] with PredicateHelper { } } } - From 4e713a425aafb5652582417e7345e92728c5f900 Mon Sep 17 00:00:00 2001 From: Li Jin Date: Tue, 2 Jan 2018 10:15:29 -0500 Subject: [PATCH 20/35] Add doctest SKIP for passing build with pypy --- python/pyspark/sql/group.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/pyspark/sql/group.py b/python/pyspark/sql/group.py index 857914b991c97..8b77c827a2914 100644 --- a/python/pyspark/sql/group.py +++ b/python/pyspark/sql/group.py @@ -89,7 +89,7 @@ def agg(self, *exprs): [Row(name=u'Alice', count(1)=1), Row(name=u'Bob', count(1)=1)] >>> from pyspark.sql import functions as F - >>> sorted(gdf.agg(F.min(df.age)).collect()) + >>> sorted(gdf.agg(F.min(df.age)).collect()) # doctest: +SKIP [Row(name=u'Alice', min(age)=2), Row(name=u'Bob', min(age)=5)] >>> from pyspark.sql.functions import pandas_udf, PandasUDFType From a89416fa7e79ebf4ea8a37bc6760762557608cbc Mon Sep 17 00:00:00 2001 From: Li Jin Date: Tue, 2 Jan 2018 10:28:48 -0500 Subject: [PATCH 21/35] Fix incorrect doctest SKIP --- python/pyspark/sql/group.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/python/pyspark/sql/group.py b/python/pyspark/sql/group.py index 8b77c827a2914..96dace079b353 100644 --- a/python/pyspark/sql/group.py +++ b/python/pyspark/sql/group.py @@ -89,14 +89,14 @@ def agg(self, *exprs): [Row(name=u'Alice', count(1)=1), Row(name=u'Bob', count(1)=1)] >>> from pyspark.sql import functions as F - >>> sorted(gdf.agg(F.min(df.age)).collect()) # doctest: +SKIP + >>> sorted(gdf.agg(F.min(df.age)).collect()) [Row(name=u'Alice', min(age)=2), Row(name=u'Bob', min(age)=5)] >>> from pyspark.sql.functions import pandas_udf, PandasUDFType >>> @pandas_udf('int', PandasUDFType.GROUP_AGG) ... def min_udf(v): ... return v.min() - >>> sorted(gdf.agg(min_udf(df.age)).collect()) + >>> sorted(gdf.agg(min_udf(df.age)).collect()) # doctest: +SKIP [Row(name=u'Alice', min_udf(age)=2), Row(name=u'Bob', min_udf(age)=5)] """ assert exprs, "exprs should not be empty" From 4253caa5d4ce3e4f1cc2c9bff2be79a4e579e41d Mon Sep 17 00:00:00 2001 From: Li Jin Date: Tue, 2 Jan 2018 16:21:46 -0500 Subject: [PATCH 22/35] Add docs for AggregateInPandasExec --- .../sql/execution/python/AggregateInPandasExec.scala | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/AggregateInPandasExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/AggregateInPandasExec.scala index 8fa5888995c8d..ab232f81ce51f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/AggregateInPandasExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/AggregateInPandasExec.scala @@ -31,6 +31,14 @@ import org.apache.spark.sql.execution.{GroupedIterator, SparkPlan, UnaryExecNode import org.apache.spark.sql.types.{DataType, StructField, StructType} import org.apache.spark.util.Utils +/** + * Physical node for aggregation with group aggregate Pandas UDF. + * + * This plan works by sending the necessary (projected) input grouped data as Arrow record batches + * to the python worker, the python worker invokes the UDF and sends the results to the executor, + * finally the executor evaluates any post-aggregation expressions and join the result with the + * grouped key. + */ case class AggregateInPandasExec( groupingExpressions: Seq[NamedExpression], udfExpressions: Seq[PythonUDF], From f91d9ba3f46e3e17d1d6e3c3bf7224cbefcf0406 Mon Sep 17 00:00:00 2001 From: Li Jin Date: Wed, 10 Jan 2018 14:37:41 -0500 Subject: [PATCH 23/35] Address PR comments --- python/pyspark/sql/tests.py | 2 +- .../sql/catalyst/analysis/CheckAnalysis.scala | 3 +-- .../sql/catalyst/expressions/PythonUDF.scala | 11 +++++----- .../spark/sql/execution/SparkStrategies.scala | 3 +-- .../python/AggregateInPandasExec.scala | 21 +++++++++++-------- 5 files changed, 21 insertions(+), 19 deletions(-) diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py index fa82e1d61a09c..f7b7c71081fd5 100644 --- a/python/pyspark/sql/tests.py +++ b/python/pyspark/sql/tests.py @@ -4722,7 +4722,7 @@ def test_invalid_args(self): with QuietTest(self.sc): with self.assertRaisesRegexp( - Exception, + IllegalArgumentException, 'mixture.*aggregate function.*group aggregate pandas UDF'): df.groupby(df.id).agg(mean_udf(df.v), mean(df.v)).collect() diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala index 02f71f19a3390..ef91d79f3302c 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala @@ -154,8 +154,7 @@ trait CheckAnalysis extends PredicateHelper { case Aggregate(groupingExprs, aggregateExprs, child) => def isAggregateExpression(expr: Expression) = { - expr.isInstanceOf[AggregateExpression] || - PythonUDF.isGroupAggPandasUDF(expr) + expr.isInstanceOf[AggregateExpression] || PythonUDF.isGroupAggPandasUDF(expr) } def checkValidAggregateExpression(expr: Expression): Unit = expr match { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/PythonUDF.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/PythonUDF.scala index 274f2cd1cd791..9c5bc318020fb 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/PythonUDF.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/PythonUDF.scala @@ -25,12 +25,13 @@ import org.apache.spark.sql.types.DataType * Helper functions for PythonUDF */ object PythonUDF { + private[this] val SCALAR_TYPES = Set( + PythonEvalType.SQL_BATCHED_UDF, + PythonEvalType.SQL_PANDAS_SCALAR_UDF + ) + def isScalarPythonUDF(e: Expression): Boolean = { - e.isInstanceOf[PythonUDF] && - Set( - PythonEvalType.SQL_BATCHED_UDF, - PythonEvalType.SQL_PANDAS_SCALAR_UDF - ).contains(e.asInstanceOf[PythonUDF].evalType) + e.isInstanceOf[PythonUDF] && SCALAR_TYPES.contains(e.asInstanceOf[PythonUDF].evalType) } def isGroupAggPandasUDF(e: Expression): Boolean = { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala index 563ce04eb4fc1..8a52e75077f8c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala @@ -291,8 +291,7 @@ abstract class SparkStrategies extends QueryPlanner[SparkPlan] { require( !aggregateExpressions.exists(PythonUDF.isGroupAggPandasUDF), - "Streaming aggregation doesn't support group aggregate pandas UDF" - ) + "Streaming aggregation doesn't support group aggregate pandas UDF") aggregate.AggUtils.planStreamingAggregation( namedGroupingExpressions, diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/AggregateInPandasExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/AggregateInPandasExec.scala index ab232f81ce51f..18e5f8605c60d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/AggregateInPandasExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/AggregateInPandasExec.scala @@ -85,6 +85,8 @@ case class AggregateInPandasExec( val (pyFuncs, inputs) = udfExpressions.map(collectFunctions).unzip + // Filter child output attributes down to only those that are UDF inputs. + // Also eliminate duplicate UDF inputs. val allInputs = new ArrayBuffer[Expression] val dataTypes = new ArrayBuffer[DataType] val argOffsets = inputs.map { input => @@ -99,14 +101,13 @@ case class AggregateInPandasExec( }.toArray }.toArray - val schema = StructType(dataTypes.zipWithIndex.map { case (dt, i) => + // Schema of input rows to the python runner + val aggInputSchema = StructType(dataTypes.zipWithIndex.map { case (dt, i) => StructField(s"_$i", dt) }) - val input = groupingExpressions.map(_.toAttribute) ++ udfExpressions.map(_.resultAttribute) - inputRDD.mapPartitionsInternal { iter => - val proj = UnsafeProjection.create(allInputs, child.output) + val prunedProj = UnsafeProjection.create(allInputs, child.output) val grouped = if (groupingExpressions.isEmpty) { // Use an empty unsafe row as a place holder for the grouping key @@ -114,7 +115,7 @@ case class AggregateInPandasExec( } else { GroupedIterator(iter, groupingExpressions, child.output) }.map { case (key, rows) => - (key, rows.map(proj)) + (key, rows.map(prunedProj)) } val context = TaskContext.get() @@ -135,16 +136,18 @@ case class AggregateInPandasExec( val columnarBatchIter = new ArrowPythonRunner( pyFuncs, bufferSize, reuseWorker, - PythonEvalType.SQL_PANDAS_GROUP_AGG_UDF, argOffsets, schema, + PythonEvalType.SQL_PANDAS_GROUP_AGG_UDF, argOffsets, aggInputSchema, sessionLocalTimeZone, pandasRespectSessionTimeZone) .compute(projectedRowIter, context.partitionId(), context) + val joinedAttributes = + groupingExpressions.map(_.toAttribute) ++ udfExpressions.map(_.resultAttribute) val joined = new JoinedRow - val resultProj = UnsafeProjection.create(resultExpressions, input) + val resultProj = UnsafeProjection.create(resultExpressions, joinedAttributes) - columnarBatchIter.map(_.rowIterator.next()).map { outputRow => + columnarBatchIter.map(_.rowIterator.next()).map { aggOutputRow => val leftRow = queue.remove() - val joinedRow = joined(leftRow, outputRow) + val joinedRow = joined(leftRow, aggOutputRow) resultProj(joinedRow) } } From 9085ca6adf9cc3bbdfdc09d9515ce14a8644a403 Mon Sep 17 00:00:00 2001 From: Li Jin Date: Tue, 16 Jan 2018 14:01:36 -0500 Subject: [PATCH 24/35] Address PR comments --- python/pyspark/sql/tests.py | 215 +++++++++++------- python/pyspark/sql/udf.py | 5 +- .../sql/catalyst/expressions/PythonUDF.scala | 2 +- .../spark/sql/execution/SparkStrategies.scala | 95 ++++---- 4 files changed, 179 insertions(+), 138 deletions(-) diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py index f7b7c71081fd5..db00409d41d2a 100644 --- a/python/pyspark/sql/tests.py +++ b/python/pyspark/sql/tests.py @@ -3699,7 +3699,7 @@ def foo(k, v): @unittest.skipIf(not _have_pandas or not _have_arrow, "Pandas or Arrow not installed") -class VectorizedUDFTests(ReusedSQLTestCase): +class ScalarPandasUDF(ReusedSQLTestCase): @classmethod def setUpClass(cls): @@ -4178,7 +4178,7 @@ def test_register_vectorized_udf_basic(self): @unittest.skipIf(not _have_pandas or not _have_arrow, "Pandas or Arrow not installed") -class GroupbyApplyTests(ReusedSQLTestCase): +class GroupbyApplyPandasUDFTests(ReusedSQLTestCase): @property def data(self): @@ -4347,7 +4347,7 @@ def test_unsupported_types(self): @unittest.skipIf(not _have_pandas or not _have_arrow, "Pandas or Arrow not installed") -class GroupbyAggTests(ReusedSQLTestCase): +class GroupbyAggPandasUDFTests(ReusedSQLTestCase): @property def data(self): @@ -4359,7 +4359,7 @@ def data(self): .withColumn('w', lit(1.0)) @property - def plus_one(self): + def python_plus_one(self): from pyspark.sql.functions import udf @udf('double') @@ -4369,7 +4369,7 @@ def plus_one(v): return plus_one @property - def plus_two(self): + def pandas_scalar_plus_two(self): import pandas as pd from pyspark.sql.functions import pandas_udf, PandasUDFType @@ -4384,18 +4384,18 @@ def mean_udf(self): from pyspark.sql.functions import pandas_udf, PandasUDFType @pandas_udf('double', PandasUDFType.GROUP_AGG) - def mean_udf(v): + def avg(v): return v.mean() - return mean_udf + return avg @property def sum_udf(self): from pyspark.sql.functions import pandas_udf, PandasUDFType @pandas_udf('double', PandasUDFType.GROUP_AGG) - def sum_udf(v): + def sum(v): return v.sum() - return sum_udf + return sum @property def weighted_mean_udf(self): @@ -4403,9 +4403,9 @@ def weighted_mean_udf(self): from pyspark.sql.functions import pandas_udf, PandasUDFType @pandas_udf('double', PandasUDFType.GROUP_AGG) - def weighted_mean_udf(v, w): + def weighted_mean(v, w): return np.average(v, weights=w) - return weighted_mean_udf + return weighted_mean def test_basic(self): from pyspark.sql.functions import col, lit, sum, mean @@ -4413,46 +4413,53 @@ def test_basic(self): df = self.data weighted_mean_udf = self.weighted_mean_udf + # Groupby one column and aggregate one UDF with literal result1 = df.groupby('id').agg(weighted_mean_udf(df.v, lit(1.0))).sort('id') - expected1 = df.groupby('id').agg(mean(df.v).alias('weighted_mean_udf(v, 1.0)')).sort('id') + expected1 = df.groupby('id').agg(mean(df.v).alias('weighted_mean(v, 1.0)')).sort('id') self.assertPandasEqual(expected1.toPandas(), result1.toPandas()) + # Groupby one expression and aggregate one UDF with literal result2 = df.groupby((col('id') + 1)).agg(weighted_mean_udf(df.v, lit(1.0)))\ .sort(df.id + 1) expected2 = df.groupby((col('id') + 1))\ - .agg(mean(df.v).alias('weighted_mean_udf(v, 1.0)')).sort(df.id + 1) + .agg(mean(df.v).alias('weighted_mean(v, 1.0)')).sort(df.id + 1) self.assertPandasEqual(expected2.toPandas(), result2.toPandas()) + # Groupby one column and aggregate one UDF without literal result3 = df.groupby('id').agg(weighted_mean_udf(df.v, df.w)).sort('id') - expected3 = df.groupby('id').agg(mean(df.v).alias('weighted_mean_udf(v, w)')).sort('id') + expected3 = df.groupby('id').agg(mean(df.v).alias('weighted_mean(v, w)')).sort('id') self.assertPandasEqual(expected3.toPandas(), result3.toPandas()) + # Groupby one expression and aggregate one UDF without literal result4 = df.groupby((col('id') + 1).alias('id'))\ .agg(weighted_mean_udf(df.v, df.w))\ .sort('id') expected4 = df.groupby((col('id') + 1).alias('id'))\ - .agg(mean(df.v).alias('weighted_mean_udf(v, w)'))\ + .agg(mean(df.v).alias('weighted_mean(v, w)'))\ .sort('id') self.assertPandasEqual(expected4.toPandas(), result4.toPandas()) - def test_array(self): - from pyspark.sql.types import ArrayType, DoubleType + def test_unsupported_types(self): + from pyspark.sql.types import ArrayType, DoubleType, MapType from pyspark.sql.functions import pandas_udf, PandasUDFType with QuietTest(self.sc): - with self.assertRaisesRegexp(NotImplementedError, 'not supported'): + with self.assertRaisesRegex(NotImplementedError, 'not supported'): @pandas_udf(ArrayType(DoubleType()), PandasUDFType.GROUP_AGG) def mean_and_std_udf(v): return [v.mean(), v.std()] - def test_struct(self): - from pyspark.sql.functions import pandas_udf, PandasUDFType - with QuietTest(self.sc): - with self.assertRaisesRegexp(NotImplementedError, 'not supported'): + with self.assertRaisesRegex(NotImplementedError, 'not supported'): @pandas_udf('mean double, std double', PandasUDFType.GROUP_AGG) def mean_and_std_udf(v): - return (v.mean(), v.std()) + return v.mean(), v.std() + + with QuietTest(self.sc): + with self.assertRaisesRegex(NotImplementedError, 'not supported'): + @pandas_udf(MapType(DoubleType(), DoubleType()), PandasUDFType.GROUP_AGG) + def mean_and_std_udf(v): + return {v.mean(): v.std()} def test_alias(self): from pyspark.sql.functions import mean @@ -4466,84 +4473,114 @@ def test_alias(self): self.assertPandasEqual(expected1.toPandas(), result1.toPandas()) def test_mixed_sql(self): + """ + Test mixing group aggregate pandas UDF with sql expression. + """ from pyspark.sql.functions import sum, mean df = self.data sum_udf = self.sum_udf + # Mix group aggregate pandas UDF with sql expression result1 = (df.groupby('id') .agg(sum_udf(df.v) + 1) .sort('id')) - expected1 = (df.groupby('id') - .agg((sum(df.v) + 1).alias('(sum_udf(v) + 1)')) + .agg(sum(df.v) + 1) .sort('id')) + # Mix group aggregate pandas UDF with sql expression (order swapped) result2 = (df.groupby('id') .agg(sum_udf(df.v + 1)) .sort('id')) expected2 = (df.groupby('id') - .agg(sum(df.v + 1).alias('sum_udf((v + 1))')) + .agg(sum(df.v + 1)) .sort('id')) + # Wrap group aggregate pandas UDF with two sql expressions + result3 = (df.groupby('id') + .agg(sum_udf(df.v + 1) + 2) + .sort('id')) + expected3 = (df.groupby('id') + .agg(sum(df.v + 1) + 2) + .sort('id')) + self.assertPandasEqual(expected1.toPandas(), result1.toPandas()) self.assertPandasEqual(expected2.toPandas(), result2.toPandas()) + self.assertPandasEqual(expected3.toPandas(), result3.toPandas()) - def test_mixed_udf(self): + def test_mixed_udfs(self): + """ + Test mixing group aggregate pandas UDF with python UDF and scalar pandas UDF. + """ from pyspark.sql.functions import sum, mean df = self.data - plus_one = self.plus_one - plus_two = self.plus_two + plus_one = self.python_plus_one + plus_two = self.pandas_scalar_plus_two sum_udf = self.sum_udf + # Mix group aggregate pandas UDF and python UDF result1 = (df.groupby('id') .agg(plus_one(sum_udf(df.v))) .sort('id')) - expected1 = (df.groupby('id') - .agg(plus_one(sum(df.v)).alias("plus_one(sum_udf(v))")) + .agg(plus_one(sum(df.v))) .sort('id')) + # Mix group aggregate pandas UDF and python UDF (order swapped) result2 = (df.groupby('id') .agg(sum_udf(plus_one(df.v))) .sort('id')) - expected2 = (df.groupby('id') - .agg(sum(df.v + 1).alias("sum_udf(plus_one(v))")) + .agg(sum(plus_one(df.v))) .sort('id')) + # Mix group aggregate pandas UDF and scalar pandas UDF result3 = (df.groupby('id') .agg(sum_udf(plus_two(df.v))) .sort('id')) - expected3 = (df.groupby('id') - .agg(sum(df.v + 2).alias("sum_udf(plus_two(v))")) + .agg(sum(plus_two(df.v))) .sort('id')) + # Mix group aggregate pandas UDF and scalar pandas UDF (order swapped) result4 = (df.groupby('id') .agg(plus_two(sum_udf(df.v))) .sort('id')) - expected4 = (df.groupby('id') - .agg(plus_two(sum(df.v)).alias("plus_two(sum_udf(v))")) + .agg(plus_two(sum(df.v))) .sort('id')) + # Wrap group aggregate pandas UDF with two python UDFs and use python UDF in groupby result5 = (df.groupby(plus_one(df.id)) .agg(plus_one(sum_udf(plus_one(df.v)))) .sort('plus_one(id)')) expected5 = (df.groupby(plus_one(df.id)) - .agg(plus_one(sum(plus_one(df.v))).alias('plus_one(sum_udf(plus_one(v)))')) + .agg(plus_one(sum(plus_one(df.v)))) .sort('plus_one(id)')) + # Wrap group aggregate pandas UDF with two scala pandas UDF and user scala pandas UDF in + # groupby + result6 = (df.groupby(plus_two(df.id)) + .agg(plus_two(sum_udf(plus_two(df.v)))) + .sort('plus_two(id)')) + expected6 = (df.groupby(plus_two(df.id)) + .agg(plus_two(sum(plus_two(df.v)))) + .sort('plus_two(id)')) + self.assertPandasEqual(expected1.toPandas(), result1.toPandas()) self.assertPandasEqual(expected2.toPandas(), result2.toPandas()) self.assertPandasEqual(expected3.toPandas(), result3.toPandas()) self.assertPandasEqual(expected4.toPandas(), result4.toPandas()) self.assertPandasEqual(expected5.toPandas(), result5.toPandas()) + self.assertPandasEqual(expected6.toPandas(), result6.toPandas()) - def test_multiple(self): + def test_multiple_udfs(self): + """ + Test multiple group aggregate pandas UDFs in one agg function. + """ from pyspark.sql.functions import col, lit, sum, mean df = self.data @@ -4557,54 +4594,50 @@ def test_multiple(self): weighted_mean_udf(df.v, df.w)) .sort('id') .toPandas()) - expected1 = (df.groupBy('id') - .agg(mean(df.v).alias('mean_udf(v)'), - sum(df.v).alias('sum_udf(v)'), - mean(df.v).alias('weighted_mean_udf(v, w)')) + .agg(mean(df.v), + sum(df.v), + mean(df.v).alias('weighted_mean(v, w)')) .sort('id') .toPandas()) - result2 = (df.groupBy('id', 'v') - .agg(mean_udf(df.v), - sum_udf(df.id)) - .sort('id', 'v') - .toPandas()) - - expected2 = (df.groupBy('id', 'v') - .agg(mean_udf(df.v).alias('mean_udf(v)'), - sum_udf(df.id).alias('sum_udf(id)')) - .sort('id', 'v') - .toPandas()) - self.assertPandasEqual(expected1, result1) - self.assertPandasEqual(expected2, result2) def test_complex_grouping(self): from pyspark.sql.functions import lit, sum df = self.data sum_udf = self.sum_udf - plus_one = self.plus_one - plus_two = self.plus_two + plus_one = self.python_plus_one + plus_two = self.pandas_scalar_plus_two - result1 = df.groupby(df.id + 1).agg(sum_udf(df.v)) - expected1 = df.groupby(df.id + 1).agg(sum(df.v).alias('sum_udf(v)')) + # groupby one expression + result1 = df.groupby(df.v % 2).agg(sum_udf(df.v)) + expected1 = df.groupby(df.v % 2).agg(sum(df.v)) + # empty groupby result2 = df.groupby().agg(sum_udf(df.v)) - expected2 = df.groupby().agg(sum(df.v).alias('sum_udf(v)')) + expected2 = df.groupby().agg(sum(df.v)) + # groupby one column and one sql expression result3 = df.groupby(df.id, df.v % 2).agg(sum_udf(df.v)) - expected3 = df.groupby(df.id, df.v % 2).agg(sum(df.v).alias('sum_udf(v)')) + expected3 = df.groupby(df.id, df.v % 2).agg(sum(df.v)) + # groupby one python UDF result4 = df.groupby(plus_one(df.id)).agg(sum_udf(df.v)) - expected4 = df.groupby(plus_one(df.id)).agg(sum(df.v).alias('sum_udf(v)')) + expected4 = df.groupby(plus_one(df.id)).agg(sum(df.v)) + # groupby one scalar pandas UDF result5 = df.groupby(plus_two(df.id)).agg(sum_udf(df.v)) - expected5 = df.groupby(plus_two(df.id)).agg(sum(df.v).alias('sum_udf(v)')) + expected5 = df.groupby(plus_two(df.id)).agg(sum(df.v)) - result6 = df.groupby(df.id, plus_one(df.id)).agg(sum_udf(df.v)) - expected6 = df.groupby(df.id, plus_one(df.id)).agg(sum(df.v).alias('sum_udf(v)')) + # groupby one expression and one python UDF + result6 = df.groupby(df.v % 2, plus_one(df.id)).agg(sum_udf(df.v)) + expected6 = df.groupby(df.v % 2, plus_one(df.id)).agg(sum(df.v)) + + # groupby one expression and one scalar pandas UDF + result7 = df.groupby(df.v % 2, plus_two(df.id)).agg(sum_udf(df.v)).sort('sum(v)') + expected7 = df.groupby(df.v % 2, plus_two(df.id)).agg(sum(df.v)).sort('sum(v)') self.assertPandasEqual(expected1.toPandas(), result1.toPandas()) self.assertPandasEqual(expected2.toPandas(), result2.toPandas()) @@ -4612,18 +4645,21 @@ def test_complex_grouping(self): self.assertPandasEqual(expected4.toPandas(), result4.toPandas()) self.assertPandasEqual(expected5.toPandas(), result5.toPandas()) self.assertPandasEqual(expected6.toPandas(), result6.toPandas()) + self.assertPandasEqual(expected7.toPandas(), result7.toPandas()) - def test_complex_expression(self): + def test_complex_expressions(self): from pyspark.sql.functions import col, sum df = self.data - plus_one = self.plus_one - plus_two = self.plus_two + plus_one = self.python_plus_one + plus_two = self.pandas_scalar_plus_two sum_udf = self.sum_udf + # Test complex expressions with sql expression, python UDF and + # group aggregate pandas UDF result1 = (df.withColumn('v1', plus_one(df.v)) .withColumn('v2', df.v + 2) - .groupby('id') + .groupby(df.id, df.v % 2) .agg(sum_udf(col('v')), sum_udf(col('v1') + 3), sum_udf(col('v2')) + 5, @@ -4634,18 +4670,20 @@ def test_complex_expression(self): expected1 = (df.withColumn('v1', df.v + 1) .withColumn('v2', df.v + 2) - .groupby('id') - .agg(sum(col('v')).alias('sum_udf(v)'), - sum(col('v1') + 3).alias('sum_udf((v1 + 3))'), - (sum(col('v2')) + 5).alias('(sum_udf(v2) + 5)'), - plus_one(sum(col('v1'))).alias('plus_one(sum_udf(v1))'), - sum(plus_one(col('v2'))).alias('sum_udf(plus_one(v2))')) + .groupby(df.id, df.v % 2) + .agg(sum(col('v')), + sum(col('v1') + 3), + sum(col('v2')) + 5, + plus_one(sum(col('v1'))), + sum(plus_one(col('v2')))) .sort('id') .toPandas()) + # Test complex expressions with sql expression, scala pandas UDF and + # group aggregate pandas UDF result2 = (df.withColumn('v1', plus_one(df.v)) .withColumn('v2', df.v + 2) - .groupby('id') + .groupby(df.id, df.v % 2) .agg(sum_udf(col('v')), sum_udf(col('v1') + 3), sum_udf(col('v2')) + 5, @@ -4656,26 +4694,27 @@ def test_complex_expression(self): expected2 = (df.withColumn('v1', df.v + 1) .withColumn('v2', df.v + 2) - .groupby('id') - .agg(sum(col('v')).alias('sum_udf(v)'), - sum(col('v1') + 3).alias('sum_udf((v1 + 3))'), - (sum(col('v2')) + 5).alias('(sum_udf(v2) + 5)'), - plus_two(sum(col('v1'))).alias('plus_two(sum_udf(v1))'), - sum(plus_two(col('v2'))).alias('sum_udf(plus_two(v2))')) + .groupby(df.id, df.v % 2) + .agg(sum(col('v')), + sum(col('v1') + 3), + sum(col('v2')) + 5, + plus_two(sum(col('v1'))), + sum(plus_two(col('v2')))) .sort('id') .toPandas()) + # Test sequential groupby aggregate result3 = (df.groupby('id') .agg(sum_udf(df.v).alias('v')) .groupby('id') - .agg(sum_udf(col('v')).alias('sum_v')) + .agg(sum_udf(col('v'))) .sort('id') .toPandas()) expected3 = (df.groupby('id') .agg(sum(df.v).alias('v')) .groupby('id') - .agg(sum(col('v')).alias('sum_v')) + .agg(sum(col('v'))) .sort('id') .toPandas()) @@ -4692,7 +4731,7 @@ def test_retain_group_columns(self): sum_udf = self.sum_udf result1 = df.groupby(df.id).agg(sum_udf(df.v)) - expected1 = df.groupby(df.id).agg(sum(df.v).alias('sum_udf(v)')) + expected1 = df.groupby(df.id).agg(sum(df.v)) self.assertPandasEqual(expected1.toPandas(), result1.toPandas()) finally: @@ -4705,7 +4744,7 @@ def test_invalid_args(self): from pyspark.sql.functions import mean df = self.data - plus_one = self.plus_one + plus_one = self.python_plus_one mean_udf = self.mean_udf with QuietTest(self.sc): @@ -4722,7 +4761,7 @@ def test_invalid_args(self): with QuietTest(self.sc): with self.assertRaisesRegexp( - IllegalArgumentException, + AnalysisException, 'mixture.*aggregate function.*group aggregate pandas UDF'): df.groupby(df.id).agg(mean_udf(df.v), mean(df.v)).collect() diff --git a/python/pyspark/sql/udf.py b/python/pyspark/sql/udf.py index e9f795810af71..6424dee1f9b9b 100644 --- a/python/pyspark/sql/udf.py +++ b/python/pyspark/sql/udf.py @@ -22,7 +22,8 @@ from pyspark import SparkContext, since from pyspark.rdd import _prepare_for_python_RDD, PythonEvalType, ignore_unicode_prefix from pyspark.sql.column import Column, _to_java_column, _to_seq -from pyspark.sql.types import StringType, DataType, ArrayType, StructType, _parse_datatype_string +from pyspark.sql.types import StringType, DataType, ArrayType, StructType, MapType, \ + _parse_datatype_string __all__ = ["UDFRegistration"] @@ -114,7 +115,7 @@ def returnType(self): raise ValueError("Invalid returnType: returnType must be a StructType for " "pandas_udf with function type GROUP_MAP") elif self.evalType == PythonEvalType.SQL_PANDAS_GROUP_AGG_UDF \ - and isinstance(self._returnType_placeholder, (StructType, ArrayType)): + and isinstance(self._returnType_placeholder, (StructType, ArrayType, MapType)): raise NotImplementedError( "StructType and ArrayType are not supported with PandasUDFType.GROUP_AGG") diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/PythonUDF.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/PythonUDF.scala index 9c5bc318020fb..4ba8ff6e3802f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/PythonUDF.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/PythonUDF.scala @@ -22,7 +22,7 @@ import org.apache.spark.sql.catalyst.util.toPrettySQL import org.apache.spark.sql.types.DataType /** - * Helper functions for PythonUDF + * Helper functions for [[PythonUDF]] */ object PythonUDF { private[this] val SCALAR_TYPES = Set( diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala index 8a52e75077f8c..be104e5a2ab1b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala @@ -289,9 +289,10 @@ abstract class SparkStrategies extends QueryPlanner[SparkPlan] { case PhysicalAggregation( namedGroupingExpressions, aggregateExpressions, rewrittenResultExpressions, child) => - require( - !aggregateExpressions.exists(PythonUDF.isGroupAggPandasUDF), - "Streaming aggregation doesn't support group aggregate pandas UDF") + if (!aggregateExpressions.exists(PythonUDF.isGroupAggPandasUDF)) { + throw new AnalysisException( + "Streaming aggregation doesn't support group aggregate pandas UDF") + } aggregate.AggUtils.planStreamingAggregation( namedGroupingExpressions, @@ -338,52 +339,52 @@ abstract class SparkStrategies extends QueryPlanner[SparkPlan] { */ object Aggregation extends Strategy { def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match { - case PhysicalAggregation( - groupingExpressions, aggExpressions, resultExpressions, child) => - - if (aggExpressions.forall(expr => expr.isInstanceOf[AggregateExpression])) { - - val aggregateExpressions = aggExpressions.map(expr => - expr.asInstanceOf[AggregateExpression]) + case PhysicalAggregation(groupingExpressions, aggExpressions, resultExpressions, child) + if aggExpressions.forall(expr => expr.isInstanceOf[AggregateExpression]) => + + val aggregateExpressions = aggExpressions.map(expr => + expr.asInstanceOf[AggregateExpression]) + + val (functionsWithDistinct, functionsWithoutDistinct) = + aggregateExpressions.partition(_.isDistinct) + if (functionsWithDistinct.map(_.aggregateFunction.children).distinct.length > 1) { + // This is a sanity check. We should not reach here when we have multiple distinct + // column sets. Our MultipleDistinctRewriter should take care this case. + sys.error("You hit a query analyzer bug. Please report your query to " + + "Spark user mailing list.") + } - val (functionsWithDistinct, functionsWithoutDistinct) = - aggregateExpressions.partition(_.isDistinct) - if (functionsWithDistinct.map(_.aggregateFunction.children).distinct.length > 1) { - // This is a sanity check. We should not reach here when we have multiple distinct - // column sets. Our MultipleDistinctRewriter should take care this case. - sys.error("You hit a query analyzer bug. Please report your query to " + - "Spark user mailing list.") + val aggregateOperator = + if (functionsWithDistinct.isEmpty) { + aggregate.AggUtils.planAggregateWithoutDistinct( + groupingExpressions, + aggregateExpressions, + resultExpressions, + planLater(child)) + } else { + aggregate.AggUtils.planAggregateWithOneDistinct( + groupingExpressions, + functionsWithDistinct, + functionsWithoutDistinct, + resultExpressions, + planLater(child)) } - - val aggregateOperator = - if (functionsWithDistinct.isEmpty) { - aggregate.AggUtils.planAggregateWithoutDistinct( - groupingExpressions, - aggregateExpressions, - resultExpressions, - planLater(child)) - } else { - aggregate.AggUtils.planAggregateWithOneDistinct( - groupingExpressions, - functionsWithDistinct, - functionsWithoutDistinct, - resultExpressions, - planLater(child)) - } - - aggregateOperator - } else if (aggExpressions.forall(expr => expr.isInstanceOf[PythonUDF])) { - val udfExpressions = aggExpressions.map(expr => expr.asInstanceOf[PythonUDF]) - - Seq(execution.python.AggregateInPandasExec( - groupingExpressions, - udfExpressions, - resultExpressions, - planLater(child))) - } else { - throw new IllegalArgumentException( - "Cannot use a mixture of aggregate function and group aggregate pandas UDF") - } + aggregateOperator + + case PhysicalAggregation(groupingExpressions, aggExpressions, resultExpressions, child) + if aggExpressions.forall(expr => expr.isInstanceOf[PythonUDF]) => + val udfExpressions = aggExpressions.map(expr => expr.asInstanceOf[PythonUDF]) + + Seq(execution.python.AggregateInPandasExec( + groupingExpressions, + udfExpressions, + resultExpressions, + planLater(child))) + + case PhysicalAggregation(groupingExpressions, aggExpressions, resultExpressions, child) => + // If cannot match the two cases above, then it's an error + throw new AnalysisException( + "Cannot use a mixture of aggregate function and group aggregate pandas UDF") case _ => Nil } From ebc49cc4bae5dcfd7f44a16cf11c40407946717c Mon Sep 17 00:00:00 2001 From: Li Jin Date: Tue, 16 Jan 2018 14:05:01 -0500 Subject: [PATCH 25/35] Minor style change --- python/pyspark/sql/tests.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py index db00409d41d2a..2079c9b968333 100644 --- a/python/pyspark/sql/tests.py +++ b/python/pyspark/sql/tests.py @@ -4603,7 +4603,7 @@ def test_multiple_udfs(self): self.assertPandasEqual(expected1, result1) - def test_complex_grouping(self): + def test_complex_groupby(self): from pyspark.sql.functions import lit, sum df = self.data From bf084ffa9160f92438a1547e8e72d87dbeab8d7f Mon Sep 17 00:00:00 2001 From: Li Jin Date: Tue, 16 Jan 2018 14:11:25 -0500 Subject: [PATCH 26/35] Fix error message] --- python/pyspark/sql/udf.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/pyspark/sql/udf.py b/python/pyspark/sql/udf.py index 6424dee1f9b9b..7a61c5ac83153 100644 --- a/python/pyspark/sql/udf.py +++ b/python/pyspark/sql/udf.py @@ -117,7 +117,7 @@ def returnType(self): elif self.evalType == PythonEvalType.SQL_PANDAS_GROUP_AGG_UDF \ and isinstance(self._returnType_placeholder, (StructType, ArrayType, MapType)): raise NotImplementedError( - "StructType and ArrayType are not supported with PandasUDFType.GROUP_AGG") + "ArrayType, StructType and MapType are not supported with PandasUDFType.GROUP_AGG") return self._returnType_placeholder From 7745b0a1b98970670d8a22b32d312e02015dd719 Mon Sep 17 00:00:00 2001 From: Li Jin Date: Tue, 16 Jan 2018 15:19:45 -0500 Subject: [PATCH 27/35] Fix Streaming aggregation check --- .../scala/org/apache/spark/sql/execution/SparkStrategies.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala index be104e5a2ab1b..ec7045132307b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala @@ -289,7 +289,7 @@ abstract class SparkStrategies extends QueryPlanner[SparkPlan] { case PhysicalAggregation( namedGroupingExpressions, aggregateExpressions, rewrittenResultExpressions, child) => - if (!aggregateExpressions.exists(PythonUDF.isGroupAggPandasUDF)) { + if (aggregateExpressions.exists(PythonUDF.isGroupAggPandasUDF)) { throw new AnalysisException( "Streaming aggregation doesn't support group aggregate pandas UDF") } From cf9e7dcd2973aecedf29ca673bb5bbf71fdb9be1 Mon Sep 17 00:00:00 2001 From: Li Jin Date: Wed, 17 Jan 2018 10:55:56 -0500 Subject: [PATCH 28/35] Minor style fix --- python/pyspark/sql/tests.py | 28 +++++++++---------- .../spark/sql/execution/SparkStrategies.scala | 3 +- .../execution/python/ExtractPythonUDFs.scala | 2 +- 3 files changed, 17 insertions(+), 16 deletions(-) diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py index 2079c9b968333..1774369cb8def 100644 --- a/python/pyspark/sql/tests.py +++ b/python/pyspark/sql/tests.py @@ -4380,7 +4380,7 @@ def plus_two(v): return plus_two @property - def mean_udf(self): + def pandas_agg_mean_udf(self): from pyspark.sql.functions import pandas_udf, PandasUDFType @pandas_udf('double', PandasUDFType.GROUP_AGG) @@ -4389,7 +4389,7 @@ def avg(v): return avg @property - def sum_udf(self): + def pandas_agg_sum_udf(self): from pyspark.sql.functions import pandas_udf, PandasUDFType @pandas_udf('double', PandasUDFType.GROUP_AGG) @@ -4398,7 +4398,7 @@ def sum(v): return sum @property - def weighted_mean_udf(self): + def pandas_agg_weighted_mean_udf(self): import numpy as np from pyspark.sql.functions import pandas_udf, PandasUDFType @@ -4411,7 +4411,7 @@ def test_basic(self): from pyspark.sql.functions import col, lit, sum, mean df = self.data - weighted_mean_udf = self.weighted_mean_udf + weighted_mean_udf = self.pandas_agg_weighted_mean_udf # Groupby one column and aggregate one UDF with literal result1 = df.groupby('id').agg(weighted_mean_udf(df.v, lit(1.0))).sort('id') @@ -4465,7 +4465,7 @@ def test_alias(self): from pyspark.sql.functions import mean df = self.data - mean_udf = self.mean_udf + mean_udf = self.pandas_agg_mean_udf result1 = df.groupby('id').agg(mean_udf(df.v).alias('mean_alias')) expected1 = df.groupby('id').agg(mean(df.v).alias('mean_alias')) @@ -4479,7 +4479,7 @@ def test_mixed_sql(self): from pyspark.sql.functions import sum, mean df = self.data - sum_udf = self.sum_udf + sum_udf = self.pandas_agg_sum_udf # Mix group aggregate pandas UDF with sql expression result1 = (df.groupby('id') @@ -4519,7 +4519,7 @@ def test_mixed_udfs(self): df = self.data plus_one = self.python_plus_one plus_two = self.pandas_scalar_plus_two - sum_udf = self.sum_udf + sum_udf = self.pandas_agg_sum_udf # Mix group aggregate pandas UDF and python UDF result1 = (df.groupby('id') @@ -4584,9 +4584,9 @@ def test_multiple_udfs(self): from pyspark.sql.functions import col, lit, sum, mean df = self.data - mean_udf = self.mean_udf - sum_udf = self.sum_udf - weighted_mean_udf = self.weighted_mean_udf + mean_udf = self.pandas_agg_mean_udf + sum_udf = self.pandas_agg_sum_udf + weighted_mean_udf = self.pandas_agg_weighted_mean_udf result1 = (df.groupBy('id') .agg(mean_udf(df.v), @@ -4607,7 +4607,7 @@ def test_complex_groupby(self): from pyspark.sql.functions import lit, sum df = self.data - sum_udf = self.sum_udf + sum_udf = self.pandas_agg_sum_udf plus_one = self.python_plus_one plus_two = self.pandas_scalar_plus_two @@ -4653,7 +4653,7 @@ def test_complex_expressions(self): df = self.data plus_one = self.python_plus_one plus_two = self.pandas_scalar_plus_two - sum_udf = self.sum_udf + sum_udf = self.pandas_agg_sum_udf # Test complex expressions with sql expression, python UDF and # group aggregate pandas UDF @@ -4728,7 +4728,7 @@ def test_retain_group_columns(self): self.spark.conf.set("spark.sql.retainGroupColumns", False) try: df = self.data - sum_udf = self.sum_udf + sum_udf = self.pandas_agg_sum_udf result1 = df.groupby(df.id).agg(sum_udf(df.v)) expected1 = df.groupby(df.id).agg(sum(df.v)) @@ -4745,7 +4745,7 @@ def test_invalid_args(self): df = self.data plus_one = self.python_plus_one - mean_udf = self.mean_udf + mean_udf = self.pandas_agg_mean_udf with QuietTest(self.sc): with self.assertRaisesRegexp( diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala index ec7045132307b..855b44cfbd117 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala @@ -351,7 +351,7 @@ abstract class SparkStrategies extends QueryPlanner[SparkPlan] { // This is a sanity check. We should not reach here when we have multiple distinct // column sets. Our MultipleDistinctRewriter should take care this case. sys.error("You hit a query analyzer bug. Please report your query to " + - "Spark user mailing list.") + "Spark user mailing list.") } val aggregateOperator = @@ -369,6 +369,7 @@ abstract class SparkStrategies extends QueryPlanner[SparkPlan] { resultExpressions, planLater(child)) } + aggregateOperator case PhysicalAggregation(groupingExpressions, aggExpressions, resultExpressions, child) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ExtractPythonUDFs.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ExtractPythonUDFs.scala index 8bfbfaba5cf43..1862e3f6e12ca 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ExtractPythonUDFs.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ExtractPythonUDFs.scala @@ -27,6 +27,7 @@ import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, LogicalPlan, Proj import org.apache.spark.sql.catalyst.rules.Rule import org.apache.spark.sql.execution.{FilterExec, ProjectExec, SparkPlan} + /** * Extracts all the Python UDFs in logical aggregate, which depends on aggregate expression or * grouping key, evaluate them after aggregate. @@ -51,7 +52,6 @@ object ExtractPythonUDFFromAggregate extends Rule[LogicalPlan] { private def extract(agg: Aggregate): LogicalPlan = { val projList = new ArrayBuffer[NamedExpression]() val aggExpr = new ArrayBuffer[NamedExpression]() - agg.aggregateExpressions.foreach { expr => if (hasPythonUdfOverAggregate(expr, agg)) { // Python UDF can only be evaluated after aggregate From 6d505d339335b0c3c0821f21721ee240a156d40b Mon Sep 17 00:00:00 2001 From: Li Jin Date: Wed, 17 Jan 2018 11:02:22 -0500 Subject: [PATCH 29/35] Minor style fix --- .../scala/org/apache/spark/sql/execution/SparkStrategies.scala | 1 - 1 file changed, 1 deletion(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala index 855b44cfbd117..1543d6d1fe78b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala @@ -341,7 +341,6 @@ abstract class SparkStrategies extends QueryPlanner[SparkPlan] { def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match { case PhysicalAggregation(groupingExpressions, aggExpressions, resultExpressions, child) if aggExpressions.forall(expr => expr.isInstanceOf[AggregateExpression]) => - val aggregateExpressions = aggExpressions.map(expr => expr.asInstanceOf[AggregateExpression]) From 8d2d9439354a5cc0adf1eecf3ee2a88f84844103 Mon Sep 17 00:00:00 2001 From: Li Jin Date: Wed, 17 Jan 2018 11:05:37 -0500 Subject: [PATCH 30/35] Revert accidental removal --- python/pyspark/sql/tests.py | 1 + 1 file changed, 1 insertion(+) diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py index 1774369cb8def..7015efae0e088 100644 --- a/python/pyspark/sql/tests.py +++ b/python/pyspark/sql/tests.py @@ -577,6 +577,7 @@ def test_udf_with_order_by_and_limit(self): my_copy = udf(lambda x: x, IntegerType()) df = self.spark.range(10).orderBy("id") res = df.select(df.id, my_copy(df.id).alias("copy")).limit(1) + res.explain(True) self.assertEqual(res.collect(), [Row(id=0, copy=0)]) def test_udf_registration_returns_udf(self): From 17fad5c0f83edb142471e2c4a1ffad08d7a29c5d Mon Sep 17 00:00:00 2001 From: Li Jin Date: Thu, 18 Jan 2018 11:45:12 -0500 Subject: [PATCH 31/35] Fix docs. Address PR comments. --- python/pyspark/sql/functions.py | 10 ++++----- python/pyspark/sql/group.py | 21 ++++++++++++------- .../spark/sql/execution/SparkStrategies.scala | 2 +- 3 files changed, 19 insertions(+), 14 deletions(-) diff --git a/python/pyspark/sql/functions.py b/python/pyspark/sql/functions.py index b3d221495d621..c9e35c0a33921 100644 --- a/python/pyspark/sql/functions.py +++ b/python/pyspark/sql/functions.py @@ -2161,7 +2161,7 @@ def pandas_udf(f=None, returnType=None, functionType=None): 1. SCALAR A scalar UDF defines a transformation: One or more `pandas.Series` -> A `pandas.Series`. - The returnType should be a primitive data type, e.g., `DoubleType()`. + The returnType should be a primitive data type, e.g., :class:`DoubleType`. The length of the returned `pandas.Series` must be of the same as the input `pandas.Series`. Scalar UDFs are used with :meth:`pyspark.sql.DataFrame.withColumn` and @@ -2226,11 +2226,12 @@ def pandas_udf(f=None, returnType=None, functionType=None): 3. GROUP_AGG A group aggregate UDF defines a transformation: One or more `pandas.Series` -> A scalar - The returnType should be a primitive data type, e.g, `DoubleType()`. + The `returnType` should be a primitive data type, e.g, :class:`DoubleType`. The returned scalar can be either a python primitive type, e.g., `int` or `float` or a numpy data type, e.g., `numpy.int64` or `numpy.float64`. - StructType and ArrayType are currently not supported. + :class:`ArrayType`, :class:`MapType` and :class:`StructType` are currently not supported as + output types. Group aggregate UDFs are used with :meth:`pyspark.sql.GroupedData.agg` @@ -2249,9 +2250,6 @@ def pandas_udf(f=None, returnType=None, functionType=None): | 2| 6.0| +---+-----------+ - .. note:: There is no partial aggregation with group aggregate UDFs, i.e., - a full shuffle is required. - .. seealso:: :meth:`pyspark.sql.GroupedData.agg` .. note:: The user-defined functions are considered deterministic by default. Due to diff --git a/python/pyspark/sql/group.py b/python/pyspark/sql/group.py index 96dace079b353..fa71abf59ff7c 100644 --- a/python/pyspark/sql/group.py +++ b/python/pyspark/sql/group.py @@ -69,18 +69,23 @@ def agg(self, *exprs): 1. built-in aggregation functions, such as `avg`, `max`, `min`, `sum`, `count` - 2. group aggregate pandas UDFs + 2. group aggregate pandas UDFs, created with :func:`pyspark.sql.functions.pandas_udf` .. note:: There is no partial aggregation with group aggregate UDFs, i.e., - a full shuffle is required. + a full shuffle is required. Also, all the data of a group will be loaded into + memory, so the user should be aware of the potential OOM risk if data is skewed + and certain groups are too large to fit in memory. - .. seealso:: :meth:`pyspark.sql.functions.pandas_udf` + .. seealso:: :func:`pyspark.sql.functions.pandas_udf` If ``exprs`` is a single :class:`dict` mapping from string to string, then the key is the column to perform aggregation on, and the value is the aggregate function. Alternatively, ``exprs`` can also be a list of aggregate :class:`Column` expressions. + .. note:: Built-in aggregation functions and group aggregate pandas UDFs cannot be mixed + in a single call to this function. + :param exprs: a dict mapping from column name (string) to aggregate functions (string), or a list of :class:`Column`. @@ -220,16 +225,18 @@ def apply(self, udf): The user-defined function should take a `pandas.DataFrame` and return another `pandas.DataFrame`. For each group, all columns are passed together as a `pandas.DataFrame` - to the user-function and the returned `pandas.DataFrame`s are combined as a + to the user-function and the returned `pandas.DataFrame` are combined as a :class:`DataFrame`. + The returned `pandas.DataFrame` can be of arbitrary length and its schema must match the returnType of the pandas udf. - This function does not support partial aggregation, and requires shuffling all the data in - the :class:`DataFrame`. + .. note:: This function requires a full shuffle. all the data of a group will be loaded + into memory, so the user should be aware of the potential OOM risk if data is skewed + and certain groups are too large to fit in memory. :param udf: a group map user-defined function returned by - :meth:`pyspark.sql.functions.pandas_udf`. + :func:`pyspark.sql.functions.pandas_udf`. >>> from pyspark.sql.functions import pandas_udf, PandasUDFType >>> df = spark.createDataFrame( diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala index 1543d6d1fe78b..7661a6406dd83 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala @@ -381,7 +381,7 @@ abstract class SparkStrategies extends QueryPlanner[SparkPlan] { resultExpressions, planLater(child))) - case PhysicalAggregation(groupingExpressions, aggExpressions, resultExpressions, child) => + case PhysicalAggregation(_) => // If cannot match the two cases above, then it's an error throw new AnalysisException( "Cannot use a mixture of aggregate function and group aggregate pandas UDF") From 0fec5cf86619f0a42647c1c53b4cb5b3d449ecd8 Mon Sep 17 00:00:00 2001 From: Li Jin Date: Thu, 18 Jan 2018 12:02:27 -0500 Subject: [PATCH 32/35] Fix SparkStrategies --- .../scala/org/apache/spark/sql/execution/SparkStrategies.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala index 7661a6406dd83..ce512bc46563a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala @@ -381,7 +381,7 @@ abstract class SparkStrategies extends QueryPlanner[SparkPlan] { resultExpressions, planLater(child))) - case PhysicalAggregation(_) => + case PhysicalAggregation(_, _, _, _) => // If cannot match the two cases above, then it's an error throw new AnalysisException( "Cannot use a mixture of aggregate function and group aggregate pandas UDF") From 4d22107cabb9683d9d1dcd8c03a4a6e45f34a909 Mon Sep 17 00:00:00 2001 From: Li Jin Date: Thu, 18 Jan 2018 13:29:10 -0500 Subject: [PATCH 33/35] Add a manual test --- python/pyspark/sql/tests.py | 21 +++++++++++++++++++++ 1 file changed, 21 insertions(+) diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py index 7015efae0e088..095828bdfb051 100644 --- a/python/pyspark/sql/tests.py +++ b/python/pyspark/sql/tests.py @@ -4408,6 +4408,27 @@ def weighted_mean(v, w): return np.average(v, weights=w) return weighted_mean + def test_manual(self): + df = self.data + sum_udf = self.pandas_agg_sum_udf + mean_udf = self.pandas_agg_mean_udf + + result1 = df.groupby('id').agg(sum_udf(df.v), mean_udf(df.v)).sort('id') + expected1 = self.spark.createDataFrame( + [[0, 245.0, 24.5], + [1, 255.0, 25.5], + [2, 265.0, 26.5], + [3, 275.0, 27.5], + [4, 285.0, 28.5], + [5, 295.0, 29.5], + [6, 305.0, 30.5], + [7, 315.0, 31.5], + [8, 325.0, 32.5], + [9, 335.0, 33.5]], + ['id', 'sum(v)', 'avg(v)']) + + self.assertPandasEqual(expected1.toPandas(), result1.toPandas()) + def test_basic(self): from pyspark.sql.functions import col, lit, sum, mean From 91885e5dbca02daf30f4d7c8dd1560c8b1bbad47 Mon Sep 17 00:00:00 2001 From: Li Jin Date: Mon, 22 Jan 2018 10:59:07 -0500 Subject: [PATCH 34/35] Address comments --- python/pyspark/sql/functions.py | 2 +- python/pyspark/sql/udf.py | 6 ++++-- 2 files changed, 5 insertions(+), 3 deletions(-) diff --git a/python/pyspark/sql/functions.py b/python/pyspark/sql/functions.py index c9e35c0a33921..b4440c675c272 100644 --- a/python/pyspark/sql/functions.py +++ b/python/pyspark/sql/functions.py @@ -2226,7 +2226,7 @@ def pandas_udf(f=None, returnType=None, functionType=None): 3. GROUP_AGG A group aggregate UDF defines a transformation: One or more `pandas.Series` -> A scalar - The `returnType` should be a primitive data type, e.g, :class:`DoubleType`. + The `returnType` should be a primitive data type, e.g., :class:`DoubleType`. The returned scalar can be either a python primitive type, e.g., `int` or `float` or a numpy data type, e.g., `numpy.int64` or `numpy.float64`. diff --git a/python/pyspark/sql/udf.py b/python/pyspark/sql/udf.py index 7a61c5ac83153..094b8c663737e 100644 --- a/python/pyspark/sql/udf.py +++ b/python/pyspark/sql/udf.py @@ -37,8 +37,10 @@ def _wrap_function(sc, func, returnType): def _create_udf(f, returnType, evalType): - if evalType == PythonEvalType.SQL_PANDAS_SCALAR_UDF or \ - evalType == PythonEvalType.SQL_PANDAS_GROUP_MAP_UDF: + if evalType in (PythonEvalType.SQL_PANDAS_SCALAR_UDF, + PythonEvalType.SQL_PANDAS_GROUP_MAP_UDF, + PythonEvalType.SQL_PANDAS_GROUP_AGG_UDF): + import inspect from pyspark.sql.utils import require_minimum_pyarrow_version From cc659bc2487d81a9497bd032049c2c4272660716 Mon Sep 17 00:00:00 2001 From: Li Jin Date: Mon, 22 Jan 2018 15:00:34 -0500 Subject: [PATCH 35/35] Add doctest SKIP --- python/pyspark/sql/functions.py | 2 +- python/pyspark/sql/group.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/python/pyspark/sql/functions.py b/python/pyspark/sql/functions.py index b4440c675c272..a291c9b71913f 100644 --- a/python/pyspark/sql/functions.py +++ b/python/pyspark/sql/functions.py @@ -2239,7 +2239,7 @@ def pandas_udf(f=None, returnType=None, functionType=None): >>> df = spark.createDataFrame( ... [(1, 1.0), (1, 2.0), (2, 3.0), (2, 5.0), (2, 10.0)], ... ("id", "v")) - >>> @pandas_udf("double", PandasUDFType.GROUP_AGG) + >>> @pandas_udf("double", PandasUDFType.GROUP_AGG) # doctest: +SKIP ... def mean_udf(v): ... return v.mean() >>> df.groupby("id").agg(mean_udf(df['v'])).show() # doctest: +SKIP diff --git a/python/pyspark/sql/group.py b/python/pyspark/sql/group.py index fa71abf59ff7c..f90a909d7c2b1 100644 --- a/python/pyspark/sql/group.py +++ b/python/pyspark/sql/group.py @@ -98,7 +98,7 @@ def agg(self, *exprs): [Row(name=u'Alice', min(age)=2), Row(name=u'Bob', min(age)=5)] >>> from pyspark.sql.functions import pandas_udf, PandasUDFType - >>> @pandas_udf('int', PandasUDFType.GROUP_AGG) + >>> @pandas_udf('int', PandasUDFType.GROUP_AGG) # doctest: +SKIP ... def min_udf(v): ... return v.min() >>> sorted(gdf.agg(min_udf(df.age)).collect()) # doctest: +SKIP