diff --git a/SignalR.Client/Connection.cs b/SignalR.Client/Connection.cs index 311d48a89f..0c91d8f0eb 100644 --- a/SignalR.Client/Connection.cs +++ b/SignalR.Client/Connection.cs @@ -5,7 +5,6 @@ using System.Linq; using System.Net; using System.Reflection; -using System.Threading; using System.Threading.Tasks; using Newtonsoft.Json.Linq; using SignalR.Client.Http; @@ -21,8 +20,12 @@ public class Connection : IConnection private static Version _assemblyVersion; private IClientTransport _transport; - private ConnectionState _state; - private CancellationTokenSource _cancel; + + // The default connection state is disconnected + private ConnectionState _state = ConnectionState.Disconnected; + + // Used to synchornize state changes + private readonly object _stateLock = new object(); /// /// Occurs when the has received data from the server. @@ -135,14 +138,6 @@ public Connection(string url, string queryString) /// public string QueryString { get; private set; } - public CancellationToken CancellationToken - { - get - { - return _cancel.Token; - } - } - /// /// Gets the current of the connection. /// @@ -150,7 +145,10 @@ public ConnectionState State { get { - return _state; + lock (_stateLock) + { + return _state; + } } private set { @@ -166,18 +164,6 @@ private set } } - ConnectionState IConnection.State - { - get - { - return State; - } - set - { - State = value; - } - } - /// /// Starts the . /// @@ -209,15 +195,11 @@ public Task Start(IHttpClient httpClient) /// A task that represents when the connection has started. public virtual Task Start(IClientTransport transport) { - if (State == ConnectionState.Connected || - State == ConnectionState.Connecting) + if (!ChangeState(ConnectionState.Disconnected, ConnectionState.Connecting)) { return TaskAsyncHelper.Empty; } - State = ConnectionState.Connecting; - _cancel = new CancellationTokenSource(); - _transport = transport; return Negotiate(transport); @@ -271,8 +253,29 @@ private Task Negotiate(IClientTransport transport) private Task StartTransport(string data) { - return _transport.Start(this, _cancel.Token, data) - .Then(() => State = ConnectionState.Connected); + return _transport.Start(this, data) + .Then(() => + { + ChangeState(ConnectionState.Connecting, ConnectionState.Connected); + }); + } + + private bool ChangeState(ConnectionState oldState, ConnectionState newState) + { + return ((IConnection)this).ChangeState(oldState, newState); + } + + bool IConnection.ChangeState(ConnectionState oldState, ConnectionState newState) + { + lock (_stateLock) + { + if (_state == oldState) + { + State = newState; + return true; + } + return false; + } } private static void VerifyProtocolVersion(string versionString) @@ -294,14 +297,11 @@ public virtual void Stop() try { // Do nothing if the connection is offline - if (this.IsDisconnecting()) + if (State == ConnectionState.Disconnected) { return; } - State = ConnectionState.Disconnecting; - - _cancel.Cancel(throwOnFirstException: false); _transport.Stop(this); if (Closed != null) @@ -327,7 +327,7 @@ public Task Send(string data) Task IConnection.Send(string data) { - if (this.IsDisconnecting()) + if (State == ConnectionState.Disconnected) { // TODO: Update this error message throw new InvalidOperationException("Start must be called before data can be sent"); diff --git a/SignalR.Client/ConnectionExtensions.cs b/SignalR.Client/ConnectionExtensions.cs index 8e706878cd..50e6ed688c 100644 --- a/SignalR.Client/ConnectionExtensions.cs +++ b/SignalR.Client/ConnectionExtensions.cs @@ -16,18 +16,6 @@ public static T GetValue(this IConnection connection, string key) return default(T); } - public static bool IsDisconnecting(this IConnection connection) - { - return connection.State == ConnectionState.Disconnecting || - connection.State == ConnectionState.Disconnected; - } - - public static bool IsActive(this IConnection connection) - { - return connection.State == ConnectionState.Connected || - connection.State == ConnectionState.Connecting; - } - #if !WINDOWS_PHONE && !SILVERLIGHT && !NET35 public static IObservable AsObservable(this Connection connection) { diff --git a/SignalR.Client/ConnectionState.cs b/SignalR.Client/ConnectionState.cs index 73177e8fdb..99028df137 100644 --- a/SignalR.Client/ConnectionState.cs +++ b/SignalR.Client/ConnectionState.cs @@ -5,7 +5,6 @@ public enum ConnectionState Connecting, Connected, Reconnecting, - Disconnecting, Disconnected } } diff --git a/SignalR.Client/IConnection.cs b/SignalR.Client/IConnection.cs index 6804e59eac..2f1e0a8d29 100644 --- a/SignalR.Client/IConnection.cs +++ b/SignalR.Client/IConnection.cs @@ -16,7 +16,9 @@ public interface IConnection string ConnectionId { get; } string Url { get; } string QueryString { get; } - ConnectionState State { get; set; } + ConnectionState State { get; } + + bool ChangeState(ConnectionState oldState, ConnectionState newState); ICredentials Credentials { get; set; } CookieContainer CookieContainer { get; set; } diff --git a/SignalR.Client/Transports/AutoTransport.cs b/SignalR.Client/Transports/AutoTransport.cs index 87f4452749..2986d9d552 100644 --- a/SignalR.Client/Transports/AutoTransport.cs +++ b/SignalR.Client/Transports/AutoTransport.cs @@ -26,22 +26,22 @@ public Task Negotiate(IConnection connection) return HttpBasedTransport.GetNegotiationResponse(_httpClient, connection); } - public Task Start(IConnection connection, CancellationToken cancellationToken, string data) + public Task Start(IConnection connection, string data) { var tcs = new TaskCompletionSource(); // Resolve the transport - ResolveTransport(connection, cancellationToken, data, tcs, 0); + ResolveTransport(connection, data, tcs, 0); return tcs.Task; } - private void ResolveTransport(IConnection connection, CancellationToken cancellationToken, string data, TaskCompletionSource tcs, int index) + private void ResolveTransport(IConnection connection, string data, TaskCompletionSource tcs, int index) { // Pick the current transport IClientTransport transport = _transports[index]; - transport.Start(connection, cancellationToken, data).ContinueWith(task => + transport.Start(connection, data).ContinueWith(task => { if (task.IsFaulted) { @@ -61,7 +61,7 @@ private void ResolveTransport(IConnection connection, CancellationToken cancella if (next < _transports.Length) { // Try the next transport - ResolveTransport(connection, cancellationToken, data, tcs, next); + ResolveTransport(connection, data, tcs, next); } else { diff --git a/SignalR.Client/Transports/HttpBasedTransport.cs b/SignalR.Client/Transports/HttpBasedTransport.cs index 86fb93abd0..79bb3717fe 100644 --- a/SignalR.Client/Transports/HttpBasedTransport.cs +++ b/SignalR.Client/Transports/HttpBasedTransport.cs @@ -11,261 +11,252 @@ namespace SignalR.Client.Transports { - public abstract class HttpBasedTransport : IClientTransport - { - // The send query string - private const string _sendQueryString = "?transport={0}&connectionId={1}{2}"; - - // The transport name - protected readonly string _transport; + public abstract class HttpBasedTransport : IClientTransport + { + // The send query string + private const string _sendQueryString = "?transport={0}&connectionId={1}{2}"; - protected const string HttpRequestKey = "http.Request"; + // The transport name + protected readonly string _transport; - protected readonly IHttpClient _httpClient; - - public HttpBasedTransport(IHttpClient httpClient, string transport) - { - _httpClient = httpClient; - _transport = transport; - } + protected const string HttpRequestKey = "http.Request"; - public CancellationToken CancellationToken - { - get; - private set; - } - - public Task Negotiate(IConnection connection) - { - return GetNegotiationResponse(_httpClient, connection); - } - - internal static Task GetNegotiationResponse(IHttpClient httpClient, IConnection connection) - { - string negotiateUrl = connection.Url + "negotiate"; - - return httpClient.GetAsync(negotiateUrl, connection.PrepareRequest).Then(response => - { - string raw = response.ReadAsString(); - - if (raw == null) - { - throw new InvalidOperationException("Server negotiation failed."); - } - - return JsonConvert.DeserializeObject(raw); - }); - } - - public Task Start(IConnection connection, CancellationToken cancellationToken, string data) - { - var tcs = new TaskCompletionSource(); - - // Set the cancellation token for this operation - CancellationToken = cancellationToken; - - OnStart(connection, data, () => tcs.TrySetResult(null), exception => tcs.TrySetException(exception)); - - return tcs.Task; - } - - protected abstract void OnStart(IConnection connection, string data, Action initializeCallback, Action errorCallback); - - public Task Send(IConnection connection, string data) - { - string url = connection.Url + "send"; - string customQueryString = GetCustomQueryString(connection); - - url += String.Format(_sendQueryString, _transport, connection.ConnectionId, customQueryString); - - var postData = new Dictionary { - { "data", data } - }; - - return _httpClient.PostAsync(url, connection.PrepareRequest, postData).Then(response => - { - string raw = response.ReadAsString(); - - if (String.IsNullOrEmpty(raw)) - { - return default(T); - } - - return JsonConvert.DeserializeObject(raw); - }); - } - - protected string GetReceiveQueryString(IConnection connection, string data) - { - // ?transport={0}&connectionId={1}&messageId={2}&groups={3}&connectionData={4}{5} - var qsBuilder = new StringBuilder(); - qsBuilder.Append("?transport=" + _transport) - .Append("&connectionId=" + Uri.EscapeDataString(connection.ConnectionId)); - - if (connection.MessageId != null) - { - qsBuilder.Append("&messageId=" + connection.MessageId); - } - - if (connection.Groups != null && connection.Groups.Any()) - { - qsBuilder.Append("&groups=" + Uri.EscapeDataString(JsonConvert.SerializeObject(connection.Groups))); - } - - if (data != null) - { - qsBuilder.Append("&connectionData=" + data); - } - - string customQuery = GetCustomQueryString(connection); - - if (!String.IsNullOrEmpty(customQuery)) - { - qsBuilder.Append("&") - .Append(customQuery); - } - - return qsBuilder.ToString(); - } - - protected virtual Action PrepareRequest(IConnection connection) - { - return request => - { - // Setup the user agent along with any other defaults - connection.PrepareRequest(request); - - connection.Items[HttpRequestKey] = request; - }; - } - - public void Stop(IConnection connection) - { - var httpRequest = connection.GetValue(HttpRequestKey); - if (httpRequest != null) - { - try - { - OnBeforeAbort(connection); - - // Abort the server side connection - AbortConnection(connection); - - // Now abort the client connection - httpRequest.Abort(); - } - catch (NotImplementedException) - { - // If this isn't implemented then do nothing - } - } - } - - private void AbortConnection(IConnection connection) - { - string url = connection.Url + "abort" + String.Format(_sendQueryString, _transport, connection.ConnectionId, null); - - try - { - // Attempt to perform a clean disconnect, but only wait 2 seconds - _httpClient.PostAsync(url, connection.PrepareRequest).Wait(TimeSpan.FromSeconds(2)); - } - catch (Exception ex) - { - // Swallow any exceptions, but log them - Debug.WriteLine("Clean disconnect failed. " + ex.Unwrap().Message); - } - } - - - protected virtual void OnBeforeAbort(IConnection connection) - { - - } - - protected static void ProcessResponse(IConnection connection, string response, out bool timedOut, out bool disconnected) - { - timedOut = false; - disconnected = false; - - if (String.IsNullOrEmpty(response)) - { - return; - } - - if (connection.MessageId == null) - { - connection.MessageId = 0; - } - - try - { - var result = JValue.Parse(response); - - if (!result.HasValues) - { - return; - } - - timedOut = result.Value("TimedOut"); - disconnected = result.Value("Disconnect"); - - if (disconnected) - { - return; - } - - var messages = result["Messages"] as JArray; - if (messages != null) - { - foreach (JToken message in messages) - { - try - { - connection.OnReceived(message); - } - catch (Exception ex) - { + protected readonly IHttpClient _httpClient; + + public HttpBasedTransport(IHttpClient httpClient, string transport) + { + _httpClient = httpClient; + _transport = transport; + } + + public Task Negotiate(IConnection connection) + { + return GetNegotiationResponse(_httpClient, connection); + } + + internal static Task GetNegotiationResponse(IHttpClient httpClient, IConnection connection) + { + string negotiateUrl = connection.Url + "negotiate"; + + return httpClient.GetAsync(negotiateUrl, connection.PrepareRequest).Then(response => + { + string raw = response.ReadAsString(); + + if (raw == null) + { + throw new InvalidOperationException("Server negotiation failed."); + } + + return JsonConvert.DeserializeObject(raw); + }); + } + + public Task Start(IConnection connection, string data) + { + var tcs = new TaskCompletionSource(); + + OnStart(connection, data, () => tcs.TrySetResult(null), exception => tcs.TrySetException(exception)); + + return tcs.Task; + } + + protected abstract void OnStart(IConnection connection, string data, Action initializeCallback, Action errorCallback); + + public Task Send(IConnection connection, string data) + { + string url = connection.Url + "send"; + string customQueryString = GetCustomQueryString(connection); + + url += String.Format(_sendQueryString, _transport, connection.ConnectionId, customQueryString); + + var postData = new Dictionary { + { "data", data } + }; + + return _httpClient.PostAsync(url, connection.PrepareRequest, postData).Then(response => + { + string raw = response.ReadAsString(); + + if (String.IsNullOrEmpty(raw)) + { + return default(T); + } + + return JsonConvert.DeserializeObject(raw); + }); + } + + protected string GetReceiveQueryString(IConnection connection, string data) + { + // ?transport={0}&connectionId={1}&messageId={2}&groups={3}&connectionData={4}{5} + var qsBuilder = new StringBuilder(); + qsBuilder.Append("?transport=" + _transport) + .Append("&connectionId=" + Uri.EscapeDataString(connection.ConnectionId)); + + if (connection.MessageId != null) + { + qsBuilder.Append("&messageId=" + connection.MessageId); + } + + if (connection.Groups != null && connection.Groups.Any()) + { + qsBuilder.Append("&groups=" + Uri.EscapeDataString(JsonConvert.SerializeObject(connection.Groups))); + } + + if (data != null) + { + qsBuilder.Append("&connectionData=" + data); + } + + string customQuery = GetCustomQueryString(connection); + + if (!String.IsNullOrEmpty(customQuery)) + { + qsBuilder.Append("&") + .Append(customQuery); + } + + return qsBuilder.ToString(); + } + + protected virtual Action PrepareRequest(IConnection connection) + { + return request => + { + // Setup the user agent along with any other defaults + connection.PrepareRequest(request); + + connection.Items[HttpRequestKey] = request; + }; + } + + public void Stop(IConnection connection) + { + var httpRequest = connection.GetValue(HttpRequestKey); + if (httpRequest != null) + { + try + { + OnBeforeAbort(connection); + + // Abort the server side connection + AbortConnection(connection); + + // Now abort the client connection + httpRequest.Abort(); + } + catch (NotImplementedException) + { + // If this isn't implemented then do nothing + } + } + } + + private void AbortConnection(IConnection connection) + { + string url = connection.Url + "abort" + String.Format(_sendQueryString, _transport, connection.ConnectionId, null); + + try + { + // Attempt to perform a clean disconnect, but only wait 2 seconds + _httpClient.PostAsync(url, connection.PrepareRequest).Wait(TimeSpan.FromSeconds(2)); + } + catch (Exception ex) + { + // Swallow any exceptions, but log them + Debug.WriteLine("Clean disconnect failed. " + ex.Unwrap().Message); + } + } + + + protected virtual void OnBeforeAbort(IConnection connection) + { + + } + + protected static void ProcessResponse(IConnection connection, string response, out bool timedOut, out bool disconnected) + { + timedOut = false; + disconnected = false; + + if (String.IsNullOrEmpty(response)) + { + return; + } + + if (connection.MessageId == null) + { + connection.MessageId = 0; + } + + try + { + var result = JValue.Parse(response); + + if (!result.HasValues) + { + return; + } + + timedOut = result.Value("TimedOut"); + disconnected = result.Value("Disconnect"); + + if (disconnected) + { + return; + } + + var messages = result["Messages"] as JArray; + if (messages != null) + { + foreach (JToken message in messages) + { + try + { + connection.OnReceived(message); + } + catch (Exception ex) + { #if NET35 - Debug.WriteLine(String.Format(System.Globalization.CultureInfo.InvariantCulture, "Failed to process message: {0}", ex)); + Debug.WriteLine(String.Format(System.Globalization.CultureInfo.InvariantCulture, "Failed to process message: {0}", ex)); #else - Debug.WriteLine("Failed to process message: {0}", ex); + Debug.WriteLine("Failed to process message: {0}", ex); #endif - connection.OnError(ex); - } - } - - connection.MessageId = result["MessageId"].Value(); - - var transportData = result["TransportData"] as JObject; - - if (transportData != null) - { - var groups = (JArray)transportData["Groups"]; - if (groups != null) - { - connection.Groups = groups.Select(token => token.Value()); - } - } - } - } - catch (Exception ex) - { + connection.OnError(ex); + } + } + + connection.MessageId = result["MessageId"].Value(); + + var transportData = result["TransportData"] as JObject; + + if (transportData != null) + { + var groups = (JArray)transportData["Groups"]; + if (groups != null) + { + connection.Groups = groups.Select(token => token.Value()); + } + } + } + } + catch (Exception ex) + { #if NET35 - Debug.WriteLine(String.Format(System.Globalization.CultureInfo.InvariantCulture, "Failed to response: {0}", ex)); + Debug.WriteLine(String.Format(System.Globalization.CultureInfo.InvariantCulture, "Failed to response: {0}", ex)); #else - Debug.WriteLine("Failed to response: {0}", ex); + Debug.WriteLine("Failed to response: {0}", ex); #endif - connection.OnError(ex); - } - } - - private static string GetCustomQueryString(IConnection connection) - { - return String.IsNullOrEmpty(connection.QueryString) - ? "" - : "&" + connection.QueryString; - } - } + connection.OnError(ex); + } + } + + private static string GetCustomQueryString(IConnection connection) + { + return String.IsNullOrEmpty(connection.QueryString) + ? "" + : "&" + connection.QueryString; + } + } } diff --git a/SignalR.Client/Transports/IClientTransport.cs b/SignalR.Client/Transports/IClientTransport.cs index 89a107a86a..d41f437a3b 100644 --- a/SignalR.Client/Transports/IClientTransport.cs +++ b/SignalR.Client/Transports/IClientTransport.cs @@ -6,7 +6,7 @@ namespace SignalR.Client.Transports public interface IClientTransport { Task Negotiate(IConnection connection); - Task Start(IConnection connection, CancellationToken cancellationToken, string data); + Task Start(IConnection connection, string data); Task Send(IConnection connection, string data); void Stop(IConnection connection); } diff --git a/SignalR.Client/Transports/LongPollingTransport.cs b/SignalR.Client/Transports/LongPollingTransport.cs index e34f5cb41e..3eeebbdfec 100644 --- a/SignalR.Client/Transports/LongPollingTransport.cs +++ b/SignalR.Client/Transports/LongPollingTransport.cs @@ -32,8 +32,6 @@ protected override void OnStart(IConnection connection, string data, Action init private void PollingLoop(IConnection connection, string data, Action initializeCallback, Action errorCallback, bool raiseReconnect = false) { string url = connection.Url; - var reconnectTokenSource = new CancellationTokenSource(); - int reconnectFired = 0; // This is only necessary for the initial request where initializeCallback and errorCallback are non-null int callbackFired = 0; @@ -46,7 +44,10 @@ private void PollingLoop(IConnection connection, string data, Action initializeC { url += "reconnect"; - connection.State = ConnectionState.Reconnecting; + if (!connection.ChangeState(ConnectionState.Connected, ConnectionState.Reconnecting)) + { + return; + } } url += GetReceiveQueryString(connection, data); @@ -73,7 +74,7 @@ private void PollingLoop(IConnection connection, string data, Action initializeC { // If the timeout for the reconnect hasn't fired as yet just fire the // event here before any incoming messages are processed - FireReconnected(connection, reconnectTokenSource, ref reconnectFired); + FireReconnected(connection); } // Get the response @@ -100,9 +101,6 @@ private void PollingLoop(IConnection connection, string data, Action initializeC if (task.IsFaulted) { - // Cancel the previous reconnect event - reconnectTokenSource.Cancel(); - // Raise the reconnect event if we successfully reconnect after failing shouldRaiseReconnect = true; @@ -110,7 +108,7 @@ private void PollingLoop(IConnection connection, string data, Action initializeC Exception exception = task.Exception.Unwrap(); // If the error callback isn't null then raise it and don't continue polling - if (errorCallback != null && + if (errorCallback != null && Interlocked.Exchange(ref callbackFired, 1) == 0) { // Call the callback @@ -132,7 +130,7 @@ private void PollingLoop(IConnection connection, string data, Action initializeC // before polling again so we aren't hammering the server TaskAsyncHelper.Delay(_errorDelay).Then(() => { - if (!CancellationToken.IsCancellationRequested) + if (connection.State != ConnectionState.Disconnected) { PollingLoop(connection, data, @@ -146,7 +144,7 @@ private void PollingLoop(IConnection connection, string data, Action initializeC } else { - if (!CancellationToken.IsCancellationRequested) + if (connection.State != ConnectionState.Disconnected) { // Continue polling if there was no error PollingLoop(connection, @@ -173,7 +171,7 @@ private void PollingLoop(IConnection connection, string data, Action initializeC TaskAsyncHelper.Delay(ReconnectDelay).Then(() => { // Fire the reconnect event after the delay. This gives the - FireReconnected(connection, reconnectTokenSource, ref reconnectFired); + FireReconnected(connection); }); } } @@ -181,17 +179,12 @@ private void PollingLoop(IConnection connection, string data, Action initializeC /// /// /// - private static void FireReconnected(IConnection connection, CancellationTokenSource reconnectTokenSource, ref int reconnectedFired) + private static void FireReconnected(IConnection connection) { - if (!reconnectTokenSource.IsCancellationRequested) + // Mark the connection as connected + if (connection.ChangeState(ConnectionState.Reconnecting, ConnectionState.Connected)) { - if (Interlocked.Exchange(ref reconnectedFired, 1) == 0) - { - // Mark the connection as connected - connection.State = ConnectionState.Connected; - - connection.OnReconnected(); - } + connection.OnReconnected(); } } } diff --git a/SignalR.Client/Transports/ServerSentEventsTransport.cs b/SignalR.Client/Transports/ServerSentEventsTransport.cs index 62e8ffbd4b..8c9d964810 100644 --- a/SignalR.Client/Transports/ServerSentEventsTransport.cs +++ b/SignalR.Client/Transports/ServerSentEventsTransport.cs @@ -12,6 +12,7 @@ public class ServerSentEventsTransport : HttpBasedTransport { private int _initializedCalled; + private const string EventSourceKey = "eventSourceStream"; private static readonly TimeSpan ReconnectDelay = TimeSpan.FromSeconds(2); public ServerSentEventsTransport() @@ -37,11 +38,6 @@ protected override void OnStart(IConnection connection, string data, Action init private void Reconnect(IConnection connection, string data) { - if (CancellationToken.IsCancellationRequested) - { - return; - } - // Wait for a bit before reconnecting TaskAsyncHelper.Delay(ReconnectDelay).Then(() => { @@ -90,10 +86,8 @@ private void OpenConnection(IConnection connection, string data, Action initiali } } - if (reconnecting && !CancellationToken.IsCancellationRequested) + if (reconnecting && connection.ChangeState(ConnectionState.Connected, ConnectionState.Reconnecting)) { - connection.State = ConnectionState.Reconnecting; - // Retry Reconnect(connection, data); return; @@ -106,9 +100,8 @@ private void OpenConnection(IConnection connection, string data, Action initiali var eventSource = new EventSourceStreamReader(stream); bool retry = true; - - // When this fires close the event source - CancellationToken.Register(() => eventSource.Close()); + + connection.Items[EventSourceKey] = eventSource; eventSource.Opened = () => { @@ -117,11 +110,8 @@ private void OpenConnection(IConnection connection, string data, Action initiali initializeCallback(); } - if (reconnecting) + if (reconnecting && connection.ChangeState(ConnectionState.Reconnecting, ConnectionState.Connected)) { - // Change the status to connected - connection.State = ConnectionState.Connected; - // Raise the reconnect event if the connection comes back up connection.OnReconnected(); } @@ -153,11 +143,8 @@ private void OpenConnection(IConnection connection, string data, Action initiali { response.Close(); - if (retry && !CancellationToken.IsCancellationRequested) + if (retry && connection.ChangeState(ConnectionState.Connected, ConnectionState.Reconnecting)) { - // If we're retrying then just go again - connection.State = ConnectionState.Reconnecting; - Reconnect(connection, data); } else @@ -166,10 +153,7 @@ private void OpenConnection(IConnection connection, string data, Action initiali } }; - if (!CancellationToken.IsCancellationRequested) - { - eventSource.Start(); - } + eventSource.Start(); } }); @@ -188,5 +172,16 @@ private void OpenConnection(IConnection connection, string data, Action initiali }); } } + + protected override void OnBeforeAbort(IConnection connection) + { + var eventSourceStream = connection.GetValue(EventSourceKey); + if (eventSourceStream != null) + { + eventSourceStream.Close(); + } + + base.OnBeforeAbort(connection); + } } } diff --git a/SignalR.Tests/ConnectionFacts.cs b/SignalR.Tests/ConnectionFacts.cs index cce5ab12c7..48315c759e 100644 --- a/SignalR.Tests/ConnectionFacts.cs +++ b/SignalR.Tests/ConnectionFacts.cs @@ -53,7 +53,7 @@ public void FailedStartShouldNotBeActive() ConnectionId = "Something" })); - transport.Setup(m => m.Start(connection, It.IsAny(), null)) + transport.Setup(m => m.Start(connection, null)) .Returns(TaskAsyncHelper.FromError(new InvalidOperationException("Something failed."))); var aggEx = Assert.Throws(() => connection.Start(transport.Object).Wait());