Skip to content

Commit

Permalink
Refactor broadcast classes
Browse files Browse the repository at this point in the history
  • Loading branch information
andrewor14 committed Mar 26, 2014
1 parent c7ccef1 commit ba52e00
Show file tree
Hide file tree
Showing 11 changed files with 169 additions and 268 deletions.
7 changes: 1 addition & 6 deletions core/src/main/scala/org/apache/spark/SparkContext.scala
Original file line number Diff line number Diff line change
Expand Up @@ -641,13 +641,8 @@ class SparkContext(
* Broadcast a read-only variable to the cluster, returning a
* [[org.apache.spark.broadcast.Broadcast]] object for reading it in distributed functions.
* The variable will be sent to each cluster only once.
*
* If `registerBlocks` is true, workers will notify driver about blocks they create
* and these blocks will be dropped when `unpersist` method of the broadcast variable is called.
*/
def broadcast[T](value: T, registerBlocks: Boolean = false) = {
env.broadcastManager.newBroadcast[T](value, isLocal, registerBlocks)
}
def broadcast[T](value: T) = env.broadcastManager.newBroadcast[T](value, isLocal)

/**
* Add a file to be downloaded with this Spark job on every node.
Expand Down
51 changes: 0 additions & 51 deletions core/src/main/scala/org/apache/spark/broadcast/Broadcast.scala
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,6 @@
package org.apache.spark.broadcast

import java.io.Serializable
import java.util.concurrent.atomic.AtomicLong

import org.apache.spark._

/**
* A broadcast variable. Broadcast variables allow the programmer to keep a read-only variable
Expand Down Expand Up @@ -53,56 +50,8 @@ import org.apache.spark._
abstract class Broadcast[T](val id: Long) extends Serializable {
def value: T

/**
* Removes all blocks of this broadcast from memory (and disk if removeSource is true).
*
* @param removeSource Whether to remove data from disk as well.
* Will cause errors if broadcast is accessed on workers afterwards
* (e.g. in case of RDD re-computation due to executor failure).
*/
def unpersist(removeSource: Boolean = false)

// We cannot have an abstract readObject here due to some weird issues with
// readObject having to be 'private' in sub-classes.

override def toString = "Broadcast(" + id + ")"
}

