Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SPARK-9709] [SQL] Avoid starving unsafe operators that use sort #8011

Closed
wants to merge 6 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -138,6 +138,11 @@ private UnsafeExternalSorter(
this.inMemSorter = existingInMemorySorter;
}

// Acquire a new page as soon as we construct the sorter to ensure that we have at
// least one page to work with. Otherwise, other operators in the same task may starve
// this sorter (SPARK-9709).
acquireNewPage();

// Register a cleanup task with TaskContext to ensure that memory is guaranteed to be freed at
// the end of the task. This is necessary to avoid memory leaks in when the downstream operator
// does not fully consume the sorter's output (e.g. sort followed by limit).
Expand Down Expand Up @@ -343,22 +348,32 @@ private void acquireNewPageIfNecessary(int requiredSpace) throws IOException {
throw new IOException("Required space " + requiredSpace + " is greater than page size (" +
pageSizeBytes + ")");
} else {
final long memoryAcquired = shuffleMemoryManager.tryToAcquire(pageSizeBytes);
if (memoryAcquired < pageSizeBytes) {
shuffleMemoryManager.release(memoryAcquired);
spill();
final long memoryAcquiredAfterSpilling = shuffleMemoryManager.tryToAcquire(pageSizeBytes);
if (memoryAcquiredAfterSpilling != pageSizeBytes) {
shuffleMemoryManager.release(memoryAcquiredAfterSpilling);
throw new IOException("Unable to acquire " + pageSizeBytes + " bytes of memory");
}
}
currentPage = taskMemoryManager.allocatePage(pageSizeBytes);
currentPagePosition = currentPage.getBaseOffset();
freeSpaceInCurrentPage = pageSizeBytes;
allocatedPages.add(currentPage);
acquireNewPage();
}
}
}

