Skip to content

Commit

Permalink
Disable cross domain requests by default.
Browse files Browse the repository at this point in the history
- Added EnableCrossDomain flag to ConnectionConfiguration.
- Reject JSONP and cross domain requests with the origin header by default
  if EnableCrossDomain isn't true.
- Refactored HubDispatcherHandler and PersistentConnectionHandler to take
  respective configuration objects.

#1306
  • Loading branch information
davidfowl committed Jan 12, 2013
1 parent ecf62de commit 034b4a8
Show file tree
Hide file tree
Showing 8 changed files with 92 additions and 33 deletions.
5 changes: 5 additions & 0 deletions src/Microsoft.AspNet.SignalR.Core/ConnectionConfiguration.cs
Expand Up @@ -13,5 +13,10 @@ public IDependencyResolver Resolver
get { return _resolver ?? GlobalHost.DependencyResolver; }
set { _resolver = value; }
}

/// <summary>
/// Determines if browsers can make cross domain requests to SignalR endpoints.
/// </summary>
public bool EnableCrossDomain { get; set; }
}
}
64 changes: 48 additions & 16 deletions src/Microsoft.AspNet.SignalR.Owin/Handlers/CallHandler.cs
Expand Up @@ -2,7 +2,6 @@

using System;
using System.Collections.Generic;
using System.Linq;
using System.Threading;
using System.Threading.Tasks;
using Microsoft.AspNet.SignalR.Hosting;
Expand All @@ -12,18 +11,16 @@ namespace Microsoft.AspNet.SignalR.Owin
{
public class CallHandler
{
private readonly IDependencyResolver _resolver;
private readonly ConnectionConfiguration _configuration;
private readonly PersistentConnection _connection;

private static readonly string[] AllowCredentialsTrue = new[] { "true" };

private static bool _supportWebSockets;
private static bool _supportWebSocketsInitialized;
private static object _supportWebSocketsLock = new object();

public CallHandler(IDependencyResolver resolver, PersistentConnection connection)
public CallHandler(ConnectionConfiguration configuration, PersistentConnection connection)
{
_resolver = resolver;
_configuration = configuration;
_connection = connection;
}

Expand All @@ -33,12 +30,29 @@ public Task Invoke(IDictionary<string, object> environment)
var serverResponse = new ServerResponse(environment);
var hostContext = new HostContext(serverRequest, serverResponse);

// Add CORS support
var origins = serverRequest.RequestHeaders.GetHeaders("Origin");
if (origins != null && origins.Any(origin => !String.IsNullOrEmpty(origin)))
string origin = serverRequest.RequestHeaders.GetHeader("Origin");

if (_configuration.EnableCrossDomain)
{
serverResponse.ResponseHeaders["Access-Control-Allow-Origin"] = origins;
serverResponse.ResponseHeaders["Access-Control-Allow-Credentials"] = AllowCredentialsTrue;
// Add CORS response headers support
if (!String.IsNullOrEmpty(origin))
{
serverResponse.ResponseHeaders.SetHeader("Access-Control-Allow-Origin", origin);
serverResponse.ResponseHeaders.SetHeader("Access-Control-Allow-Credentials", "true");
}
}
else
{
string callback = serverRequest.QueryString["callback"];

// If it's a JSONP request and we're not allowing cross domain requests then block it
// If there's an origin header and it's not a same origin request then block it.

if (!String.IsNullOrEmpty(callback) ||
(!String.IsNullOrEmpty(origin) && !IsSameOrigin(serverRequest.Url, origin)))
{
return EndResponse(environment, 403, "Forbidden");
}
}

hostContext.Items[HostConstants.SupportsWebSockets] = LazyInitializer.EnsureInitialized(
Expand All @@ -53,21 +67,39 @@ public Task Invoke(IDictionary<string, object> environment)
serverRequest.DisableRequestBuffering();
serverResponse.DisableResponseBuffering();

_connection.Initialize(_resolver, hostContext);
_connection.Initialize(_configuration.Resolver, hostContext);

if (!_connection.Authorize(serverRequest))
{
// If we failed to authorize the request then return a 403 since the request
// can't do anything
environment[OwinConstants.ResponseStatusCode] = 403;
environment[OwinConstants.ResponseReasonPhrase] = "Forbidden";

return TaskAsyncHelper.Empty;
return EndResponse(environment, 403, "Forbidden");
}
else
{
return _connection.ProcessRequest(hostContext);
}
}

private static Task EndResponse(IDictionary<string, object> environment, int statusCode, string reason)
{
environment[OwinConstants.ResponseStatusCode] = statusCode;
environment[OwinConstants.ResponseReasonPhrase] = reason;

return TaskAsyncHelper.Empty;
}

private static bool IsSameOrigin(Uri requestUri, string origin)
{
Uri originUri;
if (!Uri.TryCreate(origin.Trim(), UriKind.Absolute, out originUri))
{
return false;
}

return (requestUri.Scheme == originUri.Scheme) &&
(requestUri.Host == originUri.Host) &&
(requestUri.Port == originUri.Port);
}
}
}
Expand Up @@ -13,15 +13,13 @@ public class HubDispatcherHandler
{
private readonly AppFunc _next;
private readonly string _path;
private readonly bool _enableJavaScriptProxies;
private readonly IDependencyResolver _resolver;
private readonly HubConfiguration _configuration;

public HubDispatcherHandler(AppFunc next, string path, bool enableJavaScriptProxies, IDependencyResolver resolver)
public HubDispatcherHandler(AppFunc next, string path, HubConfiguration configuration)
{
_next = next;
_path = path;
_enableJavaScriptProxies = enableJavaScriptProxies;
_resolver = resolver;
_configuration = configuration;
}

public Task Invoke(IDictionary<string, object> environment)
Expand All @@ -33,9 +31,9 @@ public Task Invoke(IDictionary<string, object> environment)
}

