Skip to content

Commit

Permalink
Implement SignalR.Client's connection state management with thread sa…
Browse files Browse the repository at this point in the history
…fety in mind.

- Check to make sure we're doing valid state transitions.
- Made State get only and added ChangeSet method which makes sure
  state transitions are valid.
Fixes #474
  • Loading branch information
davidfowl committed Jun 22, 2012
1 parent e7be16f commit 74a128a
Show file tree
Hide file tree
Showing 10 changed files with 315 additions and 347 deletions.
72 changes: 36 additions & 36 deletions SignalR.Client/Connection.cs
Expand Up @@ -5,7 +5,6 @@
using System.Linq;
using System.Net;
using System.Reflection;
using System.Threading;
using System.Threading.Tasks;
using Newtonsoft.Json.Linq;
using SignalR.Client.Http;
Expand All @@ -21,8 +20,12 @@ public class Connection : IConnection
private static Version _assemblyVersion;

private IClientTransport _transport;
private ConnectionState _state;
private CancellationTokenSource _cancel;

// The default connection state is disconnected
private ConnectionState _state = ConnectionState.Disconnected;

// Used to synchornize state changes
private readonly object _stateLock = new object();

/// <summary>
/// Occurs when the <see cref="Connection"/> has received data from the server.
Expand Down Expand Up @@ -135,22 +138,17 @@ public Connection(string url, string queryString)
/// </summary>
public string QueryString { get; private set; }

public CancellationToken CancellationToken
{
get
{
return _cancel.Token;
}
}

/// <summary>
/// Gets the current <see cref="ConnectionState"/> of the connection.
/// </summary>
public ConnectionState State
{
get
{
return _state;
lock (_stateLock)
{
return _state;
}
}
private set
{
Expand All @@ -166,18 +164,6 @@ private set
}
}

ConnectionState IConnection.State
{
get
{
return State;
}
set
{
State = value;
}
}

/// <summary>
/// Starts the <see cref="Connection"/>.
/// </summary>
Expand Down Expand Up @@ -209,15 +195,11 @@ public Task Start(IHttpClient httpClient)
/// <returns>A task that represents when the connection has started.</returns>
public virtual Task Start(IClientTransport transport)
{
if (State == ConnectionState.Connected ||
State == ConnectionState.Connecting)
if (!ChangeState(ConnectionState.Disconnected, ConnectionState.Connecting))
{
return TaskAsyncHelper.Empty;
}

State = ConnectionState.Connecting;
_cancel = new CancellationTokenSource();

_transport = transport;

return Negotiate(transport);
Expand Down Expand Up @@ -271,8 +253,29 @@ private Task Negotiate(IClientTransport transport)

private Task StartTransport(string data)
{
return _transport.Start(this, _cancel.Token, data)
.Then(() => State = ConnectionState.Connected);
return _transport.Start(this, data)
.Then(() =>
{
ChangeState(ConnectionState.Connecting, ConnectionState.Connected);
});
}

private bool ChangeState(ConnectionState oldState, ConnectionState newState)
{
return ((IConnection)this).ChangeState(oldState, newState);
}

bool IConnection.ChangeState(ConnectionState oldState, ConnectionState newState)
{
lock (_stateLock)
{
if (_state == oldState)
{
State = newState;
return true;
}
return false;
}
}

private static void VerifyProtocolVersion(string versionString)
Expand All @@ -294,14 +297,11 @@ public virtual void Stop()
try
{
// Do nothing if the connection is offline
if (this.IsDisconnecting())
if (State == ConnectionState.Disconnected)
{
return;
}

State = ConnectionState.Disconnecting;

_cancel.Cancel(throwOnFirstException: false);
_transport.Stop(this);

if (Closed != null)
Expand All @@ -327,7 +327,7 @@ public Task Send(string data)

Task<T> IConnection.Send<T>(string data)
{
if (this.IsDisconnecting())
if (State == ConnectionState.Disconnected)
{
// TODO: Update this error message
throw new InvalidOperationException("Start must be called before data can be sent");
Expand Down
12 changes: 0 additions & 12 deletions SignalR.Client/ConnectionExtensions.cs
Expand Up @@ -16,18 +16,6 @@ public static T GetValue<T>(this IConnection connection, string key)
return default(T);
}

public static bool IsDisconnecting(this IConnection connection)
{
return connection.State == ConnectionState.Disconnecting ||
connection.State == ConnectionState.Disconnected;
}

public static bool IsActive(this IConnection connection)
{
return connection.State == ConnectionState.Connected ||
connection.State == ConnectionState.Connecting;
}

#if !WINDOWS_PHONE && !SILVERLIGHT && !NET35
public static IObservable<string> AsObservable(this Connection connection)
{
Expand Down
1 change: 0 additions & 1 deletion SignalR.Client/ConnectionState.cs
Expand Up @@ -5,7 +5,6 @@ public enum ConnectionState
Connecting,
Connected,
Reconnecting,
Disconnecting,
Disconnected
}
}
4 changes: 3 additions & 1 deletion SignalR.Client/IConnection.cs
Expand Up @@ -16,7 +16,9 @@ public interface IConnection
string ConnectionId { get; }
string Url { get; }
string QueryString { get; }
ConnectionState State { get; set; }
ConnectionState State { get; }

bool ChangeState(ConnectionState oldState, ConnectionState newState);

ICredentials Credentials { get; set; }
CookieContainer CookieContainer { get; set; }
Expand Down
10 changes: 5 additions & 5 deletions SignalR.Client/Transports/AutoTransport.cs
Expand Up @@ -26,22 +26,22 @@ public Task<NegotiationResponse> Negotiate(IConnection connection)
return HttpBasedTransport.GetNegotiationResponse(_httpClient, connection);
}

public Task Start(IConnection connection, CancellationToken cancellationToken, string data)
public Task Start(IConnection connection, string data)
{
var tcs = new TaskCompletionSource<object>();

// Resolve the transport
ResolveTransport(connection, cancellationToken, data, tcs, 0);
ResolveTransport(connection, data, tcs, 0);

return tcs.Task;
}

private void ResolveTransport(IConnection connection, CancellationToken cancellationToken, string data, TaskCompletionSource<object> tcs, int index)
private void ResolveTransport(IConnection connection, string data, TaskCompletionSource<object> tcs, int index)
{
// Pick the current transport
IClientTransport transport = _transports[index];

transport.Start(connection, cancellationToken, data).ContinueWith(task =>
transport.Start(connection, data).ContinueWith(task =>
{
if (task.IsFaulted)
{
Expand All @@ -61,7 +61,7 @@ private void ResolveTransport(IConnection connection, CancellationToken cancella
if (next < _transports.Length)
{
// Try the next transport
ResolveTransport(connection, cancellationToken, data, tcs, next);
ResolveTransport(connection, data, tcs, next);
}
else
{
Expand Down

0 comments on commit 74a128a

Please sign in to comment.