Skip to content

Commit

Permalink
[SPARK-14267] [SQL] [PYSPARK] execute multiple Python UDFs within sin…
Browse files Browse the repository at this point in the history
…gle batch

## What changes were proposed in this pull request?

This PR support multiple Python UDFs within single batch, also improve the performance.

```python
>>> from pyspark.sql.types import IntegerType
>>> sqlContext.registerFunction("double", lambda x: x * 2, IntegerType())
>>> sqlContext.registerFunction("add", lambda x, y: x + y, IntegerType())
>>> sqlContext.sql("SELECT double(add(1, 2)), add(double(2), 1)").explain(True)
== Parsed Logical Plan ==
'Project [unresolvedalias('double('add(1, 2)), None),unresolvedalias('add('double(2), 1), None)]
+- OneRowRelation$

== Analyzed Logical Plan ==
double(add(1, 2)): int, add(double(2), 1): int
Project [double(add(1, 2))#14,add(double(2), 1)#15]
+- Project [double(add(1, 2))#14,add(double(2), 1)#15]
   +- Project [pythonUDF0#16 AS double(add(1, 2))#14,pythonUDF0#18 AS add(double(2), 1)#15]
      +- EvaluatePython [add(pythonUDF1#17, 1)], [pythonUDF0#18]
         +- EvaluatePython [double(add(1, 2)),double(2)], [pythonUDF0#16,pythonUDF1#17]
            +- OneRowRelation$

== Optimized Logical Plan ==
Project [pythonUDF0#16 AS double(add(1, 2))#14,pythonUDF0#18 AS add(double(2), 1)#15]
+- EvaluatePython [add(pythonUDF1#17, 1)], [pythonUDF0#18]
   +- EvaluatePython [double(add(1, 2)),double(2)], [pythonUDF0#16,pythonUDF1#17]
      +- OneRowRelation$

== Physical Plan ==
WholeStageCodegen
:  +- Project [pythonUDF0#16 AS double(add(1, 2))#14,pythonUDF0#18 AS add(double(2), 1)#15]
:     +- INPUT
+- !BatchPythonEvaluation [add(pythonUDF1#17, 1)], [pythonUDF0#16,pythonUDF1#17,pythonUDF0#18]
   +- !BatchPythonEvaluation [double(add(1, 2)),double(2)], [pythonUDF0#16,pythonUDF1#17]
      +- Scan OneRowRelation[]
```

## How was this patch tested?

Added new tests.

Using the following script to benchmark 1, 2 and 3 udfs,
```
df = sqlContext.range(1, 1 << 23, 1, 4)
double = F.udf(lambda x: x * 2, LongType())
print df.select(double(df.id)).count()
print df.select(double(df.id), double(df.id + 1)).count()
print df.select(double(df.id), double(df.id + 1), double(df.id + 2)).count()
```
Here is the results:

N | Before | After  | speed up
---- |------------ | -------------|------
1 | 22 s | 7 s |  3.1X
2 | 38 s | 13 s | 2.9X
3 | 58 s | 16 s | 3.6X

This benchmark ran locally with 4 CPUs. For 3 UDFs, it launched 12 Python before before this patch, 4 process after this patch. After this patch, it will use less memory for multiple UDFs than before (less buffering).

Author: Davies Liu <davies@databricks.com>

Closes #12057 from davies/multi_udfs.
  • Loading branch information
Davies Liu authored and davies committed Mar 31, 2016
1 parent 8de201b commit f0afafd
Show file tree
Hide file tree
Showing 8 changed files with 233 additions and 101 deletions.
64 changes: 49 additions & 15 deletions core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ private[spark] class PythonRDD(
val asJavaRDD: JavaRDD[Array[Byte]] = JavaRDD.fromRDD(this)

override def compute(split: Partition, context: TaskContext): Iterator[Array[Byte]] = {
val runner = new PythonRunner(Seq(func), bufferSize, reuse_worker, false)
val runner = PythonRunner(func, bufferSize, reuse_worker)
runner.compute(firstParent.iterator(split, context), split.index, context)
}
}
Expand All @@ -78,21 +78,41 @@ private[spark] case class PythonFunction(
accumulator: Accumulator[JList[Array[Byte]]])

/**
* A helper class to run Python UDFs in Spark.
* A wrapper for chained Python functions (from bottom to top).
* @param funcs
*/
private[spark] case class ChainedPythonFunctions(funcs: Seq[PythonFunction])

private[spark] object PythonRunner {
def apply(func: PythonFunction, bufferSize: Int, reuse_worker: Boolean): PythonRunner = {
new PythonRunner(
Seq(ChainedPythonFunctions(Seq(func))), bufferSize, reuse_worker, false, Array(Array(0)))
}
}

