diff --git a/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala b/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala index 0f579b4ef509a..6faa03c12b6d3 100644 --- a/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala +++ b/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala @@ -59,7 +59,7 @@ private[spark] class PythonRDD( val asJavaRDD: JavaRDD[Array[Byte]] = JavaRDD.fromRDD(this) override def compute(split: Partition, context: TaskContext): Iterator[Array[Byte]] = { - val runner = new PythonRunner(Seq(func), bufferSize, reuse_worker, false) + val runner = PythonRunner(func, bufferSize, reuse_worker) runner.compute(firstParent.iterator(split, context), split.index, context) } } @@ -78,21 +78,41 @@ private[spark] case class PythonFunction( accumulator: Accumulator[JList[Array[Byte]]]) /** - * A helper class to run Python UDFs in Spark. + * A wrapper for chained Python functions (from bottom to top). + * @param funcs + */ +private[spark] case class ChainedPythonFunctions(funcs: Seq[PythonFunction]) + +private[spark] object PythonRunner { + def apply(func: PythonFunction, bufferSize: Int, reuse_worker: Boolean): PythonRunner = { + new PythonRunner( + Seq(ChainedPythonFunctions(Seq(func))), bufferSize, reuse_worker, false, Array(Array(0))) + } +} + +/** + * A helper class to run Python mapPartition/UDFs in Spark. + * + * funcs is a list of independent Python functions, each one of them is a list of chained Python + * functions (from bottom to top). */ private[spark] class PythonRunner( - funcs: Seq[PythonFunction], + funcs: Seq[ChainedPythonFunctions], bufferSize: Int, reuse_worker: Boolean, - rowBased: Boolean) + isUDF: Boolean, + argOffsets: Array[Array[Int]]) extends Logging { + require(funcs.length == argOffsets.length, "argOffsets should have the same length as funcs") + // All the Python functions should have the same exec, version and envvars. - private val envVars = funcs.head.envVars - private val pythonExec = funcs.head.pythonExec - private val pythonVer = funcs.head.pythonVer + private val envVars = funcs.head.funcs.head.envVars + private val pythonExec = funcs.head.funcs.head.pythonExec + private val pythonVer = funcs.head.funcs.head.pythonVer - private val accumulator = funcs.head.accumulator // TODO: support accumulator in multiple UDF + // TODO: support accumulator in multiple UDF + private val accumulator = funcs.head.funcs.head.accumulator def compute( inputIterator: Iterator[_], @@ -232,8 +252,8 @@ private[spark] class PythonRunner( @volatile private var _exception: Exception = null - private val pythonIncludes = funcs.flatMap(_.pythonIncludes.asScala).toSet - private val broadcastVars = funcs.flatMap(_.broadcastVars.asScala) + private val pythonIncludes = funcs.flatMap(_.funcs.flatMap(_.pythonIncludes.asScala)).toSet + private val broadcastVars = funcs.flatMap(_.funcs.flatMap(_.broadcastVars.asScala)) setDaemon(true) @@ -284,11 +304,25 @@ private[spark] class PythonRunner( } dataOut.flush() // Serialized command: - dataOut.writeInt(if (rowBased) 1 else 0) - dataOut.writeInt(funcs.length) - funcs.foreach { f => - dataOut.writeInt(f.command.length) - dataOut.write(f.command) + if (isUDF) { + dataOut.writeInt(1) + dataOut.writeInt(funcs.length) + funcs.zip(argOffsets).foreach { case (chained, offsets) => + dataOut.writeInt(offsets.length) + offsets.foreach { offset => + dataOut.writeInt(offset) + } + dataOut.writeInt(chained.funcs.length) + chained.funcs.foreach { f => + dataOut.writeInt(f.command.length) + dataOut.write(f.command) + } + } + } else { + dataOut.writeInt(0) + val command = funcs.head.funcs.head.command + dataOut.writeInt(command.length) + dataOut.write(command) } // Data values PythonRDD.writeIteratorToStream(inputIterator, dataOut) diff --git a/python/pyspark/sql/functions.py b/python/pyspark/sql/functions.py index 321183422641f..3b20ba5177efd 100644 --- a/python/pyspark/sql/functions.py +++ b/python/pyspark/sql/functions.py @@ -1649,8 +1649,7 @@ def sort_array(col, asc=True): # ---------------------------- User Defined Function ---------------------------------- def _wrap_function(sc, func, returnType): - ser = AutoBatchedSerializer(PickleSerializer()) - command = (func, returnType, ser) + command = (func, returnType) pickled_command, broadcast_vars, env, includes = _prepare_for_python_RDD(sc, command) return sc._jvm.PythonFunction(bytearray(pickled_command), env, includes, sc.pythonExec, sc.pythonVer, broadcast_vars, sc._javaAccumulator) diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py index 84947560e775c..536ef552519e1 100644 --- a/python/pyspark/sql/tests.py +++ b/python/pyspark/sql/tests.py @@ -305,7 +305,7 @@ def test_udf2(self): [res] = self.sqlCtx.sql("SELECT strlen(a) FROM test WHERE strlen(a) > 1").collect() self.assertEqual(4, res[0]) - def test_chained_python_udf(self): + def test_chained_udf(self): self.sqlCtx.registerFunction("double", lambda x: x + x, IntegerType()) [row] = self.sqlCtx.sql("SELECT double(1)").collect() self.assertEqual(row[0], 2) @@ -314,6 +314,16 @@ def test_chained_python_udf(self): [row] = self.sqlCtx.sql("SELECT double(double(1) + 1)").collect() self.assertEqual(row[0], 6) + def test_multiple_udfs(self): + self.sqlCtx.registerFunction("double", lambda x: x * 2, IntegerType()) + [row] = self.sqlCtx.sql("SELECT double(1), double(2)").collect() + self.assertEqual(tuple(row), (2, 4)) + [row] = self.sqlCtx.sql("SELECT double(double(1)), double(double(2) + 2)").collect() + self.assertEqual(tuple(row), (4, 12)) + self.sqlCtx.registerFunction("add", lambda x, y: x + y, IntegerType()) + [row] = self.sqlCtx.sql("SELECT double(add(1, 2)), add(double(2), 1)").collect() + self.assertEqual(tuple(row), (6, 5)) + def test_udf_with_array_type(self): d = [Row(l=list(range(3)), d={"key": list(range(5))})] rdd = self.sc.parallelize(d) diff --git a/python/pyspark/worker.py b/python/pyspark/worker.py index 0f05fe31aa28a..cf47ab8f96c6d 100644 --- a/python/pyspark/worker.py +++ b/python/pyspark/worker.py @@ -29,7 +29,7 @@ from pyspark.broadcast import Broadcast, _broadcastRegistry from pyspark.files import SparkFiles from pyspark.serializers import write_with_length, write_int, read_long, \ - write_long, read_int, SpecialLengths, UTF8Deserializer, PickleSerializer + write_long, read_int, SpecialLengths, UTF8Deserializer, PickleSerializer, BatchedSerializer from pyspark import shuffle pickleSer = PickleSerializer() @@ -59,7 +59,54 @@ def read_command(serializer, file): def chain(f, g): """chain two function together """ - return lambda x: g(f(x)) + return lambda *a: g(f(*a)) + + +def wrap_udf(f, return_type): + if return_type.needConversion(): + toInternal = return_type.toInternal + return lambda *a: toInternal(f(*a)) + else: + return lambda *a: f(*a) + + +def read_single_udf(pickleSer, infile): + num_arg = read_int(infile) + arg_offsets = [read_int(infile) for i in range(num_arg)] + row_func = None + for i in range(read_int(infile)): + f, return_type = read_command(pickleSer, infile) + if row_func is None: + row_func = f + else: + row_func = chain(row_func, f) + # the last returnType will be the return type of UDF + return arg_offsets, wrap_udf(row_func, return_type) + + +def read_udfs(pickleSer, infile): + num_udfs = read_int(infile) + if num_udfs == 1: + # fast path for single UDF + _, udf = read_single_udf(pickleSer, infile) + mapper = lambda a: udf(*a) + else: + udfs = {} + call_udf = [] + for i in range(num_udfs): + arg_offsets, udf = read_single_udf(pickleSer, infile) + udfs['f%d' % i] = udf + args = ["a[%d]" % o for o in arg_offsets] + call_udf.append("f%d(%s)" % (i, ", ".join(args))) + # Create function like this: + # lambda a: (f0(a0), f1(a1, a2), f2(a3)) + mapper_str = "lambda a: (%s)" % (", ".join(call_udf)) + mapper = eval(mapper_str, udfs) + + func = lambda _, it: map(mapper, it) + ser = BatchedSerializer(PickleSerializer(), 100) + # profiling is not supported for UDF + return func, None, ser, ser def main(infile, outfile): @@ -107,21 +154,10 @@ def main(infile, outfile): _broadcastRegistry.pop(bid) _accumulatorRegistry.clear() - row_based = read_int(infile) - num_commands = read_int(infile) - if row_based: - profiler = None # profiling is not supported for UDF - row_func = None - for i in range(num_commands): - f, returnType, deserializer = read_command(pickleSer, infile) - if row_func is None: - row_func = f - else: - row_func = chain(row_func, f) - serializer = deserializer - func = lambda _, it: map(lambda x: returnType.toInternal(row_func(*x)), it) + is_sql_udf = read_int(infile) + if is_sql_udf: + func, profiler, deserializer, serializer = read_udfs(pickleSer, infile) else: - assert num_commands == 1 func, profiler, deserializer, serializer = read_command(pickleSer, infile) init_time = time.time() diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala index 7841ff01f93c2..7a2e2b73822f9 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala @@ -426,8 +426,8 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] { case logical.RepartitionByExpression(expressions, child, nPartitions) => exchange.ShuffleExchange(HashPartitioning( expressions, nPartitions.getOrElse(numPartitions)), planLater(child)) :: Nil - case e @ python.EvaluatePython(udf, child, _) => - python.BatchPythonEvaluation(udf, e.output, planLater(child)) :: Nil + case e @ python.EvaluatePython(udfs, child, _) => + python.BatchPythonEvaluation(udfs, e.output, planLater(child)) :: Nil case LogicalRDD(output, rdd) => PhysicalRDD(output, rdd, "ExistingRDD") :: Nil case BroadcastHint(child) => planLater(child) :: Nil case _ => Nil diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/BatchPythonEvaluation.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/BatchPythonEvaluation.scala index a76009e7dfeff..c9ab40a0a9abf 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/BatchPythonEvaluation.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/BatchPythonEvaluation.scala @@ -18,16 +18,17 @@ package org.apache.spark.sql.execution.python import scala.collection.JavaConverters._ +import scala.collection.mutable.ArrayBuffer import net.razorvine.pickle.{Pickler, Unpickler} import org.apache.spark.TaskContext -import org.apache.spark.api.python.{PythonFunction, PythonRunner} +import org.apache.spark.api.python.{ChainedPythonFunctions, PythonRunner} import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.execution.SparkPlan -import org.apache.spark.sql.types.{StructField, StructType} +import org.apache.spark.sql.types.{DataType, StructField, StructType} /** @@ -40,20 +41,20 @@ import org.apache.spark.sql.types.{StructField, StructType} * we drain the queue to find the original input row. Note that if the Python process is way too * slow, this could lead to the queue growing unbounded and eventually run out of memory. */ -case class BatchPythonEvaluation(udf: PythonUDF, output: Seq[Attribute], child: SparkPlan) +case class BatchPythonEvaluation(udfs: Seq[PythonUDF], output: Seq[Attribute], child: SparkPlan) extends SparkPlan { def children: Seq[SparkPlan] = child :: Nil - private def collectFunctions(udf: PythonUDF): (Seq[PythonFunction], Seq[Expression]) = { + private def collectFunctions(udf: PythonUDF): (ChainedPythonFunctions, Seq[Expression]) = { udf.children match { case Seq(u: PythonUDF) => - val (fs, children) = collectFunctions(u) - (fs ++ Seq(udf.func), children) + val (chained, children) = collectFunctions(u) + (ChainedPythonFunctions(chained.funcs ++ Seq(udf.func)), children) case children => // There should not be any other UDFs, or the children can't be evaluated directly. assert(children.forall(_.find(_.isInstanceOf[PythonUDF]).isEmpty)) - (Seq(udf.func), udf.children) + (ChainedPythonFunctions(Seq(udf.func)), udf.children) } } @@ -69,19 +70,47 @@ case class BatchPythonEvaluation(udf: PythonUDF, output: Seq[Attribute], child: // combine input with output from Python. val queue = new java.util.concurrent.ConcurrentLinkedQueue[InternalRow]() - val (pyFuncs, children) = collectFunctions(udf) - - val pickle = new Pickler - val currentRow = newMutableProjection(children, child.output)() - val fields = children.map(_.dataType) - val schema = new StructType(fields.map(t => new StructField("", t, true)).toArray) + val (pyFuncs, inputs) = udfs.map(collectFunctions).unzip + + // flatten all the arguments + val allInputs = new ArrayBuffer[Expression] + val dataTypes = new ArrayBuffer[DataType] + val argOffsets = inputs.map { input => + input.map { e => + if (allInputs.exists(_.semanticEquals(e))) { + allInputs.indexWhere(_.semanticEquals(e)) + } else { + allInputs += e + dataTypes += e.dataType + allInputs.length - 1 + } + }.toArray + }.toArray + val projection = newMutableProjection(allInputs, child.output)() + val schema = StructType(dataTypes.map(dt => StructField("", dt))) + val needConversion = dataTypes.exists(EvaluatePython.needConversionInPython) + // enable memo iff we serialize the row with schema (schema and class should be memorized) + val pickle = new Pickler(needConversion) // Input iterator to Python: input rows are grouped so we send them in batches to Python. // For each row, add it to the queue. val inputIterator = iter.grouped(100).map { inputRows => - val toBePickled = inputRows.map { row => - queue.add(row) - EvaluatePython.toJava(currentRow(row), schema) + val toBePickled = inputRows.map { inputRow => + queue.add(inputRow) + val row = projection(inputRow) + if (needConversion) { + EvaluatePython.toJava(row, schema) + } else { + // fast path for these types that does not need conversion in Python + val fields = new Array[Any](row.numFields) + var i = 0 + while (i < row.numFields) { + val dt = dataTypes(i) + fields(i) = EvaluatePython.toJava(row.get(i, dt), dt) + i += 1 + } + fields + } }.toArray pickle.dumps(toBePickled) } @@ -89,19 +118,30 @@ case class BatchPythonEvaluation(udf: PythonUDF, output: Seq[Attribute], child: val context = TaskContext.get() // Output iterator for results from Python. - val outputIterator = new PythonRunner(pyFuncs, bufferSize, reuseWorker, true) + val outputIterator = new PythonRunner(pyFuncs, bufferSize, reuseWorker, true, argOffsets) .compute(inputIterator, context.partitionId(), context) val unpickle = new Unpickler - val row = new GenericMutableRow(1) + val mutableRow = new GenericMutableRow(1) val joined = new JoinedRow + val resultType = if (udfs.length == 1) { + udfs.head.dataType + } else { + StructType(udfs.map(u => StructField("", u.dataType, u.nullable))) + } val resultProj = UnsafeProjection.create(output, output) outputIterator.flatMap { pickedResult => val unpickledBatch = unpickle.loads(pickedResult) unpickledBatch.asInstanceOf[java.util.ArrayList[Any]].asScala }.map { result => - row(0) = EvaluatePython.fromJava(result, udf.dataType) + val row = if (udfs.length == 1) { + // fast path for single UDF + mutableRow(0) = EvaluatePython.fromJava(result, resultType) + mutableRow + } else { + EvaluatePython.fromJava(result, resultType).asInstanceOf[InternalRow] + } resultProj(joined(queue.poll(), row)) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/EvaluatePython.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/EvaluatePython.scala index da28ec4f53412..f3d1c44b25b4c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/EvaluatePython.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/EvaluatePython.scala @@ -36,24 +36,28 @@ import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.UTF8String /** - * Evaluates a [[PythonUDF]], appending the result to the end of the input tuple. + * Evaluates a list of [[PythonUDF]], appending the result to the end of the input tuple. */ case class EvaluatePython( - udf: PythonUDF, + udfs: Seq[PythonUDF], child: LogicalPlan, - resultAttribute: AttributeReference) + resultAttribute: Seq[AttributeReference]) extends logical.UnaryNode { - def output: Seq[Attribute] = child.output :+ resultAttribute + def output: Seq[Attribute] = child.output ++ resultAttribute // References should not include the produced attribute. - override def references: AttributeSet = udf.references + override def references: AttributeSet = AttributeSet(udfs.flatMap(_.references)) } object EvaluatePython { - def apply(udf: PythonUDF, child: LogicalPlan): EvaluatePython = - new EvaluatePython(udf, child, AttributeReference("pythonUDF", udf.dataType)()) + def apply(udfs: Seq[PythonUDF], child: LogicalPlan): EvaluatePython = { + val resultAttrs = udfs.zipWithIndex.map { case (u, i) => + AttributeReference(s"pythonUDF$i", u.dataType)() + } + new EvaluatePython(udfs, child, resultAttrs) + } def takeAndServe(df: DataFrame, n: Int): Int = { registerPicklers() @@ -66,6 +70,16 @@ object EvaluatePython { } } + def needConversionInPython(dt: DataType): Boolean = dt match { + case DateType | TimestampType => true + case _: StructType => true + case _: UserDefinedType[_] => true + case ArrayType(elementType, _) => needConversionInPython(elementType) + case MapType(keyType, valueType, _) => + needConversionInPython(keyType) || needConversionInPython(valueType) + case _ => false + } + /** * Helper for converting from Catalyst type to java type suitable for Pyrolite. */ diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ExtractPythonUDFs.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ExtractPythonUDFs.scala index c486ce18e8a1c..0934cd135d480 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ExtractPythonUDFs.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ExtractPythonUDFs.scala @@ -17,7 +17,9 @@ package org.apache.spark.sql.execution.python -import org.apache.spark.sql.catalyst.expressions.Expression +import scala.collection.mutable + +import org.apache.spark.sql.catalyst.expressions.{Attribute, Expression} import org.apache.spark.sql.catalyst.plans.logical import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.catalyst.rules.Rule @@ -47,10 +49,9 @@ private[spark] object ExtractPythonUDFs extends Rule[LogicalPlan] { } } - private def collectEvaluatableUDF(expr: Expression): Seq[PythonUDF] = { - expr.collect { - case udf: PythonUDF if canEvaluateInPython(udf) => udf - } + private def collectEvaluatableUDF(expr: Expression): Seq[PythonUDF] = expr match { + case udf: PythonUDF if canEvaluateInPython(udf) => Seq(udf) + case e => e.children.flatMap(collectEvaluatableUDF) } def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators { @@ -59,45 +60,43 @@ private[spark] object ExtractPythonUDFs extends Rule[LogicalPlan] { case plan: LogicalPlan if plan.resolved => // Extract any PythonUDFs from the current operator. - val udfs = plan.expressions.flatMap(collectEvaluatableUDF) + val udfs = plan.expressions.flatMap(collectEvaluatableUDF).filter(_.resolved) if (udfs.isEmpty) { // If there aren't any, we are done. plan } else { - // Pick the UDF we are going to evaluate (TODO: Support evaluating multiple UDFs at a time) - // If there is more than one, we will add another evaluation operator in a subsequent pass. - udfs.find(_.resolved) match { - case Some(udf) => - var evaluation: EvaluatePython = null - - // Rewrite the child that has the input required for the UDF - val newChildren = plan.children.map { child => - // Check to make sure that the UDF can be evaluated with only the input of this child. - // Other cases are disallowed as they are ambiguous or would require a cartesian - // product. - if (udf.references.subsetOf(child.outputSet)) { - evaluation = EvaluatePython(udf, child) - evaluation - } else if (udf.references.intersect(child.outputSet).nonEmpty) { - sys.error(s"Invalid PythonUDF $udf, requires attributes from more than one child.") - } else { - child - } - } - - assert(evaluation != null, "Unable to evaluate PythonUDF. Missing input attributes.") - - // Trim away the new UDF value if it was only used for filtering or something. - logical.Project( - plan.output, - plan.transformExpressions { - case p: PythonUDF if p.fastEquals(udf) => evaluation.resultAttribute - }.withNewChildren(newChildren)) - - case None => - // If there is no Python UDF that is resolved, skip this round. - plan + val attributeMap = mutable.HashMap[PythonUDF, Expression]() + // Rewrite the child that has the input required for the UDF + val newChildren = plan.children.map { child => + // Pick the UDF we are going to evaluate + val validUdfs = udfs.filter { case udf => + // Check to make sure that the UDF can be evaluated with only the input of this child. + udf.references.subsetOf(child.outputSet) + } + if (validUdfs.nonEmpty) { + val evaluation = EvaluatePython(validUdfs, child) + attributeMap ++= validUdfs.zip(evaluation.resultAttribute) + evaluation + } else { + child + } } + // Other cases are disallowed as they are ambiguous or would require a cartesian + // product. + udfs.filterNot(attributeMap.contains).foreach { udf => + if (udf.references.subsetOf(plan.inputSet)) { + sys.error(s"Invalid PythonUDF $udf, requires attributes from more than one child.") + } else { + sys.error(s"Unable to evaluate PythonUDF $udf. Missing input attributes.") + } + } + + // Trim away the new UDF value if it was only used for filtering or something. + logical.Project( + plan.output, + plan.transformExpressions { + case p: PythonUDF if attributeMap.contains(p) => attributeMap(p) + }.withNewChildren(newChildren)) } } }