Skip to content

Commit

Permalink
[SPARK-20396][SQL][PYSPARK] groupby().apply() with pandas udf
Browse files Browse the repository at this point in the history
## What changes were proposed in this pull request?

This PR adds an apply() function on df.groupby(). apply() takes a pandas udf that is a transformation on `pandas.DataFrame` -> `pandas.DataFrame`.

Static schema
-------------------
```
schema = df.schema

pandas_udf(schema)
def normalize(df):
    df = df.assign(v1 = (df.v1 - df.v1.mean()) / df.v1.std()
    return df

df.groupBy('id').apply(normalize)
```
Dynamic schema
-----------------------
**This use case is removed from the PR and we will discuss this as a follow up. See discussion #18732 (review)

Another example to use pd.DataFrame dtypes as output schema of the udf:

```
sample_df = df.filter(df.id == 1).toPandas()

def foo(df):
      ret = # Some transformation on the input pd.DataFrame
      return ret

foo_udf = pandas_udf(foo, foo(sample_df).dtypes)

df.groupBy('id').apply(foo_udf)
```
In interactive use case, user usually have a sample pd.DataFrame to test function `foo` in their notebook. Having been able to use `foo(sample_df).dtypes` frees user from specifying the output schema of `foo`.

Design doc: https://github.com/icexelloss/spark/blob/pandas-udf-doc/docs/pyspark-pandas-udf.md

## How was this patch tested?
* Added GroupbyApplyTest

Author: Li Jin <ice.xelloss@gmail.com>
Author: Takuya UESHIN <ueshin@databricks.com>
Author: Bryan Cutler <cutlerb@gmail.com>

Closes #18732 from icexelloss/groupby-apply-SPARK-20396.
  • Loading branch information
icexelloss authored and HyukjinKwon committed Oct 10, 2017
1 parent 2028e5a commit bfc7e1f
Show file tree
Hide file tree
Showing 14 changed files with 561 additions and 69 deletions.
6 changes: 3 additions & 3 deletions python/pyspark/sql/dataframe.py
Expand Up @@ -1227,7 +1227,7 @@ def groupBy(self, *cols):
"""
jgd = self._jdf.groupBy(self._jcols(*cols))
from pyspark.sql.group import GroupedData
return GroupedData(jgd, self.sql_ctx)
return GroupedData(jgd, self)

@since(1.4)
def rollup(self, *cols):
Expand All @@ -1248,7 +1248,7 @@ def rollup(self, *cols):
"""
jgd = self._jdf.rollup(self._jcols(*cols))
from pyspark.sql.group import GroupedData
return GroupedData(jgd, self.sql_ctx)
return GroupedData(jgd, self)

