Skip to content
39 changes: 35 additions & 4 deletions python/pyspark/serializers.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,7 @@
from itertools import izip as zip, imap as map
else:
import pickle
basestring = unicode = str
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@HyukjinKwon I thought we have other places for this kind of thing (or is it your new PR for cloudpickle)

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes .. there are some places that use this here and there. IIRC, we discussed about Python 2 drop in dev mailing list. I could get rid of it soon anyway ..

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, this and below are just for Python 2 support. Are we dropping that for Spark 3.0?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think we'll drop it in Spark 3.0. I will cc you in the related PRs later in the future.

xrange = range
pickle_protocol = pickle.HIGHEST_PROTOCOL

Expand Down Expand Up @@ -244,7 +245,7 @@ def __repr__(self):
return "ArrowStreamSerializer"


def _create_batch(series, timezone, safecheck):
def _create_batch(series, timezone, safecheck, assign_cols_by_name):
"""
Create an Arrow record batch from the given pandas.Series or list of Series, with optional type.

Expand All @@ -254,6 +255,7 @@ def _create_batch(series, timezone, safecheck):
"""
import decimal
from distutils.version import LooseVersion
import pandas as pd
import pyarrow as pa
from pyspark.sql.types import _check_series_convert_timestamps_internal
# Make input conform to [(series1, type1), (series2, type2), ...]
Expand Down Expand Up @@ -295,7 +297,34 @@ def create_array(s, t):
raise RuntimeError(error_msg % (s.dtype, t), e)
return array

arrs = [create_array(s, t) for s, t in series]
arrs = []
for s, t in series:
if t is not None and pa.types.is_struct(t):
if not isinstance(s, pd.DataFrame):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good @BryanCutler, and cc @ueshin FYI. Just out of curiosity, WDYT about putting those PySpark specific conversion logics into somewhere together, of course, in a separate PR and JIRA? Looks it's getting difficult to read (to me .. )

raise ValueError("A field of type StructType expects a pandas.DataFrame, "
"but got: %s" % str(type(s)))

# Input partition and result pandas.DataFrame empty, make empty Arrays with struct
if len(s) == 0 and len(s.columns) == 0:
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I came across a case where there was an empty partition and when the udf processed it, the result is an empty Pandas DataFrame with no columns, see https://github.com/apache/spark/pull/23900/files#diff-d1bd0bd4ceeedd30cc219293a75ad90fR395

I figured it would be pretty confusing for the user to handle these kind of cases, and it's pretty simple to just check and add an empty struct when this happens, so that's what this check is for.

arrs_names = [(pa.array([], type=field.type), field.name) for field in t]
# Assign result columns by schema name if user labeled with strings
elif assign_cols_by_name and any(isinstance(name, basestring) for name in s.columns):
arrs_names = [(create_array(s[field.name], field.type), field.name) for field in t]
# Assign result columns by position
else:
arrs_names = [(create_array(s[s.columns[i]], field.type), field.name)
for i, field in enumerate(t)]

struct_arrs, struct_names = zip(*arrs_names)

# TODO: from_arrays args switched for v0.9.0, remove when bump minimum pyarrow version
if LooseVersion(pa.__version__) < LooseVersion("0.9.0"):
arrs.append(pa.StructArray.from_arrays(struct_names, struct_arrs))
else:
arrs.append(pa.StructArray.from_arrays(struct_arrs, struct_names))
else:
arrs.append(create_array(s, t))

return pa.RecordBatch.from_arrays(arrs, ["_%d" % i for i in xrange(len(arrs))])


