-
Notifications
You must be signed in to change notification settings - Fork 28.1k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[SPARK-22274][PYTHON][SQL] User-defined aggregation functions with pandas udf (full shuffle) #19872
Changes from 33 commits
692b54f
11321f6
8cf0ccd
c1f6cf9
3791b28
df9f6b3
58c21c1
d79464c
51f4782
505acdb
4856e82
3287e6e
7dcdd3a
1c834f7
066783e
cd16485
959f3eb
3cda9b8
1696bdb
4e713a4
a89416f
4253caa
f91d9ba
9085ca6
ebc49cc
bf084ff
7745b0a
cf9e7dc
6d505d3
8d2d943
17fad5c
0fec5cf
4d22107
91885e5
cc659bc
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -2089,6 +2089,8 @@ class PandasUDFType(object): | |
|
||
GROUP_MAP = PythonEvalType.SQL_PANDAS_GROUP_MAP_UDF | ||
|
||
GROUP_AGG = PythonEvalType.SQL_PANDAS_GROUP_AGG_UDF | ||
|
||
|
||
@since(1.3) | ||
def udf(f=None, returnType=StringType()): | ||
|
@@ -2159,7 +2161,7 @@ def pandas_udf(f=None, returnType=None, functionType=None): | |
1. SCALAR | ||
|
||
A scalar UDF defines a transformation: One or more `pandas.Series` -> A `pandas.Series`. | ||
The returnType should be a primitive data type, e.g., `DoubleType()`. | ||
The returnType should be a primitive data type, e.g., :class:`DoubleType`. | ||
The length of the returned `pandas.Series` must be of the same as the input `pandas.Series`. | ||
|
||
Scalar UDFs are used with :meth:`pyspark.sql.DataFrame.withColumn` and | ||
|
@@ -2221,6 +2223,35 @@ def pandas_udf(f=None, returnType=None, functionType=None): | |
|
||
.. seealso:: :meth:`pyspark.sql.GroupedData.apply` | ||
|
||
3. GROUP_AGG | ||
|
||
A group aggregate UDF defines a transformation: One or more `pandas.Series` -> A scalar | ||
The `returnType` should be a primitive data type, e.g, :class:`DoubleType`. | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. very small nit: There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Fixed. Thanks! |
||
The returned scalar can be either a python primitive type, e.g., `int` or `float` | ||
or a numpy data type, e.g., `numpy.int64` or `numpy.float64`. | ||
|
||
:class:`ArrayType`, :class:`MapType` and :class:`StructType` are currently not supported as | ||
output types. | ||
|
||
Group aggregate UDFs are used with :meth:`pyspark.sql.GroupedData.agg` | ||
|
||
>>> from pyspark.sql.functions import pandas_udf, PandasUDFType | ||
>>> df = spark.createDataFrame( | ||
... [(1, 1.0), (1, 2.0), (2, 3.0), (2, 5.0), (2, 10.0)], | ||
... ("id", "v")) | ||
>>> @pandas_udf("double", PandasUDFType.GROUP_AGG) | ||
... def mean_udf(v): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. shall we include grouping columns? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Sorry @cloud-fan, I don't understand this comment, could you elaborate? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. similar to the There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. IIUC, this is similar to SQL's aggregation, and this aggregation UDF should only take the column for aggregation without grouping columns. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. ah makes sense There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I agree with @viirya |
||
... return v.mean() | ||
>>> df.groupby("id").agg(mean_udf(df['v'])).show() # doctest: +SKIP | ||
+---+-----------+ | ||
| id|mean_udf(v)| | ||
+---+-----------+ | ||
| 1| 1.5| | ||
| 2| 6.0| | ||
+---+-----------+ | ||
|
||
.. seealso:: :meth:`pyspark.sql.GroupedData.agg` | ||
|
||
.. note:: The user-defined functions are considered deterministic by default. Due to | ||
optimization, duplicate invocations may be eliminated or the function may even be invoked | ||
more times than it is present in the query. If your function is not deterministic, call | ||
|
@@ -2267,7 +2298,8 @@ def pandas_udf(f=None, returnType=None, functionType=None): | |
raise ValueError("Invalid returnType: returnType can not be None") | ||
|
||
if eval_type not in [PythonEvalType.SQL_PANDAS_SCALAR_UDF, | ||
PythonEvalType.SQL_PANDAS_GROUP_MAP_UDF]: | ||
PythonEvalType.SQL_PANDAS_GROUP_MAP_UDF, | ||
PythonEvalType.SQL_PANDAS_GROUP_AGG_UDF]: | ||
raise ValueError("Invalid functionType: " | ||
"functionType must be one the values from PandasUDFType") | ||
|
||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -65,13 +65,27 @@ def __init__(self, jgd, df): | |
def agg(self, *exprs): | ||
"""Compute aggregates and returns the result as a :class:`DataFrame`. | ||
|
||
The available aggregate functions are `avg`, `max`, `min`, `sum`, `count`. | ||
The available aggregate functions can be: | ||
|
||
1. built-in aggregation functions, such as `avg`, `max`, `min`, `sum`, `count` | ||
|
||
2. group aggregate pandas UDFs, created with :func:`pyspark.sql.functions.pandas_udf` | ||
|
||
.. note:: There is no partial aggregation with group aggregate UDFs, i.e., | ||
a full shuffle is required. Also, all the data of a group will be loaded into | ||
memory, so the user should be aware of the potential OOM risk if data is skewed | ||
and certain groups are too large to fit in memory. | ||
|
||
.. seealso:: :func:`pyspark.sql.functions.pandas_udf` | ||
|
||
If ``exprs`` is a single :class:`dict` mapping from string to string, then the key | ||
is the column to perform aggregation on, and the value is the aggregate function. | ||
|
||
Alternatively, ``exprs`` can also be a list of aggregate :class:`Column` expressions. | ||
|
||
.. note:: Built-in aggregation functions and group aggregate pandas UDFs cannot be mixed | ||
in a single call to this function. | ||
|
||
:param exprs: a dict mapping from column name (string) to aggregate functions (string), | ||
or a list of :class:`Column`. | ||
|
||
|
@@ -82,6 +96,13 @@ def agg(self, *exprs): | |
>>> from pyspark.sql import functions as F | ||
>>> sorted(gdf.agg(F.min(df.age)).collect()) | ||
[Row(name=u'Alice', min(age)=2), Row(name=u'Bob', min(age)=5)] | ||
|
||
>>> from pyspark.sql.functions import pandas_udf, PandasUDFType | ||
>>> @pandas_udf('int', PandasUDFType.GROUP_AGG) | ||
... def min_udf(v): | ||
... return v.min() | ||
>>> sorted(gdf.agg(min_udf(df.age)).collect()) # doctest: +SKIP | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I don't know a good way of skipping doctest when pyarrow is not available... If others have some ideas, please let me know There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. That's fine. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think in the future we should make pandas/arrow a requirement of pyspark, so that we can always assume the pandas/arrow is installed when run doc test. |
||
[Row(name=u'Alice', min_udf(age)=2), Row(name=u'Bob', min_udf(age)=5)] | ||
""" | ||
assert exprs, "exprs should not be empty" | ||
if len(exprs) == 1 and isinstance(exprs[0], dict): | ||
|
@@ -204,16 +225,18 @@ def apply(self, udf): | |
|
||
The user-defined function should take a `pandas.DataFrame` and return another | ||
`pandas.DataFrame`. For each group, all columns are passed together as a `pandas.DataFrame` | ||
to the user-function and the returned `pandas.DataFrame`s are combined as a | ||
to the user-function and the returned `pandas.DataFrame` are combined as a | ||
:class:`DataFrame`. | ||
|
||
The returned `pandas.DataFrame` can be of arbitrary length and its schema must match the | ||
returnType of the pandas udf. | ||
|
||
This function does not support partial aggregation, and requires shuffling all the data in | ||
the :class:`DataFrame`. | ||
.. note:: This function requires a full shuffle. all the data of a group will be loaded | ||
into memory, so the user should be aware of the potential OOM risk if data is skewed | ||
and certain groups are too large to fit in memory. | ||
|
||
:param udf: a group map user-defined function returned by | ||
:meth:`pyspark.sql.functions.pandas_udf`. | ||
:func:`pyspark.sql.functions.pandas_udf`. | ||
|
||
>>> from pyspark.sql.functions import pandas_udf, PandasUDFType | ||
>>> df = spark.createDataFrame( | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
So I'm worried that it isn't clear to the user that this will result in a full-shuffle with no-partial aggregation. Is there maybe a place we can document this warning?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Added in docstring of
pandas_udf
andgroupby().agg()