diff --git a/python/pyspark/context.py b/python/pyspark/context.py index 1dc2fec0ae5c8..ddd1109afdb9b 100644 --- a/python/pyspark/context.py +++ b/python/pyspark/context.py @@ -23,8 +23,6 @@ from threading import Lock from tempfile import NamedTemporaryFile -from py4j.java_collections import ListConverter - from pyspark import accumulators from pyspark.accumulators import Accumulator from pyspark.broadcast import Broadcast @@ -643,7 +641,6 @@ def union(self, rdds): rdds = [x._reserialize() for x in rdds] first = rdds[0]._jrdd rest = [x._jrdd for x in rdds[1:]] - rest = ListConverter().convert(rest, self._gateway._gateway_client) return RDD(self._jsc.union(first, rest), self, rdds[0]._jrdd_deserializer) def broadcast(self, value): @@ -846,13 +843,12 @@ def runJob(self, rdd, partitionFunc, partitions=None, allowLocal=False): """ if partitions is None: partitions = range(rdd._jrdd.partitions().size()) - javaPartitions = ListConverter().convert(partitions, self._gateway._gateway_client) # Implementation note: This is implemented as a mapPartitions followed # by runJob() in order to avoid having to pass a Python lambda into # SparkContext#runJob. mappedRDD = rdd.mapPartitions(partitionFunc) - port = self._jvm.PythonRDD.runJob(self._jsc.sc(), mappedRDD._jrdd, javaPartitions, + port = self._jvm.PythonRDD.runJob(self._jsc.sc(), mappedRDD._jrdd, partitions, allowLocal) return list(_load_from_socket(port, mappedRDD._jrdd_deserializer)) diff --git a/python/pyspark/java_gateway.py b/python/pyspark/java_gateway.py index 45bc38f7e61f8..1d238231a8377 100644 --- a/python/pyspark/java_gateway.py +++ b/python/pyspark/java_gateway.py @@ -23,11 +23,20 @@ import socket import platform from subprocess import Popen, PIPE + from py4j.java_gateway import java_import, JavaGateway, GatewayClient +from py4j.java_collections import ListConverter from pyspark.serializers import read_int +# patching ListConverter, or it will convert bytearray into Java ArrayList +def can_convert_list(self, obj): + return isinstance(obj, list) + +ListConverter.can_convert = can_convert_list + + def launch_gateway(): if "PYSPARK_GATEWAY_PORT" in os.environ: gateway_port = int(os.environ["PYSPARK_GATEWAY_PORT"]) @@ -92,7 +101,7 @@ def killChild(): atexit.register(killChild) # Connect to the gateway - gateway = JavaGateway(GatewayClient(port=gateway_port), auto_convert=False) + gateway = JavaGateway(GatewayClient(port=gateway_port), auto_convert=True) # Import the classes used by PySpark java_import(gateway.jvm, "org.apache.spark.SparkConf") diff --git a/python/pyspark/mllib/common.py b/python/pyspark/mllib/common.py index ba6058978880a..61a6e4776d9e5 100644 --- a/python/pyspark/mllib/common.py +++ b/python/pyspark/mllib/common.py @@ -23,7 +23,7 @@ import py4j.protocol from py4j.protocol import Py4JJavaError from py4j.java_gateway import JavaObject -from py4j.java_collections import ListConverter, JavaArray, JavaList +from py4j.java_collections import JavaArray, JavaList from pyspark import RDD, SparkContext from pyspark.serializers import PickleSerializer, AutoBatchedSerializer @@ -76,7 +76,7 @@ def _py2java(sc, obj): elif isinstance(obj, SparkContext): obj = obj._jsc elif isinstance(obj, list): - obj = ListConverter().convert([_py2java(sc, x) for x in obj], sc._gateway._gateway_client) + obj = [_py2java(sc, x) for x in obj] elif isinstance(obj, JavaObject): pass elif isinstance(obj, (int, long, float, bool, bytes, unicode)): diff --git a/python/pyspark/sql/context.py b/python/pyspark/sql/context.py index c90afc326ca0e..58a728ec2996f 100644 --- a/python/pyspark/sql/context.py +++ b/python/pyspark/sql/context.py @@ -25,7 +25,6 @@ from itertools import imap as map from py4j.protocol import Py4JError -from py4j.java_collections import MapConverter from pyspark.rdd import RDD, _prepare_for_python_RDD, ignore_unicode_prefix from pyspark.serializers import AutoBatchedSerializer, PickleSerializer @@ -442,15 +441,13 @@ def load(self, path=None, source=None, schema=None, **options): if source is None: source = self.getConf("spark.sql.sources.default", "org.apache.spark.sql.parquet") - joptions = MapConverter().convert(options, - self._sc._gateway._gateway_client) if schema is None: - df = self._ssql_ctx.load(source, joptions) + df = self._ssql_ctx.load(source, options) else: if not isinstance(schema, StructType): raise TypeError("schema should be StructType") scala_datatype = self._ssql_ctx.parseDataType(schema.json()) - df = self._ssql_ctx.load(source, scala_datatype, joptions) + df = self._ssql_ctx.load(source, scala_datatype, options) return DataFrame(df, self) def createExternalTable(self, tableName, path=None, source=None, @@ -471,16 +468,14 @@ def createExternalTable(self, tableName, path=None, source=None, if source is None: source = self.getConf("spark.sql.sources.default", "org.apache.spark.sql.parquet") - joptions = MapConverter().convert(options, - self._sc._gateway._gateway_client) if schema is None: - df = self._ssql_ctx.createExternalTable(tableName, source, joptions) + df = self._ssql_ctx.createExternalTable(tableName, source, options) else: if not isinstance(schema, StructType): raise TypeError("schema should be StructType") scala_datatype = self._ssql_ctx.parseDataType(schema.json()) df = self._ssql_ctx.createExternalTable(tableName, source, scala_datatype, - joptions) + options) return DataFrame(df, self) @ignore_unicode_prefix diff --git a/python/pyspark/sql/dataframe.py b/python/pyspark/sql/dataframe.py index d70c5b0a6930c..e8962c76852ee 100644 --- a/python/pyspark/sql/dataframe.py +++ b/python/pyspark/sql/dataframe.py @@ -25,8 +25,6 @@ else: from itertools import imap as map -from py4j.java_collections import ListConverter, MapConverter - from pyspark.context import SparkContext from pyspark.rdd import RDD, _load_from_socket, ignore_unicode_prefix from pyspark.serializers import BatchedSerializer, PickleSerializer, UTF8Deserializer @@ -186,9 +184,7 @@ def saveAsTable(self, tableName, source=None, mode="error", **options): source = self.sql_ctx.getConf("spark.sql.sources.default", "org.apache.spark.sql.parquet") jmode = self._java_save_mode(mode) - joptions = MapConverter().convert(options, - self.sql_ctx._sc._gateway._gateway_client) - self._jdf.saveAsTable(tableName, source, jmode, joptions) + self._jdf.saveAsTable(tableName, source, jmode, options) def save(self, path=None, source=None, mode="error", **options): """Saves the contents of the :class:`DataFrame` to a data source. @@ -211,9 +207,7 @@ def save(self, path=None, source=None, mode="error", **options): source = self.sql_ctx.getConf("spark.sql.sources.default", "org.apache.spark.sql.parquet") jmode = self._java_save_mode(mode) - joptions = MapConverter().convert(options, - self._sc._gateway._gateway_client) - self._jdf.save(source, jmode, joptions) + self._jdf.save(source, jmode, options) @property def schema(self): @@ -819,7 +813,6 @@ def fillna(self, value, subset=None): value = float(value) if isinstance(value, dict): - value = MapConverter().convert(value, self.sql_ctx._sc._gateway._gateway_client) return DataFrame(self._jdf.na().fill(value), self.sql_ctx) elif subset is None: return DataFrame(self._jdf.na().fill(value), self.sql_ctx) @@ -932,9 +925,7 @@ def agg(self, *exprs): """ assert exprs, "exprs should not be empty" if len(exprs) == 1 and isinstance(exprs[0], dict): - jmap = MapConverter().convert(exprs[0], - self.sql_ctx._sc._gateway._gateway_client) - jdf = self._jdf.agg(jmap) + jdf = self._jdf.agg(exprs[0]) else: # Columns assert all(isinstance(c, Column) for c in exprs), "all exprs should be Column" @@ -1040,8 +1031,7 @@ def _to_seq(sc, cols, converter=None): """ if converter: cols = [converter(c) for c in cols] - jcols = ListConverter().convert(cols, sc._gateway._gateway_client) - return sc._jvm.PythonUtils.toSeq(jcols) + return sc._jvm.PythonUtils.toSeq(cols) def _unary_op(name, doc="unary operator"): diff --git a/python/pyspark/streaming/context.py b/python/pyspark/streaming/context.py index 4590c58839266..ac5ba69e8dbbb 100644 --- a/python/pyspark/streaming/context.py +++ b/python/pyspark/streaming/context.py @@ -20,7 +20,6 @@ import os import sys -from py4j.java_collections import ListConverter from py4j.java_gateway import java_import, JavaObject from pyspark import RDD, SparkConf @@ -305,9 +304,7 @@ def queueStream(self, rdds, oneAtATime=True, default=None): rdds = [self._sc.parallelize(input) for input in rdds] self._check_serializers(rdds) - jrdds = ListConverter().convert([r._jrdd for r in rdds], - SparkContext._gateway._gateway_client) - queue = self._jvm.PythonDStream.toRDDQueue(jrdds) + queue = self._jvm.PythonDStream.toRDDQueue([r._jrdd for r in rdds]) if default: default = default._reserialize(rdds[0]._jrdd_deserializer) jdstream = self._jssc.queueStream(queue, oneAtATime, default._jrdd) @@ -322,8 +319,7 @@ def transform(self, dstreams, transformFunc): the transform function parameter will be the same as the order of corresponding DStreams in the list. """ - jdstreams = ListConverter().convert([d._jdstream for d in dstreams], - SparkContext._gateway._gateway_client) + jdstreams = [d._jdstream for d in dstreams] # change the final serializer to sc.serializer func = TransformFunction(self._sc, lambda t, *rdds: transformFunc(rdds).map(lambda x: x), @@ -346,6 +342,5 @@ def union(self, *dstreams): if len(set(s._slideDuration for s in dstreams)) > 1: raise ValueError("All DStreams should have same slide duration") first = dstreams[0] - jrest = ListConverter().convert([d._jdstream for d in dstreams[1:]], - SparkContext._gateway._gateway_client) + jrest = [d._jdstream for d in dstreams[1:]] return DStream(self._jssc.union(first._jdstream, jrest), self, first._jrdd_deserializer) diff --git a/python/pyspark/streaming/kafka.py b/python/pyspark/streaming/kafka.py index 7a7b6e1d9a527..8d610d6569b4a 100644 --- a/python/pyspark/streaming/kafka.py +++ b/python/pyspark/streaming/kafka.py @@ -15,8 +15,7 @@ # limitations under the License. # -from py4j.java_collections import MapConverter -from py4j.java_gateway import java_import, Py4JError, Py4JJavaError +from py4j.java_gateway import Py4JJavaError from pyspark.storagelevel import StorageLevel from pyspark.serializers import PairDeserializer, NoOpSerializer @@ -57,8 +56,6 @@ def createStream(ssc, zkQuorum, groupId, topics, kafkaParams={}, }) if not isinstance(topics, dict): raise TypeError("topics should be dict") - jtopics = MapConverter().convert(topics, ssc.sparkContext._gateway._gateway_client) - jparam = MapConverter().convert(kafkaParams, ssc.sparkContext._gateway._gateway_client) jlevel = ssc._sc._getJavaStorageLevel(storageLevel) try: @@ -66,7 +63,7 @@ def createStream(ssc, zkQuorum, groupId, topics, kafkaParams={}, helperClass = ssc._jvm.java.lang.Thread.currentThread().getContextClassLoader()\ .loadClass("org.apache.spark.streaming.kafka.KafkaUtilsPythonHelper") helper = helperClass.newInstance() - jstream = helper.createStream(ssc._jssc, jparam, jtopics, jlevel) + jstream = helper.createStream(ssc._jssc, kafkaParams, topics, jlevel) except Py4JJavaError as e: # TODO: use --jar once it also work on driver if 'ClassNotFoundException' in str(e.java_exception): diff --git a/python/pyspark/streaming/tests.py b/python/pyspark/streaming/tests.py index 06d22154373bc..33f958a601f3a 100644 --- a/python/pyspark/streaming/tests.py +++ b/python/pyspark/streaming/tests.py @@ -24,8 +24,6 @@ import struct from functools import reduce -from py4j.java_collections import MapConverter - from pyspark.context import SparkConf, SparkContext, RDD from pyspark.streaming.context import StreamingContext from pyspark.streaming.kafka import KafkaUtils @@ -581,11 +579,9 @@ def test_kafka_stream(self): """Test the Python Kafka stream API.""" topic = "topic1" sendData = {"a": 3, "b": 5, "c": 10} - jSendData = MapConverter().convert(sendData, - self.ssc.sparkContext._gateway._gateway_client) self._kafkaTestUtils.createTopic(topic) - self._kafkaTestUtils.sendMessages(topic, jSendData) + self._kafkaTestUtils.sendMessages(topic, sendData) stream = KafkaUtils.createStream(self.ssc, self._kafkaTestUtils.zkAddress(), "test-streaming-consumer", {topic: 1},