Skip to content

Commit

Permalink
Clean up DAGScheduler getMissingParentStages / stageDependsOn methods.
Browse files Browse the repository at this point in the history
  • Loading branch information
JoshRosen committed Nov 29, 2014
1 parent 48223d8 commit 1ab3d6d
Show file tree
Hide file tree
Showing 2 changed files with 17 additions and 73 deletions.
79 changes: 6 additions & 73 deletions core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala
Original file line number Diff line number Diff line change
Expand Up @@ -360,37 +360,6 @@ class DAGScheduler(
parents
}

private def getMissingParentStages(stage: Stage): List[Stage] = {
val missing = new HashSet[Stage]
val visited = new HashSet[RDD[_]]
// We are manually maintaining a stack here to prevent StackOverflowError
// caused by recursively visiting
val waitingForVisit = new Stack[RDD[_]]
def visit(rdd: RDD[_]) {
if (!visited(rdd)) {
visited += rdd
if (getCacheLocs(rdd).contains(Nil)) {
for (dep <- rdd.dependencies) {
dep match {
case shufDep: ShuffleDependency[_, _, _] =>
val mapStage = getShuffleMapStage(shufDep, stage.jobId)
if (!mapStage.isAvailable) {
missing += mapStage
}
case narrowDep: NarrowDependency[_] =>
waitingForVisit.push(narrowDep.rdd)
}
}
}
}
}
waitingForVisit.push(stage.rdd)
while (!waitingForVisit.isEmpty) {
visit(waitingForVisit.pop())
}
missing.toList
}

/**
* Registers the given jobId among the jobs that need the given stage and
* all of that stage's ancestors.
Expand All @@ -401,7 +370,7 @@ class DAGScheduler(
val s = stages.head
s.jobIds += jobId
jobIdToStageIds.getOrElseUpdate(jobId, new HashSet[Int]()) += s.id
val parents: List[Stage] = getParentStages(s.rdd, jobId)
val parents: List[Stage] = stage.parents
val parentsWithoutThisJobId = parents.filter { ! _.jobIds.contains(jobId) }
updateJobIdStageIdMapsList(parentsWithoutThisJobId ++ stages.tail)
}
Expand Down Expand Up @@ -745,7 +714,7 @@ class DAGScheduler(
job.jobId, callSite.shortForm, partitions.length, allowLocal))
logInfo("Final stage: " + finalStage + "(" + finalStage.name + ")")
logInfo("Parents of final stage: " + finalStage.parents)
logInfo("Missing parents: " + getMissingParentStages(finalStage))
logInfo("Missing parents: " + finalStage.missingParents)
val shouldRunLocally =
localExecutionEnabled && allowLocal && finalStage.parents.isEmpty && partitions.length == 1
if (shouldRunLocally) {
Expand All @@ -771,7 +740,7 @@ class DAGScheduler(
if (jobId.isDefined) {
logDebug("submitStage(" + stage + ")")
if (!waitingStages(stage) && !runningStages(stage) && !failedStages(stage)) {
val missing = getMissingParentStages(stage).sortBy(_.id)
val missing = stage.missingParents.sortBy(_.id)
logDebug("missing: " + missing)
if (missing == Nil) {
logInfo("Submitting " + stage + " (" + stage.rdd + "), which has no missing parents")
Expand Down Expand Up @@ -1040,9 +1009,9 @@ class DAGScheduler(
} else {
val newlyRunnable = new ArrayBuffer[Stage]
for (stage <- waitingStages) {
logInfo("Missing parents for " + stage + ": " + getMissingParentStages(stage))
logInfo(s"Missing parents for $stage: ${stage.missingParents}")
}
for (stage <- waitingStages if getMissingParentStages(stage) == Nil) {
for (stage <- waitingStages if stage.missingParents == Nil) {
newlyRunnable += stage
}
waitingStages --= newlyRunnable
Expand Down Expand Up @@ -1197,7 +1166,7 @@ class DAGScheduler(
return
}
val dependentJobs: Seq[ActiveJob] =
activeJobs.filter(job => stageDependsOn(job.finalStage, failedStage)).toSeq
activeJobs.filter(job => job.finalStage.dependsOn(failedStage)).toSeq
failedStage.latestInfo.completionTime = Some(clock.getTime())
for (job <- dependentJobs) {
failJobAndIndependentStages(job, s"Job aborted due to stage failure: $reason")
Expand Down Expand Up @@ -1257,42 +1226,6 @@ class DAGScheduler(
}
}

/**
* Return true if one of stage's ancestors is target.
*/
private def stageDependsOn(stage: Stage, target: Stage): Boolean = {
if (stage == target) {
return true
}
val visitedRdds = new HashSet[RDD[_]]
val visitedStages = new HashSet[Stage]
// We are manually maintaining a stack here to prevent StackOverflowError
// caused by recursively visiting
val waitingForVisit = new Stack[RDD[_]]
def visit(rdd: RDD[_]) {
if (!visitedRdds(rdd)) {
visitedRdds += rdd
for (dep <- rdd.dependencies) {
dep match {
case shufDep: ShuffleDependency[_, _, _] =>
val mapStage = getShuffleMapStage(shufDep, stage.jobId)
if (!mapStage.isAvailable) {
visitedStages += mapStage
waitingForVisit.push(mapStage.rdd)
} // Otherwise there's no need to follow the dependency back
case narrowDep: NarrowDependency[_] =>
waitingForVisit.push(narrowDep.rdd)
}
}
}
}
waitingForVisit.push(stage.rdd)
while (!waitingForVisit.isEmpty) {
visit(waitingForVisit.pop())
}
visitedRdds.contains(target.rdd)
}

/**
* Synchronized method that might be called from other threads.
* @param rdd whose partitions are to be looked at
Expand Down
11 changes: 11 additions & 0 deletions core/src/main/scala/org/apache/spark/scheduler/Stage.scala
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,17 @@ private[spark] class Stage(
}
}

def missingParents: List[Stage] = {
parents.filterNot(_.isAvailable)
}

/**
* Returns true if one of this stage's ancestors is `otherStage`.
*/
def dependsOn(otherStage: Stage): Boolean = {
parents.exists(_.rdd == otherStage.rdd)
}

def addOutputLoc(partition: Int, status: MapStatus) {
val prevList = outputLocs(partition)
outputLocs(partition) = status :: prevList
Expand Down

0 comments on commit 1ab3d6d

Please sign in to comment.