/**
* A helper class to run Python mapPartition/UDFs in Spark.
*
* funcs is a list of independent Python functions, each one of them is a list of chained Python
* functions (from bottom to top).
*/
private[spark] class PythonRunner(
funcs: Seq[PythonFunction],
funcs: Seq[ChainedPythonFunctions],
bufferSize: Int,
reuse_worker: Boolean,
rowBased: Boolean)
isUDF: Boolean,
argOffsets: Array[Array[Int]])
extends Logging {

require(funcs.length == argOffsets.length, "argOffsets should have the same length as funcs")

// All the Python functions should have the same exec, version and envvars.
private val envVars = funcs.head.envVars
private val pythonExec = funcs.head.pythonExec
private val pythonVer = funcs.head.pythonVer
private val envVars = funcs.head.funcs.head.envVars
private val pythonExec = funcs.head.funcs.head.pythonExec
private val pythonVer = funcs.head.funcs.head.pythonVer

private val accumulator = funcs.head.accumulator // TODO: support accumulator in multiple UDF
// TODO: support accumulator in multiple UDF
private val accumulator = funcs.head.funcs.head.accumulator

def compute(
inputIterator: Iterator[_],
Expand Down Expand Up @@ -232,8 +252,8 @@ private[spark] class PythonRunner(

@volatile private var _exception: Exception = null

private val pythonIncludes = funcs.flatMap(_.pythonIncludes.asScala).toSet
private val broadcastVars = funcs.flatMap(_.broadcastVars.asScala)
private val pythonIncludes = funcs.flatMap(_.funcs.flatMap(_.pythonIncludes.asScala)).toSet
private val broadcastVars = funcs.flatMap(_.funcs.flatMap(_.broadcastVars.asScala))

setDaemon(true)

Expand Down Expand Up @@ -284,11 +304,25 @@ private[spark] class PythonRunner(
}
dataOut.flush()
// Serialized command:
dataOut.writeInt(if (rowBased) 1 else 0)
dataOut.writeInt(funcs.length)
funcs.foreach { f =>
dataOut.writeInt(f.command.length)
dataOut.write(f.command)
if (isUDF) {
dataOut.writeInt(1)
dataOut.writeInt(funcs.length)
funcs.zip(argOffsets).foreach { case (chained, offsets) =>
dataOut.writeInt(offsets.length)
offsets.foreach { offset =>
dataOut.writeInt(offset)
}
dataOut.writeInt(chained.funcs.length)
chained.funcs.foreach { f =>
dataOut.writeInt(f.command.length)
dataOut.write(f.command)
}
}
} else {
dataOut.writeInt(0)
val command = funcs.head.funcs.head.command
dataOut.writeInt(command.length)
dataOut.write(command)
}
// Data values
PythonRDD.writeIteratorToStream(inputIterator, dataOut)
Expand Down
3 changes: 1 addition & 2 deletions python/pyspark/sql/functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -1649,8 +1649,7 @@ def sort_array(col, asc=True):
# ---------------------------- User Defined Function ----------------------------------

def _wrap_function(sc, func, returnType):
ser = AutoBatchedSerializer(PickleSerializer())
command = (func, returnType, ser)
command = (func, returnType)
pickled_command, broadcast_vars, env, includes = _prepare_for_python_RDD(sc, command)
return sc._jvm.PythonFunction(bytearray(pickled_command), env, includes, sc.pythonExec,
sc.pythonVer, broadcast_vars, sc._javaAccumulator)
Expand Down
12 changes: 11 additions & 1 deletion python/pyspark/sql/tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -305,7 +305,7 @@ def test_udf2(self):
[res] = self.sqlCtx.sql("SELECT strlen(a) FROM test WHERE strlen(a) > 1").collect()
self.assertEqual(4, res[0])

def test_chained_python_udf(self):
def test_chained_udf(self):
self.sqlCtx.registerFunction("double", lambda x: x + x, IntegerType())
[row] = self.sqlCtx.sql("SELECT double(1)").collect()
self.assertEqual(row[0], 2)
Expand All @@ -314,6 +314,16 @@ def test_chained_python_udf(self):
[row] = self.sqlCtx.sql("SELECT double(double(1) + 1)").collect()
self.assertEqual(row[0], 6)

def test_multiple_udfs(self):
self.sqlCtx.registerFunction("double", lambda x: x * 2, IntegerType())
[row] = self.sqlCtx.sql("SELECT double(1), double(2)").collect()
self.assertEqual(tuple(row), (2, 4))
[row] = self.sqlCtx.sql("SELECT double(double(1)), double(double(2) + 2)").collect()
self.assertEqual(tuple(row), (4, 12))
self.sqlCtx.registerFunction("add", lambda x, y: x + y, IntegerType())
[row] = self.sqlCtx.sql("SELECT double(add(1, 2)), add(double(2), 1)").collect()
self.assertEqual(tuple(row), (6, 5))

def test_udf_with_array_type(self):
d = [Row(l=list(range(3)), d={"key": list(range(5))})]
rdd = self.sc.parallelize(d)
Expand Down
68 changes: 52 additions & 16 deletions python/pyspark/worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@
from pyspark.broadcast import Broadcast, _broadcastRegistry
from pyspark.files import SparkFiles
from pyspark.serializers import write_with_length, write_int, read_long, \
write_long, read_int, SpecialLengths, UTF8Deserializer, PickleSerializer
write_long, read_int, SpecialLengths, UTF8Deserializer, PickleSerializer, BatchedSerializer
from pyspark import shuffle

pickleSer = PickleSerializer()
Expand Down Expand Up @@ -59,7 +59,54 @@ def read_command(serializer, file):

def chain(f, g):
"""chain two function together """
return lambda x: g(f(x))
return lambda *a: g(f(*a))


def wrap_udf(f, return_type):
if return_type.needConversion():
toInternal = return_type.toInternal
return lambda *a: toInternal(f(*a))
else:
return lambda *a: f(*a)


def read_single_udf(pickleSer, infile):
num_arg = read_int(infile)
arg_offsets = [read_int(infile) for i in range(num_arg)]
row_func = None
for i in range(read_int(infile)):
f, return_type = read_command(pickleSer, infile)
if row_func is None:
row_func = f
else:
row_func = chain(row_func, f)
# the last returnType will be the return type of UDF
return arg_offsets, wrap_udf(row_func, return_type)


def read_udfs(pickleSer, infile):
num_udfs = read_int(infile)
if num_udfs == 1:
# fast path for single UDF
_, udf = read_single_udf(pickleSer, infile)
mapper = lambda a: udf(*a)
else:
udfs = {}
call_udf = []
for i in range(num_udfs):
arg_offsets, udf = read_single_udf(pickleSer, infile)
udfs['f%d' % i] = udf
args = ["a[%d]" % o for o in arg_offsets]
call_udf.append("f%d(%s)" % (i, ", ".join(args)))
# Create function like this:
# lambda a: (f0(a0), f1(a1, a2), f2(a3))
mapper_str = "lambda a: (%s)" % (", ".join(call_udf))
mapper = eval(mapper_str, udfs)

func = lambda _, it: map(mapper, it)
ser = BatchedSerializer(PickleSerializer(), 100)
# profiling is not supported for UDF
return func, None, ser, ser


def main(infile, outfile):
Expand Down Expand Up @@ -107,21 +154,10 @@ def main(infile, outfile):
_broadcastRegistry.pop(bid)

_accumulatorRegistry.clear()
row_based = read_int(infile)
num_commands = read_int(infile)
if row_based:
profiler = None # profiling is not supported for UDF
row_func = None
for i in range(num_commands):
f, returnType, deserializer = read_command(pickleSer, infile)
if row_func is None:
row_func = f
else:
row_func = chain(row_func, f)
serializer = deserializer
func = lambda _, it: map(lambda x: returnType.toInternal(row_func(*x)), it)
is_sql_udf = read_int(infile)
if is_sql_udf:
func, profiler, deserializer, serializer = read_udfs(pickleSer, infile)
else:
assert num_commands == 1
func, profiler, deserializer, serializer = read_command(pickleSer, infile)

init_time = time.time()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -426,8 +426,8 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] {
case logical.RepartitionByExpression(expressions, child, nPartitions) =>
exchange.ShuffleExchange(HashPartitioning(
expressions, nPartitions.getOrElse(numPartitions)), planLater(child)) :: Nil
case e @ python.EvaluatePython(udf, child, _) =>
python.BatchPythonEvaluation(udf, e.output, planLater(child)) :: Nil
case e @ python.EvaluatePython(udfs, child, _) =>
python.BatchPythonEvaluation(udfs, e.output, planLater(child)) :: Nil
case LogicalRDD(output, rdd) => PhysicalRDD(output, rdd, "ExistingRDD") :: Nil
case BroadcastHint(child) => planLater(child) :: Nil
case _ => Nil
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,16 +18,17 @@
package org.apache.spark.sql.execution.python

import scala.collection.JavaConverters._
import scala.collection.mutable.ArrayBuffer

import net.razorvine.pickle.{Pickler, Unpickler}

import org.apache.spark.TaskContext
import org.apache.spark.api.python.{PythonFunction, PythonRunner}
import org.apache.spark.api.python.{ChainedPythonFunctions, PythonRunner}
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.execution.SparkPlan
import org.apache.spark.sql.types.{StructField, StructType}
import org.apache.spark.sql.types.{DataType, StructField, StructType}


/**
Expand All @@ -40,20 +41,20 @@ import org.apache.spark.sql.types.{StructField, StructType}
* we drain the queue to find the original input row. Note that if the Python process is way too
* slow, this could lead to the queue growing unbounded and eventually run out of memory.
*/
case class BatchPythonEvaluation(udf: PythonUDF, output: Seq[Attribute], child: SparkPlan)
case class BatchPythonEvaluation(udfs: Seq[PythonUDF], output: Seq[Attribute], child: SparkPlan)
extends SparkPlan {

def children: Seq[SparkPlan] = child :: Nil

private def collectFunctions(udf: PythonUDF): (Seq[PythonFunction], Seq[Expression]) = {
private def collectFunctions(udf: PythonUDF): (ChainedPythonFunctions, Seq[Expression]) = {
udf.children match {
case Seq(u: PythonUDF) =>
val (fs, children) = collectFunctions(u)
(fs ++ Seq(udf.func), children)
val (chained, children) = collectFunctions(u)
(ChainedPythonFunctions(chained.funcs ++ Seq(udf.func)), children)
case children =>
// There should not be any other UDFs, or the children can't be evaluated directly.
assert(children.forall(_.find(_.isInstanceOf[PythonUDF]).isEmpty))
(Seq(udf.func), udf.children)
(ChainedPythonFunctions(Seq(udf.func)), udf.children)
}
}

Expand All @@ -69,39 +70,78 @@ case class BatchPythonEvaluation(udf: PythonUDF, output: Seq[Attribute], child:
// combine input with output from Python.
val queue = new java.util.concurrent.ConcurrentLinkedQueue[InternalRow]()

val (pyFuncs, children) = collectFunctions(udf)

val pickle = new Pickler
val currentRow = newMutableProjection(children, child.output)()
val fields = children.map(_.dataType)
val schema = new StructType(fields.map(t => new StructField("", t, true)).toArray)
val (pyFuncs, inputs) = udfs.map(collectFunctions).unzip

// flatten all the arguments
val allInputs = new ArrayBuffer[Expression]
val dataTypes = new ArrayBuffer[DataType]
val argOffsets = inputs.map { input =>
input.map { e =>
if (allInputs.exists(_.semanticEquals(e))) {
allInputs.indexWhere(_.semanticEquals(e))
} else {
allInputs += e
dataTypes += e.dataType
allInputs.length - 1
}
}.toArray
}.toArray
val projection = newMutableProjection(allInputs, child.output)()
val schema = StructType(dataTypes.map(dt => StructField("", dt)))
val needConversion = dataTypes.exists(EvaluatePython.needConversionInPython)

// enable memo iff we serialize the row with schema (schema and class should be memorized)
val pickle = new Pickler(needConversion)
// Input iterator to Python: input rows are grouped so we send them in batches to Python.
// For each row, add it to the queue.
val inputIterator = iter.grouped(100).map { inputRows =>
val toBePickled = inputRows.map { row =>
queue.add(row)
EvaluatePython.toJava(currentRow(row), schema)
val toBePickled = inputRows.map { inputRow =>
queue.add(inputRow)
val row = projection(inputRow)
if (needConversion) {
EvaluatePython.toJava(row, schema)
} else {
// fast path for these types that does not need conversion in Python
val fields = new Array[Any](row.numFields)
var i = 0
while (i < row.numFields) {
val dt = dataTypes(i)
fields(i) = EvaluatePython.toJava(row.get(i, dt), dt)
i += 1
}
fields
}
}.toArray
pickle.dumps(toBePickled)
}

val context = TaskContext.get()

// Output iterator for results from Python.
val outputIterator = new PythonRunner(pyFuncs, bufferSize, reuseWorker, true)
val outputIterator = new PythonRunner(pyFuncs, bufferSize, reuseWorker, true, argOffsets)
.compute(inputIterator, context.partitionId(), context)

val unpickle = new Unpickler
val row = new GenericMutableRow(1)
val mutableRow = new GenericMutableRow(1)
val joined = new JoinedRow
val resultType = if (udfs.length == 1) {
udfs.head.dataType
} else {
StructType(udfs.map(u => StructField("", u.dataType, u.nullable)))
}
val resultProj = UnsafeProjection.create(output, output)

outputIterator.flatMap { pickedResult =>
val unpickledBatch = unpickle.loads(pickedResult)
unpickledBatch.asInstanceOf[java.util.ArrayList[Any]].asScala
}.map { result =>
row(0) = EvaluatePython.fromJava(result, udf.dataType)
val row = if (udfs.length == 1) {
// fast path for single UDF
mutableRow(0) = EvaluatePython.fromJava(result, resultType)
mutableRow
} else {
EvaluatePython.fromJava(result, resultType).asInstanceOf[InternalRow]
}
resultProj(joined(queue.poll(), row))
}
}
Expand Down
Loading

0 comments on commit f0afafd

Please sign in to comment.