Skip to content

Commit

Permalink
make if allowing duplicate update as an option of accumulator
Browse files Browse the repository at this point in the history
  • Loading branch information
Nan Zhu authored and Nan Zhu committed Sep 25, 2014
1 parent 74fb2ec commit af3ba6c
Show file tree
Hide file tree
Showing 4 changed files with 184 additions and 57 deletions.
34 changes: 20 additions & 14 deletions core/src/main/scala/org/apache/spark/Accumulators.scala
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
package org.apache.spark

import java.io.{ObjectInputStream, Serializable}
import java.util.concurrent.atomic.AtomicLong

import scala.collection.generic.Growable
import scala.collection.mutable.Map
Expand All @@ -43,11 +44,13 @@ import org.apache.spark.serializer.JavaSerializer
class Accumulable[R, T] (
@transient initialValue: R,
param: AccumulableParam[R, T],
val name: Option[String])
val name: Option[String],
val allowDuplicate: Boolean = true)
extends Serializable {

def this(@transient initialValue: R, param: AccumulableParam[R, T]) =
this(initialValue, param, None)
def this(@transient initialValue: R, param: AccumulableParam[R, T],
allowDuplicate: Boolean = true) =
this(initialValue, param, None, allowDuplicate)

val id: Long = Accumulators.newId

Expand Down Expand Up @@ -225,9 +228,13 @@ GrowableAccumulableParam[R <% Growable[T] with TraversableOnce[T] with Serializa
* @param param helper object defining how to add elements of type `T`
* @tparam T result type
*/
class Accumulator[T](@transient initialValue: T, param: AccumulatorParam[T], name: Option[String])
extends Accumulable[T,T](initialValue, param, name) {
class Accumulator[T](@transient initialValue: T, param: AccumulatorParam[T],
name: Option[String], allowDuplicate: Boolean = true)
extends Accumulable[T,T](initialValue, param, name, allowDuplicate) {
def this(initialValue: T, param: AccumulatorParam[T]) = this(initialValue, param, None)

def this(initialValue: T, param: AccumulatorParam[T], allowDuplicate: Boolean) =
this(initialValue, param, None, allowDuplicate)
}

/**
Expand All @@ -251,10 +258,9 @@ private object Accumulators {
val localAccums = Map[Thread, Map[Long, Accumulable[_, _]]]()
var lastId: Long = 0

def newId: Long = synchronized {
lastId += 1
lastId
}
private val nextAccumID = new AtomicLong(0)

def newId: Long = nextAccumID.getAndIncrement

def register(a: Accumulable[_, _], original: Boolean): Unit = synchronized {
if (original) {
Expand All @@ -265,6 +271,8 @@ private object Accumulators {
}
}

def isAllowDuplicate(id: Long) = originals(id).allowDuplicate

// Clear the local (non-original) accumulators for the current thread
def clear() {
synchronized {
Expand All @@ -282,11 +290,9 @@ private object Accumulators {
}

// Add values to the original accumulators with some given IDs
def add(values: Map[Long, Any]): Unit = synchronized {
for ((id, value) <- values) {
if (originals.contains(id)) {
originals(id).asInstanceOf[Accumulable[Any, Any]] ++= value
}
def add(value: (Long, Any)): Unit = synchronized {
if (originals.contains(value._1)) {
originals(value._1).asInstanceOf[Accumulable[Any, Any]] ++= value._2
}
}

Expand Down
19 changes: 11 additions & 8 deletions core/src/main/scala/org/apache/spark/SparkContext.scala
Original file line number Diff line number Diff line change
Expand Up @@ -744,8 +744,9 @@ class SparkContext(config: SparkConf) extends Logging {
* Create an [[org.apache.spark.Accumulator]] variable of a given type, which tasks can "add"
* values to using the `+=` method. Only the driver can access the accumulator's `value`.
*/
def accumulator[T](initialValue: T)(implicit param: AccumulatorParam[T]) =
new Accumulator(initialValue, param)
def accumulator[T](initialValue: T, allowDuplicate: Boolean = true)
(implicit param: AccumulatorParam[T]) =
new Accumulator(initialValue, param, None, allowDuplicate)

/**
* Create an [[org.apache.spark.Accumulator]] variable of a given type, with a name for display
Expand All @@ -762,8 +763,9 @@ class SparkContext(config: SparkConf) extends Logging {
* @tparam T accumulator type
* @tparam R type that can be added to the accumulator
*/
def accumulable[T, R](initialValue: T)(implicit param: AccumulableParam[T, R]) =
new Accumulable(initialValue, param)
def accumulable[T, R](initialValue: T, allowDuplicate: Boolean = true)
(implicit param: AccumulableParam[T, R]) =
new Accumulable(initialValue, param, allowDuplicate)

/**
* Create an [[org.apache.spark.Accumulable]] shared variable, with a name for display in the
Expand All @@ -772,8 +774,9 @@ class SparkContext(config: SparkConf) extends Logging {
* @tparam T accumulator type
* @tparam R type that can be added to the accumulator
*/
def accumulable[T, R](initialValue: T, name: String)(implicit param: AccumulableParam[T, R]) =
new Accumulable(initialValue, param, Some(name))
def accumulable[T, R](initialValue: T, name: String, allowDuplicate: Boolean = true)
(implicit param: AccumulableParam[T, R]) =
new Accumulable(initialValue, param, Some(name), allowDuplicate)

/**
* Create an accumulator from a "mutable collection" type.
Expand All @@ -782,9 +785,9 @@ class SparkContext(config: SparkConf) extends Logging {
* standard mutable collections. So you can use this with mutable Map, Set, etc.
*/
def accumulableCollection[R <% Growable[T] with TraversableOnce[T] with Serializable: ClassTag, T]
(initialValue: R): Accumulable[R, T] = {
(initialValue: R, allowDuplicate: Boolean = true): Accumulable[R, T] = {
val param = new GrowableAccumulableParam[R,T]
new Accumulable(initialValue, param)
new Accumulable(initialValue, param, allowDuplicate)
}

/**
Expand Down
99 changes: 70 additions & 29 deletions core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ import java.io.NotSerializableException
import java.util.Properties
import java.util.concurrent.atomic.AtomicInteger

import scala.collection.mutable.{ArrayBuffer, HashMap, HashSet, Map, Stack}
import scala.collection.mutable.{ArrayBuffer, HashMap, HashSet, Map, Stack, ListBuffer}
import scala.concurrent.Await
import scala.concurrent.duration._
import scala.language.postfixOps
Expand Down Expand Up @@ -112,6 +112,10 @@ class DAGScheduler(
// stray messages to detect.
private val failedEpoch = new HashMap[String, Long]

// stageId => (SplitId -> (accumulatorId, accumulatorValue))
private[scheduler] val stageIdToAccumulators = new HashMap[Int,
HashMap[Int, ListBuffer[(Long, Any)]]]

private val dagSchedulerActorSupervisor =
env.actorSystem.actorOf(Props(new DAGSchedulerActorSupervisor(this)))

Expand Down Expand Up @@ -406,6 +410,66 @@ class DAGScheduler(
updateJobIdStageIdMapsList(List(stage))
}

def removeStage(stageId: Int) {
// data structures based on Stage
for (stage <- stageIdToStage.get(stageId)) {
if (runningStages.contains(stage)) {
logDebug("Removing running stage %d".format(stageId))
runningStages -= stage
}
for ((k, v) <- shuffleToMapStage.find(_._2 == stage)) {
shuffleToMapStage.remove(k)
}
if (waitingStages.contains(stage)) {
logDebug("Removing stage %d from waiting set.".format(stageId))
waitingStages -= stage
}
if (failedStages.contains(stage)) {
logDebug("Removing stage %d from failed set.".format(stageId))
failedStages -= stage
}
}
// data structures based on StageId
stageIdToStage -= stageId

// accumulate acc values, if the stage is aborted, its accumulators
// will not be calculated, since we have removed it in abortStage()
for (partitionIdToAccum <- stageIdToAccumulators.get(stageId);
accumulators <- partitionIdToAccum.values;
accum <- accumulators) {
Accumulators.add(accum)
}

stageIdToAccumulators -= stageId

logDebug("After removal of stage %d, remaining stages = %d"
.format(stageId, stageIdToStage.size))
}

/**
* detect the duplicate accumulator value and save the accumulator values
* @param accumValue the accumulator values received from the task
* @param stage the stage which the task belongs to
* @param task the completed task
*/
private def saveAccumulatorValue(accumValue: Map[Long, Any], stage: Stage, task: Task[_]) {
if (accumValue != null) {
for ((id, value) <- accumValue) {
if (Accumulators.isAllowDuplicate(id)) {
Accumulators.add((id, value))
} else {
if (!stageIdToAccumulators.contains(stage.id) ||
!stageIdToAccumulators(stage.id).contains(task.partitionId)) {
val accum = stageIdToAccumulators.getOrElseUpdate(stage.id,
new HashMap[Int, ListBuffer[(Long, Any)]]).
getOrElseUpdate(task.partitionId, new ListBuffer[(Long, Any)])
accum += id -> value
}
}
}
}
}

/**
* Removes state for job and any stages that are not needed by any other job. Does not
* handle cancelling tasks or notifying the SparkListener about finished jobs/stages/tasks.
Expand All @@ -425,32 +489,6 @@ class DAGScheduler(
"Job %d not registered for stage %d even though that stage was registered for the job"
.format(job.jobId, stageId))
} else {
def removeStage(stageId: Int) {
// data structures based on Stage
for (stage <- stageIdToStage.get(stageId)) {
if (runningStages.contains(stage)) {
logDebug("Removing running stage %d".format(stageId))
runningStages -= stage
}
for ((k, v) <- shuffleToMapStage.find(_._2 == stage)) {
shuffleToMapStage.remove(k)
}
if (waitingStages.contains(stage)) {
logDebug("Removing stage %d from waiting set.".format(stageId))
waitingStages -= stage
}
if (failedStages.contains(stage)) {
logDebug("Removing stage %d from failed set.".format(stageId))
failedStages -= stage
}
}
// data structures based on StageId
stageIdToStage -= stageId

logDebug("After removal of stage %d, remaining stages = %d"
.format(stageId, stageIdToStage.size))
}

jobSet -= job.jobId
if (jobSet.isEmpty) { // no other job needs this stage
removeStage(stageId)
Expand Down Expand Up @@ -939,7 +977,7 @@ class DAGScheduler(
case Success =>
if (event.accumUpdates != null) {
try {
Accumulators.add(event.accumUpdates)
saveAccumulatorValue(event.accumUpdates, stage, task)
event.accumUpdates.foreach { case (id, partialValue) =>
val acc = Accumulators.originals(id).asInstanceOf[Accumulable[Any, Any]]
// To avoid UI cruft, ignore cases where value wasn't updated
Expand Down Expand Up @@ -1076,7 +1114,7 @@ class DAGScheduler(
}
failedStages += failedStage
failedStages += mapStage

stageIdToAccumulators -= failedStage.id
// Mark the map whose fetch failed as broken in the map stage
if (mapId != -1) {
mapStage.removeOutputLoc(mapId, bmAddress)
Expand Down Expand Up @@ -1211,6 +1249,9 @@ class DAGScheduler(
} else {
// This is the only job that uses this stage, so fail the stage if it is running.
val stage = stageIdToStage(stageId)
// remove StageIdToAccumulators(id) ensuring that the aborted stage
// accumulator is not calculated when the stage is finished successfully
stageIdToAccumulators -= stage.id
if (runningStages.contains(stage)) {
try { // cancelTasks will fail if a SchedulerBackend does not implement killTask
taskScheduler.cancelTasks(stageId, shouldInterruptThread)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -207,7 +207,18 @@ class DAGSchedulerSuite extends TestKit(ActorSystem("DAGSchedulerSuite")) with F
assert(taskSet.tasks.size >= results.size)
for ((result, i) <- results.zipWithIndex) {
if (i < taskSet.tasks.size) {
runEvent(CompletionEvent(taskSet.tasks(i), result._1, result._2, Map[Long, Any](), null, null))
runEvent(CompletionEvent(taskSet.tasks(i), result._1, result._2, null, null, null))
}
}
}

private def completeWithAccumulator(accumId: Long, taskSet: TaskSet,
results: Seq[(TaskEndReason, Any)]) {
assert(taskSet.tasks.size >= results.size)
for ((result, i) <- results.zipWithIndex) {
if (i < taskSet.tasks.size) {
runEvent(CompletionEvent(taskSet.tasks(i), result._1, result._2,
Map[Long, Any]((accumId, 1)), null, null))
}
}
}
Expand Down Expand Up @@ -493,17 +504,16 @@ class DAGSchedulerSuite extends TestKit(ActorSystem("DAGSchedulerSuite")) with F
runEvent(ExecutorLost("exec-hostA"))
val newEpoch = mapOutputTracker.getEpoch
assert(newEpoch > oldEpoch)
val noAccum = Map[Long, Any]()
val taskSet = taskSets(0)
// should be ignored for being too old
runEvent(CompletionEvent(taskSet.tasks(0), Success, makeMapStatus("hostA", 1), noAccum, null, null))
runEvent(CompletionEvent(taskSet.tasks(0), Success, makeMapStatus("hostA", 1), null, null, null))
// should work because it's a non-failed host
runEvent(CompletionEvent(taskSet.tasks(0), Success, makeMapStatus("hostB", 1), noAccum, null, null))
runEvent(CompletionEvent(taskSet.tasks(0), Success, makeMapStatus("hostB", 1), null, null, null))
// should be ignored for being too old
runEvent(CompletionEvent(taskSet.tasks(0), Success, makeMapStatus("hostA", 1), noAccum, null, null))
runEvent(CompletionEvent(taskSet.tasks(0), Success, makeMapStatus("hostA", 1), null, null, null))
// should work because it's a new epoch
taskSet.tasks(1).epoch = newEpoch
runEvent(CompletionEvent(taskSet.tasks(1), Success, makeMapStatus("hostA", 1), noAccum, null, null))
runEvent(CompletionEvent(taskSet.tasks(1), Success, makeMapStatus("hostA", 1), null, null, null))
assert(mapOutputTracker.getServerStatuses(shuffleId, 0).map(_._1) ===
Array(makeBlockManagerId("hostB"), makeBlockManagerId("hostA")))
complete(taskSets(1), Seq((Success, 42), (Success, 43)))
Expand Down Expand Up @@ -728,6 +738,72 @@ class DAGSchedulerSuite extends TestKit(ActorSystem("DAGSchedulerSuite")) with F
assert(scheduler.sc.dagScheduler === null)
}

test("accumulator allowing duplication can be calculated correctly") {
val accum = new Accumulator[Int](0, SparkContext.IntAccumulatorParam)
val shuffleOneRdd = new MyRDD(sc, 2, Nil)
val shuffleDepOne = new ShuffleDependency(shuffleOneRdd, null)
val shuffleTwoRdd = new MyRDD(sc, 2, List(shuffleDepOne))
val shuffleDepTwo = new ShuffleDependency(shuffleTwoRdd, null)
val finalRdd = new MyRDD(sc, 1, List(shuffleDepTwo))
submit(finalRdd, Array(0))
// have the first stage complete normally
completeWithAccumulator(accum.id, taskSets(0), Seq(
(Success, makeMapStatus("hostA", 2)),
(Success, makeMapStatus("hostB", 2))))
// have the second stage complete normally
completeWithAccumulator(accum.id, taskSets(1), Seq(
(Success, makeMapStatus("hostA", 1)),
(Success, makeMapStatus("hostC", 1))))
// fail the third stage because hostA went down
completeWithAccumulator(accum.id, taskSets(2), Seq(
(FetchFailed(makeBlockManagerId("hostA"), shuffleDepTwo.shuffleId, 0, 0), null)))
scheduler.resubmitFailedStages()
completeWithAccumulator(accum.id, taskSets(3), Seq((Success, makeMapStatus("hostA", 2))))
completeWithAccumulator(accum.id, taskSets(4), Seq((Success, makeMapStatus("hostA", 1))))
completeWithAccumulator(accum.id, taskSets(5), Seq((Success, 42)))
assert(results === Map(0 -> 42))
assert(Accumulators.originals(accum.id).value === 7)
assertDataStructuresEmpty
}

test("accumulator not allowing duplication is not calculated for resubmitted stage") {
//just for register
val accum = new Accumulator[Int](0, SparkContext.IntAccumulatorParam, false)
val shuffleOneRdd = new MyRDD(sc, 2, Nil)
val shuffleDepOne = new ShuffleDependency(shuffleOneRdd, null)
val shuffleTwoRdd = new MyRDD(sc, 2, List(shuffleDepOne))
val shuffleDepTwo = new ShuffleDependency(shuffleTwoRdd, null)
val finalRdd = new MyRDD(sc, 1, List(shuffleDepTwo))
submit(finalRdd, Array(0))
// have the first stage complete normally
completeWithAccumulator(accum.id, taskSets(0), Seq(
(Success, makeMapStatus("hostA", 2)),
(Success, makeMapStatus("hostB", 2))))
// have the second stage complete normally
completeWithAccumulator(accum.id, taskSets(1), Seq(
(Success, makeMapStatus("hostA", 1)),
(Success, makeMapStatus("hostC", 1))))
// fail the third stage because hostA went down
completeWithAccumulator(accum.id, taskSets(2), Seq(
(FetchFailed(makeBlockManagerId("hostA"), shuffleDepTwo.shuffleId, 0, 0), null)))
scheduler.resubmitFailedStages()
completeWithAccumulator(accum.id, taskSets(3), Seq((Success, makeMapStatus("hostA", 2))))
completeWithAccumulator(accum.id, taskSets(4), Seq((Success, makeMapStatus("hostA", 1))))
completeWithAccumulator(accum.id, taskSets(5), Seq((Success, 42)))
assert(results === Map(0 -> 42))
assert(Accumulators.originals(accum.id).value === 5)
assertDataStructuresEmpty
}

test("accumulator is cleared for aborted stages") {
//just for register
new Accumulator[Int](0, SparkContext.IntAccumulatorParam)
val rdd = new MyRDD(sc, 2, Nil)
submit(rdd, Array(0))
failed(taskSets(0), "tastset failed")
assertDataStructuresEmpty
}

/**
* Assert that the supplied TaskSet has exactly the given hosts as its preferred locations.
* Note that this checks only the host and not the executor ID.
Expand All @@ -754,6 +830,7 @@ class DAGSchedulerSuite extends TestKit(ActorSystem("DAGSchedulerSuite")) with F
assert(scheduler.runningStages.isEmpty)
assert(scheduler.shuffleToMapStage.isEmpty)
assert(scheduler.waitingStages.isEmpty)
assert(scheduler.stageIdToAccumulators.isEmpty)
}
}

0 comments on commit af3ba6c

Please sign in to comment.