Skip to content

Commit

Permalink
Managed Identity - Calculate the refresh_in value based on expires_on (
Browse files Browse the repository at this point in the history
…#4068)

* initial commit

* remove

* simple case

* pr comments

* add AT refresh logic

* uni test

* pr comments

* Update src/client/Microsoft.Identity.Client/OAuth2/MsalTokenResponse.cs

Co-authored-by: Peter M <34331512+pmaytak@users.noreply.github.com>

* ManagedIdentitySourceType fix

* Update ManagedIdentityTests.cs

* pr comments

---------

Co-authored-by: Gladwin Johnson <gljohns@microsoft.com>
Co-authored-by: Peter M <34331512+pmaytak@users.noreply.github.com>
  • Loading branch information
3 people committed May 1, 2023
1 parent 5effd6c commit 9e4bf56
Show file tree
Hide file tree
Showing 9 changed files with 276 additions and 134 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -37,19 +37,21 @@ protected override async Task<AuthenticationResult> ExecuteAsync(CancellationTok
MsalErrorMessage.ScopesRequired);
}

AuthenticationResult authResult = null;
MsalAccessTokenCacheItem cachedAccessTokenItem = null;
var logger = AuthenticationRequestParameters.RequestContext.Logger;
CacheRefreshReason cacheInfoTelemetry = CacheRefreshReason.NotApplicable;

