diff --git a/src/Microsoft.Azure.SignalR.Common/ServiceConnections/ServiceConnectionContainerBase.cs b/src/Microsoft.Azure.SignalR.Common/ServiceConnections/ServiceConnectionContainerBase.cs index 64ef90b2b..2783b631d 100644 --- a/src/Microsoft.Azure.SignalR.Common/ServiceConnections/ServiceConnectionContainerBase.cs +++ b/src/Microsoft.Azure.SignalR.Common/ServiceConnections/ServiceConnectionContainerBase.cs @@ -4,7 +4,6 @@ using System; using System.Collections.Generic; using System.Diagnostics; -using System.IO; using System.Linq; using System.Threading; using System.Threading.Tasks; @@ -56,6 +55,8 @@ internal abstract class ServiceConnectionContainerBase : IServiceConnectionConta private volatile bool _hasClients; private volatile bool _terminated = false; + private volatile IServiceConnection _inUseConnection; + protected ILogger Logger { get; } protected List ServiceConnections @@ -140,7 +141,7 @@ private set ConnectionStatusChanged += OnStatusChanged; _statusPing = new CustomizedPingTimer(Logger, Constants.CustomizedPingTimer.ServiceStatus, WriteServiceStatusPingAsync, Constants.Periods.DefaultStatusPingInterval, Constants.Periods.DefaultStatusPingInterval); - + // when server connection count is specified to 0, the app server only handle negotiate requests if (initial.Count > 0) { @@ -231,7 +232,7 @@ public void HandleAck(AckMessage ackMessage) public virtual Task WriteAsync(ServiceMessage serviceMessage) { - return WriteToScopedOrRandomAvailableConnection(serviceMessage); + return WriteToScopedOrFixedAvailableConnection(serviceMessage); } public async Task WriteAckableMessageAsync(ServiceMessage serviceMessage, CancellationToken cancellationToken = default) @@ -248,7 +249,7 @@ public async Task WriteAckableMessageAsync(ServiceMessage serviceMessage, // whereas ackable ones complete upon full roundtrip of the message and the ack (or timeout). // Therefore sending them over different connections creates a possibility for processing them out of original order. // By sending both message types over the same connection we ensure that they are sent (and processed) in their original order. - await WriteToScopedOrRandomAvailableConnection(serviceMessage); + await WriteToScopedOrFixedAvailableConnection(serviceMessage); var status = await task; switch (status) @@ -484,7 +485,7 @@ private void OnConnectionStatusChanged(StatusChange obj) } } - private async Task WriteToScopedOrRandomAvailableConnection(ServiceMessage serviceMessage) + private async Task WriteToScopedOrFixedAvailableConnection(ServiceMessage serviceMessage) { // ServiceConnections can change the collection underneath so we make a local copy and pass it along var currentConnections = ServiceConnections; @@ -498,7 +499,7 @@ private async Task WriteToScopedOrRandomAvailableConnection(ServiceMessage servi IServiceConnection connection = null; connectionWeakReference?.TryGetTarget(out connection); - var connectionUsed = await WriteWithRetry(serviceMessage, connection, currentConnections); + var connectionUsed = await WriteWithRetry(serviceMessage, connection, currentConnections, IterateConnectionsInRandomOrder); // Todo: // There is currently no synchronization when persisting selected connection in ClientConnectionScope. @@ -515,22 +516,21 @@ private async Task WriteToScopedOrRandomAvailableConnection(ServiceMessage servi } else { - await WriteWithRetry(serviceMessage, null, currentConnections); + var connectionUsed = await WriteWithRetry(serviceMessage, _inUseConnection, currentConnections, IterateConnectionsInFixedOrder); + // Similarly, here is currently no synchronization when persisting selected connection in _inUseConnection. + + if (connectionUsed != _inUseConnection) + { + _inUseConnection = connectionUsed; + } } } - private async Task WriteWithRetry(ServiceMessage serviceMessage, IServiceConnection connection, List currentConnections) + private async Task WriteWithRetry(ServiceMessage serviceMessage, IServiceConnection connection, List currentConnections, Func, IEnumerable> iterateConnections) { // go through all the connections, it can be useful when one of the remote service instances is down - var count = currentConnections.Count; - var initial = StaticRandom.Next(-count, count); - var maxRetry = count; - var retry = 0; - var index = (initial & int.MaxValue) % count; - var direction = initial > 0 ? 1 : count - 1; - - // ensure a full sweep starting with the connection flowed with the async context - while (retry <= maxRetry) + IEnumerator iterator = null; + while (true) { if (connection != null && connection.Status == ServiceConnectionStatus.Connected) { @@ -542,23 +542,33 @@ private async Task WriteWithRetry(ServiceMessage serviceMess } catch (ServiceConnectionNotActiveException) { - if (retry == maxRetry - 1) - { - throw; - } } } + iterator ??= iterateConnections(currentConnections).GetEnumerator(); + connection = iterator.MoveNext() ? iterator.Current : throw new ServiceConnectionNotActiveException(); + } + } - // try current index instead - connection = currentConnections[index]; + private static IEnumerable IterateConnectionsInRandomOrder(List connections) + { + var count = connections.Count; + var initial = StaticRandom.Next(-count, count); + var maxRetry = count; + var retry = 0; + var index = (initial & int.MaxValue) % count; + var direction = initial > 0 ? 1 : count - 1; + + while (retry <= maxRetry) + { + yield return connections[index]; retry++; index = (index + direction) % count; } - - throw new ServiceConnectionNotActiveException(); } + private static IEnumerable IterateConnectionsInFixedOrder(List connections) => connections; + private IEnumerable CreateFixedServiceConnection(int count) { for (int i = 0; i < count; i++) diff --git a/test/Microsoft.Azure.SignalR.Management.Tests/MessageOrderTest.cs b/test/Microsoft.Azure.SignalR.Management.Tests/MessageOrderTest.cs new file mode 100644 index 000000000..07597a270 --- /dev/null +++ b/test/Microsoft.Azure.SignalR.Management.Tests/MessageOrderTest.cs @@ -0,0 +1,94 @@ +// Copyright (c) Microsoft. All rights reserved. +// Licensed under the MIT license. See LICENSE file in the project root for full license information. + +using System; +using System.Linq; +using System.Threading.Tasks; +using Microsoft.AspNetCore.SignalR; +using Microsoft.Azure.SignalR.Protocol; +using Microsoft.Azure.SignalR.Tests.Common; +using Microsoft.Extensions.DependencyInjection; +using Microsoft.Extensions.Logging; +using Xunit; +using Xunit.Abstractions; + +namespace Microsoft.Azure.SignalR.Management.Tests; +public class MessageOrderTest +{ + private readonly ITestOutputHelper _output; + + public MessageOrderTest(ITestOutputHelper output) + { + _output = output; + } + + /// + /// First, set up a client to send a fixed number of messages. Now the first service connection is used to send messages. + /// Then in the middle of sending, disconnect one of the service connections and the second service connections should be used to send message. + /// + [Fact] + public async Task TestMessageOrderWithSequentialSending() + { + async Task testAction(ServiceHubContext hubContext, TestServiceConnectionFactory testConnectionFactory) + { + var sendTask = SendingTask(hubContext); + + await Task.Delay(7 * 100); + foreach (var connections in testConnectionFactory.CreatedConnections.Values) + { + (connections.First() as TestServiceConnection).SetStatus(ServiceConnectionStatus.Disconnected); + } + + await sendTask; + + foreach (var connections in testConnectionFactory.CreatedConnections.Values) + { + var expectedIndex = 0; + + foreach (var message in (connections[0] as TestServiceConnection).ReceivedMessages) + { + Assert.Equal(expectedIndex.ToString(), (message as BroadcastDataMessage).ExcludedList.Single()); + expectedIndex++; + } + + Assert.True(21 > expectedIndex); + + foreach (var message in (connections[1] as TestServiceConnection).ReceivedMessages) + { + Assert.Equal(expectedIndex.ToString(), (message as BroadcastDataMessage).ExcludedList.Single()); + expectedIndex++; + } + Assert.Equal(21, expectedIndex); + } + } + await MockConnectionTestAsync(testAction); + } + + + private static async Task SendingTask(ServiceHubContext hubContext) + { + for (var i = 0; i < 21; i++) + { + await hubContext.Clients.AllExcept(new string[] { i.ToString() }).SendAsync("Send"); + await Task.Delay(300); + } + } + + private async Task MockConnectionTestAsync(Func testAction) + { + var connectionFactory = new TestServiceConnectionFactory(); + var serviceManager = new ServiceManagerBuilder() + .WithOptions(o => + { + o.ServiceTransportType = ServiceTransportType.Persistent; + o.ServiceEndpoints = FakeEndpointUtils.GetFakeEndpoint(2).ToArray(); + o.ConnectionCount = 3; + }) + .WithLoggerFactory(new LoggerFactory().AddXunit(_output)) + .ConfigureServices(services => services.AddSingleton(connectionFactory)) + .BuildServiceManager(); + var hubContext = await serviceManager.CreateHubContextAsync("hub1", default); + + await testAction.Invoke(hubContext, connectionFactory); + } +}