Skip to content

Commit

Permalink
python udf
Browse files Browse the repository at this point in the history
  • Loading branch information
Davies Liu committed Feb 4, 2015
1 parent 58dee20 commit 7bccc3b
Show file tree
Hide file tree
Showing 3 changed files with 98 additions and 1 deletion.
49 changes: 49 additions & 0 deletions python/pyspark/sql.py
Original file line number Diff line number Diff line change
Expand Up @@ -2554,6 +2554,45 @@ def _(col):
return staticmethod(_)


class UserDefinedFunction(object):
def __init__(self, func, returnType):
self.func = func
self.returnType = returnType
self._judf = self._create_judf()

def _create_judf(self):
f = self.func
sc = SparkContext._active_spark_context
# TODO(davies): refactor
func = lambda _, it: imap(lambda x: f(*x), it)
command = (func, None,
AutoBatchedSerializer(PickleSerializer()),
AutoBatchedSerializer(PickleSerializer()))
ser = CloudPickleSerializer()
pickled_command = ser.dumps(command)
if len(pickled_command) > (1 << 20): # 1M
broadcast = sc.broadcast(pickled_command)
pickled_command = ser.dumps(broadcast)
broadcast_vars = ListConverter().convert(
[x._jbroadcast for x in sc._pickled_broadcast_vars],
sc._gateway._gateway_client)
sc._pickled_broadcast_vars.clear()
env = MapConverter().convert(sc.environment, sc._gateway._gateway_client)
includes = ListConverter().convert(sc._python_includes, sc._gateway._gateway_client)
ssql_ctx = sc._jvm.SQLContext(sc._jsc.sc())
jdt = ssql_ctx.parseDataType(self.returnType.json())
judf = sc._jvm.Dsl.pythonUDF(f.__name__, bytearray(pickled_command), env, includes,
sc.pythonExec, broadcast_vars, sc._javaAccumulator, jdt)
return judf

def __call__(self, *cols):
sc = SparkContext._active_spark_context
jcols = ListConverter().convert([_to_java_column(c) for c in cols],
sc._gateway._gateway_client)
jc = self._judf.apply(sc._jvm.Dsl.toColumns(jcols))
return Column(jc)


class Dsl(object):
"""
A collections of builtin aggregators
Expand Down Expand Up @@ -2612,6 +2651,16 @@ def approxCountDistinct(col, rsd=None):
jc = sc._jvm.Dsl.approxCountDistinct(_to_java_column(col), rsd)
return Column(jc)

@staticmethod
def udf(f, returnType=StringType()):
"""Create a user defined function (UDF)
>>> slen = Dsl.udf(lambda s: len(s), IntegerType())
>>> df.select(slen(df.name).As('slen')).collect()
[Row(slen=5), Row(slen=3)]
"""
return UserDefinedFunction(f, returnType)


def _test():
import doctest
Expand Down
23 changes: 22 additions & 1 deletion sql/core/src/main/scala/org/apache/spark/sql/Dsl.scala
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,11 @@

package org.apache.spark.sql

import java.util.{List => JList}
import java.util.{List => JList, Map => JMap}

import org.apache.spark.Accumulator
import org.apache.spark.api.python.PythonBroadcast
import org.apache.spark.broadcast.Broadcast

import scala.language.implicitConversions
import scala.reflect.runtime.universe.{TypeTag, typeTag}
Expand Down Expand Up @@ -177,6 +181,23 @@ object Dsl {
cols.toList.toSeq
}

/**
* This is a private API for Python
* TODO: move this to a private package
*/
def pythonUDF(
name: String,
command: Array[Byte],
envVars: JMap[String, String],
pythonIncludes: JList[String],
pythonExec: String,
broadcastVars: JList[Broadcast[PythonBroadcast]],
accumulator: Accumulator[JList[Array[Byte]]],
dataType: DataType): UserDefinedPythonFunction = {
UserDefinedPythonFunction(name, command, envVars, pythonIncludes, pythonExec, broadcastVars,
accumulator, dataType)
}

//////////////////////////////////////////////////////////////////////////////////////////////
//////////////////////////////////////////////////////////////////////////////////////////////

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,13 @@

package org.apache.spark.sql

import java.util.{List => JList, Map => JMap}

import org.apache.spark.Accumulator
import org.apache.spark.api.python.PythonBroadcast
import org.apache.spark.broadcast.Broadcast
import org.apache.spark.sql.catalyst.expressions.ScalaUdf
import org.apache.spark.sql.execution.PythonUDF
import org.apache.spark.sql.types.DataType

/**
Expand All @@ -37,3 +43,24 @@ case class UserDefinedFunction(f: AnyRef, dataType: DataType) {
Column(ScalaUdf(f, dataType, exprs.map(_.expr)))
}
}

/**
* A user-defined Python function. To create one, use the `pythonUDF` functions in [[Dsl]].
* This is used by Python API.
*/
private[sql] case class UserDefinedPythonFunction(
name: String,
command: Array[Byte],
envVars: JMap[String, String],
pythonIncludes: JList[String],
pythonExec: String,
broadcastVars: JList[Broadcast[PythonBroadcast]],
accumulator: Accumulator[JList[Array[Byte]]],
dataType: DataType) {

def apply(exprs: Column*): Column = {
val udf = PythonUDF(name, command, envVars, pythonIncludes, pythonExec, broadcastVars,
accumulator, dataType, exprs.map(_.expr))
Column(udf)
}
}

0 comments on commit 7bccc3b

Please sign in to comment.