Skip to content

Commit

Permalink
fixed subscription restore from client state when mqtt client reconne…
Browse files Browse the repository at this point in the history
…cts (#5048)

Cherry picked (#4189)

When an MQTT client used persistent session, the previous subscriptions did not get restored. If the same client connected to IoTHub directly, its subscriptions (e.g. for receiving twin responses) would have been restored.
In case of c# sdk it would not be a problem, because the sdk automatically resubscribes when it is able to restore the connection to edgeHub.

In case of python sdk however this does not happen. As a result, if a device connects using the python sdk and edgeHub gets restarted, the client will not be receiving m2m messages, twin patches, etc anymore.

The fix seems big but the main code in "MessagingServiceClient.cs", the rest is just the result of a changed interface and some plumbing.
  • Loading branch information
vipeller committed May 28, 2021
1 parent a6d7fee commit aadf030
Show file tree
Hide file tree
Showing 12 changed files with 360 additions and 128 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@ namespace Microsoft.Azure.Devices.Edge.Hub.Core.Device
using System.Diagnostics;
using System.Net;
using System.Threading.Tasks;
using App.Metrics.Infrastructure;
using Microsoft.Azure.Devices.Edge.Hub.Core.Cloud;
using Microsoft.Azure.Devices.Edge.Hub.Core.Identity;
using Microsoft.Azure.Devices.Edge.Util;
Expand Down Expand Up @@ -71,9 +70,12 @@ public Task ProcessMethodResponseAsync(IMessage message)

public void BindDeviceProxy(IDeviceProxy deviceProxy)
{
this.underlyingProxy = Preconditions.CheckNotNull(deviceProxy);
this.connectionManager.AddDeviceConnection(this.Identity, this);
Events.BindDeviceProxy(this.Identity);
this.BindDeviceProxyAsync(deviceProxy, Option.None<Action>());
}

public void BindDeviceProxy(IDeviceProxy deviceProxy, Action initWhenBound)
{
this.BindDeviceProxyAsync(deviceProxy, Option.Some(initWhenBound));
}

public async Task CloseAsync()
Expand Down Expand Up @@ -247,6 +249,33 @@ public async Task UpdateReportedPropertiesAsync(IMessage reportedPropertiesMessa
}
}

async void BindDeviceProxyAsync(IDeviceProxy deviceProxy, Option<Action> initWhenBound)
{
this.underlyingProxy = Preconditions.CheckNotNull(deviceProxy);

try
{
await this.connectionManager.AddDeviceConnection(this.Identity, this);
}
catch (Exception ex)
{
Events.ErrorBindingDeviceProxy(this.Identity, ex);
return;
}

try
{
initWhenBound.ForEach(a => a?.Invoke());
}
catch (Exception ex)
{
Events.ErrorPostBindAction(this.Identity, ex);
return;
}

Events.BindDeviceProxy(this.Identity);
}

async Task HandleTwinOperationException(string correlationId, Exception e)
{
if (!string.IsNullOrWhiteSpace(correlationId))
Expand Down Expand Up @@ -276,6 +305,8 @@ static class Events
enum EventIds
{
BindDeviceProxy = IdStart,
ErrorBindingDeviceProxy,
ErrorPostBindAction,
RemoveDeviceConnection,
MethodSentToClient,
MethodResponseReceived,
Expand All @@ -301,6 +332,16 @@ public static void BindDeviceProxy(IIdentity identity)
Log.LogInformation((int)EventIds.BindDeviceProxy, Invariant($"Bind device proxy for device {identity.Id}"));
}

public static void ErrorBindingDeviceProxy(IIdentity identity, Exception ex)
{
Log.LogError((int)EventIds.ErrorBindingDeviceProxy, ex, Invariant($"Error binding device proxy for device {identity.Id}"));
}

public static void ErrorPostBindAction(IIdentity identity, Exception ex)
{
Log.LogError((int)EventIds.ErrorPostBindAction, ex, Invariant($"Error executing post-bind action for device {identity.Id}"));
}

public static void Close(IIdentity identity)
{
Log.LogInformation((int)EventIds.RemoveDeviceConnection, Invariant($"Remove device connection for device {identity.Id}"));
Expand Down
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
// Copyright (c) Microsoft. All rights reserved.
namespace Microsoft.Azure.Devices.Edge.Hub.Core.Device
{
using System;
using System.Collections.Generic;
using System.Threading.Tasks;
using Microsoft.Azure.Devices.Edge.Hub.Core.Identity;
using Microsoft.Azure.Devices.Edge.Util;

public interface IDeviceListener
{
Expand All @@ -21,6 +23,8 @@ public interface IDeviceListener

void BindDeviceProxy(IDeviceProxy deviceProxy);

void BindDeviceProxy(IDeviceProxy deviceProxy, Action initWhenBound);

Task CloseAsync();

Task ProcessMessageFeedbackAsync(string messageId, FeedbackStatus feedbackStatus);
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
// Copyright (c) Microsoft. All rights reserved.
namespace Microsoft.Azure.Devices.Edge.Hub.Mqtt
{
using Microsoft.Azure.Devices.Edge.Util;
using Microsoft.Azure.Devices.ProtocolGateway.Identity;

public class AuthenticatedIdentity : IDeviceIdentity
{
public AuthenticatedIdentity(string id)
{
this.Id = Preconditions.CheckNotNull(id, nameof(id));
}

public bool IsAuthenticated => true;
public string Id { get; }
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,8 @@ namespace Microsoft.Azure.Devices.Edge.Hub.Mqtt
using Microsoft.Azure.Devices.Edge.Hub.Core.Device;
using Microsoft.Azure.Devices.Edge.Hub.Core.Identity;
using Microsoft.Azure.Devices.Edge.Util;
using Microsoft.Azure.Devices.Edge.Util.Metrics;
using Microsoft.Azure.Devices.ProtocolGateway.Messaging;
using Microsoft.Azure.Devices.ProtocolGateway.Mqtt.Persistence;
using Microsoft.Extensions.Logging;
using Microsoft.Extensions.Primitives;
using static System.FormattableString;
Expand All @@ -24,12 +24,14 @@ public class MessagingServiceClient : IMessagingServiceClient
readonly IDeviceListener deviceListener;
readonly IMessageConverter<IProtocolGatewayMessage> messageConverter;
readonly IByteBufferConverter byteBufferConverter;
readonly ISessionStatePersistenceProvider sessionStatePersistenceProvider;

public MessagingServiceClient(IDeviceListener deviceListener, IMessageConverter<IProtocolGatewayMessage> messageConverter, IByteBufferConverter byteBufferConverter)
public MessagingServiceClient(IDeviceListener deviceListener, IMessageConverter<IProtocolGatewayMessage> messageConverter, IByteBufferConverter byteBufferConverter, ISessionStatePersistenceProvider sessionStatePersistenceProvider)
{
this.deviceListener = Preconditions.CheckNotNull(deviceListener, nameof(deviceListener));
this.messageConverter = Preconditions.CheckNotNull(messageConverter, nameof(messageConverter));
this.byteBufferConverter = Preconditions.CheckNotNull(byteBufferConverter, nameof(byteBufferConverter));
this.sessionStatePersistenceProvider = Preconditions.CheckNotNull(sessionStatePersistenceProvider, nameof(sessionStatePersistenceProvider));
}

public int MaxPendingMessages => 100;
Expand All @@ -43,8 +45,38 @@ public IProtocolGatewayMessage CreateMessage(string address, IByteBuffer payload

public void BindMessagingChannel(IMessagingChannel<IProtocolGatewayMessage> channel)
{
var sessionStateQuery = this.sessionStatePersistenceProvider.GetAsync(new AuthenticatedIdentity(this.deviceListener.Identity.Id));

IDeviceProxy deviceProxy = new DeviceProxy(channel, this.deviceListener.Identity, this.messageConverter, this.byteBufferConverter);
this.deviceListener.BindDeviceProxy(deviceProxy);
this.deviceListener.BindDeviceProxy(
deviceProxy,
async () =>
{
var sessionState = await sessionStateQuery;
if (sessionState is SessionState registrationSessionState)
{
var subscriptions = SessionStateParser.GetDeviceSubscriptions(registrationSessionState.SubscriptionRegistrations, this.deviceListener.Identity.Id);
foreach ((DeviceSubscription deviceSubscription, bool addSubscription) in subscriptions)
{
if (deviceSubscription == DeviceSubscription.Unknown)
{
continue;
}
if (addSubscription)
{
await this.deviceListener.AddSubscription(deviceSubscription);
}
else
{
await this.deviceListener.RemoveSubscription(deviceSubscription);
}
}
}
});

Events.BindMessageChannel(this.deviceListener.Identity);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ namespace Microsoft.Azure.Devices.Edge.Hub.Mqtt
using Microsoft.Azure.Devices.Edge.Hub.Core.Device;
using Microsoft.Azure.Devices.Edge.Util;
using Microsoft.Azure.Devices.ProtocolGateway.Messaging;
using Microsoft.Azure.Devices.ProtocolGateway.Mqtt.Persistence;
using IDeviceIdentity = Microsoft.Azure.Devices.ProtocolGateway.Identity.IDeviceIdentity;
using IProtocolGatewayMessage = Microsoft.Azure.Devices.ProtocolGateway.Messaging.IMessage;

Expand All @@ -15,12 +16,14 @@ public class MqttConnectionProvider : IMqttConnectionProvider
readonly IConnectionProvider connectionProvider;
readonly IMessageConverter<IProtocolGatewayMessage> pgMessageConverter;
readonly IByteBufferConverter byteBufferConverter;
readonly ISessionStatePersistenceProvider sessionStatePersistenceProvider;

public MqttConnectionProvider(IConnectionProvider connectionProvider, IMessageConverter<IProtocolGatewayMessage> pgMessageConverter, IByteBufferConverter byteBufferConverter)
public MqttConnectionProvider(IConnectionProvider connectionProvider, IMessageConverter<IProtocolGatewayMessage> pgMessageConverter, IByteBufferConverter byteBufferConverter, ISessionStatePersistenceProvider sessionStatePersistenceProvider)
{
this.connectionProvider = Preconditions.CheckNotNull(connectionProvider, nameof(connectionProvider));
this.pgMessageConverter = Preconditions.CheckNotNull(pgMessageConverter, nameof(pgMessageConverter));
this.byteBufferConverter = Preconditions.CheckNotNull(byteBufferConverter, nameof(byteBufferConverter));
this.sessionStatePersistenceProvider = Preconditions.CheckNotNull(sessionStatePersistenceProvider, nameof(sessionStatePersistenceProvider));
}

public async Task<IMessagingBridge> Connect(IDeviceIdentity deviceidentity)
Expand All @@ -31,7 +34,7 @@ public async Task<IMessagingBridge> Connect(IDeviceIdentity deviceidentity)
}

IDeviceListener deviceListener = await this.connectionProvider.GetDeviceListenerAsync(protocolGatewayIdentity.ClientCredentials.Identity, protocolGatewayIdentity.ModelId);
IMessagingServiceClient messagingServiceClient = new MessagingServiceClient(deviceListener, this.pgMessageConverter, this.byteBufferConverter);
IMessagingServiceClient messagingServiceClient = new MessagingServiceClient(deviceListener, this.pgMessageConverter, this.byteBufferConverter, this.sessionStatePersistenceProvider);
IMessagingBridge messagingBridge = new SingleClientMessagingBridge(deviceidentity, messagingServiceClient);

return messagingBridge;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ public static class MqttEventIds
public const int DeviceProxy = EventIdStart + 100;
public const int MessagingServiceClient = EventIdStart + 200;
public const int SessionStatePersistenceProvider = EventIdStart + 300;
public const int SessionStateParser = EventIdStart + 320;
public const int SessionStateStoragePersistenceProvider = EventIdStart + 400;
public const int MqttWebSocketListener = EventIdStart + 500;
public const int ServerWebSocketChannel = EventIdStart + 600;
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,83 @@
// Copyright (c) Microsoft. All rights reserved.
namespace Microsoft.Azure.Devices.Edge.Hub.Mqtt
{
using System.Collections.Generic;
using System.Linq;
using System.Text.RegularExpressions;
using Microsoft.Azure.Devices.Edge.Hub.Core.Device;
using Microsoft.Azure.Devices.Edge.Util;
using Microsoft.Extensions.Logging;

using static System.FormattableString;

public class SessionStateParser
{
internal const string C2DSubscriptionTopicPrefix = @"messages/devicebound/#";
internal const string MethodSubscriptionTopicPrefix = @"$iothub/methods/POST/";
internal const string TwinSubscriptionTopicPrefix = @"$iothub/twin/PATCH/properties/desired/";
internal const string TwinResponseTopicFilter = "$iothub/twin/res/#";
internal static readonly Regex ModuleMessageTopicRegex = new Regex("^devices/.+/modules/.+/#$");

public static IEnumerable<(DeviceSubscription, bool)> GetDeviceSubscriptions(IReadOnlyDictionary<string, bool> topics, string id)
{
return topics.Select(
subscriptionRegistration =>
{
string topicName = subscriptionRegistration.Key;
bool addSubscription = subscriptionRegistration.Value;
DeviceSubscription deviceSubscription = GetDeviceSubscription(topicName);
if (deviceSubscription == DeviceSubscription.Unknown)
{
Events.UnknownTopicSubscription(topicName, id);
}
return (deviceSubscription, addSubscription);
});
}

public static DeviceSubscription GetDeviceSubscription(string topicName)
{
Preconditions.CheckNonWhiteSpace(topicName, nameof(topicName));
if (topicName.StartsWith(MethodSubscriptionTopicPrefix))
{
return DeviceSubscription.Methods;
}
else if (topicName.StartsWith(TwinSubscriptionTopicPrefix))
{
return DeviceSubscription.DesiredPropertyUpdates;
}
else if (topicName.EndsWith(C2DSubscriptionTopicPrefix))
{
return DeviceSubscription.C2D;
}
else if (topicName.Equals(TwinResponseTopicFilter))
{
return DeviceSubscription.TwinResponse;
}
else if (ModuleMessageTopicRegex.IsMatch(topicName))
{
return DeviceSubscription.ModuleMessages;
}
else
{
return DeviceSubscription.Unknown;
}
}

static class Events
{
const int IdStart = MqttEventIds.SessionStateParser;
static readonly ILogger Log = Logger.Factory.CreateLogger<SessionStateParser>();

enum EventIds
{
UnknownSubscription = IdStart,
}

public static void UnknownTopicSubscription(string topicName, string id)
{
Log.LogInformation((int)EventIds.UnknownSubscription, Invariant($"Ignoring unknown subscription to topic {topicName} for client {id}."));
}
}
}
}

0 comments on commit aadf030

Please sign in to comment.