if (!_managedIdentityParameters.ForceRefresh)
{
MsalAccessTokenCacheItem cachedAccessTokenItem = await CacheManager.FindAccessTokenAsync().ConfigureAwait(false);
cachedAccessTokenItem = await CacheManager.FindAccessTokenAsync().ConfigureAwait(false);

if (cachedAccessTokenItem != null)
{
AuthenticationRequestParameters.RequestContext.ApiEvent.IsAccessTokenCacheHit = true;

Metrics.IncrementTotalAccessTokensFromCache();
return new AuthenticationResult(
authResult = new AuthenticationResult(
cachedAccessTokenItem,
null,
AuthenticationRequestParameters.AuthenticationScheme,
Expand Down Expand Up @@ -79,9 +81,35 @@ protected override async Task<AuthenticationResult> ExecuteAsync(CancellationTok
AuthenticationRequestParameters.RequestContext.ApiEvent.CacheInfo = cacheInfoTelemetry;
}

// No AT in the cache
return await FetchNewAccessTokenAsync(cancellationToken).ConfigureAwait(false);

// No AT in the cache or AT needs to be refreshed
try
{
if (cachedAccessTokenItem == null)
{
authResult = await FetchNewAccessTokenAsync(cancellationToken).ConfigureAwait(false);
}
else
{
var shouldRefresh = SilentRequestHelper.NeedsRefresh(cachedAccessTokenItem);

// may fire a request to get a new token in the background
if (shouldRefresh)
{
AuthenticationRequestParameters.RequestContext.ApiEvent.CacheInfo = CacheRefreshReason.ProactivelyRefreshed;

SilentRequestHelper.ProcessFetchInBackground(
cachedAccessTokenItem,
() => FetchNewAccessTokenAsync(cancellationToken), logger);
}
}

return authResult;
}
catch (MsalServiceException e)
{
return await HandleTokenRefreshErrorAsync(e, cachedAccessTokenItem).ConfigureAwait(false);
}

}

private async Task<AuthenticationResult> FetchNewAccessTokenAsync(CancellationToken cancellationToken)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,19 +15,45 @@ namespace Microsoft.Identity.Client.ManagedIdentity
[Preserve(AllMembers = true)]
internal class ManagedIdentityResponse
{
/// <summary>
/// The requested access token.
/// </summary>
/// <remarks>When you call a secured REST API, the token is embedded in the Authorization
/// request header field as a "bearer" token,
/// allowing the API to authenticate the caller.</remarks>
[JsonProperty("access_token")]
public string AccessToken { get; set; }

/// <summary>
/// The timespan when the access token expires.
/// </summary>
/// <remarks>The date is represented as the number of seconds from "1970-01-01T0:0:0Z UTC"
/// (corresponds to the token's exp claim).</remarks>
[JsonProperty("expires_on")]
public string ExpiresOn { get; set; }

/// <summary>
/// The resource the access token was requested for.
/// </summary>
/// <remarks>Which matches the resource query string parameter of the request.</remarks>
[JsonProperty("resource")]
public string Resource { get; set; }

/// <summary>
/// The type of token returned by the Managed Identity endpoint.
/// </summary>
/// <remarks>which is a "Bearer" access token, which means the resource
/// can give access to the bearer of this token.</remarks>
[JsonProperty("token_type")]
public string TokenType { get; set; }

/// <summary>
/// A unique identifier generated by Azure AD for the Azure Resource.
/// </summary>
/// <remarks>The Client ID is a GUID value that uniquely identifies the application
/// and its configuration within the identity platform</remarks>
[JsonProperty("client_id")]
public string ClientId { get; set; }

}
}
19 changes: 17 additions & 2 deletions src/client/Microsoft.Identity.Client/OAuth2/MsalTokenResponse.cs
Original file line number Diff line number Diff line change
Expand Up @@ -251,15 +251,30 @@ internal static MsalTokenResponse CreateFromManagedIdentityResponse(ManagedIdent
{
ValidateManagedIdentityResult(managedIdentityResponse);

long expiresIn = DateTimeHelpers.GetDurationFromNowInSeconds(managedIdentityResponse.ExpiresOn);

return new MsalTokenResponse
{
AccessToken = managedIdentityResponse.AccessToken,
ExpiresIn = DateTimeHelpers.GetDurationFromNowInSeconds(managedIdentityResponse.ExpiresOn),
ExpiresIn = expiresIn,
TokenType = managedIdentityResponse.TokenType,
TokenSource = TokenSource.IdentityProvider
TokenSource = TokenSource.IdentityProvider,
RefreshIn = InferManagedIdentityRefreshInValue(expiresIn)
};
}

// Compute refresh_in as 1/2 expires_in, but only if expires_in > 2h.
private static long? InferManagedIdentityRefreshInValue(long expiresIn)

{
if (expiresIn > 2 * 3600)
{
return expiresIn / 2;
}

return null;
}

private static void ValidateManagedIdentityResult(ManagedIdentityResponse response)
{
if (string.IsNullOrEmpty(response.AccessToken))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -81,25 +81,8 @@ public void AssertAccessCounts(int expectedReads, int expectedWrites)

public void WaitTo_AssertAcessCounts(int expectedReads, int expectedWrites, int maxTimeInMilliSec = 30000)
{
YieldTillSatisfied(() => BeforeWriteCount == expectedWrites && AfterAccessWriteCount == expectedWrites && AfterAccessTotalCount == (expectedReads + expectedWrites) && BeforeAccessCount == (expectedReads + expectedWrites), maxTimeInMilliSec);
TestCommon.YieldTillSatisfied(() => BeforeWriteCount == expectedWrites && AfterAccessWriteCount == expectedWrites && AfterAccessTotalCount == (expectedReads + expectedWrites) && BeforeAccessCount == (expectedReads + expectedWrites), maxTimeInMilliSec);
AssertAccessCounts(expectedReads, expectedWrites);
}

private bool YieldTillSatisfied(Func<bool> func, int maxTimeInMilliSec = 30000)
{
int iCount = maxTimeInMilliSec / 100;
while (iCount > 0)
{
if (func())
{
return true;
}
Thread.Yield();
Thread.Sleep(100);
iCount--;
}

return false;
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -99,9 +99,9 @@ public static string GetBridgedHybridSpaTokenResponse(string spaAccountId)
",\"id_token_expires_in\":\"3600\"}";
}

public static string GetMsiSuccessfulResponse()
public static string GetMsiSuccessfulResponse(int expiresInHours = 1)
{
string expiresOn = DateTimeHelpers.DateTimeToUnixTimestamp(DateTime.UtcNow.AddHours(1));
string expiresOn = DateTimeHelpers.DateTimeToUnixTimestamp(DateTime.UtcNow.AddHours(expiresInHours));
return
"{\"access_token\":\"" + TestConstants.ATSecret + "\",\"expires_on\":\"" + expiresOn + "\",\"resource\":\"https://management.azure.com/\",\"token_type\":" +
"\"Bearer\",\"client_id\":\"client_id\"}";
Expand Down Expand Up @@ -519,6 +519,5 @@ public static HttpResponseMessage CreateAdfsOpenIdConfigurationResponse(string a
}
};
}

}
}
45 changes: 44 additions & 1 deletion tests/Microsoft.Identity.Test.Common/TestCommon.cs
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,12 @@
using System.Net;
using System.Net.Http;
using System.Text;
using System.Threading;
using System.Threading.Tasks;
using Microsoft.Identity.Client;
using Microsoft.Identity.Client.ApiConfig.Parameters;
using Microsoft.Identity.Client.Cache;
using Microsoft.Identity.Client.Cache.Items;
using Microsoft.Identity.Client.Http;
using Microsoft.Identity.Client.Instance;
using Microsoft.Identity.Client.Instance.Discovery;
Expand All @@ -24,7 +27,7 @@
using Microsoft.Identity.Client.PlatformsCommon.Shared;
using Microsoft.Identity.Test.Unit;
using Microsoft.VisualStudio.TestTools.UnitTesting;

using Microsoft.Identity.Test.Common.Core.Mocks;
using NSubstitute;
using static Microsoft.Identity.Client.TelemetryCore.Internal.Events.ApiEvent;

Expand Down Expand Up @@ -270,5 +273,45 @@ public static async Task ValidatePopNonceAsync(string nonce)

// Assert.AreEqual(response.StatusCode, System.Net.HttpStatusCode.OK);
//}

public static bool YieldTillSatisfied(Func<bool> func, int maxTimeInMilliSec = 30000)
{
int iCount = maxTimeInMilliSec / 100;
while (iCount > 0)
{
if (func())
{
return true;
}
Thread.Yield();
Thread.Sleep(100);
iCount--;
}

return false;
}

public static MsalAccessTokenCacheItem UpdateATWithRefreshOn(
ITokenCacheAccessor accessor,
DateTimeOffset? refreshOn = null,
bool expired = false)
{
MsalAccessTokenCacheItem atItem = accessor.GetAllAccessTokens().Single();

refreshOn = refreshOn ?? DateTimeOffset.UtcNow - TimeSpan.FromMinutes(30);

atItem = atItem.WithRefreshOn(refreshOn);

Assert.IsTrue(atItem.ExpiresOn > DateTime.UtcNow + TimeSpan.FromMinutes(10));

if (expired)
{
atItem = atItem.WithExpiresOn(DateTime.UtcNow - TimeSpan.FromMinutes(1));
}

accessor.SaveAccessToken(atItem);

return atItem;
}
}
}

0 comments on commit 9e4bf56

Please sign in to comment.