Skip to content

Commit

Permalink
Fix for issue dotnet#54500: Added Microsoft.Extensions.DependencyInje…
Browse files Browse the repository at this point in the history
…ction to Microsoft.AspNetCore.Http.Abstractions and utilized the ActivatorUtilities it provides to obtain a middleware instance. Also changed the ReflectionMiddlewareBinder to be able to handle keyed injection.
  • Loading branch information
NicoBrabers committed May 14, 2024
1 parent 1b454f5 commit 33c9cf4
Show file tree
Hide file tree
Showing 6 changed files with 190 additions and 19 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
using System.Runtime.CompilerServices;
using Microsoft.AspNetCore.Http;
using Microsoft.AspNetCore.Http.Abstractions;
using Microsoft.Extensions.Internal;
using Microsoft.Extensions.DependencyInjection;

namespace Microsoft.AspNetCore.Builder;

Expand All @@ -21,6 +21,7 @@ public static class UseMiddlewareExtensions
internal const string InvokeAsyncMethodName = "InvokeAsync";

private static readonly MethodInfo GetServiceInfo = typeof(UseMiddlewareExtensions).GetMethod(nameof(GetService), BindingFlags.NonPublic | BindingFlags.Static)!;
private static readonly MethodInfo GetKeyedServiceInfo = typeof(UseMiddlewareExtensions).GetMethod(nameof(GetKeyedService), BindingFlags.NonPublic | BindingFlags.Static)!;

// We're going to keep all public constructors and public methods on middleware
private const DynamicallyAccessedMemberTypes MiddlewareAccessibility =
Expand Down Expand Up @@ -215,13 +216,55 @@ public RequestDelegate CreateMiddleware(RequestDelegate next)
methodArguments[0] = context;
for (var i = 1; i < parameters.Length; i++)
{
methodArguments[i] = GetService(serviceProvider, parameters[i].ParameterType, methodInfo.DeclaringType!);
var parameter = parameters[i];
var hasServiceKey = TryGetServiceKey(parameter, out object? key);
var parameterType = parameter.ParameterType;
var declaringType = methodInfo.DeclaringType;
methodArguments[i] = hasServiceKey ? GetKeyedService(serviceProvider, key!, parameterType, declaringType!) : GetService(serviceProvider, parameterType, declaringType!);
}
return (Task)methodInfo.Invoke(middleware, BindingFlags.DoNotWrapExceptions, binder: null, methodArguments, culture: null)!;
};
}

private static bool TryGetServiceKey(ParameterInfo parameterInfo, out object? key)
{
if (parameterInfo.CustomAttributes != null)
{
foreach (var attribute in parameterInfo.GetCustomAttributes(true))
{
if (attribute is FromKeyedServicesAttribute keyed)
{
key = keyed.Key;
return true;
}
}
}
key = null;
return false;
}

private static UnaryExpression GetMethodArgument(ParameterInfo parameter, ParameterExpression providerArg, Type parameterType, Type? declaringType)
{
var parameterTypeExpression = new List<Expression>() { providerArg };
var hasServiceKey = TryGetServiceKey(parameter, out object? key);

if (hasServiceKey)
{
parameterTypeExpression.Add(Expression.Constant(key, typeof(object)));
}

parameterTypeExpression.Add(Expression.Constant(parameterType, typeof(Type)));
parameterTypeExpression.Add(Expression.Constant(declaringType, typeof(Type)));

var getServiceCall = Expression.Call(hasServiceKey ? GetKeyedServiceInfo : GetServiceInfo, parameterTypeExpression);
var methodArgument = Expression.Convert(getServiceCall, parameterType);

return methodArgument;
}

