Skip to content

Commit

Permalink
=rem #3967 Handle refused connections as association failures
Browse files Browse the repository at this point in the history
  • Loading branch information
bantonsson committed Apr 7, 2014
1 parent bddab2a commit 672e7f9
Show file tree
Hide file tree
Showing 3 changed files with 125 additions and 33 deletions.
68 changes: 39 additions & 29 deletions akka-remote/src/main/scala/akka/remote/Remoting.scala
Expand Up @@ -411,6 +411,12 @@ private[remote] class EndpointManager(conf: Config, log: LoggingAdapter) extends
var pendingReadHandoffs = Map[ActorRef, AkkaProtocolHandle]()
var stashedInbound = Map[ActorRef, Vector[InboundAssociation]]()

def handleStashedInbound(endpoint: ActorRef) {
val stashed = stashedInbound.getOrElse(endpoint, Vector.empty)
stashedInbound -= endpoint
stashed foreach (handleInboundAssociation _)
}

def keepQuarantinedOr(remoteAddress: Address)(body: Unit): Unit = endpoints.refuseUid(remoteAddress) match {
case Some(uid)
log.info("Quarantined address [{}] is still unreachable or has not been restarted. Keeping it quarantined.", remoteAddress)
Expand Down Expand Up @@ -567,44 +573,19 @@ private[remote] class EndpointManager(conf: Config, log: LoggingAdapter) extends

}

case ia @ InboundAssociation(handle: AkkaProtocolHandle) endpoints.readOnlyEndpointFor(handle.remoteAddress) match {
case Some(endpoint)
pendingReadHandoffs.get(endpoint) foreach (_.disassociate())
pendingReadHandoffs += endpoint -> handle
endpoint ! EndpointWriter.TakeOver(handle)
case None
if (endpoints.isQuarantined(handle.remoteAddress, handle.handshakeInfo.uid))
handle.disassociate(AssociationHandle.Quarantined)
else endpoints.writableEndpointWithPolicyFor(handle.remoteAddress) match {
case Some(Pass(ep, None))
stashedInbound += ep -> (stashedInbound.getOrElse(ep, Vector.empty) :+ ia)
case Some(Pass(ep, Some(uid)))
if (handle.handshakeInfo.uid == uid) {
pendingReadHandoffs.get(ep) foreach (_.disassociate())
pendingReadHandoffs += ep -> handle
ep ! EndpointWriter.StopReading(ep)
} else {
context.stop(ep)
endpoints.unregisterEndpoint(ep)
pendingReadHandoffs -= ep
createAndRegisterEndpoint(handle, Some(uid))
}
case state
createAndRegisterEndpoint(handle, None)
}
}
case ia @ InboundAssociation(handle: AkkaProtocolHandle)
handleInboundAssociation(ia)
case EndpointWriter.StoppedReading(endpoint)
acceptPendingReader(takingOverFrom = endpoint)
case Terminated(endpoint)
acceptPendingReader(takingOverFrom = endpoint)
endpoints.unregisterEndpoint(endpoint)
stashedInbound -= endpoint
handleStashedInbound(endpoint)
case EndpointWriter.TookOver(endpoint, handle)
removePendingReader(takingOverFrom = endpoint, withHandle = handle)
case ReliableDeliverySupervisor.GotUid(uid)
endpoints.registerWritableEndpointUid(sender, uid)
stashedInbound.getOrElse(sender, Vector.empty) foreach (self ! _)
stashedInbound -= sender
handleStashedInbound(sender)
case Prune
endpoints.prune()
case ShutdownAndFlush
Expand Down Expand Up @@ -635,6 +616,35 @@ private[remote] class EndpointManager(conf: Config, log: LoggingAdapter) extends
case Terminated(_) // why should we care now?
}

def handleInboundAssociation(ia: InboundAssociation): Unit = ia match {
case ia @ InboundAssociation(handle: AkkaProtocolHandle) endpoints.readOnlyEndpointFor(handle.remoteAddress) match {
case Some(endpoint)
pendingReadHandoffs.get(endpoint) foreach (_.disassociate())
pendingReadHandoffs += endpoint -> handle
endpoint ! EndpointWriter.TakeOver(handle)
case None
if (endpoints.isQuarantined(handle.remoteAddress, handle.handshakeInfo.uid))
handle.disassociate(AssociationHandle.Quarantined)
else endpoints.writableEndpointWithPolicyFor(handle.remoteAddress) match {
case Some(Pass(ep, None))
stashedInbound += ep -> (stashedInbound.getOrElse(ep, Vector.empty) :+ ia)
case Some(Pass(ep, Some(uid)))
if (handle.handshakeInfo.uid == uid) {
pendingReadHandoffs.get(ep) foreach (_.disassociate())
pendingReadHandoffs += ep -> handle
ep ! EndpointWriter.StopReading(ep)
} else {
context.stop(ep)
endpoints.unregisterEndpoint(ep)
pendingReadHandoffs -= ep
createAndRegisterEndpoint(handle, Some(uid))
}
case state
createAndRegisterEndpoint(handle, None)
}
}
}

