From fb91792e9dc310f9a38075debe0d515c72fcac77 Mon Sep 17 00:00:00 2001 From: GuoQiang Li Date: Thu, 7 Aug 2014 21:41:16 +0800 Subject: [PATCH 1/3] org.apache.spark.broadcast.TorrentBroadcast does use the serializer class specified in the spark option "spark.serializer" --- .../spark/broadcast/TorrentBroadcast.scala | 20 +++++++++++++------ 1 file changed, 14 insertions(+), 6 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/broadcast/TorrentBroadcast.scala b/core/src/main/scala/org/apache/spark/broadcast/TorrentBroadcast.scala index 86731b684f441..f69dacf32d78b 100644 --- a/core/src/main/scala/org/apache/spark/broadcast/TorrentBroadcast.scala +++ b/core/src/main/scala/org/apache/spark/broadcast/TorrentBroadcast.scala @@ -17,14 +17,13 @@ package org.apache.spark.broadcast -import java.io.{ByteArrayInputStream, ObjectInputStream, ObjectOutputStream} +import java.io._ import scala.reflect.ClassTag import scala.util.Random import org.apache.spark.{Logging, SparkConf, SparkEnv, SparkException} import org.apache.spark.storage.{BroadcastBlockId, StorageLevel} -import org.apache.spark.util.Utils /** * A [[org.apache.spark.broadcast.Broadcast]] implementation that uses a BitTorrent-like @@ -228,8 +227,12 @@ private[broadcast] object TorrentBroadcast extends Logging { initialized = false } - def blockifyObject[T](obj: T): TorrentInfo = { - val byteArray = Utils.serialize[T](obj) + def blockifyObject[T: ClassTag](obj: T): TorrentInfo = { + val bos = new ByteArrayOutputStream() + val ser = SparkEnv.get.serializer.newInstance() + val serOut = ser.serializeStream(bos) + serOut.writeObject[T](obj).close() + val byteArray = bos.toByteArray val bais = new ByteArrayInputStream(byteArray) var blockNum = byteArray.length / BLOCK_SIZE @@ -255,7 +258,7 @@ private[broadcast] object TorrentBroadcast extends Logging { info } - def unBlockifyObject[T]( + def unBlockifyObject[T: ClassTag]( arrayOfBlocks: Array[TorrentBlock], totalBytes: Int, totalBlocks: Int): T = { @@ -264,7 +267,12 @@ private[broadcast] object TorrentBroadcast extends Logging { System.arraycopy(arrayOfBlocks(i).byteArray, 0, retByteArray, i * BLOCK_SIZE, arrayOfBlocks(i).byteArray.length) } - Utils.deserialize[T](retByteArray, Thread.currentThread.getContextClassLoader) + val in = new ByteArrayInputStream(retByteArray) + val ser = SparkEnv.get.serializer.newInstance() + val serIn = ser.deserializeStream(in) + val obj = serIn.readObject[T]() + serIn.close() + obj } /** From ada4fbaab90169fc498b6da3beaedc0a724f5b60 Mon Sep 17 00:00:00 2001 From: GuoQiang Li Date: Fri, 8 Aug 2014 11:40:15 +0800 Subject: [PATCH 2/3] TorrentBroadcast does not support broadcast compression --- .../spark/broadcast/TorrentBroadcast.scala | 25 ++++++++++++++++--- 1 file changed, 22 insertions(+), 3 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/broadcast/TorrentBroadcast.scala b/core/src/main/scala/org/apache/spark/broadcast/TorrentBroadcast.scala index f69dacf32d78b..d9dfa619d2347 100644 --- a/core/src/main/scala/org/apache/spark/broadcast/TorrentBroadcast.scala +++ b/core/src/main/scala/org/apache/spark/broadcast/TorrentBroadcast.scala @@ -22,6 +22,7 @@ import java.io._ import scala.reflect.ClassTag import scala.util.Random +import org.apache.spark.io.CompressionCodec import org.apache.spark.{Logging, SparkConf, SparkEnv, SparkException} import org.apache.spark.storage.{BroadcastBlockId, StorageLevel} @@ -213,11 +214,15 @@ private[broadcast] object TorrentBroadcast extends Logging { private lazy val BLOCK_SIZE = conf.getInt("spark.broadcast.blockSize", 4096) * 1024 private var initialized = false private var conf: SparkConf = null + private var compress: Boolean = false + private var compressionCodec: CompressionCodec = null def initialize(_isDriver: Boolean, conf: SparkConf) { TorrentBroadcast.conf = conf // TODO: we might have to fix it in tests synchronized { if (!initialized) { + compress = conf.getBoolean("spark.broadcast.compress", true) + compressionCodec = CompressionCodec.createCodec(conf) initialized = true } } @@ -228,9 +233,16 @@ private[broadcast] object TorrentBroadcast extends Logging { } def blockifyObject[T: ClassTag](obj: T): TorrentInfo = { - val bos = new ByteArrayOutputStream() + val bos =new ByteArrayOutputStream() + val out: OutputStream = { + if (compress) { + compressionCodec.compressedOutputStream(bos) + } else { + bos + } + } val ser = SparkEnv.get.serializer.newInstance() - val serOut = ser.serializeStream(bos) + val serOut = ser.serializeStream(out) serOut.writeObject[T](obj).close() val byteArray = bos.toByteArray val bais = new ByteArrayInputStream(byteArray) @@ -267,7 +279,14 @@ private[broadcast] object TorrentBroadcast extends Logging { System.arraycopy(arrayOfBlocks(i).byteArray, 0, retByteArray, i * BLOCK_SIZE, arrayOfBlocks(i).byteArray.length) } - val in = new ByteArrayInputStream(retByteArray) + + val in: InputStream = { + if (compress) { + compressionCodec.compressedInputStream(new ByteArrayInputStream(retByteArray)) + } else { + new ByteArrayInputStream(retByteArray) + } + } val ser = SparkEnv.get.serializer.newInstance() val serIn = ser.deserializeStream(in) val obj = serIn.readObject[T]() From 23cdc5b7d1807a01fc800b33f6bfa25db625874e Mon Sep 17 00:00:00 2001 From: GuoQiang Li Date: Sat, 9 Aug 2014 04:39:06 +0800 Subject: [PATCH 3/3] review commit --- .../spark/broadcast/TorrentBroadcast.scala | 22 ++++++------------- .../spark/broadcast/BroadcastSuite.scala | 10 +++++++-- 2 files changed, 15 insertions(+), 17 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/broadcast/TorrentBroadcast.scala b/core/src/main/scala/org/apache/spark/broadcast/TorrentBroadcast.scala index d9dfa619d2347..fe73456ef8fad 100644 --- a/core/src/main/scala/org/apache/spark/broadcast/TorrentBroadcast.scala +++ b/core/src/main/scala/org/apache/spark/broadcast/TorrentBroadcast.scala @@ -17,13 +17,14 @@ package org.apache.spark.broadcast -import java.io._ +import java.io.{ByteArrayOutputStream, ByteArrayInputStream, InputStream, + ObjectInputStream, ObjectOutputStream, OutputStream} import scala.reflect.ClassTag import scala.util.Random -import org.apache.spark.io.CompressionCodec import org.apache.spark.{Logging, SparkConf, SparkEnv, SparkException} +import org.apache.spark.io.CompressionCodec import org.apache.spark.storage.{BroadcastBlockId, StorageLevel} /** @@ -233,14 +234,8 @@ private[broadcast] object TorrentBroadcast extends Logging { } def blockifyObject[T: ClassTag](obj: T): TorrentInfo = { - val bos =new ByteArrayOutputStream() - val out: OutputStream = { - if (compress) { - compressionCodec.compressedOutputStream(bos) - } else { - bos - } - } + val bos = new ByteArrayOutputStream() + val out: OutputStream = if (compress) compressionCodec.compressedOutputStream(bos) else bos val ser = SparkEnv.get.serializer.newInstance() val serOut = ser.serializeStream(out) serOut.writeObject[T](obj).close() @@ -281,11 +276,8 @@ private[broadcast] object TorrentBroadcast extends Logging { } val in: InputStream = { - if (compress) { - compressionCodec.compressedInputStream(new ByteArrayInputStream(retByteArray)) - } else { - new ByteArrayInputStream(retByteArray) - } + val arrIn = new ByteArrayInputStream(retByteArray) + if (compress) compressionCodec.compressedInputStream(arrIn) else arrIn } val ser = SparkEnv.get.serializer.newInstance() val serIn = ser.deserializeStream(in) diff --git a/core/src/test/scala/org/apache/spark/broadcast/BroadcastSuite.scala b/core/src/test/scala/org/apache/spark/broadcast/BroadcastSuite.scala index 7c3d0208b195a..17c64455b2429 100644 --- a/core/src/test/scala/org/apache/spark/broadcast/BroadcastSuite.scala +++ b/core/src/test/scala/org/apache/spark/broadcast/BroadcastSuite.scala @@ -44,7 +44,10 @@ class BroadcastSuite extends FunSuite with LocalSparkContext { test("Accessing HttpBroadcast variables in a local cluster") { val numSlaves = 4 - sc = new SparkContext("local-cluster[%d, 1, 512]".format(numSlaves), "test", httpConf) + val conf = httpConf.clone + conf.set("spark.serializer", "org.apache.spark.serializer.KryoSerializer") + conf.set("spark.broadcast.compress", "true") + sc = new SparkContext("local-cluster[%d, 1, 512]".format(numSlaves), "test", conf) val list = List[Int](1, 2, 3, 4) val broadcast = sc.broadcast(list) val results = sc.parallelize(1 to numSlaves).map(x => (x, broadcast.value.sum)) @@ -69,7 +72,10 @@ class BroadcastSuite extends FunSuite with LocalSparkContext { test("Accessing TorrentBroadcast variables in a local cluster") { val numSlaves = 4 - sc = new SparkContext("local-cluster[%d, 1, 512]".format(numSlaves), "test", torrentConf) + val conf = torrentConf.clone + conf.set("spark.serializer", "org.apache.spark.serializer.KryoSerializer") + conf.set("spark.broadcast.compress", "true") + sc = new SparkContext("local-cluster[%d, 1, 512]".format(numSlaves), "test", conf) val list = List[Int](1, 2, 3, 4) val broadcast = sc.broadcast(list) val results = sc.parallelize(1 to numSlaves).map(x => (x, broadcast.value.sum))