private static Func<T, HttpContext, IServiceProvider, Task> CompileExpression<T>(MethodInfo methodInfo, ParameterInfo[] parameters)
{
Debug.Assert(RuntimeFeature.IsDynamicCodeSupported, "Use compiled expression when dynamic code is supported.");
Expand Down Expand Up @@ -262,21 +305,14 @@ public RequestDelegate CreateMiddleware(RequestDelegate next)
methodArguments[0] = httpContextArg;
for (var i = 1; i < parameters.Length; i++)
{
var parameterType = parameters[i].ParameterType;
var parameter = parameters[i];
var parameterType = parameter.ParameterType;
if (parameterType.IsByRef)
{
throw new NotSupportedException(Resources.FormatException_InvokeDoesNotSupportRefOrOutParams(InvokeMethodName));
}

var parameterTypeExpression = new Expression[]
{
providerArg,
Expression.Constant(parameterType, typeof(Type)),
Expression.Constant(methodInfo.DeclaringType, typeof(Type))
};

var getServiceCall = Expression.Call(GetServiceInfo, parameterTypeExpression);
methodArguments[i] = Expression.Convert(getServiceCall, parameterType);
methodArguments[i] = GetMethodArgument(parameter, providerArg, parameterType, methodInfo.DeclaringType);
}

Expression middlewareInstanceArg = instanceArg;
Expand All @@ -294,12 +330,20 @@ public RequestDelegate CreateMiddleware(RequestDelegate next)

private static object GetService(IServiceProvider sp, Type type, Type middleware)
{
var service = sp.GetService(type);
if (service == null)
var service = sp.GetService(type) ?? throw new InvalidOperationException(Resources.FormatException_InvokeMiddlewareNoService(type, middleware));

return service;
}

private static object GetKeyedService(IServiceProvider sp, object key, Type type, Type middleware)
{
if (sp is IKeyedServiceProvider ksp)
{
throw new InvalidOperationException(Resources.FormatException_InvokeMiddlewareNoService(type, middleware));
var service = ksp.GetKeyedService(type, key) ?? throw new InvalidOperationException(Resources.FormatException_InvokeMiddlewareNoService(type, middleware));

return service;
}

return service;
throw new InvalidOperationException(Resources.Exception_KeyedServicesNotSupported);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,8 @@ Microsoft.AspNetCore.Http.HttpResponse</Description>
<Reference Include="Microsoft.AspNetCore.Http.Features" />
<Reference Include="Microsoft.Net.Http.Headers" />
<Reference Include="Microsoft.Extensions.Logging.Abstractions" />
<Reference Include="Microsoft.Extensions.DependencyInjection.Abstractions" />

<Compile Include="$(SharedSourceRoot)ActivatorUtilities\*.cs" />
<Compile Include="$(SharedSourceRoot)ParameterDefaultValue\*.cs" />
<Compile Include="$(SharedSourceRoot)PropertyHelper\**\*.cs" />
<Compile Include="$(SharedSourceRoot)\UrlDecoder\UrlDecoder.cs" Link="UrlDecoder.cs" />
Expand Down
6 changes: 6 additions & 0 deletions src/Http/Http.Abstractions/src/Resources.resx
Original file line number Diff line number Diff line change
Expand Up @@ -162,4 +162,10 @@
<data name="RouteValueDictionary_DuplicatePropertyName" xml:space="preserve">
<value>The type '{0}' defines properties '{1}' and '{2}' which differ only by casing. This is not supported by {3} which uses case-insensitive comparisons.</value>
</data>
<data name="Exception_KeyedServicesNotSupported" xml:space="preserve">
<value>This service provider doesn't support keyed services.</value>
</data>
<data name="Exception_NoServiceRegistered" xml:space="preserve">
<value>No service for type '{0}' has been registered.</value>
</data>
</root>
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
<Reference Include="Microsoft.AspNetCore.Routing" />
<Reference Include="Microsoft.AspNetCore.TestHost" />
<Reference Include="Mono.TextTemplating" />
<Reference Include="Microsoft.Extensions.DependencyInjection.Abstractions" />
</ItemGroup>

<ItemGroup>
Expand Down
100 changes: 100 additions & 0 deletions src/Http/Http.Abstractions/test/UseMiddlewareTest.cs
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

using Microsoft.AspNetCore.Builder;
using Microsoft.AspNetCore.Http.Abstractions;
using Microsoft.Extensions.DependencyInjection;

namespace Microsoft.AspNetCore.Http;

Expand Down Expand Up @@ -130,6 +131,32 @@ public async Task UseMiddleware_ThrowsIfArgCantBeResolvedFromContainer()
exception.Message);
}

[Fact]
public async Task UseMiddleware_ThrowsIfKeyedArgCantBeResolvedFromContainer()
{
var builder = new ApplicationBuilder(new DummyKeyedServiceProvider());
builder.UseMiddleware(typeof(MiddlewareKeyedInjectInvokeNoService));
var app = builder.Build();
var exception = await Assert.ThrowsAsync<InvalidOperationException>(() => app(new DefaultHttpContext()));
Assert.Equal(
Resources.FormatException_InvokeMiddlewareNoService(
typeof(object),
typeof(MiddlewareKeyedInjectInvokeNoService)),
exception.Message);
}

[Fact]
public async Task UseMiddleware_ThrowsIfServiceProviderIsNotAIKeyedServiceProvider()
{
var builder = new ApplicationBuilder(new DummyServiceProvider());
builder.UseMiddleware(typeof(MiddlewareKeyedInjectInvokeNoService));
var app = builder.Build();
var exception = await Assert.ThrowsAsync<InvalidOperationException>(() => app(new DefaultHttpContext()));
Assert.Equal(
Resources.Exception_KeyedServicesNotSupported,
exception.Message);
}

