Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -205,6 +205,13 @@ private[spark] class ExternalSorter[K, V, C](
map.changeValue((getPartition(kv._1), kv._1), update)
maybeSpillCollection(usingMap = true)
}
} else if (bypassMergeSort) {
// SPARK-4479: Also bypass buffering if merge sort is bypassed to avoid defensive copies
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Skipping this buffering seems to make it so that much of the rest of the bypassMergeSort-handling code is no longer needed. For example, if we don't buffer then we won't need to spill, so we can remove the code that deals with merging spills in the bypassMergeSort case. Based on this, I've opened #6397 to remove all of this now-unused code and to move the handling of the bypassMergeSort path into its own file. It would be great if this PR's reviewers could look at that PR to double-check my reasoning.

if (records.hasNext) {
spillToPartitionFiles(records.map { kv =>
((getPartition(kv._1), kv._1), kv._2.asInstanceOf[C])
})
}
} else {
// Stick values into our buffer
while (records.hasNext) {
Expand Down Expand Up @@ -336,6 +343,10 @@ private[spark] class ExternalSorter[K, V, C](
* @param collection whichever collection we're using (map or buffer)
*/
private def spillToPartitionFiles(collection: SizeTrackingPairCollection[(Int, K), C]): Unit = {
spillToPartitionFiles(collection.iterator)
}

private def spillToPartitionFiles(iterator: Iterator[((Int, K), C)]): Unit = {
assert(bypassMergeSort)

// Create our file writers if we haven't done so yet
Expand All @@ -350,9 +361,9 @@ private[spark] class ExternalSorter[K, V, C](
}
}

val it = collection.iterator // No need to sort stuff, just write each element out
while (it.hasNext) {
val elem = it.next()
// No need to sort stuff, just write each element out
while (iterator.hasNext) {
val elem = iterator.next()
val partitionId = elem._1._1
val key = elem._1._2
val value = elem._2
Expand Down Expand Up @@ -748,6 +759,12 @@ private[spark] class ExternalSorter[K, V, C](

context.taskMetrics.memoryBytesSpilled += memoryBytesSpilled
context.taskMetrics.diskBytesSpilled += diskBytesSpilled
context.taskMetrics.shuffleWriteMetrics.filter(_ => bypassMergeSort).foreach { m =>
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I was reading through ExternalSorter to try to understand how shuffle write time metrics are calculated and came across this line. This style is confusing to a casual reader: it looks like the logic here is "if shuffle write metrics are defined and merge sort is bypassed, then run this block", but this is slightly obfuscated by the fact that we're filtering an option with a filter function that doesn't depend on that option's value.

For next time, I think we should just use a simple if statement instead.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah, sorry for that. "Learned" this from Michael, won't do this again, lol

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Lies 😜

I don't think I've ever written a .filter(_ => ... though I will admit guilt for merging all but one instance in the codebase...

if (curWriteMetrics != null) {
m.shuffleBytesWritten += curWriteMetrics.shuffleBytesWritten
m.shuffleWriteTime += curWriteMetrics.shuffleWriteTime
}
}

lengths
}
Expand Down
12 changes: 6 additions & 6 deletions core/src/test/scala/org/apache/spark/ShuffleSuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -95,14 +95,14 @@ abstract class ShuffleSuite extends FunSuite with Matchers with LocalSparkContex
// Use a local cluster with 2 processes to make sure there are both local and remote blocks
sc = new SparkContext("local-cluster[2,1,512]", "test", conf)

// 10 partitions from 4 keys
val NUM_BLOCKS = 10
// 201 partitions (greater than "spark.shuffle.sort.bypassMergeThreshold") from 4 keys
val NUM_BLOCKS = 201
val a = sc.parallelize(1 to 4, NUM_BLOCKS)
val b = a.map(x => (x, x*2))

// NOTE: The default Java serializer doesn't create zero-sized blocks.
// So, use Kryo
val c = new ShuffledRDD[Int, Int, Int](b, new HashPartitioner(10))
val c = new ShuffledRDD[Int, Int, Int](b, new HashPartitioner(NUM_BLOCKS))
.setSerializer(new KryoSerializer(conf))

val shuffleId = c.dependencies.head.asInstanceOf[ShuffleDependency[_, _, _]].shuffleId
Expand All @@ -122,13 +122,13 @@ abstract class ShuffleSuite extends FunSuite with Matchers with LocalSparkContex
// Use a local cluster with 2 processes to make sure there are both local and remote blocks
sc = new SparkContext("local-cluster[2,1,512]", "test", conf)

// 10 partitions from 4 keys
val NUM_BLOCKS = 10
// 201 partitions (greater than "spark.shuffle.sort.bypassMergeThreshold") from 4 keys
val NUM_BLOCKS = 201
val a = sc.parallelize(1 to 4, NUM_BLOCKS)
val b = a.map(x => (x, x*2))

// NOTE: The default Java serializer should create zero-sized blocks
val c = new ShuffledRDD[Int, Int, Int](b, new HashPartitioner(10))
val c = new ShuffledRDD[Int, Int, Int](b, new HashPartitioner(NUM_BLOCKS))

val shuffleId = c.dependencies.head.asInstanceOf[ShuffleDependency[_, _, _]].shuffleId
assert(c.count === 4)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -41,11 +41,21 @@ case class Exchange(newPartitioning: Partitioning, child: SparkPlan) extends Una
/** We must copy rows when sort based shuffle is on */
protected def sortBasedShuffleOn = SparkEnv.get.shuffleManager.isInstanceOf[SortShuffleManager]

private val bypassMergeThreshold =
child.sqlContext.sparkContext.conf.getInt("spark.shuffle.sort.bypassMergeThreshold", 200)

override def execute() = attachTree(this , "execute") {
newPartitioning match {
case HashPartitioning(expressions, numPartitions) =>
// TODO: Eliminate redundant expressions in grouping key and value.
val rdd = if (sortBasedShuffleOn) {
// This is a workaround for SPARK-4479. When:
// 1. sort based shuffle is on, and
// 2. the partition number is under the merge threshold, and
// 3. no ordering is required
// we can avoid the defensive copies to improve performance. In the long run, we probably
// want to include information in shuffle dependencies to indicate whether elements in the
// source RDD should be copied.
val rdd = if (sortBasedShuffleOn && numPartitions > bypassMergeThreshold) {
child.execute().mapPartitions { iter =>
val hashExpressions = newMutableProjection(expressions, child.output)()
iter.map(r => (hashExpressions(r).copy(), r.copy()))
Expand Down Expand Up @@ -82,6 +92,10 @@ case class Exchange(newPartitioning: Partitioning, child: SparkPlan) extends Una
shuffled.map(_._1)

case SinglePartition =>
// SPARK-4479: Can't turn off defensive copy as what we do for `HashPartitioning`, since
// operators like `TakeOrdered` may require an ordering within the partition, and currently
// `SinglePartition` doesn't include ordering information.
// TODO Add `SingleOrderedPartition` for operators like `TakeOrdered`
val rdd = if (sortBasedShuffleOn) {
child.execute().mapPartitions { iter => iter.map(r => (null, r.copy())) }
} else {
Expand Down