Skip to content

Commit

Permalink
[H/3] Fix code passed into QuicConnection.CloseAsync and QuicStream.A…
Browse files Browse the repository at this point in the history
…bort (#55282)

* Fix code passed into QuicConnection.CloseAsync and QuicStream.Abort

* Validate that configured error code is used

---------

Co-authored-by: Andrew Casey <andrew.casey@microsoft.com>
  • Loading branch information
ManickaP and amcasey committed May 3, 2024
1 parent e6cf613 commit b37eafd
Show file tree
Hide file tree
Showing 9 changed files with 138 additions and 13 deletions.
4 changes: 3 additions & 1 deletion src/Servers/Kestrel/Core/src/Internal/Http3/Http3Stream.cs
Expand Up @@ -166,6 +166,9 @@ private void AbortCore(Exception exception, Http3ErrorCode errorCode)
abortReason = new ConnectionAbortedException(exception.Message, exception);
}

// This has the side-effect of validating the error code, so do it before we consume the error code
_errorCodeFeature.Error = (long)errorCode;

_context.WebTransportSession?.Abort(abortReason, errorCode);

Log.Http3StreamAbort(TraceIdentifier, errorCode, abortReason);
Expand All @@ -181,7 +184,6 @@ private void AbortCore(Exception exception, Http3ErrorCode errorCode)
RequestBodyPipe.Writer.Complete(exception);

// Abort framewriter and underlying transport after stopping output.
_errorCodeFeature.Error = (long)errorCode;
_frameWriter.Abort(abortReason);
}
}
Expand Down
Expand Up @@ -16,7 +16,11 @@ internal sealed partial class QuicConnectionContext : IProtocolErrorCodeFeature,
public long Error
{
get => _error ?? -1;
set => _error = value;
set
{
QuicTransportOptions.ValidateErrorCode(value);
_error = value;
}
}

public X509Certificate2? ClientCertificate
Expand Down
Expand Up @@ -56,6 +56,7 @@ public override async ValueTask DisposeAsync()
{
lock (_shutdownLock)
{
// The DefaultCloseErrorCode setter validates that the error code is within the valid range
_closeTask ??= _connection.CloseAsync(errorCode: _context.Options.DefaultCloseErrorCode).AsTask();
}

Expand All @@ -81,7 +82,7 @@ public override void Abort(ConnectionAbortedException abortReason)
return;
}

