Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SPARK-4452][Core]Shuffle data structures can starve others on the same thread for memory #7130

Closed
wants to merge 14 commits into from
30 changes: 30 additions & 0 deletions core/src/main/java/org/apache/spark/Spillable.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark;

/**
* Force to spill contents of memory buffer to disk and release its memory
*/
public interface Spillable {

/**
* force to spill contents of memory buffer to disk
* @return numBytes bytes of spilled
*/
public long forceSpill();
}
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ package org.apache.spark.shuffle

import scala.collection.mutable

import org.apache.spark.{Logging, SparkException, SparkConf}
import org.apache.spark._

/**
* Allocates a pool of memory to task threads for use in shuffle operations. Each disk-spilling
Expand All @@ -38,8 +38,47 @@ import org.apache.spark.{Logging, SparkException, SparkConf}
private[spark] class ShuffleMemoryManager(maxMemory: Long) extends Logging {
private val threadMemory = new mutable.HashMap[Long, Long]() // threadId -> memory bytes

// threadId -> memory reserved list
private val threadReservedList = new mutable.HashMap[Long, mutable.ListBuffer[Spillable]]()

def this(conf: SparkConf) = this(ShuffleMemoryManager.getMaxMemory(conf))

/**
* release other Spillable's memory of current thread until freeMemory >= requestedMemory
*/
def releaseReservedMemory(toGrant: Long, requestMemory: Long): Long = synchronized {
val threadId = Thread.currentThread().getId
if (toGrant >= requestMemory || !threadReservedList.contains(threadId)){
toGrant
} else {
//try to release Spillable's memory in current thread to make space for new request
var addMemory = toGrant
while(addMemory < requestMemory && !threadReservedList(threadId).isEmpty ) {
val toSpill = threadReservedList(threadId).remove(0)
val spillMemory = toSpill.forceSpill()
logInfo(s"Thread $threadId forceSpill $spillMemory bytes to be free")
addMemory += spillMemory
}
if (addMemory > requestMemory) {
this.release(addMemory - requestMemory)
addMemory = requestMemory
}
addMemory
}
}

/**
* add Spillable to memoryReservedList of current thread, when current thread has
* no enough memory, we can release memory of current thread's memoryReservedList
*/
def addSpillableToReservedList(spill: Spillable) = synchronized {
val threadId = Thread.currentThread().getId
if (!threadReservedList.contains(threadId)) {
threadReservedList(threadId) = new mutable.ListBuffer[Spillable]()
}
threadReservedList(threadId) += spill
}

