Skip to content

Commit

Permalink
Close AMQP connection explicitly when no more links (removing links k…
Browse files Browse the repository at this point in the history
…ept tcp level connection) (#4984)

cherry pick of #4914
  • Loading branch information
vipeller committed May 14, 2021
1 parent fb8367b commit 5be30fb
Show file tree
Hide file tree
Showing 4 changed files with 104 additions and 23 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -152,7 +152,7 @@ void OnConnectionOpening(object sender, OpenEventArgs e)
amqpConnection.Extensions.Add(cbsNode);
}

IClientConnectionsHandler connectionHandler = new ClientConnectionsHandler(this.connectionProvider);
IClientConnectionsHandler connectionHandler = new ClientConnectionsHandler(this.connectionProvider, amqpConnection);
amqpConnection.Extensions.Add(connectionHandler);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ namespace Microsoft.Azure.Devices.Edge.Hub.Amqp
using System.Linq;
using System.Threading.Tasks;
using System.Web;
using Microsoft.Azure.Amqp;
using Microsoft.Azure.Devices.Edge.Hub.Amqp.LinkHandlers;
using Microsoft.Azure.Devices.Edge.Hub.Core;
using Microsoft.Azure.Devices.Edge.Hub.Core.Device;
Expand All @@ -21,18 +22,21 @@ namespace Microsoft.Azure.Devices.Edge.Hub.Amqp
/// </summary>
class ClientConnectionHandler : IConnectionHandler
{
readonly TimeSpan closeTimeout = TimeSpan.FromSeconds(60);
readonly IDictionary<LinkType, ILinkHandler> registry = new Dictionary<LinkType, ILinkHandler>();
readonly IIdentity identity;
readonly AmqpConnectionBase amqpConnection;

readonly AsyncLock initializationLock = new AsyncLock();
readonly AsyncLock registryUpdateLock = new AsyncLock();
readonly IConnectionProvider connectionProvider;
Option<IDeviceListener> deviceListener = Option.None<IDeviceListener>();

public ClientConnectionHandler(IIdentity identity, IConnectionProvider connectionProvider)
public ClientConnectionHandler(IIdentity identity, IConnectionProvider connectionProvider, AmqpConnectionBase amqpConnection)
{
this.identity = Preconditions.CheckNotNull(identity, nameof(identity));
this.connectionProvider = Preconditions.CheckNotNull(connectionProvider, nameof(connectionProvider));
this.amqpConnection = Preconditions.CheckNotNull(amqpConnection, nameof(amqpConnection));
}

public Task<IDeviceListener> GetDeviceListener()
Expand Down Expand Up @@ -129,18 +133,13 @@ public async Task RemoveLinkHandler(ILinkHandler linkHandler)
}
}

Task CloseAllLinks()
{
IList<ILinkHandler> links = this.registry.Values.ToList();
IEnumerable<Task> closeTasks = links.Select(l => l.CloseAsync(Constants.DefaultTimeout));
return Task.WhenAll(closeTasks);
}

async Task CloseConnection()
{
using (await this.initializationLock.LockAsync())
{
await this.deviceListener.ForEachAsync(d => d.CloseAsync());
this.deviceListener = Option.None<IDeviceListener>();
await this.amqpConnection.CloseAsync(this.closeTimeout);
}
}

