Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Further MQTT streaming hardening #1327

Merged
merged 6 commits into from
Nov 16, 2018
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,8 @@ import scala.util.{Failure, Success}
connectFlags: ConnectFlags,
keepAlive: FiniteDuration,
pendingPingResp: Boolean,
activeConsumers: Set[String],
activeProducers: Set[String],
pendingLocalPublications: Seq[(String, PublishReceivedLocally)],
pendingRemotePublications: Seq[(String, PublishReceivedFromRemote)],
remote: SourceQueueWithComplete[ForwardConnectCommand],
Expand Down Expand Up @@ -210,6 +212,8 @@ import scala.util.{Failure, Success}
data.connect.connectFlags,
data.connect.keepAlive,
pendingPingResp = false,
Set.empty,
Set.empty,
Vector.empty,
Vector.empty,
data.remote,
Expand Down Expand Up @@ -284,30 +288,30 @@ import scala.util.{Failure, Success}
local.success(Consumer.ForwardPublish)
serverConnected(data)
case (context, prfr @ PublishReceivedFromRemote(publish @ Publish(_, topicName, Some(packetId), _), local)) =>
val consumerName = ActorName.mkName(ConsumerNamePrefix + topicName)
context.child(consumerName) match {
case None if !data.pendingRemotePublications.exists(_._1 == publish.topicName) =>
context.watchWith(
context.spawn(
Consumer(publish, packetId, local, data.consumerPacketRouter, data.settings),
consumerName
),
ConsumerFree(publish.topicName)
)
serverConnected(data)
case _ =>
serverConnected(
data.copy(pendingRemotePublications = data.pendingRemotePublications :+ (publish.topicName -> prfr))
)
if (!data.activeConsumers.contains(topicName)) {
val consumerName = ActorName.mkName(ConsumerNamePrefix + topicName + "-" + context.children.size)
context.watchWith(
context.spawn(
Consumer(publish, None, packetId, local, data.consumerPacketRouter, data.settings),
consumerName
),
ConsumerFree(publish.topicName)
)
serverConnected(data.copy(activeConsumers = data.activeConsumers + publish.topicName))
} else {
serverConnected(
data.copy(pendingRemotePublications = data.pendingRemotePublications :+ (publish.topicName -> prfr))
)
}
case (context, ConsumerFree(topicName)) =>
val i = data.pendingRemotePublications.indexWhere(_._1 == topicName)
if (i >= 0) {
val prfr = data.pendingRemotePublications(i)._2
val consumerName = ActorName.mkName(ConsumerNamePrefix + topicName)
val consumerName = ActorName.mkName(ConsumerNamePrefix + topicName + "-" + context.children.size)
context.watchWith(
context.spawn(
Consumer(prfr.publish,
None,
prfr.publish.packetId.get,
prfr.local,
data.consumerPacketRouter,
Expand All @@ -323,35 +327,34 @@ import scala.util.{Failure, Success}
)
)
} else {
serverConnected(data)
serverConnected(data.copy(activeConsumers = data.activeConsumers - topicName))
}
case (_, PublishReceivedLocally(publish, _))
if (publish.flags & ControlPacketFlags.QoSReserved).underlying == 0 =>
data.remote.offer(ForwardPublish(publish, None))
serverConnected(data)
case (context, prl @ PublishReceivedLocally(publish, publishData)) =>
val producerName = ActorName.mkName(ProducerNamePrefix + publish.topicName)
context.child(producerName) match {
case None if !data.pendingLocalPublications.exists(_._1 == publish.topicName) =>
val reply = Promise[Source[Producer.ForwardPublishingCommand, NotUsed]]
import context.executionContext
reply.future.foreach(command => context.self ! ReceivedProducerPublishingCommand(command))
context.watchWith(
context.spawn(Producer(publish, publishData, reply, data.producerPacketRouter, data.settings),
producerName),
ProducerFree(publish.topicName)
)
serverConnected(data)
case _ =>
serverConnected(
data.copy(pendingLocalPublications = data.pendingLocalPublications :+ (publish.topicName -> prl))
)
val producerName = ActorName.mkName(ProducerNamePrefix + publish.topicName + "-" + context.children.size)
if (!data.activeProducers.contains(publish.topicName)) {
val reply = Promise[Source[Producer.ForwardPublishingCommand, NotUsed]]
import context.executionContext
reply.future.foreach(command => context.self ! ReceivedProducerPublishingCommand(command))
context.watchWith(
context.spawn(Producer(publish, publishData, reply, data.producerPacketRouter, data.settings),
producerName),
ProducerFree(publish.topicName)
)
serverConnected(data.copy(activeProducers = data.activeProducers + publish.topicName))
} else {
serverConnected(
data.copy(pendingLocalPublications = data.pendingLocalPublications :+ (publish.topicName -> prl))
)
}
case (context, ProducerFree(topicName)) =>
val i = data.pendingLocalPublications.indexWhere(_._1 == topicName)
if (i >= 0) {
val prl = data.pendingLocalPublications(i)._2
val producerName = ActorName.mkName(ProducerNamePrefix + topicName)
val producerName = ActorName.mkName(ProducerNamePrefix + topicName + "-" + context.children.size)
val reply = Promise[Source[Producer.ForwardPublishingCommand, NotUsed]]
import context.executionContext
reply.future.foreach(command => context.self ! ReceivedProducerPublishingCommand(command))
Expand All @@ -369,7 +372,7 @@ import scala.util.{Failure, Success}
)
)
} else {
serverConnected(data)
serverConnected(data.copy(activeProducers = data.activeProducers - topicName))
}
case (_, ReceivedProducerPublishingCommand(command)) =>
command.runWith(Sink.foreach {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ import akka.actor.typed.scaladsl.Behaviors
import akka.annotation.InternalApi
import akka.stream.{Materializer, OverflowStrategy}
import akka.stream.scaladsl.{BroadcastHub, Keep, Source, SourceQueueWithComplete}
import akka.util.ByteString

import scala.concurrent.Promise
import scala.util.control.NoStackTrace
Expand Down Expand Up @@ -175,33 +176,42 @@ import scala.util.{Failure, Success}
*/
case object ConsumeFailed extends Exception with NoStackTrace

/*
* A consume is active while a duplicate publish was received.
*/
case object ConsumeActive extends Exception with NoStackTrace

/*
* Construct with the starting state
*/
def apply(publish: Publish,
clientId: Option[String],
packetId: PacketId,
local: Promise[ForwardPublish.type],
packetRouter: ActorRef[RemotePacketRouter.Request[Event]],
settings: MqttSessionSettings): Behavior[Event] =
prepareClientConsumption(Start(publish, packetId, local, packetRouter, settings))
prepareClientConsumption(Start(publish, clientId, packetId, local, packetRouter, settings))

// Our FSM data, FSM events and commands emitted by the FSM

sealed abstract class Data(val publish: Publish,
val clientId: Option[String],
val packetId: PacketId,
val packetRouter: ActorRef[RemotePacketRouter.Request[Event]],
val settings: MqttSessionSettings)
final case class Start(override val publish: Publish,
override val clientId: Option[String],
override val packetId: PacketId,
local: Promise[ForwardPublish.type],
override val packetRouter: ActorRef[RemotePacketRouter.Request[Event]],
override val settings: MqttSessionSettings)
extends Data(publish, packetId, packetRouter, settings)
extends Data(publish, clientId, packetId, packetRouter, settings)
final case class ClientConsuming(override val publish: Publish,
override val clientId: Option[String],
override val packetId: PacketId,
override val packetRouter: ActorRef[RemotePacketRouter.Request[Event]],
override val settings: MqttSessionSettings)
extends Data(publish, packetId, packetRouter, settings)
extends Data(publish, clientId, packetId, packetRouter, settings)

sealed abstract class Event
final case object RegisteredPacketId extends Event
Expand All @@ -213,6 +223,7 @@ import scala.util.{Failure, Success}
case object ReceivePubRelTimeout extends Event
final case class PubCompReceivedLocally(remote: Promise[ForwardPubComp.type]) extends Event
case object ReceivePubCompTimeout extends Event
final case class DupPublishReceivedFromRemote(local: Promise[ForwardPublish.type]) extends Event

sealed abstract class Command
case object ForwardPublish extends Command
Expand All @@ -225,7 +236,7 @@ import scala.util.{Failure, Success}

def prepareClientConsumption(data: Start): Behavior[Event] = Behaviors.setup { context =>
val reply = Promise[RemotePacketRouter.Registered.type]
data.packetRouter ! RemotePacketRouter.Register(context.self.upcast, data.packetId, reply)
data.packetRouter ! RemotePacketRouter.Register(context.self.upcast, data.clientId, data.packetId, reply)
import context.executionContext
reply.future.onComplete {
case Success(RemotePacketRouter.Registered) => context.self ! RegisteredPacketId
Expand All @@ -235,7 +246,12 @@ import scala.util.{Failure, Success}
Behaviors.receiveMessagePartial[Event] {
case RegisteredPacketId =>
data.local.success(ForwardPublish)
consumeUnacknowledged(ClientConsuming(data.publish, data.packetId, data.packetRouter, data.settings))
consumeUnacknowledged(
ClientConsuming(data.publish, data.clientId, data.packetId, data.packetRouter, data.settings)
)
case _: DupPublishReceivedFromRemote =>
data.local.failure(ConsumeActive)
throw ConsumeActive
case UnobtainablePacketId =>
data.local.failure(ConsumeFailed)
throw ConsumeFailed
Expand All @@ -253,12 +269,15 @@ import scala.util.{Failure, Success}
case PubRecReceivedLocally(remote) if data.publish.flags.contains(ControlPacketFlags.QoSExactlyOnceDelivery) =>
remote.success(ForwardPubRec)
consumeReceived(data)
case DupPublishReceivedFromRemote(local) =>
local.success(ForwardPublish)
consumeUnacknowledged(data)
case ReceivePubAckRecTimeout =>
throw ConsumeFailed
}
.receiveSignal {
case (_, PostStop) =>
data.packetRouter ! RemotePacketRouter.Unregister(data.packetId)
data.packetRouter ! RemotePacketRouter.Unregister(data.clientId, data.packetId)
Behaviors.same
}
}
Expand All @@ -270,12 +289,15 @@ import scala.util.{Failure, Success}
case PubRelReceivedFromRemote(local) =>
local.success(ForwardPubRel)
consumeAcknowledged(data)
case DupPublishReceivedFromRemote(local) =>
local.success(ForwardPublish)
consumeUnacknowledged(data)
case ReceivePubRelTimeout =>
throw ConsumeFailed
}
.receiveSignal {
case (_, PostStop) =>
data.packetRouter ! RemotePacketRouter.Unregister(data.packetId)
data.packetRouter ! RemotePacketRouter.Unregister(data.clientId, data.packetId)
Behaviors.same
}
}
Expand All @@ -287,12 +309,15 @@ import scala.util.{Failure, Success}
case PubCompReceivedLocally(remote) =>
remote.success(ForwardPubComp)
Behaviors.stopped
case DupPublishReceivedFromRemote(local) =>
local.success(ForwardPublish)
consumeUnacknowledged(data)
case ReceivePubCompTimeout =>
throw ConsumeFailed
}
.receiveSignal {
case (_, PostStop) =>
data.packetRouter ! RemotePacketRouter.Unregister(data.packetId)
data.packetRouter ! RemotePacketRouter.Unregister(data.clientId, data.packetId)
Behaviors.same
}
}
Expand Down Expand Up @@ -381,10 +406,21 @@ import scala.util.{Failure, Success}
// Requests

