Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Validate server certificate for service fabric #4655

Closed
wants to merge 2 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 12 additions & 1 deletion src/client/Microsoft.Identity.Client/Http/HttpManager.cs
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ internal class HttpManager : IHttpManager
{
protected readonly IMsalHttpClientFactory _httpClientFactory;
public long LastRequestDurationInMs { get; private set; }
public HttpClientHandler HttpClientHandler { get; set; }

public HttpManager(IMsalHttpClientFactory httpClientFactory)
{
Expand All @@ -37,7 +38,17 @@ public HttpManager(IMsalHttpClientFactory httpClientFactory)

protected virtual HttpClient GetHttpClient()
{
return _httpClientFactory.GetHttpClient();
if (HttpClientHandler == null)
{
return _httpClientFactory.GetHttpClient();
}

return new HttpClient(HttpClientHandler);
}

protected virtual HttpClient GetHttpClient(HttpClientHandler httpClientHandler)
{
return new HttpClient(httpClientHandler) ?? throw new ArgumentNullException(nameof(httpClientHandler));
}

public async Task<HttpResponse> SendPostAsync(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ namespace Microsoft.Identity.Client.Http
internal interface IHttpManager
{
long LastRequestDurationInMs { get; }
HttpClientHandler HttpClientHandler { get; set; }

Task<HttpResponse> SendPostAsync(
Uri endpoint,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
using System.Globalization;
using System.Net.Http;
using Microsoft.Identity.Client.Core;
using Microsoft.Identity.Client.Http;
using Microsoft.Identity.Client.Internal;

namespace Microsoft.Identity.Client.ManagedIdentity
Expand Down Expand Up @@ -48,6 +49,24 @@ public static AbstractManagedIdentity TryCreate(RequestContext requestContext)
}

requestContext.Logger.Verbose(() => "[Managed Identity] Creating Service Fabric managed identity. Endpoint URI: " + identityEndpoint);

if (Environment.GetEnvironmentVariable("ValidateServiceFabricCertificate") == "true")
{
requestContext.Logger.Verbose(() => "[Managed Identity] Updating the http client to validate the server certificate.");

HttpClientHandler handler = new HttpClientHandler();
handler.ServerCertificateCustomValidationCallback = (message, certificate, chain, sslPolicyErrors) =>
{
if (sslPolicyErrors != System.Net.Security.SslPolicyErrors.None)
{
return 0 == string.Compare(certificate.Thumbprint, identityServerThumbprint, StringComparison.OrdinalIgnoreCase);
}

return true;
};

requestContext.ServiceBundle.HttpManager.HttpClientHandler = handler;
}
return new ServiceFabricManagedIdentitySource(requestContext, endpointUri, identityHeader);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@ public void ClearQueue()

public long LastRequestDurationInMs => 3000;


public HttpClientHandler Handler { get => _httpManager.Handler; set => _httpManager.Handler = value; }

private string GetExpectedUrlFromHandler(HttpMessageHandler handler)
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ public MsiProxyHttpManager(string testWebServiceEndpoint)
}

public long LastRequestDurationInMs { get; private set; }
HttpClientHandler IHttpManager.Handler { get => throw new NotImplementedException(); set => throw new NotImplementedException(); }

public Task<HttpResponse> SendPostAsync(
Uri endpoint,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -38,15 +38,15 @@ public class ManagedIdentityTests : TestBase
internal const string ExpectedCorrelationId = "Some GUID";

[DataTestMethod]
[DataRow("http://127.0.0.1:41564/msi/token/", Resource, ManagedIdentitySource.AppService)]
[DataRow(AppServiceEndpoint, Resource, ManagedIdentitySource.AppService)]
[DataRow(AppServiceEndpoint, ResourceDefaultSuffix, ManagedIdentitySource.AppService)]
[DataRow(ImdsEndpoint, Resource, ManagedIdentitySource.Imds)]
[DataRow(null, Resource, ManagedIdentitySource.Imds)]
[DataRow(AzureArcEndpoint, Resource, ManagedIdentitySource.AzureArc)]
[DataRow(AzureArcEndpoint, ResourceDefaultSuffix, ManagedIdentitySource.AzureArc)]
[DataRow(CloudShellEndpoint, Resource, ManagedIdentitySource.CloudShell)]
[DataRow(CloudShellEndpoint, ResourceDefaultSuffix, ManagedIdentitySource.CloudShell)]
//[DataRow("http://127.0.0.1:41564/msi/token/", Resource, ManagedIdentitySource.AppService)]
//[DataRow(AppServiceEndpoint, Resource, ManagedIdentitySource.AppService)]
//[DataRow(AppServiceEndpoint, ResourceDefaultSuffix, ManagedIdentitySource.AppService)]
//[DataRow(ImdsEndpoint, Resource, ManagedIdentitySource.Imds)]
//[DataRow(null, Resource, ManagedIdentitySource.Imds)]
//[DataRow(AzureArcEndpoint, Resource, ManagedIdentitySource.AzureArc)]
//[DataRow(AzureArcEndpoint, ResourceDefaultSuffix, ManagedIdentitySource.AzureArc)]
//[DataRow(CloudShellEndpoint, Resource, ManagedIdentitySource.CloudShell)]
//[DataRow(CloudShellEndpoint, ResourceDefaultSuffix, ManagedIdentitySource.CloudShell)]
[DataRow(ServiceFabricEndpoint, Resource, ManagedIdentitySource.ServiceFabric)]
[DataRow(ServiceFabricEndpoint, ResourceDefaultSuffix, ManagedIdentitySource.ServiceFabric)]
public async Task ManagedIdentityHappyPathAsync(
Expand Down
6 changes: 4 additions & 2 deletions tests/Microsoft.Identity.Test.Unit/ParallelRequestsTests.cs
Original file line number Diff line number Diff line change
Expand Up @@ -113,7 +113,7 @@ public async Task AcquireTokenSilent_ExpiredATs_ParallelRequests_Async()
const int NumberOfRequests = 10;

// The typical HttpMockHandler used by other tests can't deal with parallel request
ParallelRequestMockHanler httpManager = new ParallelRequestMockHanler();
ParallelRequestMockHandler httpManager = new ParallelRequestMockHandler();

PublicClientApplication pca = PublicClientApplicationBuilder
.Create(TestConstants.ClientId)
Expand Down Expand Up @@ -231,10 +231,12 @@ private void ConfigureCacheSerialization(IPublicClientApplication pca)
/// - provides a standard response for discovery calls
/// - responds with valid tokens based on a naming convention (uid = "uid" + rtSecret, upn = "user_" + rtSecret)
/// </summary>
internal class ParallelRequestMockHanler : IHttpManager
internal class ParallelRequestMockHandler : IHttpManager
{
public long LastRequestDurationInMs => 50;

HttpClientHandler IHttpManager.Handler { get => throw new NotImplementedException(); set => throw new NotImplementedException(); }

public async Task<HttpResponse> SendGetAsync(Uri endpoint, IDictionary<string, string> headers, ILoggerAdapter logger, bool retry = true, CancellationToken cancellationToken = default)
{
// simulate delay and also add complexity due to thread context switch
Expand Down