Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SPARK-4452][Core]Shuffle data structures can starve others on the same thread for memory #7130

Closed
wants to merge 14 commits into from
30 changes: 30 additions & 0 deletions core/src/main/java/org/apache/spark/Spillable.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark;

/**
* Force to spill contents of memory buffer to disk and release its memory
*/
public interface Spillable {

/**
* force to spill contents of memory buffer to disk
* @return numBytes bytes of spilled
*/
public long forceSpill();
}
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ package org.apache.spark.shuffle

import scala.collection.mutable

import org.apache.spark.{Logging, SparkException, SparkConf}
import org.apache.spark._

/**
* Allocates a pool of memory to task threads for use in shuffle operations. Each disk-spilling
Expand All @@ -38,8 +38,48 @@ import org.apache.spark.{Logging, SparkException, SparkConf}
private[spark] class ShuffleMemoryManager(maxMemory: Long) extends Logging {
private val threadMemory = new mutable.HashMap[Long, Long]() // threadId -> memory bytes

// threadId -> memory reserved list
private val threadReservedList = new mutable.HashMap[Long, mutable.ListBuffer[Spillable]]()

def this(conf: SparkConf) = this(ShuffleMemoryManager.getMaxMemory(conf))

/**
* release other Spillable's memory of current thread until freeMemory >= requestedMemory
*/
private[this] def releaseReservedMemory(toGrant: Long, requestMemory: Long): Long =
synchronized {
val threadId = Thread.currentThread().getId
if (toGrant >= requestMemory || !threadReservedList.contains(threadId)){
toGrant
} else {
// try to release Spillable's memory in current thread to make space for new request
var addMemory = toGrant
while(addMemory < requestMemory && !threadReservedList(threadId).isEmpty ) {
val toSpill = threadReservedList(threadId).remove(0)
val spillMemory = toSpill.forceSpill()
logInfo(s"Thread $threadId forceSpill $spillMemory bytes to be free")
addMemory += spillMemory
}
if (addMemory > requestMemory) {
this.release(addMemory - requestMemory)
addMemory = requestMemory
}
addMemory
}
}

/**
* add Spillable to memoryReservedList of current thread, when current thread has
* no enough memory, we can release memory of current thread's memoryReservedList
*/
private[spark] def addSpillableToReservedList(spill: Spillable) = synchronized {
val threadId = Thread.currentThread().getId
if (!threadReservedList.contains(threadId)) {
threadReservedList(threadId) = new mutable.ListBuffer[Spillable]()
}
threadReservedList(threadId) += spill
}