[Fact]
public void UseMiddlewareWithInvokeArg()
{
Expand All @@ -139,6 +166,17 @@ public void UseMiddlewareWithInvokeArg()
app(new DefaultHttpContext());
}

[Fact]
public void UseMiddlewareWithInvokeKeyedArg()
{
var keyedServiceProvider = new DummyKeyedServiceProvider();
keyedServiceProvider.AddKeyedService("test", typeof(DummyKeyedServiceProvider), keyedServiceProvider);
var builder = new ApplicationBuilder(keyedServiceProvider);
builder.UseMiddleware(typeof(MiddlewareKeyedInjectInvoke));
var app = builder.Build();
app(new DefaultHttpContext());
}

[Fact]
public void UseMiddlewareWithInvokeWithOutAndRefThrows()
{
Expand Down Expand Up @@ -274,6 +312,54 @@ private class DummyServiceProvider : IServiceProvider
}
}

private class DummyKeyedServiceProvider : IKeyedServiceProvider
{
private readonly Dictionary<object, Tuple<Type, object>> _services = new Dictionary<object, Tuple<Type, object>>();

public DummyKeyedServiceProvider()
{

}

public void AddKeyedService(object key, Type type, object value) => _services[key] = new Tuple<Type, object>(type, value);

public object? GetKeyedService(Type serviceType, object? serviceKey)
{
if (_services.TryGetValue(serviceKey!, out var value))
{
return value.Item2;
}

return null;
}

public object GetRequiredKeyedService(Type serviceType, object? serviceKey)
{
var service = GetKeyedService(serviceType, serviceKey);

if (service == null)
{
throw new InvalidOperationException(Resources.FormatException_NoServiceRegistered(serviceType));
}

return service;
}

public object? GetService(Type serviceType)
{
if (serviceType == typeof(IServiceProvider))
{
return this;
}

if (_services.TryGetValue(serviceType, out var value))
{
return value;
}
return null;
}
}

public class MiddlewareInjectWithOutAndRefParams
{
public MiddlewareInjectWithOutAndRefParams(RequestDelegate next) { }
Expand All @@ -293,13 +379,27 @@ private class MiddlewareInjectInvokeNoService
public Task Invoke(HttpContext context, object value) => Task.CompletedTask;
}

private class MiddlewareKeyedInjectInvokeNoService
{
public MiddlewareKeyedInjectInvokeNoService(RequestDelegate next) { }

public Task Invoke(HttpContext context, [FromKeyedServices("test")] object value) => Task.CompletedTask;
}

private class MiddlewareInjectInvoke
{
public MiddlewareInjectInvoke(RequestDelegate next) { }

public Task Invoke(HttpContext context, IServiceProvider provider) => Task.CompletedTask;
}

private class MiddlewareKeyedInjectInvoke
{
public MiddlewareKeyedInjectInvoke(RequestDelegate next) { }

public Task Invoke(HttpContext context, [FromKeyedServices("test")] IKeyedServiceProvider provider) => Task.CompletedTask;
}

private class MiddlewareNoParametersStub
{
public MiddlewareNoParametersStub(RequestDelegate next) { }
Expand Down
24 changes: 22 additions & 2 deletions src/Http/Http.Abstractions/test/UsePathBaseExtensionsTests.cs
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,6 @@ public IServiceProvider ApplicationServices
public IFeatureCollection ServerFeatures => _wrappedBuilder.ServerFeatures;
public RequestDelegate Build() => _wrappedBuilder.Build();
public IApplicationBuilder New() => _wrappedBuilder.New();

}

[Theory]
Expand Down Expand Up @@ -238,6 +237,27 @@ private static HttpContext CreateRequest(string pathBase, string requestPath)

private static ApplicationBuilder CreateBuilder()
{
return new ApplicationBuilder(serviceProvider: null!);
return new ApplicationBuilder(new DummyServiceProvider());
}

private class DummyServiceProvider : IServiceProvider
{
private readonly Dictionary<Type, object> _services = new Dictionary<Type, object>();

public void AddService(Type type, object value) => _services[type] = value;

public object? GetService(Type serviceType)
{
if (serviceType == typeof(IServiceProvider))
{
return this;
}

if (_services.TryGetValue(serviceType, out var value))
{
return value;
}
return null;
}
}
}

0 comments on commit 33c9cf4

Please sign in to comment.