From 7ef74bc95d80300609d23a7bc9392df08a0e6032 Mon Sep 17 00:00:00 2001 From: Stephen Halter Date: Thu, 20 Dec 2012 16:53:48 -0800 Subject: [PATCH] Addresses issue #1155 for the forever transports. While PersistentConnection.OnConnectedAsync is running we queue all writes to the response stream to prevent clients from prematurely indicating the connection has started. We keep the subscriber's receive loop running to allow the subscriber to continue processing commands and to prevent deadlocks caused by waiting on ACKs. --- .../Infrastructure/TaskQueue.cs | 11 ++- .../Transports/ForeverFrameTransport.cs | 11 +-- .../Transports/ForeverTransport.cs | 83 +++++++++---------- .../Transports/ServerSentEventsTransport.cs | 17 ++-- .../Transports/TransportDisconnectBase.cs | 14 ++-- .../Connections/PersistentConnectionFacts.cs | 26 ++++++ .../App_Start/RegisterHubs.cs | 1 + .../AddGroupOnConnectedConnection.cs | 21 +++++ .../Infrastructure/MemoryTestHost.cs | 1 + ...crosoft.AspNet.SignalR.Tests.Common.csproj | 1 + 10 files changed, 117 insertions(+), 69 deletions(-) create mode 100644 tests/Microsoft.AspNet.SignalR.Tests.Common/Connections/AddGroupOnConnectedConnection.cs diff --git a/src/Microsoft.AspNet.SignalR.Core/Infrastructure/TaskQueue.cs b/src/Microsoft.AspNet.SignalR.Core/Infrastructure/TaskQueue.cs index cf8f7e60ce..9a4183359b 100644 --- a/src/Microsoft.AspNet.SignalR.Core/Infrastructure/TaskQueue.cs +++ b/src/Microsoft.AspNet.SignalR.Core/Infrastructure/TaskQueue.cs @@ -12,9 +12,18 @@ namespace Microsoft.AspNet.SignalR.Infrastructure internal sealed class TaskQueue { private readonly object _lockObj = new object(); - private Task _lastQueuedTask = TaskAsyncHelper.Empty; + private Task _lastQueuedTask; private volatile bool _drained; + public TaskQueue() : this(TaskAsyncHelper.Empty) + { + } + + public TaskQueue(Task initialTask) + { + _lastQueuedTask = initialTask; + } + [SuppressMessage("Microsoft.Performance", "CA1811:AvoidUncalledPrivateCode", Justification = "This is shared code")] public bool IsDrained { diff --git a/src/Microsoft.AspNet.SignalR.Core/Transports/ForeverFrameTransport.cs b/src/Microsoft.AspNet.SignalR.Core/Transports/ForeverFrameTransport.cs index 9abd8ab7a3..2fe919ac5d 100644 --- a/src/Microsoft.AspNet.SignalR.Core/Transports/ForeverFrameTransport.cs +++ b/src/Microsoft.AspNet.SignalR.Core/Transports/ForeverFrameTransport.cs @@ -59,7 +59,7 @@ private HTMLTextWriter HTMLOutputWriter public override Task KeepAlive() { - if (!Initialized) + if (!InitializeTcs.Task.IsCompleted) { return TaskAsyncHelper.Empty; } @@ -97,13 +97,10 @@ protected override Task InitializeResponse(ITransportConnection connection) { Context.Response.ContentType = "text/html"; - return EnqueueOperation(() => - { - HTMLOutputWriter.WriteRaw(initScript); - HTMLOutputWriter.Flush(); + HTMLOutputWriter.WriteRaw(initScript); + HTMLOutputWriter.Flush(); - return Context.Response.FlushAsync(); - }); + return Context.Response.FlushAsync(); }, _initPrefix + Context.Request.QueryString["frameId"] + _initSuffix); } diff --git a/src/Microsoft.AspNet.SignalR.Core/Transports/ForeverTransport.cs b/src/Microsoft.AspNet.SignalR.Core/Transports/ForeverTransport.cs index 171a66cdb9..6be297b302 100644 --- a/src/Microsoft.AspNet.SignalR.Core/Transports/ForeverTransport.cs +++ b/src/Microsoft.AspNet.SignalR.Core/Transports/ForeverTransport.cs @@ -14,9 +14,6 @@ public abstract class ForeverTransport : TransportDisconnectBase, ITransport private IJsonSerializer _jsonSerializer; private string _lastMessageId; - // Determines whether the transport has been fully initialized. - private volatile bool _initialized; - private const int MaxMessages = 10; protected ForeverTransport(HostContext context, IDependencyResolver resolver) @@ -37,6 +34,7 @@ protected ForeverTransport(HostContext context, IDependencyResolver resolver) { _jsonSerializer = jsonSerializer; _counters = performanceCounterWriter; + } protected string LastMessageId @@ -57,17 +55,7 @@ protected IJsonSerializer JsonSerializer get { return _jsonSerializer; } } - protected bool Initialized - { - get - { - return _initialized; - } - set - { - _initialized = value; - } - } + protected TaskCompletionSource InitializeTcs { get; set; } protected virtual void OnSending(string payload) { @@ -93,6 +81,16 @@ protected virtual void OnSendingResponse(PersistentResponse response) internal Action BeforeReceive; internal Action AfterRequestEnd; + protected override void InitializePersistentState() + { + // PersistentConnection.OnConnectedAsync must complete before we can write to the output stream, + // so clients don't indicate the connection has started too early. + InitializeTcs = new TaskCompletionSource(); + WriteQueue = new TaskQueue(InitializeTcs.Task); + + base.InitializePersistentState(); + } + protected Task ProcessRequestCore(ITransportConnection connection) { Connection = connection; @@ -168,6 +166,21 @@ protected virtual Task InitializeResponse(ITransportConnection connection) return TaskAsyncHelper.Empty; } + protected internal override Task EnqueueOperation(Func writeAsync) + { + Task task = base.EnqueueOperation(writeAsync); + + // If PersistentConnection.OnConnectedAsync has not completed (as indicated by InitializeTcs), + // the queue will be blocked to prevent clients from prematurely indicating the connection has + // started, but we must keep receive loop running to continue processing commands and to + // prevent deadlocks caused by waiting on ACKs. + if (InitializeTcs == null || InitializeTcs.Task.IsCompleted) + { + return task; + } + return TaskAsyncHelper.Empty; + } + protected void IncrementErrorCounters(Exception exception) { _counters.ErrorsTransportTotal.Increment(); @@ -208,12 +221,7 @@ private Task ProcessReceiveRequestWithoutTracking(ITransportConnection connectio } return TaskAsyncHelper.Empty; }, - () => InitializeResponse(connection), - () => - { - Initialized = true; - return TaskAsyncHelper.Empty; - }); + () => InitializeResponse(connection)); }; return ProcessMessages(connection, afterReceive); @@ -229,7 +237,7 @@ private Task OnTransportConnected() return TaskAsyncHelper.Empty; } - private Task ProcessMessages(ITransportConnection connection, Func postReceive = null) + private Task ProcessMessages(ITransportConnection connection, Func postReceive) { var tcs = new TaskCompletionSource(); @@ -253,7 +261,7 @@ private Task ProcessMessages(ITransportConnection connection, Func postRec CompleteRequest(); Trace.TraceInformation("EndRequest(" + ConnectionId + ")"); - }, + }, TaskContinuationOptions.ExecuteSynchronously); if (AfterRequestEnd != null) @@ -274,8 +282,6 @@ private void ProcessMessages(ITransportConnection connection, Func postRec IDisposable subscription = null; IDisposable registration = null; - var wh = new ManualResetEventSlim(); - if (BeforeReceive != null) { BeforeReceive(); @@ -285,9 +291,6 @@ private void ProcessMessages(ITransportConnection connection, Func postRec { subscription = connection.Receive(LastMessageId, response => { - // We need to wait until post receive has been called - wh.Wait(); - response.TimedOut = IsTimedOut; // If we're telling the client to disconnect then clean up the instantiated connection. @@ -342,9 +345,6 @@ private void ProcessMessages(ITransportConnection connection, Func postRec catch (Exception ex) { endRequest(ex); - - wh.Set(); - return; } @@ -353,20 +353,15 @@ private void ProcessMessages(ITransportConnection connection, Func postRec AfterReceive(); } - if (postReceive != null) - { - postReceive().Catch(_counters.ErrorsAllTotal, _counters.ErrorsAllPerSec) - .Catch(ex => endRequest(ex)) - .Catch(ex => - { - Trace.TraceInformation("Failed post receive for {0} with: {1}", ConnectionId, ex.GetBaseException()); - }) - .ContinueWith(task => wh.Set()); - } - else - { - wh.Set(); - } + + postReceive().Catch(_counters.ErrorsAllTotal, _counters.ErrorsAllPerSec) + .Catch(ex => endRequest(ex)) + .Catch(ex => + { + Trace.TraceInformation("Failed post receive for {0} with: {1}", ConnectionId, ex.GetBaseException()); + }) + .ContinueWith(InitializeTcs); + if (BeforeCancellationTokenCallbackRegistered != null) { diff --git a/src/Microsoft.AspNet.SignalR.Core/Transports/ServerSentEventsTransport.cs b/src/Microsoft.AspNet.SignalR.Core/Transports/ServerSentEventsTransport.cs index 21dbbab1ad..8780f00fcd 100644 --- a/src/Microsoft.AspNet.SignalR.Core/Transports/ServerSentEventsTransport.cs +++ b/src/Microsoft.AspNet.SignalR.Core/Transports/ServerSentEventsTransport.cs @@ -15,7 +15,7 @@ public ServerSentEventsTransport(HostContext context, IDependencyResolver resolv public override Task KeepAlive() { - if (!Initialized) + if (!InitializeTcs.Task.IsCompleted) { return TaskAsyncHelper.Empty; } @@ -54,16 +54,13 @@ protected override Task InitializeResponse(ITransportConnection connection) { Context.Response.ContentType = "text/event-stream"; - return EnqueueOperation(() => - { - // "data: initialized\n\n" - OutputWriter.Write("data: initialized"); - OutputWriter.WriteLine(); - OutputWriter.WriteLine(); - OutputWriter.Flush(); + // "data: initialized\n\n" + OutputWriter.Write("data: initialized"); + OutputWriter.WriteLine(); + OutputWriter.WriteLine(); + OutputWriter.Flush(); - return Context.Response.FlushAsync(); - }); + return Context.Response.FlushAsync(); }); } } diff --git a/src/Microsoft.AspNet.SignalR.Core/Transports/TransportDisconnectBase.cs b/src/Microsoft.AspNet.SignalR.Core/Transports/TransportDisconnectBase.cs index 804762a8d0..2e7b88ce80 100644 --- a/src/Microsoft.AspNet.SignalR.Core/Transports/TransportDisconnectBase.cs +++ b/src/Microsoft.AspNet.SignalR.Core/Transports/TransportDisconnectBase.cs @@ -120,7 +120,7 @@ protected TaskCompletionSource Completed internal TaskQueue WriteQueue { get; - private set; + set; } public IEnumerable Groups @@ -287,18 +287,18 @@ public void CompleteRequest() } } - protected internal Task EnqueueOperation(Func writeAsync) + protected virtual internal Task EnqueueOperation(Func writeAsync) { - if (IsAlive) + if (!IsAlive) { - // Only enqueue new writes if the connection is alive - return WriteQueue.Enqueue(writeAsync); + return TaskAsyncHelper.Empty; } - return TaskAsyncHelper.Empty; + // Only enqueue new writes if the connection is alive + return WriteQueue.Enqueue(writeAsync); } - protected void InitializePersistentState() + protected virtual void InitializePersistentState() { _hostShutdownToken = _context.HostShutdownToken(); diff --git a/tests/Microsoft.AspNet.SignalR.FunctionalTests/Server/Connections/PersistentConnectionFacts.cs b/tests/Microsoft.AspNet.SignalR.FunctionalTests/Server/Connections/PersistentConnectionFacts.cs index ed7492d903..1b9089a716 100644 --- a/tests/Microsoft.AspNet.SignalR.FunctionalTests/Server/Connections/PersistentConnectionFacts.cs +++ b/tests/Microsoft.AspNet.SignalR.FunctionalTests/Server/Connections/PersistentConnectionFacts.cs @@ -92,6 +92,32 @@ public void GroupsAreNotReadOnConnectedAsync(HostType hostType, TransportType tr } } + [Theory] + [InlineData(HostType.Memory, TransportType.Auto)] + // [InlineData(HostType.IISExpress, TransportType.Auto)] + public void GroupCanBeAddedAndMessagedOnConnected(HostType hostType, TransportType transportType) + { + using (var host = CreateHost(hostType, transportType)) + { + var wh = new ManualResetEventSlim(); + host.Initialize(); + + var connection = new Client.Connection(host.Url + "/add-group"); + connection.Received += data => + { + Assert.Equal("hey", data); + wh.Set(); + }; + + connection.Start(host.Transport).Wait(); + connection.SendWithTimeout(""); + + Assert.True(wh.Wait(TimeSpan.FromSeconds(5))); + + connection.Stop(); + } + } + [Theory] [InlineData(HostType.Memory, TransportType.ServerSentEvents)] [InlineData(HostType.Memory, TransportType.LongPolling)] diff --git a/tests/Microsoft.AspNet.SignalR.Tests.Common/App_Start/RegisterHubs.cs b/tests/Microsoft.AspNet.SignalR.Tests.Common/App_Start/RegisterHubs.cs index 0cd72e74c1..08e94cfa80 100644 --- a/tests/Microsoft.AspNet.SignalR.Tests.Common/App_Start/RegisterHubs.cs +++ b/tests/Microsoft.AspNet.SignalR.Tests.Common/App_Start/RegisterHubs.cs @@ -65,6 +65,7 @@ public static void Start() RouteTable.Routes.MapConnection("filter", "filter/{*operation}"); RouteTable.Routes.MapConnection("items", "items/{*operation}"); RouteTable.Routes.MapConnection("sync-error", "sync-error/{*operation}"); + RouteTable.Routes.MapConnection("add-group", "add-group/{*operation}"); // End point to hit to verify the webserver is up RouteTable.Routes.Add("test-endpoint", new Route("ping", new TestEndPoint())); diff --git a/tests/Microsoft.AspNet.SignalR.Tests.Common/Connections/AddGroupOnConnectedConnection.cs b/tests/Microsoft.AspNet.SignalR.Tests.Common/Connections/AddGroupOnConnectedConnection.cs new file mode 100644 index 0000000000..8ba12152d3 --- /dev/null +++ b/tests/Microsoft.AspNet.SignalR.Tests.Common/Connections/AddGroupOnConnectedConnection.cs @@ -0,0 +1,21 @@ +using System; +using System.Threading; +using System.Threading.Tasks; + +namespace Microsoft.AspNet.SignalR.FunctionalTests +{ + class AddGroupOnConnectedConnection : PersistentConnection + { + protected override async Task OnConnectedAsync(IRequest request, string connectionId) + { + await Groups.Add(connectionId, "test"); + Thread.Sleep(TimeSpan.FromSeconds(1)); + await Groups.Add(connectionId, "test2"); + } + + protected override async Task OnReceivedAsync(IRequest request, string connectionId, string data) + { + await Groups.Send("test2", "hey"); + } + } +} diff --git a/tests/Microsoft.AspNet.SignalR.Tests.Common/Infrastructure/MemoryTestHost.cs b/tests/Microsoft.AspNet.SignalR.Tests.Common/Infrastructure/MemoryTestHost.cs index 7c6e57d043..598f14230d 100644 --- a/tests/Microsoft.AspNet.SignalR.Tests.Common/Infrastructure/MemoryTestHost.cs +++ b/tests/Microsoft.AspNet.SignalR.Tests.Common/Infrastructure/MemoryTestHost.cs @@ -71,6 +71,7 @@ public string Url _host.MapConnection("/filter"); _host.MapConnection("/sync-error"); _host.MapConnection("/fall-back"); + _host.MapConnection("/add-group"); } public void Dispose() diff --git a/tests/Microsoft.AspNet.SignalR.Tests.Common/Microsoft.AspNet.SignalR.Tests.Common.csproj b/tests/Microsoft.AspNet.SignalR.Tests.Common/Microsoft.AspNet.SignalR.Tests.Common.csproj index bbc9d01190..a0aaf584de 100644 --- a/tests/Microsoft.AspNet.SignalR.Tests.Common/Microsoft.AspNet.SignalR.Tests.Common.csproj +++ b/tests/Microsoft.AspNet.SignalR.Tests.Common/Microsoft.AspNet.SignalR.Tests.Common.csproj @@ -99,6 +99,7 @@ +