private[spark]
class BroadcastManager(val _isDriver: Boolean, conf: SparkConf, securityManager: SecurityManager)
extends Logging with Serializable {

private var initialized = false
private var broadcastFactory: BroadcastFactory = null

initialize()

// Called by SparkContext or Executor before using Broadcast
private def initialize() {
synchronized {
if (!initialized) {
val broadcastFactoryClass = conf.get(
"spark.broadcast.factory", "org.apache.spark.broadcast.HttpBroadcastFactory")

broadcastFactory =
Class.forName(broadcastFactoryClass).newInstance.asInstanceOf[BroadcastFactory]

// Initialize appropriate BroadcastFactory and BroadcastObject
broadcastFactory.initialize(isDriver, conf, securityManager)

initialized = true
}
}
}

def stop() {
broadcastFactory.stop()
}

private val nextBroadcastId = new AtomicLong(0)

def newBroadcast[T](value_ : T, isLocal: Boolean, registerBlocks: Boolean) =
broadcastFactory.newBroadcast[T](value_, isLocal, nextBroadcastId.getAndIncrement(), registerBlocks)

def isDriver = _isDriver
}
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,6 @@ import org.apache.spark.SparkConf
*/
trait BroadcastFactory {
def initialize(isDriver: Boolean, conf: SparkConf, securityMgr: SecurityManager): Unit
def newBroadcast[T](value: T, isLocal: Boolean, id: Long, registerBlocks: Boolean): Broadcast[T]
def newBroadcast[T](value: T, isLocal: Boolean, id: Long): Broadcast[T]
def stop(): Unit
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.broadcast

import java.util.concurrent.atomic.AtomicLong

import org.apache.spark._

private[spark] class BroadcastManager(
val isDriver: Boolean,
conf: SparkConf,
securityManager: SecurityManager)
extends Logging with Serializable {

private var initialized = false
private var broadcastFactory: BroadcastFactory = null

initialize()

// Called by SparkContext or Executor before using Broadcast
private def initialize() {
synchronized {
if (!initialized) {
val broadcastFactoryClass =
conf.get("spark.broadcast.factory", "org.apache.spark.broadcast.HttpBroadcastFactory")

broadcastFactory =
Class.forName(broadcastFactoryClass).newInstance.asInstanceOf[BroadcastFactory]

// Initialize appropriate BroadcastFactory and BroadcastObject
broadcastFactory.initialize(isDriver, conf, securityManager)

initialized = true
}
}
}

def stop() {
broadcastFactory.stop()
}

private val nextBroadcastId = new AtomicLong(0)

def newBroadcast[T](value_ : T, isLocal: Boolean) = {
broadcastFactory.newBroadcast[T](value_, isLocal, nextBroadcastId.getAndIncrement())
}

}
59 changes: 10 additions & 49 deletions core/src/main/scala/org/apache/spark/broadcast/HttpBroadcast.scala
Original file line number Diff line number Diff line change
Expand Up @@ -29,24 +29,11 @@ import org.apache.spark.io.CompressionCodec
import org.apache.spark.storage.{BroadcastBlockId, StorageLevel}
import org.apache.spark.util.{MetadataCleaner, MetadataCleanerType, TimeStampedHashSet, Utils}

private[spark] class HttpBroadcast[T](@transient var value_ : T, isLocal: Boolean, id: Long, registerBlocks: Boolean)
private[spark] class HttpBroadcast[T](@transient var value_ : T, isLocal: Boolean, id: Long)
extends Broadcast[T](id) with Logging with Serializable {

def value = value_

def unpersist(removeSource: Boolean) {
HttpBroadcast.synchronized {
SparkEnv.get.blockManager.master.removeBlock(blockId)
SparkEnv.get.blockManager.removeBlock(blockId)
}

if (removeSource) {
HttpBroadcast.synchronized {
HttpBroadcast.cleanupById(id)
}
}
}

def blockId = BroadcastBlockId(id)

HttpBroadcast.synchronized {
Expand All @@ -67,7 +54,7 @@ private[spark] class HttpBroadcast[T](@transient var value_ : T, isLocal: Boolea
logInfo("Started reading broadcast variable " + id)
val start = System.nanoTime
value_ = HttpBroadcast.read[T](id)
SparkEnv.get.blockManager.putSingle(blockId, value_, StorageLevel.MEMORY_AND_DISK, registerBlocks)
SparkEnv.get.blockManager.putSingle(blockId, value_, StorageLevel.MEMORY_AND_DISK, false)
val time = (System.nanoTime - start) / 1e9
logInfo("Reading broadcast variable " + id + " took " + time + " s")
}
Expand All @@ -76,20 +63,6 @@ private[spark] class HttpBroadcast[T](@transient var value_ : T, isLocal: Boolea
}
}

/**
* A [[BroadcastFactory]] implementation that uses a HTTP server as the broadcast medium.
*/
class HttpBroadcastFactory extends BroadcastFactory {
def initialize(isDriver: Boolean, conf: SparkConf, securityMgr: SecurityManager) {
HttpBroadcast.initialize(isDriver, conf, securityMgr)
}

def newBroadcast[T](value_ : T, isLocal: Boolean, id: Long, registerBlocks: Boolean) =
new HttpBroadcast[T](value_, isLocal, id, registerBlocks)

def stop() { HttpBroadcast.stop() }
}

private object HttpBroadcast extends Logging {
private var initialized = false

Expand Down Expand Up @@ -149,10 +122,8 @@ private object HttpBroadcast extends Logging {
logInfo("Broadcast server started at " + serverUri)
}

def getFile(id: Long) = new File(broadcastDir, BroadcastBlockId(id).name)

def write(id: Long, value: Any) {
val file = getFile(id)
val file = new File(broadcastDir, BroadcastBlockId(id).name)
val out: OutputStream = {
if (compress) {
compressionCodec.compressedOutputStream(new FileOutputStream(file))
Expand Down Expand Up @@ -198,30 +169,20 @@ private object HttpBroadcast extends Logging {
obj
}

def deleteFile(fileName: String) {
try {
new File(fileName).delete()
logInfo("Deleted broadcast file '" + fileName + "'")
} catch {
case e: Exception => logWarning("Could not delete broadcast file '" + fileName + "'", e)
}
}

def cleanup(cleanupTime: Long) {
val iterator = files.internalMap.entrySet().iterator()
while(iterator.hasNext) {
val entry = iterator.next()
val (file, time) = (entry.getKey, entry.getValue)
if (time < cleanupTime) {
iterator.remove()
deleteFile(file)
try {
iterator.remove()
new File(file.toString).delete()
logInfo("Deleted broadcast file '" + file + "'")
} catch {
case e: Exception => logWarning("Could not delete broadcast file '" + file + "'", e)
}
}
}
}

def cleanupById(id: Long) {
val file = getFile(id).getAbsolutePath
files.internalMap.remove(file)
deleteFile(file)
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.broadcast

import org.apache.spark.{SecurityManager, SparkConf}

/**
* A [[BroadcastFactory]] implementation that uses a HTTP server as the broadcast medium.
*/
class HttpBroadcastFactory extends BroadcastFactory {
def initialize(isDriver: Boolean, conf: SparkConf, securityMgr: SecurityManager) {
HttpBroadcast.initialize(isDriver, conf, securityMgr)
}

def newBroadcast[T](value_ : T, isLocal: Boolean, id: Long) =
new HttpBroadcast[T](value_, isLocal, id)

def stop() { HttpBroadcast.stop() }
}
Loading

0 comments on commit ba52e00

Please sign in to comment.