diff --git a/core/src/main/java/org/apache/spark/unsafe/map/BytesToBytesMap.java b/core/src/main/java/org/apache/spark/unsafe/map/BytesToBytesMap.java index 85b46ec8bfae3..87ed47e88c4ef 100644 --- a/core/src/main/java/org/apache/spark/unsafe/map/BytesToBytesMap.java +++ b/core/src/main/java/org/apache/spark/unsafe/map/BytesToBytesMap.java @@ -193,6 +193,11 @@ public BytesToBytesMap( TaskMemoryManager.MAXIMUM_PAGE_SIZE_BYTES); } allocate(initialCapacity); + + // Acquire a new page as soon as we construct the map to ensure that we have at least + // one page to work with. Otherwise, other operators in the same task may starve this + // map (SPARK-9747). + acquireNewPage(); } public BytesToBytesMap( @@ -574,16 +579,9 @@ public boolean putNewKey( final long lengthOffsetInPage = currentDataPage.getBaseOffset() + pageCursor; Platform.putInt(pageBaseObject, lengthOffsetInPage, END_OF_PAGE_MARKER); } - final long memoryGranted = shuffleMemoryManager.tryToAcquire(pageSizeBytes); - if (memoryGranted != pageSizeBytes) { - shuffleMemoryManager.release(memoryGranted); - logger.debug("Failed to acquire {} bytes of memory", pageSizeBytes); + if (!acquireNewPage()) { return false; } - MemoryBlock newPage = taskMemoryManager.allocatePage(pageSizeBytes); - dataPages.add(newPage); - pageCursor = 0; - currentDataPage = newPage; dataPage = currentDataPage; dataPageBaseObject = currentDataPage.getBaseObject(); dataPageInsertOffset = currentDataPage.getBaseOffset(); @@ -642,6 +640,24 @@ public boolean putNewKey( } } + /** + * Acquire a new page from the {@link ShuffleMemoryManager}. + * @return whether there is enough space to allocate the new page. + */ + private boolean acquireNewPage() { + final long memoryGranted = shuffleMemoryManager.tryToAcquire(pageSizeBytes); + if (memoryGranted != pageSizeBytes) { + shuffleMemoryManager.release(memoryGranted); + logger.debug("Failed to acquire {} bytes of memory", pageSizeBytes); + return false; + } + MemoryBlock newPage = taskMemoryManager.allocatePage(pageSizeBytes); + dataPages.add(newPage); + pageCursor = 0; + currentDataPage = newPage; + return true; + } + /** * Allocate new data structures for this map. When calling this outside of the constructor, * make sure to keep references to the old data structures so that you can free them. @@ -748,7 +764,7 @@ public long getNumHashCollisions() { } @VisibleForTesting - int getNumDataPages() { + public int getNumDataPages() { return dataPages.size(); } diff --git a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorter.java b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorter.java index 9601aafe55464..fc364e0a895b1 100644 --- a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorter.java +++ b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorter.java @@ -132,16 +132,15 @@ private UnsafeExternalSorter( if (existingInMemorySorter == null) { initializeForWriting(); + // Acquire a new page as soon as we construct the sorter to ensure that we have at + // least one page to work with. Otherwise, other operators in the same task may starve + // this sorter (SPARK-9709). We don't need to do this if we already have an existing sorter. + acquireNewPage(); } else { this.isInMemSorterExternal = true; this.inMemSorter = existingInMemorySorter; } - // Acquire a new page as soon as we construct the sorter to ensure that we have at - // least one page to work with. Otherwise, other operators in the same task may starve - // this sorter (SPARK-9709). - acquireNewPage(); - // Register a cleanup task with TaskContext to ensure that memory is guaranteed to be freed at // the end of the task. This is necessary to avoid memory leaks in when the downstream operator // does not fully consume the sorter's output (e.g. sort followed by limit). diff --git a/core/src/test/java/org/apache/spark/unsafe/map/AbstractBytesToBytesMapSuite.java b/core/src/test/java/org/apache/spark/unsafe/map/AbstractBytesToBytesMapSuite.java index 1a79c20c35246..ab480b60adaed 100644 --- a/core/src/test/java/org/apache/spark/unsafe/map/AbstractBytesToBytesMapSuite.java +++ b/core/src/test/java/org/apache/spark/unsafe/map/AbstractBytesToBytesMapSuite.java @@ -543,7 +543,7 @@ public void testPeakMemoryUsed() { Platform.LONG_ARRAY_OFFSET, 8); newPeakMemory = map.getPeakMemoryUsedBytes(); - if (i % numRecordsPerPage == 0) { + if (i % numRecordsPerPage == 0 && i > 0) { // We allocated a new page for this record, so peak memory should change assertEquals(previousPeakMemory + pageSizeBytes, newPeakMemory); } else { @@ -561,4 +561,13 @@ public void testPeakMemoryUsed() { map.free(); } } + + @Test + public void testAcquirePageInConstructor() { + final BytesToBytesMap map = new BytesToBytesMap( + taskMemoryManager, shuffleMemoryManager, 1, PAGE_SIZE_BYTES); + assertEquals(1, map.getNumDataPages()); + map.free(); + } + } diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/UnsafeFixedWidthAggregationMap.java b/sql/core/src/main/java/org/apache/spark/sql/execution/UnsafeFixedWidthAggregationMap.java index 5cce41d5a7569..09511ff35f785 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/UnsafeFixedWidthAggregationMap.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/UnsafeFixedWidthAggregationMap.java @@ -19,6 +19,8 @@ import java.io.IOException; +import com.google.common.annotations.VisibleForTesting; + import org.apache.spark.SparkEnv; import org.apache.spark.shuffle.ShuffleMemoryManager; import org.apache.spark.sql.catalyst.InternalRow; @@ -220,6 +222,11 @@ public long getPeakMemoryUsedBytes() { return map.getPeakMemoryUsedBytes(); } + @VisibleForTesting + public int getNumDataPages() { + return map.getNumDataPages(); + } + /** * Free the memory associated with this map. This is idempotent and can be called multiple times. */ diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregate.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregate.scala index 6b5935a7ce296..c40ca973796a6 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregate.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregate.scala @@ -17,12 +17,13 @@ package org.apache.spark.sql.execution.aggregate -import org.apache.spark.rdd.RDD +import org.apache.spark.TaskContext +import org.apache.spark.rdd.{MapPartitionsWithPreparationRDD, RDD} import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.errors._ import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression2 import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.catalyst.plans.physical.{UnspecifiedDistribution, ClusteredDistribution, AllTuples, Distribution} +import org.apache.spark.sql.catalyst.plans.physical._ import org.apache.spark.sql.execution.{UnaryNode, SparkPlan} import org.apache.spark.sql.execution.metric.SQLMetrics @@ -68,35 +69,56 @@ case class TungstenAggregate( protected override def doExecute(): RDD[InternalRow] = attachTree(this, "execute") { val numInputRows = longMetric("numInputRows") val numOutputRows = longMetric("numOutputRows") - child.execute().mapPartitions { iter => - val hasInput = iter.hasNext - if (!hasInput && groupingExpressions.nonEmpty) { - // This is a grouped aggregate and the input iterator is empty, - // so return an empty iterator. - Iterator.empty.asInstanceOf[Iterator[UnsafeRow]] - } else { - val aggregationIterator = - new TungstenAggregationIterator( - groupingExpressions, - nonCompleteAggregateExpressions, - completeAggregateExpressions, - initialInputBufferOffset, - resultExpressions, - newMutableProjection, - child.output, - iter, - testFallbackStartsAt, - numInputRows, - numOutputRows) - - if (!hasInput && groupingExpressions.isEmpty) { + + /** + * Set up the underlying unsafe data structures used before computing the parent partition. + * This makes sure our iterator is not starved by other operators in the same task. + */ + def preparePartition(): TungstenAggregationIterator = { + new TungstenAggregationIterator( + groupingExpressions, + nonCompleteAggregateExpressions, + completeAggregateExpressions, + initialInputBufferOffset, + resultExpressions, + newMutableProjection, + child.output, + testFallbackStartsAt, + numInputRows, + numOutputRows) + } + + /** Compute a partition using the iterator already set up previously. */ + def executePartition( + context: TaskContext, + partitionIndex: Int, + aggregationIterator: TungstenAggregationIterator, + parentIterator: Iterator[InternalRow]): Iterator[UnsafeRow] = { + val hasInput = parentIterator.hasNext + if (!hasInput) { + // We're not using the underlying map, so we just can free it here + aggregationIterator.free() + if (groupingExpressions.isEmpty) { numOutputRows += 1 Iterator.single[UnsafeRow](aggregationIterator.outputForEmptyGroupingKeyWithoutInput()) } else { - aggregationIterator + // This is a grouped aggregate and the input iterator is empty, + // so return an empty iterator. + Iterator[UnsafeRow]() } + } else { + aggregationIterator.start(parentIterator) + aggregationIterator } } + + // Note: we need to set up the iterator in each partition before computing the + // parent partition, so we cannot simply use `mapPartitions` here (SPARK-9747). + val resultRdd = { + new MapPartitionsWithPreparationRDD[UnsafeRow, InternalRow, TungstenAggregationIterator]( + child.execute(), preparePartition, executePartition, preservesPartitioning = true) + } + resultRdd.asInstanceOf[RDD[InternalRow]] } override def simpleString: String = { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregationIterator.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregationIterator.scala index 1f383dd04482f..af7e0fcedbe4e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregationIterator.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregationIterator.scala @@ -72,8 +72,6 @@ import org.apache.spark.sql.types.StructType * the function used to create mutable projections. * @param originalInputAttributes * attributes of representing input rows from `inputIter`. - * @param inputIter - * the iterator containing input [[UnsafeRow]]s. */ class TungstenAggregationIterator( groupingExpressions: Seq[NamedExpression], @@ -83,12 +81,14 @@ class TungstenAggregationIterator( resultExpressions: Seq[NamedExpression], newMutableProjection: (Seq[Expression], Seq[Attribute]) => (() => MutableProjection), originalInputAttributes: Seq[Attribute], - inputIter: Iterator[InternalRow], testFallbackStartsAt: Option[Int], numInputRows: LongSQLMetric, numOutputRows: LongSQLMetric) extends Iterator[UnsafeRow] with Logging { + // The parent partition iterator, to be initialized later in `start` + private[this] var inputIter: Iterator[InternalRow] = null + /////////////////////////////////////////////////////////////////////////// // Part 1: Initializing aggregate functions. /////////////////////////////////////////////////////////////////////////// @@ -348,11 +348,15 @@ class TungstenAggregationIterator( false // disable tracking of performance metrics ) + // Exposed for testing + private[aggregate] def getHashMap: UnsafeFixedWidthAggregationMap = hashMap + // The function used to read and process input rows. When processing input rows, // it first uses hash-based aggregation by putting groups and their buffers in // hashMap. If we could not allocate more memory for the map, we switch to // sort-based aggregation (by calling switchToSortBasedAggregation). private def processInputs(): Unit = { + assert(inputIter != null, "attempted to process input when iterator was null") while (!sortBased && inputIter.hasNext) { val newInput = inputIter.next() numInputRows += 1 @@ -372,6 +376,7 @@ class TungstenAggregationIterator( // that it switch to sort-based aggregation after `fallbackStartsAt` input rows have // been processed. private def processInputsWithControlledFallback(fallbackStartsAt: Int): Unit = { + assert(inputIter != null, "attempted to process input when iterator was null") var i = 0 while (!sortBased && inputIter.hasNext) { val newInput = inputIter.next() @@ -412,6 +417,7 @@ class TungstenAggregationIterator( * Switch to sort-based aggregation when the hash-based approach is unable to acquire memory. */ private def switchToSortBasedAggregation(firstKey: UnsafeRow, firstInput: InternalRow): Unit = { + assert(inputIter != null, "attempted to process input when iterator was null") logInfo("falling back to sort based aggregation.") // Step 1: Get the ExternalSorter containing sorted entries of the map. externalSorter = hashMap.destructAndCreateExternalSorter() @@ -431,6 +437,11 @@ class TungstenAggregationIterator( case _ => false } + // Note: Since we spill the sorter's contents immediately after creating it, we must insert + // something into the sorter here to ensure that we acquire at least a page of memory. + // This is done through `externalSorter.insertKV`, which will trigger the page allocation. + // Otherwise, children operators may steal the window of opportunity and starve our sorter. + if (needsProcess) { // First, we create a buffer. val buffer = createNewAggregationBuffer() @@ -588,27 +599,33 @@ class TungstenAggregationIterator( // have not switched to sort-based aggregation. /////////////////////////////////////////////////////////////////////////// - // Starts to process input rows. - testFallbackStartsAt match { - case None => - processInputs() - case Some(fallbackStartsAt) => - // This is the testing path. processInputsWithControlledFallback is same as processInputs - // except that it switches to sort-based aggregation after `fallbackStartsAt` input rows - // have been processed. - processInputsWithControlledFallback(fallbackStartsAt) - } + /** + * Start processing input rows. + * Only after this method is called will this iterator be non-empty. + */ + def start(parentIter: Iterator[InternalRow]): Unit = { + inputIter = parentIter + testFallbackStartsAt match { + case None => + processInputs() + case Some(fallbackStartsAt) => + // This is the testing path. processInputsWithControlledFallback is same as processInputs + // except that it switches to sort-based aggregation after `fallbackStartsAt` input rows + // have been processed. + processInputsWithControlledFallback(fallbackStartsAt) + } - // If we did not switch to sort-based aggregation in processInputs, - // we pre-load the first key-value pair from the map (to make hasNext idempotent). - if (!sortBased) { - // First, set aggregationBufferMapIterator. - aggregationBufferMapIterator = hashMap.iterator() - // Pre-load the first key-value pair from the aggregationBufferMapIterator. - mapIteratorHasNext = aggregationBufferMapIterator.next() - // If the map is empty, we just free it. - if (!mapIteratorHasNext) { - hashMap.free() + // If we did not switch to sort-based aggregation in processInputs, + // we pre-load the first key-value pair from the map (to make hasNext idempotent). + if (!sortBased) { + // First, set aggregationBufferMapIterator. + aggregationBufferMapIterator = hashMap.iterator() + // Pre-load the first key-value pair from the aggregationBufferMapIterator. + mapIteratorHasNext = aggregationBufferMapIterator.next() + // If the map is empty, we just free it. + if (!mapIteratorHasNext) { + hashMap.free() + } } } @@ -673,21 +690,20 @@ class TungstenAggregationIterator( } /////////////////////////////////////////////////////////////////////////// - // Part 8: A utility function used to generate a output row when there is no - // input and there is no grouping expression. + // Part 8: Utility functions /////////////////////////////////////////////////////////////////////////// + /** + * Generate a output row when there is no input and there is no grouping expression. + */ def outputForEmptyGroupingKeyWithoutInput(): UnsafeRow = { - if (groupingExpressions.isEmpty) { - sortBasedAggregationBuffer.copyFrom(initialAggregationBuffer) - // We create a output row and copy it. So, we can free the map. - val resultCopy = - generateOutput(UnsafeRow.createFromByteArray(0, 0), sortBasedAggregationBuffer).copy() - hashMap.free() - resultCopy - } else { - throw new IllegalStateException( - "This method should not be called when groupingExpressions is not empty.") - } + assert(groupingExpressions.isEmpty) + assert(inputIter == null) + generateOutput(UnsafeRow.createFromByteArray(0, 0), initialAggregationBuffer) + } + + /** Free memory used in the underlying map. */ + def free(): Unit = { + hashMap.free() } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregationIteratorSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregationIteratorSuite.scala new file mode 100644 index 0000000000000..ac22c2f3c0a58 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregationIteratorSuite.scala @@ -0,0 +1,56 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.aggregate + +import org.apache.spark._ +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.expressions.InterpretedMutableProjection +import org.apache.spark.sql.execution.metric.SQLMetrics +import org.apache.spark.sql.test.TestSQLContext +import org.apache.spark.unsafe.memory.TaskMemoryManager + +class TungstenAggregationIteratorSuite extends SparkFunSuite { + + test("memory acquired on construction") { + // set up environment + val ctx = TestSQLContext + + val taskMemoryManager = new TaskMemoryManager(SparkEnv.get.executorMemoryManager) + val taskContext = new TaskContextImpl(0, 0, 0, 0, taskMemoryManager, null, Seq.empty) + TaskContext.setTaskContext(taskContext) + + // Assert that a page is allocated before processing starts + var iter: TungstenAggregationIterator = null + try { + val newMutableProjection = (expr: Seq[Expression], schema: Seq[Attribute]) => { + () => new InterpretedMutableProjection(expr, schema) + } + val dummyAccum = SQLMetrics.createLongMetric(ctx.sparkContext, "dummy") + iter = new TungstenAggregationIterator(Seq.empty, Seq.empty, Seq.empty, 0, + Seq.empty, newMutableProjection, Seq.empty, None, dummyAccum, dummyAccum) + val numPages = iter.getHashMap.getNumDataPages + assert(numPages === 1) + } finally { + // Clean up + if (iter != null) { + iter.free() + } + TaskContext.unset() + } + } +}