Skip to content

Commit

Permalink
[SPARK-23645][PYTHON] Allow python udfs to be called with keyword arg…
Browse files Browse the repository at this point in the history
…uments
  • Loading branch information
mstewart141 committed Mar 11, 2018
1 parent b6f837c commit 5ec810a
Showing 1 changed file with 17 additions and 3 deletions.
20 changes: 17 additions & 3 deletions python/pyspark/sql/udf.py
Expand Up @@ -20,6 +20,7 @@
import sys
import inspect
import functools
import itertools
import sys

from pyspark import SparkContext, since
Expand Down Expand Up @@ -165,7 +166,20 @@ def _create_judf(self):
self._name, wrapped_func, jdt, self.evalType, self.deterministic)
return judf

def __call__(self, *cols):
def __call__(self, *cols, **kwcols):
# Handle keyword arguments
required = _get_argspec(self.func).args
if len(cols) < len(required):
# Extract remaining required arguments (from kwcols) in proper order
# Ensure no duplicate or unused arguments were passed
cols = tuple(itertools.chain.from_iterable(
[cols, (kwcols.pop(c) for c in required[len(cols):])]))
kwargs_remaining = list(kwcols.keys())
if kwargs_remaining:
raise TypeError(self._name + "() "
+ "got unexpected (or duplicated) keyword arguments: "
+ str(kwargs_remaining))

judf = self._judf
sc = SparkContext._active_spark_context
return Column(judf.apply(_to_seq(sc, cols, _to_java_column)))
Expand All @@ -187,8 +201,8 @@ def _wrapped(self):
a for a in functools.WRAPPER_ASSIGNMENTS if a != '__name__' and a != '__module__')

@functools.wraps(self.func, assigned=assignments)
def wrapper(*args):
return self(*args)
def wrapper(*args, **kwargs):
return self(*args, **kwargs)

wrapper.__name__ = self._name
wrapper.__module__ = (self.func.__module__ if hasattr(self.func, '__module__')
Expand Down

0 comments on commit 5ec810a

Please sign in to comment.