From 2e4fcb9823a49a2134d7a6fc1ed3ae00695009c8 Mon Sep 17 00:00:00 2001 From: Christopher Hunt Date: Sun, 11 Nov 2018 22:49:18 +0100 Subject: [PATCH 1/6] Corrected misinterpretation of behaviour setup I had thought that Behaviors.setup would only get called once for an actor. This is not the case - it is called each time a behaviour is transitioned to. There was only one place where my misunderstanding mattered, which was where a promise could be tried to be completed multiple times. --- .../akka/stream/alpakka/mqtt/streaming/impl/ServerState.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mqtt-streaming/src/main/scala/akka/stream/alpakka/mqtt/streaming/impl/ServerState.scala b/mqtt-streaming/src/main/scala/akka/stream/alpakka/mqtt/streaming/impl/ServerState.scala index e997758097..328e01465e 100644 --- a/mqtt-streaming/src/main/scala/akka/stream/alpakka/mqtt/streaming/impl/ServerState.scala +++ b/mqtt-streaming/src/main/scala/akka/stream/alpakka/mqtt/streaming/impl/ServerState.scala @@ -325,7 +325,7 @@ import scala.util.{Failure, Success} private val UnpublisherNamePrefix = "unpublisher-" def clientConnect(data: ConnectReceived)(implicit mat: Materializer): Behavior[Event] = Behaviors.setup { _ => - data.local.success(ForwardConnect) + data.local.trySuccess(ForwardConnect) Behaviors.withTimers { timer => timer.startSingleTimer("receive-connack", ReceiveConnAckTimeout, data.settings.receiveConnAckTimeout) From 7b8763e55586ec66f2da1c70fb86ec1f3da602e0 Mon Sep 17 00:00:00 2001 From: Christopher Hunt Date: Sun, 11 Nov 2018 08:32:40 +0100 Subject: [PATCH 2/6] Avoid a race condition when creating child actors MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit We observed an `InvalidActorNameException` when creating child actors. This could have been due to the following sequence when receiving a “Publish Received Locally” message: 1. prl received, producer actor created 2. producer actor terminates and sends a termination message 3. prl received before termination message is received, the parent actor creates another producer 4. the termination message is received and then attempts to create another producer with the same name The solution is to explicitly track active consumers and producers rather than rely on another data structure such as `context.children`, which will be updated in response to other events. --- .../mqtt/streaming/impl/ClientState.scala | 72 ++++++------ .../mqtt/streaming/impl/ServerState.scala | 106 +++++++++++------- 2 files changed, 101 insertions(+), 77 deletions(-) diff --git a/mqtt-streaming/src/main/scala/akka/stream/alpakka/mqtt/streaming/impl/ClientState.scala b/mqtt-streaming/src/main/scala/akka/stream/alpakka/mqtt/streaming/impl/ClientState.scala index 41200383f1..d1a04593a7 100644 --- a/mqtt-streaming/src/main/scala/akka/stream/alpakka/mqtt/streaming/impl/ClientState.scala +++ b/mqtt-streaming/src/main/scala/akka/stream/alpakka/mqtt/streaming/impl/ClientState.scala @@ -89,6 +89,8 @@ import scala.util.{Failure, Success} connectFlags: ConnectFlags, keepAlive: FiniteDuration, pendingPingResp: Boolean, + activeConsumers: Set[String], + activeProducers: Set[String], pendingLocalPublications: Seq[(String, PublishReceivedLocally)], pendingRemotePublications: Seq[(String, PublishReceivedFromRemote)], remote: SourceQueueWithComplete[ForwardConnectCommand], @@ -210,6 +212,8 @@ import scala.util.{Failure, Success} data.connect.connectFlags, data.connect.keepAlive, pendingPingResp = false, + Set.empty, + Set.empty, Vector.empty, Vector.empty, data.remote, @@ -284,27 +288,26 @@ import scala.util.{Failure, Success} local.success(Consumer.ForwardPublish) serverConnected(data) case (context, prfr @ PublishReceivedFromRemote(publish @ Publish(_, topicName, Some(packetId), _), local)) => - val consumerName = ActorName.mkName(ConsumerNamePrefix + topicName) - context.child(consumerName) match { - case None if !data.pendingRemotePublications.exists(_._1 == publish.topicName) => - context.watchWith( - context.spawn( - Consumer(publish, packetId, local, data.consumerPacketRouter, data.settings), - consumerName - ), - ConsumerFree(publish.topicName) - ) - serverConnected(data) - case _ => - serverConnected( - data.copy(pendingRemotePublications = data.pendingRemotePublications :+ (publish.topicName -> prfr)) - ) + if (!data.activeConsumers.contains(topicName)) { + val consumerName = ActorName.mkName(ConsumerNamePrefix + topicName + context.children.size) + context.watchWith( + context.spawn( + Consumer(publish, packetId, local, data.consumerPacketRouter, data.settings), + consumerName + ), + ConsumerFree(publish.topicName) + ) + serverConnected(data.copy(activeConsumers = data.activeConsumers + publish.topicName)) + } else { + serverConnected( + data.copy(pendingRemotePublications = data.pendingRemotePublications :+ (publish.topicName -> prfr)) + ) } case (context, ConsumerFree(topicName)) => val i = data.pendingRemotePublications.indexWhere(_._1 == topicName) if (i >= 0) { val prfr = data.pendingRemotePublications(i)._2 - val consumerName = ActorName.mkName(ConsumerNamePrefix + topicName) + val consumerName = ActorName.mkName(ConsumerNamePrefix + topicName + context.children.size) context.watchWith( context.spawn( Consumer(prfr.publish, @@ -323,35 +326,34 @@ import scala.util.{Failure, Success} ) ) } else { - serverConnected(data) + serverConnected(data.copy(activeConsumers = data.activeConsumers - topicName)) } case (_, PublishReceivedLocally(publish, _)) if (publish.flags & ControlPacketFlags.QoSReserved).underlying == 0 => data.remote.offer(ForwardPublish(publish, None)) serverConnected(data) case (context, prl @ PublishReceivedLocally(publish, publishData)) => - val producerName = ActorName.mkName(ProducerNamePrefix + publish.topicName) - context.child(producerName) match { - case None if !data.pendingLocalPublications.exists(_._1 == publish.topicName) => - val reply = Promise[Source[Producer.ForwardPublishingCommand, NotUsed]] - import context.executionContext - reply.future.foreach(command => context.self ! ReceivedProducerPublishingCommand(command)) - context.watchWith( - context.spawn(Producer(publish, publishData, reply, data.producerPacketRouter, data.settings), - producerName), - ProducerFree(publish.topicName) - ) - serverConnected(data) - case _ => - serverConnected( - data.copy(pendingLocalPublications = data.pendingLocalPublications :+ (publish.topicName -> prl)) - ) + val producerName = ActorName.mkName(ProducerNamePrefix + publish.topicName + context.children.size) + if (!data.activeProducers.contains(publish.topicName)) { + val reply = Promise[Source[Producer.ForwardPublishingCommand, NotUsed]] + import context.executionContext + reply.future.foreach(command => context.self ! ReceivedProducerPublishingCommand(command)) + context.watchWith( + context.spawn(Producer(publish, publishData, reply, data.producerPacketRouter, data.settings), + producerName), + ProducerFree(publish.topicName) + ) + serverConnected(data.copy(activeProducers = data.activeProducers + publish.topicName)) + } else { + serverConnected( + data.copy(pendingLocalPublications = data.pendingLocalPublications :+ (publish.topicName -> prl)) + ) } case (context, ProducerFree(topicName)) => val i = data.pendingLocalPublications.indexWhere(_._1 == topicName) if (i >= 0) { val prl = data.pendingLocalPublications(i)._2 - val producerName = ActorName.mkName(ProducerNamePrefix + topicName) + val producerName = ActorName.mkName(ProducerNamePrefix + topicName + context.children.size) val reply = Promise[Source[Producer.ForwardPublishingCommand, NotUsed]] import context.executionContext reply.future.foreach(command => context.self ! ReceivedProducerPublishingCommand(command)) @@ -369,7 +371,7 @@ import scala.util.{Failure, Success} ) ) } else { - serverConnected(data) + serverConnected(data.copy(activeProducers = data.activeProducers - topicName)) } case (_, ReceivedProducerPublishingCommand(command)) => command.runWith(Sink.foreach { diff --git a/mqtt-streaming/src/main/scala/akka/stream/alpakka/mqtt/streaming/impl/ServerState.scala b/mqtt-streaming/src/main/scala/akka/stream/alpakka/mqtt/streaming/impl/ServerState.scala index 328e01465e..a8ba57e02b 100644 --- a/mqtt-streaming/src/main/scala/akka/stream/alpakka/mqtt/streaming/impl/ServerState.scala +++ b/mqtt-streaming/src/main/scala/akka/stream/alpakka/mqtt/streaming/impl/ServerState.scala @@ -214,6 +214,8 @@ import scala.util.{Failure, Success} connect, local, Set.empty, + Set.empty, + Set.empty, Vector.empty, Vector.empty, Vector.empty, @@ -236,6 +238,8 @@ import scala.util.{Failure, Success} connect: Connect, local: Promise[ForwardConnect.type], publishers: Set[String], + activeConsumers: Set[String], + activeProducers: Set[String], pendingLocalPublications: Seq[(String, PublishReceivedLocally)], pendingRemotePublications: Seq[(String, PublishReceivedFromRemote)], stash: Seq[Event], @@ -249,6 +253,8 @@ import scala.util.{Failure, Success} connect: Connect, remote: SourceQueueWithComplete[ForwardConnAckCommand], publishers: Set[String], + activeConsumers: Set[String], + activeProducers: Set[String], pendingLocalPublications: Seq[(String, PublishReceivedLocally)], pendingRemotePublications: Seq[(String, PublishReceivedFromRemote)], override val consumerPacketRouter: ActorRef[RemotePacketRouter.Request[Consumer.Event]], @@ -262,6 +268,8 @@ import scala.util.{Failure, Success} connect: Connect, remote: SourceQueueWithComplete[ForwardConnAckCommand], publishers: Set[String], + activeConsumers: Set[String], + activeProducers: Set[String], pendingLocalPublications: Seq[(String, PublishReceivedLocally)], pendingRemotePublications: Seq[(String, PublishReceivedFromRemote)], stash: Seq[Event], @@ -273,6 +281,8 @@ import scala.util.{Failure, Success} ) extends Data(consumerPacketRouter, producerPacketRouter, publisherPacketRouter, unpublisherPacketRouter, settings) final case class Disconnected( publishers: Set[String], + activeConsumers: Set[String], + activeProducers: Set[String], pendingLocalPublications: Seq[(String, PublishReceivedLocally)], pendingRemotePublications: Seq[(String, PublishReceivedFromRemote)], override val consumerPacketRouter: ActorRef[RemotePacketRouter.Request[Consumer.Event]], @@ -347,6 +357,8 @@ import scala.util.{Failure, Success} data.connect, queue, data.publishers, + data.activeConsumers, + data.activeProducers, data.pendingLocalPublications, data.pendingRemotePublications, data.consumerPacketRouter, @@ -392,6 +404,8 @@ import scala.util.{Failure, Success} data.connect, data.remote, data.publishers, + data.activeConsumers, + data.activeProducers, data.pendingLocalPublications, data.pendingRemotePublications, Vector.empty, @@ -426,29 +440,28 @@ import scala.util.{Failure, Success} case (_, PublishReceivedFromRemote(publish, local)) if (publish.flags & ControlPacketFlags.QoSReserved).underlying == 0 => local.success(Consumer.ForwardPublish) - Behaviors.same + clientConnected(data) case (context, prfr @ PublishReceivedFromRemote(publish @ Publish(_, topicName, Some(packetId), _), local)) => - val consumerName = ActorName.mkName(ConsumerNamePrefix + topicName) - context.child(consumerName) match { - case None if !data.pendingRemotePublications.exists(_._1 == publish.topicName) => - context.watchWith( - context.spawn( - Consumer(publish, packetId, local, data.consumerPacketRouter, data.settings), - consumerName - ), - ConsumerFree(publish.topicName) - ) - clientConnected(data) - case _ => - clientConnected( - data.copy(pendingRemotePublications = data.pendingRemotePublications :+ (publish.topicName -> prfr)) - ) + if (!data.activeConsumers.contains(topicName)) { + val consumerName = ActorName.mkName(ConsumerNamePrefix + topicName + context.children.size) + context.watchWith( + context.spawn( + Consumer(publish, packetId, local, data.consumerPacketRouter, data.settings), + consumerName + ), + ConsumerFree(publish.topicName) + ) + clientConnected(data.copy(activeConsumers = data.activeConsumers + publish.topicName)) + } else { + clientConnected( + data.copy(pendingRemotePublications = data.pendingRemotePublications :+ (publish.topicName -> prfr)) + ) } case (context, ConsumerFree(topicName)) => val i = data.pendingRemotePublications.indexWhere(_._1 == topicName) if (i >= 0) { val prfr = data.pendingRemotePublications(i)._2 - val consumerName = ActorName.mkName(ConsumerNamePrefix + topicName) + val consumerName = ActorName.mkName(ConsumerNamePrefix + topicName + context.children.size) context.watchWith( context.spawn( Consumer(prfr.publish, @@ -467,39 +480,34 @@ import scala.util.{Failure, Success} ) ) } else { - clientConnected(data) + clientConnected(data.copy(activeConsumers = data.activeConsumers - topicName)) } case (_, PublishReceivedLocally(publish, _)) - if (publish.flags & ControlPacketFlags.QoSReserved).underlying == 0 && - data.publishers.exists(matchTopicFilter(_, publish.topicName)) => + if (publish.flags & ControlPacketFlags.QoSReserved).underlying == 0 => data.remote.offer(ForwardPublish(publish, None)) clientConnected(data) - case (context, prl @ PublishReceivedLocally(publish, publishData)) - if data.publishers.exists(matchTopicFilter(_, publish.topicName)) => - val producerName = ActorName.mkName(ProducerNamePrefix + publish.topicName) - context.child(producerName) match { - case None if !data.pendingLocalPublications.exists(_._1 == publish.topicName) => - val reply = Promise[Source[Producer.ForwardPublishingCommand, NotUsed]] - import context.executionContext - reply.future.foreach(command => context.self ! ReceivedProducerPublishingCommand(command)) - context.watchWith( - context.spawn( - Producer(publish, publishData, reply, data.producerPacketRouter, data.settings), - producerName - ), - ProducerFree(publish.topicName) - ) - clientConnected(data) - case _ => - clientConnected( - data.copy(pendingLocalPublications = data.pendingLocalPublications :+ (publish.topicName -> prl)) - ) + case (context, prl @ PublishReceivedLocally(publish, publishData)) => + val producerName = ActorName.mkName(ProducerNamePrefix + publish.topicName + context.children.size) + if (!data.activeProducers.contains(publish.topicName)) { + val reply = Promise[Source[Producer.ForwardPublishingCommand, NotUsed]] + import context.executionContext + reply.future.foreach(command => context.self ! ReceivedProducerPublishingCommand(command)) + context.watchWith( + context.spawn(Producer(publish, publishData, reply, data.producerPacketRouter, data.settings), + producerName), + ProducerFree(publish.topicName) + ) + clientConnected(data.copy(activeProducers = data.activeProducers + publish.topicName)) + } else { + clientConnected( + data.copy(pendingLocalPublications = data.pendingLocalPublications :+ (publish.topicName -> prl)) + ) } case (context, ProducerFree(topicName)) => val i = data.pendingLocalPublications.indexWhere(_._1 == topicName) if (i >= 0) { val prl = data.pendingLocalPublications(i)._2 - val producerName = ActorName.mkName(ProducerNamePrefix + topicName) + val producerName = ActorName.mkName(ProducerNamePrefix + topicName + context.children.size) val reply = Promise[Source[Producer.ForwardPublishingCommand, NotUsed]] import context.executionContext reply.future.foreach(command => context.self ! ReceivedProducerPublishingCommand(command)) @@ -517,7 +525,7 @@ import scala.util.{Failure, Success} ) ) } else { - clientConnected(data) + clientConnected(data.copy(activeProducers = data.activeProducers - topicName)) } case (_, ReceivedProducerPublishingCommand(command)) => command.runWith(Sink.foreach { @@ -534,6 +542,8 @@ import scala.util.{Failure, Success} clientDisconnected( Disconnected( data.publishers, + data.activeConsumers, + data.activeProducers, data.pendingLocalPublications, data.pendingRemotePublications, data.consumerPacketRouter, @@ -548,6 +558,8 @@ import scala.util.{Failure, Success} clientDisconnected( Disconnected( data.publishers, + data.activeConsumers, + data.activeProducers, data.pendingLocalPublications, data.pendingRemotePublications, data.consumerPacketRouter, @@ -565,6 +577,8 @@ import scala.util.{Failure, Success} connect, local, Set.empty, + Set.empty, + Set.empty, Vector.empty, Vector.empty, Vector.empty, @@ -581,6 +595,8 @@ import scala.util.{Failure, Success} connect, local, data.publishers, + data.activeConsumers, + data.activeProducers, data.pendingLocalPublications, data.pendingRemotePublications, Vector.empty, @@ -616,6 +632,8 @@ import scala.util.{Failure, Success} data.remote, data.publishers ++ (if (t.failure.contains(Publisher.SubscribeFailed)) Vector.empty else data.subscribe.topicFilters.map(_._1)), + data.activeConsumers, + data.activeProducers, data.pendingLocalPublications, data.pendingRemotePublications, data.consumerPacketRouter, @@ -644,6 +662,8 @@ import scala.util.{Failure, Success} connect, local, Set.empty, + Set.empty, + Set.empty, Vector.empty, Vector.empty, Vector.empty, @@ -660,6 +680,8 @@ import scala.util.{Failure, Success} connect, local, data.publishers, + data.activeConsumers, + data.activeProducers, data.pendingLocalPublications, data.pendingRemotePublications, Vector.empty, From ba072d2045953df59850ce7062147944ebe949c6 Mon Sep 17 00:00:00 2001 From: Christopher Hunt Date: Mon, 12 Nov 2018 19:16:15 +0100 Subject: [PATCH 3/6] Handle a duplicate publish while consuming Duplicate publish events received from a remote were previously stashed until any existing handling was complete. This commit changes that so that they are routed and handled. The commit also tightens up some exception handling around benign exceptions that the developer should not need to consider. --- .../mqtt/streaming/impl/RequestState.scala | 18 +++++ .../mqtt/streaming/scaladsl/MqttSession.scala | 54 ++++++++++++--- .../scala/docs/scaladsl/MqttSessionSpec.scala | 69 +++++++++++++++++++ 3 files changed, 132 insertions(+), 9 deletions(-) diff --git a/mqtt-streaming/src/main/scala/akka/stream/alpakka/mqtt/streaming/impl/RequestState.scala b/mqtt-streaming/src/main/scala/akka/stream/alpakka/mqtt/streaming/impl/RequestState.scala index 9de39d6e8d..104c50fc46 100644 --- a/mqtt-streaming/src/main/scala/akka/stream/alpakka/mqtt/streaming/impl/RequestState.scala +++ b/mqtt-streaming/src/main/scala/akka/stream/alpakka/mqtt/streaming/impl/RequestState.scala @@ -175,6 +175,11 @@ import scala.util.{Failure, Success} */ case object ConsumeFailed extends Exception with NoStackTrace + /* + * A consume is active while a duplicate publish was received. + */ + case object ConsumeActive extends Exception with NoStackTrace + /* * Construct with the starting state */ @@ -213,6 +218,7 @@ import scala.util.{Failure, Success} case object ReceivePubRelTimeout extends Event final case class PubCompReceivedLocally(remote: Promise[ForwardPubComp.type]) extends Event case object ReceivePubCompTimeout extends Event + final case class DupPublishReceivedFromRemote(local: Promise[ForwardPublish.type]) extends Event sealed abstract class Command case object ForwardPublish extends Command @@ -236,6 +242,9 @@ import scala.util.{Failure, Success} case RegisteredPacketId => data.local.success(ForwardPublish) consumeUnacknowledged(ClientConsuming(data.publish, data.packetId, data.packetRouter, data.settings)) + case _: DupPublishReceivedFromRemote => + data.local.failure(ConsumeActive) + throw ConsumeActive case UnobtainablePacketId => data.local.failure(ConsumeFailed) throw ConsumeFailed @@ -253,6 +262,9 @@ import scala.util.{Failure, Success} case PubRecReceivedLocally(remote) if data.publish.flags.contains(ControlPacketFlags.QoSExactlyOnceDelivery) => remote.success(ForwardPubRec) consumeReceived(data) + case DupPublishReceivedFromRemote(local) => + local.success(ForwardPublish) + consumeUnacknowledged(data) case ReceivePubAckRecTimeout => throw ConsumeFailed } @@ -270,6 +282,9 @@ import scala.util.{Failure, Success} case PubRelReceivedFromRemote(local) => local.success(ForwardPubRel) consumeAcknowledged(data) + case DupPublishReceivedFromRemote(local) => + local.success(ForwardPublish) + consumeUnacknowledged(data) case ReceivePubRelTimeout => throw ConsumeFailed } @@ -287,6 +302,9 @@ import scala.util.{Failure, Success} case PubCompReceivedLocally(remote) => remote.success(ForwardPubComp) Behaviors.stopped + case DupPublishReceivedFromRemote(local) => + local.success(ForwardPublish) + consumeUnacknowledged(data) case ReceivePubCompTimeout => throw ConsumeFailed } diff --git a/mqtt-streaming/src/main/scala/akka/stream/alpakka/mqtt/streaming/scaladsl/MqttSession.scala b/mqtt-streaming/src/main/scala/akka/stream/alpakka/mqtt/streaming/scaladsl/MqttSession.scala index 0c71ba4d4b..c7d81f0f3d 100644 --- a/mqtt-streaming/src/main/scala/akka/stream/alpakka/mqtt/streaming/scaladsl/MqttSession.scala +++ b/mqtt-streaming/src/main/scala/akka/stream/alpakka/mqtt/streaming/scaladsl/MqttSession.scala @@ -9,7 +9,7 @@ import java.util.concurrent.atomic.AtomicLong import akka.{NotUsed, actor => untyped} import akka.actor.typed.scaladsl.adapter._ -import akka.stream.{Materializer, OverflowStrategy} +import akka.stream.{ActorAttributes, Materializer, OverflowStrategy, Supervision} import akka.stream.alpakka.mqtt.streaming.impl._ import akka.stream.scaladsl.{BroadcastHub, Flow, Keep, Source} import akka.util.ByteString @@ -127,7 +127,7 @@ final class ActorMqttClientSession(settings: MqttSessionSettings)(implicit mat: override def ![A](cp: Command[A]): Unit = cp match { case Command(cp: Publish, carry) => clientConnector ! ClientConnector.PublishReceivedLocally(cp, carry) - case c: Command[_] => throw new IllegalStateException(c + " is not a client command that can be sent directly") + case c: Command[A] => throw new IllegalStateException(c + " is not a client command that can be sent directly") } override def shutdown(): Unit = { @@ -141,7 +141,7 @@ final class ActorMqttClientSession(settings: MqttSessionSettings)(implicit mat: private val pingReqBytes = PingReq.encode(ByteString.newBuilder).result() override def commandFlow[A]: CommandFlow[A] = - Flow[Command[_]] + Flow[Command[A]] .watch(clientConnector.toUntyped) .watchTermination() { case (_, terminated) => @@ -189,9 +189,16 @@ final class ActorMqttClientSession(settings: MqttSessionSettings)(implicit mat: val reply = Promise[ClientConnector.ForwardDisconnect.type] clientConnector ! ClientConnector.DisconnectReceivedLocally(reply) Source.fromFuture(reply.future.map(_ => cp.encode(ByteString.newBuilder).result())) - case c: Command[_] => throw new IllegalStateException(c + " is not a client command") + case c: Command[A] => throw new IllegalStateException(c + " is not a client command") } ) + .withAttributes(ActorAttributes.supervisionStrategy { + // Benign exceptions + case RemotePacketRouter.CannotRoute => + Supervision.Resume + case _ => + Supervision.Stop + }) override def eventFlow[A]: EventFlow[A] = Flow[ByteString] @@ -203,7 +210,7 @@ final class ActorMqttClientSession(settings: MqttSessionSettings)(implicit mat: } .via(new MqttFrameStage(settings.maxPacketSize)) .map(_.iterator.decodeControlPacket(settings.maxPacketSize)) - .mapAsync(settings.eventParallelism) { + .mapAsync[Either[MqttCodec.DecodeError, Event[A]]](settings.eventParallelism) { case Right(cp: ConnAck) => val reply = Promise[ClientConnector.ForwardConnAck] clientConnector ! ClientConnector.ConnAckReceivedFromRemote(cp, reply) @@ -226,6 +233,10 @@ final class ActorMqttClientSession(settings: MqttSessionSettings)(implicit mat: reply.future.map { case Unsubscriber.ForwardUnsubAck(carry: Option[A] @unchecked) => Right(Event(cp, carry)) } + case Right(cp @ Publish(flags, _, Some(packetId), _)) if flags.contains(ControlPacketFlags.DUP) => + val reply = Promise[Consumer.ForwardPublish.type] + consumerPacketRouter ! RemotePacketRouter.Route(packetId, Consumer.DupPublishReceivedFromRemote(reply), reply) + reply.future.map(_ => Right(Event(cp))) case Right(cp: Publish) => val reply = Promise[Consumer.ForwardPublish.type] clientConnector ! ClientConnector.PublishReceivedFromRemote(cp, reply) @@ -259,6 +270,13 @@ final class ActorMqttClientSession(settings: MqttSessionSettings)(implicit mat: case Right(cp) => Future.failed(new IllegalStateException(cp + " is not a client event")) case Left(de) => Future.successful(Left(de)) } + .withAttributes(ActorAttributes.supervisionStrategy { + // Benign exceptions + case Consumer.ConsumeActive | LocalPacketRouter.CannotRoute | RemotePacketRouter.CannotRoute => + Supervision.Resume + case _ => + Supervision.Stop + }) } object MqttServerSession { @@ -359,7 +377,7 @@ final class ActorMqttServerSession(settings: MqttSessionSettings)(implicit mat: override def ![A](cp: Command[A]): Unit = cp match { case Command(cp: Publish, carry) => serverConnector ! ServerConnector.PublishReceivedLocally(cp, carry) - case c: Command[_] => throw new IllegalStateException(c + " is not a server command that can be sent directly") + case c: Command[A] => throw new IllegalStateException(c + " is not a server command that can be sent directly") } override def shutdown(): Unit = { @@ -374,7 +392,7 @@ final class ActorMqttServerSession(settings: MqttSessionSettings)(implicit mat: private val pingRespBytes = PingResp.encode(ByteString.newBuilder).result() override def commandFlow[A](connectionId: ByteString): CommandFlow[A] = - Flow[Command[_]] + Flow[Command[A]] .watch(serverConnector.toUntyped) .watchTermination() { case (_, terminated) => @@ -422,9 +440,16 @@ final class ActorMqttServerSession(settings: MqttSessionSettings)(implicit mat: val reply = Promise[Consumer.ForwardPubComp.type] consumerPacketRouter ! RemotePacketRouter.Route(cp.packetId, Consumer.PubCompReceivedLocally(reply), reply) Source.fromFuture(reply.future.map(_ => cp.encode(ByteString.newBuilder).result())) - case c: Command[_] => throw new IllegalStateException(c + " is not a server command") + case c: Command[A] => throw new IllegalStateException(c + " is not a server command") } ) + .withAttributes(ActorAttributes.supervisionStrategy { + // Benign exceptions + case RemotePacketRouter.CannotRoute => + Supervision.Resume + case _ => + Supervision.Stop + }) override def eventFlow[A](connectionId: ByteString): EventFlow[A] = Flow[ByteString] @@ -436,7 +461,7 @@ final class ActorMqttServerSession(settings: MqttSessionSettings)(implicit mat: } .via(new MqttFrameStage(settings.maxPacketSize)) .map(_.iterator.decodeControlPacket(settings.maxPacketSize)) - .mapAsync(settings.eventParallelism) { + .mapAsync[Either[MqttCodec.DecodeError, Event[A]]](settings.eventParallelism) { case Right(cp: Connect) => val reply = Promise[ClientConnection.ForwardConnect.type] serverConnector ! ServerConnector.ConnectReceivedFromRemote(connectionId, cp, reply) @@ -449,6 +474,10 @@ final class ActorMqttServerSession(settings: MqttSessionSettings)(implicit mat: val reply = Promise[Unpublisher.ForwardUnsubscribe.type] serverConnector ! ServerConnector.UnsubscribeReceivedFromRemote(connectionId, cp, reply) reply.future.map(_ => Right(Event(cp))) + case Right(cp @ Publish(flags, _, Some(packetId), _)) if flags.contains(ControlPacketFlags.DUP) => + val reply = Promise[Consumer.ForwardPublish.type] + consumerPacketRouter ! RemotePacketRouter.Route(packetId, Consumer.DupPublishReceivedFromRemote(reply), reply) + reply.future.map(_ => Right(Event(cp))) case Right(cp: Publish) => val reply = Promise[Consumer.ForwardPublish.type] serverConnector ! ServerConnector.PublishReceivedFromRemote(connectionId, cp, reply) @@ -486,4 +515,11 @@ final class ActorMqttServerSession(settings: MqttSessionSettings)(implicit mat: case Right(cp) => Future.failed(new IllegalStateException(cp + " is not a server event")) case Left(de) => Future.successful(Left(de)) } + .withAttributes(ActorAttributes.supervisionStrategy { + // Benign exceptions + case Consumer.ConsumeActive | LocalPacketRouter.CannotRoute | RemotePacketRouter.CannotRoute => + Supervision.Resume + case _ => + Supervision.Stop + }) } diff --git a/mqtt-streaming/src/test/scala/docs/scaladsl/MqttSessionSpec.scala b/mqtt-streaming/src/test/scala/docs/scaladsl/MqttSessionSpec.scala index 8ce1cbe0c6..463b55bccc 100644 --- a/mqtt-streaming/src/test/scala/docs/scaladsl/MqttSessionSpec.scala +++ b/mqtt-streaming/src/test/scala/docs/scaladsl/MqttSessionSpec.scala @@ -1086,6 +1086,75 @@ class MqttSessionSpec "re-connect given connect, subscribe, connect again, publish" in reconnectTest(explicitDisconnect = false) + + "receive a duplicate publish" in { + val session = ActorMqttServerSession(settings) + + val client = TestProbe() + val toClient = Sink.foreach[ByteString](bytes => client.ref ! bytes) + val (fromClientQueue, fromClient) = Source + .queue[ByteString](1, OverflowStrategy.dropHead) + .toMat(BroadcastHub.sink)(Keep.both) + .run() + + val pipeToClient = Flow.fromSinkAndSource(toClient, fromClient) + + val connect = Connect("some-client-id", ConnectFlags.None) + val connectReceived = Promise[Done] + + val publish = Publish("some-topic", ByteString("some-payload")) + val publishReceived = Promise[Done] + val dupPublishReceived = Promise[Done] + + val server = + Source + .queue[Command[Nothing]](1, OverflowStrategy.fail) + .via( + Mqtt + .serverSessionFlow(session, ByteString.empty) + .join(pipeToClient) + ) + .wireTap(Sink.foreach[Either[DecodeError, Event[_]]] { + case Right(Event(`connect`, _)) => + connectReceived.success(Done) + case Right(Event(cp: Publish, _)) if cp.flags.contains(ControlPacketFlags.DUP) => + dupPublishReceived.success(Done) + case Right(Event(_: Publish, _)) => + publishReceived.success(Done) + }) + .toMat(Sink.collection)(Keep.left) + .run + + val connectBytes = connect.encode(ByteString.newBuilder).result() + val connAck = ConnAck(ConnAckFlags.None, ConnAckReturnCode.ConnectionAccepted) + val connAckBytes = connAck.encode(ByteString.newBuilder).result() + + val publishBytes = publish.encode(ByteString.newBuilder, Some(PacketId(1))).result() + val dupPublishBytes = publish + .copy(flags = publish.flags | ControlPacketFlags.DUP) + .encode(ByteString.newBuilder, Some(PacketId(1))) + .result() + val pubAck = PubAck(PacketId(1)) + val pubAckBytes = pubAck.encode(ByteString.newBuilder).result() + + fromClientQueue.offer(connectBytes) + + connectReceived.future.futureValue shouldBe Done + + server.offer(Command(connAck)) + client.expectMsg(connAckBytes) + + fromClientQueue.offer(publishBytes) + + publishReceived.future.futureValue shouldBe Done + + fromClientQueue.offer(dupPublishBytes) + + dupPublishReceived.future.futureValue shouldBe Done + + server.offer(Command(pubAck)) + client.expectMsg(pubAckBytes) + } } override def afterAll: Unit = From 861d3c607fe5bb4e1567482ca360ff429c761c37 Mon Sep 17 00:00:00 2001 From: Christopher Hunt Date: Wed, 14 Nov 2018 06:54:02 +0100 Subject: [PATCH 4/6] Miscellaneous PR feedback In particular, actor names need to avoid the sequence number of an actor conflicting with MQTT topic names. --- .../mqtt/streaming/impl/ClientState.scala | 8 ++-- .../mqtt/streaming/impl/ServerState.scala | 8 ++-- .../scala/docs/scaladsl/MqttSessionSpec.scala | 40 +++++++++---------- 3 files changed, 28 insertions(+), 28 deletions(-) diff --git a/mqtt-streaming/src/main/scala/akka/stream/alpakka/mqtt/streaming/impl/ClientState.scala b/mqtt-streaming/src/main/scala/akka/stream/alpakka/mqtt/streaming/impl/ClientState.scala index d1a04593a7..1b56ea30c5 100644 --- a/mqtt-streaming/src/main/scala/akka/stream/alpakka/mqtt/streaming/impl/ClientState.scala +++ b/mqtt-streaming/src/main/scala/akka/stream/alpakka/mqtt/streaming/impl/ClientState.scala @@ -289,7 +289,7 @@ import scala.util.{Failure, Success} serverConnected(data) case (context, prfr @ PublishReceivedFromRemote(publish @ Publish(_, topicName, Some(packetId), _), local)) => if (!data.activeConsumers.contains(topicName)) { - val consumerName = ActorName.mkName(ConsumerNamePrefix + topicName + context.children.size) + val consumerName = ActorName.mkName(ConsumerNamePrefix + topicName + "-" + context.children.size) context.watchWith( context.spawn( Consumer(publish, packetId, local, data.consumerPacketRouter, data.settings), @@ -307,7 +307,7 @@ import scala.util.{Failure, Success} val i = data.pendingRemotePublications.indexWhere(_._1 == topicName) if (i >= 0) { val prfr = data.pendingRemotePublications(i)._2 - val consumerName = ActorName.mkName(ConsumerNamePrefix + topicName + context.children.size) + val consumerName = ActorName.mkName(ConsumerNamePrefix + topicName + "-" + context.children.size) context.watchWith( context.spawn( Consumer(prfr.publish, @@ -333,7 +333,7 @@ import scala.util.{Failure, Success} data.remote.offer(ForwardPublish(publish, None)) serverConnected(data) case (context, prl @ PublishReceivedLocally(publish, publishData)) => - val producerName = ActorName.mkName(ProducerNamePrefix + publish.topicName + context.children.size) + val producerName = ActorName.mkName(ProducerNamePrefix + publish.topicName + "-" + context.children.size) if (!data.activeProducers.contains(publish.topicName)) { val reply = Promise[Source[Producer.ForwardPublishingCommand, NotUsed]] import context.executionContext @@ -353,7 +353,7 @@ import scala.util.{Failure, Success} val i = data.pendingLocalPublications.indexWhere(_._1 == topicName) if (i >= 0) { val prl = data.pendingLocalPublications(i)._2 - val producerName = ActorName.mkName(ProducerNamePrefix + topicName + context.children.size) + val producerName = ActorName.mkName(ProducerNamePrefix + topicName + "-" + context.children.size) val reply = Promise[Source[Producer.ForwardPublishingCommand, NotUsed]] import context.executionContext reply.future.foreach(command => context.self ! ReceivedProducerPublishingCommand(command)) diff --git a/mqtt-streaming/src/main/scala/akka/stream/alpakka/mqtt/streaming/impl/ServerState.scala b/mqtt-streaming/src/main/scala/akka/stream/alpakka/mqtt/streaming/impl/ServerState.scala index a8ba57e02b..5cdde19a2a 100644 --- a/mqtt-streaming/src/main/scala/akka/stream/alpakka/mqtt/streaming/impl/ServerState.scala +++ b/mqtt-streaming/src/main/scala/akka/stream/alpakka/mqtt/streaming/impl/ServerState.scala @@ -443,7 +443,7 @@ import scala.util.{Failure, Success} clientConnected(data) case (context, prfr @ PublishReceivedFromRemote(publish @ Publish(_, topicName, Some(packetId), _), local)) => if (!data.activeConsumers.contains(topicName)) { - val consumerName = ActorName.mkName(ConsumerNamePrefix + topicName + context.children.size) + val consumerName = ActorName.mkName(ConsumerNamePrefix + topicName + "-" + context.children.size) context.watchWith( context.spawn( Consumer(publish, packetId, local, data.consumerPacketRouter, data.settings), @@ -461,7 +461,7 @@ import scala.util.{Failure, Success} val i = data.pendingRemotePublications.indexWhere(_._1 == topicName) if (i >= 0) { val prfr = data.pendingRemotePublications(i)._2 - val consumerName = ActorName.mkName(ConsumerNamePrefix + topicName + context.children.size) + val consumerName = ActorName.mkName(ConsumerNamePrefix + topicName + "-" + context.children.size) context.watchWith( context.spawn( Consumer(prfr.publish, @@ -487,7 +487,7 @@ import scala.util.{Failure, Success} data.remote.offer(ForwardPublish(publish, None)) clientConnected(data) case (context, prl @ PublishReceivedLocally(publish, publishData)) => - val producerName = ActorName.mkName(ProducerNamePrefix + publish.topicName + context.children.size) + val producerName = ActorName.mkName(ProducerNamePrefix + publish.topicName + "-" + context.children.size) if (!data.activeProducers.contains(publish.topicName)) { val reply = Promise[Source[Producer.ForwardPublishingCommand, NotUsed]] import context.executionContext @@ -507,7 +507,7 @@ import scala.util.{Failure, Success} val i = data.pendingLocalPublications.indexWhere(_._1 == topicName) if (i >= 0) { val prl = data.pendingLocalPublications(i)._2 - val producerName = ActorName.mkName(ProducerNamePrefix + topicName + context.children.size) + val producerName = ActorName.mkName(ProducerNamePrefix + topicName + "-" + context.children.size) val reply = Promise[Source[Producer.ForwardPublishingCommand, NotUsed]] import context.executionContext reply.future.foreach(command => context.self ! ReceivedProducerPublishingCommand(command)) diff --git a/mqtt-streaming/src/test/scala/docs/scaladsl/MqttSessionSpec.scala b/mqtt-streaming/src/test/scala/docs/scaladsl/MqttSessionSpec.scala index 463b55bccc..616dc0e79a 100644 --- a/mqtt-streaming/src/test/scala/docs/scaladsl/MqttSessionSpec.scala +++ b/mqtt-streaming/src/test/scala/docs/scaladsl/MqttSessionSpec.scala @@ -57,7 +57,7 @@ class MqttSessionSpec .join(pipeToServer) ) .toMat(Sink.collection)(Keep.both) - .run + .run() val connect = Connect("some-client-id", ConnectFlags.None) val connectBytes = connect.encode(ByteString.newBuilder).result() @@ -176,7 +176,7 @@ class MqttSessionSpec .join(pipeToServer) ) .toMat(Sink.ignore)(Keep.left) - .run + .run() val connectBytes = connect.encode(ByteString.newBuilder).result() val connAck = ConnAck(ConnAckFlags.None, ConnAckReturnCode.ConnectionAccepted) @@ -211,7 +211,7 @@ class MqttSessionSpec ) .drop(2) .toMat(Sink.head)(Keep.both) - .run + .run() val connectBytes = connect.encode(ByteString.newBuilder).result() val connAck = ConnAck(ConnAckFlags.None, ConnAckReturnCode.ConnectionAccepted) @@ -260,7 +260,7 @@ class MqttSessionSpec } .wireTap(_ => publishReceived.success(Done)) .toMat(Sink.ignore)(Keep.left) - .run + .run() val connectBytes = connect.encode(ByteString.newBuilder).result() val connAck = ConnAck(ConnAckFlags.None, ConnAckReturnCode.ConnectionAccepted) @@ -323,7 +323,7 @@ class MqttSessionSpec } } .toMat(Sink.ignore)(Keep.left) - .run + .run() val connectBytes = connect.encode(ByteString.newBuilder).result() val connAck = ConnAck(ConnAckFlags.None, ConnAckReturnCode.ConnectionAccepted) @@ -382,7 +382,7 @@ class MqttSessionSpec .join(pipeToServer) ) .toMat(Sink.ignore)(Keep.left) - .run + .run() val connect = Connect("some-client-id", ConnectFlags.None) val connectBytes = connect.encode(ByteString.newBuilder).result() @@ -418,7 +418,7 @@ class MqttSessionSpec ) .drop(1) .toMat(Sink.head)(Keep.both) - .run + .run() val connect = Connect("some-client-id", ConnectFlags.None) val connectBytes = connect.encode(ByteString.newBuilder).result() @@ -459,7 +459,7 @@ class MqttSessionSpec .join(pipeToServer) ) .toMat(Sink.ignore)(Keep.left) - .run + .run() val connect = Connect("some-client-id", ConnectFlags.None) val connectBytes = connect.encode(ByteString.newBuilder).result() @@ -501,7 +501,7 @@ class MqttSessionSpec .join(pipeToServer) ) .toMat(Sink.ignore)(Keep.left) - .run + .run() val connect = Connect("some-client-id", ConnectFlags.None) val connectBytes = connect.encode(ByteString.newBuilder).result() @@ -544,7 +544,7 @@ class MqttSessionSpec ) .drop(2) .toMat(Sink.head)(Keep.both) - .run + .run() val connect = Connect("some-client-id", ConnectFlags.None) val connectBytes = connect.encode(ByteString.newBuilder).result() @@ -592,7 +592,7 @@ class MqttSessionSpec .join(pipeToServer) ) .toMat(Sink.ignore)(Keep.left) - .run + .run() val connect = Connect("some-client-id", ConnectFlags.None).copy(keepAlive = 200.millis.dilated) val connectBytes = connect.encode(ByteString.newBuilder).result() @@ -631,7 +631,7 @@ class MqttSessionSpec .join(pipeToServer) ) .toMat(Sink.ignore)(Keep.both) - .run + .run() val connect = Connect("some-client-id", ConnectFlags.None).copy(keepAlive = 100.millis.dilated) val connectBytes = connect.encode(ByteString.newBuilder).result() @@ -666,7 +666,7 @@ class MqttSessionSpec ) .drop(1) .toMat(Sink.head)(Keep.both) - .run + .run() val connect = Connect("some-client-id", ConnectFlags.None) val connectBytes = connect.encode(ByteString.newBuilder).result() @@ -707,7 +707,7 @@ class MqttSessionSpec .join(pipeToServer) ) .toMat(Sink.ignore)(Keep.both) - .run + .run() val connect = Connect("some-client-id", ConnectFlags.None) @@ -763,7 +763,7 @@ class MqttSessionSpec unsubscribeReceived.success(Done) }) .toMat(Sink.collection)(Keep.both) - .run + .run() val connectBytes = connect.encode(ByteString.newBuilder).result() val connAck = ConnAck(ConnAckFlags.None, ConnAckReturnCode.ConnectionAccepted) @@ -851,7 +851,7 @@ class MqttSessionSpec case _ => }) .toMat(Sink.collection)(Keep.both) - .run + .run() val connectBytes = connect.encode(ByteString.newBuilder).result() val connAck = ConnAck(ConnAckFlags.None, ConnAckReturnCode.ConnectionAccepted) @@ -905,7 +905,7 @@ class MqttSessionSpec case _ => }) .toMat(Sink.ignore)(Keep.both) - .run + .run() val connectBytes = connect.encode(ByteString.newBuilder).result() val connAck = ConnAck(ConnAckFlags.None, ConnAckReturnCode.ConnectionAccepted) @@ -955,7 +955,7 @@ class MqttSessionSpec case _ => }) .toMat(Sink.ignore)(Keep.left) - .run + .run() val connectBytes = connect.encode(ByteString.newBuilder).result() val connAck = ConnAck(ConnAckFlags.None, ConnAckReturnCode.ConnectionAccepted) @@ -1027,7 +1027,7 @@ class MqttSessionSpec }) .drop(4) .toMat(Sink.head)(Keep.left) - .run + .run() val connectBytes = connect.encode(ByteString.newBuilder).result() val connAck = ConnAck(ConnAckFlags.None, ConnAckReturnCode.ConnectionAccepted) @@ -1123,7 +1123,7 @@ class MqttSessionSpec publishReceived.success(Done) }) .toMat(Sink.collection)(Keep.left) - .run + .run() val connectBytes = connect.encode(ByteString.newBuilder).result() val connAck = ConnAck(ConnAckFlags.None, ConnAckReturnCode.ConnectionAccepted) From 6da7a7a48cc92a13f4fabb7d0148ea72b53acb94 Mon Sep 17 00:00:00 2001 From: Christopher Hunt Date: Thu, 15 Nov 2018 16:13:26 +0100 Subject: [PATCH 5/6] Connection packet id distinction Packet ids were previously not distinguished according to their connection. This caused a problem when there was more than one client issuing the same packet ids! This commit addresses the problem by permitting the routing of remotely received packet ids by connection id and client id. --- .../mqtt/streaming/impl/ClientState.scala | 3 +- .../mqtt/streaming/impl/RequestState.scala | 76 ++++++++++++++----- .../mqtt/streaming/impl/ServerState.scala | 66 +++++++++++----- .../mqtt/streaming/scaladsl/MqttSession.scala | 62 +++++++++++---- .../streaming/impl/RequestStateSpec.scala | 26 +++++-- 5 files changed, 174 insertions(+), 59 deletions(-) diff --git a/mqtt-streaming/src/main/scala/akka/stream/alpakka/mqtt/streaming/impl/ClientState.scala b/mqtt-streaming/src/main/scala/akka/stream/alpakka/mqtt/streaming/impl/ClientState.scala index 1b56ea30c5..ccffc942ce 100644 --- a/mqtt-streaming/src/main/scala/akka/stream/alpakka/mqtt/streaming/impl/ClientState.scala +++ b/mqtt-streaming/src/main/scala/akka/stream/alpakka/mqtt/streaming/impl/ClientState.scala @@ -292,7 +292,7 @@ import scala.util.{Failure, Success} val consumerName = ActorName.mkName(ConsumerNamePrefix + topicName + "-" + context.children.size) context.watchWith( context.spawn( - Consumer(publish, packetId, local, data.consumerPacketRouter, data.settings), + Consumer(publish, None, packetId, local, data.consumerPacketRouter, data.settings), consumerName ), ConsumerFree(publish.topicName) @@ -311,6 +311,7 @@ import scala.util.{Failure, Success} context.watchWith( context.spawn( Consumer(prfr.publish, + None, prfr.publish.packetId.get, prfr.local, data.consumerPacketRouter, diff --git a/mqtt-streaming/src/main/scala/akka/stream/alpakka/mqtt/streaming/impl/RequestState.scala b/mqtt-streaming/src/main/scala/akka/stream/alpakka/mqtt/streaming/impl/RequestState.scala index 104c50fc46..f54230df47 100644 --- a/mqtt-streaming/src/main/scala/akka/stream/alpakka/mqtt/streaming/impl/RequestState.scala +++ b/mqtt-streaming/src/main/scala/akka/stream/alpakka/mqtt/streaming/impl/RequestState.scala @@ -11,6 +11,7 @@ import akka.actor.typed.scaladsl.Behaviors import akka.annotation.InternalApi import akka.stream.{Materializer, OverflowStrategy} import akka.stream.scaladsl.{BroadcastHub, Keep, Source, SourceQueueWithComplete} +import akka.util.ByteString import scala.concurrent.Promise import scala.util.control.NoStackTrace @@ -184,29 +185,33 @@ import scala.util.{Failure, Success} * Construct with the starting state */ def apply(publish: Publish, + clientId: Option[String], packetId: PacketId, local: Promise[ForwardPublish.type], packetRouter: ActorRef[RemotePacketRouter.Request[Event]], settings: MqttSessionSettings): Behavior[Event] = - prepareClientConsumption(Start(publish, packetId, local, packetRouter, settings)) + prepareClientConsumption(Start(publish, clientId, packetId, local, packetRouter, settings)) // Our FSM data, FSM events and commands emitted by the FSM sealed abstract class Data(val publish: Publish, + val clientId: Option[String], val packetId: PacketId, val packetRouter: ActorRef[RemotePacketRouter.Request[Event]], val settings: MqttSessionSettings) final case class Start(override val publish: Publish, + override val clientId: Option[String], override val packetId: PacketId, local: Promise[ForwardPublish.type], override val packetRouter: ActorRef[RemotePacketRouter.Request[Event]], override val settings: MqttSessionSettings) - extends Data(publish, packetId, packetRouter, settings) + extends Data(publish, clientId, packetId, packetRouter, settings) final case class ClientConsuming(override val publish: Publish, + override val clientId: Option[String], override val packetId: PacketId, override val packetRouter: ActorRef[RemotePacketRouter.Request[Event]], override val settings: MqttSessionSettings) - extends Data(publish, packetId, packetRouter, settings) + extends Data(publish, clientId, packetId, packetRouter, settings) sealed abstract class Event final case object RegisteredPacketId extends Event @@ -231,7 +236,7 @@ import scala.util.{Failure, Success} def prepareClientConsumption(data: Start): Behavior[Event] = Behaviors.setup { context => val reply = Promise[RemotePacketRouter.Registered.type] - data.packetRouter ! RemotePacketRouter.Register(context.self.upcast, data.packetId, reply) + data.packetRouter ! RemotePacketRouter.Register(context.self.upcast, data.clientId, data.packetId, reply) import context.executionContext reply.future.onComplete { case Success(RemotePacketRouter.Registered) => context.self ! RegisteredPacketId @@ -241,7 +246,9 @@ import scala.util.{Failure, Success} Behaviors.receiveMessagePartial[Event] { case RegisteredPacketId => data.local.success(ForwardPublish) - consumeUnacknowledged(ClientConsuming(data.publish, data.packetId, data.packetRouter, data.settings)) + consumeUnacknowledged( + ClientConsuming(data.publish, data.clientId, data.packetId, data.packetRouter, data.settings) + ) case _: DupPublishReceivedFromRemote => data.local.failure(ConsumeActive) throw ConsumeActive @@ -270,7 +277,7 @@ import scala.util.{Failure, Success} } .receiveSignal { case (_, PostStop) => - data.packetRouter ! RemotePacketRouter.Unregister(data.packetId) + data.packetRouter ! RemotePacketRouter.Unregister(data.clientId, data.packetId) Behaviors.same } } @@ -290,7 +297,7 @@ import scala.util.{Failure, Success} } .receiveSignal { case (_, PostStop) => - data.packetRouter ! RemotePacketRouter.Unregister(data.packetId) + data.packetRouter ! RemotePacketRouter.Unregister(data.clientId, data.packetId) Behaviors.same } } @@ -310,7 +317,7 @@ import scala.util.{Failure, Success} } .receiveSignal { case (_, PostStop) => - data.packetRouter ! RemotePacketRouter.Unregister(data.packetId) + data.packetRouter ! RemotePacketRouter.Unregister(data.clientId, data.packetId) Behaviors.same } } @@ -399,10 +406,21 @@ import scala.util.{Failure, Success} // Requests sealed abstract class Request[A] - final case class Register[A](registrant: ActorRef[A], packetId: PacketId, reply: Promise[Registered.type]) + final case class Register[A](registrant: ActorRef[A], + clientId: Option[String], + packetId: PacketId, + reply: Promise[Registered.type]) + extends Request[A] + final case class RegisterConnection[A](connectionId: ByteString, clientId: String) extends Request[A] + final case class Unregister[A](clientId: Option[String], packetId: PacketId) extends Request[A] + final case class UnregisterConnection[A](connectionId: ByteString) extends Request[A] + final case class Route[A](clientId: Option[String], packetId: PacketId, event: A, failureReply: Promise[_]) + extends Request[A] + final case class RouteViaConnection[A](connectionId: ByteString, + packetId: PacketId, + event: A, + failureReply: Promise[_]) extends Request[A] - final case class Unregister[A](packetId: PacketId) extends Request[A] - final case class Route[A](packetId: PacketId, event: A, failureReply: Promise[_]) extends Request[A] // Replies @@ -413,7 +431,7 @@ import scala.util.{Failure, Success} * Construct with the starting state */ def apply[A]: Behavior[Request[A]] = - new RemotePacketRouter[A].main(Map.empty) + new RemotePacketRouter[A].main(Map.empty, Map.empty) } /* @@ -427,18 +445,38 @@ import scala.util.{Failure, Success} // Processing - def main(registrantsByPacketId: Map[PacketId, ActorRef[A]]): Behavior[Request[A]] = + def main(registrantsByPacketId: Map[(Option[String], PacketId), ActorRef[A]], + clientIdsByConnectionId: Map[ByteString, String]): Behavior[Request[A]] = Behaviors.receiveMessage { - case Register(registrant: ActorRef[A], packetId, reply) => + case Register(registrant: ActorRef[A], clientId, packetId, reply) => reply.success(Registered) - main(registrantsByPacketId + (packetId -> registrant)) - case Unregister(packetId) => - main(registrantsByPacketId - packetId) - case Route(packetId, event, failureReply) => - registrantsByPacketId.get(packetId) match { + val key = (clientId, packetId) + main(registrantsByPacketId + (key -> registrant), clientIdsByConnectionId) + case RegisterConnection(connectionId, clientId) => + main(registrantsByPacketId, clientIdsByConnectionId + (connectionId -> clientId)) + case Unregister(clientId, packetId) => + val key = (clientId, packetId) + main(registrantsByPacketId - key, clientIdsByConnectionId) + case UnregisterConnection(connectionId) => + main(registrantsByPacketId, clientIdsByConnectionId - connectionId) + case Route(clientId, packetId, event, failureReply) => + val key = (clientId, packetId) + registrantsByPacketId.get(key) match { case Some(reply) => reply ! event case None => failureReply.failure(CannotRoute) } Behaviors.same + case RouteViaConnection(connectionId, packetId, event, failureReply) => + clientIdsByConnectionId.get(connectionId) match { + case clientId: Some[String] => + val key = (clientId, packetId) + registrantsByPacketId.get(key) match { + case Some(reply) => reply ! event + case None => failureReply.failure(CannotRoute) + } + case None => + failureReply.failure(CannotRoute) + } + Behaviors.same } } diff --git a/mqtt-streaming/src/main/scala/akka/stream/alpakka/mqtt/streaming/impl/ServerState.scala b/mqtt-streaming/src/main/scala/akka/stream/alpakka/mqtt/streaming/impl/ServerState.scala index 5cdde19a2a..509aa356f2 100644 --- a/mqtt-streaming/src/main/scala/akka/stream/alpakka/mqtt/streaming/impl/ServerState.scala +++ b/mqtt-streaming/src/main/scala/akka/stream/alpakka/mqtt/streaming/impl/ServerState.scala @@ -138,6 +138,9 @@ import scala.util.{Failure, Success} cc } context.watch(clientConnection) + data.consumerPacketRouter ! RemotePacketRouter.RegisterConnection(connectionId, connect.clientId) + data.publisherPacketRouter ! RemotePacketRouter.RegisterConnection(connectionId, connect.clientId) + data.unpublisherPacketRouter ! RemotePacketRouter.RegisterConnection(connectionId, connect.clientId) val newConnection = (connectionId, (connect.clientId, clientConnection)) listening( data.copy( @@ -174,6 +177,9 @@ import scala.util.{Failure, Success} case Some((connectionId, (clientId, _))) => if (t.failure.contains(ClientConnection.ClientConnectionFailed)) data.terminations.offer(ClientSessionTerminated(clientId)) + data.consumerPacketRouter ! RemotePacketRouter.UnregisterConnection(connectionId) + data.publisherPacketRouter ! RemotePacketRouter.UnregisterConnection(connectionId) + data.unpublisherPacketRouter ! RemotePacketRouter.UnregisterConnection(connectionId) listening(data.copy(clientConnections = data.clientConnections - connectionId)) case None => Behaviors.same @@ -394,7 +400,11 @@ import scala.util.{Failure, Success} context.child(publisherName) match { case None => val publisher = context.spawn( - Publisher(subscribe.packetId, local, data.publisherPacketRouter, data.settings), + Publisher(data.connect.clientId, + subscribe.packetId, + local, + data.publisherPacketRouter, + data.settings), publisherName ) context.watch(publisher) @@ -425,7 +435,11 @@ import scala.util.{Failure, Success} context.child(unpublisherName) match { case None => val unpublisher = context.spawn( - Unpublisher(unsubscribe.packetId, local, data.unpublisherPacketRouter, data.settings), + Unpublisher(data.connect.clientId, + unsubscribe.packetId, + local, + data.unpublisherPacketRouter, + data.settings), unpublisherName ) context.watchWith(unpublisher, UnpublisherFree(unsubscribe.topicFilters)) @@ -446,7 +460,12 @@ import scala.util.{Failure, Success} val consumerName = ActorName.mkName(ConsumerNamePrefix + topicName + "-" + context.children.size) context.watchWith( context.spawn( - Consumer(publish, packetId, local, data.consumerPacketRouter, data.settings), + Consumer(publish, + Some(data.connect.clientId), + packetId, + local, + data.consumerPacketRouter, + data.settings), consumerName ), ConsumerFree(publish.topicName) @@ -465,6 +484,7 @@ import scala.util.{Failure, Success} context.watchWith( context.spawn( Consumer(prfr.publish, + Some(data.connect.clientId), prfr.publish.packetId.get, prfr.local, data.consumerPacketRouter, @@ -756,24 +776,27 @@ import scala.util.{Failure, Success} /* * Construct with the starting state */ - def apply(packetId: PacketId, + def apply(clientId: String, + packetId: PacketId, local: Promise[ForwardSubscribe.type], packetRouter: ActorRef[RemotePacketRouter.Request[Event]], settings: MqttSessionSettings): Behavior[Event] = - preparePublisher(Start(packetId, local, packetRouter, settings)) + preparePublisher(Start(Some(clientId), packetId, local, packetRouter, settings)) // Our FSM data, FSM events and commands emitted by the FSM - sealed abstract class Data(val packetId: PacketId, val settings: MqttSessionSettings) - final case class Start(override val packetId: PacketId, + sealed abstract class Data(val clientId: Some[String], val packetId: PacketId, val settings: MqttSessionSettings) + final case class Start(override val clientId: Some[String], + override val packetId: PacketId, local: Promise[ForwardSubscribe.type], packetRouter: ActorRef[RemotePacketRouter.Request[Event]], override val settings: MqttSessionSettings) - extends Data(packetId, settings) - final case class ServerSubscribe(override val packetId: PacketId, + extends Data(clientId, packetId, settings) + final case class ServerSubscribe(override val clientId: Some[String], + override val packetId: PacketId, packetRouter: ActorRef[RemotePacketRouter.Request[Event]], override val settings: MqttSessionSettings) - extends Data(packetId, settings) + extends Data(clientId, packetId, settings) sealed abstract class Event final case object RegisteredPacketId extends Event @@ -789,7 +812,7 @@ import scala.util.{Failure, Success} def preparePublisher(data: Start): Behavior[Event] = Behaviors.setup { context => val reply = Promise[RemotePacketRouter.Registered.type] - data.packetRouter ! RemotePacketRouter.Register(context.self.upcast, data.packetId, reply) + data.packetRouter ! RemotePacketRouter.Register(context.self.upcast, data.clientId, data.packetId, reply) import context.executionContext reply.future.onComplete { case Success(RemotePacketRouter.Registered) => context.self ! RegisteredPacketId @@ -799,7 +822,7 @@ import scala.util.{Failure, Success} Behaviors.receiveMessagePartial[Event] { case RegisteredPacketId => data.local.success(ForwardSubscribe) - serverSubscribe(ServerSubscribe(data.packetId, data.packetRouter, data.settings)) + serverSubscribe(ServerSubscribe(data.clientId, data.packetId, data.packetRouter, data.settings)) case UnobtainablePacketId => data.local.failure(SubscribeFailed) throw SubscribeFailed @@ -820,7 +843,7 @@ import scala.util.{Failure, Success} } .receiveSignal { case (_, PostStop) => - data.packetRouter ! RemotePacketRouter.Unregister(data.packetId) + data.packetRouter ! RemotePacketRouter.Unregister(data.clientId, data.packetId) Behaviors.same } } @@ -840,21 +863,24 @@ import scala.util.{Failure, Success} /* * Construct with the starting state */ - def apply(packetId: PacketId, + def apply(clientId: String, + packetId: PacketId, local: Promise[ForwardUnsubscribe.type], packetRouter: ActorRef[RemotePacketRouter.Request[Event]], settings: MqttSessionSettings): Behavior[Event] = - prepareServerUnpublisher(Start(packetId, local, packetRouter, settings)) + prepareServerUnpublisher(Start(Some(clientId), packetId, local, packetRouter, settings)) // Our FSM data, FSM events and commands emitted by the FSM sealed abstract class Data(val settings: MqttSessionSettings) - final case class Start(packetId: PacketId, + final case class Start(clientId: Some[String], + packetId: PacketId, local: Promise[ForwardUnsubscribe.type], packetRouter: ActorRef[RemotePacketRouter.Request[Event]], override val settings: MqttSessionSettings) extends Data(settings) - final case class ServerUnsubscribe(packetId: PacketId, + final case class ServerUnsubscribe(clientId: Some[String], + packetId: PacketId, packetRouter: ActorRef[RemotePacketRouter.Request[Event]], override val settings: MqttSessionSettings) extends Data(settings) @@ -873,7 +899,7 @@ import scala.util.{Failure, Success} def prepareServerUnpublisher(data: Start): Behavior[Event] = Behaviors.setup { context => val reply = Promise[RemotePacketRouter.Registered.type] - data.packetRouter ! RemotePacketRouter.Register(context.self.upcast, data.packetId, reply) + data.packetRouter ! RemotePacketRouter.Register(context.self.upcast, data.clientId, data.packetId, reply) import context.executionContext reply.future.onComplete { case Success(RemotePacketRouter.Registered) => context.self ! RegisteredPacketId @@ -883,7 +909,7 @@ import scala.util.{Failure, Success} Behaviors.receiveMessagePartial[Event] { case RegisteredPacketId => data.local.success(ForwardUnsubscribe) - serverUnsubscribe(ServerUnsubscribe(data.packetId, data.packetRouter, data.settings)) + serverUnsubscribe(ServerUnsubscribe(data.clientId, data.packetId, data.packetRouter, data.settings)) case UnobtainablePacketId => data.local.failure(UnsubscribeFailed) throw UnsubscribeFailed @@ -903,7 +929,7 @@ import scala.util.{Failure, Success} } .receiveSignal { case (_, PostStop) => - data.packetRouter ! RemotePacketRouter.Unregister(data.packetId) + data.packetRouter ! RemotePacketRouter.Unregister(data.clientId, data.packetId) Behaviors.same } } diff --git a/mqtt-streaming/src/main/scala/akka/stream/alpakka/mqtt/streaming/scaladsl/MqttSession.scala b/mqtt-streaming/src/main/scala/akka/stream/alpakka/mqtt/streaming/scaladsl/MqttSession.scala index c7d81f0f3d..33143bd2ba 100644 --- a/mqtt-streaming/src/main/scala/akka/stream/alpakka/mqtt/streaming/scaladsl/MqttSession.scala +++ b/mqtt-streaming/src/main/scala/akka/stream/alpakka/mqtt/streaming/scaladsl/MqttSession.scala @@ -167,15 +167,24 @@ final class ActorMqttClientSession(settings: MqttSessionSettings)(implicit mat: ) case Command(cp: PubAck, _) => val reply = Promise[Consumer.ForwardPubAck.type] - consumerPacketRouter ! RemotePacketRouter.Route(cp.packetId, Consumer.PubAckReceivedLocally(reply), reply) + consumerPacketRouter ! RemotePacketRouter.Route(None, + cp.packetId, + Consumer.PubAckReceivedLocally(reply), + reply) Source.fromFuture(reply.future.map(_ => cp.encode(ByteString.newBuilder).result())) case Command(cp: PubRec, _) => val reply = Promise[Consumer.ForwardPubRec.type] - consumerPacketRouter ! RemotePacketRouter.Route(cp.packetId, Consumer.PubRecReceivedLocally(reply), reply) + consumerPacketRouter ! RemotePacketRouter.Route(None, + cp.packetId, + Consumer.PubRecReceivedLocally(reply), + reply) Source.fromFuture(reply.future.map(_ => cp.encode(ByteString.newBuilder).result())) case Command(cp: PubComp, _) => val reply = Promise[Consumer.ForwardPubComp.type] - consumerPacketRouter ! RemotePacketRouter.Route(cp.packetId, Consumer.PubCompReceivedLocally(reply), reply) + consumerPacketRouter ! RemotePacketRouter.Route(None, + cp.packetId, + Consumer.PubCompReceivedLocally(reply), + reply) Source.fromFuture(reply.future.map(_ => cp.encode(ByteString.newBuilder).result())) case Command(cp: Subscribe, carry) => val reply = Promise[Subscriber.ForwardSubscribe] @@ -235,7 +244,10 @@ final class ActorMqttClientSession(settings: MqttSessionSettings)(implicit mat: } case Right(cp @ Publish(flags, _, Some(packetId), _)) if flags.contains(ControlPacketFlags.DUP) => val reply = Promise[Consumer.ForwardPublish.type] - consumerPacketRouter ! RemotePacketRouter.Route(packetId, Consumer.DupPublishReceivedFromRemote(reply), reply) + consumerPacketRouter ! RemotePacketRouter.Route(None, + packetId, + Consumer.DupPublishReceivedFromRemote(reply), + reply) reply.future.map(_ => Right(Event(cp))) case Right(cp: Publish) => val reply = Promise[Consumer.ForwardPublish.type] @@ -255,7 +267,10 @@ final class ActorMqttClientSession(settings: MqttSessionSettings)(implicit mat: } case Right(cp: PubRel) => val reply = Promise[Consumer.ForwardPubRel.type] - consumerPacketRouter ! RemotePacketRouter.Route(cp.packetId, Consumer.PubRelReceivedFromRemote(reply), reply) + consumerPacketRouter ! RemotePacketRouter.Route(None, + cp.packetId, + Consumer.PubRelReceivedFromRemote(reply), + reply) reply.future.map(_ => Right(Event(cp))) case Right(cp: PubComp) => val reply = Promise[Producer.ForwardPubComp] @@ -420,25 +435,38 @@ final class ActorMqttServerSession(settings: MqttSessionSettings)(implicit mat: ) case Command(cp: SubAck, _) => val reply = Promise[Publisher.ForwardSubAck.type] - publisherPacketRouter ! RemotePacketRouter.Route(cp.packetId, Publisher.SubAckReceivedLocally(reply), reply) + publisherPacketRouter ! RemotePacketRouter.RouteViaConnection(connectionId, + cp.packetId, + Publisher.SubAckReceivedLocally(reply), + reply) Source.fromFuture(reply.future.map(_ => cp.encode(ByteString.newBuilder).result())) case Command(cp: UnsubAck, _) => val reply = Promise[Unpublisher.ForwardUnsubAck.type] - unpublisherPacketRouter ! RemotePacketRouter.Route(cp.packetId, - Unpublisher.UnsubAckReceivedLocally(reply), - reply) + unpublisherPacketRouter ! RemotePacketRouter.RouteViaConnection(connectionId, + cp.packetId, + Unpublisher.UnsubAckReceivedLocally(reply), + reply) Source.fromFuture(reply.future.map(_ => cp.encode(ByteString.newBuilder).result())) case Command(cp: PubAck, _) => val reply = Promise[Consumer.ForwardPubAck.type] - consumerPacketRouter ! RemotePacketRouter.Route(cp.packetId, Consumer.PubAckReceivedLocally(reply), reply) + consumerPacketRouter ! RemotePacketRouter.RouteViaConnection(connectionId, + cp.packetId, + Consumer.PubAckReceivedLocally(reply), + reply) Source.fromFuture(reply.future.map(_ => cp.encode(ByteString.newBuilder).result())) case Command(cp: PubRec, _) => val reply = Promise[Consumer.ForwardPubRec.type] - consumerPacketRouter ! RemotePacketRouter.Route(cp.packetId, Consumer.PubRecReceivedLocally(reply), reply) + consumerPacketRouter ! RemotePacketRouter.RouteViaConnection(connectionId, + cp.packetId, + Consumer.PubRecReceivedLocally(reply), + reply) Source.fromFuture(reply.future.map(_ => cp.encode(ByteString.newBuilder).result())) case Command(cp: PubComp, _) => val reply = Promise[Consumer.ForwardPubComp.type] - consumerPacketRouter ! RemotePacketRouter.Route(cp.packetId, Consumer.PubCompReceivedLocally(reply), reply) + consumerPacketRouter ! RemotePacketRouter.RouteViaConnection(connectionId, + cp.packetId, + Consumer.PubCompReceivedLocally(reply), + reply) Source.fromFuture(reply.future.map(_ => cp.encode(ByteString.newBuilder).result())) case c: Command[A] => throw new IllegalStateException(c + " is not a server command") } @@ -476,7 +504,10 @@ final class ActorMqttServerSession(settings: MqttSessionSettings)(implicit mat: reply.future.map(_ => Right(Event(cp))) case Right(cp @ Publish(flags, _, Some(packetId), _)) if flags.contains(ControlPacketFlags.DUP) => val reply = Promise[Consumer.ForwardPublish.type] - consumerPacketRouter ! RemotePacketRouter.Route(packetId, Consumer.DupPublishReceivedFromRemote(reply), reply) + consumerPacketRouter ! RemotePacketRouter.RouteViaConnection(connectionId, + packetId, + Consumer.DupPublishReceivedFromRemote(reply), + reply) reply.future.map(_ => Right(Event(cp))) case Right(cp: Publish) => val reply = Promise[Consumer.ForwardPublish.type] @@ -496,7 +527,10 @@ final class ActorMqttServerSession(settings: MqttSessionSettings)(implicit mat: } case Right(cp: PubRel) => val reply = Promise[Consumer.ForwardPubRel.type] - consumerPacketRouter ! RemotePacketRouter.Route(cp.packetId, Consumer.PubRelReceivedFromRemote(reply), reply) + consumerPacketRouter ! RemotePacketRouter.RouteViaConnection(connectionId, + cp.packetId, + Consumer.PubRelReceivedFromRemote(reply), + reply) reply.future.map(_ => Right(Event(cp))) case Right(cp: PubComp) => val reply = Promise[Producer.ForwardPubComp] diff --git a/mqtt-streaming/src/test/scala/akka/stream/alpakka/mqtt/streaming/impl/RequestStateSpec.scala b/mqtt-streaming/src/test/scala/akka/stream/alpakka/mqtt/streaming/impl/RequestStateSpec.scala index dba7db4d91..a405dcd3c8 100644 --- a/mqtt-streaming/src/test/scala/akka/stream/alpakka/mqtt/streaming/impl/RequestStateSpec.scala +++ b/mqtt-streaming/src/test/scala/akka/stream/alpakka/mqtt/streaming/impl/RequestStateSpec.scala @@ -6,6 +6,7 @@ package akka.stream.alpakka.mqtt.streaming package impl import akka.actor.testkit.typed.scaladsl.ActorTestKit +import akka.util.ByteString import org.scalatest.concurrent.ScalaFutures import org.scalatest.{BeforeAndAfterAll, Matchers, WordSpec} @@ -83,25 +84,40 @@ class RequestStateSpec extends WordSpec with Matchers with BeforeAndAfterAll wit "remote packet router" should { "route a packet" in { + val clientId = "some-client" val packetId = PacketId(1) + val connectionId = ByteString("some-connection") + val registrant = testKit.createTestProbe[String]() val registerReply = Promise[RemotePacketRouter.Registered.type]() val failureReply1 = Promise[String] val failureReply2 = Promise[String] + val failureReply3 = Promise[String] + val failureReply4 = Promise[String] val router = testKit.spawn(RemotePacketRouter[String]) - router ! RemotePacketRouter.Register(registrant.ref, packetId, registerReply) + router ! RemotePacketRouter.Register(registrant.ref, Some(clientId), packetId, registerReply) registerReply.future.futureValue shouldBe RemotePacketRouter.Registered - router ! RemotePacketRouter.Route(packetId, "some-packet", failureReply1) + router ! RemotePacketRouter.Route(Some(clientId), packetId, "some-packet", failureReply1) registrant.expectMessage("some-packet") failureReply1.future.isCompleted shouldBe false - router ! RemotePacketRouter.Unregister(packetId) - router ! RemotePacketRouter.Route(packetId, "some-packet", failureReply2) + router ! RemotePacketRouter.RegisterConnection(connectionId, clientId) + router ! RemotePacketRouter.RouteViaConnection(connectionId, packetId, "some-packet2", failureReply3) + registrant.expectMessage("some-packet2") + failureReply3.future.isCompleted shouldBe false + + router ! RemotePacketRouter.UnregisterConnection(connectionId) + router ! RemotePacketRouter.RouteViaConnection(connectionId, packetId, "some-packet2", failureReply4) + failureReply4.future.failed.futureValue shouldBe RemotePacketRouter.CannotRoute + registrant.expectNoMessage(100.millis) + + router ! RemotePacketRouter.Unregister(Some(clientId), packetId) + router ! RemotePacketRouter.Route(Some(clientId), packetId, "some-packet", failureReply2) failureReply2.future.failed.futureValue shouldBe RemotePacketRouter.CannotRoute - registrant.expectNoMessage(1.second) + registrant.expectNoMessage(100.millis) } } From d550f756cff6ef0184d51b8f6f8dddadce6673ac Mon Sep 17 00:00:00 2001 From: Christopher Hunt Date: Fri, 16 Nov 2018 08:57:08 +0100 Subject: [PATCH 6/6] Reinstate the topic filtering Accidentally removed topic filtering from a previous commit. This could cause unexpected behaviour given that every publish would be broadcast to all topics. --- .../stream/alpakka/mqtt/streaming/impl/ServerState.scala | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/mqtt-streaming/src/main/scala/akka/stream/alpakka/mqtt/streaming/impl/ServerState.scala b/mqtt-streaming/src/main/scala/akka/stream/alpakka/mqtt/streaming/impl/ServerState.scala index 509aa356f2..e6e0a3446f 100644 --- a/mqtt-streaming/src/main/scala/akka/stream/alpakka/mqtt/streaming/impl/ServerState.scala +++ b/mqtt-streaming/src/main/scala/akka/stream/alpakka/mqtt/streaming/impl/ServerState.scala @@ -503,10 +503,12 @@ import scala.util.{Failure, Success} clientConnected(data.copy(activeConsumers = data.activeConsumers - topicName)) } case (_, PublishReceivedLocally(publish, _)) - if (publish.flags & ControlPacketFlags.QoSReserved).underlying == 0 => + if (publish.flags & ControlPacketFlags.QoSReserved).underlying == 0 && + data.publishers.exists(matchTopicFilter(_, publish.topicName)) => data.remote.offer(ForwardPublish(publish, None)) clientConnected(data) - case (context, prl @ PublishReceivedLocally(publish, publishData)) => + case (context, prl @ PublishReceivedLocally(publish, publishData)) + if data.publishers.exists(matchTopicFilter(_, publish.topicName)) => val producerName = ActorName.mkName(ProducerNamePrefix + publish.topicName + "-" + context.children.size) if (!data.activeProducers.contains(publish.topicName)) { val reply = Promise[Source[Producer.ForwardPublishingCommand, NotUsed]]