/**
* Acquire a new page from the {@link ShuffleMemoryManager}.
*
* If there is not enough space to allocate the new page, spill all existing ones
* and try again. If there is still not enough space, report error to the caller.
*/
private void acquireNewPage() throws IOException {
final long memoryAcquired = shuffleMemoryManager.tryToAcquire(pageSizeBytes);
if (memoryAcquired < pageSizeBytes) {
shuffleMemoryManager.release(memoryAcquired);
spill();
final long memoryAcquiredAfterSpilling = shuffleMemoryManager.tryToAcquire(pageSizeBytes);
if (memoryAcquiredAfterSpilling != pageSizeBytes) {
shuffleMemoryManager.release(memoryAcquiredAfterSpilling);
throw new IOException("Unable to acquire " + pageSizeBytes + " bytes of memory");
}
}
currentPage = taskMemoryManager.allocatePage(pageSizeBytes);
currentPagePosition = currentPage.getBaseOffset();
freeSpaceInCurrentPage = pageSizeBytes;
allocatedPages.add(currentPage);
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,9 @@ import scala.reflect.ClassTag

import org.apache.spark.{Partition, TaskContext}

/**
* An RDD that applies the provided function to every partition of the parent RDD.
*/
private[spark] class MapPartitionsRDD[U: ClassTag, T: ClassTag](
prev: RDD[T],
f: (TaskContext, Int, Iterator[T]) => Iterator[U], // (TaskContext, partition index, iterator)
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.rdd

import scala.reflect.ClassTag

import org.apache.spark.{Partition, Partitioner, TaskContext}

/**
* An RDD that applies a user provided function to every partition of the parent RDD, and
* additionally allows the user to prepare each partition before computing the parent partition.
*/
private[spark] class MapPartitionsWithPreparationRDD[U: ClassTag, T: ClassTag, M: ClassTag](
prev: RDD[T],
preparePartition: () => M,
executePartition: (TaskContext, Int, M, Iterator[T]) => Iterator[U],
preservesPartitioning: Boolean = false)
extends RDD[U](prev) {

override val partitioner: Option[Partitioner] = {
if (preservesPartitioning) firstParent[T].partitioner else None
}

override def getPartitions: Array[Partition] = firstParent[T].partitions

/**
* Prepare a partition before computing it from its parent.
*/
override def compute(partition: Partition, context: TaskContext): Iterator[U] = {
val preparedArgument = preparePartition()
val parentIterator = firstParent[T].iterator(partition, context)
executePartition(context, partition.index, preparedArgument, parentIterator)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -124,7 +124,7 @@ private[spark] class ShuffleMemoryManager(maxMemory: Long) extends Logging {
}
}

private object ShuffleMemoryManager {
private[spark] object ShuffleMemoryManager {
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

used in tests

/**
* Figure out the shuffle memory limit from a SparkConf. We currently have both a fraction
* of the memory pool and a safety factor since collections can sometimes grow bigger than
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -340,7 +340,8 @@ public void testPeakMemoryUsed() throws Exception {
for (int i = 0; i < numRecordsPerPage * 10; i++) {
insertNumber(sorter, i);
newPeakMemory = sorter.getPeakMemoryUsedBytes();
if (i % numRecordsPerPage == 0) {
// The first page is pre-allocated on instantiation
if (i % numRecordsPerPage == 0 && i > 0) {
// We allocated a new page for this record, so peak memory should change
assertEquals(previousPeakMemory + pageSizeBytes, newPeakMemory);
} else {
Expand All @@ -364,5 +365,21 @@ public void testPeakMemoryUsed() throws Exception {
}
}

@Test
public void testReservePageOnInstantiation() throws Exception {
final UnsafeExternalSorter sorter = newSorter();
try {
assertEquals(1, sorter.getNumberOfAllocatedPages());
// Inserting a new record doesn't allocate more memory since we already have a page
long peakMemory = sorter.getPeakMemoryUsedBytes();
insertNumber(sorter, 100);
assertEquals(peakMemory, sorter.getPeakMemoryUsedBytes());
assertEquals(1, sorter.getNumberOfAllocatedPages());
} finally {
sorter.cleanupResources();
assertSpillFilesWereCleanedUp();
}
}

}

Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.rdd

import scala.collection.mutable

import org.apache.spark.{LocalSparkContext, SparkContext, SparkFunSuite, TaskContext}

class MapPartitionsWithPreparationRDDSuite extends SparkFunSuite with LocalSparkContext {

test("prepare called before parent partition is computed") {
sc = new SparkContext("local", "test")

// Have the parent partition push a number to the list
val parent = sc.parallelize(1 to 100, 1).mapPartitions { iter =>
TestObject.things.append(20)
iter
}

// Push a different number during the prepare phase
val preparePartition = () => { TestObject.things.append(10) }

// Push yet another number during the execution phase
val executePartition = (
taskContext: TaskContext,
partitionIndex: Int,
notUsed: Unit,
parentIterator: Iterator[Int]) => {
TestObject.things.append(30)
TestObject.things.iterator
}

// Verify that the numbers are pushed in the order expected
val result = {
new MapPartitionsWithPreparationRDD[Int, Int, Unit](
parent, preparePartition, executePartition).collect()
}
assert(result === Array(10, 20, 30))
}

}

private object TestObject {
val things = new mutable.ListBuffer[Int]
}
Original file line number Diff line number Diff line change
Expand Up @@ -158,7 +158,7 @@ abstract class SparkPlan extends QueryPlan[SparkPlan] with Logging with Serializ
*/
final def prepare(): Unit = {
if (prepareCalled.compareAndSet(false, true)) {
doPrepare
doPrepare()
children.foreach(_.prepare())
}
}
Expand Down
28 changes: 23 additions & 5 deletions sql/core/src/main/scala/org/apache/spark/sql/execution/sort.scala
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
package org.apache.spark.sql.execution

import org.apache.spark.{InternalAccumulator, TaskContext}
import org.apache.spark.rdd.RDD
import org.apache.spark.rdd.{MapPartitionsWithPreparationRDD, RDD}
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.errors._
import org.apache.spark.sql.catalyst.expressions._
Expand Down Expand Up @@ -123,7 +123,12 @@ case class TungstenSort(
val schema = child.schema
val childOutput = child.output
val pageSize = sparkContext.conf.getSizeAsBytes("spark.buffer.pageSize", "64m")
child.execute().mapPartitions({ iter =>

/**
* Set up the sorter in each partition before computing the parent partition.
* This makes sure our sorter is not starved by other sorters used in the same task.
*/
def preparePartition(): UnsafeExternalRowSorter = {
val ordering = newOrdering(sortOrder, childOutput)

// The comparator for comparing prefix
Expand All @@ -143,12 +148,25 @@ case class TungstenSort(
if (testSpillFrequency > 0) {
sorter.setTestSpillFrequency(testSpillFrequency)
}
val sortedIterator = sorter.sort(iter.asInstanceOf[Iterator[UnsafeRow]])
val taskContext = TaskContext.get()
sorter
}

/** Compute a partition using the sorter already set up previously. */
def executePartition(
taskContext: TaskContext,
partitionIndex: Int,
sorter: UnsafeExternalRowSorter,
parentIterator: Iterator[InternalRow]): Iterator[InternalRow] = {
val sortedIterator = sorter.sort(parentIterator.asInstanceOf[Iterator[UnsafeRow]])
taskContext.internalMetricsToAccumulators(
InternalAccumulator.PEAK_EXECUTION_MEMORY).add(sorter.getPeakMemoryUsage)
sortedIterator
}, preservesPartitioning = true)
}

// Note: we need to set up the external sorter in each partition before computing
// the parent partition, so we cannot simply use `mapPartitions` here (SPARK-9709).
new MapPartitionsWithPreparationRDD[InternalRow, InternalRow, UnsafeExternalRowSorter](
child.execute(), preparePartition, executePartition, preservesPartitioning = true)
}

}
Expand Down