@since(1.4)
def cube(self, *cols):
Expand All @@ -1271,7 +1271,7 @@ def cube(self, *cols):
"""
jgd = self._jdf.cube(self._jcols(*cols))
from pyspark.sql.group import GroupedData
return GroupedData(jgd, self.sql_ctx)
return GroupedData(jgd, self)

@since(1.3)
def agg(self, *exprs):
Expand Down
98 changes: 72 additions & 26 deletions python/pyspark/sql/functions.py
Expand Up @@ -2058,7 +2058,7 @@ def __init__(self, func, returnType, name=None, vectorized=False):
self._name = name or (
func.__name__ if hasattr(func, '__name__')
else func.__class__.__name__)
self._vectorized = vectorized
self.vectorized = vectorized

@property
def returnType(self):
Expand Down Expand Up @@ -2090,7 +2090,7 @@ def _create_judf(self):
wrapped_func = _wrap_function(sc, self.func, self.returnType)
jdt = spark._jsparkSession.parseDataType(self.returnType.json())
judf = sc._jvm.org.apache.spark.sql.execution.python.UserDefinedPythonFunction(
self._name, wrapped_func, jdt, self._vectorized)
self._name, wrapped_func, jdt, self.vectorized)
return judf

def __call__(self, *cols):
Expand Down Expand Up @@ -2118,8 +2118,10 @@ def wrapper(*args):
wrapper.__name__ = self._name
wrapper.__module__ = (self.func.__module__ if hasattr(self.func, '__module__')
else self.func.__class__.__module__)

wrapper.func = self.func
wrapper.returnType = self.returnType
wrapper.vectorized = self.vectorized

return wrapper

Expand All @@ -2129,8 +2131,12 @@ def _create_udf(f, returnType, vectorized):
def _udf(f, returnType=StringType(), vectorized=vectorized):
if vectorized:
import inspect
if len(inspect.getargspec(f).args) == 0:
raise NotImplementedError("0-parameter pandas_udfs are not currently supported")
argspec = inspect.getargspec(f)
if len(argspec.args) == 0 and argspec.varargs is None:
raise ValueError(
"0-arg pandas_udfs are not supported. "
"Instead, create a 1-arg pandas_udf and ignore the arg in your function."
)
udf_obj = UserDefinedFunction(f, returnType, vectorized=vectorized)
return udf_obj._wrapped()

Expand All @@ -2146,7 +2152,7 @@ def _udf(f, returnType=StringType(), vectorized=vectorized):

@since(1.3)
def udf(f=None, returnType=StringType()):
"""Creates a :class:`Column` expression representing a user defined function (UDF).
"""Creates a user defined function (UDF).
.. note:: The user-defined functions must be deterministic. Due to optimization,
duplicate invocations may be eliminated or the function may even be invoked more times than
Expand Down Expand Up @@ -2181,30 +2187,70 @@ def udf(f=None, returnType=StringType()):
@since(2.3)
def pandas_udf(f=None, returnType=StringType()):
"""
Creates a :class:`Column` expression representing a user defined function (UDF) that accepts
`Pandas.Series` as input arguments and outputs a `Pandas.Series` of the same length.
Creates a vectorized user defined function (UDF).
:param f: python function if used as a standalone function
:param f: user-defined function. A python function if used as a standalone function
:param returnType: a :class:`pyspark.sql.types.DataType` object
>>> from pyspark.sql.types import IntegerType, StringType
>>> slen = pandas_udf(lambda s: s.str.len(), IntegerType())
>>> @pandas_udf(returnType=StringType())
... def to_upper(s):
... return s.str.upper()
...
>>> @pandas_udf(returnType="integer")
... def add_one(x):
... return x + 1
...
>>> df = spark.createDataFrame([(1, "John Doe", 21)], ("id", "name", "age"))
>>> df.select(slen("name").alias("slen(name)"), to_upper("name"), add_one("age")) \\
... .show() # doctest: +SKIP
+----------+--------------+------------+
|slen(name)|to_upper(name)|add_one(age)|
+----------+--------------+------------+
| 8| JOHN DOE| 22|
+----------+--------------+------------+
The user-defined function can define one of the following transformations:
1. One or more `pandas.Series` -> A `pandas.Series`
This udf is used with :meth:`pyspark.sql.DataFrame.withColumn` and
:meth:`pyspark.sql.DataFrame.select`.
The returnType should be a primitive data type, e.g., `DoubleType()`.
The length of the returned `pandas.Series` must be of the same as the input `pandas.Series`.
>>> from pyspark.sql.types import IntegerType, StringType
>>> slen = pandas_udf(lambda s: s.str.len(), IntegerType())
>>> @pandas_udf(returnType=StringType())
... def to_upper(s):
... return s.str.upper()
...
>>> @pandas_udf(returnType="integer")
... def add_one(x):
... return x + 1
...
>>> df = spark.createDataFrame([(1, "John Doe", 21)], ("id", "name", "age"))
>>> df.select(slen("name").alias("slen(name)"), to_upper("name"), add_one("age")) \\
... .show() # doctest: +SKIP
+----------+--------------+------------+
|slen(name)|to_upper(name)|add_one(age)|
+----------+--------------+------------+
| 8| JOHN DOE| 22|
+----------+--------------+------------+
2. A `pandas.DataFrame` -> A `pandas.DataFrame`
This udf is only used with :meth:`pyspark.sql.GroupedData.apply`.
The returnType should be a :class:`StructType` describing the schema of the returned
`pandas.DataFrame`.
>>> df = spark.createDataFrame(
... [(1, 1.0), (1, 2.0), (2, 3.0), (2, 5.0), (2, 10.0)],
... ("id", "v"))
>>> @pandas_udf(returnType=df.schema)
... def normalize(pdf):
... v = pdf.v
... return pdf.assign(v=(v - v.mean()) / v.std())
>>> df.groupby('id').apply(normalize).show() # doctest: +SKIP
+---+-------------------+
| id| v|
+---+-------------------+
| 1|-0.7071067811865475|
| 1| 0.7071067811865475|
| 2|-0.8320502943378437|
| 2|-0.2773500981126146|
| 2| 1.1094003924504583|
+---+-------------------+
.. note:: This type of udf cannot be used with functions such as `withColumn` or `select`
because it defines a `DataFrame` transformation rather than a `Column`
transformation.
.. seealso:: :meth:`pyspark.sql.GroupedData.apply`
.. note:: The user-defined function must be deterministic.
"""
return _create_udf(f, returnType=returnType, vectorized=True)

Expand Down
88 changes: 84 additions & 4 deletions python/pyspark/sql/group.py
Expand Up @@ -54,9 +54,10 @@ class GroupedData(object):
.. versionadded:: 1.3
"""

