From 58dee20e38ddc2499fab6b2aa50e0d07d82b4235 Mon Sep 17 00:00:00 2001 From: Davies Liu Date: Tue, 3 Feb 2015 21:54:46 -0800 Subject: [PATCH 1/6] clean up --- python/pyspark/sql.py | 110 ++++++------------ .../scala/org/apache/spark/sql/Column.scala | 17 ++- 2 files changed, 51 insertions(+), 76 deletions(-) diff --git a/python/pyspark/sql.py b/python/pyspark/sql.py index 268c7ef97cffc..091bc6a53a777 100644 --- a/python/pyspark/sql.py +++ b/python/pyspark/sql.py @@ -2077,9 +2077,9 @@ def dtypes(self): """Return all column names and their data types as a list. >>> df.dtypes - [(u'age', 'IntegerType'), (u'name', 'StringType')] + [('age', 'integer'), ('name', 'string')] """ - return [(f.name, str(f.dataType)) for f in self.schema().fields] + return [(str(f.name), f.dataType.jsonValue()) for f in self.schema().fields] @property def columns(self): @@ -2263,18 +2263,6 @@ def subtract(self, other): """ return DataFrame(getattr(self._jdf, "except")(other._jdf), self.sql_ctx) - def sample(self, withReplacement, fraction, seed=None): - """ Return a new DataFrame by sampling a fraction of rows. - - >>> df.sample(False, 0.5, 10).collect() - [Row(age=2, name=u'Alice')] - """ - if seed is None: - jdf = self._jdf.sample(withReplacement, fraction) - else: - jdf = self._jdf.sample(withReplacement, fraction, seed) - return DataFrame(jdf, self.sql_ctx) - def addColumn(self, colName, col): """ Return a new :class:`DataFrame` by adding a column. @@ -2376,28 +2364,6 @@ def sum(self): group.""" -SCALA_METHOD_MAPPINGS = { - '=': '$eq', - '>': '$greater', - '<': '$less', - '+': '$plus', - '-': '$minus', - '*': '$times', - '/': '$div', - '!': '$bang', - '@': '$at', - '#': '$hash', - '%': '$percent', - '^': '$up', - '&': '$amp', - '~': '$tilde', - '?': '$qmark', - '|': '$bar', - '\\': '$bslash', - ':': '$colon', -} - - def _create_column_from_literal(literal): sc = SparkContext._active_spark_context return sc._jvm.Dsl.lit(literal) @@ -2416,23 +2382,18 @@ def _to_java_column(col): return jcol -def _scalaMethod(name): - """ Translate operators into methodName in Scala - - >>> _scalaMethod('+') - '$plus' - >>> _scalaMethod('>=') - '$greater$eq' - >>> _scalaMethod('cast') - 'cast' - """ - return ''.join(SCALA_METHOD_MAPPINGS.get(c, c) for c in name) - - def _unary_op(name, doc="unary operator"): """ Create a method for given unary operator """ def _(self): - jc = getattr(self._jc, _scalaMethod(name))() + jc = getattr(self._jc, name)() + return Column(jc, self.sql_ctx) + _.__doc__ = doc + return _ + + +def _dsl_op(name, doc=''): + def _(self): + jc = getattr(self._sc._jvm.Dsl, name)(self._jc) return Column(jc, self.sql_ctx) _.__doc__ = doc return _ @@ -2443,7 +2404,7 @@ def _bin_op(name, doc="binary operator"): """ def _(self, other): jc = other._jc if isinstance(other, Column) else other - njc = getattr(self._jc, _scalaMethod(name))(jc) + njc = getattr(self._jc, name)(jc) return Column(njc, self.sql_ctx) _.__doc__ = doc return _ @@ -2454,7 +2415,7 @@ def _reverse_op(name, doc="binary operator"): """ def _(self, other): jother = _create_column_from_literal(other) - jc = getattr(jother, _scalaMethod(name))(self._jc) + jc = getattr(jother, name)(self._jc) return Column(jc, self.sql_ctx) _.__doc__ = doc return _ @@ -2481,34 +2442,33 @@ def __init__(self, jc, sql_ctx=None): super(Column, self).__init__(jc, sql_ctx) # arithmetic operators - __neg__ = _unary_op("unary_-") - __add__ = _bin_op("+") - __sub__ = _bin_op("-") - __mul__ = _bin_op("*") - __div__ = _bin_op("/") - __mod__ = _bin_op("%") - __radd__ = _bin_op("+") - __rsub__ = _reverse_op("-") - __rmul__ = _bin_op("*") - __rdiv__ = _reverse_op("/") - __rmod__ = _reverse_op("%") - __abs__ = _unary_op("abs") + __neg__ = _dsl_op("negate") + __add__ = _bin_op("plus") + __sub__ = _bin_op("minus") + __mul__ = _bin_op("multiply") + __div__ = _bin_op("divide") + __mod__ = _bin_op("mod") + __radd__ = _bin_op("plus") + __rsub__ = _reverse_op("minus") + __rmul__ = _bin_op("multiply") + __rdiv__ = _reverse_op("divide") + __rmod__ = _reverse_op("mod") # logistic operators - __eq__ = _bin_op("===") - __ne__ = _bin_op("!==") - __lt__ = _bin_op("<") - __le__ = _bin_op("<=") - __ge__ = _bin_op(">=") - __gt__ = _bin_op(">") + __eq__ = _bin_op("equalTo") + __ne__ = _bin_op("notEqual") + __lt__ = _bin_op("lt") + __le__ = _bin_op("leq") + __ge__ = _bin_op("geq") + __gt__ = _bin_op("gt") # `and`, `or`, `not` cannot be overloaded in Python, # so use bitwise operators as boolean operators - __and__ = _bin_op('&&') - __or__ = _bin_op('||') - __invert__ = _unary_op('unary_!') - __rand__ = _bin_op("&&") - __ror__ = _bin_op("||") + __and__ = _bin_op('and') + __or__ = _bin_op('or') + __invert__ = _dsl_op('not') + __rand__ = _bin_op("and") + __ror__ = _bin_op("or") # container operators __contains__ = _bin_op("contains") diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Column.scala b/sql/core/src/main/scala/org/apache/spark/sql/Column.scala index ddce77deb83e1..33ec03e84d965 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/Column.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/Column.scala @@ -128,7 +128,6 @@ trait Column extends DataFrame { */ def unary_! : Column = exprToColumn(Not(expr)) - /** * Equality test. * {{{ @@ -173,6 +172,22 @@ trait Column extends DataFrame { Not(EqualTo(expr, o.expr)) } + /** + * Inequality test. + * {{{ + * // Scala: + * df.select( df("colA") !== df("colB") ) + * df.select( !(df("colA") === df("colB")) ) + * + * // Java: + * import static org.apache.spark.sql.Dsl.*; + * df.filter( col("colA").notEqual(col("colB")) ); + * }}} + */ + def notEqual(other: Any): Column = constructColumn(other) { o => + Not(EqualTo(expr, o.expr)) + } + /** * Greater than. * {{{ From 7bccc3bb7d1e829a3f8d91f508e12899d225235e Mon Sep 17 00:00:00 2001 From: Davies Liu Date: Tue, 3 Feb 2015 22:37:32 -0800 Subject: [PATCH 2/6] python udf --- python/pyspark/sql.py | 49 +++++++++++++++++++ .../main/scala/org/apache/spark/sql/Dsl.scala | 23 ++++++++- .../spark/sql/UserDefinedFunction.scala | 27 ++++++++++ 3 files changed, 98 insertions(+), 1 deletion(-) diff --git a/python/pyspark/sql.py b/python/pyspark/sql.py index 091bc6a53a777..8b9d96e0b880e 100644 --- a/python/pyspark/sql.py +++ b/python/pyspark/sql.py @@ -2554,6 +2554,45 @@ def _(col): return staticmethod(_) +class UserDefinedFunction(object): + def __init__(self, func, returnType): + self.func = func + self.returnType = returnType + self._judf = self._create_judf() + + def _create_judf(self): + f = self.func + sc = SparkContext._active_spark_context + # TODO(davies): refactor + func = lambda _, it: imap(lambda x: f(*x), it) + command = (func, None, + AutoBatchedSerializer(PickleSerializer()), + AutoBatchedSerializer(PickleSerializer())) + ser = CloudPickleSerializer() + pickled_command = ser.dumps(command) + if len(pickled_command) > (1 << 20): # 1M + broadcast = sc.broadcast(pickled_command) + pickled_command = ser.dumps(broadcast) + broadcast_vars = ListConverter().convert( + [x._jbroadcast for x in sc._pickled_broadcast_vars], + sc._gateway._gateway_client) + sc._pickled_broadcast_vars.clear() + env = MapConverter().convert(sc.environment, sc._gateway._gateway_client) + includes = ListConverter().convert(sc._python_includes, sc._gateway._gateway_client) + ssql_ctx = sc._jvm.SQLContext(sc._jsc.sc()) + jdt = ssql_ctx.parseDataType(self.returnType.json()) + judf = sc._jvm.Dsl.pythonUDF(f.__name__, bytearray(pickled_command), env, includes, + sc.pythonExec, broadcast_vars, sc._javaAccumulator, jdt) + return judf + + def __call__(self, *cols): + sc = SparkContext._active_spark_context + jcols = ListConverter().convert([_to_java_column(c) for c in cols], + sc._gateway._gateway_client) + jc = self._judf.apply(sc._jvm.Dsl.toColumns(jcols)) + return Column(jc) + + class Dsl(object): """ A collections of builtin aggregators @@ -2612,6 +2651,16 @@ def approxCountDistinct(col, rsd=None): jc = sc._jvm.Dsl.approxCountDistinct(_to_java_column(col), rsd) return Column(jc) + @staticmethod + def udf(f, returnType=StringType()): + """Create a user defined function (UDF) + + >>> slen = Dsl.udf(lambda s: len(s), IntegerType()) + >>> df.select(slen(df.name).As('slen')).collect() + [Row(slen=5), Row(slen=3)] + """ + return UserDefinedFunction(f, returnType) + def _test(): import doctest diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Dsl.scala b/sql/core/src/main/scala/org/apache/spark/sql/Dsl.scala index 8cf59f0a1f099..3480b57a67296 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/Dsl.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/Dsl.scala @@ -17,7 +17,11 @@ package org.apache.spark.sql -import java.util.{List => JList} +import java.util.{List => JList, Map => JMap} + +import org.apache.spark.Accumulator +import org.apache.spark.api.python.PythonBroadcast +import org.apache.spark.broadcast.Broadcast import scala.language.implicitConversions import scala.reflect.runtime.universe.{TypeTag, typeTag} @@ -177,6 +181,23 @@ object Dsl { cols.toList.toSeq } + /** + * This is a private API for Python + * TODO: move this to a private package + */ + def pythonUDF( + name: String, + command: Array[Byte], + envVars: JMap[String, String], + pythonIncludes: JList[String], + pythonExec: String, + broadcastVars: JList[Broadcast[PythonBroadcast]], + accumulator: Accumulator[JList[Array[Byte]]], + dataType: DataType): UserDefinedPythonFunction = { + UserDefinedPythonFunction(name, command, envVars, pythonIncludes, pythonExec, broadcastVars, + accumulator, dataType) + } + ////////////////////////////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/sql/core/src/main/scala/org/apache/spark/sql/UserDefinedFunction.scala b/sql/core/src/main/scala/org/apache/spark/sql/UserDefinedFunction.scala index 8d7c2a1b8339e..c60d4070942a9 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/UserDefinedFunction.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/UserDefinedFunction.scala @@ -17,7 +17,13 @@ package org.apache.spark.sql +import java.util.{List => JList, Map => JMap} + +import org.apache.spark.Accumulator +import org.apache.spark.api.python.PythonBroadcast +import org.apache.spark.broadcast.Broadcast import org.apache.spark.sql.catalyst.expressions.ScalaUdf +import org.apache.spark.sql.execution.PythonUDF import org.apache.spark.sql.types.DataType /** @@ -37,3 +43,24 @@ case class UserDefinedFunction(f: AnyRef, dataType: DataType) { Column(ScalaUdf(f, dataType, exprs.map(_.expr))) } } + +/** + * A user-defined Python function. To create one, use the `pythonUDF` functions in [[Dsl]]. + * This is used by Python API. + */ +private[sql] case class UserDefinedPythonFunction( + name: String, + command: Array[Byte], + envVars: JMap[String, String], + pythonIncludes: JList[String], + pythonExec: String, + broadcastVars: JList[Broadcast[PythonBroadcast]], + accumulator: Accumulator[JList[Array[Byte]]], + dataType: DataType) { + + def apply(exprs: Column*): Column = { + val udf = PythonUDF(name, command, envVars, pythonIncludes, pythonExec, broadcastVars, + accumulator, dataType, exprs.map(_.expr)) + Column(udf) + } +} From f99b2e12ccb03fa9e5803d7379535f7dc54dcab4 Mon Sep 17 00:00:00 2001 From: Davies Liu Date: Wed, 4 Feb 2015 01:01:45 -0800 Subject: [PATCH 3/6] address comments --- python/pyspark/rdd.py | 34 ++++++++++++++---------- python/pyspark/sql.py | 61 ++++++++++++------------------------------- 2 files changed, 36 insertions(+), 59 deletions(-) diff --git a/python/pyspark/rdd.py b/python/pyspark/rdd.py index 2f8a0edfe9644..c67eee7dfd09e 100644 --- a/python/pyspark/rdd.py +++ b/python/pyspark/rdd.py @@ -2162,6 +2162,25 @@ def toLocalIterator(self): yield row +def _prepare_for_python_RDD(sc, command, obj=None): + # the serialized command will be compressed by broadcast + ser = CloudPickleSerializer() + pickled_command = ser.dumps(command) + if len(pickled_command) > (1 << 20): # 1M + broadcast = sc.broadcast(pickled_command) + pickled_command = ser.dumps(broadcast) + # tracking the life cycle by obj + if obj is not None: + obj._broadcast = broadcast + broadcast_vars = ListConverter().convert( + [x._jbroadcast for x in sc._pickled_broadcast_vars], + sc._gateway._gateway_client) + sc._pickled_broadcast_vars.clear() + env = MapConverter().convert(sc.environment, sc._gateway._gateway_client) + includes = ListConverter().convert(sc._python_includes, sc._gateway._gateway_client) + return pickled_command, broadcast_vars, env, includes + + class PipelinedRDD(RDD): """ @@ -2228,20 +2247,7 @@ def _jrdd(self): command = (self.func, profiler, self._prev_jrdd_deserializer, self._jrdd_deserializer) - # the serialized command will be compressed by broadcast - ser = CloudPickleSerializer() - pickled_command = ser.dumps(command) - if len(pickled_command) > (1 << 20): # 1M - self._broadcast = self.ctx.broadcast(pickled_command) - pickled_command = ser.dumps(self._broadcast) - broadcast_vars = ListConverter().convert( - [x._jbroadcast for x in self.ctx._pickled_broadcast_vars], - self.ctx._gateway._gateway_client) - self.ctx._pickled_broadcast_vars.clear() - env = MapConverter().convert(self.ctx.environment, - self.ctx._gateway._gateway_client) - includes = ListConverter().convert(self.ctx._python_includes, - self.ctx._gateway._gateway_client) + pickled_command, broadcast_vars, env, includes = _prepare_for_python_RDD(self.ctx, command) python_rdd = self.ctx._jvm.PythonRDD(self._prev_jrdd.rdd(), bytearray(pickled_command), env, includes, self.preservesPartitioning, diff --git a/python/pyspark/sql.py b/python/pyspark/sql.py index 244baf7639da9..404384225dee5 100644 --- a/python/pyspark/sql.py +++ b/python/pyspark/sql.py @@ -51,7 +51,7 @@ from py4j.java_collections import ListConverter, MapConverter from pyspark.context import SparkContext -from pyspark.rdd import RDD +from pyspark.rdd import RDD, _prepare_for_python_RDD from pyspark.serializers import BatchedSerializer, AutoBatchedSerializer, PickleSerializer, \ CloudPickleSerializer, UTF8Deserializer from pyspark.storagelevel import StorageLevel @@ -1274,22 +1274,9 @@ def registerFunction(self, name, f, returnType=StringType()): [Row(c0=4)] """ func = lambda _, it: imap(lambda x: f(*x), it) - command = (func, None, - AutoBatchedSerializer(PickleSerializer()), - AutoBatchedSerializer(PickleSerializer())) - ser = CloudPickleSerializer() - pickled_command = ser.dumps(command) - if len(pickled_command) > (1 << 20): # 1M - broadcast = self._sc.broadcast(pickled_command) - pickled_command = ser.dumps(broadcast) - broadcast_vars = ListConverter().convert( - [x._jbroadcast for x in self._sc._pickled_broadcast_vars], - self._sc._gateway._gateway_client) - self._sc._pickled_broadcast_vars.clear() - env = MapConverter().convert(self._sc.environment, - self._sc._gateway._gateway_client) - includes = ListConverter().convert(self._sc._python_includes, - self._sc._gateway._gateway_client) + ser = AutoBatchedSerializer(PickleSerializer()) + command = (func, None, ser, ser) + pickled_command, broadcast_vars, env, includes = _prepare_for_python_RDD(self._sc, command) self._ssql_ctx.udf().registerPython(name, bytearray(pickled_command), env, @@ -2187,7 +2174,7 @@ def select(self, *cols): [Row(age=2, name=u'Alice'), Row(age=5, name=u'Bob')] >>> df.select('name', 'age').collect() [Row(name=u'Alice', age=2), Row(name=u'Bob', age=5)] - >>> df.select(df.name, (df.age + 10).As('age')).collect() + >>> df.select(df.name, (df.age + 10).alias('age')).collect() [Row(name=u'Alice', age=12), Row(name=u'Bob', age=15)] """ if not cols: @@ -2268,7 +2255,7 @@ def addColumn(self, colName, col): >>> df.addColumn('age2', df.age + 2).collect() [Row(age=2, name=u'Alice', age2=4), Row(age=5, name=u'Bob', age2=7)] """ - return self.select('*', col.As(colName)) + return self.select('*', col.alias(colName)) # Having SchemaRDD for backward compatibility (for docs) @@ -2509,24 +2496,20 @@ def substr(self, startPos, length): isNull = _unary_op("isNull", "True if the current expression is null.") isNotNull = _unary_op("isNotNull", "True if the current expression is not null.") - # `as` is keyword def alias(self, alias): """Return a alias for this column - >>> df.age.As("age2").collect() - [Row(age2=2), Row(age2=5)] >>> df.age.alias("age2").collect() [Row(age2=2), Row(age2=5)] """ return Column(getattr(self._jc, "as")(alias), self.sql_ctx) - As = alias def cast(self, dataType): """ Convert the column into type `dataType` - >>> df.select(df.age.cast("string").As('ages')).collect() + >>> df.select(df.age.cast("string").alias('ages')).collect() [Row(ages=u'2'), Row(ages=u'5')] - >>> df.select(df.age.cast(StringType()).As('ages')).collect() + >>> df.select(df.age.cast(StringType()).alias('ages')).collect() [Row(ages=u'2'), Row(ages=u'5')] """ if self.sql_ctx is None: @@ -2560,24 +2543,12 @@ def __init__(self, func, returnType): self._judf = self._create_judf() def _create_judf(self): - f = self.func - sc = SparkContext._active_spark_context - # TODO(davies): refactor + f = self.func # put it in closure `func` func = lambda _, it: imap(lambda x: f(*x), it) - command = (func, None, - AutoBatchedSerializer(PickleSerializer()), - AutoBatchedSerializer(PickleSerializer())) - ser = CloudPickleSerializer() - pickled_command = ser.dumps(command) - if len(pickled_command) > (1 << 20): # 1M - broadcast = sc.broadcast(pickled_command) - pickled_command = ser.dumps(broadcast) - broadcast_vars = ListConverter().convert( - [x._jbroadcast for x in sc._pickled_broadcast_vars], - sc._gateway._gateway_client) - sc._pickled_broadcast_vars.clear() - env = MapConverter().convert(sc.environment, sc._gateway._gateway_client) - includes = ListConverter().convert(sc._python_includes, sc._gateway._gateway_client) + ser = AutoBatchedSerializer(PickleSerializer()) + command = (func, None, ser, ser) + sc = SparkContext._active_spark_context + pickled_command, broadcast_vars, env, includes = _prepare_for_python_RDD(sc, command) ssql_ctx = sc._jvm.SQLContext(sc._jsc.sc()) jdt = ssql_ctx.parseDataType(self.returnType.json()) judf = sc._jvm.Dsl.pythonUDF(f.__name__, bytearray(pickled_command), env, includes, @@ -2625,7 +2596,7 @@ def countDistinct(col, *cols): """ Return a new Column for distinct count of (col, *cols) >>> from pyspark.sql import Dsl - >>> df.agg(Dsl.countDistinct(df.age, df.name).As('c')).collect() + >>> df.agg(Dsl.countDistinct(df.age, df.name).alias('c')).collect() [Row(c=2)] """ sc = SparkContext._active_spark_context @@ -2640,7 +2611,7 @@ def approxCountDistinct(col, rsd=None): """ Return a new Column for approxiate distinct count of (col, *cols) >>> from pyspark.sql import Dsl - >>> df.agg(Dsl.approxCountDistinct(df.age).As('c')).collect() + >>> df.agg(Dsl.approxCountDistinct(df.age).alias('c')).collect() [Row(c=2)] """ sc = SparkContext._active_spark_context @@ -2655,7 +2626,7 @@ def udf(f, returnType=StringType()): """Create a user defined function (UDF) >>> slen = Dsl.udf(lambda s: len(s), IntegerType()) - >>> df.select(slen(df.name).As('slen')).collect() + >>> df.select(slen(df.name).alias('slen')).collect() [Row(slen=5), Row(slen=3)] """ return UserDefinedFunction(f, returnType) From f0a31217ed7d837dc98ff974f8417bb456fd49af Mon Sep 17 00:00:00 2001 From: Davies Liu Date: Wed, 4 Feb 2015 09:52:43 -0800 Subject: [PATCH 4/6] track life cycle of broadcast --- python/pyspark/rdd.py | 6 +++--- python/pyspark/sql.py | 14 ++++++++++---- 2 files changed, 13 insertions(+), 7 deletions(-) diff --git a/python/pyspark/rdd.py b/python/pyspark/rdd.py index c67eee7dfd09e..6e029bf7f13fc 100644 --- a/python/pyspark/rdd.py +++ b/python/pyspark/rdd.py @@ -2247,12 +2247,12 @@ def _jrdd(self): command = (self.func, profiler, self._prev_jrdd_deserializer, self._jrdd_deserializer) - pickled_command, broadcast_vars, env, includes = _prepare_for_python_RDD(self.ctx, command) + pickled_cmd, bvars, env, includes = _prepare_for_python_RDD(self.ctx, command, self) python_rdd = self.ctx._jvm.PythonRDD(self._prev_jrdd.rdd(), - bytearray(pickled_command), + bytearray(pickled_cmd), env, includes, self.preservesPartitioning, self.ctx.pythonExec, - broadcast_vars, self.ctx._javaAccumulator) + bvars, self.ctx._javaAccumulator) self._jrdd_val = python_rdd.asJavaRDD() if profiler: diff --git a/python/pyspark/sql.py b/python/pyspark/sql.py index 404384225dee5..02313412206b7 100644 --- a/python/pyspark/sql.py +++ b/python/pyspark/sql.py @@ -1276,13 +1276,13 @@ def registerFunction(self, name, f, returnType=StringType()): func = lambda _, it: imap(lambda x: f(*x), it) ser = AutoBatchedSerializer(PickleSerializer()) command = (func, None, ser, ser) - pickled_command, broadcast_vars, env, includes = _prepare_for_python_RDD(self._sc, command) + pickled_cmd, bvars, env, includes = _prepare_for_python_RDD(self._sc, command, self) self._ssql_ctx.udf().registerPython(name, - bytearray(pickled_command), + bytearray(pickled_cmd), env, includes, self._sc.pythonExec, - broadcast_vars, + bvars, self._sc._javaAccumulator, returnType.json()) @@ -2540,6 +2540,7 @@ class UserDefinedFunction(object): def __init__(self, func, returnType): self.func = func self.returnType = returnType + self._broadcast = None self._judf = self._create_judf() def _create_judf(self): @@ -2548,13 +2549,18 @@ def _create_judf(self): ser = AutoBatchedSerializer(PickleSerializer()) command = (func, None, ser, ser) sc = SparkContext._active_spark_context - pickled_command, broadcast_vars, env, includes = _prepare_for_python_RDD(sc, command) + pickled_command, broadcast_vars, env, includes = _prepare_for_python_RDD(sc, command, self) ssql_ctx = sc._jvm.SQLContext(sc._jsc.sc()) jdt = ssql_ctx.parseDataType(self.returnType.json()) judf = sc._jvm.Dsl.pythonUDF(f.__name__, bytearray(pickled_command), env, includes, sc.pythonExec, broadcast_vars, sc._javaAccumulator, jdt) return judf + def __del__(self): + if self._broadcast is not None: + self._broadcast.unpersist() + self._broadcast = None + def __call__(self, *cols): sc = SparkContext._active_spark_context jcols = ListConverter().convert([_to_java_column(c) for c in cols], From 440f76922bba64b701c1cab2f762e6811d0a558e Mon Sep 17 00:00:00 2001 From: Davies Liu Date: Wed, 4 Feb 2015 10:09:38 -0800 Subject: [PATCH 5/6] address comments --- python/pyspark/sql.py | 5 +++-- .../scala/org/apache/spark/sql/Column.scala | 2 +- .../main/scala/org/apache/spark/sql/Dsl.scala | 17 ----------------- 3 files changed, 4 insertions(+), 20 deletions(-) diff --git a/python/pyspark/sql.py b/python/pyspark/sql.py index 02313412206b7..27588e8c5bbb5 100644 --- a/python/pyspark/sql.py +++ b/python/pyspark/sql.py @@ -2552,8 +2552,9 @@ def _create_judf(self): pickled_command, broadcast_vars, env, includes = _prepare_for_python_RDD(sc, command, self) ssql_ctx = sc._jvm.SQLContext(sc._jsc.sc()) jdt = ssql_ctx.parseDataType(self.returnType.json()) - judf = sc._jvm.Dsl.pythonUDF(f.__name__, bytearray(pickled_command), env, includes, - sc.pythonExec, broadcast_vars, sc._javaAccumulator, jdt) + judf = sc._jvm.UserDefinedPythonFunction(f.__name__, bytearray(pickled_command), env, + includes, sc.pythonExec, broadcast_vars, + sc._javaAccumulator, jdt) return judf def __del__(self): diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Column.scala b/sql/core/src/main/scala/org/apache/spark/sql/Column.scala index 33ec03e84d965..4c2aeadae9492 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/Column.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/Column.scala @@ -165,7 +165,7 @@ trait Column extends DataFrame { * * // Java: * import static org.apache.spark.sql.Dsl.*; - * df.filter( not(col("colA").equalTo(col("colB"))) ); + * df.filter( col("colA").notEqual(col("colB")) ); * }}} */ def !== (other: Any): Column = constructColumn(other) { o => diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Dsl.scala b/sql/core/src/main/scala/org/apache/spark/sql/Dsl.scala index 3480b57a67296..a2fd97783ff7c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/Dsl.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/Dsl.scala @@ -181,23 +181,6 @@ object Dsl { cols.toList.toSeq } - /** - * This is a private API for Python - * TODO: move this to a private package - */ - def pythonUDF( - name: String, - command: Array[Byte], - envVars: JMap[String, String], - pythonIncludes: JList[String], - pythonExec: String, - broadcastVars: JList[Broadcast[PythonBroadcast]], - accumulator: Accumulator[JList[Array[Byte]]], - dataType: DataType): UserDefinedPythonFunction = { - UserDefinedPythonFunction(name, command, envVars, pythonIncludes, pythonExec, broadcastVars, - accumulator, dataType) - } - ////////////////////////////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////////////////////////// From d25069257d6a195f6d7c3b848bc32f9764a7f6b1 Mon Sep 17 00:00:00 2001 From: Davies Liu Date: Wed, 4 Feb 2015 13:56:23 -0800 Subject: [PATCH 6/6] fix conflict --- python/pyspark/sql.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/pyspark/sql.py b/python/pyspark/sql.py index 5ee8e7a954612..5b56b36bdcdb7 100644 --- a/python/pyspark/sql.py +++ b/python/pyspark/sql.py @@ -2599,7 +2599,7 @@ def __call__(self, *cols): sc = SparkContext._active_spark_context jcols = ListConverter().convert([_to_java_column(c) for c in cols], sc._gateway._gateway_client) - jc = self._judf.apply(sc._jvm.Dsl.toColumns(jcols)) + jc = self._judf.apply(sc._jvm.PythonUtils.toSeq(jcols)) return Column(jc)