diff --git a/SignalR.Tests/HubFacts.cs b/SignalR.Tests/HubFacts.cs index dec10d015c..bd2590100e 100644 --- a/SignalR.Tests/HubFacts.cs +++ b/SignalR.Tests/HubFacts.cs @@ -71,5 +71,22 @@ public void GenericTaskWithException() Assert.Equal("Exception of type 'System.Exception' was thrown.", ex.GetBaseException().Message); } + + [Fact] + public void Overloads() + { + var host = new MemoryHost(); + host.MapHubs(); + var connection = new Client.Hubs.HubConnection("http://foo/"); + + var hub = connection.CreateProxy("demo"); + + connection.Start(host).Wait(); + + hub.Invoke("Overload").Wait(); + int n = hub.Invoke("Overload", 1).Result; + + Assert.Equal(1, n); + } } } diff --git a/SignalR/Hubs/Lookup/ReflectedMethodDescriptorProvider.cs b/SignalR/Hubs/Lookup/ReflectedMethodDescriptorProvider.cs index e82df50813..cda195fc53 100644 --- a/SignalR/Hubs/Lookup/ReflectedMethodDescriptorProvider.cs +++ b/SignalR/Hubs/Lookup/ReflectedMethodDescriptorProvider.cs @@ -11,12 +11,10 @@ namespace SignalR.Hubs public class ReflectedMethodDescriptorProvider : IMethodDescriptorProvider { private readonly ConcurrentDictionary>> _methods; - private readonly ConcurrentDictionary _executableMethods; public ReflectedMethodDescriptorProvider() { _methods = new ConcurrentDictionary>>(StringComparer.OrdinalIgnoreCase); - _executableMethods = new ConcurrentDictionary(StringComparer.OrdinalIgnoreCase); } public IEnumerable GetMethods(HubDescriptor hub) @@ -59,10 +57,10 @@ public IEnumerable GetMethods(HubDescriptor hub) Invoker = oload.Invoke, Parameters = oload.GetParameters() .Select(p => new ParameterDescriptor - { - Name = p.Name, - Type = p.ParameterType, - }) + { + Name = p.Name, + Type = p.ParameterType, + }) .ToList() }), StringComparer.OrdinalIgnoreCase); @@ -70,32 +68,20 @@ public IEnumerable GetMethods(HubDescriptor hub) public bool TryGetMethod(HubDescriptor hub, string method, out MethodDescriptor descriptor, params JToken[] parameters) { - string hubMethodKey = hub.Name + "::" + method; + IEnumerable overloads; - if(!_executableMethods.TryGetValue(hubMethodKey, out descriptor)) + if (FetchMethodsFor(hub).TryGetValue(method, out overloads)) { - IEnumerable overloads; - - if(FetchMethodsFor(hub).TryGetValue(method, out overloads)) - { - var matches = overloads.Where(o => o.Matches(parameters)).ToList(); - - // If only one match is found, that is the "executable" version, otherwise none of the methods can be returned because we don't know which one was actually being targeted - descriptor = matches.Count == 1 ? matches[0] : null; - } - else - { - descriptor = null; - } - - // If an executable method was found, cache it for future lookups (NOTE: we don't cache null instances because it could be a surface area for DoS attack by supplying random method names to flood the cache) - if(descriptor != null) + var matches = overloads.Where(o => o.Matches(parameters)).ToList(); + if (matches.Count == 1) { - _executableMethods.TryAdd(hubMethodKey, descriptor); + descriptor = matches.First(); + return true; } } - return descriptor != null; + descriptor = null; + return false; } private static string GetMethodName(MethodInfo method)