Skip to content

Commit

Permalink
enable auto convert
Browse files Browse the repository at this point in the history
  • Loading branch information
Davies Liu committed Apr 18, 2015
1 parent d850b4b commit cb094ff
Show file tree
Hide file tree
Showing 8 changed files with 27 additions and 49 deletions.
6 changes: 1 addition & 5 deletions python/pyspark/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,6 @@
from threading import Lock
from tempfile import NamedTemporaryFile

from py4j.java_collections import ListConverter

from pyspark import accumulators
from pyspark.accumulators import Accumulator
from pyspark.broadcast import Broadcast
Expand Down Expand Up @@ -643,7 +641,6 @@ def union(self, rdds):
rdds = [x._reserialize() for x in rdds]
first = rdds[0]._jrdd
rest = [x._jrdd for x in rdds[1:]]
rest = ListConverter().convert(rest, self._gateway._gateway_client)
return RDD(self._jsc.union(first, rest), self, rdds[0]._jrdd_deserializer)

def broadcast(self, value):
Expand Down Expand Up @@ -846,13 +843,12 @@ def runJob(self, rdd, partitionFunc, partitions=None, allowLocal=False):
"""
if partitions is None:
partitions = range(rdd._jrdd.partitions().size())
javaPartitions = ListConverter().convert(partitions, self._gateway._gateway_client)

# Implementation note: This is implemented as a mapPartitions followed
# by runJob() in order to avoid having to pass a Python lambda into
# SparkContext#runJob.
mappedRDD = rdd.mapPartitions(partitionFunc)
port = self._jvm.PythonRDD.runJob(self._jsc.sc(), mappedRDD._jrdd, javaPartitions,
port = self._jvm.PythonRDD.runJob(self._jsc.sc(), mappedRDD._jrdd, partitions,
allowLocal)
return list(_load_from_socket(port, mappedRDD._jrdd_deserializer))

Expand Down
11 changes: 10 additions & 1 deletion python/pyspark/java_gateway.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,11 +23,20 @@
import socket
import platform
from subprocess import Popen, PIPE

from py4j.java_gateway import java_import, JavaGateway, GatewayClient
from py4j.java_collections import ListConverter

from pyspark.serializers import read_int


# patching ListConverter, or it will convert bytearray into Java ArrayList
def can_convert_list(self, obj):
return isinstance(obj, list)

ListConverter.can_convert = can_convert_list


def launch_gateway():
if "PYSPARK_GATEWAY_PORT" in os.environ:
gateway_port = int(os.environ["PYSPARK_GATEWAY_PORT"])
Expand Down Expand Up @@ -92,7 +101,7 @@ def killChild():
atexit.register(killChild)

# Connect to the gateway
gateway = JavaGateway(GatewayClient(port=gateway_port), auto_convert=False)
gateway = JavaGateway(GatewayClient(port=gateway_port), auto_convert=True)

# Import the classes used by PySpark
java_import(gateway.jvm, "org.apache.spark.SparkConf")
Expand Down
4 changes: 2 additions & 2 deletions python/pyspark/mllib/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
import py4j.protocol
from py4j.protocol import Py4JJavaError
from py4j.java_gateway import JavaObject
from py4j.java_collections import ListConverter, JavaArray, JavaList
from py4j.java_collections import JavaArray, JavaList

from pyspark import RDD, SparkContext
from pyspark.serializers import PickleSerializer, AutoBatchedSerializer
Expand Down Expand Up @@ -76,7 +76,7 @@ def _py2java(sc, obj):
elif isinstance(obj, SparkContext):
obj = obj._jsc
elif isinstance(obj, list):
obj = ListConverter().convert([_py2java(sc, x) for x in obj], sc._gateway._gateway_client)
obj = [_py2java(sc, x) for x in obj]
elif isinstance(obj, JavaObject):
pass
elif isinstance(obj, (int, long, float, bool, bytes, unicode)):
Expand Down
13 changes: 4 additions & 9 deletions python/pyspark/sql/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,6 @@
from itertools import imap as map

from py4j.protocol import Py4JError
from py4j.java_collections import MapConverter

from pyspark.rdd import RDD, _prepare_for_python_RDD, ignore_unicode_prefix
from pyspark.serializers import AutoBatchedSerializer, PickleSerializer
Expand Down Expand Up @@ -442,15 +441,13 @@ def load(self, path=None, source=None, schema=None, **options):
if source is None:
source = self.getConf("spark.sql.sources.default",
"org.apache.spark.sql.parquet")
joptions = MapConverter().convert(options,
self._sc._gateway._gateway_client)
if schema is None:
df = self._ssql_ctx.load(source, joptions)
df = self._ssql_ctx.load(source, options)
else:
if not isinstance(schema, StructType):
raise TypeError("schema should be StructType")
scala_datatype = self._ssql_ctx.parseDataType(schema.json())
df = self._ssql_ctx.load(source, scala_datatype, joptions)
df = self._ssql_ctx.load(source, scala_datatype, options)
return DataFrame(df, self)

def createExternalTable(self, tableName, path=None, source=None,
Expand All @@ -471,16 +468,14 @@ def createExternalTable(self, tableName, path=None, source=None,
if source is None:
source = self.getConf("spark.sql.sources.default",
"org.apache.spark.sql.parquet")
joptions = MapConverter().convert(options,
self._sc._gateway._gateway_client)
if schema is None:
df = self._ssql_ctx.createExternalTable(tableName, source, joptions)
df = self._ssql_ctx.createExternalTable(tableName, source, options)
else:
if not isinstance(schema, StructType):
raise TypeError("schema should be StructType")
scala_datatype = self._ssql_ctx.parseDataType(schema.json())
df = self._ssql_ctx.createExternalTable(tableName, source, scala_datatype,
joptions)
options)
return DataFrame(df, self)

@ignore_unicode_prefix
Expand Down
18 changes: 4 additions & 14 deletions python/pyspark/sql/dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,8 +25,6 @@
else:
from itertools import imap as map

from py4j.java_collections import ListConverter, MapConverter

from pyspark.context import SparkContext
from pyspark.rdd import RDD, _load_from_socket, ignore_unicode_prefix
from pyspark.serializers import BatchedSerializer, PickleSerializer, UTF8Deserializer
Expand Down Expand Up @@ -186,9 +184,7 @@ def saveAsTable(self, tableName, source=None, mode="error", **options):
source = self.sql_ctx.getConf("spark.sql.sources.default",
"org.apache.spark.sql.parquet")
jmode = self._java_save_mode(mode)
joptions = MapConverter().convert(options,
self.sql_ctx._sc._gateway._gateway_client)
self._jdf.saveAsTable(tableName, source, jmode, joptions)
self._jdf.saveAsTable(tableName, source, jmode, options)

def save(self, path=None, source=None, mode="error", **options):
"""Saves the contents of the :class:`DataFrame` to a data source.
Expand All @@ -211,9 +207,7 @@ def save(self, path=None, source=None, mode="error", **options):
source = self.sql_ctx.getConf("spark.sql.sources.default",
"org.apache.spark.sql.parquet")
jmode = self._java_save_mode(mode)
joptions = MapConverter().convert(options,
self._sc._gateway._gateway_client)
self._jdf.save(source, jmode, joptions)
self._jdf.save(source, jmode, options)

