Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SPARK-21190][PYSPARK] Python Vectorized UDFs #18659

Closed
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
19 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 17 additions & 5 deletions core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala
Original file line number Diff line number Diff line change
Expand Up @@ -83,10 +83,23 @@ private[spark] case class PythonFunction(
*/
private[spark] case class ChainedPythonFunctions(funcs: Seq[PythonFunction])

/**
* Enumerate the type of command that will be sent to the Python worker
*/
private[spark] object PythonEvalType {
val NON_UDF = 0
val SQL_BATCHED_UDF = 1
val SQL_PANDAS_UDF = 2
}

private[spark] object PythonRunner {
def apply(func: PythonFunction, bufferSize: Int, reuse_worker: Boolean): PythonRunner = {
new PythonRunner(
Seq(ChainedPythonFunctions(Seq(func))), bufferSize, reuse_worker, false, Array(Array(0)))
Seq(ChainedPythonFunctions(Seq(func))),
bufferSize,
reuse_worker,
PythonEvalType.NON_UDF,
Array(Array(0)))
}
}

Expand All @@ -100,7 +113,7 @@ private[spark] class PythonRunner(
funcs: Seq[ChainedPythonFunctions],
bufferSize: Int,
reuse_worker: Boolean,
isUDF: Boolean,
evalType: Int,
argOffsets: Array[Array[Int]])
extends Logging {

Expand Down Expand Up @@ -309,8 +322,8 @@ private[spark] class PythonRunner(
}
dataOut.flush()
// Serialized command:
if (isUDF) {
dataOut.writeInt(1)
dataOut.writeInt(evalType)
if (evalType != PythonEvalType.NON_UDF) {
dataOut.writeInt(funcs.length)
funcs.zip(argOffsets).foreach { case (chained, offsets) =>
dataOut.writeInt(offsets.length)
Expand All @@ -324,7 +337,6 @@ private[spark] class PythonRunner(
}
}
} else {
dataOut.writeInt(0)
val command = funcs.head.funcs.head.command
dataOut.writeInt(command.length)
dataOut.write(command)
Expand Down
65 changes: 63 additions & 2 deletions python/pyspark/serializers.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,12 @@ class SpecialLengths(object):
NULL = -5


class PythonEvalType(object):
NON_UDF = 0
SQL_BATCHED_UDF = 1
SQL_PANDAS_UDF = 2


class Serializer(object):

def dump_stream(self, iterator, stream):
Expand Down Expand Up @@ -187,8 +193,14 @@ class ArrowSerializer(FramedSerializer):
Serializes an Arrow stream.
"""

def dumps(self, obj):
raise NotImplementedError
def dumps(self, batch):
import pyarrow as pa
import io
sink = io.BytesIO()
writer = pa.RecordBatchFileWriter(sink, batch.schema)
writer.write_batch(batch)
writer.close()
return sink.getvalue()

def loads(self, obj):
import pyarrow as pa
Expand All @@ -199,6 +211,55 @@ def __repr__(self):
return "ArrowSerializer"


class ArrowPandasSerializer(ArrowSerializer):
"""
Serializes Pandas.Series as Arrow data.
"""

def __init__(self):
super(ArrowPandasSerializer, self).__init__()
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we need this?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No, that was leftovers.. I'll remove it in a followup.


def dumps(self, series):
"""
Make an ArrowRecordBatch from a Pandas Series and serialize. Input is a single series or
a list of series accompanied by an optional pyarrow type to coerce the data to.
"""
import pyarrow as pa
# Make input conform to [(series1, type1), (series2, type2), ...]
if not isinstance(series, (list, tuple)) or \
(len(series) == 2 and isinstance(series[1], pa.DataType)):
series = [series]
series = ((s, None) if not isinstance(s, (list, tuple)) else s for s in series)

# If a nullable integer series has been promoted to floating point with NaNs, need to cast
# NOTE: this is not necessary with Arrow >= 0.7
def cast_series(s, t):
if t is None or s.dtype == t.to_pandas_dtype():
return s
else:
return s.fillna(0).astype(t.to_pandas_dtype(), copy=False)

arrs = [pa.Array.from_pandas(cast_series(s, t), mask=s.isnull(), type=t) for s, t in series]
batch = pa.RecordBatch.from_arrays(arrs, ["_%d" % i for i in xrange(len(arrs))])
return super(ArrowPandasSerializer, self).dumps(batch)

def loads(self, obj):
"""
Deserialize an ArrowRecordBatch to an Arrow table and return as a list of pandas.Series
followed by a dictionary containing length of the loaded batches.
"""
import pyarrow as pa
reader = pa.RecordBatchFileReader(pa.BufferReader(obj))
batches = [reader.get_batch(i) for i in xrange(reader.num_record_batches)]
# NOTE: a 0-parameter pandas_udf will produce an empty batch that can have num_rows set
num_rows = sum((batch.num_rows for batch in batches))
table = pa.Table.from_batches(batches)
return [c.to_pandas() for c in table.itercolumns()] + [{"length": num_rows}]

def __repr__(self):
return "ArrowPandasSerializer"


class BatchedSerializer(Serializer):

"""
Expand Down
49 changes: 37 additions & 12 deletions python/pyspark/sql/functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -2032,7 +2032,7 @@ class UserDefinedFunction(object):

.. versionadded:: 1.3
"""
def __init__(self, func, returnType, name=None):
def __init__(self, func, returnType, name=None, vectorized=False):
if not callable(func):
raise TypeError(
"Not a function or callable (__call__ is not defined): "
Expand All @@ -2046,6 +2046,7 @@ def __init__(self, func, returnType, name=None):
self._name = name or (
func.__name__ if hasattr(func, '__name__')
else func.__class__.__name__)
self._vectorized = vectorized

@property
def returnType(self):
Expand Down Expand Up @@ -2077,7 +2078,7 @@ def _create_judf(self):
wrapped_func = _wrap_function(sc, self.func, self.returnType)
jdt = spark._jsparkSession.parseDataType(self.returnType.json())
judf = sc._jvm.org.apache.spark.sql.execution.python.UserDefinedPythonFunction(
self._name, wrapped_func, jdt)
self._name, wrapped_func, jdt, self._vectorized)
return judf

def __call__(self, *cols):
Expand Down Expand Up @@ -2111,6 +2112,22 @@ def wrapper(*args):
return wrapper


def _create_udf(f, returnType, vectorized):

def _udf(f, returnType=StringType(), vectorized=vectorized):
udf_obj = UserDefinedFunction(f, returnType, vectorized=vectorized)
return udf_obj._wrapped()

# decorator @udf, @udf(), @udf(dataType()), or similar with @pandas_udf
if f is None or isinstance(f, (str, DataType)):
# If DataType has been passed as a positional argument
# for decorator use it as a returnType
return_type = f or returnType
return functools.partial(_udf, returnType=return_type, vectorized=vectorized)
else:
return _udf(f=f, returnType=returnType, vectorized=vectorized)


@since(1.3)
def udf(f=None, returnType=StringType()):
"""Creates a :class:`Column` expression representing a user defined function (UDF).
Expand Down Expand Up @@ -2142,18 +2159,26 @@ def udf(f=None, returnType=StringType()):
| 8| JOHN DOE| 22|
+----------+--------------+------------+
"""
def _udf(f, returnType=StringType()):
udf_obj = UserDefinedFunction(f, returnType)
return udf_obj._wrapped()
return _create_udf(f, returnType=returnType, vectorized=False)

# decorator @udf, @udf() or @udf(dataType())
if f is None or isinstance(f, (str, DataType)):
# If DataType has been passed as a positional argument
# for decorator use it as a returnType
return_type = f or returnType
return functools.partial(_udf, returnType=return_type)

@since(2.3)
def pandas_udf(f=None, returnType=StringType()):
"""
Creates a :class:`Column` expression representing a user defined function (UDF) that accepts
`Pandas.Series` as input arguments and outputs a `Pandas.Series` of the same length.

:param f: python function if used as a standalone function
:param returnType: a :class:`pyspark.sql.types.DataType` object

# TODO: doctest
"""
import inspect
# If function "f" does not define the optional kwargs, then wrap with a kwargs placeholder
if inspect.getargspec(f).keywords is None:
return _create_udf(lambda *a, **kwargs: f(*a), returnType=returnType, vectorized=True)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We wrap a kwargs placeholder for the function, but we don't actually pass it into the function. So different than the 0-argument pandas udf in SPIP, we explicitly ask it to define a kwargs? Namely we don't have really 0-argument pandas udf, because it at least has kwargs defined?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we make the kwargs as mandatory for 0-argument pandas udf? I think a 0-argument pandas udf without the kwargs seems no making sense as it can't guess the size of returning Series.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If the user function doesn't define the keyword args, then it is wrapped with a placeholder so that worker.py can expect the function to always have keywords. I thought this was better than trying to do inspection on the worker while running the UDF.

I'm not crazy about the 0-parameter pandas_udf, but if we have to support it here then it does need to get the required length of output somehow, unless we repeat/slice the output to make the length correct.

I'm ok with making **kwargs mandatory for 0-parameter UDFs and optional for others.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How about disallowing it for now? I think it could be an option if 0-parameter UDF alone should not be supported consistently. return pd.Series(1).repeat(kwargs['length']) looks still a little bit weird ..

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agree it is still a bit weird.. Did you mean disallowing 0-parameter panda_udfs or requiring 0-parameter panda_udfs to accept kwargs?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah, I was thinking that disallowing 0-parameter panda_udf could be an option ...

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It would be a lot cleaner to just not allow 0-parameters. Is it an option to not allow 0-parameter UDFs for pandas_udfs @ueshin @cloud-fan ?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm fine to disallow 0-parameter pandas udf, as it's not a common case. We can add it when people request it.

else:
return _udf(f=f, returnType=returnType)
return _create_udf(f, returnType=returnType, vectorized=True)


blacklist = ['map', 'since', 'ignore_unicode_prefix']
Expand Down
Loading