Skip to content

Commit

Permalink
Respond to bad requests with a 400 status code
Browse files Browse the repository at this point in the history
- Previously bad requests caused PersistentConnection.ProcessRequest to throw
  resulting in a 500
- Respond with a 403 status code when the client changes identity
- No longer throwing on bad requests reduces extraneous error logging
- #2522
  • Loading branch information
halter73 committed Nov 1, 2013
1 parent c23520d commit 961398c
Show file tree
Hide file tree
Showing 6 changed files with 108 additions and 23 deletions.
5 changes: 5 additions & 0 deletions src/Microsoft.AspNet.SignalR.Core/Hosting/IResponse.cs
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,11 @@ public interface IResponse
/// </summary>
CancellationToken CancellationToken { get; }

/// <summary>
/// Gets or sets the status code of the response.
/// </summary>
int StatusCode { get; set; }

/// <summary>
/// Gets or sets the content type of the response.
/// </summary>
Expand Down
43 changes: 35 additions & 8 deletions src/Microsoft.AspNet.SignalR.Core/PersistentConnection.cs
Original file line number Diff line number Diff line change
Expand Up @@ -166,18 +166,25 @@ public virtual Task ProcessRequest(HostContext context)

if (Transport == null)
{
throw new InvalidOperationException(String.Format(CultureInfo.CurrentCulture, Resources.Error_ProtocolErrorUnknownTransport));
return FailResponse(context.Response, String.Format(CultureInfo.CurrentCulture, Resources.Error_ProtocolErrorUnknownTransport));
}

string connectionToken = context.Request.QueryString["connectionToken"];

// If there's no connection id then this is a bad request
if (String.IsNullOrEmpty(connectionToken))
{
throw new InvalidOperationException(String.Format(CultureInfo.CurrentCulture, Resources.Error_ProtocolErrorMissingConnectionToken));
return FailResponse(context.Response, String.Format(CultureInfo.CurrentCulture, Resources.Error_ProtocolErrorMissingConnectionToken));
}

string connectionId = GetConnectionId(context, connectionToken);
string connectionId;
string message;
int statusCode;

if (!TryGetConnectionId(context, connectionToken, out connectionId, out message, out statusCode))
{
return FailResponse(context.Response, message, statusCode);
}

// Set the transport's connection id to the unprotected one
Transport.ConnectionId = connectionId;
Expand Down Expand Up @@ -228,10 +235,21 @@ public virtual Task ProcessRequest(HostContext context)
}

[SuppressMessage("Microsoft.Design", "CA1031:DoNotCatchGeneralExceptionTypes", Justification = "We want to catch any exception when unprotecting data.")]
internal string GetConnectionId(HostContext context, string connectionToken)
internal bool TryGetConnectionId(HostContext context,
string connectionToken,
out string connectionId,
out string message,
out int statusCode)
{
string unprotectedConnectionToken = null;

// connectionId is only valid when this method returns true
connectionId = null;

// message and statusCode are only valid when this method returns false
message = null;
statusCode = 400;

try
{
unprotectedConnectionToken = ProtectedData.Unprotect(connectionToken, Purposes.ConnectionToken);
Expand All @@ -243,21 +261,24 @@ internal string GetConnectionId(HostContext context, string connectionToken)

if (String.IsNullOrEmpty(unprotectedConnectionToken))
{
throw new InvalidOperationException(String.Format(CultureInfo.CurrentCulture, Resources.Error_ConnectionIdIncorrectFormat));
message = String.Format(CultureInfo.CurrentCulture, Resources.Error_ConnectionIdIncorrectFormat);
return false;
}

var tokens = unprotectedConnectionToken.Split(SplitChars, 2);

string connectionId = tokens[0];
connectionId = tokens[0];
string tokenUserName = tokens.Length > 1 ? tokens[1] : String.Empty;
string userName = GetUserIdentity(context);

if (!String.Equals(tokenUserName, userName, StringComparison.OrdinalIgnoreCase))
{
throw new InvalidOperationException(String.Format(CultureInfo.CurrentCulture, Resources.Error_UnrecognizedUserIdentity));
message = String.Format(CultureInfo.CurrentCulture, Resources.Error_UnrecognizedUserIdentity);
statusCode = 403;
return false;
}

return connectionId;
return true;
}

[SuppressMessage("Microsoft.Design", "CA1031:DoNotCatchGeneralExceptionTypes", Justification = "We want to prevent any failures in unprotecting")]
Expand Down Expand Up @@ -478,6 +499,12 @@ private Task ProcessJsonpRequest(HostContext context, object payload)
return context.Response.End(data);
}

