Skip to content

Commit

Permalink
[SPARK-5577] Python udf for DataFrame
Browse files Browse the repository at this point in the history
Author: Davies Liu <davies@databricks.com>

Closes #4351 from davies/python_udf and squashes the following commits:

d250692 [Davies Liu] fix conflict
34234d4 [Davies Liu] Merge branch 'master' of github.com:apache/spark into python_udf
440f769 [Davies Liu] address comments
f0a3121 [Davies Liu] track life cycle of broadcast
f99b2e1 [Davies Liu] address comments
462b334 [Davies Liu] Merge branch 'master' of github.com:apache/spark into python_udf
7bccc3b [Davies Liu] python udf
58dee20 [Davies Liu] clean up
  • Loading branch information
Davies Liu authored and rxin committed Feb 4, 2015
1 parent e0490e2 commit dc101b0
Show file tree
Hide file tree
Showing 4 changed files with 157 additions and 122 deletions.
38 changes: 22 additions & 16 deletions python/pyspark/rdd.py
Original file line number Diff line number Diff line change
Expand Up @@ -2162,6 +2162,25 @@ def toLocalIterator(self):
yield row


def _prepare_for_python_RDD(sc, command, obj=None):
# the serialized command will be compressed by broadcast
ser = CloudPickleSerializer()
pickled_command = ser.dumps(command)
if len(pickled_command) > (1 << 20): # 1M
broadcast = sc.broadcast(pickled_command)
pickled_command = ser.dumps(broadcast)
# tracking the life cycle by obj
if obj is not None:
obj._broadcast = broadcast
broadcast_vars = ListConverter().convert(
[x._jbroadcast for x in sc._pickled_broadcast_vars],
sc._gateway._gateway_client)
sc._pickled_broadcast_vars.clear()
env = MapConverter().convert(sc.environment, sc._gateway._gateway_client)
includes = ListConverter().convert(sc._python_includes, sc._gateway._gateway_client)
return pickled_command, broadcast_vars, env, includes


class PipelinedRDD(RDD):

"""
Expand Down Expand Up @@ -2228,25 +2247,12 @@ def _jrdd(self):

command = (self.func, profiler, self._prev_jrdd_deserializer,
self._jrdd_deserializer)
# the serialized command will be compressed by broadcast
ser = CloudPickleSerializer()
pickled_command = ser.dumps(command)
if len(pickled_command) > (1 << 20): # 1M
self._broadcast = self.ctx.broadcast(pickled_command)
pickled_command = ser.dumps(self._broadcast)
broadcast_vars = ListConverter().convert(
[x._jbroadcast for x in self.ctx._pickled_broadcast_vars],
self.ctx._gateway._gateway_client)
self.ctx._pickled_broadcast_vars.clear()
env = MapConverter().convert(self.ctx.environment,
self.ctx._gateway._gateway_client)
includes = ListConverter().convert(self.ctx._python_includes,
self.ctx._gateway._gateway_client)
pickled_cmd, bvars, env, includes = _prepare_for_python_RDD(self.ctx, command, self)
python_rdd = self.ctx._jvm.PythonRDD(self._prev_jrdd.rdd(),
bytearray(pickled_command),
bytearray(pickled_cmd),
env, includes, self.preservesPartitioning,
self.ctx.pythonExec,
broadcast_vars, self.ctx._javaAccumulator)
bvars, self.ctx._javaAccumulator)
self._jrdd_val = python_rdd.asJavaRDD()

if profiler:
Expand Down
195 changes: 91 additions & 104 deletions python/pyspark/sql.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@
from py4j.java_collections import ListConverter, MapConverter

from pyspark.context import SparkContext
from pyspark.rdd import RDD
from pyspark.rdd import RDD, _prepare_for_python_RDD
from pyspark.serializers import BatchedSerializer, AutoBatchedSerializer, PickleSerializer, \
CloudPickleSerializer, UTF8Deserializer
from pyspark.storagelevel import StorageLevel
Expand Down Expand Up @@ -1274,28 +1274,15 @@ def registerFunction(self, name, f, returnType=StringType()):
[Row(c0=4)]
"""
func = lambda _, it: imap(lambda x: f(*x), it)
command = (func, None,
AutoBatchedSerializer(PickleSerializer()),
AutoBatchedSerializer(PickleSerializer()))
ser = CloudPickleSerializer()
pickled_command = ser.dumps(command)
if len(pickled_command) > (1 << 20): # 1M
broadcast = self._sc.broadcast(pickled_command)
pickled_command = ser.dumps(broadcast)
broadcast_vars = ListConverter().convert(
[x._jbroadcast for x in self._sc._pickled_broadcast_vars],
self._sc._gateway._gateway_client)
self._sc._pickled_broadcast_vars.clear()
env = MapConverter().convert(self._sc.environment,
self._sc._gateway._gateway_client)
includes = ListConverter().convert(self._sc._python_includes,
self._sc._gateway._gateway_client)
ser = AutoBatchedSerializer(PickleSerializer())
command = (func, None, ser, ser)
pickled_cmd, bvars, env, includes = _prepare_for_python_RDD(self._sc, command, self)
self._ssql_ctx.udf().registerPython(name,
bytearray(pickled_command),
bytearray(pickled_cmd),
env,
includes,
self._sc.pythonExec,
broadcast_vars,
bvars,
self._sc._javaAccumulator,
returnType.json())

