Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SPARK-9747] [SQL] Avoid starving an unsafe operator in aggregation #8038

Closed
wants to merge 14 commits into from
Closed
Original file line number Diff line number Diff line change
Expand Up @@ -191,6 +191,11 @@ public BytesToBytesMap(
TaskMemoryManager.MAXIMUM_PAGE_SIZE_BYTES);
}
allocate(initialCapacity);

// Acquire a new page as soon as we construct the map to ensure that we have at least
// one page to work with. Otherwise, other operators in the same task may starve this
// map (SPARK-9747).
acquireNewPage();
}

public BytesToBytesMap(
Expand Down Expand Up @@ -565,16 +570,9 @@ public boolean putNewKey(
final long lengthOffsetInPage = currentDataPage.getBaseOffset() + pageCursor;
PlatformDependent.UNSAFE.putInt(pageBaseObject, lengthOffsetInPage, END_OF_PAGE_MARKER);
}
final long memoryGranted = shuffleMemoryManager.tryToAcquire(pageSizeBytes);
if (memoryGranted != pageSizeBytes) {
shuffleMemoryManager.release(memoryGranted);
logger.debug("Failed to acquire {} bytes of memory", pageSizeBytes);
if (!acquireNewPage()) {
return false;
}
MemoryBlock newPage = taskMemoryManager.allocatePage(pageSizeBytes);
dataPages.add(newPage);
pageCursor = 0;
currentDataPage = newPage;
dataPage = currentDataPage;
dataPageBaseObject = currentDataPage.getBaseObject();
dataPageInsertOffset = currentDataPage.getBaseOffset();
Expand Down Expand Up @@ -633,6 +631,24 @@ public boolean putNewKey(
}
}

/**
* Acquire a new page from the {@link ShuffleMemoryManager}.
* @return whether there is enough space to allocate the new page.
*/
private boolean acquireNewPage() {
final long memoryGranted = shuffleMemoryManager.tryToAcquire(pageSizeBytes);
if (memoryGranted != pageSizeBytes) {
shuffleMemoryManager.release(memoryGranted);
logger.debug("Failed to acquire {} bytes of memory", pageSizeBytes);
return false;
}
MemoryBlock newPage = taskMemoryManager.allocatePage(pageSizeBytes);
dataPages.add(newPage);
pageCursor = 0;
currentDataPage = newPage;
return true;
}

/**
* Allocate new data structures for this map. When calling this outside of the constructor,
* make sure to keep references to the old data structures so that you can free them.
Expand Down Expand Up @@ -722,7 +738,7 @@ public long getNumHashCollisions() {
}

@VisibleForTesting
int getNumDataPages() {
public int getNumDataPages() {
return dataPages.size();
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -549,7 +549,7 @@ public void testTotalMemoryConsumption() {
PlatformDependent.LONG_ARRAY_OFFSET,
8);
newMemory = map.getTotalMemoryConsumption();
if (i % numRecordsPerPage == 0) {
if (i % numRecordsPerPage == 0 && i > 0) {
// We allocated a new page for this record, so peak memory should change
assertEquals(previousMemory + pageSizeBytes, newMemory);
} else {
Expand All @@ -561,4 +561,13 @@ public void testTotalMemoryConsumption() {
map.free();
}
}

@Test
public void testAcquirePageInConstructor() {
final BytesToBytesMap map = new BytesToBytesMap(
taskMemoryManager, shuffleMemoryManager, 1, PAGE_SIZE_BYTES);
assertEquals(1, map.getNumDataPages());
map.free();
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@

import java.io.IOException;

import com.google.common.annotations.VisibleForTesting;

import org.apache.spark.SparkEnv;
import org.apache.spark.shuffle.ShuffleMemoryManager;
import org.apache.spark.sql.catalyst.InternalRow;
Expand Down Expand Up @@ -217,6 +219,11 @@ public long getMemoryUsage() {
return map.getTotalMemoryConsumption();
}

@VisibleForTesting
public int getNumDataPages() {
return map.getNumDataPages();
}

/**
* Free the memory associated with this map. This is idempotent and can be called multiple times.
*/
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,8 @@

package org.apache.spark.sql.execution.aggregate

import org.apache.spark.rdd.RDD
import org.apache.spark.TaskContext
import org.apache.spark.rdd.{MapPartitionsWithPreparationRDD, RDD}
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.errors._
import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression2
Expand Down Expand Up @@ -61,32 +62,54 @@ case class TungstenAggregate(
}

protected override def doExecute(): RDD[InternalRow] = attachTree(this, "execute") {
child.execute().mapPartitions { iter =>
val hasInput = iter.hasNext
if (!hasInput && groupingExpressions.nonEmpty) {
// This is a grouped aggregate and the input iterator is empty,
// so return an empty iterator.
Iterator.empty.asInstanceOf[Iterator[UnsafeRow]]
} else {
val aggregationIterator =
new TungstenAggregationIterator(
groupingExpressions,
nonCompleteAggregateExpressions,
completeAggregateExpressions,
initialInputBufferOffset,
resultExpressions,
newMutableProjection,
child.output,
iter.asInstanceOf[Iterator[UnsafeRow]],
testFallbackStartsAt)

if (!hasInput && groupingExpressions.isEmpty) {
/**
* Set up the underlying unsafe data structures used before computing the parent partition.
* This makes sure our iterator is not starved by other operators in the same task.
*/
def preparePartition(): TungstenAggregationIterator = {
new TungstenAggregationIterator(
groupingExpressions,
nonCompleteAggregateExpressions,
completeAggregateExpressions,
initialInputBufferOffset,
resultExpressions,
newMutableProjection,
child.output,
testFallbackStartsAt)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What will happen if there is no memory space left to reserve?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we'll fail fast with "unable to acquire memory" exception

}

/** Compute a partition using the iterator already set up previously. */
def executePartition(
context: TaskContext,
partitionIndex: Int,
aggregationIterator: TungstenAggregationIterator,
parentIterator: Iterator[UnsafeRow]): Iterator[UnsafeRow] = {
val hasInput = parentIterator.hasNext
if (!hasInput) {
// We're not using the underlying map, so we just can free it here
aggregationIterator.free()
if (groupingExpressions.isEmpty) {
// This is a grouped aggregate and the input iterator is empty,
// so return an empty iterator.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Seems we should put this comment in the else block. Instead, this branch is used when we do not have input row and there is no grouping expression.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

good catch.

Iterator.single[UnsafeRow](aggregationIterator.outputForEmptyGroupingKeyWithoutInput())
} else {
aggregationIterator
Iterator[UnsafeRow]()
}
} else {
aggregationIterator.start(parentIterator)
aggregationIterator
}
}

// Note: we need to set up the iterator in each partition before computing the
// parent partition, so we cannot simply use `mapPartitions` here (SPARK-9747).
val parentPartition = child.execute().asInstanceOf[RDD[UnsafeRow]]
val resultRdd = {
new MapPartitionsWithPreparationRDD[UnsafeRow, UnsafeRow, TungstenAggregationIterator](
parentPartition, preparePartition, executePartition, preservesPartitioning = true)
}
resultRdd.asInstanceOf[RDD[InternalRow]]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we just return resultRdd? Seems we do not need to cast?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actually result RDD is of type RDD[UnsafeRow]. Since RDDs are not covariant I think we do need the cast.

}

override def simpleString: String = {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -71,8 +71,6 @@ import org.apache.spark.sql.types.StructType
* the function used to create mutable projections.
* @param originalInputAttributes
* attributes of representing input rows from `inputIter`.
* @param inputIter
* the iterator containing input [[UnsafeRow]]s.
*/
class TungstenAggregationIterator(
groupingExpressions: Seq[NamedExpression],
Expand All @@ -82,10 +80,12 @@ class TungstenAggregationIterator(
resultExpressions: Seq[NamedExpression],
newMutableProjection: (Seq[Expression], Seq[Attribute]) => (() => MutableProjection),
originalInputAttributes: Seq[Attribute],
inputIter: Iterator[UnsafeRow],
testFallbackStartsAt: Option[Int])
extends Iterator[UnsafeRow] with Logging {

// The parent partition iterator, to be initialized later in `start`
private[this] var inputIter: Iterator[UnsafeRow] = Iterator[UnsafeRow]()

///////////////////////////////////////////////////////////////////////////
// Part 1: Initializing aggregate functions.
///////////////////////////////////////////////////////////////////////////
Expand Down Expand Up @@ -335,7 +335,7 @@ class TungstenAggregationIterator(
// This is the hash map used for hash-based aggregation. It is backed by an
// UnsafeFixedWidthAggregationMap and it is used to store
// all groups and their corresponding aggregation buffers for hash-based aggregation.
private[this] val hashMap = new UnsafeFixedWidthAggregationMap(
private[aggregate] val hashMap = new UnsafeFixedWidthAggregationMap(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If we change it to private[aggregate], looks like we will lose the direct field access of hashMap (provided by private[this]). How about we create a method to return hashMap instead of changing the scope.

initialAggregationBuffer,
StructType.fromAttributes(allAggregateFunctions.flatMap(_.bufferAttributes)),
StructType.fromAttributes(groupingExpressions.map(_.toAttribute)),
Expand Down Expand Up @@ -576,27 +576,33 @@ class TungstenAggregationIterator(
// have not switched to sort-based aggregation.
///////////////////////////////////////////////////////////////////////////

// Starts to process input rows.
testFallbackStartsAt match {
case None =>
processInputs()
case Some(fallbackStartsAt) =>
// This is the testing path. processInputsWithControlledFallback is same as processInputs
// except that it switches to sort-based aggregation after `fallbackStartsAt` input rows
// have been processed.
processInputsWithControlledFallback(fallbackStartsAt)
}
/**
* Start processing input rows.
* Only after this method is called will this iterator be non-empty.
*/
def start(parentIter: Iterator[UnsafeRow]): Unit = {
inputIter = parentIter
testFallbackStartsAt match {
case None =>
processInputs()
case Some(fallbackStartsAt) =>
// This is the testing path. processInputsWithControlledFallback is same as processInputs
// except that it switches to sort-based aggregation after `fallbackStartsAt` input rows
// have been processed.
processInputsWithControlledFallback(fallbackStartsAt)
}

// If we did not switch to sort-based aggregation in processInputs,
// we pre-load the first key-value pair from the map (to make hasNext idempotent).
if (!sortBased) {
// First, set aggregationBufferMapIterator.
aggregationBufferMapIterator = hashMap.iterator()
// Pre-load the first key-value pair from the aggregationBufferMapIterator.
mapIteratorHasNext = aggregationBufferMapIterator.next()
// If the map is empty, we just free it.
if (!mapIteratorHasNext) {
hashMap.free()
// If we did not switch to sort-based aggregation in processInputs,
// we pre-load the first key-value pair from the map (to make hasNext idempotent).
if (!sortBased) {
// First, set aggregationBufferMapIterator.
aggregationBufferMapIterator = hashMap.iterator()
// Pre-load the first key-value pair from the aggregationBufferMapIterator.
mapIteratorHasNext = aggregationBufferMapIterator.next()
// If the map is empty, we just free it.
if (!mapIteratorHasNext) {
hashMap.free()
}
}
}

Expand Down Expand Up @@ -648,20 +654,20 @@ class TungstenAggregationIterator(
}

///////////////////////////////////////////////////////////////////////////
// Part 8: A utility function used to generate a output row when there is no
// input and there is no grouping expression.
// Part 8: Utility functions
///////////////////////////////////////////////////////////////////////////

/**
* Generate a output row when there is no input and there is no grouping expression.
*/
def outputForEmptyGroupingKeyWithoutInput(): UnsafeRow = {
if (groupingExpressions.isEmpty) {
sortBasedAggregationBuffer.copyFrom(initialAggregationBuffer)
// We create a output row and copy it. So, we can free the map.
val resultCopy =
generateOutput(UnsafeRow.createFromByteArray(0, 0), sortBasedAggregationBuffer).copy()
hashMap.free()
resultCopy
} else {
throw new IllegalStateException(
"This method should not be called when groupingExpressions is not empty.")
}
assert(groupingExpressions.isEmpty)
assert(!inputIter.hasNext)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If we change inputIter's initial value to null, we need to also change it to a null check.

generateOutput(UnsafeRow.createFromByteArray(0, 0), initialAggregationBuffer)
}

/** Free memory used in the underlying map. */
def free(): Unit = {
hashMap.free()
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.sql.execution.aggregate

import org.apache.spark._
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.unsafe.memory.TaskMemoryManager
import org.apache.spark.sql.catalyst.expressions.InterpretedMutableProjection

class TungstenAggregationIteratorSuite extends SparkFunSuite with LocalSparkContext {

test("memory acquired on construction") {
// Needed for various things in SparkEnv
sc = new SparkContext("local", "testing")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I feel the spark context we are creating at here messed up the the following tests. How about we comment it out and try the pr builder?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actually, is it possible to create the taskMemoryManager and shuffleMemoryManager without creating a new SparkContext?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yeah I can figure something out

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(this is why we shouldn't have singleton SQLContexts!)

val taskMemoryManager = new TaskMemoryManager(sc.env.executorMemoryManager)
val taskContext = new TaskContextImpl(0, 0, 0, 0, taskMemoryManager, null, Seq.empty)
TaskContext.setTaskContext(taskContext)

// Assert that a page is allocated before processing starts
var iter: TungstenAggregationIterator = null
try {
val newMutableProjection = (expr: Seq[Expression], schema: Seq[Attribute]) => {
() => new InterpretedMutableProjection(expr, schema)
}
iter = new TungstenAggregationIterator(
Seq.empty, Seq.empty, Seq.empty, 0, Seq.empty, newMutableProjection, Seq.empty, None)
val numPages = iter.hashMap.getNumDataPages
assert(numPages === 1)
} finally {
// Clean up
if (iter != null) {
iter.free()
}
TaskContext.unset()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we also call sc.stop?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this extends LocalSparkContext, which does that for us

}
}
}