Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SPARK-8464][Core][Shuffle] Consider separating aggregator and non-aggregator paths in ExternalSorter #7129

Closed
wants to merge 20 commits into from
Closed
Original file line number Diff line number Diff line change
Expand Up @@ -17,12 +17,13 @@

package org.apache.spark.shuffle.hash

import org.apache.spark.{InterruptibleIterator, Logging, MapOutputTracker, SparkEnv, TaskContext}
import org.apache.spark.Logging
import org.apache.spark.serializer.Serializer
import org.apache.spark.shuffle.{BaseShuffleHandle, ShuffleReader}
import org.apache.spark.storage.{BlockManager, ShuffleBlockFetcherIterator}
import org.apache.spark.util.CompletionIterator
import org.apache.spark.util.collection.ExternalSorter
import org.apache.spark.util.collection.ExternalSorterNoAgg
import org.apache.spark.{InterruptibleIterator, MapOutputTracker, SparkEnv, TaskContext}

private[spark] class HashShuffleReader[K, C](
handle: BaseShuffleHandle[K, _, C],
Expand Down Expand Up @@ -98,7 +99,8 @@ private[spark] class HashShuffleReader[K, C](
case Some(keyOrd: Ordering[K]) =>
// Create an ExternalSorter to sort the data. Note that if spark.shuffle.spill is disabled,
// the ExternalSorter won't spill to disk.
val sorter = new ExternalSorter[K, C, C](ordering = Some(keyOrd), serializer = Some(ser))
val sorter =
new ExternalSorterNoAgg[K, C, C](ordering = Some(keyOrd), serializer = Some(ser))
sorter.insertAll(aggregatedIter)
context.taskMetrics.incMemoryBytesSpilled(sorter.memoryBytesSpilled)
context.taskMetrics.incDiskBytesSpilled(sorter.diskBytesSpilled)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,9 +21,9 @@ import org.apache.spark._
import org.apache.spark.executor.ShuffleWriteMetrics
import org.apache.spark.scheduler.MapStatus
import org.apache.spark.serializer.Serializer
import org.apache.spark.shuffle.{IndexShuffleBlockResolver, ShuffleWriter, BaseShuffleHandle}
import org.apache.spark.shuffle.{BaseShuffleHandle, IndexShuffleBlockResolver, ShuffleWriter}
import org.apache.spark.storage.ShuffleBlockId
import org.apache.spark.util.collection.ExternalSorter
import org.apache.spark.util.collection.{ExternalSorterAgg, ExternalSorterNoAgg}

private[spark] class SortShuffleWriter[K, V, C](
shuffleBlockResolver: IndexShuffleBlockResolver,
Expand Down Expand Up @@ -52,8 +52,8 @@ private[spark] class SortShuffleWriter[K, V, C](
override def write(records: Iterator[Product2[K, V]]): Unit = {
sorter = if (dep.mapSideCombine) {
require(dep.aggregator.isDefined, "Map-side combine without Aggregator specified!")
new ExternalSorter[K, V, C](
dep.aggregator, Some(dep.partitioner), dep.keyOrdering, dep.serializer)
new ExternalSorterAgg[K, V, C](
dep.aggregator.get, Some(dep.partitioner), dep.keyOrdering, dep.serializer)
} else if (SortShuffleWriter.shouldBypassMergeSort(
SparkEnv.get.conf, dep.partitioner.numPartitions, aggregator = None, keyOrdering = None)) {
// If there are fewer than spark.shuffle.sort.bypassMergeThreshold partitions and we don't
Expand All @@ -67,8 +67,7 @@ private[spark] class SortShuffleWriter[K, V, C](
// In this case we pass neither an aggregator nor an ordering to the sorter, because we don't
// care whether the keys get sorted in each partition; that will be done on the reduce side
// if the operation being run is sortByKey.
new ExternalSorter[K, V, V](
aggregator = None, Some(dep.partitioner), ordering = None, dep.serializer)
new ExternalSorterNoAgg[K, V, V](Some(dep.partitioner), ordering = None, dep.serializer)
}
sorter.insertAll(records)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,16 +20,15 @@ package org.apache.spark.util.collection
import java.io._
import java.util.Comparator

import scala.collection.mutable.ArrayBuffer
import scala.collection.mutable
import scala.collection.mutable.ArrayBuffer

import com.google.common.annotations.VisibleForTesting
import com.google.common.io.ByteStreams

import org.apache.spark._
import org.apache.spark.serializer._
import org.apache.spark.executor.ShuffleWriteMetrics
import org.apache.spark.shuffle.sort.{SortShuffleFileWriter, SortShuffleWriter}
import org.apache.spark.serializer._
import org.apache.spark.shuffle.sort.SortShuffleFileWriter
import org.apache.spark.storage.{BlockId, DiskBlockObjectWriter}

/**
Expand All @@ -45,7 +44,6 @@ import org.apache.spark.storage.{BlockId, DiskBlockObjectWriter}
* `spark.shuffle.compress`). We may need to revisit this if ExternalSorter is used in other
* non-shuffle contexts where we might want to use different configuration settings.
*
* @param aggregator optional Aggregator with combine functions to use for merging data
* @param partitioner optional Partitioner; if given, sort by partition ID and then key
* @param ordering optional Ordering to sort keys within each partition; should be a total ordering
* @param serializer serializer to use when spilling to disk
Expand Down Expand Up @@ -87,39 +85,31 @@ import org.apache.spark.storage.{BlockId, DiskBlockObjectWriter}
*
* - Users are expected to call stop() at the end to delete all the intermediate files.
*/
private[spark] class ExternalSorter[K, V, C](
aggregator: Option[Aggregator[K, V, C]] = None,
private[spark] abstract class ExternalSorter[K, V, C](
partitioner: Option[Partitioner] = None,
ordering: Option[Ordering[K]] = None,
serializer: Option[Serializer] = None)
extends Logging
with Spillable[WritablePartitionedPairCollection[K, C]]
with SortShuffleFileWriter[K, V] {

private val conf = SparkEnv.get.conf
protected val conf = SparkEnv.get.conf

private val numPartitions = partitioner.map(_.numPartitions).getOrElse(1)
private val shouldPartition = numPartitions > 1
private def getPartition(key: K): Int = {
protected val numPartitions = partitioner.map(_.numPartitions).getOrElse(1)
protected val shouldPartition = numPartitions > 1
protected def getPartition(key: K): Int = {
if (shouldPartition) partitioner.get.getPartition(key) else 0
}

// Since SPARK-7855, bypassMergeSort optimization is no longer performed as part of this class.
// As a sanity check, make sure that we're not handling a shuffle which should use that path.
if (SortShuffleWriter.shouldBypassMergeSort(conf, numPartitions, aggregator, ordering)) {
throw new IllegalArgumentException("ExternalSorter should not be used to handle "
+ " a sort that the BypassMergeSortShuffleWriter should handle")
}

private val blockManager = SparkEnv.get.blockManager
private val diskBlockManager = blockManager.diskBlockManager
private val ser = Serializer.getSerializer(serializer)
private val serInstance = ser.newInstance()
protected val blockManager = SparkEnv.get.blockManager
protected val diskBlockManager = blockManager.diskBlockManager
protected val ser = Serializer.getSerializer(serializer)
protected val serInstance = ser.newInstance()

private val spillingEnabled = conf.getBoolean("spark.shuffle.spill", true)
protected val spillingEnabled = conf.getBoolean("spark.shuffle.spill", true)

// Use getSizeAsKb (not bytes) to maintain backwards compatibility if no units are provided
private val fileBufferSize = conf.getSizeAsKb("spark.shuffle.file.buffer", "32k").toInt * 1024
protected val fileBufferSize = conf.getSizeAsKb("spark.shuffle.file.buffer", "32k").toInt * 1024

// Size of object batches when reading/writing from serializers.
//
Expand All @@ -128,14 +118,14 @@ private[spark] class ExternalSorter[K, V, C](
//
// NOTE: Setting this too low can cause excessive copying when serializing, since some serializers
// grow internal data structures by growing + copying every time the number of objects doubles.
private val serializerBatchSize = conf.getLong("spark.shuffle.spill.batchSize", 10000)
protected val serializerBatchSize = conf.getLong("spark.shuffle.spill.batchSize", 10000)

private val useSerializedPairBuffer =
protected val useSerializedPairBuffer =
ordering.isEmpty &&
conf.getBoolean("spark.shuffle.sort.serializeMapOutputs", true) &&
ser.supportsRelocationOfSerializedObjects
private val kvChunkSize = conf.getInt("spark.shuffle.sort.kvChunkSize", 1 << 22) // 4 MB
private def newBuffer(): WritablePartitionedPairCollection[K, C] with SizeTracker = {
protected val kvChunkSize = conf.getInt("spark.shuffle.sort.kvChunkSize", 1 << 22) // 4 MB
protected def newBuffer(): WritablePartitionedPairCollection[K, C] with SizeTracker = {
if (useSerializedPairBuffer) {
new PartitionedSerializedPairBuffer(metaInitialRecords = 256, kvChunkSize, serInstance)
} else {
Expand All @@ -145,81 +135,49 @@ private[spark] class ExternalSorter[K, V, C](
// Data structures to store in-memory objects before we spill. Depending on whether we have an
// Aggregator set, we either put objects into an AppendOnlyMap where we combine them, or we
// store them in an array buffer.
private var map = new PartitionedAppendOnlyMap[K, C]
private var buffer = newBuffer()
protected var map = new PartitionedAppendOnlyMap[K, C]
protected var buffer = newBuffer()

// Total spilling statistics
private var _diskBytesSpilled = 0L
protected var _diskBytesSpilled = 0L
def diskBytesSpilled: Long = _diskBytesSpilled


// A comparator for keys K that orders them within a partition to allow aggregation or sorting.
// Can be a partial ordering by hash code if a total ordering is not provided through by the
// user. (A partial ordering means that equal keys have comparator.compare(k, k) = 0, but some
// non-equal keys also have this, so we need to do a later pass to find truly equal keys).
// Note that we ignore this if no aggregator and no ordering are given.
private val keyComparator: Comparator[K] = ordering.getOrElse(new Comparator[K] {
protected val keyComparator: Comparator[K] = ordering.getOrElse(new Comparator[K] {
override def compare(a: K, b: K): Int = {
val h1 = if (a == null) 0 else a.hashCode()
val h2 = if (b == null) 0 else b.hashCode()
if (h1 < h2) -1 else if (h1 == h2) 0 else 1
}
})

private def comparator: Option[Comparator[K]] = {
if (ordering.isDefined || aggregator.isDefined) {
Some(keyComparator)
} else {
None
}
}
// Interface for comparator object to abstract away aggregator dependence
protected def comparator: Option[Comparator[K]]

// Information about a spilled file. Includes sizes in bytes of "batches" written by the
// serializer as we periodically reset its stream, as well as number of elements in each
// partition, used to efficiently keep track of partitions when merging.
private[this] case class SpilledFile(
protected[this] case class SpilledFile(
file: File,
blockId: BlockId,
serializerBatchSizes: Array[Long],
elementsPerPartition: Array[Long])

private val spills = new ArrayBuffer[SpilledFile]
protected val spills = new ArrayBuffer[SpilledFile]

override def insertAll(records: Iterator[Product2[K, V]]): Unit = {
// TODO: stop combining if we find that the reduction factor isn't high
val shouldCombine = aggregator.isDefined

if (shouldCombine) {
// Combine values in-memory first using our AppendOnlyMap
val mergeValue = aggregator.get.mergeValue
val createCombiner = aggregator.get.createCombiner
var kv: Product2[K, V] = null
val update = (hadValue: Boolean, oldValue: C) => {
if (hadValue) mergeValue(oldValue, kv._2) else createCombiner(kv._2)
}
while (records.hasNext) {
addElementsRead()
kv = records.next()
map.changeValue((getPartition(kv._1), kv._1), update)
maybeSpillCollection(usingMap = true)
}
} else {
// Stick values into our buffer
while (records.hasNext) {
addElementsRead()
val kv = records.next()
buffer.insert(getPartition(kv._1), kv._1, kv._2.asInstanceOf[C])
maybeSpillCollection(usingMap = false)
}
}
}
// Interface for insertAll to abstract away aggregator dependence
def insertAll(records: Iterator[Product2[K, V]]): Unit

/**
* Spill the current in-memory collection to disk if needed.
*
* @param usingMap whether we're using a map or buffer as our current in-memory collection
*/
private def maybeSpillCollection(usingMap: Boolean): Unit = {
protected def maybeSpillCollection(usingMap: Boolean): Unit = {
if (!spillingEnabled) {
return
}
Expand Down Expand Up @@ -323,33 +281,17 @@ private[spark] class ExternalSorter[K, V, C](
* partition we then have an iterator over its contents, and these are expected to be accessed
* in order (you can't "skip ahead" to one partition without reading the previous one).
* Guaranteed to return a key-value pair for each partition, in order of partition ID.
*
* This interface abstracts away aggregator dependence
*/
private def merge(spills: Seq[SpilledFile], inMemory: Iterator[((Int, K), C)])
: Iterator[(Int, Iterator[Product2[K, C]])] = {
val readers = spills.map(new SpillReader(_))
val inMemBuffered = inMemory.buffered
(0 until numPartitions).iterator.map { p =>
val inMemIterator = new IteratorForPartition(p, inMemBuffered)
val iterators = readers.map(_.readNextPartition()) ++ Seq(inMemIterator)
if (aggregator.isDefined) {
// Perform partial aggregation across partitions
(p, mergeWithAggregation(
iterators, aggregator.get.mergeCombiners, keyComparator, ordering.isDefined))
} else if (ordering.isDefined) {
// No aggregator given, but we have an ordering (e.g. used by reduce tasks in sortByKey);
// sort the elements without trying to merge them
(p, mergeSort(iterators, ordering.get))
} else {
(p, iterators.iterator.flatten)
}
}
}
protected def merge(spills: Seq[SpilledFile], inMemory: Iterator[((Int, K), C)])
: Iterator[(Int, Iterator[Product2[K, C]])]

/**
* Merge-sort a sequence of (K, C) iterators using a given a comparator for the keys.
*/
private def mergeSort(iterators: Seq[Iterator[Product2[K, C]]], comparator: Comparator[K])
: Iterator[Product2[K, C]] =
protected def mergeSort(iterators: Seq[Iterator[Product2[K, C]]], comparator: Comparator[K])
: Iterator[Product2[K, C]] =
{
val bufferedIters = iterators.filter(_.hasNext).map(_.buffered)
type Iter = BufferedIterator[Product2[K, C]]
Expand Down Expand Up @@ -381,12 +323,11 @@ private[spark] class ExternalSorter[K, V, C](
* (e.g. when we sort objects by hash code and different keys may compare as equal although
* they're not), we still merge them by doing equality tests for all keys that compare as equal.
*/
private def mergeWithAggregation(
protected def mergeWithAggregation(
iterators: Seq[Iterator[Product2[K, C]]],
mergeCombiners: (C, C) => C,
comparator: Comparator[K],
totalOrder: Boolean)
: Iterator[Product2[K, C]] =
totalOrder: Boolean): Iterator[Product2[K, C]] =
{
if (!totalOrder) {
// We only have a partial ordering, e.g. comparing the keys by hash code, which means that
Expand Down Expand Up @@ -461,7 +402,7 @@ private[spark] class ExternalSorter[K, V, C](
* An internal class for reading a spilled file partition by partition. Expects all the
* partitions to be requested in order.
*/
private[this] class SpillReader(spill: SpilledFile) {
protected[this] class SpillReader(spill: SpilledFile) {
// Serializer batch offsets; size will be batchSize.length + 1
val batchOffsets = spill.serializerBatchSizes.scanLeft(0L)(_ + _)

Expand Down Expand Up @@ -611,11 +552,14 @@ private[spark] class ExternalSorter[K, V, C](
*
* For now, we just merge all the spilled files in once pass, but this can be modified to
* support hierarchical merging.
*
* This interface abstracts away aggregator dependence.
*/
@VisibleForTesting
def partitionedIterator: Iterator[(Int, Iterator[Product2[K, C]])] = {
val usingMap = aggregator.isDefined
val collection: WritablePartitionedPairCollection[K, C] = if (usingMap) map else buffer
def partitionedIterator: Iterator[(Int, Iterator[Product2[K, C]])]

protected def partitionedIterator(collection: WritablePartitionedPairCollection[K, C])
: Iterator[(Int, Iterator[Product2[K, C]])] = {
if (spills.isEmpty) {
// Special case: if we have only in-memory data, we don't need to merge streams, and perhaps
// we don't even need to sort by anything other than partition ID
Expand Down Expand Up @@ -644,18 +588,25 @@ private[spark] class ExternalSorter[K, V, C](
* @param blockId block ID to write to. The index file will be blockId.name + ".index".
* @param context a TaskContext for a running Spark task, for us to update shuffle metrics.
* @return array of lengths, in bytes, of each partition of the file (used by map output tracker)
*
* This interface abstracts away aggregator dependence.
*/
override def writePartitionedFile(
def writePartitionedFile(
blockId: BlockId,
context: TaskContext,
outputFile: File): Array[Long]

protected def writePartitionedFile(
blockId: BlockId,
context: TaskContext,
outputFile: File): Array[Long] = {
outputFile: File,
collection: WritablePartitionedPairCollection[K, C]): Array[Long] = {

// Track location of each range in the output file
val lengths = new Array[Long](numPartitions)

if (spills.isEmpty) {
// Case where we only have in-memory data
val collection = if (aggregator.isDefined) map else buffer
val it = collection.destructiveSortedWritablePartitionedIterator(comparator)
while (it.hasNext) {
val writer = blockManager.getDiskWriter(blockId, outputFile, serInstance, fileBufferSize,
Expand Down Expand Up @@ -701,7 +652,7 @@ private[spark] class ExternalSorter[K, V, C](
*
* @param data an iterator of elements, assumed to already be sorted by partition ID
*/
private def groupByPartition(data: Iterator[((Int, K), C)])
protected def groupByPartition(data: Iterator[((Int, K), C)])
: Iterator[(Int, Iterator[Product2[K, C]])] =
{
val buffered = data.buffered
Expand All @@ -713,8 +664,8 @@ private[spark] class ExternalSorter[K, V, C](
* stream, assuming this partition is the next one to be read. Used to make it easier to return
* partitioned iterators from our in-memory collection.
*/
private[this] class IteratorForPartition(partitionId: Int, data: BufferedIterator[((Int, K), C)])
extends Iterator[Product2[K, C]]
protected[this] class IteratorForPartition(partitionId: Int,
data: BufferedIterator[((Int, K), C)]) extends Iterator[Product2[K, C]]
{
override def hasNext: Boolean = data.hasNext && data.head._1._1 == partitionId

Expand Down
Loading