Skip to content

Commit

Permalink
Addresses issue #1155 for the forever transports.
Browse files Browse the repository at this point in the history
While PersistentConnection.OnConnectedAsync is running we queue all writes
to the response stream to prevent clients from prematurely indicating the
connection has started.

We keep the subscriber's receive loop running to allow the subscriber to
continue processing commands and to prevent deadlocks caused by waiting on
ACKs.
  • Loading branch information
halter73 authored and davidfowl committed Dec 21, 2012
1 parent 2222082 commit 7ef74bc
Show file tree
Hide file tree
Showing 10 changed files with 117 additions and 69 deletions.
11 changes: 10 additions & 1 deletion src/Microsoft.AspNet.SignalR.Core/Infrastructure/TaskQueue.cs
Expand Up @@ -12,9 +12,18 @@ namespace Microsoft.AspNet.SignalR.Infrastructure
internal sealed class TaskQueue
{
private readonly object _lockObj = new object();
private Task _lastQueuedTask = TaskAsyncHelper.Empty;
private Task _lastQueuedTask;
private volatile bool _drained;

public TaskQueue() : this(TaskAsyncHelper.Empty)
{
}

public TaskQueue(Task initialTask)
{
_lastQueuedTask = initialTask;
}

[SuppressMessage("Microsoft.Performance", "CA1811:AvoidUncalledPrivateCode", Justification = "This is shared code")]
public bool IsDrained
{
Expand Down
Expand Up @@ -59,7 +59,7 @@ private HTMLTextWriter HTMLOutputWriter

public override Task KeepAlive()
{
if (!Initialized)
if (!InitializeTcs.Task.IsCompleted)
{
return TaskAsyncHelper.Empty;
}
Expand Down Expand Up @@ -97,13 +97,10 @@ protected override Task InitializeResponse(ITransportConnection connection)
{
Context.Response.ContentType = "text/html";
return EnqueueOperation(() =>
{
HTMLOutputWriter.WriteRaw(initScript);
HTMLOutputWriter.Flush();
HTMLOutputWriter.WriteRaw(initScript);
HTMLOutputWriter.Flush();
return Context.Response.FlushAsync();
});
return Context.Response.FlushAsync();
},
_initPrefix + Context.Request.QueryString["frameId"] + _initSuffix);
}
Expand Down
83 changes: 39 additions & 44 deletions src/Microsoft.AspNet.SignalR.Core/Transports/ForeverTransport.cs
Expand Up @@ -14,9 +14,6 @@ public abstract class ForeverTransport : TransportDisconnectBase, ITransport
private IJsonSerializer _jsonSerializer;
private string _lastMessageId;

// Determines whether the transport has been fully initialized.
private volatile bool _initialized;

private const int MaxMessages = 10;

protected ForeverTransport(HostContext context, IDependencyResolver resolver)
Expand All @@ -37,6 +34,7 @@ protected ForeverTransport(HostContext context, IDependencyResolver resolver)
{
_jsonSerializer = jsonSerializer;
_counters = performanceCounterWriter;

}

protected string LastMessageId
Expand All @@ -57,17 +55,7 @@ protected IJsonSerializer JsonSerializer
get { return _jsonSerializer; }
}

protected bool Initialized
{
get
{
return _initialized;
}
set
{
_initialized = value;
}
}
protected TaskCompletionSource<object> InitializeTcs { get; set; }

protected virtual void OnSending(string payload)
{
Expand All @@ -93,6 +81,16 @@ protected virtual void OnSendingResponse(PersistentResponse response)
internal Action BeforeReceive;
internal Action AfterRequestEnd;

protected override void InitializePersistentState()
{
// PersistentConnection.OnConnectedAsync must complete before we can write to the output stream,
// so clients don't indicate the connection has started too early.
InitializeTcs = new TaskCompletionSource<object>();
WriteQueue = new TaskQueue(InitializeTcs.Task);

base.InitializePersistentState();
}

protected Task ProcessRequestCore(ITransportConnection connection)
{
Connection = connection;
Expand Down Expand Up @@ -168,6 +166,21 @@ protected virtual Task InitializeResponse(ITransportConnection connection)
return TaskAsyncHelper.Empty;
}

protected internal override Task EnqueueOperation(Func<Task> writeAsync)
{
Task task = base.EnqueueOperation(writeAsync);

// If PersistentConnection.OnConnectedAsync has not completed (as indicated by InitializeTcs),
// the queue will be blocked to prevent clients from prematurely indicating the connection has
// started, but we must keep receive loop running to continue processing commands and to
// prevent deadlocks caused by waiting on ACKs.
if (InitializeTcs == null || InitializeTcs.Task.IsCompleted)
{
return task;
}
return TaskAsyncHelper.Empty;
}

protected void IncrementErrorCounters(Exception exception)
{
_counters.ErrorsTransportTotal.Increment();
Expand Down Expand Up @@ -208,12 +221,7 @@ private Task ProcessReceiveRequestWithoutTracking(ITransportConnection connectio
}
return TaskAsyncHelper.Empty;
},
() => InitializeResponse(connection),
() =>
{
Initialized = true;
return TaskAsyncHelper.Empty;
});
() => InitializeResponse(connection));
};

