Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SPARK-22537][core] Aggregation of map output statistics on driver faces single point bottleneck #19763

Closed
wants to merge 13 commits into from
60 changes: 57 additions & 3 deletions core/src/main/scala/org/apache/spark/MapOutputTracker.scala
Expand Up @@ -23,11 +23,14 @@ import java.util.zip.{GZIPInputStream, GZIPOutputStream}

import scala.collection.JavaConverters._
import scala.collection.mutable.{ArrayBuffer, HashMap, HashSet, Map}
import scala.concurrent.{ExecutionContext, Future}
import scala.concurrent.duration.Duration
import scala.reflect.ClassTag
import scala.util.control.NonFatal

import org.apache.spark.broadcast.{Broadcast, BroadcastManager}
import org.apache.spark.internal.Logging
import org.apache.spark.internal.config._
import org.apache.spark.rpc.{RpcCallContext, RpcEndpoint, RpcEndpointRef, RpcEnv}
import org.apache.spark.scheduler.MapStatus
import org.apache.spark.shuffle.MetadataFetchFailedException
Expand Down Expand Up @@ -472,15 +475,66 @@ private[spark] class MapOutputTrackerMaster(
shuffleStatuses.get(shuffleId).map(_.findMissingPartitions())
}

/**
* Grouped function of Range, this is to avoid traverse of all elements of Range using
* IterableLike's grouped function.
*/
def rangeGrouped(range: Range, size: Int): Seq[Range] = {
val start = range.start
val step = range.step
val end = range.end
for (i <- start.until(end, size * step)) yield {
i.until(i + size * step, step)
}
}

/**
* To equally divide n elements into m buckets, basically each bucket should have n/m elements,
* for the remaining n%m elements, add one more element to the first n%m buckets each.
*/
def equallyDivide(numElements: Int, numBuckets: Int): Seq[Seq[Int]] = {
val elementsPerBucket = numElements / numBuckets
val remaining = numElements % numBuckets
val splitPoint = (elementsPerBucket + 1) * remaining
if (elementsPerBucket == 0) {
rangeGrouped(0.until(splitPoint), elementsPerBucket + 1)
} else {
rangeGrouped(0.until(splitPoint), elementsPerBucket + 1) ++
rangeGrouped(splitPoint.until(numElements), elementsPerBucket)
}
}