@property
def schema(self):
Expand Down Expand Up @@ -819,7 +813,6 @@ def fillna(self, value, subset=None):
value = float(value)

if isinstance(value, dict):
value = MapConverter().convert(value, self.sql_ctx._sc._gateway._gateway_client)
return DataFrame(self._jdf.na().fill(value), self.sql_ctx)
elif subset is None:
return DataFrame(self._jdf.na().fill(value), self.sql_ctx)
Expand Down Expand Up @@ -932,9 +925,7 @@ def agg(self, *exprs):
"""
assert exprs, "exprs should not be empty"
if len(exprs) == 1 and isinstance(exprs[0], dict):
jmap = MapConverter().convert(exprs[0],
self.sql_ctx._sc._gateway._gateway_client)
jdf = self._jdf.agg(jmap)
jdf = self._jdf.agg(exprs[0])
else:
# Columns
assert all(isinstance(c, Column) for c in exprs), "all exprs should be Column"
Expand Down Expand Up @@ -1040,8 +1031,7 @@ def _to_seq(sc, cols, converter=None):
"""
if converter:
cols = [converter(c) for c in cols]
jcols = ListConverter().convert(cols, sc._gateway._gateway_client)
return sc._jvm.PythonUtils.toSeq(jcols)
return sc._jvm.PythonUtils.toSeq(cols)


def _unary_op(name, doc="unary operator"):
Expand Down
11 changes: 3 additions & 8 deletions python/pyspark/streaming/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@
import os
import sys

