Skip to content

Commit

Permalink
Use Netty HashedWheelTimer
Browse files Browse the repository at this point in the history
  • Loading branch information
JoshRosen committed Nov 14, 2014
1 parent f847dd4 commit 3200c33
Showing 1 changed file with 9 additions and 7 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -24,14 +24,14 @@ import java.nio.channels._
import java.nio.channels.spi._
import java.util.concurrent.atomic.AtomicInteger
import java.util.concurrent.{LinkedBlockingDeque, ThreadPoolExecutor, TimeUnit}
import java.util.{Timer, TimerTask}

import scala.collection.mutable.{ArrayBuffer, HashMap, HashSet, SynchronizedMap, SynchronizedQueue}
import scala.concurrent.duration._
import scala.concurrent.{Await, ExecutionContext, Future, Promise}
import scala.language.postfixOps

import com.google.common.base.Charsets.UTF_8
import io.netty.util.{Timeout, TimerTask, HashedWheelTimer}

import org.apache.spark._
import org.apache.spark.network.sasl.{SparkSaslClient, SparkSaslServer}
Expand Down Expand Up @@ -77,7 +77,8 @@ private[nio] class ConnectionManager(
}

private val selector = SelectorProvider.provider.openSelector()
private val ackTimeoutMonitor = new Timer("AckTimeoutMonitor", true)
private val ackTimeoutMonitor =
new HashedWheelTimer(Utils.namedThreadFactory("AckTimeoutMonitor"))

private val ackTimeout = conf.getInt("spark.core.connection.ack.wait.timeout", 60)

Expand Down Expand Up @@ -903,8 +904,8 @@ private[nio] class ConnectionManager(
// memory leaks since cancelled TimerTasks won't necessarily be garbage collected until they are
// scheduled to run. Therefore, extract the message id from outside of the task:
val messageId = message.id
val timeoutTask = new TimerTask {
override def run(): Unit = {
val timeoutTask: TimerTask = new TimerTask {
override def run(timeout: Timeout): Unit = {
messageStatuses.synchronized {
messageStatuses.remove(messageId).foreach ( s => {
val e = new IOException("sendMessageReliably failed because ack " +
Expand All @@ -917,8 +918,10 @@ private[nio] class ConnectionManager(
}
}

val timoutTaskHandle = ackTimeoutMonitor.newTimeout(timeoutTask, ackTimeout, TimeUnit.SECONDS)

val status = new MessageStatus(message, connectionManagerId, s => {
timeoutTask.cancel()
timoutTaskHandle.cancel()
s match {
case scala.util.Failure(e) =>
// Indicates a failure where we either never sent or never got ACK'd
Expand Down Expand Up @@ -947,7 +950,6 @@ private[nio] class ConnectionManager(
messageStatuses += ((message.id, status))
}

ackTimeoutMonitor.schedule(timeoutTask, ackTimeout * 1000)
sendMessage(connectionManagerId, message)
promise.future
}
Expand All @@ -957,7 +959,7 @@ private[nio] class ConnectionManager(
}

def stop() {
ackTimeoutMonitor.cancel()
ackTimeoutMonitor.stop()
selectorThread.interrupt()
selectorThread.join()
selector.close()
Expand Down

0 comments on commit 3200c33

Please sign in to comment.