/**
* Return statistics about all of the outputs for a given shuffle.
*/
def getStatistics(dep: ShuffleDependency[_, _, _]): MapOutputStatistics = {
shuffleStatuses(dep.shuffleId).withMapStatuses { statuses =>
val totalSizes = new Array[Long](dep.partitioner.numPartitions)
for (s <- statuses) {
for (i <- 0 until totalSizes.length) {
totalSizes(i) += s.getSizeForBlock(i)
val parallelAggThreshold = conf.get(
SHUFFLE_MAP_OUTPUT_PARALLEL_AGGREGATION_THRESHOLD)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe a little picky, but should we do:

val parallelAggThreshold = conf.get(
  SHUFFLE_MAP_OUTPUT_PARALLEL_AGGREGATION_THRESHOLD) + 1
...
val parallelism = math.min(
  Runtime.getRuntime.availableProcessors(),
  (statuses.length.toLong * totalSizes.length + 1) / parallelAggThreshold + 1).toInt

In case of the threshold being set to zero?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For zero or negative threshold, see my above comment: #19763 (comment).

Copy link
Author

@gcz2022 gcz2022 Nov 24, 2017

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think that code will make people confused, and we need more comments to explain, that seems unworthy.
In most cases the default value is enough, so we just add some value check and docs explanation will be good?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, I left the comment before #19763 (comment). I think it is good enough to add more comment to the config entry.

val parallelism = math.min(
Runtime.getRuntime.availableProcessors(),
statuses.length * totalSizes.length / parallelAggThreshold + 1)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

statuses.length.toLong. It's easy to overflow here.

if (parallelism <= 1) {
for (s <- statuses) {
for (i <- 0 until totalSizes.length) {
totalSizes(i) += s.getSizeForBlock(i)
}
}
} else {
try {
val threadPool = ThreadUtils.newDaemonFixedThreadPool(parallelism, "map-output-aggregate")
implicit val executionContext = ExecutionContext.fromExecutor(threadPool)
val mapStatusSubmitTasks = equallyDivide(totalSizes.length, parallelism).map {
reduceIds => Future {
for (s <- statuses; i <- reduceIds) {
totalSizes(i) += s.getSizeForBlock(i)
}
}
}
ThreadUtils.awaitResult(Future.sequence(mapStatusSubmitTasks), Duration.Inf)
} finally {
threadpool.shutdown()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

cc @zsxwing do we really need to shut down the thread pool every time? This method may be called many times and is it better to cache this thread pool? like the dispatcher thread pool.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agree with you, with putting the thread pool in the class, the only lost is that: even if when single-thread is used, this pool still exists. The gain is reducing creating the pool after every shuffle.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We can shut down the pool after some certain idle time, but not sure if it's worth the complexity

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm fine to create a thread pool every time since this code path seems not run pretty frequently because

  • Using a shared cached thread poll is just like creating new thread pool since the idle time of a thread is pretty large and is likely killed before the next call.
  • Using a shared fixed thread pool is totally a waste for most of use cases.
  • The cost of creating threads is trivial comparing the total time of a job.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@gczsjdy could you fix the compile error?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@zsxwing Actually I built using sbt/mvn, no errors...

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@gczsjdy Oh, sorry. I didn't realize there is already a threadpool field in MapOutputTrackerMaster. That's why there is no error. Here you are shutting down a wrong thread pool.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ah good catch! I misread it...

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

My fault!

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@cloud-fan We can shut down the pool after some certain idle time, but not sure if it's worth the complexity I know we don't need to do this now. But if we did it how to do?

}
}
new MapOutputStatistics(dep.shuffleId, totalSizes)
Expand Down
Expand Up @@ -485,4 +485,13 @@ package object config {
"array in the sorter.")
.intConf
.createWithDefault(Integer.MAX_VALUE)

private[spark] val SHUFFLE_MAP_OUTPUT_PARALLEL_AGGREGATION_THRESHOLD =
ConfigBuilder("spark.shuffle.mapOutput.parallelAggregationThreshold")
.internal()
.doc("Multi-thread is used when the number of mappers * shuffle partitions is greater than " +
"or equal to this threshold.")
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this condition to enable parallel aggregation still true?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry, but didn't get you.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks like only parallelism >= 2, this parallel aggregation is enabled. Is it equal to the number of mappers * shuffle partitions >= this threshold?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

From above statuses.length.toLong * totalSizes.length / parallelAggThreshold + 1, looks like we need to have at least two times of this threshold to enable this parallel aggregation?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

statuses.length.toLong * totalSizes.length / parallelAggThreshold + 1 >= 2 -> statuses.length.toLong * totalSizes.length >= parallelAggThreshold, so it doesn't need to be 2 times, just not smaller than 1x is good.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do you think it's necessary to indicate the actual parallelism's calculation way here?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's ok. I misread the equation. Nvm.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we don't need to indicate the calculation way in config description. The current one is enough.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

After rethinking about this, I think it is better to indicate this threshold also determines the number of threads in parallelism. So it should not be set to zero or negative number.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, I will add some.

.intConf
.createWithDefault(10000000)

}
23 changes: 23 additions & 0 deletions core/src/test/scala/org/apache/spark/MapOutputTrackerSuite.scala
Expand Up @@ -275,4 +275,27 @@ class MapOutputTrackerSuite extends SparkFunSuite {
}
}

test("equally divide map statistics tasks") {
val func = newTrackerMaster().equallyDivide _
val cases = Seq((0, 5), (4, 5), (15, 5), (16, 5), (17, 5), (18, 5), (19, 5), (20, 5))
val expects = Seq(
Seq(0, 0, 0, 0, 0),
Seq(1, 1, 1, 1, 0),
Seq(3, 3, 3, 3, 3),
Seq(4, 3, 3, 3, 3),
Seq(4, 4, 3, 3, 3),
Seq(4, 4, 4, 3, 3),
Seq(4, 4, 4, 4, 3),
Seq(4, 4, 4, 4, 4))
cases.zip(expects).foreach { case ((num, divisor), expect) =>
val answer = func(num, divisor).toSeq
var wholeSplit = (0 until num)
answer.zip(expect).foreach { case (split, expectSplitLength) =>
val (currentSplit, rest) = wholeSplit.splitAt(expectSplitLength)
assert(currentSplit.toSet == split.toSet)
wholeSplit = rest
}
}
}

}