Skip to content
This repository has been archived by the owner on Nov 22, 2018. It is now read-only.

Inject ICorsPolicyProvider instance through Invoke #106

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 2 additions & 11 deletions src/Microsoft.AspNetCore.Cors/Infrastructure/CorsMiddleware.cs
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@ public class CorsMiddleware
{
private readonly RequestDelegate _next;
private readonly ICorsService _corsService;
private readonly ICorsPolicyProvider _corsPolicyProvider;
private readonly CorsPolicy _policy;
private readonly string _corsPolicyName;

Expand All @@ -24,12 +23,10 @@ public class CorsMiddleware
/// </summary>
/// <param name="next">The next middleware in the pipeline.</param>
/// <param name="corsService">An instance of <see cref="ICorsService"/>.</param>
/// <param name="policyProvider">A policy provider which can get an <see cref="CorsPolicy"/>.</param>
/// <param name="policyName">An optional name of the policy to be fetched.</param>
public CorsMiddleware(
RequestDelegate next,
ICorsService corsService,
ICorsPolicyProvider policyProvider,
string policyName)
{
if (next == null)
Expand All @@ -42,14 +39,8 @@ public CorsMiddleware(
throw new ArgumentNullException(nameof(corsService));
}

if (policyProvider == null)
{
throw new ArgumentNullException(nameof(policyProvider));
}

_next = next;
_corsService = corsService;
_corsPolicyProvider = policyProvider;
_corsPolicyName = policyName;
}

Expand Down Expand Up @@ -85,11 +76,11 @@ public CorsMiddleware(
}

/// <inheritdoc />
public async Task Invoke(HttpContext context)
public async Task Invoke(HttpContext context, ICorsPolicyProvider corsPolicyProvider)
{
if (context.Request.Headers.ContainsKey(CorsConstants.Origin))
{
var corsPolicy = _policy ?? await _corsPolicyProvider?.GetPolicyAsync(context, _corsPolicyName);
var corsPolicy = _policy ?? await corsPolicyProvider.GetPolicyAsync(context, _corsPolicyName);
if (corsPolicy != null)
{
var corsResult = _corsService.EvaluatePolicy(context, corsPolicy);
Expand Down
6 changes: 2 additions & 4 deletions test/Microsoft.AspNetCore.Cors.Test/CorsMiddlewareTests.cs
Original file line number Diff line number Diff line change
Expand Up @@ -253,14 +253,13 @@ public async Task Uses_PolicyProvider_AsFallback()
var middleware = new CorsMiddleware(
Mock.Of<RequestDelegate>(),
corsService,
mockProvider.Object,
policyName: null);

var httpContext = new DefaultHttpContext();
httpContext.Request.Headers.Add(CorsConstants.Origin, new[] { "http://example.com" });

// Act
await middleware.Invoke(httpContext);
await middleware.Invoke(httpContext, mockProvider.Object);

// Assert
mockProvider.Verify(
Expand All @@ -281,14 +280,13 @@ public async Task DoesNotSetHeaders_ForNoPolicy()
var middleware = new CorsMiddleware(
Mock.Of<RequestDelegate>(),
corsService,
mockProvider.Object,
policyName: null);

var httpContext = new DefaultHttpContext();
httpContext.Request.Headers.Add(CorsConstants.Origin, new[] { "http://example.com" });

// Act
await middleware.Invoke(httpContext);
await middleware.Invoke(httpContext, mockProvider.Object);

// Assert
Assert.Equal(200, httpContext.Response.StatusCode);
Expand Down