Skip to content

Commit

Permalink
#1396 Lost context User in MultiplexingMiddleware (#1462)
Browse files Browse the repository at this point in the history
* Fix HttpContext.User is lost after passing MultiplexingMiddlware

* Simplify single downstream route handling

* fix

* some refactoring of long code

* add unit tests for #1396 user scenario

* Acceptance test for user forwarding

* refactor test

---------

Co-authored-by: Алексей Патрин <apatrin@croc.ru>
Co-authored-by: Raman Maksimchuk <dotnet044@gmail.com>
  • Loading branch information
3 people committed Feb 22, 2024
1 parent a9dff7c commit 108bede
Show file tree
Hide file tree
Showing 5 changed files with 262 additions and 133 deletions.
58 changes: 29 additions & 29 deletions src/Ocelot/Multiplexer/MultiplexingMiddleware.cs
Expand Up @@ -11,34 +11,41 @@ public class MultiplexingMiddleware : OcelotMiddleware
private readonly RequestDelegate _next;
private readonly IResponseAggregatorFactory _factory;

public MultiplexingMiddleware(RequestDelegate next,
public MultiplexingMiddleware(
RequestDelegate next,
IOcelotLoggerFactory loggerFactory,
IResponseAggregatorFactory factory
)
: base(loggerFactory.CreateLogger<MultiplexingMiddleware>())
IResponseAggregatorFactory factory)
: base(loggerFactory.CreateLogger<MultiplexingMiddleware>())
{
_factory = factory;
_next = next;
}

public async Task Invoke(HttpContext httpContext)
{
var route = httpContext.Items.DownstreamRouteHolder().Route;
if (httpContext.WebSockets.IsWebSocketRequest)
{
//todo this is obviously stupid
httpContext.Items.UpsertDownstreamRoute(httpContext.Items.DownstreamRouteHolder().Route.DownstreamRoute[0]);
// TODO: This is obviously stupid
httpContext.Items.UpsertDownstreamRoute(route.DownstreamRoute[0]);
await _next.Invoke(httpContext);
return;
}

var routeKeysConfigs = httpContext.Items.DownstreamRouteHolder().Route.DownstreamRouteConfig;
if (routeKeysConfigs == null || !routeKeysConfigs.Any())
// Don't do anything extra if downstream route is single
if (route.DownstreamRoute.Count == 1)
{
var downstreamRouteHolder = httpContext.Items.DownstreamRouteHolder();
httpContext.Items.UpsertDownstreamRoute(route.DownstreamRoute[0]);
var singleResponse = await Fire(httpContext, _next);
MapNotAggregate(httpContext, singleResponse);
return;
}

var tasks = new Task<HttpContext>[downstreamRouteHolder.Route.DownstreamRoute.Count];
if (route.DownstreamRouteConfig?.Any() != true)
{
var tasks = new Task<HttpContext>[route.DownstreamRoute.Count];

for (var i = 0; i < downstreamRouteHolder.Route.DownstreamRoute.Count; i++)
for (var i = 0; i < route.DownstreamRoute.Count; i++)
{
var newHttpContext = Copy(httpContext);

Expand All @@ -49,7 +56,7 @@ public async Task Invoke(HttpContext httpContext)
newHttpContext.Items
.UpsertTemplatePlaceholderNameAndValues(httpContext.Items.TemplatePlaceholderNameAndValues());
newHttpContext.Items
.UpsertDownstreamRoute(downstreamRouteHolder.Route.DownstreamRoute[i]);
.UpsertDownstreamRoute(route.DownstreamRoute[i]);

tasks[i] = Fire(newHttpContext, _next);
}
Expand All @@ -64,19 +71,13 @@ public async Task Invoke(HttpContext httpContext)
contexts.Add(finished);
}

await Map(httpContext, downstreamRouteHolder.Route, contexts);
await Map(httpContext, route, contexts);
}
else
{
httpContext.Items.UpsertDownstreamRoute(httpContext.Items.DownstreamRouteHolder().Route.DownstreamRoute[0]);
httpContext.Items.UpsertDownstreamRoute(route.DownstreamRoute[0]);
var mainResponse = await Fire(httpContext, _next);

if (httpContext.Items.DownstreamRouteHolder().Route.DownstreamRoute.Count == 1)
{
MapNotAggregate(httpContext, new List<HttpContext> { mainResponse });
return;
}

var tasks = new List<Task<HttpContext>>();

if (mainResponse.Items.DownstreamResponse() == null)
Expand All @@ -88,13 +89,13 @@ public async Task Invoke(HttpContext httpContext)

var jObject = Newtonsoft.Json.Linq.JToken.Parse(content);

for (var i = 1; i < httpContext.Items.DownstreamRouteHolder().Route.DownstreamRoute.Count; i++)
for (var i = 1; i < route.DownstreamRoute.Count; i++)
{
var templatePlaceholderNameAndValues = httpContext.Items.TemplatePlaceholderNameAndValues();

var downstreamRoute = httpContext.Items.DownstreamRouteHolder().Route.DownstreamRoute[i];
var downstreamRoute = route.DownstreamRoute[i];

var matchAdvancedAgg = routeKeysConfigs
var matchAdvancedAgg = route.DownstreamRouteConfig
.FirstOrDefault(q => q.RouteKey == downstreamRoute.Key);

if (matchAdvancedAgg != null)
Expand Down Expand Up @@ -153,7 +154,7 @@ public async Task Invoke(HttpContext httpContext)
contexts.Add(finished);
}

await Map(httpContext, httpContext.Items.DownstreamRouteHolder().Route, contexts);
await Map(httpContext, route, contexts);
}
}

Expand Down Expand Up @@ -182,6 +183,7 @@ private static HttpContext Copy(HttpContext source)
target.Connection.RemoteIpAddress = source.Connection.RemoteIpAddress;
target.RequestServices = source.RequestServices;
target.RequestAborted = source.RequestAborted;
target.User = source.User;
return target;
}

Expand All @@ -194,15 +196,13 @@ private async Task Map(HttpContext httpContext, Route route, List<HttpContext> c
}
else
{
MapNotAggregate(httpContext, contexts);
// Assume at least one... if this errors then it will be caught by global exception handler
MapNotAggregate(httpContext, contexts.First());
}
}

private static void MapNotAggregate(HttpContext httpContext, List<HttpContext> downstreamContexts)
private static void MapNotAggregate(HttpContext httpContext, HttpContext finished)
{
//assume at least one..if this errors then it will be caught by global exception handler
var finished = downstreamContexts.First();

httpContext.Items.UpsertErrors(finished.Items.Errors());

httpContext.Items.UpsertDownstreamRequest(finished.Items.DownstreamRequest());
Expand Down

0 comments on commit 108bede

Please sign in to comment.