Expand Down Expand Up @@ -2077,9 +2064,9 @@ def dtypes(self):
"""Return all column names and their data types as a list.
>>> df.dtypes
[(u'age', 'IntegerType'), (u'name', 'StringType')]
[('age', 'integer'), ('name', 'string')]
"""
return [(f.name, str(f.dataType)) for f in self.schema().fields]
return [(str(f.name), f.dataType.jsonValue()) for f in self.schema().fields]

@property
def columns(self):
Expand Down Expand Up @@ -2194,7 +2181,7 @@ def select(self, *cols):
[Row(age=2, name=u'Alice'), Row(age=5, name=u'Bob')]
>>> df.select('name', 'age').collect()
[Row(name=u'Alice', age=2), Row(name=u'Bob', age=5)]
>>> df.select(df.name, (df.age + 10).As('age')).collect()
>>> df.select(df.name, (df.age + 10).alias('age')).collect()
[Row(name=u'Alice', age=12), Row(name=u'Bob', age=15)]
"""
if not cols:
Expand Down Expand Up @@ -2295,25 +2282,13 @@ def subtract(self, other):
"""
return DataFrame(getattr(self._jdf, "except")(other._jdf), self.sql_ctx)

def sample(self, withReplacement, fraction, seed=None):
""" Return a new DataFrame by sampling a fraction of rows.
>>> df.sample(False, 0.5, 10).collect()
[Row(age=2, name=u'Alice')]
"""
if seed is None:
jdf = self._jdf.sample(withReplacement, fraction)
else:
jdf = self._jdf.sample(withReplacement, fraction, seed)
return DataFrame(jdf, self.sql_ctx)

def addColumn(self, colName, col):
""" Return a new :class:`DataFrame` by adding a column.
>>> df.addColumn('age2', df.age + 2).collect()
[Row(age=2, name=u'Alice', age2=4), Row(age=5, name=u'Bob', age2=7)]
"""
return self.select('*', col.As(colName))
return self.select('*', col.alias(colName))