private static Task FailResponse(IResponse response, string message, int statusCode = 400)
{
response.StatusCode = statusCode;
return response.End(message);
}

private static bool IsNegotiationRequest(IRequest request)
{
return request.Url.LocalPath.EndsWith("/negotiate", StringComparison.OrdinalIgnoreCase);
Expand Down
4 changes: 3 additions & 1 deletion src/Microsoft.AspNet.SignalR.Owin/ServerRequest.cs
Original file line number Diff line number Diff line change
Expand Up @@ -144,7 +144,9 @@ public Task AcceptWebSocketRequest(Func<IWebSocket, Task> callback, Task initTas
var accept = _environment.Get<Action<IDictionary<string, object>, WebSocketFunc>>(OwinConstants.WebSocketAccept);
if (accept == null)
{
throw new InvalidOperationException(Resources.Error_NotWebSocketRequest);
var response = new ServerResponse(_environment);
response.StatusCode = 400;
response.End(Resources.Error_NotWebSocketRequest);
}

var handler = new OwinWebSocketHandler(callback, initTask);
Expand Down
12 changes: 12 additions & 0 deletions src/Microsoft.AspNet.SignalR.Owin/ServerResponse.cs
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,18 @@ public CancellationToken CancellationToken
get { return _callCancelled; }
}

public int StatusCode
{
get
{
return _environment.Get<int>(OwinConstants.ResponseStatusCode);
}
set
{
_environment[OwinConstants.ResponseStatusCode] = value;
}
}

public string ContentType
{
get { return ResponseHeaders.GetHeader("Content-Type"); }
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,8 @@ public CancellationToken CancellationToken
get { return CancellationToken.None; }
}

public int StatusCode { get; set; }

public string ContentType { get; set; }

public void Write(ArraySegment<byte> data)
Expand Down
65 changes: 51 additions & 14 deletions tests/Microsoft.AspNet.SignalR.Tests/PersistentConnectionFacts.cs
Original file line number Diff line number Diff line change
Expand Up @@ -28,23 +28,30 @@ public void UninitializedThrows()
}

[Fact]
public void UnknownTransportThrows()
public void UnknownTransportFails()
{
var connection = new Mock<PersistentConnection>() { CallBase = true };
var req = new Mock<IRequest>();
req.Setup(m => m.Url).Returns(new Uri("http://foo"));
var qs = new NameValueCollection();
req.Setup(m => m.QueryString).Returns(qs);

var res = new Mock<IResponse>();
res.SetupProperty(m => m.StatusCode);
res.Setup(m => m.End()).Returns(TaskAsyncHelper.Empty);

var dr = new DefaultDependencyResolver();
var context = new HostContext(req.Object, null);
var context = new HostContext(req.Object, res.Object);
connection.Object.Initialize(dr, context);

Assert.Throws<InvalidOperationException>(() => connection.Object.ProcessRequest(context));
var task = connection.Object.ProcessRequest(context);

Assert.True(task.IsCompleted);
Assert.Equal(400, context.Response.StatusCode);
}

[Fact]
public void MissingConnectionTokenThrows()
public void MissingConnectionTokenFails()
{
var connection = new Mock<PersistentConnection>() { CallBase = true };
var req = new Mock<IRequest>();
Expand All @@ -53,11 +60,18 @@ public void MissingConnectionTokenThrows()
qs["transport"] = "serverSentEvents";
req.Setup(m => m.QueryString).Returns(qs);

var res = new Mock<IResponse>();
res.SetupProperty(m => m.StatusCode);
res.Setup(m => m.End()).Returns(TaskAsyncHelper.Empty);

var dr = new DefaultDependencyResolver();
var context = new HostContext(req.Object, null);
var context = new HostContext(req.Object, res.Object);
connection.Object.Initialize(dr, context);

Assert.Throws<InvalidOperationException>(() => connection.Object.ProcessRequest(context));
var task = connection.Object.ProcessRequest(context);

Assert.True(task.IsCompleted);
Assert.Equal(400, context.Response.StatusCode);
}
}

