Skip to content

Commit

Permalink
refactor _get_argspec
Browse files Browse the repository at this point in the history
  • Loading branch information
icexelloss committed Mar 6, 2018
1 parent 722ed50 commit c74ed05
Show file tree
Hide file tree
Showing 2 changed files with 18 additions and 11 deletions.
25 changes: 16 additions & 9 deletions python/pyspark/sql/udf.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@
"""
User-defined function related classes and functions
"""
import sys
import inspect
import functools

from pyspark import SparkContext, since
Expand All @@ -35,24 +37,29 @@ def _wrap_function(sc, func, returnType):
sc.pythonVer, broadcast_vars, sc._javaAccumulator)


def _get_argspec(f):
"""
Get argspec of a function.
"""
# `getargspec` is deprecated since python3.0 (incompatible with function annotations).
# See SPARK-23569.
if sys.version_info[0] < 3:
argspec = inspect.getargspec(f)
else:
argspec = inspect.getfullargspec(f)
return argspec


def _create_udf(f, returnType, evalType):

if evalType in (PythonEvalType.SQL_SCALAR_PANDAS_UDF,
PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF,
PythonEvalType.SQL_GROUPED_AGG_PANDAS_UDF):

import inspect
import sys
from pyspark.sql.utils import require_minimum_pyarrow_version

require_minimum_pyarrow_version()

if sys.version_info[0] < 3:
# `getargspec` is deprecated since python3.0 (incompatible with function annotations).
# See SPARK-23569.
argspec = inspect.getargspec(f)
else:
argspec = inspect.getfullargspec(f)
argspec = _get_argspec(f)

if evalType == PythonEvalType.SQL_SCALAR_PANDAS_UDF and len(argspec.args) == 0 and \
argspec.varargs is None:
Expand Down
4 changes: 2 additions & 2 deletions python/pyspark/worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@
Worker that receives input from Piped RDD.
"""
from __future__ import print_function
import inspect
import os
import sys
import time
Expand All @@ -35,6 +34,7 @@
write_long, read_int, SpecialLengths, UTF8Deserializer, PickleSerializer, \
BatchedSerializer, ArrowStreamPandasSerializer
from pyspark.sql.types import to_arrow_type
from pyspark.sql.udf import _get_argspec
from pyspark import shuffle

pickleSer = PickleSerializer()
Expand Down Expand Up @@ -94,7 +94,7 @@ def verify_result_length(*a):
def wrap_grouped_map_pandas_udf(f, return_type):
def wrapped(key_series, value_series):
import pandas as pd
argspec = inspect.getargspec(f)
argspec = _get_argspec(f)

if len(argspec.args) == 1:
result = f(pd.concat(value_series, axis=1))
Expand Down

0 comments on commit c74ed05

Please sign in to comment.