Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SPARK-13398][STREAMING] Move away from thread pool task support to forkjoin #11423

18 changes: 18 additions & 0 deletions core/src/main/scala/org/apache/spark/util/ThreadUtils.scala
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ package org.apache.spark.util
import java.util.concurrent._

import scala.concurrent.{ExecutionContext, ExecutionContextExecutor}
import scala.concurrent.forkjoin.{ForkJoinPool => SForkJoinPool, ForkJoinWorkerThread => SForkJoinWorkerThread}
import scala.util.control.NonFatal

import com.google.common.util.concurrent.{MoreExecutors, ThreadFactoryBuilder}
Expand Down Expand Up @@ -156,4 +157,21 @@ private[spark] object ThreadUtils {
result
}
}

/**
* Construct a new Scala ForkJoinPool with a specified max parallelism and name prefix.
*/
def newForkJoinPool(prefix: String, maxThreadNumber: Int): SForkJoinPool = {
// Custom factory to set thread names
val factory = new SForkJoinPool.ForkJoinWorkerThreadFactory {
override def newThread(pool: SForkJoinPool) =
new SForkJoinWorkerThread(pool) {
setName(prefix + "-" + super.getName)
}
}
new SForkJoinPool(maxThreadNumber, factory,
null, // handler
false // asyncMode
)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -18,11 +18,11 @@ package org.apache.spark.streaming.util

import java.nio.ByteBuffer
import java.util.{Iterator => JIterator}
import java.util.concurrent.{RejectedExecutionException, ThreadPoolExecutor}
import java.util.concurrent.RejectedExecutionException

import scala.collection.JavaConverters._
import scala.collection.mutable.ArrayBuffer
import scala.collection.parallel.ThreadPoolTaskSupport
import scala.collection.parallel.ExecutionContextTaskSupport
import scala.concurrent.{Await, ExecutionContext, Future}
import scala.language.postfixOps

Expand Down Expand Up @@ -62,8 +62,8 @@ private[streaming] class FileBasedWriteAheadLog(
private val threadpoolName = {
"WriteAheadLogManager" + callerName.map(c => s" for $c").getOrElse("")
}
private val threadpool = ThreadUtils.newDaemonCachedThreadPool(threadpoolName, 20)
private val executionContext = ExecutionContext.fromExecutorService(threadpool)
private val forkJoinPool = ThreadUtils.newForkJoinPool(threadpoolName, 20)
private val executionContext = ExecutionContext.fromExecutorService(forkJoinPool)

override protected def logName = {
getClass.getName.stripSuffix("$") +
Expand Down Expand Up @@ -144,7 +144,7 @@ private[streaming] class FileBasedWriteAheadLog(
} else {
// For performance gains, it makes sense to parallelize the recovery if
// closeFileAfterWrite = true
seqToParIterator(threadpool, logFilesToRead, readFile).asJava
seqToParIterator(executionContext, logFilesToRead, readFile).asJava
}
}

Expand Down Expand Up @@ -283,16 +283,17 @@ private[streaming] object FileBasedWriteAheadLog {

/**
* This creates an iterator from a parallel collection, by keeping at most `n` objects in memory
* at any given time, where `n` is the size of the thread pool. This is crucial for use cases
* where we create `FileBasedWriteAheadLogReader`s during parallel recovery. We don't want to
* open up `k` streams altogether where `k` is the size of the Seq that we want to parallelize.
* at any given time, where `n` is at most the max of the size of the thread pool or 8. This is
* crucial for use cases where we create `FileBasedWriteAheadLogReader`s during parallel recovery.
* We don't want to open up `k` streams altogether where `k` is the size of the Seq that we want
* to parallelize.
*/
def seqToParIterator[I, O](
tpool: ThreadPoolExecutor,
executionContext: ExecutionContext,
source: Seq[I],
handler: I => Iterator[O]): Iterator[O] = {
val taskSupport = new ThreadPoolTaskSupport(tpool)
val groupSize = tpool.getMaximumPoolSize.max(8)
val taskSupport = new ExecutionContextTaskSupport(executionContext)
val groupSize = taskSupport.parallelismLevel.max(8)
source.grouped(groupSize).flatMap { group =>
val parallelCollection = group.par
parallelCollection.tasksupport = taskSupport
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -228,7 +228,9 @@ class FileBasedWriteAheadLogSuite
the list of files.
*/
val numThreads = 8
val tpool = ThreadUtils.newDaemonFixedThreadPool(numThreads, "wal-test-thread-pool")
val fpool = ThreadUtils.newForkJoinPool("wal-test-thread-pool", numThreads)
val executionContext = ExecutionContext.fromExecutorService(fpool)

class GetMaxCounter {
private val value = new AtomicInteger()
@volatile private var max: Int = 0
Expand Down Expand Up @@ -258,7 +260,8 @@ class FileBasedWriteAheadLogSuite
val t = new Thread() {
override def run() {
// run the calculation on a separate thread so that we can release the latch
val iterator = FileBasedWriteAheadLog.seqToParIterator[Int, Int](tpool, testSeq, handle)
val iterator = FileBasedWriteAheadLog.seqToParIterator[Int, Int](executionContext,
testSeq, handle)
collected = iterator.toSeq
}
}
Expand All @@ -273,7 +276,7 @@ class FileBasedWriteAheadLogSuite
// make sure we didn't open too many Iterators
assert(counter.getMax() <= numThreads)
} finally {
tpool.shutdownNow()
fpool.shutdownNow()
}
}

Expand Down