Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SPARK-6194] [SPARK-677] [PySpark] fix memory leak in collect() #4923

Closed
wants to merge 4 commits into from
Closed
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
59 changes: 42 additions & 17 deletions core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala
Original file line number Diff line number Diff line change
Expand Up @@ -19,26 +19,27 @@ package org.apache.spark.api.python

import java.io._
import java.net._
import java.util.{List => JList, ArrayList => JArrayList, Map => JMap, UUID, Collections}

import org.apache.spark.input.PortableDataStream
import java.util.{Collections, ArrayList => JArrayList, List => JList, Map => JMap}

import scala.collection.JavaConversions._
import scala.collection.mutable
import scala.language.existentials

import com.google.common.base.Charsets.UTF_8

import org.apache.hadoop.conf.Configuration
import org.apache.hadoop.io.compress.CompressionCodec
import org.apache.hadoop.mapred.{InputFormat, OutputFormat, JobConf}
import org.apache.hadoop.mapred.{InputFormat, JobConf, OutputFormat}
import org.apache.hadoop.mapreduce.{InputFormat => NewInputFormat, OutputFormat => NewOutputFormat}

import org.apache.spark._
import org.apache.spark.api.java.{JavaSparkContext, JavaPairRDD, JavaRDD}
import org.apache.spark.api.java.{JavaPairRDD, JavaRDD, JavaSparkContext}
import org.apache.spark.broadcast.Broadcast
import org.apache.spark.input.PortableDataStream
import org.apache.spark.rdd.RDD
import org.apache.spark.util.Utils

import scala.util.control.NonFatal

