Skip to content

Commit

Permalink
Further MQTT streaming hardening (#1327)
Browse files Browse the repository at this point in the history
Numerous hardening activities around MQTT streaming. However, in summary:

* Corrected misinterpretation of behaviour setup
* Avoid a race condition when creating child actors
* Handle a duplicate publish while consuming
* Connection packet id distinction which now allows more than one client to send the same packet id as another
  • Loading branch information
huntc authored and ennru committed Nov 16, 2018
1 parent c367b46 commit f4d6193
Show file tree
Hide file tree
Showing 6 changed files with 421 additions and 157 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,8 @@ import scala.util.{Failure, Success}
connectFlags: ConnectFlags,
keepAlive: FiniteDuration,
pendingPingResp: Boolean,
activeConsumers: Set[String],
activeProducers: Set[String],
pendingLocalPublications: Seq[(String, PublishReceivedLocally)],
pendingRemotePublications: Seq[(String, PublishReceivedFromRemote)],
remote: SourceQueueWithComplete[ForwardConnectCommand],
Expand Down Expand Up @@ -210,6 +212,8 @@ import scala.util.{Failure, Success}
data.connect.connectFlags,
data.connect.keepAlive,
pendingPingResp = false,
Set.empty,
Set.empty,
Vector.empty,
Vector.empty,
data.remote,
Expand Down Expand Up @@ -284,30 +288,30 @@ import scala.util.{Failure, Success}
local.success(Consumer.ForwardPublish)
serverConnected(data)
case (context, prfr @ PublishReceivedFromRemote(publish @ Publish(_, topicName, Some(packetId), _), local)) =>
val consumerName = ActorName.mkName(ConsumerNamePrefix + topicName)
context.child(consumerName) match {
case None if !data.pendingRemotePublications.exists(_._1 == publish.topicName) =>
context.watchWith(
context.spawn(
Consumer(publish, packetId, local, data.consumerPacketRouter, data.settings),
consumerName
),
ConsumerFree(publish.topicName)
)
serverConnected(data)
case _ =>
serverConnected(
data.copy(pendingRemotePublications = data.pendingRemotePublications :+ (publish.topicName -> prfr))
)
if (!data.activeConsumers.contains(topicName)) {
val consumerName = ActorName.mkName(ConsumerNamePrefix + topicName + "-" + context.children.size)
context.watchWith(
context.spawn(
Consumer(publish, None, packetId, local, data.consumerPacketRouter, data.settings),
consumerName
),
ConsumerFree(publish.topicName)
)
serverConnected(data.copy(activeConsumers = data.activeConsumers + publish.topicName))
} else {
serverConnected(
data.copy(pendingRemotePublications = data.pendingRemotePublications :+ (publish.topicName -> prfr))
)
}
case (context, ConsumerFree(topicName)) =>
val i = data.pendingRemotePublications.indexWhere(_._1 == topicName)
if (i >= 0) {
val prfr = data.pendingRemotePublications(i)._2
val consumerName = ActorName.mkName(ConsumerNamePrefix + topicName)
val consumerName = ActorName.mkName(ConsumerNamePrefix + topicName + "-" + context.children.size)
context.watchWith(
context.spawn(
Consumer(prfr.publish,
None,
prfr.publish.packetId.get,
prfr.local,
data.consumerPacketRouter,
Expand All @@ -323,35 +327,34 @@ import scala.util.{Failure, Success}
)
)
} else {
serverConnected(data)
serverConnected(data.copy(activeConsumers = data.activeConsumers - topicName))
}
case (_, PublishReceivedLocally(publish, _))
if (publish.flags & ControlPacketFlags.QoSReserved).underlying == 0 =>
data.remote.offer(ForwardPublish(publish, None))
serverConnected(data)
case (context, prl @ PublishReceivedLocally(publish, publishData)) =>
val producerName = ActorName.mkName(ProducerNamePrefix + publish.topicName)
context.child(producerName) match {
case None if !data.pendingLocalPublications.exists(_._1 == publish.topicName) =>
val reply = Promise[Source[Producer.ForwardPublishingCommand, NotUsed]]
import context.executionContext
reply.future.foreach(command => context.self ! ReceivedProducerPublishingCommand(command))
context.watchWith(
context.spawn(Producer(publish, publishData, reply, data.producerPacketRouter, data.settings),
producerName),
ProducerFree(publish.topicName)
)
serverConnected(data)
case _ =>
serverConnected(
data.copy(pendingLocalPublications = data.pendingLocalPublications :+ (publish.topicName -> prl))
)
val producerName = ActorName.mkName(ProducerNamePrefix + publish.topicName + "-" + context.children.size)
if (!data.activeProducers.contains(publish.topicName)) {
val reply = Promise[Source[Producer.ForwardPublishingCommand, NotUsed]]
import context.executionContext
reply.future.foreach(command => context.self ! ReceivedProducerPublishingCommand(command))
context.watchWith(
context.spawn(Producer(publish, publishData, reply, data.producerPacketRouter, data.settings),
producerName),
ProducerFree(publish.topicName)
)
serverConnected(data.copy(activeProducers = data.activeProducers + publish.topicName))
} else {
serverConnected(
data.copy(pendingLocalPublications = data.pendingLocalPublications :+ (publish.topicName -> prl))
)
}
case (context, ProducerFree(topicName)) =>
val i = data.pendingLocalPublications.indexWhere(_._1 == topicName)
if (i >= 0) {
val prl = data.pendingLocalPublications(i)._2
val producerName = ActorName.mkName(ProducerNamePrefix + topicName)
val producerName = ActorName.mkName(ProducerNamePrefix + topicName + "-" + context.children.size)
val reply = Promise[Source[Producer.ForwardPublishingCommand, NotUsed]]
import context.executionContext
reply.future.foreach(command => context.self ! ReceivedProducerPublishingCommand(command))
Expand All @@ -369,7 +372,7 @@ import scala.util.{Failure, Success}
)
)
} else {
serverConnected(data)
serverConnected(data.copy(activeProducers = data.activeProducers - topicName))
}
case (_, ReceivedProducerPublishingCommand(command)) =>
command.runWith(Sink.foreach {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ import akka.actor.typed.scaladsl.Behaviors
import akka.annotation.InternalApi
import akka.stream.{Materializer, OverflowStrategy}
import akka.stream.scaladsl.{BroadcastHub, Keep, Source, SourceQueueWithComplete}
import akka.util.ByteString

import scala.concurrent.Promise
import scala.util.control.NoStackTrace
Expand Down Expand Up @@ -175,33 +176,42 @@ import scala.util.{Failure, Success}
*/
case object ConsumeFailed extends Exception with NoStackTrace

/*
* A consume is active while a duplicate publish was received.
*/
case object ConsumeActive extends Exception with NoStackTrace

/*
* Construct with the starting state
*/
def apply(publish: Publish,
clientId: Option[String],
packetId: PacketId,
local: Promise[ForwardPublish.type],
packetRouter: ActorRef[RemotePacketRouter.Request[Event]],
settings: MqttSessionSettings): Behavior[Event] =
prepareClientConsumption(Start(publish, packetId, local, packetRouter, settings))
prepareClientConsumption(Start(publish, clientId, packetId, local, packetRouter, settings))

// Our FSM data, FSM events and commands emitted by the FSM

sealed abstract class Data(val publish: Publish,
val clientId: Option[String],
val packetId: PacketId,
val packetRouter: ActorRef[RemotePacketRouter.Request[Event]],
val settings: MqttSessionSettings)
final case class Start(override val publish: Publish,
override val clientId: Option[String],
override val packetId: PacketId,
local: Promise[ForwardPublish.type],
override val packetRouter: ActorRef[RemotePacketRouter.Request[Event]],
override val settings: MqttSessionSettings)
extends Data(publish, packetId, packetRouter, settings)
extends Data(publish, clientId, packetId, packetRouter, settings)
final case class ClientConsuming(override val publish: Publish,
override val clientId: Option[String],
override val packetId: PacketId,
override val packetRouter: ActorRef[RemotePacketRouter.Request[Event]],
override val settings: MqttSessionSettings)
extends Data(publish, packetId, packetRouter, settings)
extends Data(publish, clientId, packetId, packetRouter, settings)

sealed abstract class Event
final case object RegisteredPacketId extends Event
Expand All @@ -213,6 +223,7 @@ import scala.util.{Failure, Success}
case object ReceivePubRelTimeout extends Event
final case class PubCompReceivedLocally(remote: Promise[ForwardPubComp.type]) extends Event
case object ReceivePubCompTimeout extends Event
final case class DupPublishReceivedFromRemote(local: Promise[ForwardPublish.type]) extends Event

sealed abstract class Command
case object ForwardPublish extends Command
Expand All @@ -225,7 +236,7 @@ import scala.util.{Failure, Success}

def prepareClientConsumption(data: Start): Behavior[Event] = Behaviors.setup { context =>
val reply = Promise[RemotePacketRouter.Registered.type]
data.packetRouter ! RemotePacketRouter.Register(context.self.upcast, data.packetId, reply)
data.packetRouter ! RemotePacketRouter.Register(context.self.upcast, data.clientId, data.packetId, reply)
import context.executionContext
reply.future.onComplete {
case Success(RemotePacketRouter.Registered) => context.self ! RegisteredPacketId
Expand All @@ -235,7 +246,12 @@ import scala.util.{Failure, Success}
Behaviors.receiveMessagePartial[Event] {
case RegisteredPacketId =>
data.local.success(ForwardPublish)
consumeUnacknowledged(ClientConsuming(data.publish, data.packetId, data.packetRouter, data.settings))
consumeUnacknowledged(
ClientConsuming(data.publish, data.clientId, data.packetId, data.packetRouter, data.settings)
)
case _: DupPublishReceivedFromRemote =>
data.local.failure(ConsumeActive)
throw ConsumeActive
case UnobtainablePacketId =>
data.local.failure(ConsumeFailed)
throw ConsumeFailed
Expand All @@ -253,12 +269,15 @@ import scala.util.{Failure, Success}
case PubRecReceivedLocally(remote) if data.publish.flags.contains(ControlPacketFlags.QoSExactlyOnceDelivery) =>
remote.success(ForwardPubRec)
consumeReceived(data)
case DupPublishReceivedFromRemote(local) =>
local.success(ForwardPublish)
consumeUnacknowledged(data)
case ReceivePubAckRecTimeout =>
throw ConsumeFailed
}
.receiveSignal {
case (_, PostStop) =>
data.packetRouter ! RemotePacketRouter.Unregister(data.packetId)
data.packetRouter ! RemotePacketRouter.Unregister(data.clientId, data.packetId)
Behaviors.same
}
}
Expand All @@ -270,12 +289,15 @@ import scala.util.{Failure, Success}
case PubRelReceivedFromRemote(local) =>
local.success(ForwardPubRel)
consumeAcknowledged(data)
case DupPublishReceivedFromRemote(local) =>
local.success(ForwardPublish)
consumeUnacknowledged(data)
case ReceivePubRelTimeout =>
throw ConsumeFailed
}
.receiveSignal {
case (_, PostStop) =>
data.packetRouter ! RemotePacketRouter.Unregister(data.packetId)
data.packetRouter ! RemotePacketRouter.Unregister(data.clientId, data.packetId)
Behaviors.same
}
}
Expand All @@ -287,12 +309,15 @@ import scala.util.{Failure, Success}
case PubCompReceivedLocally(remote) =>
remote.success(ForwardPubComp)
Behaviors.stopped
case DupPublishReceivedFromRemote(local) =>
local.success(ForwardPublish)
consumeUnacknowledged(data)
case ReceivePubCompTimeout =>
throw ConsumeFailed
}
.receiveSignal {
case (_, PostStop) =>
data.packetRouter ! RemotePacketRouter.Unregister(data.packetId)
data.packetRouter ! RemotePacketRouter.Unregister(data.clientId, data.packetId)
Behaviors.same
}
}
Expand Down Expand Up @@ -381,10 +406,21 @@ import scala.util.{Failure, Success}
// Requests

sealed abstract class Request[A]
final case class Register[A](registrant: ActorRef[A], packetId: PacketId, reply: Promise[Registered.type])
final case class Register[A](registrant: ActorRef[A],
clientId: Option[String],
packetId: PacketId,
reply: Promise[Registered.type])
extends Request[A]
final case class RegisterConnection[A](connectionId: ByteString, clientId: String) extends Request[A]
final case class Unregister[A](clientId: Option[String], packetId: PacketId) extends Request[A]
final case class UnregisterConnection[A](connectionId: ByteString) extends Request[A]
final case class Route[A](clientId: Option[String], packetId: PacketId, event: A, failureReply: Promise[_])
extends Request[A]
final case class RouteViaConnection[A](connectionId: ByteString,
packetId: PacketId,
event: A,
failureReply: Promise[_])
extends Request[A]
final case class Unregister[A](packetId: PacketId) extends Request[A]
final case class Route[A](packetId: PacketId, event: A, failureReply: Promise[_]) extends Request[A]

// Replies

Expand All @@ -395,7 +431,7 @@ import scala.util.{Failure, Success}
* Construct with the starting state
*/
def apply[A]: Behavior[Request[A]] =
new RemotePacketRouter[A].main(Map.empty)
new RemotePacketRouter[A].main(Map.empty, Map.empty)
}

