Skip to content

Commit

Permalink
#1509 #1683 Replace non-WS protocols for the 'ClientWebSocket' in Web…
Browse files Browse the repository at this point in the history
…SocketsProxyMiddleware (#1689)

* Update WebSocketsProxyMiddleware.cs

Fix WebSocket for SignalR

* Repalce url protocol after null check

* small refactoring

* Add error log when replacing protocol in WebSocketProxyMiddleware

Co-authored-by: Raman Maksimchuk <dotnet044@gmail.com>

* Fix build

* Code review

* Fix unit test

* Refactor to remove hardcoded strings of schemes

* Define public constants

* Add unit tests

---------

Co-authored-by: raman-m <dotnet044@gmail.com>
  • Loading branch information
ArtRoman and raman-m committed Sep 29, 2023
1 parent 190b001 commit ab29442
Show file tree
Hide file tree
Showing 3 changed files with 109 additions and 39 deletions.
35 changes: 27 additions & 8 deletions src/Ocelot/WebSockets/WebSocketsProxyMiddleware.cs
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
using Ocelot.Configuration;
using Ocelot.Logging;
using Ocelot.Middleware;
using Ocelot.Request.Middleware;
using System.Net.WebSockets;

namespace Ocelot.WebSockets
Expand All @@ -17,10 +18,14 @@ public class WebSocketsProxyMiddleware : OcelotMiddleware
"Connection", "Host", "Upgrade",
"Sec-WebSocket-Accept", "Sec-WebSocket-Protocol", "Sec-WebSocket-Key", "Sec-WebSocket-Version", "Sec-WebSocket-Extensions",
};

private const int DefaultWebSocketBufferSize = 4096;
private readonly RequestDelegate _next;
private readonly IWebSocketsFactory _factory;

public const string IgnoredSslWarningFormat = $"You have ignored all SSL warnings by using {nameof(DownstreamRoute.DangerousAcceptAnyServerCertificateValidator)} for this downstream route! {nameof(DownstreamRoute.UpstreamPathTemplate)}: '{{0}}', {nameof(DownstreamRoute.DownstreamPathTemplate)}: '{{1}}'.";
public const string InvalidSchemeWarningFormat = "Invalid scheme has detected which will be replaced! Scheme '{0}' of the downstream '{1}'.";