sealed abstract class Request[A]
final case class Register[A](registrant: ActorRef[A], packetId: PacketId, reply: Promise[Registered.type])
final case class Register[A](registrant: ActorRef[A],
clientId: Option[String],
packetId: PacketId,
reply: Promise[Registered.type])
extends Request[A]
final case class RegisterConnection[A](connectionId: ByteString, clientId: String) extends Request[A]
final case class Unregister[A](clientId: Option[String], packetId: PacketId) extends Request[A]
final case class UnregisterConnection[A](connectionId: ByteString) extends Request[A]
final case class Route[A](clientId: Option[String], packetId: PacketId, event: A, failureReply: Promise[_])
extends Request[A]
final case class RouteViaConnection[A](connectionId: ByteString,
packetId: PacketId,
event: A,
failureReply: Promise[_])
extends Request[A]
final case class Unregister[A](packetId: PacketId) extends Request[A]
final case class Route[A](packetId: PacketId, event: A, failureReply: Promise[_]) extends Request[A]

// Replies

Expand All @@ -395,7 +431,7 @@ import scala.util.{Failure, Success}
* Construct with the starting state
*/
def apply[A]: Behavior[Request[A]] =
new RemotePacketRouter[A].main(Map.empty)
new RemotePacketRouter[A].main(Map.empty, Map.empty)
}