/*
Expand All @@ -409,18 +445,38 @@ import scala.util.{Failure, Success}

// Processing

def main(registrantsByPacketId: Map[PacketId, ActorRef[A]]): Behavior[Request[A]] =
def main(registrantsByPacketId: Map[(Option[String], PacketId), ActorRef[A]],
clientIdsByConnectionId: Map[ByteString, String]): Behavior[Request[A]] =
Behaviors.receiveMessage {
case Register(registrant: ActorRef[A], packetId, reply) =>
case Register(registrant: ActorRef[A], clientId, packetId, reply) =>
reply.success(Registered)
main(registrantsByPacketId + (packetId -> registrant))
case Unregister(packetId) =>
main(registrantsByPacketId - packetId)
case Route(packetId, event, failureReply) =>
registrantsByPacketId.get(packetId) match {
val key = (clientId, packetId)
main(registrantsByPacketId + (key -> registrant), clientIdsByConnectionId)
case RegisterConnection(connectionId, clientId) =>
main(registrantsByPacketId, clientIdsByConnectionId + (connectionId -> clientId))
case Unregister(clientId, packetId) =>
val key = (clientId, packetId)
main(registrantsByPacketId - key, clientIdsByConnectionId)
case UnregisterConnection(connectionId) =>
main(registrantsByPacketId, clientIdsByConnectionId - connectionId)
case Route(clientId, packetId, event, failureReply) =>
val key = (clientId, packetId)
registrantsByPacketId.get(key) match {
case Some(reply) => reply ! event
case None => failureReply.failure(CannotRoute)
}
Behaviors.same
case RouteViaConnection(connectionId, packetId, event, failureReply) =>
clientIdsByConnectionId.get(connectionId) match {
case clientId: Some[String] =>
val key = (clientId, packetId)
registrantsByPacketId.get(key) match {
case Some(reply) => reply ! event
case None => failureReply.failure(CannotRoute)
}
case None =>
failureReply.failure(CannotRoute)
}
Behaviors.same
}
}
Loading

0 comments on commit f4d6193

Please sign in to comment.