From 449624fd72bb9962772bd45db18e3f2bbf73e089 Mon Sep 17 00:00:00 2001 From: Christian <6939810+chkr1011@users.noreply.github.com> Date: Sat, 6 May 2023 12:49:41 +0200 Subject: [PATCH] Expose proper connect result in clients Disconnected event (#1728) --- .github/workflows/ReleaseNotes.md | 1 + .../MqttFactoryExtensions.cs | 27 ++- .../ManagedMqttClient_Tests.cs | 32 +++ .../MqttClient/MqttClient_Connection_Tests.cs | 2 + .../MQTTnet.Tests/Mockups/TestEnvironment.cs | 2 +- .../Client/Internal/MqttClientEvents.cs | 18 ++ .../Internal/MqttClientResultFactory.cs | 13 ++ Source/MQTTnet/Client/MqttClient.cs | 212 ++++++++++-------- 8 files changed, 205 insertions(+), 102 deletions(-) create mode 100644 Source/MQTTnet/Client/Internal/MqttClientEvents.cs create mode 100644 Source/MQTTnet/Client/Internal/MqttClientResultFactory.cs diff --git a/.github/workflows/ReleaseNotes.md b/.github/workflows/ReleaseNotes.md index b8635cbe1..0aafba3b7 100644 --- a/.github/workflows/ReleaseNotes.md +++ b/.github/workflows/ReleaseNotes.md @@ -1,3 +1,4 @@ * [Core] Add validation of maximum string lengths (#1718). * [Client] Added overloads for setting packet payload and will payload (#1720). +* [Client] The proper connect result is now exposed in the _Disconnected_ event when authentication fails (#1139). * [Server] Improved performance by changing internal locking strategy for subscriptions (#1716, thanks to @zeheng). diff --git a/Source/MQTTnet.Extensions.ManagedClient/MqttFactoryExtensions.cs b/Source/MQTTnet.Extensions.ManagedClient/MqttFactoryExtensions.cs index 3349ca5d4..f5b9a7335 100644 --- a/Source/MQTTnet.Extensions.ManagedClient/MqttFactoryExtensions.cs +++ b/Source/MQTTnet.Extensions.ManagedClient/MqttFactoryExtensions.cs @@ -12,22 +12,37 @@ public static class MqttFactoryExtensions { public static IManagedMqttClient CreateManagedMqttClient(this MqttFactory factory, IMqttClient mqttClient = null) { - if (factory == null) throw new ArgumentNullException(nameof(factory)); + if (factory == null) + { + throw new ArgumentNullException(nameof(factory)); + } if (mqttClient == null) { - return new ManagedMqttClient(factory.CreateMqttClient(), factory.DefaultLogger); + return new ManagedMqttClient(factory.CreateMqttClient(), factory.DefaultLogger); } - + return new ManagedMqttClient(mqttClient, factory.DefaultLogger); } - + public static IManagedMqttClient CreateManagedMqttClient(this MqttFactory factory, IMqttNetLogger logger) { - if (factory == null) throw new ArgumentNullException(nameof(factory)); - if (logger == null) throw new ArgumentNullException(nameof(logger)); + if (factory == null) + { + throw new ArgumentNullException(nameof(factory)); + } + + if (logger == null) + { + throw new ArgumentNullException(nameof(logger)); + } return new ManagedMqttClient(factory.CreateMqttClient(logger), logger); } + + public static ManagedMqttClientOptionsBuilder CreateManagedMqttClientOptionsBuilder(this MqttFactory _) + { + return new ManagedMqttClientOptionsBuilder(); + } } } \ No newline at end of file diff --git a/Source/MQTTnet.Tests/Clients/ManagedMqttClient/ManagedMqttClient_Tests.cs b/Source/MQTTnet.Tests/Clients/ManagedMqttClient/ManagedMqttClient_Tests.cs index f16d55fa8..665519836 100644 --- a/Source/MQTTnet.Tests/Clients/ManagedMqttClient/ManagedMqttClient_Tests.cs +++ b/Source/MQTTnet.Tests/Clients/ManagedMqttClient/ManagedMqttClient_Tests.cs @@ -22,6 +22,38 @@ namespace MQTTnet.Tests.Clients.ManagedMqttClient [TestClass] public sealed class ManagedMqttClient_Tests : BaseTestClass { + [TestMethod] + public async Task Expose_Custom_Connection_Error() + { + using (var testEnvironment = CreateTestEnvironment()) + { + var server = await testEnvironment.StartServer(); + + server.ValidatingConnectionAsync += args => + { + args.ReasonCode = MqttConnectReasonCode.BadUserNameOrPassword; + return CompletedTask.Instance; + }; + + var managedClient = testEnvironment.Factory.CreateManagedMqttClient(); + + MqttClientDisconnectedEventArgs disconnectedArgs = null; + managedClient.DisconnectedAsync += args => + { + disconnectedArgs = args; + return CompletedTask.Instance; + }; + + var clientOptions = testEnvironment.Factory.CreateManagedMqttClientOptionsBuilder().WithClientOptions(testEnvironment.CreateDefaultClientOptions()).Build(); + await managedClient.StartAsync(clientOptions); + + await LongTestDelay(); + + Assert.IsNotNull(disconnectedArgs); + Assert.AreEqual(MqttClientConnectResultCode.BadUserNameOrPassword, disconnectedArgs.ConnectResult.ResultCode); + } + } + [TestMethod] public async Task Receive_While_Not_Cleanly_Disconnected() { diff --git a/Source/MQTTnet.Tests/Clients/MqttClient/MqttClient_Connection_Tests.cs b/Source/MQTTnet.Tests/Clients/MqttClient/MqttClient_Connection_Tests.cs index 1cdcd7608..4a2e1f063 100644 --- a/Source/MQTTnet.Tests/Clients/MqttClient/MqttClient_Connection_Tests.cs +++ b/Source/MQTTnet.Tests/Clients/MqttClient/MqttClient_Connection_Tests.cs @@ -127,6 +127,8 @@ public async Task Disconnect_Clean_With_Custom_Reason() // Perform a clean disconnect. await client.DisconnectAsync(disconnectOptions); + await LongTestDelay(); + Assert.IsNotNull(eventArgs); Assert.AreEqual(MqttDisconnectReasonCode.MessageRateTooHigh, eventArgs.ReasonCode); } diff --git a/Source/MQTTnet.Tests/Mockups/TestEnvironment.cs b/Source/MQTTnet.Tests/Mockups/TestEnvironment.cs index 9995ee82c..b3543706b 100644 --- a/Source/MQTTnet.Tests/Mockups/TestEnvironment.cs +++ b/Source/MQTTnet.Tests/Mockups/TestEnvironment.cs @@ -245,7 +245,7 @@ public MqttClientOptions CreateDefaultClientOptions() public MqttClientOptionsBuilder CreateDefaultClientOptionsBuilder() { - return Factory.CreateClientOptionsBuilder().WithProtocolVersion(_protocolVersion).WithTcpServer("127.0.0.1", ServerPort); + return Factory.CreateClientOptionsBuilder().WithProtocolVersion(_protocolVersion).WithTcpServer("127.0.0.1", ServerPort).WithClientId(TestContext.TestName + "_" + Guid.NewGuid()); } public ILowLevelMqttClient CreateLowLevelClient() diff --git a/Source/MQTTnet/Client/Internal/MqttClientEvents.cs b/Source/MQTTnet/Client/Internal/MqttClientEvents.cs new file mode 100644 index 000000000..fe786e82c --- /dev/null +++ b/Source/MQTTnet/Client/Internal/MqttClientEvents.cs @@ -0,0 +1,18 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + +using MQTTnet.Diagnostics; +using MQTTnet.Internal; + +namespace MQTTnet.Client.Internal +{ + public sealed class MqttClientEvents + { + public AsyncEvent ApplicationMessageReceivedEvent { get; } = new AsyncEvent(); + public AsyncEvent ConnectedEvent { get; } = new AsyncEvent(); + public AsyncEvent ConnectingEvent { get; } = new AsyncEvent(); + public AsyncEvent DisconnectedEvent { get; } = new AsyncEvent(); + public AsyncEvent InspectPacketEvent { get; } = new AsyncEvent(); + } +} \ No newline at end of file diff --git a/Source/MQTTnet/Client/Internal/MqttClientResultFactory.cs b/Source/MQTTnet/Client/Internal/MqttClientResultFactory.cs new file mode 100644 index 000000000..ef2800fe7 --- /dev/null +++ b/Source/MQTTnet/Client/Internal/MqttClientResultFactory.cs @@ -0,0 +1,13 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + +namespace MQTTnet.Client.Internal +{ + public static class MqttClientResultFactory + { + public static readonly MqttClientPublishResultFactory PublishResult = new MqttClientPublishResultFactory(); + public static readonly MqttClientSubscribeResultFactory SubscribeResult = new MqttClientSubscribeResultFactory(); + public static readonly MqttClientUnsubscribeResultFactory UnsubscribeResult = new MqttClientUnsubscribeResultFactory(); + } +} \ No newline at end of file diff --git a/Source/MQTTnet/Client/MqttClient.cs b/Source/MQTTnet/Client/MqttClient.cs index 876c13905..0719ed1fa 100644 --- a/Source/MQTTnet/Client/MqttClient.cs +++ b/Source/MQTTnet/Client/MqttClient.cs @@ -7,6 +7,7 @@ using System.Threading; using System.Threading.Tasks; using MQTTnet.Adapter; +using MQTTnet.Client.Internal; using MQTTnet.Diagnostics; using MQTTnet.Exceptions; using MQTTnet.Formatter; @@ -20,17 +21,9 @@ namespace MQTTnet.Client public sealed class MqttClient : Disposable, IMqttClient { readonly IMqttClientAdapterFactory _adapterFactory; - readonly AsyncEvent _applicationMessageReceivedEvent = new AsyncEvent(); - readonly MqttClientPublishResultFactory _clientPublishResultFactory = new MqttClientPublishResultFactory(); - readonly MqttClientSubscribeResultFactory _clientSubscribeResultFactory = new MqttClientSubscribeResultFactory(); - readonly MqttClientUnsubscribeResultFactory _clientUnsubscribeResultFactory = new MqttClientUnsubscribeResultFactory(); - - readonly AsyncEvent _connectedEvent = new AsyncEvent(); - readonly AsyncEvent _connectingEvent = new AsyncEvent(); - readonly AsyncEvent _disconnectedEvent = new AsyncEvent(); readonly object _disconnectLock = new object(); - readonly AsyncEvent _inspectPacketEvent = new AsyncEvent(); + readonly MqttClientEvents _events = new MqttClientEvents(); readonly MqttNetSourceLogger _logger; readonly MqttPacketIdentifierProvider _packetIdentifierProvider = new MqttPacketIdentifierProvider(); @@ -65,32 +58,32 @@ public MqttClient(IMqttClientAdapterFactory channelFactory, IMqttNetLogger logge public event Func ApplicationMessageReceivedAsync { - add => _applicationMessageReceivedEvent.AddHandler(value); - remove => _applicationMessageReceivedEvent.RemoveHandler(value); + add => _events.ApplicationMessageReceivedEvent.AddHandler(value); + remove => _events.ApplicationMessageReceivedEvent.RemoveHandler(value); } public event Func ConnectedAsync { - add => _connectedEvent.AddHandler(value); - remove => _connectedEvent.RemoveHandler(value); + add => _events.ConnectedEvent.AddHandler(value); + remove => _events.ConnectedEvent.RemoveHandler(value); } public event Func ConnectingAsync { - add => _connectingEvent.AddHandler(value); - remove => _connectingEvent.RemoveHandler(value); + add => _events.ConnectingEvent.AddHandler(value); + remove => _events.ConnectingEvent.RemoveHandler(value); } public event Func DisconnectedAsync { - add => _disconnectedEvent.AddHandler(value); - remove => _disconnectedEvent.RemoveHandler(value); + add => _events.DisconnectedEvent.AddHandler(value); + remove => _events.DisconnectedEvent.RemoveHandler(value); } public event Func InspectPacketAsync { - add => _inspectPacketEvent.AddHandler(value); - remove => _inspectPacketEvent.RemoveHandler(value); + add => _events.InspectPacketEvent.AddHandler(value); + remove => _events.InspectPacketEvent.RemoveHandler(value); } public bool IsConnected => (MqttClientConnectionStatus)_connectionStatus == MqttClientConnectionStatus.Connected; @@ -114,9 +107,9 @@ public async Task ConnectAsync(MqttClientOptions option { Options = options; - if (_connectingEvent.HasHandlers) + if (_events.ConnectingEvent.HasHandlers) { - await _connectingEvent.InvokeAsync(new MqttClientConnectingEventArgs(options)); + await _events.ConnectingEvent.InvokeAsync(new MqttClientConnectingEventArgs(options)); } Cleanup(); @@ -127,7 +120,7 @@ public async Task ConnectAsync(MqttClientOptions option _mqttClientAlive = new CancellationTokenSource(); var mqttClientAliveToken = _mqttClientAlive.Token; - var adapter = _adapterFactory.CreateClientAdapter(options, new MqttPacketInspector(_inspectPacketEvent, _rootLogger), _rootLogger); + var adapter = _adapterFactory.CreateClientAdapter(options, new MqttPacketInspector(_events.InspectPacketEvent, _rootLogger), _rootLogger); _adapter = adapter; if (cancellationToken.CanBeCanceled) @@ -149,34 +142,35 @@ public async Task ConnectAsync(MqttClientOptions option var keepAliveInterval = Options.KeepAlivePeriod; if (connectResult.ServerKeepAlive > 0) { - _logger.Info($"Using keep alive value ({connectResult.ServerKeepAlive}) sent from the server."); + _logger.Info($"Using keep alive value ({connectResult.ServerKeepAlive}) sent from the server"); keepAliveInterval = TimeSpan.FromSeconds(connectResult.ServerKeepAlive); } if (keepAliveInterval != TimeSpan.Zero) { - _keepAlivePacketsSenderTask = Task.Run(() => TrySendKeepAliveMessagesAsync(mqttClientAliveToken), mqttClientAliveToken); + _keepAlivePacketsSenderTask = Task.Run(() => TrySendKeepAliveMessages(mqttClientAliveToken), mqttClientAliveToken); } CompareExchangeConnectionStatus(MqttClientConnectionStatus.Connected, MqttClientConnectionStatus.Connecting); - _logger.Info("Connected."); + _logger.Info("Connected"); - if (_connectedEvent.HasHandlers) - { - var eventArgs = new MqttClientConnectedEventArgs(connectResult); - await _connectedEvent.InvokeAsync(eventArgs).ConfigureAwait(false); - } + await OnConnected(connectResult).ConfigureAwait(false); return connectResult; } catch (Exception exception) { + if (exception is MqttConnectingFailedException connectingFailedException) + { + connectResult = connectingFailedException.Result; + } + _disconnectReason = (int)MqttClientDisconnectOptionsReason.UnspecifiedError; - _logger.Error(exception, "Error while connecting with server."); + _logger.Error(exception, "Error while connecting with server"); - await DisconnectInternalAsync(null, exception, connectResult).ConfigureAwait(false); + await DisconnectInternal(null, exception, connectResult).ConfigureAwait(false); throw; } @@ -219,19 +213,19 @@ public async Task DisconnectAsync(MqttClientDisconnectOptions options, Cancellat if (cancellationToken.CanBeCanceled) { - await SendAsync(disconnectPacket, cancellationToken).ConfigureAwait(false); + await Send(disconnectPacket, cancellationToken).ConfigureAwait(false); } else { using (var timeout = new CancellationTokenSource(Options.Timeout)) { - await SendAsync(disconnectPacket, timeout.Token).ConfigureAwait(false); + await Send(disconnectPacket, timeout.Token).ConfigureAwait(false); } } } finally { - await DisconnectCoreAsync(null, null, null, clientWasConnected).ConfigureAwait(false); + await DisconnectCore(null, null, null, clientWasConnected).ConfigureAwait(false); } } @@ -239,13 +233,13 @@ public async Task PingAsync(CancellationToken cancellationToken = default) { if (cancellationToken.CanBeCanceled) { - await SendAndReceiveAsync(MqttPingReqPacket.Instance, cancellationToken).ConfigureAwait(false); + await Request(MqttPingReqPacket.Instance, cancellationToken).ConfigureAwait(false); } else { using (var timeout = new CancellationTokenSource(Options.Timeout)) { - await SendAndReceiveAsync(MqttPingReqPacket.Instance, timeout.Token).ConfigureAwait(false); + await Request(MqttPingReqPacket.Instance, timeout.Token).ConfigureAwait(false); } } } @@ -274,11 +268,11 @@ public Task PublishAsync(MqttApplicationMessage applica } case MqttQualityOfServiceLevel.AtLeastOnce: { - return PublishAtLeastOnceAsync(publishPacket, cancellationToken); + return PublishAtLeastOnce(publishPacket, cancellationToken); } case MqttQualityOfServiceLevel.ExactlyOnce: { - return PublishExactlyOnceAsync(publishPacket, cancellationToken); + return PublishExactlyOnce(publishPacket, cancellationToken); } default: { @@ -306,7 +300,7 @@ public Task SendExtendedAuthenticationExchangeDataAsync(MqttExtendedAuthenticati UserProperties = data.UserProperties }; - return SendAsync(authPacket, cancellationToken); + return Send(authPacket, cancellationToken); } public async Task SubscribeAsync(MqttClientSubscribeOptions options, CancellationToken cancellationToken = default) @@ -335,17 +329,17 @@ public async Task SubscribeAsync(MqttClientSubscribeO MqttSubAckPacket subAckPacket; if (cancellationToken.CanBeCanceled) { - subAckPacket = await SendAndReceiveAsync(subscribePacket, cancellationToken).ConfigureAwait(false); + subAckPacket = await Request(subscribePacket, cancellationToken).ConfigureAwait(false); } else { using (var timeout = new CancellationTokenSource(Options.Timeout)) { - subAckPacket = await SendAndReceiveAsync(subscribePacket, timeout.Token).ConfigureAwait(false); + subAckPacket = await Request(subscribePacket, timeout.Token).ConfigureAwait(false); } } - return _clientSubscribeResultFactory.Create(subscribePacket, subAckPacket); + return MqttClientResultFactory.SubscribeResult.Create(subscribePacket, subAckPacket); } public async Task UnsubscribeAsync(MqttClientUnsubscribeOptions options, CancellationToken cancellationToken = default) @@ -374,17 +368,17 @@ public async Task UnsubscribeAsync(MqttClientUnsubs MqttUnsubAckPacket unsubAckPacket; if (cancellationToken.CanBeCanceled) { - unsubAckPacket = await SendAndReceiveAsync(unsubscribePacket, cancellationToken).ConfigureAwait(false); + unsubAckPacket = await Request(unsubscribePacket, cancellationToken).ConfigureAwait(false); } else { using (var timeout = new CancellationTokenSource(Options.Timeout)) { - unsubAckPacket = await SendAndReceiveAsync(unsubscribePacket, timeout.Token).ConfigureAwait(false); + unsubAckPacket = await Request(unsubscribePacket, timeout.Token).ConfigureAwait(false); } } - return _clientUnsubscribeResultFactory.Create(unsubscribePacket, unsubAckPacket); + return MqttClientResultFactory.UnsubscribeResult.Create(unsubscribePacket, unsubAckPacket); } protected override void Dispose(bool disposing) @@ -408,7 +402,7 @@ Task AcknowledgeReceivedPublishPacket(MqttApplicationMessageReceivedEventArgs ev if (!eventArgs.ProcessingFailed) { var pubAckPacket = MqttPacketFactories.PubAck.Create(eventArgs); - return SendAsync(pubAckPacket, cancellationToken); + return Send(pubAckPacket, cancellationToken); } } else if (eventArgs.PublishPacket.QualityOfServiceLevel == MqttQualityOfServiceLevel.ExactlyOnce) @@ -416,7 +410,7 @@ Task AcknowledgeReceivedPublishPacket(MqttApplicationMessageReceivedEventArgs ev if (!eventArgs.ProcessingFailed) { var pubRecPacket = MqttPacketFactories.PubRec.Create(eventArgs); - return SendAsync(pubRecPacket, cancellationToken); + return Send(pubRecPacket, cancellationToken); } } else @@ -427,18 +421,26 @@ Task AcknowledgeReceivedPublishPacket(MqttApplicationMessageReceivedEventArgs ev return CompletedTask.Instance; } - async Task AuthenticateAsync(MqttClientOptions options, CancellationToken cancellationToken) + async Task Authenticate(MqttClientOptions options, CancellationToken cancellationToken) { MqttClientConnectResult result; try { var connectPacket = MqttPacketFactories.Connect.Create(options); + await Send(connectPacket, cancellationToken).ConfigureAwait(false); - var connAckPacket = await SendAndReceiveAsync(connectPacket, cancellationToken).ConfigureAwait(false); + var receivedPacket = await Receive(cancellationToken).ConfigureAwait(false); - var clientConnectResultFactory = new MqttClientConnectResultFactory(); - result = clientConnectResultFactory.Create(connAckPacket, _adapter.PacketFormatterAdapter.ProtocolVersion); + if (receivedPacket is MqttConnAckPacket connAckPacket) + { + var clientConnectResultFactory = new MqttClientConnectResultFactory(); + result = clientConnectResultFactory.Create(connAckPacket, _adapter.PacketFormatterAdapter.ProtocolVersion); + } + else + { + throw new NotSupportedException("Extended authentication handler is not yet supported"); + } } catch (Exception exception) { @@ -495,14 +497,16 @@ async Task ConnectInternal(CancellationToken cancellati _publishPacketReceiverQueue?.Dispose(); _publishPacketReceiverQueue = new AsyncQueue(); + var connectResult = await Authenticate(Options, effectiveCancellationToken.Token).ConfigureAwait(false); + _publishPacketReceiverTask = Task.Run(() => ProcessReceivedPublishPackets(backgroundCancellationToken), backgroundCancellationToken); - _packetReceiverTask = Task.Run(() => TryReceivePacketsAsync(backgroundCancellationToken), backgroundCancellationToken); + _packetReceiverTask = Task.Run(() => ReceivePacketsLoop(backgroundCancellationToken), backgroundCancellationToken); - return await AuthenticateAsync(Options, effectiveCancellationToken.Token).ConfigureAwait(false); + return connectResult; } } - async Task DisconnectCoreAsync(Task sender, Exception exception, MqttClientConnectResult connectResult, bool clientWasConnected) + async Task DisconnectCore(Task sender, Exception exception, MqttClientConnectResult connectResult, bool clientWasConnected) { TryInitiateDisconnect(); @@ -557,17 +561,17 @@ async Task DisconnectCoreAsync(Task sender, Exception exception, MqttClientConne // This handler must be executed in a new thread because otherwise a dead lock may happen // when trying to reconnect in that handler etc. - Task.Run(() => _disconnectedEvent.InvokeAsync(eventArgs)).RunInBackground(_logger); + Task.Run(() => _events.DisconnectedEvent.InvokeAsync(eventArgs)).RunInBackground(_logger); } } - Task DisconnectInternalAsync(Task sender, Exception exception, MqttClientConnectResult connectResult) + Task DisconnectInternal(Task sender, Exception exception, MqttClientConnectResult connectResult) { var clientWasConnected = IsConnected; if (!DisconnectIsPendingOrFinished()) { - return DisconnectCoreAsync(sender, exception, connectResult, clientWasConnected); + return DisconnectCore(sender, exception, connectResult, clientWasConnected); } return CompletedTask.Instance; @@ -612,15 +616,26 @@ void EnqueueReceivedPublishPacket(MqttPublishPacket publishPacket) } } - async Task HandleReceivedApplicationMessageAsync(MqttPublishPacket publishPacket) + async Task HandleReceivedApplicationMessage(MqttPublishPacket publishPacket) { var applicationMessage = MqttApplicationMessageFactory.Create(publishPacket); var eventArgs = new MqttApplicationMessageReceivedEventArgs(Options.ClientId, applicationMessage, publishPacket, AcknowledgeReceivedPublishPacket); - await _applicationMessageReceivedEvent.InvokeAsync(eventArgs).ConfigureAwait(false); + await _events.ApplicationMessageReceivedEvent.InvokeAsync(eventArgs).ConfigureAwait(false); return eventArgs; } + Task OnConnected(MqttClientConnectResult connectResult) + { + if (_events.ConnectedEvent.HasHandlers) + { + var eventArgs = new MqttClientConnectedEventArgs(connectResult); + return _events.ConnectedEvent.InvokeAsync(eventArgs); + } + + return CompletedTask.Instance; + } + Task ProcessReceivedAuthPacket(MqttAuthPacket authPacket) { var extendedAuthenticationExchangeHandler = Options.ExtendedAuthenticationExchangeHandler; @@ -641,7 +656,7 @@ Task ProcessReceivedDisconnectPacket(MqttDisconnectPacket disconnectPacket) // Also dispatch disconnect to waiting threads to generate a proper exception. _packetDispatcher.Dispose(new MqttClientUnexpectedDisconnectReceivedException(disconnectPacket)); - return DisconnectInternalAsync(_packetReceiverTask, null, null); + return DisconnectInternal(_packetReceiverTask, null, null); } async Task ProcessReceivedPublishPackets(CancellationToken cancellationToken) @@ -657,7 +672,7 @@ async Task ProcessReceivedPublishPackets(CancellationToken cancellationToken) } var publishPacket = publishPacketDequeueResult.Item; - var eventArgs = await HandleReceivedApplicationMessageAsync(publishPacket).ConfigureAwait(false); + var eventArgs = await HandleReceivedApplicationMessage(publishPacket).ConfigureAwait(false); if (eventArgs.AutoAcknowledge) { @@ -681,7 +696,7 @@ Task ProcessReceivedPubRecPacket(MqttPubRecPacket pubRecPacket, CancellationToke // The packet is unknown. Probably due to a restart of the client. // So wen send this to the server to trigger a full resend of the message. var pubRelPacket = MqttPacketFactories.PubRel.Create(pubRecPacket, MqttApplicationMessageReceivedReasonCode.PacketIdentifierNotFound); - return SendAsync(pubRelPacket, cancellationToken); + return Send(pubRelPacket, cancellationToken); } return CompletedTask.Instance; @@ -690,39 +705,56 @@ Task ProcessReceivedPubRecPacket(MqttPubRecPacket pubRecPacket, CancellationToke Task ProcessReceivedPubRelPacket(MqttPubRelPacket pubRelPacket, CancellationToken cancellationToken) { var pubCompPacket = MqttPacketFactories.PubComp.Create(pubRelPacket, MqttApplicationMessageReceivedReasonCode.Success); - return SendAsync(pubCompPacket, cancellationToken); + return Send(pubCompPacket, cancellationToken); } - async Task PublishAtLeastOnceAsync(MqttPublishPacket publishPacket, CancellationToken cancellationToken) + async Task PublishAtLeastOnce(MqttPublishPacket publishPacket, CancellationToken cancellationToken) { publishPacket.PacketIdentifier = _packetIdentifierProvider.GetNextPacketIdentifier(); - var pubAckPacket = await SendAndReceiveAsync(publishPacket, cancellationToken).ConfigureAwait(false); - return _clientPublishResultFactory.Create(pubAckPacket); + var pubAckPacket = await Request(publishPacket, cancellationToken).ConfigureAwait(false); + return MqttClientResultFactory.PublishResult.Create(pubAckPacket); } async Task PublishAtMostOnce(MqttPublishPacket publishPacket, CancellationToken cancellationToken) { // No packet identifier is used for QoS 0 [3.3.2.2 Packet Identifier] - await SendAsync(publishPacket, cancellationToken).ConfigureAwait(false); + await Send(publishPacket, cancellationToken).ConfigureAwait(false); - return _clientPublishResultFactory.Create(null); + return MqttClientResultFactory.PublishResult.Create(null); } - async Task PublishExactlyOnceAsync(MqttPublishPacket publishPacket, CancellationToken cancellationToken) + async Task PublishExactlyOnce(MqttPublishPacket publishPacket, CancellationToken cancellationToken) { publishPacket.PacketIdentifier = _packetIdentifierProvider.GetNextPacketIdentifier(); - var pubRecPacket = await SendAndReceiveAsync(publishPacket, cancellationToken).ConfigureAwait(false); + var pubRecPacket = await Request(publishPacket, cancellationToken).ConfigureAwait(false); var pubRelPacket = MqttPacketFactories.PubRel.Create(pubRecPacket, MqttApplicationMessageReceivedReasonCode.Success); - var pubCompPacket = await SendAndReceiveAsync(pubRelPacket, cancellationToken).ConfigureAwait(false); + var pubCompPacket = await Request(pubRelPacket, cancellationToken).ConfigureAwait(false); - return _clientPublishResultFactory.Create(pubRecPacket, pubCompPacket); + return MqttClientResultFactory.PublishResult.Create(pubRecPacket, pubCompPacket); } - async Task SendAndReceiveAsync(MqttPacket requestPacket, CancellationToken cancellationToken) where TResponsePacket : MqttPacket + async Task Receive(CancellationToken cancellationToken) + { + var packetTask = _adapter.ReceivePacketAsync(cancellationToken); + + MqttPacket packet; + if (packetTask.IsCompleted) + { + packet = packetTask.Result; + } + else + { + packet = await packetTask.ConfigureAwait(false); + } + + return packet; + } + + async Task Request(MqttPacket requestPacket, CancellationToken cancellationToken) where TResponsePacket : MqttPacket { cancellationToken.ThrowIfCancellationRequested(); @@ -736,7 +768,7 @@ async Task PublishExactlyOnceAsync(MqttPublishPacket pu { try { - await SendAsync(requestPacket, cancellationToken).ConfigureAwait(false); + await Send(requestPacket, cancellationToken).ConfigureAwait(false); } catch (Exception exception) { @@ -760,7 +792,7 @@ async Task PublishExactlyOnceAsync(MqttPublishPacket pu } } - Task SendAsync(MqttPacket packet, CancellationToken cancellationToken) + Task Send(MqttPacket packet, CancellationToken cancellationToken) { cancellationToken.ThrowIfCancellationRequested(); @@ -823,7 +855,7 @@ void TryInitiateDisconnect() } } - async Task TryProcessReceivedPacketAsync(MqttPacket packet, CancellationToken cancellationToken) + async Task TryProcessReceivedPacket(MqttPacket packet, CancellationToken cancellationToken) { try { @@ -884,11 +916,11 @@ async Task TryProcessReceivedPacketAsync(MqttPacket packet, CancellationToken ca _packetDispatcher.FailAll(exception); - await DisconnectInternalAsync(_packetReceiverTask, exception, null).ConfigureAwait(false); + await DisconnectInternal(_packetReceiverTask, exception, null).ConfigureAwait(false); } } - async Task TryReceivePacketsAsync(CancellationToken cancellationToken) + async Task ReceivePacketsLoop(CancellationToken cancellationToken) { try { @@ -896,17 +928,7 @@ async Task TryReceivePacketsAsync(CancellationToken cancellationToken) while (!cancellationToken.IsCancellationRequested) { - MqttPacket packet; - var packetTask = _adapter.ReceivePacketAsync(cancellationToken); - - if (packetTask.IsCompleted) - { - packet = packetTask.Result; - } - else - { - packet = await packetTask.ConfigureAwait(false); - } + var packet = await Receive(cancellationToken).ConfigureAwait(false); if (cancellationToken.IsCancellationRequested) { @@ -915,12 +937,12 @@ async Task TryReceivePacketsAsync(CancellationToken cancellationToken) if (packet == null) { - await DisconnectInternalAsync(_packetReceiverTask, null, null).ConfigureAwait(false); + await DisconnectInternal(_packetReceiverTask, null, null).ConfigureAwait(false); return; } - await TryProcessReceivedPacketAsync(packet, cancellationToken).ConfigureAwait(false); + await TryProcessReceivedPacket(packet, cancellationToken).ConfigureAwait(false); } } catch (Exception exception) @@ -949,7 +971,7 @@ async Task TryReceivePacketsAsync(CancellationToken cancellationToken) _packetDispatcher.FailAll(exception); - await DisconnectInternalAsync(_packetReceiverTask, exception, null).ConfigureAwait(false); + await DisconnectInternal(_packetReceiverTask, exception, null).ConfigureAwait(false); } finally { @@ -957,7 +979,7 @@ async Task TryReceivePacketsAsync(CancellationToken cancellationToken) } } - async Task TrySendKeepAliveMessagesAsync(CancellationToken cancellationToken) + async Task TrySendKeepAliveMessages(CancellationToken cancellationToken) { try { @@ -1006,7 +1028,7 @@ async Task TrySendKeepAliveMessagesAsync(CancellationToken cancellationToken) _logger.Error(exception, "Error exception while sending/receiving keep alive packets."); } - await DisconnectInternalAsync(_keepAlivePacketsSenderTask, exception, null).ConfigureAwait(false); + await DisconnectInternal(_keepAlivePacketsSenderTask, exception, null).ConfigureAwait(false); } finally {