diff --git a/python/pyspark/sql/functions.py b/python/pyspark/sql/functions.py index b3d221495d621..c9e35c0a33921 100644 --- a/python/pyspark/sql/functions.py +++ b/python/pyspark/sql/functions.py @@ -2161,7 +2161,7 @@ def pandas_udf(f=None, returnType=None, functionType=None): 1. SCALAR A scalar UDF defines a transformation: One or more `pandas.Series` -> A `pandas.Series`. - The returnType should be a primitive data type, e.g., `DoubleType()`. + The returnType should be a primitive data type, e.g., :class:`DoubleType`. The length of the returned `pandas.Series` must be of the same as the input `pandas.Series`. Scalar UDFs are used with :meth:`pyspark.sql.DataFrame.withColumn` and @@ -2226,11 +2226,12 @@ def pandas_udf(f=None, returnType=None, functionType=None): 3. GROUP_AGG A group aggregate UDF defines a transformation: One or more `pandas.Series` -> A scalar - The returnType should be a primitive data type, e.g, `DoubleType()`. + The `returnType` should be a primitive data type, e.g, :class:`DoubleType`. The returned scalar can be either a python primitive type, e.g., `int` or `float` or a numpy data type, e.g., `numpy.int64` or `numpy.float64`. - StructType and ArrayType are currently not supported. + :class:`ArrayType`, :class:`MapType` and :class:`StructType` are currently not supported as + output types. Group aggregate UDFs are used with :meth:`pyspark.sql.GroupedData.agg` @@ -2249,9 +2250,6 @@ def pandas_udf(f=None, returnType=None, functionType=None): | 2| 6.0| +---+-----------+ - .. note:: There is no partial aggregation with group aggregate UDFs, i.e., - a full shuffle is required. - .. seealso:: :meth:`pyspark.sql.GroupedData.agg` .. note:: The user-defined functions are considered deterministic by default. Due to diff --git a/python/pyspark/sql/group.py b/python/pyspark/sql/group.py index 96dace079b353..fa71abf59ff7c 100644 --- a/python/pyspark/sql/group.py +++ b/python/pyspark/sql/group.py @@ -69,18 +69,23 @@ def agg(self, *exprs): 1. built-in aggregation functions, such as `avg`, `max`, `min`, `sum`, `count` - 2. group aggregate pandas UDFs + 2. group aggregate pandas UDFs, created with :func:`pyspark.sql.functions.pandas_udf` .. note:: There is no partial aggregation with group aggregate UDFs, i.e., - a full shuffle is required. + a full shuffle is required. Also, all the data of a group will be loaded into + memory, so the user should be aware of the potential OOM risk if data is skewed + and certain groups are too large to fit in memory. - .. seealso:: :meth:`pyspark.sql.functions.pandas_udf` + .. seealso:: :func:`pyspark.sql.functions.pandas_udf` If ``exprs`` is a single :class:`dict` mapping from string to string, then the key is the column to perform aggregation on, and the value is the aggregate function. Alternatively, ``exprs`` can also be a list of aggregate :class:`Column` expressions. + .. note:: Built-in aggregation functions and group aggregate pandas UDFs cannot be mixed + in a single call to this function. + :param exprs: a dict mapping from column name (string) to aggregate functions (string), or a list of :class:`Column`. @@ -220,16 +225,18 @@ def apply(self, udf): The user-defined function should take a `pandas.DataFrame` and return another `pandas.DataFrame`. For each group, all columns are passed together as a `pandas.DataFrame` - to the user-function and the returned `pandas.DataFrame`s are combined as a + to the user-function and the returned `pandas.DataFrame` are combined as a :class:`DataFrame`. + The returned `pandas.DataFrame` can be of arbitrary length and its schema must match the returnType of the pandas udf. - This function does not support partial aggregation, and requires shuffling all the data in - the :class:`DataFrame`. + .. note:: This function requires a full shuffle. all the data of a group will be loaded + into memory, so the user should be aware of the potential OOM risk if data is skewed + and certain groups are too large to fit in memory. :param udf: a group map user-defined function returned by - :meth:`pyspark.sql.functions.pandas_udf`. + :func:`pyspark.sql.functions.pandas_udf`. >>> from pyspark.sql.functions import pandas_udf, PandasUDFType >>> df = spark.createDataFrame( diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala index 1543d6d1fe78b..7661a6406dd83 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala @@ -381,7 +381,7 @@ abstract class SparkStrategies extends QueryPlanner[SparkPlan] { resultExpressions, planLater(child))) - case PhysicalAggregation(groupingExpressions, aggExpressions, resultExpressions, child) => + case PhysicalAggregation(_) => // If cannot match the two cases above, then it's an error throw new AnalysisException( "Cannot use a mixture of aggregate function and group aggregate pandas UDF")