from py4j.java_collections import ListConverter
from py4j.java_gateway import java_import, JavaObject

from pyspark import RDD, SparkConf
Expand Down Expand Up @@ -305,9 +304,7 @@ def queueStream(self, rdds, oneAtATime=True, default=None):
rdds = [self._sc.parallelize(input) for input in rdds]
self._check_serializers(rdds)

jrdds = ListConverter().convert([r._jrdd for r in rdds],
SparkContext._gateway._gateway_client)
queue = self._jvm.PythonDStream.toRDDQueue(jrdds)
queue = self._jvm.PythonDStream.toRDDQueue([r._jrdd for r in rdds])
if default:
default = default._reserialize(rdds[0]._jrdd_deserializer)
jdstream = self._jssc.queueStream(queue, oneAtATime, default._jrdd)
Expand All @@ -322,8 +319,7 @@ def transform(self, dstreams, transformFunc):
the transform function parameter will be the same as the order
of corresponding DStreams in the list.
"""
jdstreams = ListConverter().convert([d._jdstream for d in dstreams],
SparkContext._gateway._gateway_client)
jdstreams = [d._jdstream for d in dstreams]
# change the final serializer to sc.serializer
func = TransformFunction(self._sc,
lambda t, *rdds: transformFunc(rdds).map(lambda x: x),
Expand All @@ -346,6 +342,5 @@ def union(self, *dstreams):
if len(set(s._slideDuration for s in dstreams)) > 1:
raise ValueError("All DStreams should have same slide duration")
first = dstreams[0]
jrest = ListConverter().convert([d._jdstream for d in dstreams[1:]],
SparkContext._gateway._gateway_client)
jrest = [d._jdstream for d in dstreams[1:]]
return DStream(self._jssc.union(first._jdstream, jrest), self, first._jrdd_deserializer)
7 changes: 2 additions & 5 deletions python/pyspark/streaming/kafka.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,7 @@
# limitations under the License.
#

from py4j.java_collections import MapConverter
from py4j.java_gateway import java_import, Py4JError, Py4JJavaError
from py4j.java_gateway import Py4JJavaError

from pyspark.storagelevel import StorageLevel
from pyspark.serializers import PairDeserializer, NoOpSerializer
Expand Down Expand Up @@ -57,16 +56,14 @@ def createStream(ssc, zkQuorum, groupId, topics, kafkaParams={},
})
if not isinstance(topics, dict):
raise TypeError("topics should be dict")
jtopics = MapConverter().convert(topics, ssc.sparkContext._gateway._gateway_client)
jparam = MapConverter().convert(kafkaParams, ssc.sparkContext._gateway._gateway_client)
jlevel = ssc._sc._getJavaStorageLevel(storageLevel)

try:
# Use KafkaUtilsPythonHelper to access Scala's KafkaUtils (see SPARK-6027)
helperClass = ssc._jvm.java.lang.Thread.currentThread().getContextClassLoader()\
.loadClass("org.apache.spark.streaming.kafka.KafkaUtilsPythonHelper")
helper = helperClass.newInstance()
jstream = helper.createStream(ssc._jssc, jparam, jtopics, jlevel)
jstream = helper.createStream(ssc._jssc, kafkaParams, topics, jlevel)
except Py4JJavaError as e:
# TODO: use --jar once it also work on driver
if 'ClassNotFoundException' in str(e.java_exception):
Expand Down
6 changes: 1 addition & 5 deletions python/pyspark/streaming/tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,6 @@
import struct
from functools import reduce

from py4j.java_collections import MapConverter

from pyspark.context import SparkConf, SparkContext, RDD
from pyspark.streaming.context import StreamingContext
from pyspark.streaming.kafka import KafkaUtils
Expand Down Expand Up @@ -581,11 +579,9 @@ def test_kafka_stream(self):
"""Test the Python Kafka stream API."""
topic = "topic1"
sendData = {"a": 3, "b": 5, "c": 10}
jSendData = MapConverter().convert(sendData,
self.ssc.sparkContext._gateway._gateway_client)

self._kafkaTestUtils.createTopic(topic)
self._kafkaTestUtils.sendMessages(topic, jSendData)
self._kafkaTestUtils.sendMessages(topic, sendData)

stream = KafkaUtils.createStream(self.ssc, self._kafkaTestUtils.zkAddress(),
"test-streaming-consumer", {topic: 1},
Expand Down

0 comments on commit cb094ff

Please sign in to comment.