Skip to content

Commit

Permalink
[SPARK-31082][CORE] MapOutputTrackerMaster.getMapLocation should hand…
Browse files Browse the repository at this point in the history
…le last mapIndex correctly

### What changes were proposed in this pull request?

In `getMapLocation`, change the condition from `...endMapIndex < statuses.length` to `...endMapIndex <= statuses.length`.

### Why are the changes needed?

`endMapIndex` is exclusive, we should include it when comparing to `statuses.length`. Otherwise, we can't get the location for last mapIndex.

### Does this PR introduce any user-facing change?

No.

### How was this patch tested?

Updated existed test.

Closes #27850 from Ngone51/fix_getmaploction.

Authored-by: yi.wu <yi.wu@databricks.com>
Signed-off-by: Wenchen Fan <wenchen@databricks.com>
  • Loading branch information
Ngone51 authored and cloud-fan committed Mar 9, 2020
1 parent 068bdd4 commit ef51ff9
Show file tree
Hide file tree
Showing 2 changed files with 10 additions and 5 deletions.
5 changes: 3 additions & 2 deletions core/src/main/scala/org/apache/spark/MapOutputTracker.scala
Expand Up @@ -696,7 +696,7 @@ private[spark] class MapOutputTrackerMaster(
*
* @param dep shuffle dependency object
* @param startMapIndex the start map index
* @param endMapIndex the end map index
* @param endMapIndex the end map index (exclusive)
* @return a sequence of locations where task runs.
*/
def getMapLocation(
Expand All @@ -707,7 +707,8 @@ private[spark] class MapOutputTrackerMaster(
val shuffleStatus = shuffleStatuses.get(dep.shuffleId).orNull
if (shuffleStatus != null) {
shuffleStatus.withMapStatuses { statuses =>
if (startMapIndex < endMapIndex && (startMapIndex >= 0 && endMapIndex < statuses.length)) {
if (startMapIndex < endMapIndex &&
(startMapIndex >= 0 && endMapIndex <= statuses.length)) {
val statusesPicked = statuses.slice(startMapIndex, endMapIndex).filter(_ != null)
statusesPicked.map(_.location.host).toSeq
} else {
Expand Down
Expand Up @@ -114,9 +114,13 @@ class AdaptiveQueryExecSuite

val numLocalReaders = collect(plan) {
case reader @ CustomShuffleReaderExec(_, _, LOCAL_SHUFFLE_READER_DESCRIPTION) => reader
}.length

assert(numShuffles === (numLocalReaders + numShufflesWithoutLocalReader))
}
numLocalReaders.foreach { r =>
val rdd = r.execute()
val parts = rdd.partitions
assert(parts.forall(rdd.preferredLocations(_).nonEmpty))
}
assert(numShuffles === (numLocalReaders.length + numShufflesWithoutLocalReader))
}

test("Change merge join to broadcast join") {
Expand Down

0 comments on commit ef51ff9

Please sign in to comment.