private[spark] class PythonRDD(
@transient parent: RDD[_],
command: Array[Byte],
Expand Down Expand Up @@ -341,21 +342,28 @@ private[spark] object PythonRDD extends Logging {
/**
* Adapter for calling SparkContext#runJob from Python.
*
* This method will return an iterator of an array that contains all elements in the RDD
* This method will serve an iterator of an array that contains all elements in the RDD
* (effectively a collect()), but allows you to run on a certain subset of partitions,
* or to enable local execution.
*/
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Even with the updated description, this method's return type could be confusing to new readers.

It might help to add an explicit description of the return type, e.g.

@return the port number of a local socket which serves the data collected from this job.

We should also document the lifecycle of the server socket created by this method: what happens if a client does not consume the whole iterator? Is it the caller's responsibility to close the socket at that point?

def runJob(
sc: SparkContext,
rdd: JavaRDD[Array[Byte]],
partitions: JArrayList[Int],
allowLocal: Boolean): Iterator[Array[Byte]] = {
allowLocal: Boolean): Int = {
type ByteArray = Array[Byte]
type UnrolledPartition = Array[ByteArray]
val allPartitions: Array[UnrolledPartition] =
sc.runJob(rdd, (x: Iterator[ByteArray]) => x.toArray, partitions, allowLocal)
val flattenedPartition: UnrolledPartition = Array.concat(allPartitions: _*)
flattenedPartition.iterator
serveIterator(flattenedPartition.iterator)
}

/**
* A helper function to collect an RDD as an iterator, then serve it via socket
*/
def collectAndServe[T](rdd: RDD[T]): Int = {
serveIterator(rdd.collect().iterator)
}

def readRDDFromFile(sc: JavaSparkContext, filename: String, parallelism: Int):
Expand Down Expand Up @@ -575,15 +583,32 @@ private[spark] object PythonRDD extends Logging {
dataOut.write(bytes)
}

def writeToFile[T](items: java.util.Iterator[T], filename: String) {
import scala.collection.JavaConverters._
writeToFile(items.asScala, filename)
}
private def serveIterator[T](items: Iterator[T]): Int = {
val serverSocket = new ServerSocket(0, 1)
serverSocket.setReuseAddress(true)
serverSocket.setSoTimeout(3000)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What's the motivation for this timeout? Might be nice to add a comment.


new Thread("serve iterator") {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe we could add a second threadName parameter to this method and pass some more informative identifier to it, such as the RDD id or RDD + partition id.

setDaemon(true)
override def run() {
try {
val sock = serverSocket.accept()
val out = new DataOutputStream(new BufferedOutputStream(sock.getOutputStream))
try {
writeIteratorToStream(items, out)
} finally {
out.close()
}
} catch {
case NonFatal(e) =>
logError(s"Error while sending iterator: $e")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If we change this to

logError("Error while sending iterator", e)

then the full stacktrace will be printed as part of the log message.

} finally {
serverSocket.close()
}
}
}.start()

def writeToFile[T](items: Iterator[T], filename: String) {
val file = new DataOutputStream(new FileOutputStream(filename))
writeIteratorToStream(items, file)
file.close()
serverSocket.getLocalPort
}

private def getMergedConf(confAsMap: java.util.HashMap[String, String],
Expand Down
13 changes: 6 additions & 7 deletions python/pyspark/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,8 @@
from threading import Lock
from tempfile import NamedTemporaryFile

from py4j.java_collections import ListConverter

from pyspark import accumulators
from pyspark.accumulators import Accumulator
from pyspark.broadcast import Broadcast
Expand All @@ -30,13 +32,11 @@
from pyspark.serializers import PickleSerializer, BatchedSerializer, UTF8Deserializer, \
PairDeserializer, AutoBatchedSerializer, NoOpSerializer
from pyspark.storagelevel import StorageLevel
from pyspark.rdd import RDD
from pyspark.rdd import RDD, _load_from_socket
from pyspark.traceback_utils import CallSite, first_spark_call
from pyspark.status import StatusTracker
from pyspark.profiler import ProfilerCollector, BasicProfiler

from py4j.java_collections import ListConverter


__all__ = ['SparkContext']

Expand All @@ -59,7 +59,6 @@ class SparkContext(object):

_gateway = None
_jvm = None
_writeToFile = None
_next_accum_id = 0
_active_spark_context = None
_lock = Lock()
Expand Down Expand Up @@ -221,7 +220,6 @@ def _ensure_initialized(cls, instance=None, gateway=None):
if not SparkContext._gateway:
SparkContext._gateway = gateway or launch_gateway()
SparkContext._jvm = SparkContext._gateway.jvm
SparkContext._writeToFile = SparkContext._jvm.PythonRDD.writeToFile

if instance:
if (SparkContext._active_spark_context and
Expand Down Expand Up @@ -840,8 +838,9 @@ def runJob(self, rdd, partitionFunc, partitions=None, allowLocal=False):
# by runJob() in order to avoid having to pass a Python lambda into
# SparkContext#runJob.
mappedRDD = rdd.mapPartitions(partitionFunc)
it = self._jvm.PythonRDD.runJob(self._jsc.sc(), mappedRDD._jrdd, javaPartitions, allowLocal)
return list(mappedRDD._collect_iterator_through_file(it))
port = self._jvm.PythonRDD.runJob(self._jsc.sc(), mappedRDD._jrdd, javaPartitions,
allowLocal)
return list(_load_from_socket(port, mappedRDD._jrdd_deserializer))

def show_profiles(self):
""" Print the profile stats to stdout """
Expand Down
30 changes: 14 additions & 16 deletions python/pyspark/rdd.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@
from collections import defaultdict
from itertools import chain, ifilter, imap
import operator
import os
import sys
import shlex
from subprocess import Popen, PIPE
Expand All @@ -29,6 +28,7 @@
import heapq
import bisect
import random
import socket
from math import sqrt, log, isinf, isnan, pow, ceil

from pyspark.serializers import NoOpSerializer, CartesianDeserializer, \
Expand Down Expand Up @@ -111,6 +111,17 @@ def _parse_memory(s):
return int(float(s[:-1]) * units[s[-1].lower()])


def _load_from_socket(port, serializer):
sock = socket.socket()
try:
sock.connect(("localhost", port))
rf = sock.makefile("rb", 65536)
for item in serializer.load_stream(rf):
yield item
finally:
sock.close()


class Partitioner(object):
def __init__(self, numPartitions, partitionFunc):
self.numPartitions = numPartitions
Expand Down Expand Up @@ -698,21 +709,8 @@ def collect(self):
Return a list that contains all of the elements in this RDD.
"""
with SCCallSiteSync(self.context) as css:
bytesInJava = self._jrdd.collect().iterator()
return list(self._collect_iterator_through_file(bytesInJava))

def _collect_iterator_through_file(self, iterator):
# Transferring lots of data through Py4J can be slow because
# socket.readline() is inefficient. Instead, we'll dump the data to a
# file and read it back.
tempFile = NamedTemporaryFile(delete=False, dir=self.ctx._temp_dir)
tempFile.close()
self.ctx._writeToFile(iterator, tempFile.name)
# Read the data into Python and deserialize it:
with open(tempFile.name, 'rb') as tempFile:
for item in self._jrdd_deserializer.load_stream(tempFile):
yield item
os.unlink(tempFile.name)
port = self.ctx._jvm.PythonRDD.collectAndServe(self._jrdd.rdd())
return list(_load_from_socket(port, self._jrdd_deserializer))

def reduce(self, f):
"""
Expand Down
14 changes: 3 additions & 11 deletions python/pyspark/sql/dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,13 +19,11 @@
import itertools
import warnings
import random
import os
from tempfile import NamedTemporaryFile

from py4j.java_collections import ListConverter, MapConverter

from pyspark.context import SparkContext
from pyspark.rdd import RDD
from pyspark.rdd import RDD, _load_from_socket
from pyspark.serializers import BatchedSerializer, PickleSerializer, UTF8Deserializer
from pyspark.storagelevel import StorageLevel
from pyspark.traceback_utils import SCCallSiteSync
Expand Down Expand Up @@ -310,14 +308,8 @@ def collect(self):
[Row(age=2, name=u'Alice'), Row(age=5, name=u'Bob')]
"""
with SCCallSiteSync(self._sc) as css:
bytesInJava = self._jdf.javaToPython().collect().iterator()
tempFile = NamedTemporaryFile(delete=False, dir=self._sc._temp_dir)
tempFile.close()
self._sc._writeToFile(bytesInJava, tempFile.name)
# Read the data into Python and deserialize it:
with open(tempFile.name, 'rb') as tempFile:
rs = list(BatchedSerializer(PickleSerializer()).load_stream(tempFile))
os.unlink(tempFile.name)
port = self._sc._jvm.PythonRDD.collectAndServe(self._jdf.javaToPython().rdd())
rs = list(_load_from_socket(port, BatchedSerializer(PickleSerializer())))
cls = _create_cls(self.schema)
return [cls(r) for r in rs]

Expand Down