Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SPARK-22274][PYTHON][SQL] User-defined aggregation functions with pandas udf (full shuffle) #19872

Closed
wants to merge 35 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
35 commits
Select commit Hold shift + click to select a range
692b54f
Initial commit: wip
icexelloss Dec 4, 2017
11321f6
wip
icexelloss Dec 5, 2017
8cf0ccd
Tests pass
icexelloss Dec 8, 2017
c1f6cf9
Clean up
icexelloss Dec 8, 2017
3791b28
More clean up
icexelloss Dec 8, 2017
df9f6b3
more clean up
icexelloss Dec 8, 2017
58c21c1
Clean up code; Address PR comments
icexelloss Dec 19, 2017
d79464c
Fix python style
icexelloss Dec 19, 2017
51f4782
Add docs and more tests
icexelloss Dec 19, 2017
505acdb
Style fix
icexelloss Dec 19, 2017
4856e82
Remove whitespace
icexelloss Dec 19, 2017
3287e6e
Fix doctest
icexelloss Dec 20, 2017
7dcdd3a
Tests pass
icexelloss Dec 27, 2017
1c834f7
Fix merge error
icexelloss Dec 27, 2017
066783e
Fix test
icexelloss Dec 28, 2017
cd16485
Add complex_grouping test
icexelloss Dec 29, 2017
959f3eb
Address PR comments
icexelloss Dec 29, 2017
3cda9b8
Remove ExtractGroupAggPandasUDFFromAggregate
icexelloss Dec 29, 2017
1696bdb
Fix style
icexelloss Dec 29, 2017
4e713a4
Add doctest SKIP for passing build with pypy
icexelloss Jan 2, 2018
a89416f
Fix incorrect doctest SKIP
icexelloss Jan 2, 2018
4253caa
Add docs for AggregateInPandasExec
icexelloss Jan 2, 2018
f91d9ba
Address PR comments
icexelloss Jan 10, 2018
9085ca6
Address PR comments
icexelloss Jan 16, 2018
ebc49cc
Minor style change
icexelloss Jan 16, 2018
bf084ff
Fix error message]
icexelloss Jan 16, 2018
7745b0a
Fix Streaming aggregation check
icexelloss Jan 16, 2018
cf9e7dc
Minor style fix
icexelloss Jan 17, 2018
6d505d3
Minor style fix
icexelloss Jan 17, 2018
8d2d943
Revert accidental removal
icexelloss Jan 17, 2018
17fad5c
Fix docs. Address PR comments.
icexelloss Jan 18, 2018
0fec5cf
Fix SparkStrategies
icexelloss Jan 18, 2018
4d22107
Add a manual test
icexelloss Jan 18, 2018
91885e5
Address comments
icexelloss Jan 22, 2018
cc659bc
Add doctest SKIP
icexelloss Jan 22, 2018
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -39,12 +39,14 @@ private[spark] object PythonEvalType {

val SQL_PANDAS_SCALAR_UDF = 200
val SQL_PANDAS_GROUP_MAP_UDF = 201
val SQL_PANDAS_GROUP_AGG_UDF = 202

def toString(pythonEvalType: Int): String = pythonEvalType match {
case NON_UDF => "NON_UDF"
case SQL_BATCHED_UDF => "SQL_BATCHED_UDF"
case SQL_PANDAS_SCALAR_UDF => "SQL_PANDAS_SCALAR_UDF"
case SQL_PANDAS_GROUP_MAP_UDF => "SQL_PANDAS_GROUP_MAP_UDF"
case SQL_PANDAS_GROUP_AGG_UDF => "SQL_PANDAS_GROUP_AGG_UDF"
}
}

Expand Down
1 change: 1 addition & 0 deletions python/pyspark/rdd.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,7 @@ class PythonEvalType(object):

SQL_PANDAS_SCALAR_UDF = 200
SQL_PANDAS_GROUP_MAP_UDF = 201
SQL_PANDAS_GROUP_AGG_UDF = 202


def portable_hash(x):
Expand Down
36 changes: 34 additions & 2 deletions python/pyspark/sql/functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -2089,6 +2089,8 @@ class PandasUDFType(object):

GROUP_MAP = PythonEvalType.SQL_PANDAS_GROUP_MAP_UDF

GROUP_AGG = PythonEvalType.SQL_PANDAS_GROUP_AGG_UDF
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So I'm worried that it isn't clear to the user that this will result in a full-shuffle with no-partial aggregation. Is there maybe a place we can document this warning?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added in docstring of pandas_udf and groupby().agg()



@since(1.3)
def udf(f=None, returnType=StringType()):
Expand Down Expand Up @@ -2159,7 +2161,7 @@ def pandas_udf(f=None, returnType=None, functionType=None):
1. SCALAR

A scalar UDF defines a transformation: One or more `pandas.Series` -> A `pandas.Series`.
The returnType should be a primitive data type, e.g., `DoubleType()`.
The returnType should be a primitive data type, e.g., :class:`DoubleType`.
The length of the returned `pandas.Series` must be of the same as the input `pandas.Series`.

Scalar UDFs are used with :meth:`pyspark.sql.DataFrame.withColumn` and
Expand Down Expand Up @@ -2221,6 +2223,35 @@ def pandas_udf(f=None, returnType=None, functionType=None):

.. seealso:: :meth:`pyspark.sql.GroupedData.apply`

