Skip to content

Commit

Permalink
Add support for migrating shuffle files
Browse files Browse the repository at this point in the history
  • Loading branch information
holdenk committed Apr 24, 2020
1 parent 249b214 commit 4126c1b
Show file tree
Hide file tree
Showing 19 changed files with 462 additions and 46 deletions.
23 changes: 22 additions & 1 deletion core/src/main/scala/org/apache/spark/MapOutputTracker.scala
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ import org.apache.spark.util._
*
* All public methods of this class are thread-safe.
*/
private class ShuffleStatus(numPartitions: Int) {
private class ShuffleStatus(numPartitions: Int) extends Logging {

private val (readLock, writeLock) = {
val lock = new ReentrantReadWriteLock()
Expand Down Expand Up @@ -121,6 +121,20 @@ private class ShuffleStatus(numPartitions: Int) {
mapStatuses(mapIndex) = status
}

/**
* Update the map output location (e.g. during migration).
*/
def updateMapOutput(mapId: Long, bmAddress: BlockManagerId): Unit = withWriteLock {
val mapStatusOpt = mapStatuses.find(_.mapId == mapId)
mapStatusOpt match {
case Some(mapStatus) =>
mapStatus.updateLocation(bmAddress)
invalidateSerializedMapOutputStatusCache()
case None =>
logError("Asked to update map output ${mapId} for untracked map status.")
}
}

/**
* Remove the map output which was served by the specified block manager.
* This is a no-op if there is no registered map output or if the registered output is from a
Expand Down Expand Up @@ -479,6 +493,13 @@ private[spark] class MapOutputTrackerMaster(
}
}

def updateMapOutput(shuffleId: Int, mapId: Long, bmAddress: BlockManagerId): Unit = {
shuffleStatuses.get(shuffleId) match {
case Some(shuffleStatus) => shuffleStatus.updateMapOutput(mapId, bmAddress)
case None => logError("Asked to update map output for unknown shuffle ${shuffleId}")
}
}

def registerMapOutput(shuffleId: Int, mapIndex: Int, status: MapStatus): Unit = {
shuffleStatuses(shuffleId).addMapOutput(mapIndex, status)
}
Expand Down
15 changes: 13 additions & 2 deletions core/src/main/scala/org/apache/spark/SparkContext.scala
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ import org.apache.spark.resource._
import org.apache.spark.resource.ResourceUtils._
import org.apache.spark.rpc.RpcEndpointRef
import org.apache.spark.scheduler._
import org.apache.spark.scheduler.cluster.StandaloneSchedulerBackend
import org.apache.spark.scheduler.cluster.{CoarseGrainedSchedulerBackend, StandaloneSchedulerBackend}
import org.apache.spark.scheduler.local.LocalSchedulerBackend
import org.apache.spark.shuffle.ShuffleDataIOUtils
import org.apache.spark.shuffle.api.ShuffleDriverComponents
Expand Down Expand Up @@ -1586,7 +1586,7 @@ class SparkContext(config: SparkConf) extends Logging {
listenerBus.removeListener(listener)
}

private[spark] def getExecutorIds(): Seq[String] = {
def getExecutorIds(): Seq[String] = {
schedulerBackend match {
case b: ExecutorAllocationClient =>
b.getExecutorIds()
Expand Down Expand Up @@ -1725,6 +1725,17 @@ class SparkContext(config: SparkConf) extends Logging {
}
}


@DeveloperApi
def decommissionExecutors(executorIds: Seq[String]): Unit = {
schedulerBackend match {
case b: CoarseGrainedSchedulerBackend =>
executorIds.foreach(b.decommissionExecutor)
case _ =>
logWarning("Decommissioning executors is not supported by current scheduler.")
}
}

/** The version of Spark on which this application is running. */
def version: String = SPARK_VERSION

Expand Down
3 changes: 2 additions & 1 deletion core/src/main/scala/org/apache/spark/SparkEnv.scala
Original file line number Diff line number Diff line change
Expand Up @@ -367,7 +367,8 @@ object SparkEnv extends Logging {
externalShuffleClient
} else {
None
}, blockManagerInfo)),
}, blockManagerInfo,
mapOutputTracker.asInstanceOf[MapOutputTrackerMaster])),
registerOrLookupEndpoint(
BlockManagerMaster.DRIVER_HEARTBEAT_ENDPOINT_NAME,
new BlockManagerMasterHeartbeatEndpoint(rpcEnv, isLocal, blockManagerInfo)),
Expand Down
15 changes: 15 additions & 0 deletions core/src/main/scala/org/apache/spark/internal/config/package.scala
Original file line number Diff line number Diff line change
Expand Up @@ -420,6 +420,21 @@ package object config {
.booleanConf
.createWithDefault(false)

private[spark] val STORAGE_SHUFFLE_DECOMMISSION_ENABLED =
ConfigBuilder("spark.storage.decommission.shuffle_blocks")
.doc("Whether to transfer shuffle blocks during block manager decommissioning. Requires " +
"an indexed shuffle resolver (like sort based shuffe)")
.version("3.1.0")
.booleanConf
.createWithDefault(true)

private[spark] val STORAGE_RDD_DECOMMISSION_ENABLED =
ConfigBuilder("spark.storage.decommission.rdd_blocks")
.doc("Whether to transfer RDD blocks during block manager decommissioning.")
.version("3.1.0")
.booleanConf
.createWithDefault(true)

private[spark] val STORAGE_DECOMMISSION_MAX_REPLICATION_FAILURE_PER_BLOCK =
ConfigBuilder("spark.storage.decommission.maxReplicationFailuresPerBlock")
.internal()
Expand Down
12 changes: 11 additions & 1 deletion core/src/main/scala/org/apache/spark/scheduler/MapStatus.scala
Original file line number Diff line number Diff line change
Expand Up @@ -33,9 +33,11 @@ import org.apache.spark.util.Utils
* task ran on as well as the sizes of outputs for each reducer, for passing on to the reduce tasks.
*/
private[spark] sealed trait MapStatus {
/** Location where this task was run. */
/** Location where this task output is. */
def location: BlockManagerId

def updateLocation(bm: BlockManagerId): Unit

/**
* Estimated size for the reduce block, in bytes.
*
Expand Down Expand Up @@ -126,6 +128,10 @@ private[spark] class CompressedMapStatus(

override def location: BlockManagerId = loc

override def updateLocation(bm: BlockManagerId): Unit = {
loc = bm
}

override def getSizeForBlock(reduceId: Int): Long = {
MapStatus.decompressSize(compressedSizes(reduceId))
}
Expand Down Expand Up @@ -178,6 +184,10 @@ private[spark] class HighlyCompressedMapStatus private (

override def location: BlockManagerId = loc

override def updateLocation(bm: BlockManagerId): Unit = {
loc = bm
}

override def getSizeForBlock(reduceId: Int): Long = {
assert(hugeBlockSizes != null)
if (emptyBlocks.contains(reduceId)) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,15 +18,18 @@
package org.apache.spark.shuffle

import java.io._
import java.nio.ByteBuffer
import java.nio.channels.Channels
import java.nio.file.Files

import org.apache.spark.{SparkConf, SparkEnv}
import org.apache.spark.internal.Logging
import org.apache.spark.io.NioBufferedFileInputStream
import org.apache.spark.network.buffer.{FileSegmentManagedBuffer, ManagedBuffer}
import org.apache.spark.network.client.StreamCallbackWithID
import org.apache.spark.network.netty.SparkTransportConf
import org.apache.spark.network.shuffle.ExecutorDiskUtils
import org.apache.spark.serializer.SerializerManager
import org.apache.spark.shuffle.IndexShuffleBlockResolver.NOOP_REDUCE_ID
import org.apache.spark.storage._
import org.apache.spark.util.Utils
Expand Down Expand Up @@ -55,6 +58,25 @@ private[spark] class IndexShuffleBlockResolver(

def getDataFile(shuffleId: Int, mapId: Long): File = getDataFile(shuffleId, mapId, None)

/**
* Get the shuffle files that are stored locally. Used for block migrations.
*/
def getStoredShuffles(): Set[(Int, Long)] = {
// Matches ShuffleIndexBlockId name
val pattern = "shuffle_(\\d+)_(\\d+)_.+\\.index".r
val rootDirs = blockManager.diskBlockManager.localDirs
// ExecutorDiskUtil puts things inside one level hashed sub directories
val searchDirs = rootDirs.flatMap(_.listFiles()).filter(_.isDirectory()) ++ rootDirs
val filenames = searchDirs.flatMap(_.list())
logDebug(s"Got block files ${filenames.toList}")
filenames.flatMap{ fname =>
pattern.findAllIn(fname).matchData.map {
matched => (matched.group(1).toInt, matched.group(2).toLong)
}
}.toSet
}


/**
* Get the shuffle data file.
*
Expand Down Expand Up @@ -148,6 +170,86 @@ private[spark] class IndexShuffleBlockResolver(
}
}

/**
* Write a provided shuffle block as a stream. Used for block migrations.
* ShuffleBlockBatchIds must contain the full range represented in the ShuffleIndexBlock.
* Requires the caller to delete any shuffle index blocks where the shuffle block fails to
* put.
*/
def putShuffleBlockAsStream(blockId: BlockId, serializerManager: SerializerManager):
StreamCallbackWithID = {
val file = blockId match {
case ShuffleIndexBlockId(shuffleId, mapId, _) =>
getIndexFile(shuffleId, mapId)
case ShuffleBlockBatchId(shuffleId, mapId, _, _) =>
getDataFile(shuffleId, mapId)
case _ =>
throw new Exception(s"Unexpected shuffle block transfer ${blockId}")
}
val fileTmp = Utils.tempFileWith(file)
val channel = Channels.newChannel(
serializerManager.wrapStream(blockId,
new FileOutputStream(fileTmp)))

new StreamCallbackWithID {

override def getID: String = blockId.name

override def onData(streamId: String, buf: ByteBuffer): Unit = {
while (buf.hasRemaining) {
channel.write(buf)
}
}

override def onComplete(streamId: String): Unit = {
logTrace(s"Done receiving block $blockId, now putting into local shuffle service")
channel.close()
val diskSize = fileTmp.length()
this.synchronized {
if (file.exists()) {
file.delete()
}
if (!fileTmp.renameTo(file)) {
throw new IOException(s"fail to rename file ${fileTmp} to ${file}")
}
}
blockManager.reportBlockStatus(blockId, BlockStatus(
StorageLevel(
useDisk = true,
useMemory = false,
useOffHeap = false,
deserialized = false,
replication = 0)
, 0, diskSize))
}

override def onFailure(streamId: String, cause: Throwable): Unit = {
// the framework handles the connection itself, we just need to do local cleanup
channel.close()
fileTmp.delete()
}
}
}

/**
* Get the index & data block for migration.
*/
def getMigrationBlocks(shuffleId: Int, mapId: Long):
((BlockId, ManagedBuffer), (BlockId, ManagedBuffer)) = {
// Load the index block
val indexFile = getIndexFile(shuffleId, mapId)
val indexBlockId = ShuffleIndexBlockId(shuffleId, mapId, 0)
val indexFileSize = indexFile.length()
val indexBlockData = new FileSegmentManagedBuffer(transportConf, indexFile, 0, indexFileSize)

// Load the data block
val dataFile = getDataFile(shuffleId, mapId)
val dataBlockId = ShuffleDataBlockId(shuffleId, mapId, 0)
val dataBlockData = new FileSegmentManagedBuffer(transportConf, dataFile, 0, dataFile.length())
((indexBlockId, indexBlockData), (dataBlockId, dataBlockData))
}


/**
* Write an index file with the offsets of each block, plus a final offset at the end for the
* end of the output file. This will be used by getBlockData to figure out where each block
Expand All @@ -169,7 +271,7 @@ private[spark] class IndexShuffleBlockResolver(
val dataFile = getDataFile(shuffleId, mapId)
// There is only one IndexShuffleBlockResolver per executor, this synchronization make sure
// the following check and rename are atomic.
synchronized {
this.synchronized {
val existingLengths = checkIndexAndDataFile(indexFile, dataFile, lengths.length)
if (existingLengths != null) {
// Another attempt for the same task has already written our map outputs successfully,
Expand Down
3 changes: 3 additions & 0 deletions core/src/main/scala/org/apache/spark/storage/BlockId.scala
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,9 @@ sealed abstract class BlockId {
def isRDD: Boolean = isInstanceOf[RDDBlockId]
def isShuffle: Boolean = isInstanceOf[ShuffleBlockId] || isInstanceOf[ShuffleBlockBatchId]
def isBroadcast: Boolean = isInstanceOf[BroadcastBlockId]
def isInternalShuffle: Boolean = {
isInstanceOf[ShuffleDataBlockId] || isInstanceOf[ShuffleIndexBlockId]
}

override def toString: String = name
}
Expand Down
Loading

0 comments on commit 4126c1b

Please sign in to comment.