/**
* Try to acquire up to numBytes memory for the current thread, and return the number of bytes
* obtained, or 0 if none can be allocated. This call may block until there is enough free memory
Expand Down Expand Up @@ -77,7 +116,7 @@ private[spark] class ShuffleMemoryManager(maxMemory: Long) extends Logging {
if (freeMemory >= math.min(maxToGrant, maxMemory / (2 * numActiveThreads) - curMem)) {
val toGrant = math.min(maxToGrant, freeMemory)
threadMemory(threadId) += toGrant
return toGrant
return this.releaseReservedMemory(toGrant, numBytes)
} else {
logInfo(s"Thread $threadId waiting for at least 1/2N of shuffle memory pool to be free")
wait()
Expand All @@ -86,7 +125,7 @@ private[spark] class ShuffleMemoryManager(maxMemory: Long) extends Logging {
// Only give it as much memory as is free, which might be none if it reached 1 / numThreads
val toGrant = math.min(maxToGrant, freeMemory)
threadMemory(threadId) += toGrant
return toGrant
return this.releaseReservedMemory(toGrant, numBytes)
}
}
0L // Never reached
Expand All @@ -108,6 +147,7 @@ private[spark] class ShuffleMemoryManager(maxMemory: Long) extends Logging {
def releaseMemoryForThisThread(): Unit = synchronized {
val threadId = Thread.currentThread().getId
threadMemory.remove(threadId)
threadReservedList.remove(threadId)
notifyAll() // Notify waiters who locked "this" in tryToAcquire that memory has been freed
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,14 +17,15 @@

package org.apache.spark.util.collection

import org.apache.spark.Logging
import org.apache.spark.SparkEnv
import org.apache.spark.{Logging, SparkEnv, Spillable}

import scala.reflect.ClassTag
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: scala imports before spark imports


/**
* Spills contents of an in-memory collection to disk when the memory threshold
* has been exceeded.
*/
private[spark] trait Spillable[C] extends Logging {
private[spark] trait CollectionSpillable[C] extends Logging with Spillable{
/**
* Spills the current in-memory collection to disk, and releases the memory.
*
Expand All @@ -40,25 +41,25 @@ private[spark] trait Spillable[C] extends Logging {
protected def addElementsRead(): Unit = { _elementsRead += 1 }

// Memory manager that can be used to acquire/release memory
private[this] val shuffleMemoryManager = SparkEnv.get.shuffleMemoryManager
protected val shuffleMemoryManager = SparkEnv.get.shuffleMemoryManager

// Initial threshold for the size of a collection before we start tracking its memory usage
// Exposed for testing
private[this] val initialMemoryThreshold: Long =
protected val initialMemoryThreshold: Long =
SparkEnv.get.conf.getLong("spark.shuffle.spill.initialMemoryThreshold", 5 * 1024 * 1024)

// Threshold for this collection's size in bytes before we start tracking its memory usage
// To avoid a large number of small spills, initialize this to a value orders of magnitude > 0
private[this] var myMemoryThreshold = initialMemoryThreshold
protected var myMemoryThreshold = initialMemoryThreshold

// Number of elements read from input since last spill
private[this] var _elementsRead = 0L
protected var _elementsRead = 0L

// Number of bytes spilled in total
private[this] var _memoryBytesSpilled = 0L
protected var _memoryBytesSpilled = 0L

// Number of spills
private[this] var _spillCount = 0
protected var _spillCount = 0

/**
* Spills the current in-memory collection to disk if needed. Attempts to acquire more
Expand Down Expand Up @@ -111,7 +112,7 @@ private[spark] trait Spillable[C] extends Logging {
*
* @param size number of bytes spilled
*/
@inline private def logSpillage(size: Long) {
@inline protected def logSpillage(size: Long) {
val threadId = Thread.currentThread().getId
logInfo("Thread %d spilling in-memory map of %s to disk (%d time%s so far)"
.format(threadId, org.apache.spark.util.Utils.bytesToString(size),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ import org.apache.spark.executor.ShuffleWriteMetrics

/**
* :: DeveloperApi ::
* An append-only map that spills sorted content to disk when there is insufficient space for it
* An append-only map that spills sorted content to disk when there is insufficient space for inMemory
* to grow.
*
* This map takes two passes over the data:
Expand Down Expand Up @@ -69,7 +69,7 @@ class ExternalAppendOnlyMap[K, V, C](
extends Iterable[(K, C)]
with Serializable
with Logging
with Spillable[SizeTracker] {
with CollectionSpillable[SizeTracker] {

private var currentMap = new SizeTrackingAppendOnlyMap[K, C]
private val spilledMaps = new ArrayBuffer[DiskMapIterator]
Expand Down Expand Up @@ -100,6 +100,8 @@ class ExternalAppendOnlyMap[K, V, C](
private val keyComparator = new HashComparator[K]
private val ser = serializer.newInstance()

private var memoryOrDiskIter: Option[MemoryOrDiskIterator] = None

/**
* Insert the given key and value into the map.
*/
Expand Down Expand Up @@ -151,6 +153,14 @@ class ExternalAppendOnlyMap[K, V, C](
* Sort the existing contents of the in-memory map and spill them to a temporary file on disk.
*/
override protected[this] def spill(collection: SizeTracker): Unit = {
val it = currentMap.destructiveSortedIterator(keyComparator)
spilledMaps.append(spillMemoryToDisk(it))
}

/**
* spill contents of the in-memory map to a temporary file on disk.
*/
private[this] def spillMemoryToDisk(inMemory: Iterator[(K, C)]): DiskMapIterator = {
val (blockId, file) = diskBlockManager.createTempLocalBlock()
curWriteMetrics = new ShuffleWriteMetrics()
var writer = blockManager.getDiskWriter(blockId, file, ser, fileBufferSize, curWriteMetrics)
Expand All @@ -171,9 +181,8 @@ class ExternalAppendOnlyMap[K, V, C](

var success = false
try {
val it = currentMap.destructiveSortedIterator(keyComparator)
while (it.hasNext) {
val kv = it.next()
while (inMemory.hasNext) {
val kv = inMemory.next()
writer.write(kv._1, kv._2)
objectsWritten += 1

Expand Down Expand Up @@ -203,8 +212,7 @@ class ExternalAppendOnlyMap[K, V, C](
}
}
}

spilledMaps.append(new DiskMapIterator(file, blockId, batchSizes))
new DiskMapIterator(file, blockId, batchSizes)
}

def diskBytesSpilled: Long = _diskBytesSpilled
Expand All @@ -214,25 +222,75 @@ class ExternalAppendOnlyMap[K, V, C](
* If no spill has occurred, simply return the in-memory map's iterator.
*/
override def iterator: Iterator[(K, C)] = {
shuffleMemoryManager.addSpillableToReservedList(this)
if (spilledMaps.isEmpty) {
currentMap.iterator
memoryOrDiskIter = Some(MemoryOrDiskIterator(currentMap.iterator))
memoryOrDiskIter.get
} else {
new ExternalIterator()
}
}

/**
* spill contents of memory map to disk and release its memory
*/
override def forceSpill(): Long = {
var freeMemory = 0L
if (memoryOrDiskIter.isDefined) {
_spillCount += 1
logSpillage(currentMap.estimateSize())

memoryOrDiskIter.get.spill()

_elementsRead = 0
_memoryBytesSpilled += currentMap.estimateSize()
freeMemory = myMemoryThreshold - initialMemoryThreshold
myMemoryThreshold = initialMemoryThreshold
}

freeMemory
}

/*
* An iterator that read elements from in-memory iterator or disk iterator when in-memory
* iterator have spilled to disk.
*/
case class MemoryOrDiskIterator(memIter: Iterator[(K,C)]) extends Iterator[(K,C)] {

var currentIter = memIter

override def hasNext: Boolean = currentIter.hasNext

override def next(): (K, C) = currentIter.next()

def spill() = {
if (hasNext) {
currentIter = spillMemoryToDisk(currentIter)
} else {
//in-memory iterator is already drained, release it by giving an empty iterator
currentIter = new Iterator[(K,C)]{
override def hasNext: Boolean = false
override def next(): (K, C) = null
}
logInfo("nothing in memory iterator, do nothing")
}
}
}

/**
* An iterator that sort-merges (K, C) pairs from the in-memory map and the spilled maps
*/
private class ExternalIterator extends Iterator[(K, C)] {

// A queue that maintains a buffer for each stream we are currently merging
// This queue maintains the invariant that it only contains non-empty buffers
// This queue maintains the invariant that inMemory only contains non-empty buffers
private val mergeHeap = new mutable.PriorityQueue[StreamBuffer]

// Input streams are derived both from the in-memory map and spilled maps on disk
// The in-memory map is sorted in place, while the spilled maps are already in sorted order
private val sortedMap = currentMap.destructiveSortedIterator(keyComparator)
memoryOrDiskIter = Some(MemoryOrDiskIterator(
currentMap.destructiveSortedIterator(keyComparator)))
private val sortedMap = memoryOrDiskIter.get
private val inputStreams = (Seq(sortedMap) ++ spilledMaps).map(it => it.buffered)

inputStreams.foreach { it =>
Expand Down Expand Up @@ -274,7 +332,7 @@ class ExternalAppendOnlyMap[K, V, C](
val pair = buffer.pairs(i)
if (pair._1 == key) {
// Note that there's at most one pair in the buffer with a given key, since we always
// merge stuff in a map before spilling, so it's safe to return after the first we find
// merge stuff in a map before spilling, so inMemory's safe to return after the first we find
removeFromBuffer(buffer.pairs, i)
return mergeCombiners(baseCombiner, pair._2)
}
Expand All @@ -285,7 +343,7 @@ class ExternalAppendOnlyMap[K, V, C](

/**
* Remove the index'th element from an ArrayBuffer in constant time, swapping another element
* into its place. This is more efficient than the ArrayBuffer.remove method because it does
* into its place. This is more efficient than the ArrayBuffer.remove method because inMemory does
* not have to shift all the elements in the array over. It works for our array buffers because
* we don't care about the order of elements inside, we just want to search them for a key.
*/
Expand Down Expand Up @@ -327,7 +385,7 @@ class ExternalAppendOnlyMap[K, V, C](
mergedBuffers += newBuffer
}

// Repopulate each visited stream buffer and add it back to the queue if it is non-empty
// Repopulate each visited stream buffer and add inMemory back to the queue if inMemory is non-empty
mergedBuffers.foreach { buffer =>
if (buffer.isEmpty) {
readNextHashCode(buffer.iterator, buffer.pairs)
Expand Down Expand Up @@ -430,7 +488,7 @@ class ExternalAppendOnlyMap[K, V, C](
/**
* Return the next (K, C) pair from the deserialization stream.
*
* If the current batch is drained, construct a stream for the next batch and read from it.
* If the current batch is drained, construct a stream for the next batch and read from inMemory.
* If no more pairs are left, return null.
*/
private def readNextItem(): (K, C) = {
Expand Down
Loading