From 5ec810a7c36691df1877ffc11e6f06392d438485 Mon Sep 17 00:00:00 2001 From: "Michael (Stu) Stewart" Date: Sun, 11 Mar 2018 13:38:29 -0700 Subject: [PATCH] [SPARK-23645][PYTHON] Allow python udfs to be called with keyword arguments --- python/pyspark/sql/udf.py | 20 +++++++++++++++++--- 1 file changed, 17 insertions(+), 3 deletions(-) diff --git a/python/pyspark/sql/udf.py b/python/pyspark/sql/udf.py index 24dd06c26089c..4ea78fd97e4d5 100644 --- a/python/pyspark/sql/udf.py +++ b/python/pyspark/sql/udf.py @@ -20,6 +20,7 @@ import sys import inspect import functools +import itertools import sys from pyspark import SparkContext, since @@ -165,7 +166,20 @@ def _create_judf(self): self._name, wrapped_func, jdt, self.evalType, self.deterministic) return judf - def __call__(self, *cols): + def __call__(self, *cols, **kwcols): + # Handle keyword arguments + required = _get_argspec(self.func).args + if len(cols) < len(required): + # Extract remaining required arguments (from kwcols) in proper order + # Ensure no duplicate or unused arguments were passed + cols = tuple(itertools.chain.from_iterable( + [cols, (kwcols.pop(c) for c in required[len(cols):])])) + kwargs_remaining = list(kwcols.keys()) + if kwargs_remaining: + raise TypeError(self._name + "() " + + "got unexpected (or duplicated) keyword arguments: " + + str(kwargs_remaining)) + judf = self._judf sc = SparkContext._active_spark_context return Column(judf.apply(_to_seq(sc, cols, _to_java_column))) @@ -187,8 +201,8 @@ def _wrapped(self): a for a in functools.WRAPPER_ASSIGNMENTS if a != '__name__' and a != '__module__') @functools.wraps(self.func, assigned=assignments) - def wrapper(*args): - return self(*args) + def wrapper(*args, **kwargs): + return self(*args, **kwargs) wrapper.__name__ = self._name wrapper.__module__ = (self.func.__module__ if hasattr(self.func, '__module__')