var pathBase = environment.Get<string>(OwinConstants.RequestPathBase);
var dispatcher = new HubDispatcher(pathBase + _path, _enableJavaScriptProxies);
var dispatcher = new HubDispatcher(pathBase + _path, _configuration.EnableJavaScriptProxies);

var handler = new CallHandler(_resolver, dispatcher);
var handler = new CallHandler(_configuration, dispatcher);
return handler.Invoke(environment);
}
}
Expand Down
Expand Up @@ -14,14 +14,14 @@ public class PersistentConnectionHandler
private readonly AppFunc _next;
private readonly string _path;
private readonly Type _connectionType;
private readonly IDependencyResolver _resolver;
private readonly ConnectionConfiguration _configuration;

public PersistentConnectionHandler(AppFunc next, string path, Type connectionType, IDependencyResolver resolver)
public PersistentConnectionHandler(AppFunc next, string path, Type connectionType, ConnectionConfiguration configuration)
{
_next = next;
_path = path;
_connectionType = connectionType;
_resolver = resolver;
_configuration = configuration;
}

public Task Invoke(IDictionary<string, object> environment)
Expand All @@ -32,10 +32,10 @@ public Task Invoke(IDictionary<string, object> environment)
return _next(environment);
}

var connectionFactory = new PersistentConnectionFactory(_resolver);
var connectionFactory = new PersistentConnectionFactory(_configuration.Resolver);
var connection = connectionFactory.CreateInstance(_connectionType);

var handler = new CallHandler(_resolver, connection);
var handler = new CallHandler(_configuration, connection);
return handler.Invoke(environment);
}
}
Expand Down
14 changes: 11 additions & 3 deletions src/Microsoft.AspNet.SignalR.Owin/OwinExtensions.cs
Expand Up @@ -30,7 +30,7 @@ public static IAppBuilder MapHubs(this IAppBuilder builder, string path, HubConf
throw new ArgumentNullException("configuration");
}

return builder.UseType<HubDispatcherHandler>(path, configuration.EnableJavaScriptProxies, configuration.Resolver);
return builder.UseType<HubDispatcherHandler>(path, configuration);
}

[SuppressMessage("Microsoft.Design", "CA1004:GenericMethodsShouldProvideTypeParameter", Justification = "The type parameter is syntactic sugar")]
Expand All @@ -52,14 +52,22 @@ public static IAppBuilder MapConnection(this IAppBuilder builder, string url, Ty
throw new ArgumentNullException("configuration");
}

return builder.UseType<PersistentConnectionHandler>(url, connectionType, configuration.Resolver);
return builder.UseType<PersistentConnectionHandler>(url, connectionType, configuration);
}

private static IAppBuilder UseType<T>(this IAppBuilder builder, params object[] args)
{
if (args.Length > 0)
{
var resolver = args[args.Length - 1] as IDependencyResolver;
var configuration = args[args.Length - 1] as ConnectionConfiguration;

if (configuration == null)
{
throw new ArgumentException(Resources.Error_NoConfiguration);
}

var resolver = configuration.Resolver;

if (resolver == null)
{
throw new ArgumentException(Resources.Error_NoDepenendeyResolver);
Expand Down
9 changes: 9 additions & 0 deletions src/Microsoft.AspNet.SignalR.Owin/Resources.Designer.cs

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

3 changes: 3 additions & 0 deletions src/Microsoft.AspNet.SignalR.Owin/Resources.resx
Expand Up @@ -117,6 +117,9 @@
<resheader name="writer">
<value>System.Resources.ResXResourceWriter, System.Windows.Forms, Version=4.0.0.0, Culture=neutral, PublicKeyToken=b77a5c561934e089</value>
</resheader>
<data name="Error_NoConfiguration" xml:space="preserve">
<value>A configuration object must be specified.</value>
</data>
<data name="Error_NoDepenendeyResolver" xml:space="preserve">
<value>A dependency resolver must be specified.</value>
</data>
Expand Down
Expand Up @@ -34,7 +34,7 @@ public static void Start()
{
GlobalHost.Configuration.ConnectionTimeout = TimeSpan.FromSeconds(connectionTimeout);
}

int disconnectTimeout;
if (Int32.TryParse(disconnectTimeoutRaw, out disconnectTimeout))
{
Expand All @@ -54,8 +54,12 @@ public static void Start()
GlobalHost.HubPipeline.EnableAutoRejoiningGroups();
}

var config = new HubConfiguration
{
EnableCrossDomain = true
};

RouteTable.Routes.MapHubs();
RouteTable.Routes.MapHubs(config);

RouteTable.Routes.MapHubs("signalr.hubs2", "/signalr2/test", new HubConfiguration());
RouteTable.Routes.MapConnection<MyBadConnection>("errors-are-fun", "ErrorsAreFun");
Expand Down

0 comments on commit 034b4a8

Please sign in to comment.