Skip to content

Commit

Permalink
[SPARK-9747] [SQL] Avoid starving an unsafe operator in aggregation
Browse files Browse the repository at this point in the history
This is the sister patch to #8011, but for aggregation.

In a nutshell: create the `TungstenAggregationIterator` before computing the parent partition. Internally this creates a `BytesToBytesMap` which acquires a page in the constructor as of this patch. This ensures that the aggregation operator is not starved since we reserve at least 1 page in advance.

rxin yhuai

Author: Andrew Or <andrew@databricks.com>

Closes #8038 from andrewor14/unsafe-starve-memory-agg.
  • Loading branch information
Andrew Or authored and rxin committed Aug 12, 2015
1 parent 66d87c1 commit e011079
Show file tree
Hide file tree
Showing 7 changed files with 201 additions and 76 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -193,6 +193,11 @@ public BytesToBytesMap(
TaskMemoryManager.MAXIMUM_PAGE_SIZE_BYTES);
}
allocate(initialCapacity);

// Acquire a new page as soon as we construct the map to ensure that we have at least
// one page to work with. Otherwise, other operators in the same task may starve this
// map (SPARK-9747).
acquireNewPage();
}

public BytesToBytesMap(
Expand Down Expand Up @@ -574,16 +579,9 @@ public boolean putNewKey(
final long lengthOffsetInPage = currentDataPage.getBaseOffset() + pageCursor;
Platform.putInt(pageBaseObject, lengthOffsetInPage, END_OF_PAGE_MARKER);
}
final long memoryGranted = shuffleMemoryManager.tryToAcquire(pageSizeBytes);
if (memoryGranted != pageSizeBytes) {
shuffleMemoryManager.release(memoryGranted);
logger.debug("Failed to acquire {} bytes of memory", pageSizeBytes);
if (!acquireNewPage()) {
return false;
}
MemoryBlock newPage = taskMemoryManager.allocatePage(pageSizeBytes);
dataPages.add(newPage);
pageCursor = 0;
currentDataPage = newPage;
dataPage = currentDataPage;
dataPageBaseObject = currentDataPage.getBaseObject();
dataPageInsertOffset = currentDataPage.getBaseOffset();
Expand Down Expand Up @@ -642,6 +640,24 @@ public boolean putNewKey(
}
}

/**
* Acquire a new page from the {@link ShuffleMemoryManager}.
* @return whether there is enough space to allocate the new page.
*/
private boolean acquireNewPage() {
final long memoryGranted = shuffleMemoryManager.tryToAcquire(pageSizeBytes);
if (memoryGranted != pageSizeBytes) {
shuffleMemoryManager.release(memoryGranted);
logger.debug("Failed to acquire {} bytes of memory", pageSizeBytes);
return false;
}
MemoryBlock newPage = taskMemoryManager.allocatePage(pageSizeBytes);
dataPages.add(newPage);
pageCursor = 0;
currentDataPage = newPage;
return true;
}

/**
* Allocate new data structures for this map. When calling this outside of the constructor,
* make sure to keep references to the old data structures so that you can free them.
Expand Down Expand Up @@ -748,7 +764,7 @@ public long getNumHashCollisions() {
}

@VisibleForTesting
int getNumDataPages() {
public int getNumDataPages() {
return dataPages.size();
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -132,16 +132,15 @@ private UnsafeExternalSorter(

if (existingInMemorySorter == null) {
initializeForWriting();
// Acquire a new page as soon as we construct the sorter to ensure that we have at
// least one page to work with. Otherwise, other operators in the same task may starve
// this sorter (SPARK-9709). We don't need to do this if we already have an existing sorter.
acquireNewPage();
} else {
this.isInMemSorterExternal = true;
this.inMemSorter = existingInMemorySorter;
}

// Acquire a new page as soon as we construct the sorter to ensure that we have at
// least one page to work with. Otherwise, other operators in the same task may starve
// this sorter (SPARK-9709).
acquireNewPage();

// Register a cleanup task with TaskContext to ensure that memory is guaranteed to be freed at
// the end of the task. This is necessary to avoid memory leaks in when the downstream operator
// does not fully consume the sorter's output (e.g. sort followed by limit).
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -543,7 +543,7 @@ public void testPeakMemoryUsed() {
Platform.LONG_ARRAY_OFFSET,
8);
newPeakMemory = map.getPeakMemoryUsedBytes();
if (i % numRecordsPerPage == 0) {
if (i % numRecordsPerPage == 0 && i > 0) {
// We allocated a new page for this record, so peak memory should change
assertEquals(previousPeakMemory + pageSizeBytes, newPeakMemory);
} else {
Expand All @@ -561,4 +561,13 @@ public void testPeakMemoryUsed() {
map.free();
}
}

@Test
public void testAcquirePageInConstructor() {
final BytesToBytesMap map = new BytesToBytesMap(
taskMemoryManager, shuffleMemoryManager, 1, PAGE_SIZE_BYTES);
assertEquals(1, map.getNumDataPages());
map.free();
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@

import java.io.IOException;

import com.google.common.annotations.VisibleForTesting;

import org.apache.spark.SparkEnv;
import org.apache.spark.shuffle.ShuffleMemoryManager;
import org.apache.spark.sql.catalyst.InternalRow;
Expand Down Expand Up @@ -220,6 +222,11 @@ public long getPeakMemoryUsedBytes() {
return map.getPeakMemoryUsedBytes();
}

@VisibleForTesting
public int getNumDataPages() {
return map.getNumDataPages();
}

/**
* Free the memory associated with this map. This is idempotent and can be called multiple times.
*/
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,12 +17,13 @@

package org.apache.spark.sql.execution.aggregate

import org.apache.spark.rdd.RDD
import org.apache.spark.TaskContext
import org.apache.spark.rdd.{MapPartitionsWithPreparationRDD, RDD}
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.errors._
import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression2
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.plans.physical.{UnspecifiedDistribution, ClusteredDistribution, AllTuples, Distribution}
import org.apache.spark.sql.catalyst.plans.physical._
import org.apache.spark.sql.execution.{UnaryNode, SparkPlan}
import org.apache.spark.sql.execution.metric.SQLMetrics

Expand Down Expand Up @@ -68,35 +69,56 @@ case class TungstenAggregate(
protected override def doExecute(): RDD[InternalRow] = attachTree(this, "execute") {
val numInputRows = longMetric("numInputRows")
val numOutputRows = longMetric("numOutputRows")
child.execute().mapPartitions { iter =>
val hasInput = iter.hasNext
if (!hasInput && groupingExpressions.nonEmpty) {
// This is a grouped aggregate and the input iterator is empty,
// so return an empty iterator.
Iterator.empty.asInstanceOf[Iterator[UnsafeRow]]
} else {
val aggregationIterator =
new TungstenAggregationIterator(
groupingExpressions,
nonCompleteAggregateExpressions,
completeAggregateExpressions,
initialInputBufferOffset,
resultExpressions,
newMutableProjection,
child.output,
iter,
testFallbackStartsAt,
numInputRows,
numOutputRows)

if (!hasInput && groupingExpressions.isEmpty) {

/**
* Set up the underlying unsafe data structures used before computing the parent partition.
* This makes sure our iterator is not starved by other operators in the same task.
*/
def preparePartition(): TungstenAggregationIterator = {
new TungstenAggregationIterator(
groupingExpressions,
nonCompleteAggregateExpressions,
completeAggregateExpressions,
initialInputBufferOffset,
resultExpressions,
newMutableProjection,
child.output,
testFallbackStartsAt,
numInputRows,
numOutputRows)
}

/** Compute a partition using the iterator already set up previously. */
def executePartition(
context: TaskContext,
partitionIndex: Int,
aggregationIterator: TungstenAggregationIterator,
parentIterator: Iterator[InternalRow]): Iterator[UnsafeRow] = {
val hasInput = parentIterator.hasNext
if (!hasInput) {
// We're not using the underlying map, so we just can free it here
aggregationIterator.free()
if (groupingExpressions.isEmpty) {
numOutputRows += 1
Iterator.single[UnsafeRow](aggregationIterator.outputForEmptyGroupingKeyWithoutInput())
} else {
aggregationIterator
// This is a grouped aggregate and the input iterator is empty,
// so return an empty iterator.
Iterator[UnsafeRow]()
}
} else {
aggregationIterator.start(parentIterator)
aggregationIterator
}
}

// Note: we need to set up the iterator in each partition before computing the
// parent partition, so we cannot simply use `mapPartitions` here (SPARK-9747).
val resultRdd = {
new MapPartitionsWithPreparationRDD[UnsafeRow, InternalRow, TungstenAggregationIterator](
child.execute(), preparePartition, executePartition, preservesPartitioning = true)
}
resultRdd.asInstanceOf[RDD[InternalRow]]
}

override def simpleString: String = {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -72,8 +72,6 @@ import org.apache.spark.sql.types.StructType
* the function used to create mutable projections.
* @param originalInputAttributes
* attributes of representing input rows from `inputIter`.
* @param inputIter
* the iterator containing input [[UnsafeRow]]s.
*/
class TungstenAggregationIterator(
groupingExpressions: Seq[NamedExpression],
Expand All @@ -83,12 +81,14 @@ class TungstenAggregationIterator(
resultExpressions: Seq[NamedExpression],
newMutableProjection: (Seq[Expression], Seq[Attribute]) => (() => MutableProjection),
originalInputAttributes: Seq[Attribute],
inputIter: Iterator[InternalRow],
testFallbackStartsAt: Option[Int],
numInputRows: LongSQLMetric,
numOutputRows: LongSQLMetric)
extends Iterator[UnsafeRow] with Logging {

// The parent partition iterator, to be initialized later in `start`
private[this] var inputIter: Iterator[InternalRow] = null

///////////////////////////////////////////////////////////////////////////
// Part 1: Initializing aggregate functions.
///////////////////////////////////////////////////////////////////////////
Expand Down Expand Up @@ -348,11 +348,15 @@ class TungstenAggregationIterator(
false // disable tracking of performance metrics
)

// Exposed for testing
private[aggregate] def getHashMap: UnsafeFixedWidthAggregationMap = hashMap

// The function used to read and process input rows. When processing input rows,
// it first uses hash-based aggregation by putting groups and their buffers in
// hashMap. If we could not allocate more memory for the map, we switch to
// sort-based aggregation (by calling switchToSortBasedAggregation).
private def processInputs(): Unit = {
assert(inputIter != null, "attempted to process input when iterator was null")
while (!sortBased && inputIter.hasNext) {
val newInput = inputIter.next()
numInputRows += 1
Expand All @@ -372,6 +376,7 @@ class TungstenAggregationIterator(
// that it switch to sort-based aggregation after `fallbackStartsAt` input rows have
// been processed.
private def processInputsWithControlledFallback(fallbackStartsAt: Int): Unit = {
assert(inputIter != null, "attempted to process input when iterator was null")
var i = 0
while (!sortBased && inputIter.hasNext) {
val newInput = inputIter.next()
Expand Down Expand Up @@ -412,6 +417,7 @@ class TungstenAggregationIterator(
* Switch to sort-based aggregation when the hash-based approach is unable to acquire memory.
*/
private def switchToSortBasedAggregation(firstKey: UnsafeRow, firstInput: InternalRow): Unit = {
assert(inputIter != null, "attempted to process input when iterator was null")
logInfo("falling back to sort based aggregation.")
// Step 1: Get the ExternalSorter containing sorted entries of the map.
externalSorter = hashMap.destructAndCreateExternalSorter()
Expand All @@ -431,6 +437,11 @@ class TungstenAggregationIterator(
case _ => false
}

// Note: Since we spill the sorter's contents immediately after creating it, we must insert
// something into the sorter here to ensure that we acquire at least a page of memory.
// This is done through `externalSorter.insertKV`, which will trigger the page allocation.
// Otherwise, children operators may steal the window of opportunity and starve our sorter.

if (needsProcess) {
// First, we create a buffer.
val buffer = createNewAggregationBuffer()
Expand Down Expand Up @@ -588,27 +599,33 @@ class TungstenAggregationIterator(
// have not switched to sort-based aggregation.
///////////////////////////////////////////////////////////////////////////

// Starts to process input rows.
testFallbackStartsAt match {
case None =>
processInputs()
case Some(fallbackStartsAt) =>
// This is the testing path. processInputsWithControlledFallback is same as processInputs
// except that it switches to sort-based aggregation after `fallbackStartsAt` input rows
// have been processed.
processInputsWithControlledFallback(fallbackStartsAt)
}
/**
* Start processing input rows.
* Only after this method is called will this iterator be non-empty.
*/
def start(parentIter: Iterator[InternalRow]): Unit = {
inputIter = parentIter
testFallbackStartsAt match {
case None =>
processInputs()
case Some(fallbackStartsAt) =>
// This is the testing path. processInputsWithControlledFallback is same as processInputs
// except that it switches to sort-based aggregation after `fallbackStartsAt` input rows
// have been processed.
processInputsWithControlledFallback(fallbackStartsAt)
}

// If we did not switch to sort-based aggregation in processInputs,
// we pre-load the first key-value pair from the map (to make hasNext idempotent).
if (!sortBased) {
// First, set aggregationBufferMapIterator.
aggregationBufferMapIterator = hashMap.iterator()
// Pre-load the first key-value pair from the aggregationBufferMapIterator.
mapIteratorHasNext = aggregationBufferMapIterator.next()
// If the map is empty, we just free it.
if (!mapIteratorHasNext) {
hashMap.free()
// If we did not switch to sort-based aggregation in processInputs,
// we pre-load the first key-value pair from the map (to make hasNext idempotent).
if (!sortBased) {
// First, set aggregationBufferMapIterator.
aggregationBufferMapIterator = hashMap.iterator()
// Pre-load the first key-value pair from the aggregationBufferMapIterator.
mapIteratorHasNext = aggregationBufferMapIterator.next()
// If the map is empty, we just free it.
if (!mapIteratorHasNext) {
hashMap.free()
}
}
}

Expand Down Expand Up @@ -673,21 +690,20 @@ class TungstenAggregationIterator(
}

///////////////////////////////////////////////////////////////////////////
// Part 8: A utility function used to generate a output row when there is no
// input and there is no grouping expression.
// Part 8: Utility functions
///////////////////////////////////////////////////////////////////////////

/**
* Generate a output row when there is no input and there is no grouping expression.
*/
def outputForEmptyGroupingKeyWithoutInput(): UnsafeRow = {
if (groupingExpressions.isEmpty) {
sortBasedAggregationBuffer.copyFrom(initialAggregationBuffer)
// We create a output row and copy it. So, we can free the map.
val resultCopy =
generateOutput(UnsafeRow.createFromByteArray(0, 0), sortBasedAggregationBuffer).copy()
hashMap.free()
resultCopy
} else {
throw new IllegalStateException(
"This method should not be called when groupingExpressions is not empty.")
}
assert(groupingExpressions.isEmpty)
assert(inputIter == null)
generateOutput(UnsafeRow.createFromByteArray(0, 0), initialAggregationBuffer)
}

/** Free memory used in the underlying map. */
def free(): Unit = {
hashMap.free()
}
}
Loading

0 comments on commit e011079

Please sign in to comment.