public WebSocketsProxyMiddleware(IOcelotLoggerFactory loggerFactory,
RequestDelegate next,
IWebSocketsFactory factory)
Expand Down Expand Up @@ -73,21 +78,26 @@ private static async Task PumpWebSocket(WebSocket source, WebSocket destination,

public async Task Invoke(HttpContext httpContext)
{
var uri = httpContext.Items.DownstreamRequest().ToUri();
var downstreamRequest = httpContext.Items.DownstreamRequest();
var downstreamRoute = httpContext.Items.DownstreamRoute();
await Proxy(httpContext, uri, downstreamRoute);
await Proxy(httpContext, downstreamRequest, downstreamRoute);
}

private async Task Proxy(HttpContext context, string serverEndpoint, DownstreamRoute downstreamRoute)
private async Task Proxy(HttpContext context, DownstreamRequest request, DownstreamRoute route)
{
if (context == null)
{
throw new ArgumentNullException(nameof(context));
}

if (serverEndpoint == null)
if (request == null)
{
throw new ArgumentNullException(nameof(request));
}

if (route == null)
{
throw new ArgumentNullException(nameof(serverEndpoint));
throw new ArgumentNullException(nameof(route));
}

if (!context.WebSockets.IsWebSocketRequest)
Expand All @@ -97,10 +107,10 @@ private async Task Proxy(HttpContext context, string serverEndpoint, DownstreamR

var client = _factory.CreateClient(); // new ClientWebSocket();

if (downstreamRoute.DangerousAcceptAnyServerCertificateValidator)
if (route.DangerousAcceptAnyServerCertificateValidator)
{
client.Options.RemoteCertificateValidationCallback = (request, certificate, chain, errors) => true;
Logger.LogWarning($"You have ignored all SSL warnings by using {nameof(DownstreamRoute.DangerousAcceptAnyServerCertificateValidator)} for this downstream route! {nameof(DownstreamRoute.UpstreamPathTemplate)}: '{downstreamRoute.UpstreamPathTemplate}', {nameof(DownstreamRoute.DownstreamPathTemplate)}: '{downstreamRoute.DownstreamPathTemplate}'.");
Logger.LogWarning(string.Format(IgnoredSslWarningFormat, route.UpstreamPathTemplate, route.DownstreamPathTemplate));
}

foreach (var protocol in context.WebSockets.WebSocketRequestedProtocols)
Expand All @@ -125,7 +135,16 @@ private async Task Proxy(HttpContext context, string serverEndpoint, DownstreamR
}
}

var destinationUri = new Uri(serverEndpoint);
// Only Uris starting with 'ws://' or 'wss://' are supported in System.Net.WebSockets.ClientWebSocket
var scheme = request.Scheme;
if (!scheme.StartsWith(Uri.UriSchemeWs))
{
Logger.LogWarning(string.Format(InvalidSchemeWarningFormat, scheme, request.ToUri()));
request.Scheme = scheme == Uri.UriSchemeHttp ? Uri.UriSchemeWs
: scheme == Uri.UriSchemeHttps ? Uri.UriSchemeWss : scheme;
}

var destinationUri = new Uri(request.ToUri());
await client.ConnectAsync(destinationUri, context.RequestAborted);

using (var server = await context.WebSockets.AcceptWebSocketAsync(client.SubProtocol))
Expand Down
1 change: 0 additions & 1 deletion test/Ocelot.UnitTests/WebSockets/MockWebSocket.cs
Original file line number Diff line number Diff line change
Expand Up @@ -161,7 +161,6 @@ protected virtual void Dispose(bool disposing)
// // Do not change this code. Put cleanup code in 'Dispose(bool disposing)' method
// Dispose(disposing: false);
// }

public override void Dispose()
{
// Do not change this code. Put cleanup code in 'Dispose(bool disposing)' method
Expand Down
112 changes: 82 additions & 30 deletions test/Ocelot.UnitTests/WebSockets/WebSocketsProxyMiddlewareTests.cs
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ public class WebSocketsProxyMiddlewareTests

private readonly Mock<HttpContext> _context;
private readonly Mock<IOcelotLogger> _logger;
private readonly Mock<IClientWebSocket> _client;

public WebSocketsProxyMiddlewareTests()
{
Expand All @@ -27,48 +28,48 @@ public WebSocketsProxyMiddlewareTests()
_factory = new Mock<IWebSocketsFactory>();

_context = new Mock<HttpContext>();
_context.SetupGet(x => x.WebSockets.IsWebSocketRequest).Returns(true);

_logger = new Mock<IOcelotLogger>();
_loggerFactory.Setup(x => x.CreateLogger<WebSocketsProxyMiddleware>())
.Returns(_logger.Object);

_middleware = new WebSocketsProxyMiddleware(_loggerFactory.Object, _next.Object, _factory.Object);

_client = new Mock<IClientWebSocket>();
_factory.Setup(x => x.CreateClient()).Returns(_client.Object);
}

[Fact]
public void ShouldIgnoreAllSslWarnings_WhenDangerousAcceptAnyServerCertificateValidatorIsTrue()
public void ShouldIgnoreAllSslWarningsWhenDangerousAcceptAnyServerCertificateValidatorIsTrue()
{
this.Given(x => x.GivenPropertyDangerousAcceptAnyServerCertificateValidator(true))
List<object> actual = new();
this.Given(x => x.GivenPropertyDangerousAcceptAnyServerCertificateValidator(true, actual))
.And(x => x.AndDoNotSetupProtocolsAndHeaders())
.And(x => x.AndDoNotConnectReally())
.And(x => x.AndDoNotConnectReally(null))
.When(x => x.WhenInvokeWithHttpContext())
.Then(x => x.ThenIgnoredAllSslWarnings())
.Then(x => x.ThenIgnoredAllSslWarnings(actual))
.BDDfy();
}

private void GivenPropertyDangerousAcceptAnyServerCertificateValidator(bool enabled)
private void GivenPropertyDangerousAcceptAnyServerCertificateValidator(bool enabled, List<object> actual)
{
var request = new HttpRequestMessage(HttpMethod.Get, "http://localhost:80");
var request = new HttpRequestMessage(HttpMethod.Get, $"{Uri.UriSchemeWs}://localhost:12345");
var downstream = new DownstreamRequest(request);
var route = new DownstreamRouteBuilder()
.WithDangerousAcceptAnyServerCertificateValidator(enabled)
.Build();
_context.SetupGet(x => x.Items).Returns(new Dictionary<object, object>
{
{ "DownstreamRequest", downstream },
{ "DownstreamRoute", route },
});

_context.SetupGet(x => x.WebSockets.IsWebSocketRequest).Returns(true);

_client = new Mock<IClientWebSocket>();
_factory.Setup(x => x.CreateClient()).Returns(_client.Object);
{
{ "DownstreamRequest", downstream },
{ "DownstreamRoute", route },
});

_client.SetupSet(x => x.Options.RemoteCertificateValidationCallback = It.IsAny<RemoteCertificateValidationCallback>())
.Callback<RemoteCertificateValidationCallback>(value => _callback = value);
.Callback<RemoteCertificateValidationCallback>(actual.Add);

_warning = string.Empty;
_logger.Setup(x => x.LogWarning(It.IsAny<string>()))
.Callback<string>(message => _warning = message);
.Callback<string>(actual.Add);
}

private void AndDoNotSetupProtocolsAndHeaders()
Expand All @@ -77,9 +78,11 @@ private void AndDoNotSetupProtocolsAndHeaders()
_context.SetupGet(x => x.Request.Headers).Returns(new HeaderDictionary());
}

private void AndDoNotConnectReally()
private void AndDoNotConnectReally(Action<Uri, CancellationToken> callbackConnectAsync)
{
_client.Setup(x => x.ConnectAsync(It.IsAny<Uri>(), It.IsAny<CancellationToken>())).Verifiable();
Action<Uri, CancellationToken> doNothing = (u, t) => { };
_client.Setup(x => x.ConnectAsync(It.IsAny<Uri>(), It.IsAny<CancellationToken>()))
.Callback(callbackConnectAsync ?? doNothing);
var clientSocket = new Mock<WebSocket>();
var serverSocket = new Mock<WebSocket>();
_client.Setup(x => x.ToWebSocket()).Returns(clientSocket.Object);
Expand All @@ -97,28 +100,77 @@ private void AndDoNotConnectReally()
serverSocket.SetupGet(x => x.CloseStatus).Returns(WebSocketCloseStatus.Empty);
}

private Mock<IClientWebSocket> _client;
private RemoteCertificateValidationCallback _callback;
private string _warning;

private async Task WhenInvokeWithHttpContext()
{
await _middleware.Invoke(_context.Object);
}

private void ThenIgnoredAllSslWarnings()
private void ThenIgnoredAllSslWarnings(List<object> actual)
{
_context.Object.Items.DownstreamRoute().DangerousAcceptAnyServerCertificateValidator
.ShouldBeTrue();
var route = _context.Object.Items.DownstreamRoute();
var request = _context.Object.Items.DownstreamRequest();
route.DangerousAcceptAnyServerCertificateValidator.ShouldBeTrue();

_logger.Verify(x => x.LogWarning(It.IsAny<string>()), Times.Once());
_warning.ShouldNotBeNullOrEmpty();
var warning = actual.Last() as string;
warning.ShouldNotBeNullOrEmpty();
var expectedWarning = string.Format(WebSocketsProxyMiddleware.IgnoredSslWarningFormat, route.UpstreamPathTemplate, route.DownstreamPathTemplate);
warning.ShouldBe(expectedWarning);

_client.VerifySet(x => x.Options.RemoteCertificateValidationCallback = It.IsAny<RemoteCertificateValidationCallback>(),
Times.Once());

_callback.ShouldNotBeNull();
var validation = _callback.Invoke(null, null, null, SslPolicyErrors.None);
var callback = actual.First() as RemoteCertificateValidationCallback;
callback.ShouldNotBeNull();
var validation = callback.Invoke(null, null, null, SslPolicyErrors.None);
validation.ShouldBeTrue();
}

[Theory]
[InlineData("http", "ws")]
[InlineData("https", "wss")]
[InlineData("ftp", "ftp")]
public void ShouldReplaceNonWsSchemes(string scheme, string expectedScheme)
{
List<object> actual = new();
this.Given(x => x.GivenNonWebsocketScheme(scheme, actual))
.And(x => x.AndDoNotSetupProtocolsAndHeaders())
.And(x => x.AndDoNotConnectReally((uri, token) => actual.Add(uri)))
.When(x => x.WhenInvokeWithHttpContext())
.Then(x => x.ThenNonWsSchemesAreReplaced(scheme, expectedScheme, actual))
.BDDfy();
}

private void GivenNonWebsocketScheme(string scheme, List<object> actual)
{
var requestMessage = new HttpRequestMessage(HttpMethod.Get, $"{scheme}://localhost:12345");
var request = new DownstreamRequest(requestMessage);
var route = new DownstreamRouteBuilder().Build();
var items = new Dictionary<object, object>
{
{ "DownstreamRequest", request },
{ "DownstreamRoute", route },
};
_context.SetupGet(x => x.Items).Returns(items);

_logger.Setup(x => x.LogWarning(It.IsAny<string>()))
.Callback<string>(actual.Add);
}

private void ThenNonWsSchemesAreReplaced(string scheme, string expectedScheme, List<object> actual)
{
var route = _context.Object.Items.DownstreamRoute();
var request = _context.Object.Items.DownstreamRequest();
route.DangerousAcceptAnyServerCertificateValidator.ShouldBeFalse();

_logger.Verify(x => x.LogWarning(It.IsAny<string>()), Times.Once());
var warning = actual.First() as string;
warning.ShouldNotBeNullOrEmpty();
warning.ShouldContain($"'{scheme}'");
var expectedWarning = string.Format(WebSocketsProxyMiddleware.InvalidSchemeWarningFormat, scheme, request.ToUri().Replace(expectedScheme, scheme));
warning.ShouldBe(expectedWarning);

request.Scheme.ShouldBe(expectedScheme);
((Uri)actual.Last()).Scheme.ShouldBe(expectedScheme);
}
}

0 comments on commit ab29442

Please sign in to comment.