From 67882d2d4ebfad955b07cf0020c726ea5a153864 Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Tue, 13 Dec 2016 03:47:10 +0000 Subject: [PATCH 1/3] Consumes the returned local iterator immediately to prevent timeout on the socket serving the data. --- .../org/apache/spark/api/python/PythonRDD.scala | 11 ++++++++++- python/pyspark/rdd.py | 10 +++++++++- python/pyspark/sql/dataframe.py | 9 ++++++++- python/pyspark/sql/tests.py | 12 ++++++++++++ python/pyspark/tests.py | 12 ++++++++++++ 5 files changed, 51 insertions(+), 3 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala b/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala index 0ca91b9bf86c6..492f754e51352 100644 --- a/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala +++ b/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala @@ -454,7 +454,16 @@ private[spark] object PythonRDD extends Logging { } def toLocalIteratorAndServe[T](rdd: RDD[T]): Int = { - serveIterator(rdd.toLocalIterator, s"serve toLocalIterator") + // To prevent the materilization of the RDD takes too long to cause timeout in writing the + // first element, we fetch the first element manually and re-construct the iterator. + val localIter = rdd.toLocalIterator + val iter = if (localIter.hasNext) { + val peek = localIter.next() + Seq(peek).iterator ++ localIter + } else { + localIter + } + serveIterator(iter, s"serve toLocalIterator") } def readRDDFromFile(sc: JavaSparkContext, filename: String, parallelism: Int): diff --git a/python/pyspark/rdd.py b/python/pyspark/rdd.py index 9e05da89af082..95d597017a204 100644 --- a/python/pyspark/rdd.py +++ b/python/pyspark/rdd.py @@ -135,6 +135,9 @@ def _load_from_socket(port, serializer): break if not sock: raise Exception("could not open socket") + # The RDD materialization time is unpredicable, if we set a timeout for socket reading + # operation, it will very possibly fail. See SPARK-18281. + sock.settimeout(None) try: rf = sock.makefile("rb", 65536) for item in serializer.load_stream(rf): @@ -2349,7 +2352,12 @@ def toLocalIterator(self): """ with SCCallSiteSync(self.context) as css: port = self.ctx._jvm.PythonRDD.toLocalIteratorAndServe(self._jrdd.rdd()) - return _load_from_socket(port, self._jrdd_deserializer) + # We set a timeout for connecting socket. The connection only begins when we start + # to consume the first element. If we do not begin to consume the returned iterator + # immediately, there will be a failure. + iter = _load_from_socket(port, self._jrdd_deserializer) + peek = next(iter) + return chain([peek], iter) def _prepare_for_python_RDD(sc, command): diff --git a/python/pyspark/sql/dataframe.py b/python/pyspark/sql/dataframe.py index b9d90384e3e2c..c769ba3f94994 100644 --- a/python/pyspark/sql/dataframe.py +++ b/python/pyspark/sql/dataframe.py @@ -26,6 +26,8 @@ else: from itertools import imap as map +from itertools import chain + from pyspark import copy_func, since from pyspark.rdd import RDD, _load_from_socket, ignore_unicode_prefix from pyspark.serializers import BatchedSerializer, PickleSerializer, UTF8Deserializer @@ -403,7 +405,12 @@ def toLocalIterator(self): """ with SCCallSiteSync(self._sc) as css: port = self._jdf.toPythonIterator() - return _load_from_socket(port, BatchedSerializer(PickleSerializer())) + # We set a timeout for connecting socket. The connection only begins when we start + # to consume the first element. If we do not begin to consume the returned iterator + # immediately, there will be a failure. + iter = _load_from_socket(port, BatchedSerializer(PickleSerializer())) + peek = next(iter) + return chain([peek], iter) @ignore_unicode_prefix @since(1.3) diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py index af7d52cdace87..568105caf0678 100644 --- a/python/pyspark/sql/tests.py +++ b/python/pyspark/sql/tests.py @@ -558,6 +558,18 @@ def test_create_dataframe_from_objects(self): self.assertEqual(df.dtypes, [("key", "bigint"), ("value", "string")]) self.assertEqual(df.first(), Row(key=1, value="1")) + def test_to_localiterator_for_dataframe(self): + from time import sleep + df = self.spark.createDataFrame([[1], [2], [3]]) + it = df.toLocalIterator() + sleep(5) + self.assertEqual([Row(_1=1), Row(_1=2), Row(_1=3)], sorted(it)) + + df2 = df.repartition(1000) + it2 = df2.toLocalIterator() + sleep(5) + self.assertEqual([Row(_1=1), Row(_1=2), Row(_1=3)], sorted(it2)) + def test_select_null_literal(self): df = self.spark.sql("select null as col") self.assertEqual(Row(col=None), df.first()) diff --git a/python/pyspark/tests.py b/python/pyspark/tests.py index 89fce8ab25baf..fe314c54a1b18 100644 --- a/python/pyspark/tests.py +++ b/python/pyspark/tests.py @@ -502,6 +502,18 @@ def test_sum(self): self.assertEqual(0, self.sc.emptyRDD().sum()) self.assertEqual(6, self.sc.parallelize([1, 2, 3]).sum()) + def test_to_localiterator(self): + from time import sleep + rdd = self.sc.parallelize([1, 2, 3]) + it = rdd.toLocalIterator() + sleep(5) + self.assertEqual([1, 2, 3], sorted(it)) + + rdd2 = rdd.repartition(1000) + it2 = rdd2.toLocalIterator() + sleep(5) + self.assertEqual([1, 2, 3], sorted(it2)) + def test_save_as_textfile_with_unicode(self): # Regression test for SPARK-970 x = u"\u00A1Hola, mundo!" From e3e20727241bca70173e344d6bf13955bcfc8ce7 Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Sat, 17 Dec 2016 03:19:35 +0000 Subject: [PATCH 2/3] Address comment. --- .../org/apache/spark/api/python/PythonRDD.scala | 11 +---------- python/pyspark/rdd.py | 16 ++++------------ python/pyspark/sql/dataframe.py | 9 +-------- 3 files changed, 6 insertions(+), 30 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala b/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala index 492f754e51352..0ca91b9bf86c6 100644 --- a/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala +++ b/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala @@ -454,16 +454,7 @@ private[spark] object PythonRDD extends Logging { } def toLocalIteratorAndServe[T](rdd: RDD[T]): Int = { - // To prevent the materilization of the RDD takes too long to cause timeout in writing the - // first element, we fetch the first element manually and re-construct the iterator. - val localIter = rdd.toLocalIterator - val iter = if (localIter.hasNext) { - val peek = localIter.next() - Seq(peek).iterator ++ localIter - } else { - localIter - } - serveIterator(iter, s"serve toLocalIterator") + serveIterator(rdd.toLocalIterator, s"serve toLocalIterator") } def readRDDFromFile(sc: JavaSparkContext, filename: String, parallelism: Int): diff --git a/python/pyspark/rdd.py b/python/pyspark/rdd.py index 95d597017a204..45ec7f4c34844 100644 --- a/python/pyspark/rdd.py +++ b/python/pyspark/rdd.py @@ -138,12 +138,9 @@ def _load_from_socket(port, serializer): # The RDD materialization time is unpredicable, if we set a timeout for socket reading # operation, it will very possibly fail. See SPARK-18281. sock.settimeout(None) - try: - rf = sock.makefile("rb", 65536) - for item in serializer.load_stream(rf): - yield item - finally: - sock.close() + # The socket will be automatically closed when garbage-collected. + rf = sock.makefile("rb", 65536) + return serializer.load_stream(sock.makefile("rb", 65536)) def ignore_unicode_prefix(f): @@ -2352,12 +2349,7 @@ def toLocalIterator(self): """ with SCCallSiteSync(self.context) as css: port = self.ctx._jvm.PythonRDD.toLocalIteratorAndServe(self._jrdd.rdd()) - # We set a timeout for connecting socket. The connection only begins when we start - # to consume the first element. If we do not begin to consume the returned iterator - # immediately, there will be a failure. - iter = _load_from_socket(port, self._jrdd_deserializer) - peek = next(iter) - return chain([peek], iter) + return _load_from_socket(port, self._jrdd_deserializer) def _prepare_for_python_RDD(sc, command): diff --git a/python/pyspark/sql/dataframe.py b/python/pyspark/sql/dataframe.py index c769ba3f94994..b9d90384e3e2c 100644 --- a/python/pyspark/sql/dataframe.py +++ b/python/pyspark/sql/dataframe.py @@ -26,8 +26,6 @@ else: from itertools import imap as map -from itertools import chain - from pyspark import copy_func, since from pyspark.rdd import RDD, _load_from_socket, ignore_unicode_prefix from pyspark.serializers import BatchedSerializer, PickleSerializer, UTF8Deserializer @@ -405,12 +403,7 @@ def toLocalIterator(self): """ with SCCallSiteSync(self._sc) as css: port = self._jdf.toPythonIterator() - # We set a timeout for connecting socket. The connection only begins when we start - # to consume the first element. If we do not begin to consume the returned iterator - # immediately, there will be a failure. - iter = _load_from_socket(port, BatchedSerializer(PickleSerializer())) - peek = next(iter) - return chain([peek], iter) + return _load_from_socket(port, BatchedSerializer(PickleSerializer())) @ignore_unicode_prefix @since(1.3) From e02695efab1c6b42578e52c74c18b60cb74dec78 Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Mon, 19 Dec 2016 23:43:00 +0000 Subject: [PATCH 3/3] Address comments. --- python/pyspark/rdd.py | 1 - python/pyspark/sql/tests.py | 12 ------------ 2 files changed, 13 deletions(-) diff --git a/python/pyspark/rdd.py b/python/pyspark/rdd.py index 45ec7f4c34844..b384b2b507332 100644 --- a/python/pyspark/rdd.py +++ b/python/pyspark/rdd.py @@ -139,7 +139,6 @@ def _load_from_socket(port, serializer): # operation, it will very possibly fail. See SPARK-18281. sock.settimeout(None) # The socket will be automatically closed when garbage-collected. - rf = sock.makefile("rb", 65536) return serializer.load_stream(sock.makefile("rb", 65536)) diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py index 568105caf0678..af7d52cdace87 100644 --- a/python/pyspark/sql/tests.py +++ b/python/pyspark/sql/tests.py @@ -558,18 +558,6 @@ def test_create_dataframe_from_objects(self): self.assertEqual(df.dtypes, [("key", "bigint"), ("value", "string")]) self.assertEqual(df.first(), Row(key=1, value="1")) - def test_to_localiterator_for_dataframe(self): - from time import sleep - df = self.spark.createDataFrame([[1], [2], [3]]) - it = df.toLocalIterator() - sleep(5) - self.assertEqual([Row(_1=1), Row(_1=2), Row(_1=3)], sorted(it)) - - df2 = df.repartition(1000) - it2 = df2.toLocalIterator() - sleep(5) - self.assertEqual([Row(_1=1), Row(_1=2), Row(_1=3)], sorted(it2)) - def test_select_null_literal(self): df = self.spark.sql("select null as col") self.assertEqual(Row(col=None), df.first())