Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SPARK-2521] Broadcast RDD object (instead of sending it along with every task). #1452

Closed
wants to merge 5 commits into from
Closed
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 18 additions & 10 deletions core/src/main/scala/org/apache/spark/Dependency.scala
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,9 @@ import org.apache.spark.shuffle.ShuffleHandle
* Base class for dependencies.
*/
@DeveloperApi
abstract class Dependency[T](val rdd: RDD[T]) extends Serializable
abstract class Dependency[T] extends Serializable {
def rdd: RDD[T]
}


/**
Expand All @@ -36,41 +38,47 @@ abstract class Dependency[T](val rdd: RDD[T]) extends Serializable
* partition of the child RDD. Narrow dependencies allow for pipelined execution.
*/
@DeveloperApi
abstract class NarrowDependency[T](rdd: RDD[T]) extends Dependency(rdd) {
abstract class NarrowDependency[T](_rdd: RDD[T]) extends Dependency[T] {
/**
* Get the parent partitions for a child partition.
* @param partitionId a partition of the child RDD
* @return the partitions of the parent RDD that the child partition depends upon
*/
def getParents(partitionId: Int): Seq[Int]

override def rdd: RDD[T] = _rdd
}


/**
* :: DeveloperApi ::
* Represents a dependency on the output of a shuffle stage.
* @param rdd the parent RDD
* Represents a dependency on the output of a shuffle stage. Note that in the case of shuffle,
* the RDD is transient since we don't need it on the executor side.
*
* @param _rdd the parent RDD
* @param partitioner partitioner used to partition the shuffle output
* @param serializer [[org.apache.spark.serializer.Serializer Serializer]] to use. If set to None,
* the default serializer, as specified by `spark.serializer` config option, will
* be used.
*/
@DeveloperApi
class ShuffleDependency[K, V, C](
@transient rdd: RDD[_ <: Product2[K, V]],
@transient _rdd: RDD[_ <: Product2[K, V]],
val partitioner: Partitioner,
val serializer: Option[Serializer] = None,
val keyOrdering: Option[Ordering[K]] = None,
val aggregator: Option[Aggregator[K, V, C]] = None,
val mapSideCombine: Boolean = false)
extends Dependency(rdd.asInstanceOf[RDD[Product2[K, V]]]) {
extends Dependency[Product2[K, V]] {

override def rdd = _rdd.asInstanceOf[RDD[Product2[K, V]]]

val shuffleId: Int = rdd.context.newShuffleId()
val shuffleId: Int = _rdd.context.newShuffleId()

val shuffleHandle: ShuffleHandle = rdd.context.env.shuffleManager.registerShuffle(
shuffleId, rdd.partitions.size, this)
val shuffleHandle: ShuffleHandle = _rdd.context.env.shuffleManager.registerShuffle(
shuffleId, _rdd.partitions.size, this)

rdd.sparkContext.cleaner.foreach(_.registerShuffleForCleanup(this))
_rdd.sparkContext.cleaner.foreach(_.registerShuffleForCleanup(this))
}


Expand Down
2 changes: 0 additions & 2 deletions core/src/main/scala/org/apache/spark/SparkContext.scala
Original file line number Diff line number Diff line change
Expand Up @@ -997,8 +997,6 @@ class SparkContext(config: SparkConf) extends Logging {
// TODO: Cache.stop()?
env.stop()
SparkEnv.set(null)
ShuffleMapTask.clearCache()
ResultTask.clearCache()
listenerBus.stop()
eventLogger.foreach(_.stop())
logInfo("Successfully stopped SparkContext")
Expand Down
26 changes: 19 additions & 7 deletions core/src/main/scala/org/apache/spark/rdd/RDD.scala
Original file line number Diff line number Diff line change
Expand Up @@ -35,12 +35,13 @@ import org.apache.spark.Partitioner._
import org.apache.spark.SparkContext._
import org.apache.spark.annotation.{DeveloperApi, Experimental}
import org.apache.spark.api.java.JavaRDD
import org.apache.spark.broadcast.Broadcast
import org.apache.spark.partial.BoundedDouble
import org.apache.spark.partial.CountEvaluator
import org.apache.spark.partial.GroupedCountEvaluator
import org.apache.spark.partial.PartialResult
import org.apache.spark.storage.StorageLevel
import org.apache.spark.util.{BoundedPriorityQueue, CallSite, Utils}
import org.apache.spark.util.{BoundedPriorityQueue, Utils}
import org.apache.spark.util.collection.OpenHashMap
import org.apache.spark.util.random.{BernoulliSampler, PoissonSampler, SamplingUtils}

Expand Down Expand Up @@ -1195,21 +1196,32 @@ abstract class RDD[T: ClassTag](
/**
* Return whether this RDD has been checkpointed or not
*/
def isCheckpointed: Boolean = {
checkpointData.map(_.isCheckpointed).getOrElse(false)
}
def isCheckpointed: Boolean = checkpointData.exists(_.isCheckpointed)

/**
* Gets the name of the file to which this RDD was checkpointed
*/
def getCheckpointFile: Option[String] = {
checkpointData.flatMap(_.getCheckpointFile)
}
def getCheckpointFile: Option[String] = checkpointData.flatMap(_.getCheckpointFile)

// =======================================================================
// Other internal methods and fields
// =======================================================================

/**
* Broadcasted copy of this RDD, used to dispatch tasks to executors. Note that we broadcast
* the serialized copy of the RDD and for each task we will deserialize it, which means each
* task gets a different copy of the RDD. This provides stronger isolation between tasks that
* might modify state of objects referenced in their closures. This is necessary in Hadoop
* where the JobConf/Configuration object is not thread-safe.
*/
@transient private[spark] lazy val broadcasted: Broadcast[Array[Byte]] = {
// TODO: Warn users about very large RDDs.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It would be nice to add this in this patch, we can just choose a threshold

val ser = SparkEnv.get.closureSerializer.newInstance()
val bytes = ser.serialize(this).array()
logDebug(s"Broadcasting RDD $id using ${bytes.length} bytes")
sc.broadcast(bytes)
}

private var storageLevel: StorageLevel = StorageLevel.NONE

/** User code that created this RDD (e.g. `textFile`, `parallelize`). */
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,6 @@ private[spark] class RDDCheckpointData[T: ClassTag](@transient rdd: RDD[T])
cpRDD = Some(newRDD)
rdd.markCheckpointed(newRDD) // Update the RDD's dependencies and partitions
cpState = Checkpointed
RDDCheckpointData.clearTaskCaches()
}
logInfo("Done checkpointing RDD " + rdd.id + " to " + path + ", new parent is RDD " + newRDD.id)
}
Expand All @@ -131,9 +130,5 @@ private[spark] class RDDCheckpointData[T: ClassTag](@transient rdd: RDD[T])
}
}

