Skip to content

Commit

Permalink
Add retry for specific codes in managed identity (#4078)
Browse files Browse the repository at this point in the history
* Add retry for specific codes in managed identity

* Refactor http manager and tests

* Mege conflict

* Send back IsRetryable to true

* Address comments

* Address comments

* Update src/client/Microsoft.Identity.Client/MsalException.cs

Co-authored-by: Peter M <34331512+pmaytak@users.noreply.github.com>

* Update src/client/Microsoft.Identity.Client/Http/HttpManagerManagedIdentity.cs

Co-authored-by: Peter M <34331512+pmaytak@users.noreply.github.com>

* Fix test

* Update tests

---------

Co-authored-by: Gladwin Johnson <90415114+gladjohn@users.noreply.github.com>
Co-authored-by: Peter M <34331512+pmaytak@users.noreply.github.com>
  • Loading branch information
3 people committed May 3, 2023
1 parent c3b7ba0 commit 90d6daf
Show file tree
Hide file tree
Showing 30 changed files with 550 additions and 123 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
using System.ComponentModel;
using System.Security.Cryptography.X509Certificates;
using System.Threading.Tasks;
using Microsoft.Identity.Client.AppConfig;
using Microsoft.Identity.Client.Cache;
using Microsoft.Identity.Client.Core;
using Microsoft.Identity.Client.Extensibility;
Expand All @@ -23,9 +24,18 @@ namespace Microsoft.Identity.Client
{
internal sealed class ApplicationConfiguration : IAppConfig
{
public ApplicationConfiguration(bool isConfidentialClient)
public ApplicationConfiguration(MsalClientType applicationType)
{
IsConfidentialClient = isConfidentialClient;
switch (applicationType)
{
case MsalClientType.ConfidentialClient:
IsConfidentialClient = true;
break;

case MsalClientType.ManagedIdentityClient:
IsManagedIdentity = true;
break;
}
}

public const string DefaultClientName = "UnknownClient";
Expand Down Expand Up @@ -111,6 +121,7 @@ public string ClientVersion
public bool IsUserAssignedManagedIdentity { get; internal set; } = false;
public string ManagedIdentityUserAssignedClientId { get; internal set; }
public string ManagedIdentityUserAssignedResourceId { get; internal set; }
public bool IsManagedIdentity { get; }

public Func<AppTokenProviderParameters, Task<AppTokenProviderResult>> AppTokenProvider;

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
using System.Security.Cryptography.X509Certificates;
using System.Threading;
using System.Threading.Tasks;
using Microsoft.Identity.Client.AppConfig;
using Microsoft.Identity.Client.Extensibility;
using Microsoft.Identity.Client.Instance;
using Microsoft.Identity.Client.Internal;
Expand Down Expand Up @@ -46,7 +47,7 @@ internal ConfidentialClientApplicationBuilder(ApplicationConfiguration configura
{
ClientApplicationBase.GuardMobileFrameworks();

var config = new ApplicationConfiguration(isConfidentialClient: true);
var config = new ApplicationConfiguration(MsalClientType.ConfidentialClient);
var builder = new ConfidentialClientApplicationBuilder(config).WithOptions(options);

if (!string.IsNullOrWhiteSpace(options.ClientSecret))
Expand Down Expand Up @@ -79,7 +80,7 @@ public static ConfidentialClientApplicationBuilder Create(string clientId)
{
ClientApplicationBase.GuardMobileFrameworks();

var config = new ApplicationConfiguration(isConfidentialClient: true);
var config = new ApplicationConfiguration(MsalClientType.ConfidentialClient);
return new ConfidentialClientApplicationBuilder(config)
.WithClientId(clientId)
.WithCacheSynchronization(false);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
using System.Security.Cryptography.X509Certificates;
using System.Threading;
using System.Threading.Tasks;
using Microsoft.Identity.Client.AppConfig;
using Microsoft.Identity.Client.Extensibility;
using Microsoft.Identity.Client.Internal;
using Microsoft.Identity.Client.TelemetryCore;
Expand Down Expand Up @@ -46,7 +47,7 @@ internal ManagedIdentityApplicationBuilder(ApplicationConfiguration configuratio
{
ApplicationBase.GuardMobileFrameworks();

var config = new ApplicationConfiguration(isConfidentialClient: false);
var config = new ApplicationConfiguration(MsalClientType.ManagedIdentityClient);
var builder = new ManagedIdentityApplicationBuilder(config).WithOptions(options);

if (!string.IsNullOrWhiteSpace(options.UserAssignedClientId))
Expand All @@ -72,7 +73,7 @@ public static ManagedIdentityApplicationBuilder Create()
{
ApplicationBase.GuardMobileFrameworks();

var config = new ApplicationConfiguration(isConfidentialClient: false);
var config = new ApplicationConfiguration(MsalClientType.ManagedIdentityClient);
return new ManagedIdentityApplicationBuilder(config)
.WithCacheSynchronization(false);
}
Expand All @@ -96,7 +97,7 @@ public static ManagedIdentityApplicationBuilder Create(string userAssignedId)
throw new ArgumentNullException(nameof(userAssignedId));
}

var config = new ApplicationConfiguration(isConfidentialClient: false);
var config = new ApplicationConfiguration(MsalClientType.ManagedIdentityClient);
return new ManagedIdentityApplicationBuilder(config)
.WithUserAssignedManagedIdentity(userAssignedId)
.WithCacheSynchronization(false);
Expand Down Expand Up @@ -200,7 +201,7 @@ public IManagedIdentityApplication Build()
/// <returns></returns>
internal ManagedIdentityApplication BuildConcrete()
{
ValidateUseOfExperimentalFeature("ManagedIdentity");
ValidateUseOfExperimentalFeature("ManagedIdentityClient");
DefaultConfiguration();
return new ManagedIdentityApplication(BuildConfiguration());
}
Expand Down
21 changes: 21 additions & 0 deletions src/client/Microsoft.Identity.Client/AppConfig/MsalClientType.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.

using System;
using System.Collections.Generic;
using System.Linq;
using System.Text;
using System.Threading.Tasks;

namespace Microsoft.Identity.Client.AppConfig
{
/// <summary>
/// Enum to represent the type of MSAL application.
/// </summary>
internal enum MsalClientType
{
ConfidentialClient,
PublicClient,
ManagedIdentityClient
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
using Microsoft.Identity.Client.PlatformsCommon.Factories;
using Microsoft.Identity.Client.PlatformsCommon.Shared;
using System.Runtime.CompilerServices;
using Microsoft.Identity.Client.AppConfig;

#if iOS
using UIKit;
Expand Down Expand Up @@ -46,7 +47,7 @@ internal PublicClientApplicationBuilder(ApplicationConfiguration configuration)
/// parameters, and to create a public client application instance</returns>
public static PublicClientApplicationBuilder CreateWithApplicationOptions(PublicClientApplicationOptions options)
{
var config = new ApplicationConfiguration(isConfidentialClient: false);
var config = new ApplicationConfiguration(MsalClientType.PublicClient);
return new PublicClientApplicationBuilder(config)
.WithOptions(options)
.WithKerberosTicketClaim(options.KerberosServicePrincipalName, options.TicketContainer);
Expand All @@ -62,7 +63,7 @@ public static PublicClientApplicationBuilder CreateWithApplicationOptions(Public
/// parameters, and to create a public client application instance</returns>
public static PublicClientApplicationBuilder Create(string clientId)
{
var config = new ApplicationConfiguration(isConfidentialClient: false);
var config = new ApplicationConfiguration(MsalClientType.PublicClient);
return new PublicClientApplicationBuilder(config).WithClientId(clientId);
}

Expand Down
77 changes: 19 additions & 58 deletions src/client/Microsoft.Identity.Client/Http/HttpManager.cs
Original file line number Diff line number Diff line change
Expand Up @@ -24,17 +24,13 @@ namespace Microsoft.Identity.Client.Http
/// </remarks>
internal class HttpManager : IHttpManager
{
private readonly IMsalHttpClientFactory _httpClientFactory;
//Determines whether or not to retry on 5xx errors. Configurable on application creation. default is true;
private readonly bool _retryConfig;
protected readonly IMsalHttpClientFactory _httpClientFactory;
public long LastRequestDurationInMs { get; private set; }

public HttpManager(IMsalHttpClientFactory httpClientFactory, bool retry = true)
public HttpManager(IMsalHttpClientFactory httpClientFactory)
{
_httpClientFactory = httpClientFactory ??
throw new ArgumentNullException(nameof(httpClientFactory));

_retryConfig = retry;
}

protected virtual HttpClient GetHttpClient()
Expand All @@ -53,70 +49,70 @@ protected virtual HttpClient GetHttpClient()
return await SendPostAsync(endpoint, headers, body, logger, cancellationToken).ConfigureAwait(false);
}

public async Task<HttpResponse> SendPostAsync(
public virtual async Task<HttpResponse> SendPostAsync(
Uri endpoint,
IDictionary<string, string> headers,
HttpContent body,
ILoggerAdapter logger,
CancellationToken cancellationToken = default)
{
return await ExecuteWithRetryAsync(endpoint, headers, body, HttpMethod.Post, logger, cancellationToken: cancellationToken).ConfigureAwait(false);
return await SendRequestAsync(endpoint, headers, body, HttpMethod.Post, logger, cancellationToken: cancellationToken).ConfigureAwait(false);
}

public async Task<HttpResponse> SendGetAsync(
public virtual async Task<HttpResponse> SendGetAsync(
Uri endpoint,
IDictionary<string, string> headers,
ILoggerAdapter logger,
bool retry = true,
CancellationToken cancellationToken = default)
{
return await ExecuteWithRetryAsync(endpoint, headers, null, HttpMethod.Get, logger, retry: retry, cancellationToken: cancellationToken).ConfigureAwait(false);
return await SendRequestAsync(endpoint, headers, null, HttpMethod.Get, logger, cancellationToken: cancellationToken).ConfigureAwait(false);
}

/// <summary>
/// Performs the GET request just like <see cref="SendGetAsync(Uri, IDictionary{string, string}, ILoggerAdapter, bool, CancellationToken)"/>
/// but does not throw a ServiceUnavailable service exception. Instead, it returns the <see cref="HttpResponse"/> associated
/// with the request.
/// </summary>
public async Task<HttpResponse> SendGetForceResponseAsync(
public virtual async Task<HttpResponse> SendGetForceResponseAsync(
Uri endpoint,
IDictionary<string, string> headers,
ILoggerAdapter logger,
bool retry = true,
CancellationToken cancellationToken = default)
{
return await ExecuteWithRetryAsync(endpoint, headers, null, HttpMethod.Get, logger, retry: retry, doNotThrow: true, cancellationToken: cancellationToken).ConfigureAwait(false);
return await SendRequestAsync(endpoint, headers, null, HttpMethod.Get, logger, doNotThrow: true, cancellationToken: cancellationToken).ConfigureAwait(false);
}

/// <summary>
/// Performs the POST request just like <see cref="SendPostAsync(Uri, IDictionary{string, string}, IDictionary{String, String}, ILoggerAdapter, CancellationToken)"/>
/// but does not throw a ServiceUnavailable service exception. Instead, it returns the <see cref="HttpResponse"/> associated
/// with the request.
/// </summary>
public async Task<HttpResponse> SendPostForceResponseAsync(
public virtual async Task<HttpResponse> SendPostForceResponseAsync(
Uri uri,
IDictionary<string, string> headers,
IDictionary<string, string> bodyParameters,
ILoggerAdapter logger,
CancellationToken cancellationToken = default)
{
HttpContent body = bodyParameters == null ? null : new FormUrlEncodedContent(bodyParameters);
return await ExecuteWithRetryAsync(uri, headers, body, HttpMethod.Post, logger, doNotThrow: true, cancellationToken: cancellationToken).ConfigureAwait(false);
return await SendRequestAsync(uri, headers, body, HttpMethod.Post, logger, doNotThrow: true, cancellationToken: cancellationToken).ConfigureAwait(false);
}

/// <summary>
/// Performs the POST request just like <see cref="SendPostAsync(Uri, IDictionary{string, string}, HttpContent, ILoggerAdapter, CancellationToken)"/>
/// but does not throw a ServiceUnavailable service exception. Instead, it returns the <see cref="HttpResponse"/> associated
/// with the request.
/// </summary>
public async Task<HttpResponse> SendPostForceResponseAsync(
public virtual async Task<HttpResponse> SendPostForceResponseAsync(
Uri uri,
IDictionary<string, string> headers,
StringContent body,
ILoggerAdapter logger,
CancellationToken cancellationToken = default)
{
return await ExecuteWithRetryAsync(uri, headers, body, HttpMethod.Post, logger, doNotThrow: true, cancellationToken: cancellationToken).ConfigureAwait(false);
return await SendRequestAsync(uri, headers, body, HttpMethod.Post, logger, doNotThrow: true, cancellationToken: cancellationToken).ConfigureAwait(false);
}

private HttpRequestMessage CreateRequestMessage(Uri endpoint, IDictionary<string, string> headers)
Expand All @@ -134,20 +130,17 @@ private HttpRequestMessage CreateRequestMessage(Uri endpoint, IDictionary<string
return requestMessage;
}

private async Task<HttpResponse> ExecuteWithRetryAsync(
protected virtual async Task<HttpResponse> SendRequestAsync(
Uri endpoint,
IDictionary<string, string> headers,
HttpContent body,
HttpMethod method,
ILoggerAdapter logger,
bool doNotThrow = false,
bool retry = true,
bool retry = false,
CancellationToken cancellationToken = default)
{
Exception timeoutException = null;
bool isRetryableStatusCode = false;
HttpResponse response = null;
bool isRetryable;

try
{
Expand All @@ -172,9 +165,6 @@ private HttpRequestMessage CreateRequestMessage(Uri endpoint, IDictionary<string
logger.Info(() => string.Format(CultureInfo.InvariantCulture,
MsalErrorMessage.HttpRequestUnsuccessful,
(int)response.StatusCode, response.StatusCode));

isRetryableStatusCode = IsRetryableStatusCode((int)response.StatusCode);
isRetryable = isRetryableStatusCode && _retryConfig && !HasRetryAfterHeader(response);
}
catch (TaskCanceledException exception)
{
Expand All @@ -185,32 +175,10 @@ private HttpRequestMessage CreateRequestMessage(Uri endpoint, IDictionary<string
}

logger.Error("The HTTP request failed. " + exception.Message);
isRetryable = true;
timeoutException = exception;
}

if (isRetryable && retry)
{
logger.Info("Retrying one more time..");
await Task.Delay(TimeSpan.FromSeconds(1)).ConfigureAwait(false);
return await ExecuteWithRetryAsync(
endpoint,
headers,
body,
method,
logger,
doNotThrow,
retry: false,
cancellationToken: cancellationToken).ConfigureAwait(false);
}

logger.Warning("Request retry failed.");
if (timeoutException != null)
{
throw new MsalServiceException(
MsalError.RequestTimeout,
"Request to the endpoint timed out.",
timeoutException);
exception);
}

if (doNotThrow)
Expand All @@ -219,7 +187,7 @@ private HttpRequestMessage CreateRequestMessage(Uri endpoint, IDictionary<string
}

// package 500 errors in a "service not available" exception
if (isRetryableStatusCode)
if (IsRetryableStatusCode((int)response.StatusCode))
{
throw MsalServiceExceptionFactory.FromHttpResponse(
MsalError.ServiceNotAvailable,
Expand All @@ -230,14 +198,7 @@ private HttpRequestMessage CreateRequestMessage(Uri endpoint, IDictionary<string
return response;
}

private static bool HasRetryAfterHeader(HttpResponse response)
{
var retryAfter = response?.Headers?.RetryAfter;
return retryAfter != null &&
(retryAfter.Delta.HasValue || retryAfter.Date.HasValue);
}

private async Task<HttpResponse> ExecuteAsync(
protected async Task<HttpResponse> ExecuteAsync(
Uri endpoint,
IDictionary<string, string> headers,
HttpContent body,
Expand Down Expand Up @@ -284,7 +245,7 @@ await client.SendAsync(requestMessage, cancellationToken).ConfigureAwait(false))
};
}

private async Task<HttpContent> CloneHttpContentAsync(HttpContent httpContent)
protected async Task<HttpContent> CloneHttpContentAsync(HttpContent httpContent)
{
var temp = new MemoryStream();
await httpContent.CopyToAsync(temp).ConfigureAwait(false);
Expand Down Expand Up @@ -316,7 +277,7 @@ private async Task<HttpContent> CloneHttpContentAsync(HttpContent httpContent)
/// In HttpManager, the retry policy is based on this simple condition.
/// Avoid changing this, as it's breaking change.
/// </summary>
private static bool IsRetryableStatusCode(int statusCode)
protected virtual bool IsRetryableStatusCode(int statusCode)
{
return statusCode >= 500 && statusCode < 600;
}
Expand Down
29 changes: 29 additions & 0 deletions src/client/Microsoft.Identity.Client/Http/HttpManagerFactory.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.

using System;
using System.Collections.Generic;
using System.Linq;
using System.Text;
using System.Threading.Tasks;

namespace Microsoft.Identity.Client.Http
{
/// <summary>
/// Factory to return the instance of HttpManager based on retry configuration and type of MSAL application.
/// </summary>
internal sealed class HttpManagerFactory
{
public static IHttpManager GetHttpManager(IMsalHttpClientFactory httpClientFactory, bool withRetry, bool isManagedIdentity)
{
if (!withRetry)
{
return new HttpManager(httpClientFactory);
}

return isManagedIdentity ?
new HttpManagerManagedIdentity(httpClientFactory) :
new HttpManagerWithRetry(httpClientFactory);
}
}
}

0 comments on commit 90d6daf

Please sign in to comment.