Expand All @@ -304,10 +333,11 @@ class ArrowStreamPandasSerializer(Serializer):
Serializes Pandas.Series as Arrow data with Arrow streaming format.
"""

def __init__(self, timezone, safecheck):
def __init__(self, timezone, safecheck, assign_cols_by_name):
super(ArrowStreamPandasSerializer, self).__init__()
self._timezone = timezone
self._safecheck = safecheck
self._assign_cols_by_name = assign_cols_by_name

def arrow_to_pandas(self, arrow_column):
from pyspark.sql.types import from_arrow_type, \
Expand All @@ -326,7 +356,8 @@ def dump_stream(self, iterator, stream):
writer = None
try:
for series in iterator:
batch = _create_batch(series, self._timezone, self._safecheck)
batch = _create_batch(series, self._timezone, self._safecheck,
self._assign_cols_by_name)
if writer is None:
write_int(SpecialLengths.START_ARROW_STREAM, stream)
writer = pa.RecordBatchStreamWriter(stream, batch.schema)
Expand Down
12 changes: 11 additions & 1 deletion python/pyspark/sql/functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -2842,8 +2842,9 @@ def pandas_udf(f=None, returnType=None, functionType=None):

A scalar UDF defines a transformation: One or more `pandas.Series` -> A `pandas.Series`.
The length of the returned `pandas.Series` must be of the same as the input `pandas.Series`.
If the return type is :class:`StructType`, the returned value should be a `pandas.DataFrame`.

:class:`MapType`, :class:`StructType` are currently not supported as output types.
:class:`MapType`, nested :class:`StructType` are currently not supported as output types.

Scalar UDFs are used with :meth:`pyspark.sql.DataFrame.withColumn` and
:meth:`pyspark.sql.DataFrame.select`.
Expand All @@ -2868,6 +2869,15 @@ def pandas_udf(f=None, returnType=None, functionType=None):
+----------+--------------+------------+
| 8| JOHN DOE| 22|
+----------+--------------+------------+
>>> @pandas_udf("first string, last string") # doctest: +SKIP
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a nice feature!

I am wondering if we have a better way to handle the incorrect inputs. For example, if our end users specify incorrect return data types (e.g. @pandas_udf("first string, last int")), do we issue a user-friend error message?

19/03/21 00:04:08 ERROR Executor: Exception in task 0.0 in stage 8.0 (TID 8)
org.apache.spark.api.python.PythonException: Traceback (most recent call last):
  File "/Users/xiaoli/IdeaProjects/sparkDelivery/python/lib/pyspark.zip/pyspark/worker.py", line 433, in main
    process()
  File "/Users/xiaoli/IdeaProjects/sparkDelivery/python/lib/pyspark.zip/pyspark/worker.py", line 428, in process
    serializer.dump_stream(func(split_index, iterator), outfile)
  File "/Users/xiaoli/IdeaProjects/sparkDelivery/python/lib/pyspark.zip/pyspark/serializers.py", line 360, in dump_stream
    self._assign_cols_by_name)
  File "/Users/xiaoli/IdeaProjects/sparkDelivery/python/lib/pyspark.zip/pyspark/serializers.py", line 316, in _create_batch
    for i, field in enumerate(t)]
  File "/Users/xiaoli/IdeaProjects/sparkDelivery/python/lib/pyspark.zip/pyspark/serializers.py", line 287, in create_array
    return pa.Array.from_pandas(s, mask=mask, type=t)
  File "pyarrow/array.pxi", line 335, in pyarrow.lib.Array.from_pandas (/Users/travis/build/BryanCutler/arrow-dist/arrow/python/build/temp.macosx-10.6-intel-2.7/lib.cxx:30884)
    return array(obj, mask=mask, type=type, memory_pool=memory_pool,
  File "pyarrow/array.pxi", line 170, in pyarrow.lib.array (/Users/travis/build/BryanCutler/arrow-dist/arrow/python/build/temp.macosx-10.6-intel-2.7/lib.cxx:29224)
    return _ndarray_to_array(values, mask, type, from_pandas, pool)
  File "pyarrow/array.pxi", line 70, in pyarrow.lib._ndarray_to_array (/Users/travis/build/BryanCutler/arrow-dist/arrow/python/build/temp.macosx-10.6-intel-2.7/lib.cxx:28465)
    check_status(NdarrayToArrow(pool, values, mask,
  File "pyarrow/error.pxi", line 85, in pyarrow.lib.check_status (/Users/travis/build/BryanCutler/arrow-dist/arrow/python/build/temp.macosx-10.6-intel-2.7/lib.cxx:8570)
    raise ArrowNotImplementedError(message)
ArrowNotImplementedError: No cast implemented from string to int32

	at org.apache.spark.api.python.BasePythonRunner$ReaderIterator.handlePythonException(PythonRunner.scala:453)
	at org.apache.spark.sql.execution.python.ArrowPythonRunner$$anon$2.read(ArrowPythonRunner.scala:172)
	at org.apache.spark.sql.execution.python.ArrowPythonRunner$$anon$2.read(ArrowPythonRunner.scala:122)
	at org.apache.spark.api.python.BasePythonRunner$ReaderIterator.hasNext(PythonRunner.scala:406)
	at org.apache.spark.InterruptibleIterator.hasNext(InterruptibleIterator.scala:37)
	at org.apache.spark.sql.execution.python.ArrowEvalPythonExec$$anon$2.<init>(ArrowEvalPythonExec.scala:102)
	at org.apache.spark.sql.execution.python.ArrowEvalPythonExec.evaluate(ArrowEvalPythonExec.scala:100)
	at org.apache.spark.sql.execution.python.EvalPythonExec.$anonfun$doExecute$2(EvalPythonExec.scala:126)
	at org.apache.spark.rdd.RDD.$anonfun$mapPartitions$2(RDD.scala:817)
	at org.apache.spark.rdd.RDD.$anonfun$mapPartitions$2$adapted(RDD.scala:817)
	at org.apache.spark.rdd.MapPartitionsRDD.compute(MapPartitionsRDD.scala:52)
	at org.apache.spark.rdd.RDD.computeOrReadCheckpoint(RDD.scala:327)
	at org.apache.spark.rdd.RDD.iterator(RDD.scala:291)
	at org.apache.spark.rdd.MapPartitionsRDD.compute(MapPartitionsRDD.scala:52)
	at org.apache.spark.rdd.RDD.computeOrReadCheckpoint(RDD.scala:327)
	at org.apache.spark.rdd.RDD.iterator(RDD.scala:291)
	at org.apache.spark.rdd.MapPartitionsRDD.compute(MapPartitionsRDD.scala:52)
	at org.apache.spark.rdd.RDD.computeOrReadCheckpoint(RDD.scala:327)
	at org.apache.spark.rdd.RDD.iterator(RDD.scala:291)
	at org.apache.spark.rdd.MapPartitionsRDD.compute(MapPartitionsRDD.scala:52)
	at org.apache.spark.rdd.RDD.computeOrReadCheckpoint(RDD.scala:327)
	at org.apache.spark.rdd.RDD.iterator(RDD.scala:291)
	at org.apache.spark.scheduler.ResultTask.runTask(ResultTask.scala:90)
	at org.apache.spark.scheduler.Task.run(Task.scala:121)
	at org.apache.spark.executor.Executor$TaskRunner.$anonfun$run$3(Executor.scala:428)
	at org.apache.spark.util.Utils$.tryWithSafeFinally(Utils.scala:1341)
	at org.apache.spark.executor.Executor$TaskRunner.run(Executor.scala:431)
	at java.util.concurrent.ThreadPoolExecutor.runWorker(ThreadPoolExecutor.java:1142)
	at java.util.concurrent.ThreadPoolExecutor$Worker.run(ThreadPoolExecutor.java:617)
	at java.lang.Thread.run(Thread.java:745)
19/03/21 00:04:08 WARN TaskSetManager: Lost task 0.0 in stage 8.0 (TID 8, localhost, executor driver): org.apache.spark.api.python.PythonException: Traceback (most recent call last):
  File "/Users/xiaoli/IdeaProjects/sparkDelivery/python/lib/pyspark.zip/pyspark/worker.py", line 433, in main
    process()
  File "/Users/xiaoli/IdeaProjects/sparkDelivery/python/lib/pyspark.zip/pyspark/worker.py", line 428, in process
    serializer.dump_stream(func(split_index, iterator), outfile)
  File "/Users/xiaoli/IdeaProjects/sparkDelivery/python/lib/pyspark.zip/pyspark/serializers.py", line 360, in dump_stream
    self._assign_cols_by_name)
  File "/Users/xiaoli/IdeaProjects/sparkDelivery/python/lib/pyspark.zip/pyspark/serializers.py", line 316, in _create_batch
    for i, field in enumerate(t)]
  File "/Users/xiaoli/IdeaProjects/sparkDelivery/python/lib/pyspark.zip/pyspark/serializers.py", line 287, in create_array
    return pa.Array.from_pandas(s, mask=mask, type=t)
  File "pyarrow/array.pxi", line 335, in pyarrow.lib.Array.from_pandas (/Users/travis/build/BryanCutler/arrow-dist/arrow/python/build/temp.macosx-10.6-intel-2.7/lib.cxx:30884)
    return array(obj, mask=mask, type=type, memory_pool=memory_pool,
  File "pyarrow/array.pxi", line 170, in pyarrow.lib.array (/Users/travis/build/BryanCutler/arrow-dist/arrow/python/build/temp.macosx-10.6-intel-2.7/lib.cxx:29224)
    return _ndarray_to_array(values, mask, type, from_pandas, pool)
  File "pyarrow/array.pxi", line 70, in pyarrow.lib._ndarray_to_array (/Users/travis/build/BryanCutler/arrow-dist/arrow/python/build/temp.macosx-10.6-intel-2.7/lib.cxx:28465)
    check_status(NdarrayToArrow(pool, values, mask,
  File "pyarrow/error.pxi", line 85, in pyarrow.lib.check_status (/Users/travis/build/BryanCutler/arrow-dist/arrow/python/build/temp.macosx-10.6-intel-2.7/lib.cxx:8570)
    raise ArrowNotImplementedError(message)

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also, is it possible the mismatched types could generate wrong results?

Copy link
Member

@HyukjinKwon HyukjinKwon Mar 22, 2019

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, something has to be done. I was at the very least tried to document the casting
combinations.

Pandas UDF matrix:

# The following table shows most of Pandas data and SQL type conversions in Pandas UDFs that
# are not yet visible to the user. Some of behaviors are buggy and might be changed in the near
# future. The table might have to be eventually documented externally.
# Please see SPARK-25798's PR to see the codes in order to generate the table below.
#
# +-----------------------------+----------------------+----------+-------+--------+--------------------+--------------------+--------+---------+---------+---------+------------+------------+------------+-----------------------------------+-----------------------------------------------------+-----------------+--------------------+-----------------------------+-------------+-----------------+------------------+-----------+--------------------------------+ # noqa
# |SQL Type \ Pandas Value(Type)|None(object(NoneType))|True(bool)|1(int8)|1(int16)| 1(int32)| 1(int64)|1(uint8)|1(uint16)|1(uint32)|1(uint64)|1.0(float16)|1.0(float32)|1.0(float64)|1970-01-01 00:00:00(datetime64[ns])|1970-01-01 00:00:00-05:00(datetime64[ns, US/Eastern])|a(object(string))| 1(object(Decimal))|[1 2 3](object(array[int32]))|1.0(float128)|(1+0j)(complex64)|(1+0j)(complex128)|A(category)|1 days 00:00:00(timedelta64[ns])| # noqa
# +-----------------------------+----------------------+----------+-------+--------+--------------------+--------------------+--------+---------+---------+---------+------------+------------+------------+-----------------------------------+-----------------------------------------------------+-----------------+--------------------+-----------------------------+-------------+-----------------+------------------+-----------+--------------------------------+ # noqa
# | boolean| None| True| True| True| True| True| True| True| True| True| False| False| False| False| False| X| X| X| False| False| False| X| False| # noqa
# | tinyint| None| 1| 1| 1| 1| 1| X| X| X| X| 1| 1| 1| X| X| X| X| X| X| X| X| 0| X| # noqa
# | smallint| None| 1| 1| 1| 1| 1| 1| X| X| X| 1| 1| 1| X| X| X| X| X| X| X| X| X| X| # noqa
# | int| None| 1| 1| 1| 1| 1| 1| 1| X| X| 1| 1| 1| X| X| X| X| X| X| X| X| X| X| # noqa
# | bigint| None| 1| 1| 1| 1| 1| 1| 1| 1| X| 1| 1| 1| 0| 18000000000000| X| X| X| X| X| X| X| X| # noqa
# | float| None| 1.0| 1.0| 1.0| 1.0| 1.0| 1.0| 1.0| 1.0| 1.0| 1.0| 1.0| 1.0| X| X| X|1.401298464324817...| X| X| X| X| X| X| # noqa
# | double| None| 1.0| 1.0| 1.0| 1.0| 1.0| 1.0| 1.0| 1.0| 1.0| 1.0| 1.0| 1.0| X| X| X| X| X| X| X| X| X| X| # noqa
# | date| None| X| X| X|datetime.date(197...| X| X| X| X| X| X| X| X| datetime.date(197...| X| X| X| X| X| X| X| X| X| # noqa
# | timestamp| None| X| X| X| X|datetime.datetime...| X| X| X| X| X| X| X| datetime.datetime...| datetime.datetime...| X| X| X| X| X| X| X| X| # noqa
# | string| None| u''|u'\x01'| u'\x01'| u'\x01'| u'\x01'| u'\x01'| u'\x01'| u'\x01'| u'\x01'| u''| u''| u''| X| X| u'a'| X| X| u''| u''| u''| X| X| # noqa
# | decimal(10,0)| None| X| X| X| X| X| X| X| X| X| X| X| X| X| X| X| Decimal('1')| X| X| X| X| X| X| # noqa
# | array<int>| None| X| X| X| X| X| X| X| X| X| X| X| X| X| X| X| X| [1, 2, 3]| X| X| X| X| X| # noqa
# | map<string,int>| X| X| X| X| X| X| X| X| X| X| X| X| X| X| X| X| X| X| X| X| X| X| X| # noqa
# | struct<_1:int>| X| X| X| X| X| X| X| X| X| X| X| X| X| X| X| X| X| X| X| X| X| X| X| # noqa
# | binary| X| X| X| X| X| X| X| X| X| X| X| X| X| X| X| X| X| X| X| X| X| X| X| # noqa
# +-----------------------------+----------------------+----------+-------+--------+--------------------+--------------------+--------+---------+---------+---------+------------+------------+------------+-----------------------------------+-----------------------------------------------------+-----------------+--------------------+-----------------------------+-------------+-----------------+------------------+-----------+--------------------------------+ # noqa
#
# Note: DDL formatted string is used for 'SQL Type' for simplicity. This string can be
# used in `returnType`.
# Note: The values inside of the table are generated by `repr`.
# Note: Python 2 is used to generate this table since it is used to check the backward
# compatibility often in practice.
# Note: Pandas 0.19.2 and PyArrow 0.9.0 are used.
# Note: Timezone is Singapore timezone.
# Note: 'X' means it throws an exception during the conversion.
# Note: 'binary' type is only supported with PyArrow 0.10.0+ (SPARK-23555).

The problem is, this matrix is different from regular PySpark UDF, and also our TypeCoercions:

Regular PySpark UDF matrix:

# The following table shows most of Python data and SQL type conversions in normal UDFs that
# are not yet visible to the user. Some of behaviors are buggy and might be changed in the near
# future. The table might have to be eventually documented externally.
# Please see SPARK-25666's PR to see the codes in order to generate the table below.
#
# +-----------------------------+--------------+----------+------+-------+---------------+---------------+--------------------+-----------------------------+----------+----------------------+---------+--------------------+-----------------+------------+--------------+------------------+----------------------+ # noqa
# |SQL Type \ Python Value(Type)|None(NoneType)|True(bool)|1(int)|1(long)| a(str)| a(unicode)| 1970-01-01(date)|1970-01-01 00:00:00(datetime)|1.0(float)|array('i', [1])(array)|[1](list)| (1,)(tuple)| ABC(bytearray)| 1(Decimal)|{'a': 1}(dict)|Row(kwargs=1)(Row)|Row(namedtuple=1)(Row)| # noqa
# +-----------------------------+--------------+----------+------+-------+---------------+---------------+--------------------+-----------------------------+----------+----------------------+---------+--------------------+-----------------+------------+--------------+------------------+----------------------+ # noqa
# | boolean| None| True| None| None| None| None| None| None| None| None| None| None| None| None| None| X| X| # noqa
# | tinyint| None| None| 1| 1| None| None| None| None| None| None| None| None| None| None| None| X| X| # noqa
# | smallint| None| None| 1| 1| None| None| None| None| None| None| None| None| None| None| None| X| X| # noqa
# | int| None| None| 1| 1| None| None| None| None| None| None| None| None| None| None| None| X| X| # noqa
# | bigint| None| None| 1| 1| None| None| None| None| None| None| None| None| None| None| None| X| X| # noqa
# | string| None| u'true'| u'1'| u'1'| u'a'| u'a'|u'java.util.Grego...| u'java.util.Grego...| u'1.0'| u'[I@24a83055'| u'[1]'|u'[Ljava.lang.Obj...| u'[B@49093632'| u'1'| u'{a=1}'| X| X| # noqa
# | date| None| X| X| X| X| X|datetime.date(197...| datetime.date(197...| X| X| X| X| X| X| X| X| X| # noqa
# | timestamp| None| X| X| X| X| X| X| datetime.datetime...| X| X| X| X| X| X| X| X| X| # noqa
# | float| None| None| None| None| None| None| None| None| 1.0| None| None| None| None| None| None| X| X| # noqa
# | double| None| None| None| None| None| None| None| None| 1.0| None| None| None| None| None| None| X| X| # noqa
# | array<int>| None| None| None| None| None| None| None| None| None| [1]| [1]| [1]| [65, 66, 67]| None| None| X| X| # noqa
# | binary| None| None| None| None|bytearray(b'a')|bytearray(b'a')| None| None| None| None| None| None|bytearray(b'ABC')| None| None| X| X| # noqa
# | decimal(10,0)| None| None| None| None| None| None| None| None| None| None| None| None| None|Decimal('1')| None| X| X| # noqa
# | map<string,int>| None| None| None| None| None| None| None| None| None| None| None| None| None| None| {u'a': 1}| X| X| # noqa
# | struct<_1:int>| None| X| X| X| X| X| X| X| X| X|Row(_1=1)| Row(_1=1)| X| X| Row(_1=None)| Row(_1=1)| Row(_1=1)| # noqa
# +-----------------------------+--------------+----------+------+-------+---------------+---------------+--------------------+-----------------------------+----------+----------------------+---------+--------------------+-----------------+------------+--------------+------------------+----------------------+ # noqa
#
# Note: DDL formatted string is used for 'SQL Type' for simplicity. This string can be
# used in `returnType`.
# Note: The values inside of the table are generated by `repr`.
# Note: Python 2 is used to generate this table since it is used to check the backward
# compatibility often in practice.
# Note: 'X' means it throws an exception during the conversion.

I lost the last discussion about whether we should allow such type coercions or not. But basically my guts say: If we allow, I think it will need a huge bunch of codes to maintain again (Arrow Type <> Pandas type <> Python type <> SparkSQL type), but if we disallow, it will break many existing apps.

One way is that we explicitly document that Pandas's type coercion is dependent on Arrow (apart from regular PySpark UDF), and throw an explicit exception.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

And .. yes I think we should also throw better exceptions at the very least.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sometimes type casting is necessary, for example if your integer column has null values then Pandas will automatically upcast to floating point to represent the nulls as NaNs. If that column is returned then it doesn't make sense to keep it as floating point because Spark can handle the null values, so using integer return type will cause type casting, but won't cause any problems.

... def split_expand(n):
... return n.str.split(expand=True)
>>> df.select(split_expand("name")).show() # doctest: +SKIP
+------------------+
|split_expand(name)|
+------------------+
| [John, Doe]|
+------------------+

.. note:: The length of `pandas.Series` within a scalar UDF is not that of the whole input
column, but is the length of an internal batch used for each call to the function.
Expand Down
3 changes: 2 additions & 1 deletion python/pyspark/sql/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -557,8 +557,9 @@ def _create_from_pandas_with_arrow(self, pdf, schema, timezone):

# Create Arrow record batches
safecheck = self._wrapped._conf.arrowSafeTypeConversion()
col_by_name = True # col by name only applies to StructType columns, can't happen here
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Meaning the value of col_by_name doesn't matter here, right?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, that's right. This will be removed when we take out that conf.

batches = [_create_batch([(c, t) for (_, c), t in zip(pdf_slice.iteritems(), arrow_types)],
timezone, safecheck)
timezone, safecheck, col_by_name)
for pdf_slice in pdf_slices]

# Create the Spark schema from the first Arrow batch (always at least 1 batch after slicing)
Expand Down
1 change: 1 addition & 0 deletions python/pyspark/sql/tests/test_pandas_udf_grouped_map.py
Original file line number Diff line number Diff line change
Expand Up @@ -273,6 +273,7 @@ def test_unsupported_types(self):
StructField('map', MapType(StringType(), IntegerType())),
StructField('arr_ts', ArrayType(TimestampType())),
StructField('null', NullType()),
StructField('struct', StructType([StructField('l', LongType())])),
]

# TODO: Remove this if-statement once minimum pyarrow version is 0.10.0
Expand Down
81 changes: 80 additions & 1 deletion python/pyspark/sql/tests/test_pandas_udf_scalar.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,13 +23,16 @@
import time
import unittest

if sys.version >= '3':
unicode = str
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ditto


from datetime import date, datetime
from decimal import Decimal
from distutils.version import LooseVersion

from pyspark.rdd import PythonEvalType
from pyspark.sql import Column
from pyspark.sql.functions import array, col, expr, lit, sum, udf, pandas_udf
from pyspark.sql.functions import array, col, expr, lit, sum, struct, udf, pandas_udf
from pyspark.sql.types import Row
from pyspark.sql.types import *
from pyspark.sql.utils import AnalysisException
Expand Down Expand Up @@ -265,6 +268,64 @@ def test_vectorized_udf_null_array(self):
result = df.select(array_f(col('array')))
self.assertEquals(df.collect(), result.collect())

def test_vectorized_udf_struct_type(self):
import pandas as pd

df = self.spark.range(10)
return_type = StructType([
StructField('id', LongType()),
StructField('str', StringType())])

def func(id):
return pd.DataFrame({'id': id, 'str': id.apply(unicode)})

f = pandas_udf(func, returnType=return_type)

expected = df.select(struct(col('id'), col('id').cast('string').alias('str'))
.alias('struct')).collect()

actual = df.select(f(col('id')).alias('struct')).collect()
self.assertEqual(expected, actual)

g = pandas_udf(func, 'id: long, str: string')
actual = df.select(g(col('id')).alias('struct')).collect()
self.assertEqual(expected, actual)

def test_vectorized_udf_struct_complex(self):
import pandas as pd

df = self.spark.range(10)
return_type = StructType([
StructField('ts', TimestampType()),
StructField('arr', ArrayType(LongType()))])

@pandas_udf(returnType=return_type)
def f(id):
return pd.DataFrame({'ts': id.apply(lambda i: pd.Timestamp(i)),
'arr': id.apply(lambda i: [i, i + 1])})

actual = df.withColumn('f', f(col('id'))).collect()
for i, row in enumerate(actual):
id, f = row
self.assertEqual(i, id)
self.assertEqual(pd.Timestamp(i).to_pydatetime(), f[0])
self.assertListEqual([i, i + 1], f[1])

def test_vectorized_udf_nested_struct(self):
nested_type = StructType([
StructField('id', IntegerType()),
StructField('nested', StructType([
StructField('foo', StringType()),
StructField('bar', FloatType())
]))
])

with QuietTest(self.sc):
with self.assertRaisesRegexp(
Exception,
'Invalid returnType with scalar Pandas UDFs'):
pandas_udf(lambda x: x, returnType=nested_type)

def test_vectorized_udf_complex(self):
df = self.spark.range(10).select(
col('id').cast('int').alias('a'),
Expand Down Expand Up @@ -331,6 +392,20 @@ def test_vectorized_udf_empty_partition(self):
res = df.select(f(col('id')))
self.assertEquals(df.collect(), res.collect())

def test_vectorized_udf_struct_with_empty_partition(self):
df = self.spark.createDataFrame(self.sc.parallelize([Row(id=1)], 2))\
.withColumn('name', lit('John Doe'))

@pandas_udf("first string, last string")
def split_expand(n):
return n.str.split(expand=True)

result = df.select(split_expand('name')).collect()
self.assertEqual(1, len(result))
row = result[0]
self.assertEqual('John', row[0]['first'])
self.assertEqual('Doe', row[0]['last'])

def test_vectorized_udf_varargs(self):
df = self.spark.createDataFrame(self.sc.parallelize([Row(id=1)], 2))
f = pandas_udf(lambda *v: v[0], LongType())
Expand All @@ -343,6 +418,10 @@ def test_vectorized_udf_unsupported_types(self):
NotImplementedError,
'Invalid returnType.*scalar Pandas UDF.*MapType'):
pandas_udf(lambda x: x, MapType(StringType(), IntegerType()))
with self.assertRaisesRegexp(
NotImplementedError,
'Invalid returnType.*scalar Pandas UDF.*ArrayType.StructType'):
pandas_udf(lambda x: x, ArrayType(StructType([StructField('a', IntegerType())])))

def test_vectorized_udf_dates(self):
schema = StructType().add("idx", LongType()).add("date", DateType())
Expand Down
8 changes: 7 additions & 1 deletion python/pyspark/sql/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -1613,9 +1613,15 @@ def to_arrow_type(dt):
# Timestamps should be in UTC, JVM Arrow timestamps require a timezone to be read
arrow_type = pa.timestamp('us', tz='UTC')
elif type(dt) == ArrayType:
if type(dt.elementType) == TimestampType:
if type(dt.elementType) in [StructType, TimestampType]:
raise TypeError("Unsupported type in conversion to Arrow: " + str(dt))
arrow_type = pa.list_(to_arrow_type(dt.elementType))
elif type(dt) == StructType:
if any(type(field.dataType) == StructType for field in dt):
raise TypeError("Nested StructType not supported in conversion to Arrow")
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is ArrayType(elementType = StructType) supported?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good catch, that should not be supported right now. I added a check and put that type in a test.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am curious why ArrayType(elementType = StructType) support was removed from here? @BryanCutler

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@cfmcgrady support wasn't removed, it was never allowed to have ArrayType(elementType = StructType) - I don't think there was an explicit check before this. It might be possible to add this in the future, but it's a little tricky to represent this in Pandas efficiently, I believe.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you for your reply.

fields = [pa.field(field.name, to_arrow_type(field.dataType), nullable=field.nullable)
for field in dt]
arrow_type = pa.struct(fields)
else:
raise TypeError("Unsupported type in conversion to Arrow: " + str(dt))
return arrow_type
Expand Down
5 changes: 4 additions & 1 deletion python/pyspark/sql/udf.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,7 +123,7 @@ def returnType(self):
elif self.evalType == PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF:
if isinstance(self._returnType_placeholder, StructType):
try:
to_arrow_schema(self._returnType_placeholder)
to_arrow_type(self._returnType_placeholder)
except TypeError:
raise NotImplementedError(
"Invalid returnType with grouped map Pandas UDFs: "
Expand All @@ -133,6 +133,9 @@ def returnType(self):
"UDFs: returnType must be a StructType.")
elif self.evalType == PythonEvalType.SQL_GROUPED_AGG_PANDAS_UDF:
try:
# StructType is not yet allowed as a return type, explicitly check here to fail fast
if isinstance(self._returnType_placeholder, StructType):
raise TypeError
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hm, @BryanCutler, sorry if I missed something but why do we throw a type error here?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Grouped Agg UDFs don't allow a StructType return yet, and before relied on the call to to_arrow_type to raise an error. Since that no longer happens, need to raise it here.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see. Can you add some message while we're here? If this is going to be fixed soon, I am okay as is as well.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sure. I will try it as a followup, but a message for now will be good. I just noticed that grouped map wasn't catching a nested struct type, so I need to fix that anyway.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yea, either way sounds good to me. I'll leave it to you.

to_arrow_type(self._returnType_placeholder)
except TypeError:
raise NotImplementedError(
Expand Down
12 changes: 9 additions & 3 deletions python/pyspark/worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@
from pyspark.serializers import write_with_length, write_int, read_long, read_bool, \
write_long, read_int, SpecialLengths, UTF8Deserializer, PickleSerializer, \
BatchedSerializer, ArrowStreamPandasSerializer
from pyspark.sql.types import to_arrow_type
from pyspark.sql.types import to_arrow_type, StructType
from pyspark.util import _get_argspec, fail_on_stopiteration
from pyspark import shuffle

Expand Down Expand Up @@ -90,8 +90,9 @@ def wrap_scalar_pandas_udf(f, return_type):
def verify_result_length(*a):
result = f(*a)
if not hasattr(result, "__len__"):
pd_type = "Pandas.DataFrame" if type(return_type) == StructType else "Pandas.Series"
raise TypeError("Return type of the user-defined function should be "
"Pandas.Series, but is {}".format(type(result)))
"{}, but is {}".format(pd_type, type(result)))
if len(result) != len(a[0]):
raise RuntimeError("Result vector from pandas_udf was not the required length: "
"expected %d, got %d" % (len(a[0]), len(result)))
Expand Down Expand Up @@ -254,7 +255,12 @@ def read_udfs(pickleSer, infile, eval_type):
timezone = runner_conf.get("spark.sql.session.timeZone", None)
safecheck = runner_conf.get("spark.sql.execution.pandas.arrowSafeTypeConversion",
"false").lower() == 'true'
ser = ArrowStreamPandasSerializer(timezone, safecheck)
# NOTE: this is duplicated from wrap_grouped_map_pandas_udf
assign_cols_by_name = runner_conf.get(
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@BryanCutler, BTW, would you be willing to work on removing spark.sql.legacy.execution.pandas.groupedMap.assignColumnsByName in Spark 3.0? I think this makes the codes complicated.

Also, would you mind working on upgrading minimum Arrow to 0.12.0 as well, as we discussed? (Probably it better be asked to dev mailing list first to be 100% sure).

If you're currently busy, I will take one or both.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Of course those should be separate JIRAs

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, definitely! I could take those 2 tasks. I was thinking on holding off a little while to bump up the minimum Arrow version just to see if anything major came up in the meantime releases. 0.12.1 will be out in a couple days, but I don't think major bug fixes for us. Maybe wait just a little bit longer?

"spark.sql.legacy.execution.pandas.groupedMap.assignColumnsByName", "true")\
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the groupedMap part of the conf doesn't really make sense here, but not sure if it's worth changing

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Based on above comment https://github.com/apache/spark/pull/23900/files#r260874304, if we are going to remove spark.sql.legacy.execution.pandas.groupedMap.assignColumnsByName, do we need to use this config here?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I left it in to be consistent. I'd rather remove both of them in a separate PR in case there is some discussion about it.

.lower() == "true"

ser = ArrowStreamPandasSerializer(timezone, safecheck, assign_cols_by_name)
else:
ser = BatchedSerializer(PickleSerializer(), 100)

Expand Down