Skip to content

Commit

Permalink
[SPARK-2024] Better type checking for batch serialized RDD
Browse files Browse the repository at this point in the history
  • Loading branch information
kanzhang committed Jul 29, 2014
1 parent 0bdec55 commit 75ca5bd
Show file tree
Hide file tree
Showing 4 changed files with 44 additions and 18 deletions.
12 changes: 8 additions & 4 deletions core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala
Original file line number Diff line number Diff line change
Expand Up @@ -607,9 +607,11 @@ private[spark] object PythonRDD extends Logging {
*/
def saveAsSequenceFile[K, V, C <: CompressionCodec](
pyRDD: JavaRDD[Array[Byte]],
batchSerialized: Boolean,
path: String,
compressionCodecClass: String) = {
saveAsHadoopFile(pyRDD, path, "org.apache.hadoop.mapred.SequenceFileOutputFormat",
saveAsHadoopFile(
pyRDD, batchSerialized, path, "org.apache.hadoop.mapred.SequenceFileOutputFormat",
null, null, null, null, new java.util.HashMap(), compressionCodecClass, false)
}

Expand All @@ -625,6 +627,7 @@ private[spark] object PythonRDD extends Logging {
def saveAsHadoopFile[K, V, F <: OutputFormat[_, _], G <: NewOutputFormat[_, _],
C <: CompressionCodec](
pyRDD: JavaRDD[Array[Byte]],
batchSerialized: Boolean,
path: String,
outputFormatClass: String,
keyClass: String,
Expand All @@ -634,7 +637,7 @@ private[spark] object PythonRDD extends Logging {
confAsMap: java.util.HashMap[String, String],
compressionCodecClass: String,
useNewAPI: Boolean) = {
val rdd = SerDeUtil.pythonToPairRDD(pyRDD)
val rdd = SerDeUtil.pythonToPairRDD(pyRDD, batchSerialized)
val (kc, vc) = getKeyValueTypes(keyClass, valueClass).getOrElse(
inferKeyValueTypes(rdd, keyConverterClass, valueConverterClass))
val mergedConf = getMergedConf(confAsMap, pyRDD.context.hadoopConfiguration)
Expand All @@ -660,13 +663,14 @@ private[spark] object PythonRDD extends Logging {
*/
def saveAsHadoopDataset[K, V](
pyRDD: JavaRDD[Array[Byte]],
batchSerialzied: Boolean,
confAsMap: java.util.HashMap[String, String],
useNewAPI: Boolean,
keyConverterClass: String,
valueConverterClass: String) = {
val conf = PythonHadoopUtil.mapToConf(confAsMap)
val converted = convertRDD(SerDeUtil.pythonToPairRDD(pyRDD), keyConverterClass,
valueConverterClass, new JavaToWritableConverter)
val converted = convertRDD(SerDeUtil.pythonToPairRDD(pyRDD, batchSerialzied),
keyConverterClass, valueConverterClass, new JavaToWritableConverter)
if (useNewAPI) {
converted.saveAsNewAPIHadoopDataset(conf)
} else {
Expand Down
20 changes: 12 additions & 8 deletions core/src/main/scala/org/apache/spark/api/python/SerDeUtil.scala
Original file line number Diff line number Diff line change
Expand Up @@ -86,21 +86,25 @@ private[python] object SerDeUtil extends Logging {
/**
* Convert an RDD of serialized Python tuple (K, V) to RDD[(K, V)].
*/
def pythonToPairRDD[K, V](pyRDD: RDD[Array[Byte]]): RDD[(K, V)] = {
def pythonToPairRDD[K, V](pyRDD: RDD[Array[Byte]], batchSerialized: Boolean): RDD[(K, V)] = {
def isPair(obj: Any): Boolean = {
Option(obj.getClass.getComponentType).map(!_.isPrimitive).getOrElse(false) &&
obj.asInstanceOf[Array[_]].length == 2
}
pyRDD.mapPartitions { iter =>
val unpickle = new Unpickler
iter.flatMap { row =>
unpickle.loads(row) match {
// batch serialized Python RDDs
case objs: java.util.List[_] => objs
// unbatched case
case obj => Seq(obj)
val unpickled = if (batchSerialized) {
iter.flatMap { batch =>
unpickle.loads(batch) match {
case objs: java.util.List[_] => collectionAsScalaIterable(objs)
case other => throw new SparkException(
s"Unexpected type ${other.getClass.getName} for batch serialized Python RDD")
}
}
}.map {
} else {
iter.map(unpickle.loads(_))
}
unpickled.map {
// we only accept pickled (K, V)
case obj if isPair(obj) =>
val arr = obj.asInstanceOf[Array[_]]
Expand Down
22 changes: 16 additions & 6 deletions python/pyspark/rdd.py
Original file line number Diff line number Diff line change
Expand Up @@ -232,7 +232,7 @@ def __init__(self, jrdd, ctx, jrdd_deserializer):
self._id = jrdd.id()

def _toPickleSerialization(self):
if (self._jrdd_deserializer == PickleSerializer or
if (self._jrdd_deserializer == PickleSerializer() or
self._jrdd_deserializer == BatchedSerializer(PickleSerializer())):
return self
else:
Expand Down Expand Up @@ -1049,7 +1049,9 @@ def saveAsNewAPIHadoopDataset(self, conf, keyConverter=None, valueConverter=None
@param valueConverter: (None by default)
"""
jconf = self.ctx._dictToJavaMap(conf)
self.ctx._jvm.PythonRDD.saveAsHadoopDataset(self._toPickleSerialization()._jrdd, jconf,
pickled = self._toPickleSerialization()
batched = isinstance(pickled._jrdd_deserializer, BatchedSerializer)
self.ctx._jvm.PythonRDD.saveAsHadoopDataset(pickled._jrdd, batched, jconf,
True, keyConverter, valueConverter)

def saveAsNewAPIHadoopFile(self, path, outputFormatClass, keyClass=None, valueClass=None,
Expand All @@ -1074,7 +1076,9 @@ def saveAsNewAPIHadoopFile(self, path, outputFormatClass, keyClass=None, valueCl
@param conf: Hadoop job configuration, passed in as a dict (None by default)
"""
jconf = self.ctx._dictToJavaMap(conf)
self.ctx._jvm.PythonRDD.saveAsHadoopFile(self._toPickleSerialization()._jrdd, path,
pickled = self._toPickleSerialization()
batched = isinstance(pickled._jrdd_deserializer, BatchedSerializer)
self.ctx._jvm.PythonRDD.saveAsHadoopFile(pickled._jrdd, batched, path,
outputFormatClass, keyClass, valueClass, keyConverter, valueConverter,
jconf, None, True)

Expand All @@ -1090,7 +1094,9 @@ def saveAsHadoopDataset(self, conf, keyConverter=None, valueConverter=None):
@param valueConverter: (None by default)
"""
jconf = self.ctx._dictToJavaMap(conf)
self.ctx._jvm.PythonRDD.saveAsHadoopDataset(self._toPickleSerialization()._jrdd, jconf,
pickled = self._toPickleSerialization()
batched = isinstance(pickled._jrdd_deserializer, BatchedSerializer)
self.ctx._jvm.PythonRDD.saveAsHadoopDataset(pickled._jrdd, batched, jconf,
False, keyConverter, valueConverter)

def saveAsHadoopFile(self, path, outputFormatClass, keyClass=None, valueClass=None,
Expand All @@ -1116,7 +1122,9 @@ def saveAsHadoopFile(self, path, outputFormatClass, keyClass=None, valueClass=No
@param compressionCodecClass: (None by default)
"""
jconf = self.ctx._dictToJavaMap(conf)
self.ctx._jvm.PythonRDD.saveAsHadoopFile(self._toPickleSerialization()._jrdd,
pickled = self._toPickleSerialization()
batched = isinstance(pickled._jrdd_deserializer, BatchedSerializer)
self.ctx._jvm.PythonRDD.saveAsHadoopFile(pickled._jrdd, batched,
path, outputFormatClass, keyClass, valueClass, keyConverter, valueConverter,
jconf, compressionCodecClass, False)

Expand All @@ -1131,7 +1139,9 @@ def saveAsSequenceFile(self, path, compressionCodecClass=None):
@param path: path to sequence file
@param compressionCodecClass: (None by default)
"""
self.ctx._jvm.PythonRDD.saveAsSequenceFile(self._toPickleSerialization()._jrdd,
pickled = self._toPickleSerialization()
batched = isinstance(pickled._jrdd_deserializer, BatchedSerializer)
self.ctx._jvm.PythonRDD.saveAsSequenceFile(pickled._jrdd, batched,
path, compressionCodecClass)

def saveAsPickleFile(self, path, batchSize=10):
Expand Down
8 changes: 8 additions & 0 deletions python/pyspark/tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -702,6 +702,14 @@ def test_unbatched_save_and_read(self):
batchSize=1).collect())
self.assertEqual(unbatched_newAPIHadoopRDD, ei)

def test_malformed_RDD(self):
basepath = self.tempdir.name
# non-batch-serialized RDD of type RDD[[(K, V)]] should be rejected
data = [[(1, "a")], [(2, "aa")], [(3, "aaa")]]
rdd = self.sc.parallelize(data, numSlices=len(data))
self.assertRaises(Exception, lambda: rdd.saveAsSequenceFile(
basepath + "/malformed/sequence"))

class TestDaemon(unittest.TestCase):
def connect(self, port):
from socket import socket, AF_INET, SOCK_STREAM
Expand Down

0 comments on commit 75ca5bd

Please sign in to comment.