diff --git a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala index cb8ccfbdbdcbb..9759963de5f7b 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala @@ -360,37 +360,6 @@ class DAGScheduler( parents } - private def getMissingParentStages(stage: Stage): List[Stage] = { - val missing = new HashSet[Stage] - val visited = new HashSet[RDD[_]] - // We are manually maintaining a stack here to prevent StackOverflowError - // caused by recursively visiting - val waitingForVisit = new Stack[RDD[_]] - def visit(rdd: RDD[_]) { - if (!visited(rdd)) { - visited += rdd - if (getCacheLocs(rdd).contains(Nil)) { - for (dep <- rdd.dependencies) { - dep match { - case shufDep: ShuffleDependency[_, _, _] => - val mapStage = getShuffleMapStage(shufDep, stage.jobId) - if (!mapStage.isAvailable) { - missing += mapStage - } - case narrowDep: NarrowDependency[_] => - waitingForVisit.push(narrowDep.rdd) - } - } - } - } - } - waitingForVisit.push(stage.rdd) - while (!waitingForVisit.isEmpty) { - visit(waitingForVisit.pop()) - } - missing.toList - } - /** * Registers the given jobId among the jobs that need the given stage and * all of that stage's ancestors. @@ -401,8 +370,7 @@ class DAGScheduler( val s = stages.head s.jobIds += jobId jobIdToStageIds.getOrElseUpdate(jobId, new HashSet[Int]()) += s.id - val parents: List[Stage] = getParentStages(s.rdd, jobId) - val parentsWithoutThisJobId = parents.filter { ! _.jobIds.contains(jobId) } + val parentsWithoutThisJobId = stage.parents.filter { ! _.jobIds.contains(jobId) } updateJobIdStageIdMapsList(parentsWithoutThisJobId ++ stages.tail) } } @@ -745,7 +713,7 @@ class DAGScheduler( job.jobId, callSite.shortForm, partitions.length, allowLocal)) logInfo("Final stage: " + finalStage + "(" + finalStage.name + ")") logInfo("Parents of final stage: " + finalStage.parents) - logInfo("Missing parents: " + getMissingParentStages(finalStage)) + logInfo("Missing parents: " + finalStage.missingParents) val shouldRunLocally = localExecutionEnabled && allowLocal && finalStage.parents.isEmpty && partitions.length == 1 if (shouldRunLocally) { @@ -771,7 +739,7 @@ class DAGScheduler( if (jobId.isDefined) { logDebug("submitStage(" + stage + ")") if (!waitingStages(stage) && !runningStages(stage) && !failedStages(stage)) { - val missing = getMissingParentStages(stage).sortBy(_.id) + val missing = stage.missingParents.sortBy(_.id) logDebug("missing: " + missing) if (missing == Nil) { logInfo("Submitting " + stage + " (" + stage.rdd + "), which has no missing parents") @@ -1040,9 +1008,9 @@ class DAGScheduler( } else { val newlyRunnable = new ArrayBuffer[Stage] for (stage <- waitingStages) { - logInfo("Missing parents for " + stage + ": " + getMissingParentStages(stage)) + logInfo(s"Missing parents for $stage: ${stage.missingParents}") } - for (stage <- waitingStages if getMissingParentStages(stage) == Nil) { + for (stage <- waitingStages if stage.missingParents == Nil) { newlyRunnable += stage } waitingStages --= newlyRunnable @@ -1197,7 +1165,7 @@ class DAGScheduler( return } val dependentJobs: Seq[ActiveJob] = - activeJobs.filter(job => stageDependsOn(job.finalStage, failedStage)).toSeq + activeJobs.filter(job => job.finalStage.dependsOn(failedStage)).toSeq failedStage.latestInfo.completionTime = Some(clock.getTime()) for (job <- dependentJobs) { failJobAndIndependentStages(job, s"Job aborted due to stage failure: $reason") @@ -1257,42 +1225,6 @@ class DAGScheduler( } } - /** - * Return true if one of stage's ancestors is target. - */ - private def stageDependsOn(stage: Stage, target: Stage): Boolean = { - if (stage == target) { - return true - } - val visitedRdds = new HashSet[RDD[_]] - val visitedStages = new HashSet[Stage] - // We are manually maintaining a stack here to prevent StackOverflowError - // caused by recursively visiting - val waitingForVisit = new Stack[RDD[_]] - def visit(rdd: RDD[_]) { - if (!visitedRdds(rdd)) { - visitedRdds += rdd - for (dep <- rdd.dependencies) { - dep match { - case shufDep: ShuffleDependency[_, _, _] => - val mapStage = getShuffleMapStage(shufDep, stage.jobId) - if (!mapStage.isAvailable) { - visitedStages += mapStage - waitingForVisit.push(mapStage.rdd) - } // Otherwise there's no need to follow the dependency back - case narrowDep: NarrowDependency[_] => - waitingForVisit.push(narrowDep.rdd) - } - } - } - } - waitingForVisit.push(stage.rdd) - while (!waitingForVisit.isEmpty) { - visit(waitingForVisit.pop()) - } - visitedRdds.contains(target.rdd) - } - /** * Synchronized method that might be called from other threads. * @param rdd whose partitions are to be looked at diff --git a/core/src/main/scala/org/apache/spark/scheduler/Stage.scala b/core/src/main/scala/org/apache/spark/scheduler/Stage.scala index cc13f57a49b89..620b92a35f928 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/Stage.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/Stage.scala @@ -85,6 +85,17 @@ private[spark] class Stage( } } + def missingParents: List[Stage] = { + parents.filterNot(_.isAvailable) + } + + /** + * Returns true if one of this stage's ancestors is `otherStage`. + */ + def dependsOn(otherStage: Stage): Boolean = { + parents.exists(_.rdd == otherStage.rdd) + } + def addOutputLoc(partition: Int, status: MapStatus) { val prevList = outputLocs(partition) outputLocs(partition) = status :: prevList