3. GROUP_AGG

A group aggregate UDF defines a transformation: One or more `pandas.Series` -> A scalar
The `returnType` should be a primitive data type, e.g., :class:`DoubleType`.
The returned scalar can be either a python primitive type, e.g., `int` or `float`
or a numpy data type, e.g., `numpy.int64` or `numpy.float64`.

:class:`ArrayType`, :class:`MapType` and :class:`StructType` are currently not supported as
output types.

Group aggregate UDFs are used with :meth:`pyspark.sql.GroupedData.agg`

>>> from pyspark.sql.functions import pandas_udf, PandasUDFType
>>> df = spark.createDataFrame(
... [(1, 1.0), (1, 2.0), (2, 3.0), (2, 5.0), (2, 10.0)],
... ("id", "v"))
>>> @pandas_udf("double", PandasUDFType.GROUP_AGG) # doctest: +SKIP
... def mean_udf(v):
... return v.mean()
>>> df.groupby("id").agg(mean_udf(df['v'])).show() # doctest: +SKIP
+---+-----------+
| id|mean_udf(v)|
+---+-----------+
| 1| 1.5|
| 2| 6.0|
+---+-----------+

.. seealso:: :meth:`pyspark.sql.GroupedData.agg`

.. note:: The user-defined functions are considered deterministic by default. Due to
optimization, duplicate invocations may be eliminated or the function may even be invoked
more times than it is present in the query. If your function is not deterministic, call
Expand Down Expand Up @@ -2267,7 +2298,8 @@ def pandas_udf(f=None, returnType=None, functionType=None):
raise ValueError("Invalid returnType: returnType can not be None")

if eval_type not in [PythonEvalType.SQL_PANDAS_SCALAR_UDF,
PythonEvalType.SQL_PANDAS_GROUP_MAP_UDF]:
PythonEvalType.SQL_PANDAS_GROUP_MAP_UDF,
PythonEvalType.SQL_PANDAS_GROUP_AGG_UDF]:
raise ValueError("Invalid functionType: "
"functionType must be one the values from PandasUDFType")

Expand Down
33 changes: 28 additions & 5 deletions python/pyspark/sql/group.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,13 +65,27 @@ def __init__(self, jgd, df):
def agg(self, *exprs):
"""Compute aggregates and returns the result as a :class:`DataFrame`.

The available aggregate functions are `avg`, `max`, `min`, `sum`, `count`.
The available aggregate functions can be:

1. built-in aggregation functions, such as `avg`, `max`, `min`, `sum`, `count`

2. group aggregate pandas UDFs, created with :func:`pyspark.sql.functions.pandas_udf`

.. note:: There is no partial aggregation with group aggregate UDFs, i.e.,
a full shuffle is required. Also, all the data of a group will be loaded into
memory, so the user should be aware of the potential OOM risk if data is skewed
and certain groups are too large to fit in memory.

.. seealso:: :func:`pyspark.sql.functions.pandas_udf`

If ``exprs`` is a single :class:`dict` mapping from string to string, then the key
is the column to perform aggregation on, and the value is the aggregate function.

Alternatively, ``exprs`` can also be a list of aggregate :class:`Column` expressions.

.. note:: Built-in aggregation functions and group aggregate pandas UDFs cannot be mixed
in a single call to this function.

:param exprs: a dict mapping from column name (string) to aggregate functions (string),
or a list of :class:`Column`.

Expand All @@ -82,6 +96,13 @@ def agg(self, *exprs):
>>> from pyspark.sql import functions as F
>>> sorted(gdf.agg(F.min(df.age)).collect())
[Row(name=u'Alice', min(age)=2), Row(name=u'Bob', min(age)=5)]

>>> from pyspark.sql.functions import pandas_udf, PandasUDFType
>>> @pandas_udf('int', PandasUDFType.GROUP_AGG) # doctest: +SKIP
... def min_udf(v):
... return v.min()
>>> sorted(gdf.agg(min_udf(df.age)).collect()) # doctest: +SKIP
[Row(name=u'Alice', min_udf(age)=2), Row(name=u'Bob', min_udf(age)=5)]
"""
assert exprs, "exprs should not be empty"
if len(exprs) == 1 and isinstance(exprs[0], dict):
Expand Down Expand Up @@ -204,16 +225,18 @@ def apply(self, udf):

The user-defined function should take a `pandas.DataFrame` and return another
`pandas.DataFrame`. For each group, all columns are passed together as a `pandas.DataFrame`
to the user-function and the returned `pandas.DataFrame`s are combined as a
to the user-function and the returned `pandas.DataFrame` are combined as a
:class:`DataFrame`.

The returned `pandas.DataFrame` can be of arbitrary length and its schema must match the
returnType of the pandas udf.

This function does not support partial aggregation, and requires shuffling all the data in
the :class:`DataFrame`.
.. note:: This function requires a full shuffle. all the data of a group will be loaded
into memory, so the user should be aware of the potential OOM risk if data is skewed
and certain groups are too large to fit in memory.

:param udf: a group map user-defined function returned by
:meth:`pyspark.sql.functions.pandas_udf`.
:func:`pyspark.sql.functions.pandas_udf`.

>>> from pyspark.sql.functions import pandas_udf, PandasUDFType
>>> df = spark.createDataFrame(
Expand Down
Loading