private def createAndRegisterEndpoint(handle: AkkaProtocolHandle, refuseUid: Option[Int]): Unit = {
val writing = settings.UsePassiveConnections && !endpoints.hasWritableEndpointFor(handle.remoteAddress)
eventPublisher.notifyListeners(AssociatedEvent(handle.localAddress, handle.remoteAddress, inbound = true))
Expand Down
Expand Up @@ -441,7 +441,7 @@ class NettyTransport(val settings: NettyTransportSettings, val system: ExtendedA
readyChannel.getPipeline.get(classOf[ClientHandler]).statusFuture
} yield handle) recover {
case c: CancellationException throw new NettyTransportException("Connection was cancelled") with NoStackTrace
case u @ (_: UnknownHostException | _: SecurityException) throw new InvalidAssociationException(u.getMessage, u.getCause)
case u @ (_: UnknownHostException | _: SecurityException | _: ConnectException) throw new InvalidAssociationException(u.getMessage, u.getCause)
case NonFatal(t) throw new NettyTransportException(t.getMessage, t.getCause) with NoStackTrace
}
}
Expand Down
88 changes: 85 additions & 3 deletions akka-remote/src/test/scala/akka/remote/RemotingSpec.scala
Expand Up @@ -14,6 +14,7 @@ import scala.concurrent.Await
import scala.concurrent.Future
import scala.concurrent.duration._
import scala.concurrent.forkjoin.ThreadLocalRandom
import akka.TestUtils.temporaryServerAddress

object RemotingSpec {

Expand Down Expand Up @@ -115,6 +116,12 @@ object RemotingSpec {
}
""")

def muteSystem(system: ActorSystem) {
system.eventStream.publish(TestEvent.Mute(
EventFilter.error(start = "AssociationError"),
EventFilter.warning(start = "AssociationError"),
EventFilter.warning(pattern = "received dead letter.*")))
}
}

@org.junit.runner.RunWith(classOf[org.scalatest.junit.JUnitRunner])
Expand Down Expand Up @@ -179,9 +186,7 @@ class RemotingSpec extends AkkaSpec(RemotingSpec.cfg) with ImplicitSender with D
}

override def atStartup() = {
system.eventStream.publish(TestEvent.Mute(
EventFilter.error(start = "AssociationError"),
EventFilter.warning(pattern = "received dead letter.*")))
muteSystem(system);
remoteSystem.eventStream.publish(TestEvent.Mute(
EventFilter[EndpointException](),
EventFilter.error(start = "AssociationError"),
Expand Down Expand Up @@ -563,5 +568,82 @@ class RemotingSpec extends AkkaSpec(RemotingSpec.cfg) with ImplicitSender with D
shutdown(otherSystem)
}
}

"be able to connect to system even if it's not there at first" in {
val config = ConfigFactory.parseString(s"""
akka.remote.enabled-transports = ["akka.remote.netty.tcp"]
akka.remote.netty.tcp.port = 0
akka.remote.retry-gate-closed-for = 5s
""").withFallback(remoteSystem.settings.config)
val thisSystem = ActorSystem("this-system", config)
try {
muteSystem(thisSystem)
val probe = new TestProbe(thisSystem)
val probeSender = probe.ref
val otherAddress = temporaryServerAddress()
val otherConfig = ConfigFactory.parseString(s"""
akka.remote.netty.tcp.port = ${otherAddress.getPort}
""").withFallback(config)
val otherSelection = thisSystem.actorSelection(s"akka.tcp://other-system@localhost:${otherAddress.getPort}/user/echo")
otherSelection.tell("ping", probeSender)
probe.expectNoMsg(1 seconds)
val otherSystem = ActorSystem("other-system", otherConfig)
try {
muteSystem(otherSystem)
probe.expectNoMsg(2 seconds)
otherSystem.actorOf(Props[Echo2], "echo")
within(5 seconds) {
awaitAssert {
otherSelection.tell("ping", probeSender)
assert(probe.expectMsgType[(String, ActorRef)](500 millis)._1 == "pong")
}
}
} finally {
shutdown(otherSystem)
}
} finally {
shutdown(thisSystem)
}
}

"allow other system to connect even if it's not there at first" in {
val config = ConfigFactory.parseString(s"""
akka.remote.enabled-transports = ["akka.remote.netty.tcp"]
akka.remote.netty.tcp.port = 0
akka.remote.retry-gate-closed-for = 5s
""").withFallback(remoteSystem.settings.config)
val thisSystem = ActorSystem("this-system", config)
try {
muteSystem(thisSystem)
val thisProbe = new TestProbe(thisSystem)
val thisSender = thisProbe.ref
thisSystem.actorOf(Props[Echo2], "echo")
val otherAddress = temporaryServerAddress()
val otherConfig = ConfigFactory.parseString(s"""
akka.remote.netty.tcp.port = ${otherAddress.getPort}
""").withFallback(config)
val otherSelection = thisSystem.actorSelection(s"akka.tcp://other-system@localhost:${otherAddress.getPort}/user/echo")
otherSelection.tell("ping", thisSender)
thisProbe.expectNoMsg(1 seconds)
val otherSystem = ActorSystem("other-system", otherConfig)
try {
muteSystem(otherSystem)
thisProbe.expectNoMsg(2 seconds)
val otherProbe = new TestProbe(otherSystem)
val otherSender = otherProbe.ref
val thisSelection = otherSystem.actorSelection(s"akka.tcp://this-system@localhost:${port(thisSystem, "tcp")}/user/echo")
within(5 seconds) {
awaitAssert {
thisSelection.tell("ping", otherSender)
assert(otherProbe.expectMsgType[(String, ActorRef)](500 millis)._1 == "pong")
}
}
} finally {
shutdown(otherSystem)
}
} finally {
shutdown(thisSystem)
}
}
}
}

0 comments on commit 672e7f9

Please sign in to comment.