var resolvedErrorCode = _error ?? 0;
var resolvedErrorCode = _error ?? 0; // Only valid error codes are assigned to _error
_abortReason = ExceptionDispatchInfo.Capture(abortReason);
QuicLog.ConnectionAbort(_log, this, resolvedErrorCode, abortReason.Message);
_closeTask = _connection.CloseAsync(errorCode: resolvedErrorCode).AsTask();
Expand Down Expand Up @@ -130,7 +131,7 @@ public override void Abort(ConnectionAbortedException abortReason)
catch (QuicException ex) when (ex.QuicError == QuicError.ConnectionAborted)
{
// Shutdown initiated by peer, abortive.
_error = ex.ApplicationErrorCode;
_error = ex.ApplicationErrorCode; // Trust Quic to provide us a valid error code
QuicLog.ConnectionAborted(_log, this, ex.ApplicationErrorCode.GetValueOrDefault(), ex);

ThreadPool.UnsafeQueueUserWorkItem(state =>
Expand Down
Expand Up @@ -38,7 +38,11 @@ public OnCloseRegistration(Action<object?> callback, object? state)
public long Error
{
get => _error ?? -1;
set => _error = value;
set
{
QuicTransportOptions.ValidateErrorCode(value);
_error = value;
}
}

public long StreamId { get; private set; }
Expand All @@ -54,6 +58,8 @@ public long Error

public void AbortRead(long errorCode, ConnectionAbortedException abortReason)
{
QuicTransportOptions.ValidateErrorCode(errorCode);

lock (_shutdownLock)
{
if (_stream != null)
Expand All @@ -74,6 +80,8 @@ public void AbortRead(long errorCode, ConnectionAbortedException abortReason)

public void AbortWrite(long errorCode, ConnectionAbortedException abortReason)
{
QuicTransportOptions.ValidateErrorCode(errorCode);

lock (_shutdownLock)
{
if (_stream != null)
Expand Down
Expand Up @@ -273,7 +273,7 @@ private async ValueTask DoReceiveAsync()
catch (QuicException ex) when (ex.QuicError is QuicError.StreamAborted or QuicError.ConnectionAborted)
{
// Abort from peer.
_error = ex.ApplicationErrorCode;
_error = ex.ApplicationErrorCode; // Trust Quic to provide us a valid error code
QuicLog.StreamAbortedRead(_log, this, ex.ApplicationErrorCode.GetValueOrDefault());

// This could be ignored if _shutdownReason is already set.
Expand Down Expand Up @@ -434,7 +434,7 @@ private async ValueTask DoSendAsync()
catch (QuicException ex) when (ex.QuicError is QuicError.StreamAborted or QuicError.ConnectionAborted)
{
// Abort from peer.
_error = ex.ApplicationErrorCode;
_error = ex.ApplicationErrorCode; // Trust Quic to provide us a valid error code
QuicLog.StreamAbortedWrite(_log, this, ex.ApplicationErrorCode.GetValueOrDefault());

// This could be ignored if _shutdownReason is already set.
Expand Down Expand Up @@ -501,7 +501,7 @@ public override void Abort(ConnectionAbortedException abortReason)
_shutdownReason = abortReason;
}

var resolvedErrorCode = _error ?? 0;
var resolvedErrorCode = _error ?? 0; // _error is validated on assignment
QuicLog.StreamAbort(_log, this, resolvedErrorCode, abortReason.Message);

if (stream.CanRead)
Expand Down
Expand Up @@ -68,14 +68,15 @@ public long DefaultCloseErrorCode
}
}

private static void ValidateErrorCode(long errorCode)
internal static void ValidateErrorCode(long errorCode)
{
const long MinErrorCode = 0;
const long MaxErrorCode = (1L << 62) - 1;

if (errorCode < MinErrorCode || errorCode > MaxErrorCode)
{
throw new ArgumentOutOfRangeException(nameof(errorCode), errorCode, $"A value between {MinErrorCode} and {MaxErrorCode} is required.");
// Print the values in hex since the max is unintelligible in decimal
throw new ArgumentOutOfRangeException(nameof(errorCode), errorCode, $"A value between 0x{MinErrorCode:x} and 0x{MaxErrorCode:x} is required.");
}
}

Expand Down
Expand Up @@ -706,6 +706,33 @@ public async Task PersistentState_StreamsReused_StatePersisted()
Assert.Equal(true, state);
}

[ConditionalTheory]
[MsQuicSupported]
[InlineData(-1L)] // Too small
[InlineData(1L << 62)] // Too big
public async Task IProtocolErrorFeature_InvalidErrorCode(long errorCode)
{
// Arrange
await using var connectionListener = await QuicTestHelpers.CreateConnectionListenerFactory(LoggerFactory);

var options = QuicTestHelpers.CreateClientConnectionOptions(connectionListener.EndPoint);
await using var clientConnection = await QuicConnection.ConnectAsync(options);

await using var serverConnection = await connectionListener.AcceptAndAddFeatureAsync().DefaultTimeout();

// Act
var clientStream = await clientConnection.OpenOutboundStreamAsync(QuicStreamType.Bidirectional);
await clientStream.WriteAsync(TestData).DefaultTimeout();

var serverStream = await serverConnection.AcceptAsync().DefaultTimeout();

var protocolErrorCodeFeature = serverConnection.Features.Get<IProtocolErrorCodeFeature>();

// Assert
Assert.IsType<QuicConnectionContext>(protocolErrorCodeFeature);
Assert.Throws<ArgumentOutOfRangeException>(() => protocolErrorCodeFeature.Error = errorCode);
}

private record RequestState(
QuicConnection QuicConnection,
MultiplexedConnectionContext ServerConnection,
Expand Down
55 changes: 55 additions & 0 deletions src/Servers/Kestrel/Transport.Quic/test/QuicStreamContextTests.cs
Expand Up @@ -526,4 +526,59 @@ public async Task StreamAbortFeature_AbortWrite_ClientReceivesAbort()
var serverEx = await Assert.ThrowsAsync<ConnectionAbortedException>(() => serverReadTask).DefaultTimeout();
Assert.Equal("Test reason", serverEx.Message);
}

[ConditionalTheory]
[MsQuicSupported]
[InlineData(-1L)] // Too small
[InlineData(1L << 62)] // Too big
public async Task IProtocolErrorFeature_InvalidErrorCode(long errorCode)
{
// Arrange
await using var connectionListener = await QuicTestHelpers.CreateConnectionListenerFactory(LoggerFactory);

var options = QuicTestHelpers.CreateClientConnectionOptions(connectionListener.EndPoint);
await using var clientConnection = await QuicConnection.ConnectAsync(options);

await using var serverConnection = await connectionListener.AcceptAndAddFeatureAsync().DefaultTimeout();

// Act
var clientStream = await clientConnection.OpenOutboundStreamAsync(QuicStreamType.Bidirectional);
await clientStream.WriteAsync(TestData).DefaultTimeout();

var serverStream = await serverConnection.AcceptAsync().DefaultTimeout();

var protocolErrorCodeFeature = serverStream.Features.Get<IProtocolErrorCodeFeature>();

// Assert
Assert.IsType<QuicStreamContext>(protocolErrorCodeFeature);
Assert.Throws<ArgumentOutOfRangeException>(() => protocolErrorCodeFeature.Error = errorCode);
}

[ConditionalTheory]
[MsQuicSupported]
[InlineData(-1L)] // Too small
[InlineData(1L << 62)] // Too big
public async Task IStreamAbortFeature_InvalidErrorCode(long errorCode)
{
// Arrange
await using var connectionListener = await QuicTestHelpers.CreateConnectionListenerFactory(LoggerFactory);

var options = QuicTestHelpers.CreateClientConnectionOptions(connectionListener.EndPoint);
await using var clientConnection = await QuicConnection.ConnectAsync(options);

await using var serverConnection = await connectionListener.AcceptAndAddFeatureAsync().DefaultTimeout();

// Act
var clientStream = await clientConnection.OpenOutboundStreamAsync(QuicStreamType.Bidirectional);
await clientStream.WriteAsync(TestData).DefaultTimeout();

var serverStream = await serverConnection.AcceptAsync().DefaultTimeout();

var protocolErrorCodeFeature = serverStream.Features.Get<IStreamAbortFeature>();

// Assert
Assert.IsType<QuicStreamContext>(protocolErrorCodeFeature);
Assert.Throws<ArgumentOutOfRangeException>(() => protocolErrorCodeFeature.AbortRead(errorCode, new ConnectionAbortedException()));
Assert.Throws<ArgumentOutOfRangeException>(() => protocolErrorCodeFeature.AbortWrite(errorCode, new ConnectionAbortedException()));
}
}
Expand Up @@ -13,17 +13,16 @@
using Microsoft.AspNetCore.Http;
using Microsoft.AspNetCore.Http.Features;
using Microsoft.AspNetCore.Internal;
using Microsoft.AspNetCore.InternalTesting;
using Microsoft.AspNetCore.Server.Kestrel.Core;
using Microsoft.AspNetCore.Server.Kestrel.Https;
using Microsoft.AspNetCore.InternalTesting;
using Microsoft.AspNetCore.Server.Kestrel.Transport.Quic;
using Microsoft.Extensions.DependencyInjection;
using Microsoft.Extensions.Diagnostics.Metrics;
using Microsoft.Extensions.Diagnostics.Metrics.Testing;
using Microsoft.Extensions.Hosting;
using Microsoft.Extensions.Logging;
using Microsoft.Extensions.Logging.Testing;
using Microsoft.Extensions.Primitives;
using Xunit;

namespace Interop.FunctionalTests.Http3;

Expand Down Expand Up @@ -2031,6 +2030,34 @@ public async Task GET_GracefulServerShutdown_RequestCompleteSuccessfullyInsideHo
}
}

[ConditionalFact]
[MsQuicSupported]
public async Task ServerReset_InvalidErrorCode()
{
var ranHandler = false;
var hostBuilder = CreateHostBuilder(context =>
{
ranHandler = true;
// Can't test a too-large value since it's bigger than int
//Assert.Throws<ArgumentOutOfRangeException>(() => context.Features.Get<IHttpResetFeature>().Reset(-1)); // Invalid negative value
context.Features.Get<IHttpResetFeature>().Reset(-1);
return Task.CompletedTask;
});

using var host = await hostBuilder.StartAsync().DefaultTimeout();
using var client = HttpHelpers.CreateClient();

var request = new HttpRequestMessage(HttpMethod.Get, $"https://127.0.0.1:{host.GetPort()}/");
request.Version = GetProtocol(HttpProtocols.Http3);
request.VersionPolicy = HttpVersionPolicy.RequestVersionExact;

var response = await client.SendAsync(request, CancellationToken.None).DefaultTimeout();
await host.StopAsync().DefaultTimeout();

Assert.True(ranHandler);
Assert.Equal(HttpStatusCode.InternalServerError, response.StatusCode);
}

private IHostBuilder CreateHostBuilder(RequestDelegate requestDelegate, HttpProtocols? protocol = null, Action<KestrelServerOptions> configureKestrel = null)
{
return HttpHelpers.CreateHostBuilder(AddTestLogging, requestDelegate, protocol, configureKestrel);
Expand Down

0 comments on commit b37eafd

Please sign in to comment.