Skip to content

Commit

Permalink
Incomplete / Show issue with partial fn in pandas_udf
Browse files Browse the repository at this point in the history
  • Loading branch information
mstewart141 committed Mar 18, 2018
1 parent 5ec810a commit 969f907
Show file tree
Hide file tree
Showing 3 changed files with 40 additions and 15 deletions.
6 changes: 6 additions & 0 deletions python/pyspark/sql/functions.py
Expand Up @@ -2155,6 +2155,9 @@ def udf(f=None, returnType=StringType()):
in boolean expressions and it ends up with being executed all internally. If the functions
can fail on special rows, the workaround is to incorporate the condition into the functions.
.. note:: The user-defined functions may take keyword arguments e.g. (a=7) in python3, but in
python2 they can not.
:param f: python function if used as a standalone function
:param returnType: the return type of the user-defined function. The value can be either a
:class:`pyspark.sql.types.DataType` object or a DDL-formatted type string.
Expand Down Expand Up @@ -2338,6 +2341,9 @@ def pandas_udf(f=None, returnType=None, functionType=None):
.. note:: The user-defined functions do not support conditional expressions or short circuiting
in boolean expressions and it ends up with being executed all internally. If the functions
can fail on special rows, the workaround is to incorporate the condition into the functions.
.. note:: The user-defined functions may take keyword arguments e.g. (a=7) in python3, but in
python2 they can not.
"""
# decorator @pandas_udf(returnType, functionType)
is_decorator = f is None or isinstance(f, (str, DataType))
Expand Down
47 changes: 33 additions & 14 deletions python/pyspark/sql/udf.py
Expand Up @@ -52,7 +52,8 @@ def _create_udf(f, returnType, evalType):
argspec = _get_argspec(f)

if evalType == PythonEvalType.SQL_SCALAR_PANDAS_UDF and len(argspec.args) == 0 and \
argspec.varargs is None:
argspec.varargs is None and not \
(sys.version_info[0] > 2 and len(argspec.kwonlyargs) > 0):
raise ValueError(
"Invalid function: 0-arg pandas_udfs are not supported. "
"Instead, create a 1-arg pandas_udf and ignore the arg in your function."
Expand Down Expand Up @@ -167,19 +168,37 @@ def _create_judf(self):
return judf

def __call__(self, *cols, **kwcols):
# Handle keyword arguments
required = _get_argspec(self.func).args
if len(cols) < len(required):
# Extract remaining required arguments (from kwcols) in proper order
# Ensure no duplicate or unused arguments were passed
cols = tuple(itertools.chain.from_iterable(
[cols, (kwcols.pop(c) for c in required[len(cols):])]))
kwargs_remaining = list(kwcols.keys())
if kwargs_remaining:
raise TypeError(self._name + "() "
+ "got unexpected (or duplicated) keyword arguments: "
+ str(kwargs_remaining))

# Handle keyword arguments for python3
if sys.version_info[0] > 2:
spec = _get_argspec(self.func)
required = spec.args + spec.kwonlyargs
defaults = spec.kwonlydefaults or {}
if len(cols) < len(required):
print('qqqqq', '\nrequired', required, '\ndefaults', defaults, '\ncols', cols, '\nkwcols', kwcols)

def _normalize_args(cols_, kwcols_):
"""
Extract remaining required arguments (from kwcols) in proper order.
Ensure no duplicate or unused arguments were passed.
"""
updated_cols = tuple(itertools.chain.from_iterable(
[cols_, (kwcols_.pop(c) for c in required[len(cols_):] if c not in defaults)]))
kwargs_remaining = list(set(kwcols_.keys()) - set(defaults.keys()))
print('REMAIN', kwargs_remaining)
if kwargs_remaining:
raise TypeError(self._name + "() "
+ "got unexpected (or duplicated) keyword arguments: "
+ str(kwargs_remaining))
return updated_cols

def _merge(d1, d2):
d = d1.copy()
d.update(d2)
print('merged', d)
return d

cols = _normalize_args(cols, _merge(kwcols, kwcols))
print('FINALLLL cols', cols)
judf = self._judf
sc = SparkContext._active_spark_context
return Column(judf.apply(_to_seq(sc, cols, _to_java_column)))
Expand Down
2 changes: 1 addition & 1 deletion python/pyspark/worker.py
Expand Up @@ -79,7 +79,7 @@ def wrap_scalar_pandas_udf(f, return_type):
arrow_return_type = to_arrow_type(return_type)

def verify_result_length(*a):
result = f(*a)
result = f(*a) # <-- this does not have any notion that f may be a functools.partial that already has some args accounted for, so partial fns blow up
if not hasattr(result, "__len__"):
raise TypeError("Return type of the user-defined functon should be "
"Pandas.Series, but is {}".format(type(result)))
Expand Down

0 comments on commit 969f907

Please sign in to comment.