diff --git a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala index 0c8d437c98bd9..24dfaf4472ebb 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala @@ -683,28 +683,56 @@ private[spark] class DAGScheduler( }.toList } - /** Find ancestor shuffle dependencies that are not registered in shuffleToMapStage yet */ - private def getMissingAncestorShuffleDependencies( - rdd: RDD[_]): ListBuffer[ShuffleDependency[_, _, _]] = { - val ancestors = new ListBuffer[ShuffleDependency[_, _, _]] + /** + * Like [[traverseRDDGraphUntil]], but does not support early termination. For each unvisited RDD, + * calls `visitor(rdd, enqueue)` where `enqueue` can be called to schedule additional RDDs for + * traversal. + */ + private def traverseRDDGraph(rdd: RDD[_])(visitor: (RDD[_], RDD[_] => Unit) => Unit): Unit = { + traverseRDDGraphUntil(rdd) { (r, enqueue) => + visitor(r, enqueue) + true + } + } + + /** + * Traverses the RDD dependency graph using a manually maintained stack to prevent + * StackOverflowError caused by recursive traversal. For each unvisited RDD, calls + * `visitor(rdd, enqueue)` where `enqueue` can be called to schedule additional RDDs for + * traversal. If `visitor` returns `false`, the traversal stops immediately. Returns `true` + * if the traversal completed normally, `false` if it was terminated early by the visitor. + */ + private def traverseRDDGraphUntil( + rdd: RDD[_])(visitor: (RDD[_], RDD[_] => Unit) => Boolean): Boolean = { val visited = new HashSet[RDD[_]] - // We are manually maintaining a stack here to prevent StackOverflowError - // caused by recursively visiting val waitingForVisit = new ListBuffer[RDD[_]] waitingForVisit += rdd + def enqueue(r: RDD[_]): Unit = waitingForVisit.prepend(r) while (waitingForVisit.nonEmpty) { val toVisit = waitingForVisit.remove(0) if (!visited(toVisit)) { visited += toVisit - val (shuffleDeps, _) = getShuffleDependenciesAndResourceProfiles(toVisit) - shuffleDeps.foreach { shuffleDep => - if (!shuffleIdToMapStage.contains(shuffleDep.shuffleId)) { - ancestors.prepend(shuffleDep) - waitingForVisit.prepend(shuffleDep.rdd) - } // Otherwise, the dependency and its ancestors have already been registered. + if (!visitor(toVisit, enqueue)) { + return false } } } + true + } + + /** Find ancestor shuffle dependencies that are not registered in shuffleIdToMapStage yet */ + private def getMissingAncestorShuffleDependencies( + rdd: RDD[_]): ListBuffer[ShuffleDependency[_, _, _]] = { + val ancestors = new ListBuffer[ShuffleDependency[_, _, _]] + traverseRDDGraph(rdd) { (toVisit, enqueue) => + val (shuffleDeps, _) = getShuffleDependenciesAndResourceProfiles(toVisit) + shuffleDeps.foreach { shuffleDep => + if (!shuffleIdToMapStage.contains(shuffleDep.shuffleId)) { + ancestors.prepend(shuffleDep) + enqueue(shuffleDep.rdd) + } // Otherwise, the dependency and its ancestors have already been registered. + } + } ancestors } @@ -725,20 +753,13 @@ private[spark] class DAGScheduler( rdd: RDD[_]): (HashSet[ShuffleDependency[_, _, _]], HashSet[ResourceProfile]) = { val parents = new HashSet[ShuffleDependency[_, _, _]] val resourceProfiles = new HashSet[ResourceProfile] - val visited = new HashSet[RDD[_]] - val waitingForVisit = new ListBuffer[RDD[_]] - waitingForVisit += rdd - while (waitingForVisit.nonEmpty) { - val toVisit = waitingForVisit.remove(0) - if (!visited(toVisit)) { - visited += toVisit - Option(toVisit.getResourceProfile()).foreach(resourceProfiles += _) - toVisit.dependencies.foreach { - case shuffleDep: ShuffleDependency[_, _, _] => - parents += shuffleDep - case dependency => - waitingForVisit.prepend(dependency.rdd) - } + traverseRDDGraph(rdd) { (toVisit, enqueue) => + Option(toVisit.getResourceProfile()).foreach(resourceProfiles += _) + toVisit.dependencies.foreach { + case shuffleDep: ShuffleDependency[_, _, _] => + parents += shuffleDep + case dependency => + enqueue(dependency.rdd) } } (parents, resourceProfiles) @@ -749,100 +770,68 @@ private[spark] class DAGScheduler( * RDDs satisfy a given predicate. */ private def traverseParentRDDsWithinStage(rdd: RDD[_], predicate: RDD[_] => Boolean): Boolean = { - val visited = new HashSet[RDD[_]] - val waitingForVisit = new ListBuffer[RDD[_]] - waitingForVisit += rdd - while (waitingForVisit.nonEmpty) { - val toVisit = waitingForVisit.remove(0) - if (!visited(toVisit)) { - if (!predicate(toVisit)) { - return false - } - visited += toVisit + traverseRDDGraphUntil(rdd) { (toVisit, enqueue) => + if (!predicate(toVisit)) { + false + } else { toVisit.dependencies.foreach { case _: ShuffleDependency[_, _, _] => - // Not within the same stage with current rdd, do nothing. + // Not within the same stage as the current RDD, do nothing. case dependency => - waitingForVisit.prepend(dependency.rdd) + enqueue(dependency.rdd) } + true } } - true } private def getMissingParentStages(stage: Stage): List[Stage] = { val missing = new HashSet[Stage] - val visited = new HashSet[RDD[_]] - // We are manually maintaining a stack here to prevent StackOverflowError - // caused by recursively visiting - val waitingForVisit = new ListBuffer[RDD[_]] - waitingForVisit += stage.rdd - def visit(rdd: RDD[_]): Unit = { - if (!visited(rdd)) { - visited += rdd - val rddHasUncachedPartitions = try { - getCacheLocs(rdd).contains(Nil) - } catch { - case e: RpcTimeoutException => - logWarning(log"Failed to get cache locations for RDD ${MDC(RDD_ID, rdd.id)} due " + - log"to rpc timeout, assuming not fully cached.", e) - true - } - if (rddHasUncachedPartitions) { - for (dep <- rdd.dependencies) { - dep match { - case shufDep: ShuffleDependency[_, _, _] => - val mapStage = getOrCreateShuffleMapStage(shufDep, stage.firstJobId) - // Mark mapStage as available with shuffle outputs only after shuffle merge is - // finalized with push based shuffle. If not, subsequent ShuffleMapStage won't - // read from merged output as the MergeStatuses are not available. - if (!mapStage.isAvailable || !mapStage.shuffleDep.shuffleMergeFinalized) { - missing += mapStage - } else { - // Forward the nextAttemptId if skipped and get visited for the first time. - // Otherwise, once it gets retried, - // 1) the stuffs in stage info become distorting, e.g. task num, input byte, e.t.c - // 2) the first attempt starts from 0-idx, it will not be marked as a retry - mapStage.increaseAttemptIdOnFirstSkip() - } - case narrowDep: NarrowDependency[_] => - waitingForVisit.prepend(narrowDep.rdd) - } + traverseRDDGraph(stage.rdd) { (rdd, enqueue) => + val rddHasUncachedPartitions = try { + getCacheLocs(rdd).contains(Nil) + } catch { + case e: RpcTimeoutException => + logWarning(log"Failed to get cache locations for RDD ${MDC(RDD_ID, rdd.id)} due " + + log"to rpc timeout, assuming not fully cached.", e) + true + } + if (rddHasUncachedPartitions) { + for (dep <- rdd.dependencies) { + dep match { + case shufDep: ShuffleDependency[_, _, _] => + val mapStage = getOrCreateShuffleMapStage(shufDep, stage.firstJobId) + // Mark mapStage as available with shuffle outputs only after shuffle merge is + // finalized with push based shuffle. If not, subsequent ShuffleMapStage won't + // read from merged output as the MergeStatuses are not available. + if (!mapStage.isAvailable || !mapStage.shuffleDep.shuffleMergeFinalized) { + missing += mapStage + } else { + // Forward the nextAttemptId if skipped and visited for the first time. + // Otherwise, once it gets retried: + // 1) the stage info fields become skewed, e.g. task count, input bytes, etc. + // 2) the first attempt starts from 0-idx, it will not be marked as a retry + mapStage.increaseAttemptIdOnFirstSkip() + } + case narrowDep: NarrowDependency[_] => + enqueue(narrowDep.rdd) } } } } - while (waitingForVisit.nonEmpty) { - visit(waitingForVisit.remove(0)) - } missing.toList } /** Invoke `.partitions` on the given RDD and all of its ancestors */ private def eagerlyComputePartitionsForRddAndAncestors(rdd: RDD[_]): Unit = { val startTime = System.nanoTime - val visitedRdds = new HashSet[RDD[_]] - // We are manually maintaining a stack here to prevent StackOverflowError - // caused by recursively visiting - val waitingForVisit = new ListBuffer[RDD[_]] - waitingForVisit += rdd - - def visit(rdd: RDD[_]): Unit = { - if (!visitedRdds(rdd)) { - visitedRdds += rdd - - // Eagerly compute: - rdd.partitions - - for (dep <- rdd.dependencies) { - waitingForVisit.prepend(dep.rdd) - } + traverseRDDGraph(rdd) { (toVisit, enqueue) => + // Eagerly compute: + toVisit.partitions + for (dep <- toVisit.dependencies) { + enqueue(dep.rdd) } } - - while (waitingForVisit.nonEmpty) { - visit(waitingForVisit.remove(0)) - } logDebug("eagerlyComputePartitionsForRddAndAncestors for RDD %d took %f seconds" .format(rdd.id, (System.nanoTime - startTime) / 1e9)) } @@ -3374,31 +3363,24 @@ private[spark] class DAGScheduler( if (stage == target) { return true } - val visitedRdds = new HashSet[RDD[_]] - // We are manually maintaining a stack here to prevent StackOverflowError - // caused by recursively visiting - val waitingForVisit = new ListBuffer[RDD[_]] - waitingForVisit += stage.rdd - def visit(rdd: RDD[_]): Unit = { - if (!visitedRdds(rdd)) { - visitedRdds += rdd + !traverseRDDGraphUntil(stage.rdd) { (rdd, enqueue) => + if (rdd == target.rdd) { + false + } else { for (dep <- rdd.dependencies) { dep match { case shufDep: ShuffleDependency[_, _, _] => val mapStage = getOrCreateShuffleMapStage(shufDep, stage.firstJobId) if (!mapStage.isAvailable) { - waitingForVisit.prepend(mapStage.rdd) + enqueue(mapStage.rdd) } // Otherwise there's no need to follow the dependency back case narrowDep: NarrowDependency[_] => - waitingForVisit.prepend(narrowDep.rdd) + enqueue(narrowDep.rdd) } } + true } } - while (waitingForVisit.nonEmpty) { - visit(waitingForVisit.remove(0)) - } - visitedRdds.contains(target.rdd) } /**