# Having SchemaRDD for backward compatibility (for docs)
Expand Down Expand Up @@ -2408,28 +2383,6 @@ def sum(self):
group."""


SCALA_METHOD_MAPPINGS = {
'=': '$eq',
'>': '$greater',
'<': '$less',
'+': '$plus',
'-': '$minus',
'*': '$times',
'/': '$div',
'!': '$bang',
'@': '$at',
'#': '$hash',
'%': '$percent',
'^': '$up',
'&': '$amp',
'~': '$tilde',
'?': '$qmark',
'|': '$bar',
'\\': '$bslash',
':': '$colon',
}


def _create_column_from_literal(literal):
sc = SparkContext._active_spark_context
return sc._jvm.Dsl.lit(literal)
Expand All @@ -2448,23 +2401,18 @@ def _to_java_column(col):
return jcol


def _scalaMethod(name):
""" Translate operators into methodName in Scala
>>> _scalaMethod('+')
'$plus'
>>> _scalaMethod('>=')
'$greater$eq'
>>> _scalaMethod('cast')
'cast'
"""
return ''.join(SCALA_METHOD_MAPPINGS.get(c, c) for c in name)


def _unary_op(name, doc="unary operator"):
""" Create a method for given unary operator """
def _(self):
jc = getattr(self._jc, _scalaMethod(name))()
jc = getattr(self._jc, name)()
return Column(jc, self.sql_ctx)
_.__doc__ = doc
return _


def _dsl_op(name, doc=''):
def _(self):
jc = getattr(self._sc._jvm.Dsl, name)(self._jc)
return Column(jc, self.sql_ctx)
_.__doc__ = doc
return _
Expand All @@ -2475,7 +2423,7 @@ def _bin_op(name, doc="binary operator"):
"""
def _(self, other):
jc = other._jc if isinstance(other, Column) else other
njc = getattr(self._jc, _scalaMethod(name))(jc)
njc = getattr(self._jc, name)(jc)
return Column(njc, self.sql_ctx)
_.__doc__ = doc
return _
Expand All @@ -2486,7 +2434,7 @@ def _reverse_op(name, doc="binary operator"):
"""
def _(self, other):
jother = _create_column_from_literal(other)
jc = getattr(jother, _scalaMethod(name))(self._jc)
jc = getattr(jother, name)(self._jc)
return Column(jc, self.sql_ctx)
_.__doc__ = doc
return _
Expand All @@ -2513,34 +2461,33 @@ def __init__(self, jc, sql_ctx=None):
super(Column, self).__init__(jc, sql_ctx)

# arithmetic operators
__neg__ = _unary_op("unary_-")
__add__ = _bin_op("+")
__sub__ = _bin_op("-")
__mul__ = _bin_op("*")
__div__ = _bin_op("/")
__mod__ = _bin_op("%")
__radd__ = _bin_op("+")
__rsub__ = _reverse_op("-")
__rmul__ = _bin_op("*")
__rdiv__ = _reverse_op("/")
__rmod__ = _reverse_op("%")
__abs__ = _unary_op("abs")
__neg__ = _dsl_op("negate")
__add__ = _bin_op("plus")
__sub__ = _bin_op("minus")
__mul__ = _bin_op("multiply")
__div__ = _bin_op("divide")
__mod__ = _bin_op("mod")
__radd__ = _bin_op("plus")
__rsub__ = _reverse_op("minus")
__rmul__ = _bin_op("multiply")
__rdiv__ = _reverse_op("divide")
__rmod__ = _reverse_op("mod")

# logistic operators
__eq__ = _bin_op("===")
__ne__ = _bin_op("!==")
__lt__ = _bin_op("<")
__le__ = _bin_op("<=")
__ge__ = _bin_op(">=")
__gt__ = _bin_op(">")
__eq__ = _bin_op("equalTo")
__ne__ = _bin_op("notEqual")
__lt__ = _bin_op("lt")
__le__ = _bin_op("leq")
__ge__ = _bin_op("geq")
__gt__ = _bin_op("gt")

# `and`, `or`, `not` cannot be overloaded in Python,
# so use bitwise operators as boolean operators
__and__ = _bin_op('&&')
__or__ = _bin_op('||')
__invert__ = _unary_op('unary_!')
__rand__ = _bin_op("&&")
__ror__ = _bin_op("||")
__and__ = _bin_op('and')
__or__ = _bin_op('or')
__invert__ = _dsl_op('not')
__rand__ = _bin_op("and")
__ror__ = _bin_op("or")

# container operators
__contains__ = _bin_op("contains")
Expand Down Expand Up @@ -2582,24 +2529,20 @@ def substr(self, startPos, length):
isNull = _unary_op("isNull", "True if the current expression is null.")
isNotNull = _unary_op("isNotNull", "True if the current expression is not null.")

# `as` is keyword
def alias(self, alias):
"""Return a alias for this column
>>> df.age.As("age2").collect()
[Row(age2=2), Row(age2=5)]
>>> df.age.alias("age2").collect()
[Row(age2=2), Row(age2=5)]
"""
return Column(getattr(self._jc, "as")(alias), self.sql_ctx)
As = alias

def cast(self, dataType):
""" Convert the column into type `dataType`
>>> df.select(df.age.cast("string").As('ages')).collect()
>>> df.select(df.age.cast("string").alias('ages')).collect()
[Row(ages=u'2'), Row(ages=u'5')]
>>> df.select(df.age.cast(StringType()).As('ages')).collect()
>>> df.select(df.age.cast(StringType()).alias('ages')).collect()
[Row(ages=u'2'), Row(ages=u'5')]
"""
if self.sql_ctx is None:
Expand All @@ -2626,6 +2569,40 @@ def _(col):
return staticmethod(_)


class UserDefinedFunction(object):
def __init__(self, func, returnType):
self.func = func
self.returnType = returnType
self._broadcast = None
self._judf = self._create_judf()

def _create_judf(self):
f = self.func # put it in closure `func`
func = lambda _, it: imap(lambda x: f(*x), it)
ser = AutoBatchedSerializer(PickleSerializer())
command = (func, None, ser, ser)
sc = SparkContext._active_spark_context
pickled_command, broadcast_vars, env, includes = _prepare_for_python_RDD(sc, command, self)
ssql_ctx = sc._jvm.SQLContext(sc._jsc.sc())
jdt = ssql_ctx.parseDataType(self.returnType.json())
judf = sc._jvm.UserDefinedPythonFunction(f.__name__, bytearray(pickled_command), env,
includes, sc.pythonExec, broadcast_vars,
sc._javaAccumulator, jdt)
return judf

def __del__(self):
if self._broadcast is not None:
self._broadcast.unpersist()
self._broadcast = None

def __call__(self, *cols):
sc = SparkContext._active_spark_context
jcols = ListConverter().convert([_to_java_column(c) for c in cols],
sc._gateway._gateway_client)
jc = self._judf.apply(sc._jvm.PythonUtils.toSeq(jcols))
return Column(jc)


class Dsl(object):
"""
A collections of builtin aggregators
Expand Down Expand Up @@ -2659,7 +2636,7 @@ def countDistinct(col, *cols):
""" Return a new Column for distinct count of (col, *cols)
>>> from pyspark.sql import Dsl
>>> df.agg(Dsl.countDistinct(df.age, df.name).As('c')).collect()
>>> df.agg(Dsl.countDistinct(df.age, df.name).alias('c')).collect()
[Row(c=2)]
"""
sc = SparkContext._active_spark_context
Expand All @@ -2674,7 +2651,7 @@ def approxCountDistinct(col, rsd=None):
""" Return a new Column for approxiate distinct count of (col, *cols)
>>> from pyspark.sql import Dsl
>>> df.agg(Dsl.approxCountDistinct(df.age).As('c')).collect()
>>> df.agg(Dsl.approxCountDistinct(df.age).alias('c')).collect()
[Row(c=2)]
"""
sc = SparkContext._active_spark_context
Expand All @@ -2684,6 +2661,16 @@ def approxCountDistinct(col, rsd=None):
jc = sc._jvm.Dsl.approxCountDistinct(_to_java_column(col), rsd)
return Column(jc)

@staticmethod
def udf(f, returnType=StringType()):
"""Create a user defined function (UDF)
>>> slen = Dsl.udf(lambda s: len(s), IntegerType())
>>> df.select(slen(df.name).alias('slen')).collect()
[Row(slen=5), Row(slen=3)]
"""
return UserDefinedFunction(f, returnType)


def _test():
import doctest
Expand Down
Loading

0 comments on commit dc101b0

Please sign in to comment.