Expand All @@ -157,14 +156,16 @@ public DeviceProxy(ClientConnectionHandler clientConnectionHandler, IIdentity id

public bool IsActive => this.isActive;

public bool IsDirectClient => true;

public IIdentity Identity { get; }

public Task CloseAsync(Exception ex)
{
if (this.isActive.GetAndSet(false))
{
Events.ClosingProxy(this.Identity, ex);
return this.clientConnectionHandler.CloseAllLinks();
return this.clientConnectionHandler.CloseConnection();
}

return Task.CompletedTask;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
namespace Microsoft.Azure.Devices.Edge.Hub.Amqp
{
using System.Collections.Concurrent;
using Microsoft.Azure.Amqp;
using Microsoft.Azure.Devices.Edge.Hub.Core;
using Microsoft.Azure.Devices.Edge.Hub.Core.Identity;
using Microsoft.Azure.Devices.Edge.Util;
Expand All @@ -10,13 +11,15 @@ class ClientConnectionsHandler : IClientConnectionsHandler
{
readonly ConcurrentDictionary<string, ClientConnectionHandler> connectionHandlers = new ConcurrentDictionary<string, ClientConnectionHandler>();
readonly IConnectionProvider connectionProvider;
readonly AmqpConnection amqpConnection;

public ClientConnectionsHandler(IConnectionProvider connectionProvider)
public ClientConnectionsHandler(IConnectionProvider connectionProvider, AmqpConnection amqpConnection)
{
this.connectionProvider = Preconditions.CheckNotNull(connectionProvider, nameof(connectionProvider));
this.amqpConnection = Preconditions.CheckNotNull(amqpConnection, nameof(amqpConnection));
}

public IConnectionHandler GetConnectionHandler(IIdentity identity) =>
this.connectionHandlers.GetOrAdd(identity.Id, i => new ClientConnectionHandler(identity, this.connectionProvider));
this.connectionHandlers.GetOrAdd(identity.Id, i => new ClientConnectionHandler(identity, this.connectionProvider, this.amqpConnection));
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,9 @@ namespace Microsoft.Azure.Devices.Edge.Hub.Amqp.Test
using System.Collections.Generic;
using System.Linq;
using System.Threading.Tasks;
using Microsoft.Azure.Amqp;
using Microsoft.Azure.Amqp.Framing;
using Microsoft.Azure.Amqp.Transport;
using Microsoft.Azure.Devices.Edge.Hub.Amqp.LinkHandlers;
using Microsoft.Azure.Devices.Edge.Hub.Core;
using Microsoft.Azure.Devices.Edge.Hub.Core.Device;
Expand All @@ -23,12 +26,14 @@ public void ConnectionHandlerCtorTest()
{
// Arrange
var identity = Mock.Of<IIdentity>();
var connectionPovider = Mock.Of<IConnectionProvider>();
var connectionProvider = Mock.Of<IConnectionProvider>();
var amqpConnection = new AmqpTestConnection();

// Act / Assert
Assert.NotNull(new ClientConnectionHandler(identity, connectionPovider));
Assert.Throws<ArgumentNullException>(() => new ClientConnectionHandler(null, connectionPovider));
Assert.Throws<ArgumentNullException>(() => new ClientConnectionHandler(identity, null));
Assert.NotNull(new ClientConnectionHandler(identity, connectionProvider, amqpConnection));
Assert.Throws<ArgumentNullException>(() => new ClientConnectionHandler(null, connectionProvider, amqpConnection));
Assert.Throws<ArgumentNullException>(() => new ClientConnectionHandler(identity, null, amqpConnection));
Assert.Throws<ArgumentNullException>(() => new ClientConnectionHandler(identity, connectionProvider, null));
}

[Fact]
Expand All @@ -42,7 +47,9 @@ public async Task GetDeviceListenerTest()
.Callback<IDeviceProxy>(d => deviceProxy = d);

var connectionProvider = Mock.Of<IConnectionProvider>(c => c.GetDeviceListenerAsync(identity, Option.None<string>()) == Task.FromResult(deviceListener));
var connectionHandler = new ClientConnectionHandler(identity, connectionProvider);
var amqpConnection = new AmqpTestConnection();

var connectionHandler = new ClientConnectionHandler(identity, connectionProvider, amqpConnection);

// Act
var tasks = new List<Task<IDeviceListener>>();
Expand Down Expand Up @@ -77,8 +84,9 @@ public async Task RegisterC2DMessageSenderTest()
.Callback<IDeviceProxy>(d => deviceProxy = d);

var connectionProvider = Mock.Of<IConnectionProvider>(c => c.GetDeviceListenerAsync(identity, Option.None<string>()) == Task.FromResult(deviceListener));
var amqpConnection = new AmqpTestConnection();

var connectionHandler = new ClientConnectionHandler(identity, connectionProvider);
var connectionHandler = new ClientConnectionHandler(identity, connectionProvider, amqpConnection);

IMessage receivedMessage = null;
var c2DLinkHandler = new Mock<ISendingLinkHandler>();
Expand Down Expand Up @@ -113,8 +121,9 @@ public async Task RegisterModuleMessageSenderTest()
.Callback<IDeviceProxy>(d => deviceProxy = d);

var connectionProvider = Mock.Of<IConnectionProvider>(c => c.GetDeviceListenerAsync(identity, Option.None<string>()) == Task.FromResult(deviceListener));
var amqpConnection = new AmqpTestConnection();

var connectionHandler = new ClientConnectionHandler(identity, connectionProvider);
var connectionHandler = new ClientConnectionHandler(identity, connectionProvider, amqpConnection);

IMessage receivedMessage = null;
var moduleMessageLinkHandler = new Mock<ISendingLinkHandler>();
Expand Down Expand Up @@ -149,8 +158,9 @@ public async Task RegisterMethodInvokerTest()
.Callback<IDeviceProxy>(d => deviceProxy = d);

var connectionProvider = Mock.Of<IConnectionProvider>(c => c.GetDeviceListenerAsync(identity, Option.None<string>()) == Task.FromResult(deviceListener));
var amqpConnection = new AmqpTestConnection();

var connectionHandler = new ClientConnectionHandler(identity, connectionProvider);
var connectionHandler = new ClientConnectionHandler(identity, connectionProvider, amqpConnection);

IMessage receivedMessage = null;
var methodSendingLinkHandler = new Mock<ISendingLinkHandler>();
Expand Down Expand Up @@ -185,8 +195,9 @@ public async Task RegisterDesiredPropertiesUpdateSenderTest()
.Callback<IDeviceProxy>(d => deviceProxy = d);

var connectionProvider = Mock.Of<IConnectionProvider>(c => c.GetDeviceListenerAsync(identity, Option.None<string>()) == Task.FromResult(deviceListener));
var amqpConnection = new AmqpTestConnection();

var connectionHandler = new ClientConnectionHandler(identity, connectionProvider);
var connectionHandler = new ClientConnectionHandler(identity, connectionProvider, amqpConnection);

IMessage receivedMessage = null;
var twinSendingLinkHandler = new Mock<ISendingLinkHandler>();
Expand Down Expand Up @@ -214,11 +225,13 @@ public async Task CloseOnRemovingAllLinksTest()
// Arrange
var deviceListener = new Mock<IDeviceListener>();
deviceListener.Setup(d => d.CloseAsync()).Returns(Task.CompletedTask);
deviceListener.Setup(d => d.BindDeviceProxy(It.IsAny<IDeviceProxy>()));

var identity = Mock.Of<IIdentity>(i => i.Id == "d1/m1");
var connectionProvider = Mock.Of<IConnectionProvider>(c => c.GetDeviceListenerAsync(identity, Option.None<string>()) == Task.FromResult(deviceListener.Object));
deviceListener.Setup(d => d.BindDeviceProxy(It.IsAny<IDeviceProxy>()));
var amqpConnection = new AmqpTestConnection();

var connectionHandler = new ClientConnectionHandler(identity, connectionProvider);
var connectionHandler = new ClientConnectionHandler(identity, connectionProvider, amqpConnection);

var eventsLinkHandler = Mock.Of<ILinkHandler>(l => l.Type == LinkType.Events);
string twinCorrelationId = Guid.NewGuid().ToString();
Expand Down Expand Up @@ -249,6 +262,70 @@ public async Task CloseOnRemovingAllLinksTest()

// Assert
deviceListener.Verify(d => d.CloseAsync(), Times.Once);
Assert.True(amqpConnection.CloseCalled);

// Act
await connectionHandler.GetDeviceListener();

// Assert
deviceListener.Verify(d => d.BindDeviceProxy(It.IsAny<IDeviceProxy>()), Times.Exactly(2));
}
}

class AmqpTestConnection : AmqpConnectionBase
{
public AmqpTestConnection()
: base("test", new TestTransport(), new AmqpConnectionSettings(), false)
{
}

public bool CloseCalled { get; private set; }

protected override void AbortInternal()
{
}

protected override bool CloseInternal()
{
this.CloseCalled = true;
return true;
}

protected override void OnFrameBuffer(ByteBuffer buffer)
{
}

protected override void OnProtocolHeader(ProtocolHeader header)
{
}

protected override bool OpenInternal() => true;
}

class TestTransport : TransportBase
{
public override string LocalEndPoint => "localhost";
public override string RemoteEndPoint => "remotehost";

public TestTransport()
: base("test")
{
}

public bool CloseCalled { get; private set; }

public override bool ReadAsync(TransportAsyncCallbackArgs args) => true;

public override void SetMonitor(ITransportMonitor usageMeter)
{
}

public override bool WriteAsync(TransportAsyncCallbackArgs args) => false;

protected override void AbortInternal()
{
}

protected override bool CloseInternal() => true;
}
}

0 comments on commit 5be30fb

Please sign in to comment.