Expand Down Expand Up @@ -128,7 +142,7 @@ private static IList<string> DoVerifyGroups(string groupsToken, string connectio
public class GetConnectionId
{
[Fact]
public void UnprotectedConnectionTokenThrows()
public void UnprotectedConnectionTokenFails()
{
var connection = new Mock<PersistentConnection>() { CallBase = true };
var req = new Mock<IRequest>();
Expand All @@ -144,11 +158,17 @@ public void UnprotectedConnectionTokenThrows()
var context = new HostContext(req.Object, null);
connection.Object.Initialize(dr, context);

Assert.Throws<InvalidOperationException>(() => connection.Object.GetConnectionId(context, "1"));
string connectionId;
string message;
int statusCode;

Assert.Equal(false, connection.Object.TryGetConnectionId(context, "1", out connectionId, out message, out statusCode));
Assert.Equal(null, connectionId);
Assert.Equal(400, statusCode);
}

[Fact]
public void NullUnprotectedConnectionTokenThrows()
public void NullUnprotectedConnectionTokenFails()
{
var connection = new Mock<PersistentConnection>() { CallBase = true };
var req = new Mock<IRequest>();
Expand All @@ -163,11 +183,17 @@ public void NullUnprotectedConnectionTokenThrows()
var context = new HostContext(req.Object, null);
connection.Object.Initialize(dr, context);

Assert.Throws<InvalidOperationException>(() => connection.Object.GetConnectionId(context, "1"));
string connectionId;
string message;
int statusCode;

Assert.Equal(false, connection.Object.TryGetConnectionId(context, "1", out connectionId, out message, out statusCode));
Assert.Equal(null, connectionId);
Assert.Equal(400, statusCode);
}

[Fact]
public void UnauthenticatedUserWithAuthenticatedTokenThrows()
public void UnauthenticatedUserWithAuthenticatedTokenFails()
{
var connection = new Mock<PersistentConnection>() { CallBase = true };
var req = new Mock<IRequest>();
Expand All @@ -182,7 +208,12 @@ public void UnauthenticatedUserWithAuthenticatedTokenThrows()
var context = new HostContext(req.Object, null);
connection.Object.Initialize(dr, context);

Assert.Throws<InvalidOperationException>(() => connection.Object.GetConnectionId(context, "1:::11:::::::1:1"));
string connectionId;
string message;
int statusCode;

Assert.Equal(false, connection.Object.TryGetConnectionId(context, "1:::11:::::::1:1", out connectionId, out message, out statusCode));
Assert.Equal(403, statusCode);
}

[Fact]
Expand All @@ -202,8 +233,11 @@ public void AuthenticatedUserNameMatches()
var context = new HostContext(req.Object, null);
connection.Object.Initialize(dr, context);

var connectionId = connection.Object.GetConnectionId(context, "1:Name");
string connectionId;
string message;
int statusCode;

Assert.Equal(true, connection.Object.TryGetConnectionId(context, "1:Name", out connectionId, out message, out statusCode));
Assert.Equal("1", connectionId);
}

Expand All @@ -226,8 +260,11 @@ public void AuthenticatedUserWithColonsInUserName()
var context = new HostContext(req.Object, null);
connection.Object.Initialize(dr, context);

string cid = connection.Object.GetConnectionId(context, connectionId + ":::11:::::::1:1");
string cid;
string message;
int statusCode;

Assert.Equal(true, connection.Object.TryGetConnectionId(context, connectionId + ":::11:::::::1:1", out cid, out message, out statusCode));
Assert.Equal(connectionId, cid);
}
}
Expand Down

0 comments on commit 961398c

Please sign in to comment.