From 934f6a70d198dd99bdb08d7e32ae2845e7b03829 Mon Sep 17 00:00:00 2001 From: David Fowler Date: Wed, 25 Jan 2017 22:27:55 +0000 Subject: [PATCH] Various fixes in HttpConnectionDispatcher (#151) - The connection state object is manipulated by multiple parties in a non thread safe way. This change introduces a semaphore that should be used by anyone updating or reading the connection state. - Handle cases where there's an active request for a connection id and another incoming request for the same connection id, sse and websockets 409 and long polling kicks out the previous connection (https://github.com/aspnet/SignalR/issues/27 and https://github.com/aspnet/SignalR/issues/4) - Handle requests being processed for disposed connections. There was a race where the background thread could remove and clean up the connection while it was about to be processed. - Synchronize between the background scanning thread and the request threads when updating the connection state. - Added `DisposeAndRemoveAsync` to the connection manager that handles`DisposeAsync` throwing and properly removes connections from connection tracking. - Added Start to ConnectionManager so that testing is easier (background timer doesn't kick in unless start is called). - Added RequestId to connection state for easier debugging and correlation (can easily see which request is currently processing the logical connection). - Added tests --- .../ConnectionManager.cs | 88 +++++++-- .../HttpConnectionDispatcher.cs | 175 ++++++++++++++---- .../Internal/ConnectionState.cs | 63 +++++-- .../SocketsApplicationLifetimeService.cs | 5 +- .../Transports/IHttpTransport.cs | 4 +- .../Transports/LongPollingTransport.cs | 24 +-- .../Transports/ServerSentEventsTransport.cs | 5 +- .../Transports/WebSocketsTransport.cs | 3 +- .../ConnectionManagerTests.cs | 20 +- .../HttpConnectionDispatcherTests.cs | 169 ++++++++++++++++- .../LongPollingTests.cs | 4 +- .../ServerSentEventsTests.cs | 4 +- 12 files changed, 462 insertions(+), 102 deletions(-) diff --git a/src/Microsoft.AspNetCore.Sockets/ConnectionManager.cs b/src/Microsoft.AspNetCore.Sockets/ConnectionManager.cs index 888677f3a7..14894f6ac5 100644 --- a/src/Microsoft.AspNetCore.Sockets/ConnectionManager.cs +++ b/src/Microsoft.AspNetCore.Sockets/ConnectionManager.cs @@ -8,17 +8,27 @@ using System.Threading.Tasks; using System.Threading.Tasks.Channels; using Microsoft.AspNetCore.Sockets.Internal; +using Microsoft.Extensions.Logging; namespace Microsoft.AspNetCore.Sockets { public class ConnectionManager { private readonly ConcurrentDictionary _connections = new ConcurrentDictionary(); - private readonly Timer _timer; + private Timer _timer; + private readonly ILogger _logger; - public ConnectionManager() + public ConnectionManager(ILogger logger) { - _timer = new Timer(Scan, this, 0, 1000); + _logger = logger; + } + + public void Start() + { + if (_timer == null) + { + _timer = new Timer(Scan, this, TimeSpan.FromSeconds(1), TimeSpan.FromSeconds(1)); + } } public bool TryGetConnection(string id, out ConnectionState state) @@ -47,9 +57,11 @@ public ConnectionState CreateConnection() public void RemoveConnection(string id) { ConnectionState state; - _connections.TryRemove(id, out state); - - // Remove the connection completely + if (_connections.TryRemove(id, out state)) + { + // Remove the connection completely + _logger.LogDebug("Removing {connectionId} from the list of connections", id); + } } private static string MakeNewConnectionId() @@ -65,38 +77,76 @@ private static void Scan(object state) private void Scan() { - // Scan the registered connections looking for ones that have timed out - foreach (var c in _connections) + // Pause the timer while we're running + _timer.Change(Timeout.Infinite, Timeout.Infinite); + + try { - if (!c.Value.Active && (DateTimeOffset.UtcNow - c.Value.LastSeenUtc).TotalSeconds > 5) + // Scan the registered connections looking for ones that have timed out + foreach (var c in _connections) { - ConnectionState s; - if (_connections.TryRemove(c.Key, out s)) + var status = ConnectionState.ConnectionStatus.Inactive; + var lastSeenUtc = DateTimeOffset.UtcNow; + + try + { + c.Value.Lock.Wait(); + + // Capture the connection state + status = c.Value.Status; + + lastSeenUtc = c.Value.LastSeenUtc; + } + finally { - // REVIEW: Should we keep firing and forgetting this? - var ignore = s.DisposeAsync(); + c.Value.Lock.Release(); + } + + // Once the decision has been made to to dispose we don't check the status again + if (status == ConnectionState.ConnectionStatus.Inactive && (DateTimeOffset.UtcNow - lastSeenUtc).TotalSeconds > 5) + { + var ignore = DisposeAndRemoveAsync(c.Value); } } } + finally + { + // Resume once we finished processing all connections + _timer.Change(TimeSpan.FromSeconds(1), TimeSpan.FromSeconds(1)); + } } public void CloseConnections() { // Stop firing the timer - _timer.Dispose(); + _timer?.Dispose(); var tasks = new List(); foreach (var c in _connections) { - ConnectionState s; - if (_connections.TryRemove(c.Key, out s)) - { - tasks.Add(s.DisposeAsync()); - } + tasks.Add(DisposeAndRemoveAsync(c.Value)); } Task.WaitAll(tasks.ToArray(), TimeSpan.FromSeconds(5)); } + + public async Task DisposeAndRemoveAsync(ConnectionState state) + { + try + { + await state.DisposeAsync(); + } + catch (Exception ex) + { + _logger.LogError(0, ex, "Failed disposing connection {connectionId}", state.Connection.ConnectionId); + } + finally + { + // Remove it from the list after disposal so that's it's easy to see + // connections that might be in a hung state via the connections list + RemoveConnection(state.Connection.ConnectionId); + } + } } } diff --git a/src/Microsoft.AspNetCore.Sockets/HttpConnectionDispatcher.cs b/src/Microsoft.AspNetCore.Sockets/HttpConnectionDispatcher.cs index 5dc231eaf3..163c4d2584 100644 --- a/src/Microsoft.AspNetCore.Sockets/HttpConnectionDispatcher.cs +++ b/src/Microsoft.AspNetCore.Sockets/HttpConnectionDispatcher.cs @@ -5,6 +5,7 @@ using System.IO; using System.IO.Pipelines; using System.Text; +using System.Threading; using System.Threading.Tasks; using Microsoft.AspNetCore.Http; using Microsoft.AspNetCore.Sockets.Internal; @@ -70,8 +71,6 @@ private async Task ExecuteEndpointAsync(string path, HttpContext context, EndPoi var sse = new ServerSentEventsTransport(state.Application.Input, _loggerFactory); await DoPersistentConnection(endpoint, sse, context, state); - - _manager.RemoveConnection(state.Connection.ConnectionId); } else if (context.Request.Path.StartsWithSegments(path + "/ws")) { @@ -92,8 +91,6 @@ private async Task ExecuteEndpointAsync(string path, HttpContext context, EndPoi var ws = new WebSocketsTransport(state.Application, _loggerFactory); await DoPersistentConnection(endpoint, ws, context, state); - - _manager.RemoveConnection(state.Connection.ConnectionId); } else if (context.Request.Path.StartsWithSegments(path + "/poll")) { @@ -111,39 +108,112 @@ private async Task ExecuteEndpointAsync(string path, HttpContext context, EndPoi return; } - // Mark the connection as active - state.Active = true; - - // Raise OnConnected for new connections only since polls happen all the time - if (state.ApplicationTask == null) + try { - _logger.LogDebug("Establishing new Long Polling connection: {0}", state.Connection.ConnectionId); + await state.Lock.WaitAsync(); + + if (state.Status == ConnectionState.ConnectionStatus.Disposed) + { + _logger.LogDebug("Connection {connectionId} was disposed,", state.Connection.ConnectionId); + + // The connection was disposed + context.Response.StatusCode = StatusCodes.Status404NotFound; + return; + } + + if (state.Status == ConnectionState.ConnectionStatus.Active) + { + _logger.LogDebug("Connection {connectionId} is already active via {requestId}. Cancelling previous request.", state.Connection.ConnectionId, state.RequestId); + + using (state.Cancellation) + { + // Cancel the previous request + state.Cancellation.Cancel(); + + try + { + // Wait for the previous request to drain + await state.TransportTask; + } + catch (OperationCanceledException) + { + // Should be a cancelled task + } + + _logger.LogDebug("Previous poll cancelled for {connectionId} on {requestId}.", state.Connection.ConnectionId, state.RequestId); + } + } + + // Mark the request identifier + state.RequestId = context.TraceIdentifier; - // This will re-initialize formatType metadata, but meh... - state.Connection.Metadata["transport"] = LongPollingTransport.Name; + // Mark the connection as active + state.Status = ConnectionState.ConnectionStatus.Active; - state.ApplicationTask = endpoint.OnConnectedAsync(state.Connection); + // Raise OnConnected for new connections only since polls happen all the time + if (state.ApplicationTask == null) + { + _logger.LogDebug("Establishing new connection: {connectionId} on {requestId}", state.Connection.ConnectionId, state.RequestId); + + state.Connection.Metadata["transport"] = LongPollingTransport.Name; + + state.ApplicationTask = endpoint.OnConnectedAsync(state.Connection); + } + else + { + _logger.LogDebug("Resuming existing connection: {connectionId} on {requestId}", state.Connection.ConnectionId, state.RequestId); + } + + var longPolling = new LongPollingTransport(state.Application.Input, _loggerFactory); + + state.Cancellation = new CancellationTokenSource(); + + // REVIEW: Performance of this isn't great as this does a bunch of per request allocations + var tokenSource = CancellationTokenSource.CreateLinkedTokenSource(state.Cancellation.Token, context.RequestAborted); + + // Start the transport + state.TransportTask = longPolling.ProcessRequestAsync(context, tokenSource.Token); } - else + finally { - _logger.LogDebug("Resuming existing Long Polling connection: {0}", state.Connection.ConnectionId); + state.Lock.Release(); } - var longPolling = new LongPollingTransport(state.Application.Input, _loggerFactory); - - // Start the transport - state.TransportTask = longPolling.ProcessRequestAsync(context); - var resultTask = await Task.WhenAny(state.ApplicationTask, state.TransportTask); + // If the application ended before the transport task then we need to end the connection completely + // so there is no future polling if (resultTask == state.ApplicationTask) { - await state.DisposeAsync(); + await _manager.DisposeAndRemoveAsync(state); + } + else if (!resultTask.IsCanceled) + { + // Otherwise, we update the state to inactive again and wait for the next poll + try + { + await state.Lock.WaitAsync(); + + if (state.Status == ConnectionState.ConnectionStatus.Active) + { + // Mark the connection as inactive + state.LastSeenUtc = DateTime.UtcNow; + + state.Status = ConnectionState.ConnectionStatus.Inactive; + + state.RequestId = null; + + // Dispose the cancellation token + state.Cancellation.Dispose(); + + state.Cancellation = null; + } + } + finally + { + state.Lock.Release(); + } } - - // Mark the connection as inactive - state.LastSeenUtc = DateTime.UtcNow; - state.Active = false; } } @@ -163,22 +233,55 @@ private ConnectionState CreateConnection(HttpContext context) return state; } - private static async Task DoPersistentConnection(EndPoint endpoint, - IHttpTransport transport, - HttpContext context, - ConnectionState state) + private async Task DoPersistentConnection(EndPoint endpoint, + IHttpTransport transport, + HttpContext context, + ConnectionState state) { - // Call into the end point passing the connection - state.ApplicationTask = endpoint.OnConnectedAsync(state.Connection); + try + { + await state.Lock.WaitAsync(); + + if (state.Status == ConnectionState.ConnectionStatus.Disposed) + { + _logger.LogDebug("Connection {connectionId} was disposed,", state.Connection.ConnectionId); - // Start the transport - state.TransportTask = transport.ProcessRequestAsync(context); + // Connection was disposed + context.Response.StatusCode = StatusCodes.Status404NotFound; + return; + } + + // There's already an active request + if (state.Status == ConnectionState.ConnectionStatus.Active) + { + _logger.LogDebug("Connection {connectionId} is already active via {requestId}.", state.Connection.ConnectionId, state.RequestId); + + // Reject the request with a 409 conflict + context.Response.StatusCode = StatusCodes.Status409Conflict; + return; + } + + // Mark the connection as active + state.Status = ConnectionState.ConnectionStatus.Active; + + // Store the request identifier + state.RequestId = context.TraceIdentifier; + + // Call into the end point passing the connection + state.ApplicationTask = endpoint.OnConnectedAsync(state.Connection); + + // Start the transport + state.TransportTask = transport.ProcessRequestAsync(context, context.RequestAborted); + } + finally + { + state.Lock.Release(); + } // Wait for any of them to end await Task.WhenAny(state.ApplicationTask, state.TransportTask); - // Kill the channel - await state.DisposeAsync(); + await _manager.DisposeAndRemoveAsync(state); } private Task ProcessNegotiate(HttpContext context) @@ -243,7 +346,7 @@ private async Task EnsureConnectionStateAsync(ConnectionState connectionSt } else if (!string.Equals(transport, transportName, StringComparison.Ordinal)) { - context.Response.StatusCode = 400; + context.Response.StatusCode = StatusCodes.Status400BadRequest; await context.Response.WriteAsync("Cannot change transports mid-connection"); return false; } diff --git a/src/Microsoft.AspNetCore.Sockets/Internal/ConnectionState.cs b/src/Microsoft.AspNetCore.Sockets/Internal/ConnectionState.cs index fa25a232ff..a1882e74c1 100644 --- a/src/Microsoft.AspNetCore.Sockets/Internal/ConnectionState.cs +++ b/src/Microsoft.AspNetCore.Sockets/Internal/ConnectionState.cs @@ -2,7 +2,9 @@ // Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. using System; +using System.Threading; using System.Threading.Tasks; +using Microsoft.Extensions.Internal; namespace Microsoft.AspNetCore.Sockets.Internal { @@ -11,11 +13,17 @@ public class ConnectionState public Connection Connection { get; set; } public IChannelConnection Application { get; } + public CancellationTokenSource Cancellation { get; set; } + + public SemaphoreSlim Lock { get; } = new SemaphoreSlim(1, 1); + + public string RequestId { get; set; } + public Task TransportTask { get; set; } public Task ApplicationTask { get; set; } public DateTime LastSeenUtc { get; set; } - public bool Active { get; set; } = true; + public ConnectionStatus Status { get; set; } = ConnectionStatus.Inactive; public ConnectionState(Connection connection, IChannelConnection application) { @@ -26,23 +34,54 @@ public ConnectionState(Connection connection, IChannelConnection applic public async Task DisposeAsync() { - // If the application task is faulted, propagate the error to the transport - if (ApplicationTask.IsFaulted) + Task applicationTask = TaskCache.CompletedTask; + Task transportTask = TaskCache.CompletedTask; + + try { - Connection.Transport.Output.TryComplete(ApplicationTask.Exception.InnerException); - } + await Lock.WaitAsync(); + + if (Status == ConnectionStatus.Disposed) + { + return; + } + + Status = ConnectionStatus.Disposed; + + RequestId = null; - // If the transport task is faulted, propagate the error to the application - if (TransportTask.IsFaulted) + // If the application task is faulted, propagate the error to the transport + if (ApplicationTask.IsFaulted) + { + Connection.Transport.Output.TryComplete(ApplicationTask.Exception.InnerException); + } + + // If the transport task is faulted, propagate the error to the application + if (TransportTask.IsFaulted) + { + Application.Output.TryComplete(TransportTask.Exception.InnerException); + } + + Connection.Dispose(); + Application.Dispose(); + + applicationTask = ApplicationTask; + transportTask = TransportTask; + } + finally { - Application.Output.TryComplete(TransportTask.Exception.InnerException); + Lock.Release(); } - Connection.Dispose(); - Application.Dispose(); - // REVIEW: Add a timeout so we don't wait forever - await Task.WhenAll(ApplicationTask, TransportTask); + await Task.WhenAll(applicationTask, transportTask); + } + + public enum ConnectionStatus + { + Inactive, + Active, + Disposed } } } diff --git a/src/Microsoft.AspNetCore.Sockets/SocketsApplicationLifetimeService.cs b/src/Microsoft.AspNetCore.Sockets/SocketsApplicationLifetimeService.cs index afb44b1543..f05ecb61c2 100644 --- a/src/Microsoft.AspNetCore.Sockets/SocketsApplicationLifetimeService.cs +++ b/src/Microsoft.AspNetCore.Sockets/SocketsApplicationLifetimeService.cs @@ -1,10 +1,6 @@ // Copyright (c) .NET Foundation. All rights reserved. // Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. -using System; -using System.Collections.Generic; -using System.Linq; -using System.Threading.Tasks; using Microsoft.AspNetCore.Hosting; namespace Microsoft.AspNetCore.Sockets @@ -20,6 +16,7 @@ public SocketsApplicationLifetimeService(ConnectionManager connectionManager) public void Start() { + _connectionManager.Start(); } public void Stop() diff --git a/src/Microsoft.AspNetCore.Sockets/Transports/IHttpTransport.cs b/src/Microsoft.AspNetCore.Sockets/Transports/IHttpTransport.cs index eb4d62c136..dc58f2dbcf 100644 --- a/src/Microsoft.AspNetCore.Sockets/Transports/IHttpTransport.cs +++ b/src/Microsoft.AspNetCore.Sockets/Transports/IHttpTransport.cs @@ -1,6 +1,7 @@ // Copyright (c) .NET Foundation. All rights reserved. // Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. +using System.Threading; using System.Threading.Tasks; using Microsoft.AspNetCore.Http; @@ -12,7 +13,8 @@ public interface IHttpTransport /// Executes the transport /// /// + /// /// A that completes when the transport has finished processing - Task ProcessRequestAsync(HttpContext context); + Task ProcessRequestAsync(HttpContext context, CancellationToken token); } } diff --git a/src/Microsoft.AspNetCore.Sockets/Transports/LongPollingTransport.cs b/src/Microsoft.AspNetCore.Sockets/Transports/LongPollingTransport.cs index 379c9b3aee..eec753f472 100644 --- a/src/Microsoft.AspNetCore.Sockets/Transports/LongPollingTransport.cs +++ b/src/Microsoft.AspNetCore.Sockets/Transports/LongPollingTransport.cs @@ -23,17 +23,14 @@ public LongPollingTransport(ReadableChannel application, ILoggerFactory _logger = loggerFactory.CreateLogger(); } - public async Task ProcessRequestAsync(HttpContext context) + public async Task ProcessRequestAsync(HttpContext context, CancellationToken token) { try { - // TODO: We need the ability to yield the connection without completing the channel. - // This is to force ReadAsync to yield without data to end to poll but not the entire connection. - // This is for cases when the client reconnects see issue #27 - if (!await _application.WaitToReadAsync(context.RequestAborted)) + if (!await _application.WaitToReadAsync(token)) { _logger.LogInformation("Terminating Long Polling connection by sending 204 response."); - context.Response.StatusCode = 204; + context.Response.StatusCode = StatusCodes.Status204NoContent; return; } @@ -50,14 +47,17 @@ public async Task ProcessRequestAsync(HttpContext context) } catch (OperationCanceledException) { - // Suppress the exception + if (!context.RequestAborted.IsCancellationRequested) + { + _logger.LogInformation("Terminating Long Polling connection by sending 204 response."); + context.Response.StatusCode = StatusCodes.Status204NoContent; + throw; + } + + // Don't count this as cancellation, this is normal as the poll can end due to the browesr closing. + // The background thread will eventually dispose this connection if it's inactive _logger.LogDebug("Client disconnected from Long Polling endpoint."); } - catch (Exception ex) - { - _logger.LogError("Error reading next message from Application: {0}", ex); - throw; - } } } } diff --git a/src/Microsoft.AspNetCore.Sockets/Transports/ServerSentEventsTransport.cs b/src/Microsoft.AspNetCore.Sockets/Transports/ServerSentEventsTransport.cs index ec34631695..4863ccc396 100644 --- a/src/Microsoft.AspNetCore.Sockets/Transports/ServerSentEventsTransport.cs +++ b/src/Microsoft.AspNetCore.Sockets/Transports/ServerSentEventsTransport.cs @@ -2,6 +2,7 @@ // Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. using System; +using System.Threading; using System.Threading.Tasks; using System.Threading.Tasks.Channels; using Microsoft.AspNetCore.Http; @@ -21,7 +22,7 @@ public ServerSentEventsTransport(ReadableChannel application, ILoggerFa _logger = loggerFactory.CreateLogger(); } - public async Task ProcessRequestAsync(HttpContext context) + public async Task ProcessRequestAsync(HttpContext context, CancellationToken token) { context.Response.ContentType = "text/event-stream"; context.Response.Headers["Cache-Control"] = "no-cache"; @@ -30,7 +31,7 @@ public async Task ProcessRequestAsync(HttpContext context) try { - while (await _application.WaitToReadAsync(context.RequestAborted)) + while (await _application.WaitToReadAsync(token)) { Message message; while (_application.TryRead(out message)) diff --git a/src/Microsoft.AspNetCore.Sockets/Transports/WebSocketsTransport.cs b/src/Microsoft.AspNetCore.Sockets/Transports/WebSocketsTransport.cs index 710c2ee6ea..215b5ede5d 100644 --- a/src/Microsoft.AspNetCore.Sockets/Transports/WebSocketsTransport.cs +++ b/src/Microsoft.AspNetCore.Sockets/Transports/WebSocketsTransport.cs @@ -3,6 +3,7 @@ using System; using System.Diagnostics; +using System.Threading; using System.Threading.Tasks; using Microsoft.AspNetCore.Http; using Microsoft.AspNetCore.WebSockets.Internal; @@ -40,7 +41,7 @@ public WebSocketsTransport(IChannelConnection application, ILoggerFacto _logger = loggerFactory.CreateLogger(); } - public async Task ProcessRequestAsync(HttpContext context) + public async Task ProcessRequestAsync(HttpContext context, CancellationToken token) { var feature = context.Features.Get(); if (feature == null || !feature.IsWebSocketRequest) diff --git a/test/Microsoft.AspNetCore.Sockets.Tests/ConnectionManagerTests.cs b/test/Microsoft.AspNetCore.Sockets.Tests/ConnectionManagerTests.cs index 8295793ac3..0476fd9131 100644 --- a/test/Microsoft.AspNetCore.Sockets.Tests/ConnectionManagerTests.cs +++ b/test/Microsoft.AspNetCore.Sockets.Tests/ConnectionManagerTests.cs @@ -4,6 +4,7 @@ using System; using System.Threading.Tasks; using Microsoft.AspNetCore.Sockets.Internal; +using Microsoft.Extensions.Logging; using Xunit; namespace Microsoft.AspNetCore.Sockets.Tests @@ -13,21 +14,23 @@ public class ConnectionManagerTests [Fact] public void NewConnectionsHaveConnectionId() { - var connectionManager = new ConnectionManager(); + var connectionManager = CreateConnectionManager(); var state = connectionManager.CreateConnection(); Assert.NotNull(state.Connection); Assert.NotNull(state.Connection.ConnectionId); - Assert.True(state.Active); + Assert.Equal(ConnectionState.ConnectionStatus.Inactive, state.Status); Assert.Null(state.ApplicationTask); Assert.Null(state.TransportTask); + Assert.Null(state.Cancellation); + Assert.Null(state.RequestId); Assert.NotNull(state.Connection.Transport); } [Fact] public void NewConnectionsCanBeRetrieved() { - var connectionManager = new ConnectionManager(); + var connectionManager = CreateConnectionManager(); var state = connectionManager.CreateConnection(); Assert.NotNull(state.Connection); @@ -41,7 +44,7 @@ public void NewConnectionsCanBeRetrieved() [Fact] public void AddNewConnection() { - var connectionManager = new ConnectionManager(); + var connectionManager = CreateConnectionManager(); var state = connectionManager.CreateConnection(); var transport = state.Connection.Transport; @@ -59,7 +62,7 @@ public void AddNewConnection() [Fact] public void RemoveConnection() { - var connectionManager = new ConnectionManager(); + var connectionManager = CreateConnectionManager(); var state = connectionManager.CreateConnection(); var transport = state.Connection.Transport; @@ -80,7 +83,7 @@ public void RemoveConnection() [Fact] public async Task CloseConnectionsEndsAllPendingConnections() { - var connectionManager = new ConnectionManager(); + var connectionManager = CreateConnectionManager(); var state = connectionManager.CreateConnection(); state.ApplicationTask = Task.Run(async () => @@ -97,5 +100,10 @@ public async Task CloseConnectionsEndsAllPendingConnections() await state.DisposeAsync(); } + + private static ConnectionManager CreateConnectionManager() + { + return new ConnectionManager(new Logger(new LoggerFactory())); + } } } diff --git a/test/Microsoft.AspNetCore.Sockets.Tests/HttpConnectionDispatcherTests.cs b/test/Microsoft.AspNetCore.Sockets.Tests/HttpConnectionDispatcherTests.cs index 532b0c137f..a503ac053d 100644 --- a/test/Microsoft.AspNetCore.Sockets.Tests/HttpConnectionDispatcherTests.cs +++ b/test/Microsoft.AspNetCore.Sockets.Tests/HttpConnectionDispatcherTests.cs @@ -4,6 +4,7 @@ using System; using System.Collections.Generic; using System.IO; +using System.IO.Pipelines; using System.Text; using System.Threading.Tasks; using Microsoft.AspNetCore.Http; @@ -21,7 +22,7 @@ public class HttpConnectionDispatcherTests [Fact] public async Task NegotiateReservesConnectionIdAndReturnsIt() { - var manager = new ConnectionManager(); + var manager = CreateConnectionManager(); var dispatcher = new HttpConnectionDispatcher(manager, new LoggerFactory()); var context = new DefaultHttpContext(); var services = new ServiceCollection(); @@ -46,7 +47,7 @@ public async Task NegotiateReservesConnectionIdAndReturnsIt() [InlineData("/ws")] public async Task EndpointsThatAcceptConnectionId404WhenUnknownConnectionIdProvided(string path) { - var manager = new ConnectionManager(); + var manager = CreateConnectionManager(); var dispatcher = new HttpConnectionDispatcher(manager, new LoggerFactory()); using (var strm = new MemoryStream()) @@ -77,7 +78,7 @@ public async Task EndpointsThatAcceptConnectionId404WhenUnknownConnectionIdProvi [InlineData("/poll")] public async Task EndpointsThatRequireConnectionId400WhenNoConnectionIdProvided(string path) { - var manager = new ConnectionManager(); + var manager = CreateConnectionManager(); var dispatcher = new HttpConnectionDispatcher(manager, new LoggerFactory()); using (var strm = new MemoryStream()) { @@ -95,13 +96,171 @@ public async Task EndpointsThatRequireConnectionId400WhenNoConnectionIdProvided( Assert.Equal("Connection ID required", Encoding.UTF8.GetString(strm.ToArray())); } } + + [Fact] + public async Task CompletedEndPointEndsConnection() + { + var manager = CreateConnectionManager(); + var state = manager.CreateConnection(); + + var dispatcher = new HttpConnectionDispatcher(manager, new LoggerFactory()); + + var context = MakeRequest("/sse", state); + + await dispatcher.ExecuteAsync("", context); + + Assert.Equal(StatusCodes.Status200OK, context.Response.StatusCode); + + ConnectionState removed; + bool exists = manager.TryGetConnection(state.Connection.ConnectionId, out removed); + Assert.False(exists); + } + + [Fact] + public async Task CompletedEndPointEndsLongPollingConnection() + { + var manager = CreateConnectionManager(); + var state = manager.CreateConnection(); + + var dispatcher = new HttpConnectionDispatcher(manager, new LoggerFactory()); + + var context = MakeRequest("/poll", state); + + await dispatcher.ExecuteAsync("", context); + + Assert.Equal(StatusCodes.Status204NoContent, context.Response.StatusCode); + + ConnectionState removed; + bool exists = manager.TryGetConnection(state.Connection.ConnectionId, out removed); + Assert.False(exists); + } + + [Fact] + public async Task RequestToActiveConnectionId409ForStreamingTransports() + { + var manager = CreateConnectionManager(); + var state = manager.CreateConnection(); + + var dispatcher = new HttpConnectionDispatcher(manager, new LoggerFactory()); + + var context1 = MakeRequest("/sse", state); + var context2 = MakeRequest("/sse", state); + + var request1 = dispatcher.ExecuteAsync("", context1); + + await dispatcher.ExecuteAsync("", context2); + + Assert.Equal(StatusCodes.Status409Conflict, context2.Response.StatusCode); + + manager.CloseConnections(); + + await request1; + } + + [Fact] + public async Task RequestToActiveConnectionIdKillsPreviousConnectionLongPolling() + { + var manager = CreateConnectionManager(); + var state = manager.CreateConnection(); + + var dispatcher = new HttpConnectionDispatcher(manager, new LoggerFactory()); + + var context1 = MakeRequest("/poll", state); + var context2 = MakeRequest("/poll", state); + + var request1 = dispatcher.ExecuteAsync("", context1); + var request2 = dispatcher.ExecuteAsync("", context2); + + await request1; + + Assert.Equal(StatusCodes.Status204NoContent, context1.Response.StatusCode); + Assert.Equal(ConnectionState.ConnectionStatus.Active, state.Status); + + Assert.False(request2.IsCompleted); + + manager.CloseConnections(); + + await request2; + } + + [Theory] + [InlineData("/sse")] + [InlineData("/poll")] + public async Task RequestToDisposedConnectionIdReturns404(string path) + { + var manager = CreateConnectionManager(); + var state = manager.CreateConnection(); + state.Status = ConnectionState.ConnectionStatus.Disposed; + + var dispatcher = new HttpConnectionDispatcher(manager, new LoggerFactory()); + + var context = MakeRequest(path, state); + + await dispatcher.ExecuteAsync("", context); + + Assert.Equal(StatusCodes.Status404NotFound, context.Response.StatusCode); + } + + [Fact] + public async Task ConnectionStateSetToInactiveAfterPoll() + { + var manager = CreateConnectionManager(); + var state = manager.CreateConnection(); + + var dispatcher = new HttpConnectionDispatcher(manager, new LoggerFactory()); + + var context = MakeRequest("/poll", state); + + var task = dispatcher.ExecuteAsync("", context); + + var buffer = ReadableBuffer.Create(Encoding.UTF8.GetBytes("Hello World")).Preserve(); + + // Write to the transport so the poll yields + await state.Connection.Transport.Output.WriteAsync(new Message(buffer, Format.Text, endOfMessage: true)); + + await task; + + Assert.Equal(ConnectionState.ConnectionStatus.Inactive, state.Status); + Assert.Null(state.RequestId); + + Assert.Equal(StatusCodes.Status200OK, context.Response.StatusCode); + } + + private static DefaultHttpContext MakeRequest(string path, ConnectionState state) where TEndPoint : EndPoint + { + var context = new DefaultHttpContext(); + var services = new ServiceCollection(); + services.AddSingleton(); + context.RequestServices = services.BuildServiceProvider(); + context.Request.Path = path; + var values = new Dictionary(); + values["id"] = state.Connection.ConnectionId; + var qs = new QueryCollection(values); + context.Request.Query = qs; + return context; + } + + private static ConnectionManager CreateConnectionManager() + { + return new ConnectionManager(new Logger(new LoggerFactory())); + } } - public class TestEndPoint : EndPoint + public class ImmediatelyCompleteEndPoint : EndPoint { public override Task OnConnectedAsync(Connection connection) { - throw new NotImplementedException(); + return Task.CompletedTask; + } + } + + public class TestEndPoint : EndPoint + { + public override async Task OnConnectedAsync(Connection connection) + { + while (await connection.Transport.Input.WaitToReadAsync()) + { + } } } } diff --git a/test/Microsoft.AspNetCore.Sockets.Tests/LongPollingTests.cs b/test/Microsoft.AspNetCore.Sockets.Tests/LongPollingTests.cs index f85cc529f8..035c81db65 100644 --- a/test/Microsoft.AspNetCore.Sockets.Tests/LongPollingTests.cs +++ b/test/Microsoft.AspNetCore.Sockets.Tests/LongPollingTests.cs @@ -27,7 +27,7 @@ public async Task Set204StatusCodeWhenChannelComplete() Assert.True(channel.Out.TryComplete()); - await poll.ProcessRequestAsync(context); + await poll.ProcessRequestAsync(context, context.RequestAborted); Assert.Equal(204, context.Response.StatusCode); } @@ -48,7 +48,7 @@ public async Task FrameSentAsSingleResponse() Assert.True(channel.Out.TryComplete()); - await poll.ProcessRequestAsync(context); + await poll.ProcessRequestAsync(context, context.RequestAborted); Assert.Equal(200, context.Response.StatusCode); Assert.Equal("Hello World", Encoding.UTF8.GetString(ms.ToArray())); diff --git a/test/Microsoft.AspNetCore.Sockets.Tests/ServerSentEventsTests.cs b/test/Microsoft.AspNetCore.Sockets.Tests/ServerSentEventsTests.cs index 04c1a67245..5fd0e7104e 100644 --- a/test/Microsoft.AspNetCore.Sockets.Tests/ServerSentEventsTests.cs +++ b/test/Microsoft.AspNetCore.Sockets.Tests/ServerSentEventsTests.cs @@ -24,7 +24,7 @@ public async Task SSESetsContentType() Assert.True(channel.Out.TryComplete()); - await sse.ProcessRequestAsync(context); + await sse.ProcessRequestAsync(context, context.RequestAborted); Assert.Equal("text/event-stream", context.Response.ContentType); Assert.Equal("no-cache", context.Response.Headers["Cache-Control"]); @@ -46,7 +46,7 @@ public async Task SSEAddsAppropriateFraming() Assert.True(channel.Out.TryComplete()); - await sse.ProcessRequestAsync(context); + await sse.ProcessRequestAsync(context, context.RequestAborted); var expected = "data: Hello World\n\n"; Assert.Equal(expected, Encoding.UTF8.GetString(ms.ToArray()));