/*
Expand All @@ -409,18 +445,38 @@ import scala.util.{Failure, Success}

// Processing

def main(registrantsByPacketId: Map[PacketId, ActorRef[A]]): Behavior[Request[A]] =
def main(registrantsByPacketId: Map[(Option[String], PacketId), ActorRef[A]],
clientIdsByConnectionId: Map[ByteString, String]): Behavior[Request[A]] =
Behaviors.receiveMessage {
case Register(registrant: ActorRef[A], packetId, reply) =>
case Register(registrant: ActorRef[A], clientId, packetId, reply) =>
reply.success(Registered)
main(registrantsByPacketId + (packetId -> registrant))
case Unregister(packetId) =>
main(registrantsByPacketId - packetId)
case Route(packetId, event, failureReply) =>
registrantsByPacketId.get(packetId) match {
val key = (clientId, packetId)
main(registrantsByPacketId + (key -> registrant), clientIdsByConnectionId)
case RegisterConnection(connectionId, clientId) =>
main(registrantsByPacketId, clientIdsByConnectionId + (connectionId -> clientId))
case Unregister(clientId, packetId) =>
val key = (clientId, packetId)
main(registrantsByPacketId - key, clientIdsByConnectionId)
case UnregisterConnection(connectionId) =>
main(registrantsByPacketId, clientIdsByConnectionId - connectionId)
case Route(clientId, packetId, event, failureReply) =>
val key = (clientId, packetId)
registrantsByPacketId.get(key) match {
case Some(reply) => reply ! event
case None => failureReply.failure(CannotRoute)
}
Behaviors.same
case RouteViaConnection(connectionId, packetId, event, failureReply) =>
clientIdsByConnectionId.get(connectionId) match {
case clientId: Some[String] =>
val key = (clientId, packetId)
registrantsByPacketId.get(key) match {
case Some(reply) => reply ! event
case None => failureReply.failure(CannotRoute)
}
case None =>
failureReply.failure(CannotRoute)
}
Behaviors.same
}
}
Loading