private[spark] object RDDCheckpointData {
def clearTaskCaches() {
ShuffleMapTask.clearCache()
ResultTask.clearCache()
}
}
// Used for synchronization
private[spark] object RDDCheckpointData
Original file line number Diff line number Diff line change
Expand Up @@ -376,9 +376,6 @@ class DAGScheduler(
stageIdToStage -= stageId
stageIdToJobIds -= stageId

ShuffleMapTask.removeStage(stageId)
ResultTask.removeStage(stageId)

logDebug("After removal of stage %d, remaining stages = %d"
.format(stageId, stageIdToStage.size))
}
Expand Down Expand Up @@ -723,7 +720,6 @@ class DAGScheduler(
}
}


/** Called when stage's parents are available and we can now do its task. */
private def submitMissingTasks(stage: Stage, jobId: Int) {
logDebug("submitMissingTasks(" + stage + ")")
Expand Down
128 changes: 31 additions & 97 deletions core/src/main/scala/org/apache/spark/scheduler/ResultTask.scala
Original file line number Diff line number Diff line change
Expand Up @@ -17,134 +17,68 @@

package org.apache.spark.scheduler

import scala.language.existentials
import java.nio.ByteBuffer

import java.io._
import java.util.zip.{GZIPInputStream, GZIPOutputStream}

import scala.collection.mutable.HashMap

import org.apache.spark._
import org.apache.spark.rdd.{RDD, RDDCheckpointData}

private[spark] object ResultTask {

// A simple map between the stage id to the serialized byte array of a task.
// Served as a cache for task serialization because serialization can be
// expensive on the master node if it needs to launch thousands of tasks.
private val serializedInfoCache = new HashMap[Int, Array[Byte]]

def serializeInfo(stageId: Int, rdd: RDD[_], func: (TaskContext, Iterator[_]) => _): Array[Byte] =
{
synchronized {
val old = serializedInfoCache.get(stageId).orNull
if (old != null) {
old
} else {
val out = new ByteArrayOutputStream
val ser = SparkEnv.get.closureSerializer.newInstance()
val objOut = ser.serializeStream(new GZIPOutputStream(out))
objOut.writeObject(rdd)
objOut.writeObject(func)
objOut.close()
val bytes = out.toByteArray
serializedInfoCache.put(stageId, bytes)
bytes
}
}
}

def deserializeInfo(stageId: Int, bytes: Array[Byte]): (RDD[_], (TaskContext, Iterator[_]) => _) =
{
val in = new GZIPInputStream(new ByteArrayInputStream(bytes))
val ser = SparkEnv.get.closureSerializer.newInstance()
val objIn = ser.deserializeStream(in)
val rdd = objIn.readObject().asInstanceOf[RDD[_]]
val func = objIn.readObject().asInstanceOf[(TaskContext, Iterator[_]) => _]
(rdd, func)
}

def removeStage(stageId: Int) {
serializedInfoCache.remove(stageId)
}

def clearCache() {
synchronized {
serializedInfoCache.clear()
}
}
}

import org.apache.spark.broadcast.Broadcast
import org.apache.spark.rdd.RDD

/**
* A task that sends back the output to the driver application.
*
* See [[org.apache.spark.scheduler.Task]] for more information.
* See [[Task]] for more information.
*
* @param stageId id of the stage this task belongs to
* @param rdd input to func
* @param rddBinary broadcast version of of the serialized RDD
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

*of - ha!

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

also past tense -- broadcasted

* @param func a function to apply on a partition of the RDD
* @param _partitionId index of the number in the RDD
* @param partition partition of the RDD this task is associated with
* @param locs preferred task execution locations for locality scheduling
* @param outputId index of the task in this job (a job can launch tasks on only a subset of the
* input RDD's partitions).
*/
private[spark] class ResultTask[T, U](
stageId: Int,
var rdd: RDD[T],
var func: (TaskContext, Iterator[T]) => U,
_partitionId: Int,
val rddBinary: Broadcast[Array[Byte]],
val func: (TaskContext, Iterator[T]) => U,
val partition: Partition,
@transient locs: Seq[TaskLocation],
var outputId: Int)
extends Task[U](stageId, _partitionId) with Externalizable {

def this() = this(0, null, null, 0, null, 0)

var split = if (rdd == null) null else rdd.partitions(partitionId)
val outputId: Int)
extends Task[U](stageId, partition.index) with Serializable {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is partitionId the same thing as partition.index?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@mateiz and I looked and it seems so.


// TODO: Should we also broadcast func? For that we would need a place to
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Perhaps we can just turn this into a JIRA rather than keeping it here in the code.

// keep a reference to it (perhaps in DAGScheduler's job object).

def this(
stageId: Int,
rdd: RDD[T],
func: (TaskContext, Iterator[T]) => U,
partitionId: Int,
locs: Seq[TaskLocation],
outputId: Int) = {
this(stageId, rdd.broadcasted, func, rdd.partitions(partitionId), locs, outputId)
}

@transient private val preferredLocs: Seq[TaskLocation] = {
@transient private[this] val preferredLocs: Seq[TaskLocation] = {
if (locs == null) Nil else locs.toSet.toSeq
}

override def runTask(context: TaskContext): U = {
// Deserialize the RDD using the broadcast variable.
val ser = SparkEnv.get.closureSerializer.newInstance()
val rdd = ser.deserialize[RDD[T]](ByteBuffer.wrap(rddBinary.value),
Thread.currentThread.getContextClassLoader)
metrics = Some(context.taskMetrics)
try {
func(context, rdd.iterator(split, context))
func(context, rdd.iterator(partition, context))
} finally {
context.executeOnCompleteCallbacks()
}
}

// This is only callable on the driver side.
override def preferredLocations: Seq[TaskLocation] = preferredLocs

override def toString = "ResultTask(" + stageId + ", " + partitionId + ")"

override def writeExternal(out: ObjectOutput) {
RDDCheckpointData.synchronized {
split = rdd.partitions(partitionId)
out.writeInt(stageId)
val bytes = ResultTask.serializeInfo(
stageId, rdd, func.asInstanceOf[(TaskContext, Iterator[_]) => _])
out.writeInt(bytes.length)
out.write(bytes)
out.writeInt(partitionId)
out.writeInt(outputId)
out.writeLong(epoch)
out.writeObject(split)
}
}

override def readExternal(in: ObjectInput) {
val stageId = in.readInt()
val numBytes = in.readInt()
val bytes = new Array[Byte](numBytes)
in.readFully(bytes)
val (rdd_, func_) = ResultTask.deserializeInfo(stageId, bytes)
rdd = rdd_.asInstanceOf[RDD[T]]
func = func_.asInstanceOf[(TaskContext, Iterator[T]) => U]
partitionId = in.readInt()
outputId = in.readInt()
epoch = in.readLong()
split = in.readObject().asInstanceOf[Partition]
}
}
Loading