return ProcessMessages(connection, afterReceive);
Expand All @@ -229,7 +237,7 @@ private Task OnTransportConnected()
return TaskAsyncHelper.Empty;
}

private Task ProcessMessages(ITransportConnection connection, Func<Task> postReceive = null)
private Task ProcessMessages(ITransportConnection connection, Func<Task> postReceive)
{
var tcs = new TaskCompletionSource<object>();

Expand All @@ -253,7 +261,7 @@ private Task ProcessMessages(ITransportConnection connection, Func<Task> postRec
CompleteRequest();
Trace.TraceInformation("EndRequest(" + ConnectionId + ")");
},
},
TaskContinuationOptions.ExecuteSynchronously);
if (AfterRequestEnd != null)
Expand All @@ -274,8 +282,6 @@ private void ProcessMessages(ITransportConnection connection, Func<Task> postRec
IDisposable subscription = null;
IDisposable registration = null;

var wh = new ManualResetEventSlim();

if (BeforeReceive != null)
{
BeforeReceive();
Expand All @@ -285,9 +291,6 @@ private void ProcessMessages(ITransportConnection connection, Func<Task> postRec
{
subscription = connection.Receive(LastMessageId, response =>
{
// We need to wait until post receive has been called
wh.Wait();
response.TimedOut = IsTimedOut;
// If we're telling the client to disconnect then clean up the instantiated connection.
Expand Down Expand Up @@ -342,9 +345,6 @@ private void ProcessMessages(ITransportConnection connection, Func<Task> postRec
catch (Exception ex)
{
endRequest(ex);

wh.Set();

return;
}

Expand All @@ -353,20 +353,15 @@ private void ProcessMessages(ITransportConnection connection, Func<Task> postRec
AfterReceive();
}

if (postReceive != null)
{
postReceive().Catch(_counters.ErrorsAllTotal, _counters.ErrorsAllPerSec)
.Catch(ex => endRequest(ex))
.Catch(ex =>
{
Trace.TraceInformation("Failed post receive for {0} with: {1}", ConnectionId, ex.GetBaseException());
})
.ContinueWith(task => wh.Set());
}
else
{
wh.Set();
}

postReceive().Catch(_counters.ErrorsAllTotal, _counters.ErrorsAllPerSec)
.Catch(ex => endRequest(ex))
.Catch(ex =>
{
Trace.TraceInformation("Failed post receive for {0} with: {1}", ConnectionId, ex.GetBaseException());
})
.ContinueWith(InitializeTcs);


if (BeforeCancellationTokenCallbackRegistered != null)
{
Expand Down
Expand Up @@ -15,7 +15,7 @@ public ServerSentEventsTransport(HostContext context, IDependencyResolver resolv

public override Task KeepAlive()
{
if (!Initialized)
if (!InitializeTcs.Task.IsCompleted)
{
return TaskAsyncHelper.Empty;
}
Expand Down Expand Up @@ -54,16 +54,13 @@ protected override Task InitializeResponse(ITransportConnection connection)
{
Context.Response.ContentType = "text/event-stream";
return EnqueueOperation(() =>
{
// "data: initialized\n\n"
OutputWriter.Write("data: initialized");
OutputWriter.WriteLine();
OutputWriter.WriteLine();
OutputWriter.Flush();
// "data: initialized\n\n"
OutputWriter.Write("data: initialized");
OutputWriter.WriteLine();
OutputWriter.WriteLine();
OutputWriter.Flush();
return Context.Response.FlushAsync();
});
return Context.Response.FlushAsync();
});
}
}
Expand Down
Expand Up @@ -120,7 +120,7 @@ protected TaskCompletionSource<object> Completed
internal TaskQueue WriteQueue
{
get;
private set;
set;
}

public IEnumerable<string> Groups
Expand Down Expand Up @@ -287,18 +287,18 @@ public void CompleteRequest()
}
}

protected internal Task EnqueueOperation(Func<Task> writeAsync)
protected virtual internal Task EnqueueOperation(Func<Task> writeAsync)
{
if (IsAlive)
if (!IsAlive)
{
// Only enqueue new writes if the connection is alive
return WriteQueue.Enqueue(writeAsync);
return TaskAsyncHelper.Empty;
}

return TaskAsyncHelper.Empty;
// Only enqueue new writes if the connection is alive
return WriteQueue.Enqueue(writeAsync);
}

protected void InitializePersistentState()
protected virtual void InitializePersistentState()
{
_hostShutdownToken = _context.HostShutdownToken();

Expand Down
Expand Up @@ -92,6 +92,32 @@ public void GroupsAreNotReadOnConnectedAsync(HostType hostType, TransportType tr
}
}

[Theory]
[InlineData(HostType.Memory, TransportType.Auto)]
// [InlineData(HostType.IISExpress, TransportType.Auto)]
public void GroupCanBeAddedAndMessagedOnConnected(HostType hostType, TransportType transportType)
{
using (var host = CreateHost(hostType, transportType))
{
var wh = new ManualResetEventSlim();
host.Initialize();

var connection = new Client.Connection(host.Url + "/add-group");
connection.Received += data =>
{
Assert.Equal("hey", data);
wh.Set();
};

connection.Start(host.Transport).Wait();
connection.SendWithTimeout("");

Assert.True(wh.Wait(TimeSpan.FromSeconds(5)));

connection.Stop();
}
}

[Theory]
[InlineData(HostType.Memory, TransportType.ServerSentEvents)]
[InlineData(HostType.Memory, TransportType.LongPolling)]
Expand Down
Expand Up @@ -65,6 +65,7 @@ public static void Start()
RouteTable.Routes.MapConnection<FilteredConnection>("filter", "filter/{*operation}");
RouteTable.Routes.MapConnection<ConnectionThatUsesItems>("items", "items/{*operation}");
RouteTable.Routes.MapConnection<SyncErrorConnection>("sync-error", "sync-error/{*operation}");
RouteTable.Routes.MapConnection<AddGroupOnConnectedConnection>("add-group", "add-group/{*operation}");

// End point to hit to verify the webserver is up
RouteTable.Routes.Add("test-endpoint", new Route("ping", new TestEndPoint()));
Expand Down
@@ -0,0 +1,21 @@
using System;
using System.Threading;
using System.Threading.Tasks;

namespace Microsoft.AspNet.SignalR.FunctionalTests
{
class AddGroupOnConnectedConnection : PersistentConnection
{
protected override async Task OnConnectedAsync(IRequest request, string connectionId)
{
await Groups.Add(connectionId, "test");
Thread.Sleep(TimeSpan.FromSeconds(1));
await Groups.Add(connectionId, "test2");
}

protected override async Task OnReceivedAsync(IRequest request, string connectionId, string data)
{
await Groups.Send("test2", "hey");
}
}
}
Expand Up @@ -71,6 +71,7 @@ public string Url
_host.MapConnection<FilteredConnection>("/filter");
_host.MapConnection<SyncErrorConnection>("/sync-error");
_host.MapConnection<FallbackToLongPollingConnection>("/fall-back");
_host.MapConnection<AddGroupOnConnectedConnection>("/add-group");
}

public void Dispose()
Expand Down
Expand Up @@ -99,6 +99,7 @@
<Compile Include="App_Start\RegisterHubs.cs" />
<Compile Include="App_Start\TestEndPoint.cs" />
<Compile Include="Build\StartIISTask.cs" />
<Compile Include="Connections\AddGroupOnConnectedConnection.cs" />
<Compile Include="Connections\ConnectionThatUsesItems.cs" />
<Compile Include="Connections\FallbackToLongPollingConnection.cs" />
<Compile Include="Connections\FilteredConnection.cs" />
Expand Down

0 comments on commit 7ef74bc

Please sign in to comment.