Permalink
Browse files

Fixed several issues with ServerSentEventsTransport and Stopping the …

…Connection.

- Associate the lifetime of a connection with a cancellation token.
- Flow that cancellation token through to the transport so that it can be used to
detect when to stop the underlying connection.
  • Loading branch information...
1 parent 1f86f3b commit fb232f5a2b75f1c1d742bd16fe1eee197d49f110 @davidfowl davidfowl committed Jun 8, 2012
@@ -5,6 +5,7 @@
using System.Linq;
using System.Net;
using System.Reflection;
+using System.Threading;
using System.Threading.Tasks;
using Newtonsoft.Json.Linq;
using SignalR.Client.Http;
@@ -21,6 +22,7 @@ public class Connection : IConnection
private IClientTransport _transport;
private ConnectionState _state;
+ private CancellationTokenSource _cancel;
/// <summary>
/// Occurs when the <see cref="Connection"/> has received data from the server.
@@ -133,6 +135,14 @@ public Connection(string url, string queryString)
/// </summary>
public string QueryString { get; private set; }
+ public CancellationToken CancellationToken
+ {
+ get
+ {
+ return _cancel.Token;
+ }
+ }
+
/// <summary>
/// Gets the current <see cref="ConnectionState"/> of the connection.
/// </summary>
@@ -206,6 +216,7 @@ public virtual Task Start(IClientTransport transport)
}
State = ConnectionState.Connecting;
+ _cancel = new CancellationTokenSource();
_transport = transport;
@@ -260,7 +271,7 @@ private Task Negotiate(IClientTransport transport)
private Task StartTransport(string data)
{
- return _transport.Start(this, data)
+ return _transport.Start(this, _cancel.Token, data)
.Then(() => State = ConnectionState.Connected);
}
@@ -290,6 +301,7 @@ public virtual void Stop()
State = ConnectionState.Disconnecting;
+ _cancel.Cancel(throwOnFirstException: false);
_transport.Stop(this);
if (Closed != null)
@@ -1,4 +1,5 @@
using System.Diagnostics;
+using System.Threading;
using System.Threading.Tasks;
using SignalR.Client.Http;
@@ -25,22 +26,22 @@ public Task<NegotiationResponse> Negotiate(IConnection connection)
return HttpBasedTransport.GetNegotiationResponse(_httpClient, connection);
}
- public Task Start(IConnection connection, string data)
+ public Task Start(IConnection connection, CancellationToken cancellationToken, string data)
{
var tcs = new TaskCompletionSource<object>();
// Resolve the transport
- ResolveTransport(connection, data, tcs, 0);
+ ResolveTransport(connection, cancellationToken, data, tcs, 0);
return tcs.Task;
}
- private void ResolveTransport(IConnection connection, string data, TaskCompletionSource<object> tcs, int index)
+ private void ResolveTransport(IConnection connection, CancellationToken cancellationToken, string data, TaskCompletionSource<object> tcs, int index)
{
// Pick the current transport
IClientTransport transport = _transports[index];
- transport.Start(connection, data).ContinueWith(task =>
+ transport.Start(connection, cancellationToken, data).ContinueWith(task =>
{
if (task.IsFaulted)
{
@@ -56,7 +57,7 @@ private void ResolveTransport(IConnection connection, string data, TaskCompletio
if (next < _transports.Length)
{
// Try the next transport
- ResolveTransport(connection, data, tcs, next);
+ ResolveTransport(connection, cancellationToken, data, tcs, next);
}
else
{
@@ -2,7 +2,8 @@
using System.Collections.Generic;
using System.Diagnostics;
using System.Linq;
-using System.Net;
+using System.Text;
+using System.Threading;
using System.Threading.Tasks;
using Newtonsoft.Json;
using Newtonsoft.Json.Linq;
@@ -12,9 +13,6 @@ namespace SignalR.Client.Transports
{
public abstract class HttpBasedTransport : IClientTransport
{
- // The receive query string
- private const string _receiveQueryString = "?transport={0}&connectionId={1}&messageId={2}&groups={3}&connectionData={4}{5}";
-
// The send query string
private const string _sendQueryString = "?transport={0}&connectionId={1}{2}";
@@ -31,6 +29,12 @@ public HttpBasedTransport(IHttpClient httpClient, string transport)
_transport = transport;
}
+ public CancellationToken CancellationToken
+ {
+ get;
+ private set;
+ }
+
public Task<NegotiationResponse> Negotiate(IConnection connection)
{
return GetNegotiationResponse(_httpClient, connection);
@@ -53,10 +57,13 @@ internal static Task<NegotiationResponse> GetNegotiationResponse(IHttpClient htt
});
}
- public Task Start(IConnection connection, string data)
+ public Task Start(IConnection connection, CancellationToken cancellationToken, string data)
{
var tcs = new TaskCompletionSource<object>();
+ // Set the cancellation token for this operation
+ CancellationToken = cancellationToken;
+
OnStart(connection, data, () => tcs.TrySetResult(null), exception => tcs.TrySetException(exception));
return tcs.Task;
@@ -90,13 +97,35 @@ public Task<T> Send<T>(IConnection connection, string data)
protected string GetReceiveQueryString(IConnection connection, string data)
{
- return String.Format(_receiveQueryString,
- _transport,
- Uri.EscapeDataString(connection.ConnectionId),
- Convert.ToString(connection.MessageId),
- Uri.EscapeDataString(JsonConvert.SerializeObject(connection.Groups)),
- data,
- GetCustomQueryString(connection));
+ // ?transport={0}&connectionId={1}&messageId={2}&groups={3}&connectionData={4}{5}
+ var qsBuilder = new StringBuilder();
+ qsBuilder.Append("?transport=" + _transport)
+ .Append("&connectionId=" + Uri.EscapeDataString(connection.ConnectionId));
+
+ if (connection.MessageId != null)
+ {
+ qsBuilder.Append("&messageId=" + connection.MessageId);
+ }
+
+ if (connection.Groups != null && connection.Groups.Any())
+ {
+ qsBuilder.Append("&groups=" + Uri.EscapeDataString(JsonConvert.SerializeObject(connection.Groups)));
+ }
+
+ if (data != null)
+ {
+ qsBuilder.Append("&connectionData=" + data);
+ }
+
+ string customQuery = GetCustomQueryString(connection);
+
+ if (!String.IsNullOrEmpty(customQuery))
+ {
+ qsBuilder.Append("&")
+ .Append(customQuery);
+ }
+
+ return qsBuilder.ToString();
}
protected virtual Action<IRequest> PrepareRequest(IConnection connection)
@@ -141,7 +170,7 @@ private void AbortConnection(IConnection connection)
// Attempt to perform a clean disconnect, but only wait 2 seconds
_httpClient.PostAsync(url, connection.PrepareRequest).Wait(TimeSpan.FromSeconds(2));
}
- catch(Exception ex)
+ catch (Exception ex)
{
// Swallow any exceptions, but log them
Debug.WriteLine("Clean disconnect failed. " + ex.Unwrap().Message);
@@ -1,11 +1,12 @@
using System.Threading.Tasks;
+using System.Threading;
namespace SignalR.Client.Transports
{
public interface IClientTransport
{
Task<NegotiationResponse> Negotiate(IConnection connection);
- Task Start(IConnection connection, string data);
+ Task Start(IConnection connection, CancellationToken cancellationToken, string data);
Task<T> Send<T>(IConnection connection, string data);
void Stop(IConnection connection);
}
@@ -124,7 +124,7 @@ private void PollingLoop(IConnection connection, string data, Action initializeC
// before polling again so we aren't hammering the server
TaskAsyncHelper.Delay(_errorDelay).Then(() =>
{
- if (!connection.IsDisconnecting())
+ if (!CancellationToken.IsCancellationRequested)
{
PollingLoop(connection,
data,
@@ -138,7 +138,7 @@ private void PollingLoop(IConnection connection, string data, Action initializeC
}
else
{
- if (!connection.IsDisconnecting())
+ if (!CancellationToken.IsCancellationRequested)
{
// Continue polling if there was no error
PollingLoop(connection,
@@ -18,6 +18,7 @@ public class EventSourceStreamReader
private int _reading;
private Action _setOpened;
+ private readonly CancellationToken _cancellationToken;
/// <summary>
/// Invoked when the connection is open.
@@ -43,6 +44,7 @@ public class EventSourceStreamReader
///
/// </summary>
/// <param name="stream"></param>
+ /// <param name="cancellationToken"></param>
public EventSourceStreamReader(Stream stream)
{
_stream = stream;
@@ -36,7 +36,7 @@ protected override void OnStart(IConnection connection, string data, Action init
private void Reconnect(IConnection connection, string data)
{
- if (connection.IsDisconnecting())
+ if (CancellationToken.IsCancellationRequested)
{
return;
}
@@ -85,7 +85,7 @@ private void OpenConnection(IConnection connection, string data, Action initiali
}
}
- if (reconnecting)
+ if (reconnecting && !CancellationToken.IsCancellationRequested)
{
connection.State = ConnectionState.Reconnecting;
@@ -96,13 +96,15 @@ private void OpenConnection(IConnection connection, string data, Action initiali
}
else
{
- // Get the reseponse stream and read it for messages
var response = task.Result;
var stream = response.GetResponseStream();
var eventSource = new EventSourceStreamReader(stream);
bool retry = true;
+ // When this fires close the event source
+ CancellationToken.Register(() => eventSource.Close());
+
eventSource.Opened = () =>
{
if (Interlocked.CompareExchange(ref _initializedCalled, 1, 0) == 0)
@@ -146,7 +148,7 @@ private void OpenConnection(IConnection connection, string data, Action initiali
{
response.Close();
- if (retry)
+ if (retry && !CancellationToken.IsCancellationRequested)
{
// If we're retrying then just go again
connection.State = ConnectionState.Reconnecting;
@@ -13,7 +13,7 @@ namespace SignalR.Hosting.Memory
public class MemoryHost : RoutingHost, IHttpClient
{
public MemoryHost()
- : base()
+ : this(new DefaultDependencyResolver())
{
}
@@ -1,4 +1,5 @@
using System;
+using System.Threading;
using Moq;
using SignalR.Client.Transports;
using Xunit;
@@ -52,7 +53,7 @@ public void FailedStartShouldNotBeActive()
ConnectionId = "Something"
}));
- transport.Setup(m => m.Start(connection, null))
+ transport.Setup(m => m.Start(connection, It.IsAny<CancellationToken>(), null))
.Returns(TaskAsyncHelper.FromError(new InvalidOperationException("Something failed.")));
var aggEx = Assert.Throws<AggregateException>(() => connection.Start(transport.Object).Wait());
@@ -24,6 +24,8 @@ public void ReadingState()
var result = hub.Invoke<string>("ReadStateValue").Result;
Assert.Equal("test", result);
+
+ connection.Stop();
}
[Fact]
@@ -40,6 +42,8 @@ public void SettingState()
Assert.Equal("test", result);
Assert.Equal("test", hub["Company"]);
+
+ connection.Stop();
}
[Fact]
@@ -56,6 +60,7 @@ public void GetValueFromServer()
var result = hub.Invoke<int>("GetValue").Result;
Assert.Equal(10, result);
+ connection.Stop();
}
[Fact]
@@ -72,6 +77,7 @@ public void TaskWithException()
var ex = Assert.Throws<AggregateException>(() => hub.Invoke("TaskWithException").Wait());
Assert.Equal("Exception of type 'System.Exception' was thrown.", ex.GetBaseException().Message);
+ connection.Stop();
}
[Fact]
@@ -88,6 +94,7 @@ public void GenericTaskWithException()
var ex = Assert.Throws<AggregateException>(() => hub.Invoke("GenericTaskWithException").Wait());
Assert.Equal("Exception of type 'System.Exception' was thrown.", ex.GetBaseException().Message);
+ connection.Stop();
}
[Fact]
@@ -105,6 +112,7 @@ public void Overloads()
int n = hub.Invoke<int>("Overload", 1).Result;
Assert.Equal(1, n);
+ connection.Stop();
}
[Fact]
@@ -121,6 +129,7 @@ public void UnsupportedOverloads()
var ex = Assert.Throws<InvalidOperationException>(() => hub.Invoke("UnsupportedOverload", 13177).Wait());
Assert.Equal("'UnsupportedOverload' method could not be resolved.", ex.GetBaseException().Message);
+ connection.Stop();
}
[Fact]
@@ -145,6 +154,7 @@ public void ChangeHubUrl()
hub.Invoke("DynamicTask").Wait();
Assert.True(wh.WaitOne(TimeSpan.FromSeconds(5)));
+ connection.Stop();
}
}
}
Oops, something went wrong.

0 comments on commit fb232f5

Please sign in to comment.