/**
* Try to acquire up to numBytes memory for the current thread, and return the number of bytes
* obtained, or 0 if none can be allocated. This call may block until there is enough free memory
Expand Down Expand Up @@ -77,7 +117,7 @@ private[spark] class ShuffleMemoryManager(maxMemory: Long) extends Logging {
if (freeMemory >= math.min(maxToGrant, maxMemory / (2 * numActiveThreads) - curMem)) {
val toGrant = math.min(maxToGrant, freeMemory)
threadMemory(threadId) += toGrant
return toGrant
return this.releaseReservedMemory(toGrant, numBytes)
} else {
logInfo(s"Thread $threadId waiting for at least 1/2N of shuffle memory pool to be free")
wait()
Expand All @@ -86,7 +126,7 @@ private[spark] class ShuffleMemoryManager(maxMemory: Long) extends Logging {
// Only give it as much memory as is free, which might be none if it reached 1 / numThreads
val toGrant = math.min(maxToGrant, freeMemory)
threadMemory(threadId) += toGrant
return toGrant
return this.releaseReservedMemory(toGrant, numBytes)
}
}
0L // Never reached
Expand All @@ -108,6 +148,7 @@ private[spark] class ShuffleMemoryManager(maxMemory: Long) extends Logging {
def releaseMemoryForThisThread(): Unit = synchronized {
val threadId = Thread.currentThread().getId
threadMemory.remove(threadId)
threadReservedList.remove(threadId)
notifyAll() // Notify waiters who locked "this" in tryToAcquire that memory has been freed
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,14 +17,13 @@

package org.apache.spark.util.collection

import org.apache.spark.Logging
import org.apache.spark.SparkEnv
import org.apache.spark.{Logging, SparkEnv, Spillable}

/**
* Spills contents of an in-memory collection to disk when the memory threshold
* has been exceeded.
*/
private[spark] trait Spillable[C] extends Logging {
private[spark] trait CollectionSpillable[C] extends Logging with Spillable{
/**
* Spills the current in-memory collection to disk, and releases the memory.
*
Expand All @@ -40,16 +39,10 @@ private[spark] trait Spillable[C] extends Logging {
protected def addElementsRead(): Unit = { _elementsRead += 1 }

// Memory manager that can be used to acquire/release memory
private[this] val shuffleMemoryManager = SparkEnv.get.shuffleMemoryManager

// Initial threshold for the size of a collection before we start tracking its memory usage
// Exposed for testing
private[this] val initialMemoryThreshold: Long =
SparkEnv.get.conf.getLong("spark.shuffle.spill.initialMemoryThreshold", 5 * 1024 * 1024)
protected val shuffleMemoryManager = SparkEnv.get.shuffleMemoryManager

// Threshold for this collection's size in bytes before we start tracking its memory usage
// To avoid a large number of small spills, initialize this to a value orders of magnitude > 0
private[this] var myMemoryThreshold = initialMemoryThreshold
private[this] var myMemoryThreshold = 0L

// Number of elements read from input since last spill
private[this] var _elementsRead = 0L
Expand Down Expand Up @@ -102,8 +95,8 @@ private[spark] trait Spillable[C] extends Logging {
*/
private def releaseMemoryForThisThread(): Unit = {
// The amount we requested does not include the initial memory tracking threshold
shuffleMemoryManager.release(myMemoryThreshold - initialMemoryThreshold)
myMemoryThreshold = initialMemoryThreshold
shuffleMemoryManager.release(myMemoryThreshold)
myMemoryThreshold = 0L
}

/**
Expand All @@ -117,4 +110,18 @@ private[spark] trait Spillable[C] extends Logging {
.format(threadId, org.apache.spark.util.Utils.bytesToString(size),
_spillCount, if (_spillCount > 1) "s" else ""))
}

/**
* log ForceSpill and return collection's size
*/
protected def logForceSpill(currentMemory: Long): Long = {
_spillCount += 1
logSpillage(currentMemory)

_elementsRead = 0
_memoryBytesSpilled += currentMemory
val freeMemory = myMemoryThreshold
myMemoryThreshold = 0L
freeMemory
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ class ExternalAppendOnlyMap[K, V, C](
extends Iterable[(K, C)]
with Serializable
with Logging
with Spillable[SizeTracker] {
with CollectionSpillable[SizeTracker] {

private var currentMap = new SizeTrackingAppendOnlyMap[K, C]
private val spilledMaps = new ArrayBuffer[DiskMapIterator]
Expand Down Expand Up @@ -100,6 +100,8 @@ class ExternalAppendOnlyMap[K, V, C](
private val keyComparator = new HashComparator[K]
private val ser = serializer.newInstance()

private var memoryOrDiskIter: Option[MemoryOrDiskIterator] = None

/**
* Insert the given key and value into the map.
*/
Expand Down Expand Up @@ -151,6 +153,14 @@ class ExternalAppendOnlyMap[K, V, C](
* Sort the existing contents of the in-memory map and spill them to a temporary file on disk.
*/
override protected[this] def spill(collection: SizeTracker): Unit = {
val it = currentMap.destructiveSortedIterator(keyComparator)
spilledMaps.append(spillMemoryToDisk(it))
}

/**
* spill contents of the in-memory map to a temporary file on disk.
*/
private[this] def spillMemoryToDisk(inMemory: Iterator[(K, C)]): DiskMapIterator = {
val (blockId, file) = diskBlockManager.createTempLocalBlock()
curWriteMetrics = new ShuffleWriteMetrics()
var writer = blockManager.getDiskWriter(blockId, file, ser, fileBufferSize, curWriteMetrics)
Expand All @@ -171,9 +181,8 @@ class ExternalAppendOnlyMap[K, V, C](

var success = false
try {
val it = currentMap.destructiveSortedIterator(keyComparator)
while (it.hasNext) {
val kv = it.next()
while (inMemory.hasNext) {
val kv = inMemory.next()
writer.write(kv._1, kv._2)
objectsWritten += 1

Expand Down Expand Up @@ -203,8 +212,7 @@ class ExternalAppendOnlyMap[K, V, C](
}
}
}

spilledMaps.append(new DiskMapIterator(file, blockId, batchSizes))
new DiskMapIterator(file, blockId, batchSizes)
}

def diskBytesSpilled: Long = _diskBytesSpilled
Expand All @@ -214,13 +222,53 @@ class ExternalAppendOnlyMap[K, V, C](
* If no spill has occurred, simply return the in-memory map's iterator.
*/
override def iterator: Iterator[(K, C)] = {
shuffleMemoryManager.addSpillableToReservedList(this)
if (spilledMaps.isEmpty) {
currentMap.iterator
memoryOrDiskIter = Some(MemoryOrDiskIterator(currentMap.iterator))
memoryOrDiskIter.get
} else {
new ExternalIterator()
}
}

/**
* spill contents of memory map to disk and release its memory
*/
override def forceSpill(): Long = {
var freeMemory = 0L
if (memoryOrDiskIter.isDefined) {
freeMemory = logForceSpill(currentMap.estimateSize())
memoryOrDiskIter.get.spill()
}
freeMemory
}

/*
* An iterator that read elements from in-memory iterator or disk iterator when in-memory
* iterator have spilled to disk.
*/
case class MemoryOrDiskIterator(memIter: Iterator[(K, C)]) extends Iterator[(K, C)] {

var currentIter = memIter

override def hasNext: Boolean = currentIter.hasNext

override def next(): (K, C) = currentIter.next()

private[spark] def spill() = {
if (hasNext) {
currentIter = spillMemoryToDisk(currentIter)
} else {
// in-memory iterator is already drained, release it by giving an empty iterator
currentIter = new Iterator[(K, C)]{
override def hasNext: Boolean = false
override def next(): (K, C) = null
}
logInfo("nothing in memory iterator, do nothing")
}
}
}

/**
* An iterator that sort-merges (K, C) pairs from the in-memory map and the spilled maps
*/
Expand All @@ -232,7 +280,9 @@ class ExternalAppendOnlyMap[K, V, C](

// Input streams are derived both from the in-memory map and spilled maps on disk
// The in-memory map is sorted in place, while the spilled maps are already in sorted order
private val sortedMap = currentMap.destructiveSortedIterator(keyComparator)
memoryOrDiskIter = Some(MemoryOrDiskIterator(
currentMap.destructiveSortedIterator(keyComparator)))
private val sortedMap = memoryOrDiskIter.get
private val inputStreams = (Seq(sortedMap) ++ spilledMaps).map(it => it.buffered)

inputStreams.foreach { it =>
Expand Down
Loading