From a17d3250d4472619fa987d9b389d48b165c44745 Mon Sep 17 00:00:00 2001 From: root Date: Wed, 7 Dec 2016 00:14:26 +0800 Subject: [PATCH 1/5] [SPARK-18742][CORE]readd spark.broadcast.factory conf to implement user-defined BroadcastFactory --- .../spark/broadcast/BroadcastManager.scala | 8 ++- .../spark/broadcast/BroadcastSuite.scala | 63 ++++++++++++++++++- 2 files changed, 67 insertions(+), 4 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/broadcast/BroadcastManager.scala b/core/src/main/scala/org/apache/spark/broadcast/BroadcastManager.scala index e88988fe03b2e..1751710840d34 100644 --- a/core/src/main/scala/org/apache/spark/broadcast/BroadcastManager.scala +++ b/core/src/main/scala/org/apache/spark/broadcast/BroadcastManager.scala @@ -20,9 +20,9 @@ package org.apache.spark.broadcast import java.util.concurrent.atomic.AtomicLong import scala.reflect.ClassTag - import org.apache.spark.{SecurityManager, SparkConf} import org.apache.spark.internal.Logging +import org.apache.spark.util.Utils private[spark] class BroadcastManager( val isDriver: Boolean, @@ -39,7 +39,11 @@ private[spark] class BroadcastManager( private def initialize() { synchronized { if (!initialized) { - broadcastFactory = new TorrentBroadcastFactory + val broadcastFactoryClass = + conf.get("spark.broadcast.factory", "org.apache.spark.broadcast.TorrentBroadcastFactory") + broadcastFactory = + Utils.classForName(broadcastFactoryClass).newInstance().asInstanceOf[BroadcastFactory] + broadcastFactory.initialize(isDriver, conf, securityManager) initialized = true } diff --git a/core/src/test/scala/org/apache/spark/broadcast/BroadcastSuite.scala b/core/src/test/scala/org/apache/spark/broadcast/BroadcastSuite.scala index 973676398ae54..cdf716dd5a68b 100644 --- a/core/src/test/scala/org/apache/spark/broadcast/BroadcastSuite.scala +++ b/core/src/test/scala/org/apache/spark/broadcast/BroadcastSuite.scala @@ -17,15 +17,21 @@ package org.apache.spark.broadcast -import scala.util.Random +import java.io.Serializable -import org.scalatest.Assertions +import scala.collection.immutable.HashMap +import scala.reflect.ClassTag +import scala.util.Random import org.apache.spark._ +import org.apache.spark.internal.Logging import org.apache.spark.io.SnappyCompressionCodec import org.apache.spark.rdd.RDD import org.apache.spark.serializer.JavaSerializer import org.apache.spark.storage._ +import org.apache.spark.util.Utils +import org.scalatest.Assertions + // Dummy class that creates a broadcast variable but doesn't use it class DummyBroadcastClass(rdd: RDD[Int]) extends Serializable { @@ -43,8 +49,61 @@ class DummyBroadcastClass(rdd: RDD[Int]) extends Serializable { } } +class UserDefineBroadcast[T: ClassTag](obj: T, id: Long) + extends Broadcast[T](id) with Logging with Serializable { + var map = HashMap[Long, Any]() + @transient private lazy val _value: T = readBroadcastBlock() + + map += id -> obj + + private def readBroadcastBlock(): T = Utils.tryOrIOException { + map.get(id) match { + case Some(v) => v.asInstanceOf[T] + case _ => + throw new SparkException(s"Failed to get $id from map") + } + } + + override protected def getValue() = { + _value + } + + override protected def doUnpersist(blocking: Boolean) = {} + + override protected def doDestroy(blocking: Boolean) = {} + +} + +class UserDefineBroadcastFactory extends BroadcastFactory { + + override def initialize(isDriver: Boolean, conf: SparkConf, securityMgr: SecurityManager) { } + + override def newBroadcast[T: ClassTag](value_ : T, isLocal: Boolean, id: Long): Broadcast[T] = { + new UserDefineBroadcast[T](value_, id) + } + + override def stop() {} + + override def unbroadcast(id: Long, removeFromDriver: Boolean, blocking: Boolean) {} +} + class BroadcastSuite extends SparkFunSuite with LocalSparkContext { + test("Using user-defined local BroadcastFactory") { + val conf = new SparkConf + conf.set("spark.broadcast.factory", "org.apache.spark.broadcast.UserDefineBroadcastFactory") + + sc = new SparkContext("local", "test user-define userDefineBroadcastFactory", conf) + + val list = List[String]("a", "b", "c", "d") + val broadcast = sc.broadcast(list) + + val results = sc.parallelize(1 to 10).map(x => (x, broadcast.value.mkString)) + + assert(results.collect().toSet === (1 to 10) + .map(x => (x, "abcd")).toSet) + } + test("Using TorrentBroadcast locally") { sc = new SparkContext("local", "test") val list = List[Int](1, 2, 3, 4) From 1b4d4b632523f33b38e6997dddbba954f324ea59 Mon Sep 17 00:00:00 2001 From: root Date: Wed, 7 Dec 2016 00:24:07 +0800 Subject: [PATCH 2/5] fix style --- .../org/apache/spark/broadcast/BroadcastManager.scala | 3 ++- .../scala/org/apache/spark/broadcast/BroadcastSuite.scala | 7 +++---- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/broadcast/BroadcastManager.scala b/core/src/main/scala/org/apache/spark/broadcast/BroadcastManager.scala index 1751710840d34..78954e3fb83bf 100644 --- a/core/src/main/scala/org/apache/spark/broadcast/BroadcastManager.scala +++ b/core/src/main/scala/org/apache/spark/broadcast/BroadcastManager.scala @@ -20,8 +20,9 @@ package org.apache.spark.broadcast import java.util.concurrent.atomic.AtomicLong import scala.reflect.ClassTag -import org.apache.spark.{SecurityManager, SparkConf} + import org.apache.spark.internal.Logging +import org.apache.spark.{SecurityManager, SparkConf} import org.apache.spark.util.Utils private[spark] class BroadcastManager( diff --git a/core/src/test/scala/org/apache/spark/broadcast/BroadcastSuite.scala b/core/src/test/scala/org/apache/spark/broadcast/BroadcastSuite.scala index cdf716dd5a68b..9c065dac16707 100644 --- a/core/src/test/scala/org/apache/spark/broadcast/BroadcastSuite.scala +++ b/core/src/test/scala/org/apache/spark/broadcast/BroadcastSuite.scala @@ -23,6 +23,8 @@ import scala.collection.immutable.HashMap import scala.reflect.ClassTag import scala.util.Random +import org.scalatest.Assertions + import org.apache.spark._ import org.apache.spark.internal.Logging import org.apache.spark.io.SnappyCompressionCodec @@ -30,7 +32,6 @@ import org.apache.spark.rdd.RDD import org.apache.spark.serializer.JavaSerializer import org.apache.spark.storage._ import org.apache.spark.util.Utils -import org.scalatest.Assertions // Dummy class that creates a broadcast variable but doesn't use it @@ -64,9 +65,7 @@ class UserDefineBroadcast[T: ClassTag](obj: T, id: Long) } } - override protected def getValue() = { - _value - } + override protected def getValue() = _value override protected def doUnpersist(blocking: Boolean) = {} From f23beccf1d484ea5121f6cf605b92090400808ad Mon Sep 17 00:00:00 2001 From: root Date: Wed, 7 Dec 2016 00:30:58 +0800 Subject: [PATCH 3/5] fix style --- .../scala/org/apache/spark/broadcast/BroadcastManager.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/core/src/main/scala/org/apache/spark/broadcast/BroadcastManager.scala b/core/src/main/scala/org/apache/spark/broadcast/BroadcastManager.scala index 78954e3fb83bf..f03cf2844672f 100644 --- a/core/src/main/scala/org/apache/spark/broadcast/BroadcastManager.scala +++ b/core/src/main/scala/org/apache/spark/broadcast/BroadcastManager.scala @@ -21,8 +21,8 @@ import java.util.concurrent.atomic.AtomicLong import scala.reflect.ClassTag -import org.apache.spark.internal.Logging import org.apache.spark.{SecurityManager, SparkConf} +import org.apache.spark.internal.Logging import org.apache.spark.util.Utils private[spark] class BroadcastManager( From 7c9519e502a264f07bb50fed5f9d4f516ab207e7 Mon Sep 17 00:00:00 2001 From: windpiger Date: Thu, 15 Dec 2016 09:38:48 +0800 Subject: [PATCH 4/5] just modify the comment of BroadcastFactory --- .../spark/broadcast/BroadcastFactory.scala | 5 +- .../spark/broadcast/BroadcastManager.scala | 7 +-- .../spark/broadcast/BroadcastSuite.scala | 57 ------------------- 3 files changed, 3 insertions(+), 66 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/broadcast/BroadcastFactory.scala b/core/src/main/scala/org/apache/spark/broadcast/BroadcastFactory.scala index fd7b4fc88b697..ece4ae6ab0310 100644 --- a/core/src/main/scala/org/apache/spark/broadcast/BroadcastFactory.scala +++ b/core/src/main/scala/org/apache/spark/broadcast/BroadcastFactory.scala @@ -24,9 +24,8 @@ import org.apache.spark.SparkConf /** * An interface for all the broadcast implementations in Spark (to allow - * multiple broadcast implementations). SparkContext uses a user-specified - * BroadcastFactory implementation to instantiate a particular broadcast for the - * entire Spark job. + * multiple broadcast implementations). SparkContext uses a BroadcastFactory + * implementation to instantiate a particular broadcast for the entire Spark job. */ private[spark] trait BroadcastFactory { diff --git a/core/src/main/scala/org/apache/spark/broadcast/BroadcastManager.scala b/core/src/main/scala/org/apache/spark/broadcast/BroadcastManager.scala index f03cf2844672f..e88988fe03b2e 100644 --- a/core/src/main/scala/org/apache/spark/broadcast/BroadcastManager.scala +++ b/core/src/main/scala/org/apache/spark/broadcast/BroadcastManager.scala @@ -23,7 +23,6 @@ import scala.reflect.ClassTag import org.apache.spark.{SecurityManager, SparkConf} import org.apache.spark.internal.Logging -import org.apache.spark.util.Utils private[spark] class BroadcastManager( val isDriver: Boolean, @@ -40,11 +39,7 @@ private[spark] class BroadcastManager( private def initialize() { synchronized { if (!initialized) { - val broadcastFactoryClass = - conf.get("spark.broadcast.factory", "org.apache.spark.broadcast.TorrentBroadcastFactory") - broadcastFactory = - Utils.classForName(broadcastFactoryClass).newInstance().asInstanceOf[BroadcastFactory] - + broadcastFactory = new TorrentBroadcastFactory broadcastFactory.initialize(isDriver, conf, securityManager) initialized = true } diff --git a/core/src/test/scala/org/apache/spark/broadcast/BroadcastSuite.scala b/core/src/test/scala/org/apache/spark/broadcast/BroadcastSuite.scala index 9c065dac16707..6ef0e5a31ea21 100644 --- a/core/src/test/scala/org/apache/spark/broadcast/BroadcastSuite.scala +++ b/core/src/test/scala/org/apache/spark/broadcast/BroadcastSuite.scala @@ -17,22 +17,15 @@ package org.apache.spark.broadcast -import java.io.Serializable - -import scala.collection.immutable.HashMap -import scala.reflect.ClassTag import scala.util.Random import org.scalatest.Assertions import org.apache.spark._ -import org.apache.spark.internal.Logging import org.apache.spark.io.SnappyCompressionCodec import org.apache.spark.rdd.RDD import org.apache.spark.serializer.JavaSerializer import org.apache.spark.storage._ -import org.apache.spark.util.Utils - // Dummy class that creates a broadcast variable but doesn't use it class DummyBroadcastClass(rdd: RDD[Int]) extends Serializable { @@ -50,59 +43,9 @@ class DummyBroadcastClass(rdd: RDD[Int]) extends Serializable { } } -class UserDefineBroadcast[T: ClassTag](obj: T, id: Long) - extends Broadcast[T](id) with Logging with Serializable { - var map = HashMap[Long, Any]() - @transient private lazy val _value: T = readBroadcastBlock() - - map += id -> obj - - private def readBroadcastBlock(): T = Utils.tryOrIOException { - map.get(id) match { - case Some(v) => v.asInstanceOf[T] - case _ => - throw new SparkException(s"Failed to get $id from map") - } - } - - override protected def getValue() = _value - - override protected def doUnpersist(blocking: Boolean) = {} - - override protected def doDestroy(blocking: Boolean) = {} - -} - -class UserDefineBroadcastFactory extends BroadcastFactory { - - override def initialize(isDriver: Boolean, conf: SparkConf, securityMgr: SecurityManager) { } - - override def newBroadcast[T: ClassTag](value_ : T, isLocal: Boolean, id: Long): Broadcast[T] = { - new UserDefineBroadcast[T](value_, id) - } - - override def stop() {} - - override def unbroadcast(id: Long, removeFromDriver: Boolean, blocking: Boolean) {} -} class BroadcastSuite extends SparkFunSuite with LocalSparkContext { - test("Using user-defined local BroadcastFactory") { - val conf = new SparkConf - conf.set("spark.broadcast.factory", "org.apache.spark.broadcast.UserDefineBroadcastFactory") - - sc = new SparkContext("local", "test user-define userDefineBroadcastFactory", conf) - - val list = List[String]("a", "b", "c", "d") - val broadcast = sc.broadcast(list) - - val results = sc.parallelize(1 to 10).map(x => (x, broadcast.value.mkString)) - - assert(results.collect().toSet === (1 to 10) - .map(x => (x, "abcd")).toSet) - } - test("Using TorrentBroadcast locally") { sc = new SparkContext("local", "test") val list = List[Int](1, 2, 3, 4) From 49d799fbbbf49b94af890820a050e395edf24c10 Mon Sep 17 00:00:00 2001 From: windpiger Date: Thu, 15 Dec 2016 09:39:52 +0800 Subject: [PATCH 5/5] fix a empty line --- .../test/scala/org/apache/spark/broadcast/BroadcastSuite.scala | 1 - 1 file changed, 1 deletion(-) diff --git a/core/src/test/scala/org/apache/spark/broadcast/BroadcastSuite.scala b/core/src/test/scala/org/apache/spark/broadcast/BroadcastSuite.scala index 6ef0e5a31ea21..973676398ae54 100644 --- a/core/src/test/scala/org/apache/spark/broadcast/BroadcastSuite.scala +++ b/core/src/test/scala/org/apache/spark/broadcast/BroadcastSuite.scala @@ -43,7 +43,6 @@ class DummyBroadcastClass(rdd: RDD[Int]) extends Serializable { } } - class BroadcastSuite extends SparkFunSuite with LocalSparkContext { test("Using TorrentBroadcast locally") {