def __init__(self, jgd, sql_ctx):
def __init__(self, jgd, df):
self._jgd = jgd
self.sql_ctx = sql_ctx
self._df = df
self.sql_ctx = df.sql_ctx

@ignore_unicode_prefix
@since(1.3)
Expand Down Expand Up @@ -170,7 +171,7 @@ def sum(self, *cols):
@since(1.6)
def pivot(self, pivot_col, values=None):
"""
Pivots a column of the current [[DataFrame]] and perform the specified aggregation.
Pivots a column of the current :class:`DataFrame` and perform the specified aggregation.
There are two versions of pivot function: one that requires the caller to specify the list
of distinct values to pivot on, and one that does not. The latter is more concise but less
efficient, because Spark needs to first compute the list of distinct values internally.
Expand All @@ -192,7 +193,85 @@ def pivot(self, pivot_col, values=None):
jgd = self._jgd.pivot(pivot_col)
else:
jgd = self._jgd.pivot(pivot_col, values)
return GroupedData(jgd, self.sql_ctx)
return GroupedData(jgd, self._df)

@since(2.3)
def apply(self, udf):
"""
Maps each group of the current :class:`DataFrame` using a pandas udf and returns the result
as a `DataFrame`.
The user-defined function should take a `pandas.DataFrame` and return another
`pandas.DataFrame`. For each group, all columns are passed together as a `pandas.DataFrame`
to the user-function and the returned `pandas.DataFrame`s are combined as a
:class:`DataFrame`.
The returned `pandas.DataFrame` can be of arbitrary length and its schema must match the
returnType of the pandas udf.
This function does not support partial aggregation, and requires shuffling all the data in
the :class:`DataFrame`.
:param udf: A function object returned by :meth:`pyspark.sql.functions.pandas_udf`
>>> from pyspark.sql.functions import pandas_udf
>>> df = spark.createDataFrame(
... [(1, 1.0), (1, 2.0), (2, 3.0), (2, 5.0), (2, 10.0)],
... ("id", "v"))
>>> @pandas_udf(returnType=df.schema)
... def normalize(pdf):
... v = pdf.v
... return pdf.assign(v=(v - v.mean()) / v.std())
>>> df.groupby('id').apply(normalize).show() # doctest: +SKIP
+---+-------------------+
| id| v|
+---+-------------------+
| 1|-0.7071067811865475|
| 1| 0.7071067811865475|
| 2|-0.8320502943378437|
| 2|-0.2773500981126146|
| 2| 1.1094003924504583|
+---+-------------------+
.. seealso:: :meth:`pyspark.sql.functions.pandas_udf`
"""
from pyspark.sql.functions import pandas_udf

# Columns are special because hasattr always return True
if isinstance(udf, Column) or not hasattr(udf, 'func') or not udf.vectorized:
raise ValueError("The argument to apply must be a pandas_udf")
if not isinstance(udf.returnType, StructType):
raise ValueError("The returnType of the pandas_udf must be a StructType")

df = self._df
func = udf.func
returnType = udf.returnType

# The python executors expects the function to use pd.Series as input and output
# So we to create a wrapper function that turns that to a pd.DataFrame before passing
# down to the user function, then turn the result pd.DataFrame back into pd.Series
columns = df.columns

def wrapped(*cols):
from pyspark.sql.types import to_arrow_type
import pandas as pd
result = func(pd.concat(cols, axis=1, keys=columns))
if not isinstance(result, pd.DataFrame):
raise TypeError("Return type of the user-defined function should be "
"Pandas.DataFrame, but is {}".format(type(result)))
if not len(result.columns) == len(returnType):
raise RuntimeError(
"Number of columns of the returned Pandas.DataFrame "
"doesn't match specified schema. "
"Expected: {} Actual: {}".format(len(returnType), len(result.columns)))
arrow_return_types = (to_arrow_type(field.dataType) for field in returnType)
return [(result[result.columns[i]], arrow_type)
for i, arrow_type in enumerate(arrow_return_types)]

wrapped_udf_obj = pandas_udf(wrapped, returnType)
udf_column = wrapped_udf_obj(*[df[col] for col in df.columns])
jdf = self._jgd.flatMapGroupsInPandas(udf_column._jc.expr())
return DataFrame(jdf, self.sql_ctx)


def _test():
Expand All @@ -206,6 +285,7 @@ def _test():
.getOrCreate()
sc = spark.sparkContext
globs['sc'] = sc
globs['spark'] = spark
globs['df'] = sc.parallelize([(2, 'Alice'), (5, 'Bob')]) \
.toDF(StructType([StructField('age', IntegerType()),
StructField('name', StringType())]))
Expand Down

0 comments on commit bfc7e1f

Please sign in to comment.