Permalink
Browse files

Merge branch '0.5'

  • Loading branch information...
2 parents 92e96f0 + 52bd5e6 commit 59191e00b969a33ff3767894298910ba9dc0387b @davidfowl davidfowl committed May 1, 2012
@@ -1,4 +1,5 @@
-using System.Net;
+using System;
+using System.Net;
using System.Threading;
using System.Threading.Tasks;
using SignalR.Hosting.Self.Infrastructure;
@@ -8,11 +9,16 @@ namespace SignalR.Hosting.Self
public class HttpListenerResponseWrapper : IResponse
{
private readonly HttpListenerResponse _httpListenerResponse;
+ private readonly Action _onInitialWrite;
private readonly CancellationToken _cancellationToken;
- public HttpListenerResponseWrapper(HttpListenerResponse httpListenerResponse, CancellationToken cancellationToken)
+ private bool _ended;
+ private int _writeInitialized;
+
+ public HttpListenerResponseWrapper(HttpListenerResponse httpListenerResponse, Action onInitialWrite, CancellationToken cancellationToken)
{
_httpListenerResponse = httpListenerResponse;
+ _onInitialWrite = onInitialWrite;
_cancellationToken = cancellationToken;
}
@@ -32,18 +38,31 @@ public bool IsClientConnected
{
get
{
- return !_cancellationToken.IsCancellationRequested;
+ return !_ended && !_cancellationToken.IsCancellationRequested;
}
}
public Task WriteAsync(string data)
{
- return DoWrite(data).Then(response => response.OutputStream.Flush(), _httpListenerResponse);
- }
+ if (Interlocked.Exchange(ref _writeInitialized, 1) == 0)
+ {
+ _onInitialWrite();
+ }
+
+ return DoWrite(data).Then(response => response.OutputStream.Flush(), _httpListenerResponse)
+ .Catch(ex => _ended = true);
+ }
public Task EndAsync(string data)
{
- return DoWrite(data).Then(response => response.CloseSafe(), _httpListenerResponse);
+ return DoWrite(data).Then(response =>
+ {
+ response.CloseSafe();
+
+ // Mark the connection as ended after we close it
+ _ended = true;
+ },
+ _httpListenerResponse);
}
private Task DoWrite(string data)
@@ -80,44 +80,11 @@ private void ReceiveLoop()
{
return;
}
-
- var cts = new CancellationTokenSource();
-
- // Get the connection id value
- var connectionIdField = typeof(HttpListenerRequest).GetField("m_ConnectionId", BindingFlags.Instance | BindingFlags.NonPublic);
- if (_requestQueueHandle != null && connectionIdField != null)
- {
- ulong connectionId = (ulong)connectionIdField.GetValue(context.Request);
- // Create a nativeOverlapped callback so we can register for disconnect callback
- var overlapped = new Overlapped();
- var nativeOverlapped = overlapped.UnsafePack((errorCode, numBytes, pOVERLAP) =>
- {
- // Free the overlapped
- Overlapped.Free(pOVERLAP);
-
- // Mark the client as disconnected
- cts.Cancel();
- },
- null);
-
- uint hr = NativeMethods.HttpWaitForDisconnect(_requestQueueHandle, connectionId, nativeOverlapped);
-
- if (hr != NativeMethods.HttpErrors.ERROR_IO_PENDING &&
- hr != NativeMethods.HttpErrors.NO_ERROR)
- {
- // We got an unknown result so throw
- throw new InvalidOperationException("Unable to register disconnect callback");
- }
- }
- else
- {
- Debug.WriteLine("Unable to resolve requestQueue handle. Disconnect notifications will be ignored");
- }
-
+
ReceiveLoop();
// Process the request async
- ProcessRequestAsync(context, cts.Token).ContinueWith(task =>
+ ProcessRequestAsync(context).ContinueWith(task =>
{
if (task.IsFaulted)
{
@@ -133,20 +100,60 @@ private void ReceiveLoop()
}, null);
}
- private Task ProcessRequestAsync(HttpListenerContext context, CancellationToken token)
+ private void RegisterForDisconnect(HttpListenerContext context, Action disconnectCallback)
+ {
+ // Get the connection id value
+ FieldInfo connectionIdField = typeof(HttpListenerRequest).GetField("m_ConnectionId", BindingFlags.Instance | BindingFlags.NonPublic);
+ if (_requestQueueHandle != null && connectionIdField != null)
+ {
+ Debug.WriteLine("Server: Registering for disconnect");
+
+ ulong connectionId = (ulong)connectionIdField.GetValue(context.Request);
+ // Create a nativeOverlapped callback so we can register for disconnect callback
+ var overlapped = new Overlapped();
+ var nativeOverlapped = overlapped.UnsafePack((errorCode, numBytes, pOVERLAP) =>
+ {
+ Debug.WriteLine("Server: http.sys disconnect callback fired.");
+
+ // Free the overlapped
+ Overlapped.Free(pOVERLAP);
+
+ // Mark the client as disconnected
+ disconnectCallback();
+ },
+ null);
+
+ uint hr = NativeMethods.HttpWaitForDisconnect(_requestQueueHandle, connectionId, nativeOverlapped);
+
+ if (hr != NativeMethods.HttpErrors.ERROR_IO_PENDING &&
+ hr != NativeMethods.HttpErrors.NO_ERROR)
+ {
+ // We got an unknown result so throw
+ throw new InvalidOperationException("Unable to register disconnect callback");
+ }
+ }
+ else
+ {
+ Debug.WriteLine("Server: Unable to resolve requestQueue handle. Disconnect notifications will be ignored");
+ }
+ }
+
+ private Task ProcessRequestAsync(HttpListenerContext context)
{
try
{
- Debug.WriteLine("Incoming request to {0}.", context.Request.Url);
+ Debug.WriteLine("Server: Incoming request to {0}.", context.Request.Url);
PersistentConnection connection;
string path = ResolvePath(context.Request.Url);
if (TryGetConnection(path, out connection))
{
+ var cts = new CancellationTokenSource();
+
var request = new HttpListenerRequestWrapper(context.Request);
- var response = new HttpListenerResponseWrapper(context.Response, token);
+ var response = new HttpListenerResponseWrapper(context.Response, () => RegisterForDisconnect(context, cts.Cancel), cts.Token);
var hostContext = new HostContext(request, response, context.User);
if (OnProcessRequest != null)
@@ -11,10 +11,12 @@ namespace SignalR.Hubs
public class ReflectedMethodDescriptorProvider : IMethodDescriptorProvider
{
private readonly ConcurrentDictionary<string, IDictionary<string, IEnumerable<MethodDescriptor>>> _methods;
+ private readonly ConcurrentDictionary<string, MethodDescriptor> _executableMethods;
public ReflectedMethodDescriptorProvider()
{
_methods = new ConcurrentDictionary<string, IDictionary<string, IEnumerable<MethodDescriptor>>>(StringComparer.OrdinalIgnoreCase);
+ _executableMethods = new ConcurrentDictionary<string, MethodDescriptor>(StringComparer.OrdinalIgnoreCase);
}
public IEnumerable<MethodDescriptor> GetMethods(HubDescriptor hub)
@@ -68,20 +70,32 @@ public IEnumerable<MethodDescriptor> GetMethods(HubDescriptor hub)
public bool TryGetMethod(HubDescriptor hub, string method, out MethodDescriptor descriptor, params JToken[] parameters)
{
- IEnumerable<MethodDescriptor> overloads;
+ string hubMethodKey = hub.Name + "::" + method;
- if (FetchMethodsFor(hub).TryGetValue(method, out overloads))
+ if(!_executableMethods.TryGetValue(hubMethodKey, out descriptor))
{
- var matches = overloads.Where(o => o.Matches(parameters)).ToList();
- if (matches.Count == 1)
+ IEnumerable<MethodDescriptor> overloads;
+
+ if(FetchMethodsFor(hub).TryGetValue(method, out overloads))
+ {
+ var matches = overloads.Where(o => o.Matches(parameters)).ToList();
+
+ // If only one match is found, that is the "executable" version, otherwise none of the methods can be returned because we don't know which one was actually being targeted
+ descriptor = matches.Count == 1 ? matches[0] : null;
+ }
+ else
+ {
+ descriptor = null;
+ }
+
+ // If an executable method was found, cache it for future lookups (NOTE: we don't cache null instances because it could be a surface area for DoS attack by supplying random method names to flood the cache)
+ if(descriptor != null)
{
- descriptor = matches.First();
- return true;
+ _executableMethods.TryAdd(hubMethodKey, descriptor);
}
}
- descriptor = null;
- return false;
+ return descriptor != null;
}
private static string GetMethodName(MethodInfo method)
@@ -34,6 +34,21 @@ public void RemoveWithLock(T item)
}
}
+ public void RemoveWithLock(Predicate<T> match)
+ {
+ try
+ {
+ _listLock.EnterWriteLock();
+
+ // REVIEW: Should we only lock if there's any matches?
+ RemoveAll(match);
+ }
+ finally
+ {
+ _listLock.ExitWriteLock();
+ }
+ }
+
public List<T> CopyWithLock()
{
try
@@ -344,15 +344,7 @@ private void RemoveExpiredEntries(object state)
// Remove all the expired ones
foreach (var entry in entries)
{
- var messages = entry.Value.CopyWithLock();
-
- foreach (var item in messages)
- {
- if (item.Expired)
- {
- entry.Value.RemoveWithLock(item);
- }
- }
+ entry.Value.RemoveWithLock(item => item.Expired);
}
}
catch (Exception ex)
@@ -18,20 +18,20 @@ static Message()
public string SignalKey { get; set; }
public object Value { get; private set; }
public DateTime Created { get; private set; }
+ private DateTime ExpiresAt { get; set; }
public bool Expired
{
get
{
- // TODO: Handle disconnect timeout
- return DateTime.Now.Subtract(Created) >= ExpiresAfter;
+ return DateTime.UtcNow >= ExpiresAt;
}
}
private Message() { }
public Message(string signalKey, object value)
- : this(signalKey, value, DateTime.Now)
+ : this(signalKey, value, DateTime.UtcNow)
{
}
@@ -41,6 +41,7 @@ public Message(string signalKey, object value, DateTime created)
SignalKey = signalKey;
Value = value;
Created = created;
- }
+ ExpiresAt = created.Add(ExpiresAfter);
+ }
}
}
@@ -40,6 +40,23 @@ public static Task Empty
return task;
}
+ public static TTask Catch<TTask>(this TTask task, Action<Exception> handler) where TTask : Task
+ {
+ if (task != null && task.Status != TaskStatus.RanToCompletion)
+ {
+ task.ContinueWith(innerTask =>
+ {
+ var ex = innerTask.Exception;
+ // observe Exception
+#if !WINDOWS_PHONE && !SILVERLIGHT && !NETFX_CORE
+ Trace.TraceError("SignalR exception thrown by Task: {0}", ex);
+#endif
+ handler(ex);
+ }, TaskContinuationOptions.OnlyOnFaulted);
+ }
+ return task;
+ }
+
public static void ContinueWithNotComplete(this Task task, TaskCompletionSource<object> tcs)
{
task.ContinueWith(t =>
@@ -14,6 +14,7 @@ static void Main(string[] args)
string url = "http://*:8081/";
var server = new Server(url);
+ server.Configuration.DisconnectTimeout = TimeSpan.Zero;
// Map connections
server.MapConnection<MyConnection>("/echo")
@@ -23,21 +24,35 @@ static void Main(string[] args)
server.Start();
Console.WriteLine("Server running on {0}", url);
-
- Console.ReadKey();
+
+ while (true)
+ {
+ ConsoleKeyInfo ki = Console.ReadKey(true);
+ if (ki.Key == ConsoleKey.X)
+ {
+ break;
+ }
+ }
}
public class MyConnection : PersistentConnection
{
protected override Task OnConnectedAsync(IRequest request, string connectionId)
{
- return Connection.Broadcast(String.Format("{0} connected from {1}", connectionId, request.Headers["User-Agent"]));
+ Console.WriteLine("{0} connected", connectionId);
+ return base.OnConnectedAsync(request, connectionId);
}
protected override Task OnReceivedAsync(string connectionId, string data)
{
return Connection.Broadcast(data);
}
+
+ protected override Task OnDisconnectAsync(string connectionId)
+ {
+ Console.WriteLine("{0} left", connectionId);
+ return base.OnDisconnectAsync(connectionId);
+ }
}
}
}

0 comments on commit 59191e0

Please sign in to comment.