Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SPARK-28560][SQL][followup] code cleanup for local shuffle reader #26128

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 17 additions & 17 deletions core/src/main/scala/org/apache/spark/MapOutputTracker.scala
Original file line number Diff line number Diff line change
Expand Up @@ -340,17 +340,17 @@ private[spark] abstract class MapOutputTracker(conf: SparkConf) extends Logging
/**
* Called from executors to get the server URIs and output sizes for each shuffle block that
* needs to be read from a given range of map output partitions (startPartition is included but
* endPartition is excluded from the range) and a given mapId.
* endPartition is excluded from the range) and is produced by a specific mapper.
*
* @return A sequence of 2-item tuples, where the first item in the tuple is a BlockManagerId,
* and the second item is a sequence of (shuffle block id, shuffle block size, map index)
* tuples describing the shuffle blocks that are stored at that block manager.
*/
def getMapSizesByExecutorId(
def getMapSizesByMapIndex(
shuffleId: Int,
mapIndex: Int,
startPartition: Int,
endPartition: Int,
mapId: Int): Iterator[(BlockManagerId, Seq[(BlockId, Long, Int)])]
endPartition: Int): Iterator[(BlockManagerId, Seq[(BlockId, Long, Int)])]

/**
* Deletes map output status information for the specified shuffle stage.
Expand Down Expand Up @@ -741,13 +741,12 @@ private[spark] class MapOutputTrackerMaster(
}
}

override def getMapSizesByExecutorId(
override def getMapSizesByMapIndex(
shuffleId: Int,
mapIndex: Int,
startPartition: Int,
endPartition: Int,
mapId: Int)
: Iterator[(BlockManagerId, Seq[(BlockId, Long, Int)])] = {
logDebug(s"Fetching outputs for shuffle $shuffleId, mapId $mapId" +
endPartition: Int): Iterator[(BlockManagerId, Seq[(BlockId, Long, Int)])] = {
logDebug(s"Fetching outputs for shuffle $shuffleId, mapIndex $mapIndex" +
s"partitions $startPartition-$endPartition")
shuffleStatuses.get(shuffleId) match {
case Some (shuffleStatus) =>
Expand All @@ -757,7 +756,7 @@ private[spark] class MapOutputTrackerMaster(
startPartition,
endPartition,
statuses,
Some(mapId))
Some(mapIndex))
}
case None =>
Iterator.empty
Expand Down Expand Up @@ -809,17 +808,17 @@ private[spark] class MapOutputTrackerWorker(conf: SparkConf) extends MapOutputTr
}
}

override def getMapSizesByExecutorId(
override def getMapSizesByMapIndex(
shuffleId: Int,
mapIndex: Int,
startPartition: Int,
endPartition: Int,
mapId: Int) : Iterator[(BlockManagerId, Seq[(BlockId, Long, Int)])] = {
logDebug(s"Fetching outputs for shuffle $shuffleId, mapId $mapId" +
endPartition: Int): Iterator[(BlockManagerId, Seq[(BlockId, Long, Int)])] = {
logDebug(s"Fetching outputs for shuffle $shuffleId, mapIndex $mapIndex" +
s"partitions $startPartition-$endPartition")
val statuses = getStatuses(shuffleId)
try {
MapOutputTracker.convertMapStatuses(shuffleId, startPartition, endPartition,
statuses, Some(mapId))
statuses, Some(mapIndex))
} catch {
case e: MetadataFetchFailedException =>
// We experienced a fetch failure so our mapStatuses cache is outdated; clear it:
Expand Down Expand Up @@ -962,6 +961,7 @@ private[spark] object MapOutputTracker extends Logging {
* @param startPartition Start of map output partition ID range (included in range)
* @param endPartition End of map output partition ID range (excluded from range)
* @param statuses List of map statuses, indexed by map partition index.
* @param mapIndex When specified, only shuffle blocks from this mapper will be processed.
* @return A sequence of 2-item tuples, where the first item in the tuple is a BlockManagerId,
* and the second item is a sequence of (shuffle block id, shuffle block size, map index)
* tuples describing the shuffle blocks that are stored at that block manager.
Expand All @@ -971,11 +971,11 @@ private[spark] object MapOutputTracker extends Logging {
startPartition: Int,
endPartition: Int,
statuses: Array[MapStatus],
mapId : Option[Int] = None): Iterator[(BlockManagerId, Seq[(BlockId, Long, Int)])] = {
mapIndex : Option[Int] = None): Iterator[(BlockManagerId, Seq[(BlockId, Long, Int)])] = {
assert (statuses != null)
val splitsByAddress = new HashMap[BlockManagerId, ListBuffer[(BlockId, Long, Int)]]
val iter = statuses.iterator.zipWithIndex
for ((status, mapIndex) <- mapId.map(id => iter.filter(_._2 == id)).getOrElse(iter)) {
for ((status, mapIndex) <- mapIndex.map(index => iter.filter(_._2 == index)).getOrElse(iter)) {
if (status == null) {
val errorMessage = s"Missing an output location for shuffle $shuffleId"
logError(errorMessage)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ package org.apache.spark.shuffle
import org.apache.spark._
import org.apache.spark.internal.{config, Logging}
import org.apache.spark.serializer.SerializerManager
import org.apache.spark.storage.{BlockManager, ShuffleBlockFetcherIterator}
import org.apache.spark.storage.{BlockId, BlockManager, BlockManagerId, ShuffleBlockFetcherIterator}
import org.apache.spark.util.CompletionIterator
import org.apache.spark.util.collection.ExternalSorter

Expand All @@ -30,34 +30,18 @@ import org.apache.spark.util.collection.ExternalSorter
*/
private[spark] class BlockStoreShuffleReader[K, C](
handle: BaseShuffleHandle[K, _, C],
startPartition: Int,
endPartition: Int,
blocksByAddress: Iterator[(BlockManagerId, Seq[(BlockId, Long, Int)])],
context: TaskContext,
readMetrics: ShuffleReadMetricsReporter,
serializerManager: SerializerManager = SparkEnv.get.serializerManager,
blockManager: BlockManager = SparkEnv.get.blockManager,
mapOutputTracker: MapOutputTracker = SparkEnv.get.mapOutputTracker,
mapId: Option[Int] = None)
mapOutputTracker: MapOutputTracker = SparkEnv.get.mapOutputTracker)
extends ShuffleReader[K, C] with Logging {

private val dep = handle.dependency

/** Read the combined key-values for this reduce task */
override def read(): Iterator[Product2[K, C]] = {
val blocksByAddress = mapId match {
case (Some(mapId)) => mapOutputTracker.getMapSizesByExecutorId(
handle.shuffleId,
startPartition,
endPartition,
mapId)
case (None) => mapOutputTracker.getMapSizesByExecutorId(
handle.shuffleId,
startPartition,
endPartition)
case (_) => throw new IllegalArgumentException(
"mapId should be both set or unset")
}

val wrappedStreams = new ShuffleBlockFetcherIterator(
context,
blockManager.blockStoreClient,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -55,17 +55,16 @@ private[spark] trait ShuffleManager {
metrics: ShuffleReadMetricsReporter): ShuffleReader[K, C]

/**
* Get a reader for a range of reduce partitions (startPartition to endPartition-1, inclusive) to
* read from mapId.
* Called on executors by reduce tasks.
* Get a reader for a range of reduce partitions (startPartition to endPartition-1, inclusive)
* that are produced by one specific mapper. Called on executors by reduce tasks.
*/
def getMapReader[K, C](
def getReaderForOneMapper[K, C](
handle: ShuffleHandle,
mapIndex: Int,
startPartition: Int,
endPartition: Int,
context: TaskContext,
metrics: ShuffleReadMetricsReporter,
mapId: Int): ShuffleReader[K, C]
metrics: ShuffleReadMetricsReporter): ShuffleReader[K, C]

/**
* Remove a shuffle's metadata from the ShuffleManager.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -122,30 +122,23 @@ private[spark] class SortShuffleManager(conf: SparkConf) extends ShuffleManager
endPartition: Int,
context: TaskContext,
metrics: ShuffleReadMetricsReporter): ShuffleReader[K, C] = {
val blocksByAddress = SparkEnv.get.mapOutputTracker.getMapSizesByExecutorId(
handle.shuffleId, startPartition, endPartition)
new BlockStoreShuffleReader(
handle.asInstanceOf[BaseShuffleHandle[K, _, C]],
startPartition, endPartition, context, metrics)
handle.asInstanceOf[BaseShuffleHandle[K, _, C]], blocksByAddress, context, metrics)
}

/**
* Get a reader for a range of reduce partitions (startPartition to endPartition-1, inclusive) to
* read from mapId.
* Called on executors by reduce tasks.
*/
override def getMapReader[K, C](
override def getReaderForOneMapper[K, C](
handle: ShuffleHandle,
mapIndex: Int,
startPartition: Int,
endPartition: Int,
context: TaskContext,
metrics: ShuffleReadMetricsReporter,
mapId: Int): ShuffleReader[K, C] = {
metrics: ShuffleReadMetricsReporter): ShuffleReader[K, C] = {
val blocksByAddress = SparkEnv.get.mapOutputTracker.getMapSizesByMapIndex(
handle.shuffleId, mapIndex, startPartition, endPartition)
new BlockStoreShuffleReader(
handle.asInstanceOf[BaseShuffleHandle[K, _, C]],
startPartition,
endPartition,
context,
metrics,
mapId = Some(mapId))
handle.asInstanceOf[BaseShuffleHandle[K, _, C]], blocksByAddress, context, metrics)
}

/** Get a writer for a given partition. Called on executors by map tasks. */
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -130,15 +130,15 @@ class BlockStoreShuffleReaderSuite extends SparkFunSuite with LocalSparkContext

val taskContext = TaskContext.empty()
val metrics = taskContext.taskMetrics.createTempShuffleReadMetrics()
val blocksByAddress = mapOutputTracker.getMapSizesByExecutorId(
shuffleId, reduceId, reduceId + 1)
val shuffleReader = new BlockStoreShuffleReader(
shuffleHandle,
reduceId,
reduceId + 1,
blocksByAddress,
taskContext,
metrics,
serializerManager,
blockManager,
mapOutputTracker)
blockManager)

assert(shuffleReader.read().length === keyValuePairsPerMap * numMaps)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ import org.apache.spark.sql.execution.metric.{SQLMetric, SQLShuffleReadMetricsRe
* (identified by `preShufflePartitionIndex`) contains a range of post-shuffle partitions
* (`startPostShufflePartitionIndex` to `endPostShufflePartitionIndex - 1`, inclusive).
*/
private final class LocalShuffleRowRDDPartition(
private final class LocalShuffledRowRDDPartition(
val preShufflePartitionIndex: Int) extends Partition {
override val index: Int = preShufflePartitionIndex
}
Expand Down Expand Up @@ -63,7 +63,7 @@ class LocalShuffledRowRDD(
override def getPartitions: Array[Partition] = {

Array.tabulate[Partition](numMappers) { i =>
new LocalShuffleRowRDDPartition(i)
new LocalShuffledRowRDDPartition(i)
}
}

Expand All @@ -73,20 +73,20 @@ class LocalShuffledRowRDD(
}

override def compute(split: Partition, context: TaskContext): Iterator[InternalRow] = {
val localRowPartition = split.asInstanceOf[LocalShuffleRowRDDPartition]
val mapId = localRowPartition.index
val localRowPartition = split.asInstanceOf[LocalShuffledRowRDDPartition]
val mapIndex = localRowPartition.index
val tempMetrics = context.taskMetrics().createTempShuffleReadMetrics()
// `SQLShuffleReadMetricsReporter` will update its own metrics for SQL exchange operator,
// as well as the `tempMetrics` for basic shuffle metrics.
val sqlMetricsReporter = new SQLShuffleReadMetricsReporter(tempMetrics, metrics)

val reader = SparkEnv.get.shuffleManager.getMapReader(
val reader = SparkEnv.get.shuffleManager.getReaderForOneMapper(
dependency.shuffleHandle,
mapIndex,
0,
numReducers,
context,
sqlMetricsReporter,
mapId)
sqlMetricsReporter)
reader.read().asInstanceOf[Iterator[Product2[Int, InternalRow]]].map(_._2)
}

Expand Down