diff --git a/generator/.DevConfigs/e11928c9-15fb-4b9e-91d3-91775150d378.json b/generator/.DevConfigs/e11928c9-15fb-4b9e-91d3-91775150d378.json new file mode 100644 index 000000000000..1f4735227ec2 --- /dev/null +++ b/generator/.DevConfigs/e11928c9-15fb-4b9e-91d3-91775150d378.json @@ -0,0 +1,13 @@ +{ + "core": { + "updateMinimum": true, + "type": "minor", + "changeLogMessages": [ + "Re-introduce background refresh of credentials during their preempt expiry period (https://github.com/aws/aws-sdk-net/issues/4024)" + ], + "backwardIncompatibilitiesToIgnore": [ + "Amazon.Runtime.RefreshingAWSCredentials/MethodAdded", + "Amazon.Runtime.RefreshingAWSCredentials/FieldTypeChanged" + ] + } +} \ No newline at end of file diff --git a/sdk/src/Core/Amazon.Runtime/Credentials/InstanceProfileAWSCredentials.cs b/sdk/src/Core/Amazon.Runtime/Credentials/InstanceProfileAWSCredentials.cs index fbdb0b85f813..1ff9ea211100 100644 --- a/sdk/src/Core/Amazon.Runtime/Credentials/InstanceProfileAWSCredentials.cs +++ b/sdk/src/Core/Amazon.Runtime/Credentials/InstanceProfileAWSCredentials.cs @@ -91,10 +91,7 @@ protected override CredentialsRefreshState GenerateNewCredentials() // but try again to refresh them in 2 minutes. if (null != _currentRefreshState) { -#pragma warning disable CS0612, CS0618 // Type or member is obsolete - var newExpiryTime = AWSSDKUtils.CorrectedUtcNow + TimeSpan.FromMinutes(2); -#pragma warning restore CS0612,CS0618 // Type or member is obsolete - + var newExpiryTime = _timeProvider.CorrectedUtcNow + TimeSpan.FromMinutes(2); _currentRefreshState = new CredentialsRefreshState(_currentRefreshState.Credentials, newExpiryTime); return _currentRefreshState; } @@ -107,10 +104,7 @@ protected override CredentialsRefreshState GenerateNewCredentials() // use a custom refresh time -#pragma warning disable CS0612, CS0618 // Type or member is obsolete - var newExpiryTime = AWSSDKUtils.CorrectedUtcNow + TimeSpan.FromMinutes(new Random().Next(5, 11)); -#pragma warning restore CS0612, CS0618 // Type or member is obsolete - + var newExpiryTime = _timeProvider.CorrectedUtcNow + TimeSpan.FromMinutes(new Random().Next(5, 11)); _currentRefreshState = new CredentialsRefreshState(newState.Credentials, newExpiryTime); return _currentRefreshState; @@ -175,7 +169,7 @@ protected override async Task GenerateNewCredentialsAsy // but try again to refresh them in 2 minutes. if (null != _currentRefreshState) { - var newExpiryTime = AWSSDKUtils.CorrectedUtcNow + TimeSpan.FromMinutes(2); + var newExpiryTime = _timeProvider.CorrectedUtcNow + TimeSpan.FromMinutes(2); _currentRefreshState = new CredentialsRefreshState(_currentRefreshState.Credentials, newExpiryTime); return _currentRefreshState; } @@ -187,7 +181,7 @@ protected override async Task GenerateNewCredentialsAsy _logger.InfoFormat(_receivedExpiredCredentialsFromIMDS); // use a custom refresh time - var newExpiryTime = AWSSDKUtils.CorrectedUtcNow + TimeSpan.FromMinutes(new Random().Next(5, 11)); + var newExpiryTime = _timeProvider.CorrectedUtcNow + TimeSpan.FromMinutes(new Random().Next(5, 11)); _currentRefreshState = new CredentialsRefreshState(newState.Credentials, newExpiryTime); return _currentRefreshState; @@ -396,7 +390,7 @@ private static Uri InfoUri private CredentialsRefreshState GetEarlyRefreshState(CredentialsRefreshState state) { - DateTime newExpiryTime = AWSSDKUtils.CorrectedUtcNow + _refreshAttemptPeriod + PreemptExpiryTime; + DateTime newExpiryTime = _timeProvider.CorrectedUtcNow + _refreshAttemptPeriod + PreemptExpiryTime; // Use this only if the time is earlier than the default expiration time if (newExpiryTime > state.Expiration) diff --git a/sdk/src/Core/Amazon.Runtime/Credentials/RefreshingAWSCredentials.cs b/sdk/src/Core/Amazon.Runtime/Credentials/RefreshingAWSCredentials.cs index afe79d0b5af6..56e771e9b91f 100644 --- a/sdk/src/Core/Amazon.Runtime/Credentials/RefreshingAWSCredentials.cs +++ b/sdk/src/Core/Amazon.Runtime/Credentials/RefreshingAWSCredentials.cs @@ -14,10 +14,11 @@ */ using Amazon.Runtime.Internal.Util; -using Amazon.Util; +using Amazon.Util.Internal; using System; using System.Globalization; using System.Threading; +using System.Threading.Tasks; namespace Amazon.Runtime { @@ -27,6 +28,12 @@ namespace Amazon.Runtime public abstract class RefreshingAWSCredentials : AWSCredentials, IDisposable { private readonly Logger _logger = Logger.GetLogger(typeof(RefreshingAWSCredentials)); + protected readonly ITimeProvider _timeProvider; + + protected RefreshingAWSCredentials() : this(DefaultTimeProvider.Instance) { } + + protected RefreshingAWSCredentials(ITimeProvider timeProvider) + => _timeProvider = timeProvider ?? DefaultTimeProvider.Instance; /// public override DateTime? Expiration @@ -49,22 +56,26 @@ public override DateTime? Expiration /// public class CredentialsRefreshState { + private readonly ITimeProvider _timeProvider; + public ImmutableCredentials Credentials { get; set; } public DateTime Expiration { get; set; } - public CredentialsRefreshState() - { - } + public CredentialsRefreshState() : this(null, default) { } - public CredentialsRefreshState(ImmutableCredentials credentials, DateTime expiration) + public CredentialsRefreshState(ImmutableCredentials credentials, DateTime expiration) + : this (credentials, expiration, DefaultTimeProvider.Instance) { } + + public CredentialsRefreshState(ImmutableCredentials credentials, DateTime expiration, ITimeProvider timeProvider) { Credentials = credentials; Expiration = expiration; + _timeProvider = timeProvider ?? DefaultTimeProvider.Instance; } internal bool IsExpiredWithin(TimeSpan preemptExpiryTime) { - var now = AWSSDKUtils.CorrectedUtcNow; + var now = _timeProvider.CorrectedUtcNow; var exp = Expiration.ToUniversalTime(); return now > exp - preemptExpiryTime; } @@ -74,19 +85,38 @@ internal bool IsExpiredWithin(TimeSpan preemptExpiryTime) /// Represents the current state of the Credentials. /// /// This can be cleared without synchronization. - protected CredentialsRefreshState currentState; + protected volatile CredentialsRefreshState currentState; #region Private members private TimeSpan _preemptExpiryTime = TimeSpan.FromMinutes(0); + private TimeSpan _expirationBuffer = TimeSpan.FromMinutes(1); + private bool _disposed; + private readonly SemaphoreSlim _updateGeneratedCredentialsSemaphore = new SemaphoreSlim(1, 1); + /// - /// Semaphore to control thread access to GetCredentialsAsync method. - /// The semaphore will allow only one thread to generate new credentials and - /// update the current state. + /// Tracks the current state of background credentials refresh. /// - private readonly SemaphoreSlim _updateGeneratedCredentialsSemaphore = new SemaphoreSlim(1, 1); + private enum CredentialsLoadState + { + /// + /// No background refresh is currently in progress. + /// This is the default state, where credentials are either valid or expired. + /// + NotLoading, + + /// + /// A background refresh is currently in progress. + /// This means we're within the preempt expiry window, and credentials are still valid. + /// + Loading, + } + + // Note this is purposefuly marked as volatile since it is modified by multiple threads. Read the comments + // in the GetCredentials method for more information on the locking flow and how this is used. + private volatile CredentialsLoadState currentLoadState; #endregion @@ -95,8 +125,8 @@ internal bool IsExpiredWithin(TimeSpan preemptExpiryTime) #region Properties /// - /// The time before actual expiration to expire the credentials. - /// Property cannot be set to a negative TimeSpan. + /// If credentials are still valid but the expiration is within the Expiration minus PreemptExpiryTime a + /// background refresh of the credentials will be triggered. /// public TimeSpan PreemptExpiryTime { @@ -109,6 +139,22 @@ public TimeSpan PreemptExpiryTime } } + /// + /// The time subtracted from the expiration provided by the credentials provider and then used for determining + /// if the credentials are expired. This provides a buffer to avoid corner case issues of processing time + /// on the client side before the credentials are actually used for signing and validation on the server side. + /// + public TimeSpan ExpirationBuffer + { + get { return _expirationBuffer; } + set + { + if (value < TimeSpan.Zero) + throw new ArgumentOutOfRangeException("value", "ExpirationBuffer cannot be negative"); + _expirationBuffer = value; + } + } + #endregion #region Override methods @@ -119,51 +165,146 @@ public TimeSpan PreemptExpiryTime /// public override sealed ImmutableCredentials GetCredentials() { - _updateGeneratedCredentialsSemaphore.Wait(); - - try + // We save the currentState as it might be modified or cleared. + var tempState = currentState; + + // Before acquiring the lock, check if we need to refresh credentials. This is essentially the read only section + // and going into the if block is the costly write section. Majority of threads going through this path will stay + // in the read section avoiding blocking on the semaphore. Execution only goes into the write section when credentials + // are nearing expiration or already expired. + // + // There are two phases of credentials needing to be refreshed: + // + // Credentials are expired. In that case the lock will be acquired and credential refresh will be blocking further + // execution until new credentials are retrieved. Once credentials are retrieved and the lock is released any other + // threads that were blocked due to expired credentials will one by one acquire the lock and see that + // new credentials are present and use those refreshed credentials. + // + // Credential epiration in preempt window. This is the case credentials are still valid but to avoid a later + // blocking expiration refresh a background refresh is triggered. Only one background refresh will be triggered + // which is controlled by setting currentLoadState to Loading while the lock is held. + // Any other threads that come through in either the read or write section in the prempt window will see the + // currentLoadState is Loading and not trigger another background refresh. The current credentials will be returned + // instead of waiting for the background refresh to complete since they are still valid. + // + // The background refresh task is in charge of reseting the currentLoadState to NotLoading once the background refresh is complete + // or an exception happened. In the case of an exception during the background refresh resettting the currentLoadState to NotLoading + // is not done within the scope of the lock. This is safe because the only other place currentLoadState is modified to a different value + // is within the block of code that initiates the background refresh. The execution will never go into that block while a background + // refresh is in progress because the currentLoadState was set to Loading which prevents starting a background refresh. Since + // currentLoadState is potentially modified by multiple threads it has been marked as volatile to ensure the latest value is always read. + if (IsExpired(tempState) || (currentLoadState != CredentialsLoadState.Loading && IsPreemptExpiryWindow(tempState))) { - // We save the currentState as it might be modified or cleared. - var tempState = currentState; - - // If credentials are expired or we don't have any state yet, update - if (ShouldUpdateState(tempState)) + _updateGeneratedCredentialsSemaphore.Wait(); + try { - tempState = GenerateNewCredentials(); - UpdateToGeneratedCredentials(tempState); - currentState = tempState; + // Update the local variable for credentials after acquiring the lock in case another thread got the lock first and updated the current state. + tempState = currentState; + + // If credentials are expired block for credential refresh. + if (IsExpired(tempState)) + { + LogCredentialsExpired(tempState); + + tempState = GenerateNewCredentials(); + ValidateGeneratedCredentials(tempState); + currentState = tempState; + currentLoadState = CredentialsLoadState.NotLoading; + } + else if (currentLoadState != CredentialsLoadState.Loading && IsPreemptExpiryWindow(tempState)) + { + LogCredentialsPreemptExpiry(tempState); + + currentLoadState = CredentialsLoadState.Loading; + _ = BackgroundCredentialsRefreshAsync(); + } + } + finally + { + _updateGeneratedCredentialsSemaphore.Release(); } - - return tempState.Credentials; } - finally + + return tempState.Credentials; + } + + public override sealed async Task GetCredentialsAsync() + { + // NOTICE: Before modifying any of the logic read the comments in the synchronous GetCredentials method to + // understand the locking flow. If any changes are required be sure the comments in that method are also updated. + + // We save the currentState as it might be modified or cleared. + var tempState = currentState; + + if (IsExpired(tempState) || (currentLoadState != CredentialsLoadState.Loading && IsPreemptExpiryWindow(tempState))) { - _updateGeneratedCredentialsSemaphore.Release(); + await _updateGeneratedCredentialsSemaphore.WaitAsync().ConfigureAwait(false); + try + { + // Update the local variable for credentials after acquiring the lock in case another thread got the lock first and updated the current state. + tempState = currentState; + + // If credentials are expired block for credential refresh. + if (IsExpired(tempState)) + { + LogCredentialsExpired(tempState); + + tempState = await GenerateNewCredentialsAsync().ConfigureAwait(false); + ValidateGeneratedCredentials(tempState); + currentState = tempState; + currentLoadState = CredentialsLoadState.NotLoading; + } + else if (currentLoadState != CredentialsLoadState.Loading && IsPreemptExpiryWindow(tempState)) + { + LogCredentialsPreemptExpiry(tempState); + + currentLoadState = CredentialsLoadState.Loading; + _ = BackgroundCredentialsRefreshAsync(); + } + } + finally + { + _updateGeneratedCredentialsSemaphore.Release(); + } } + + return tempState.Credentials; } - public override sealed async System.Threading.Tasks.Task GetCredentialsAsync() + private async Task BackgroundCredentialsRefreshAsync() { - await _updateGeneratedCredentialsSemaphore.WaitAsync().ConfigureAwait(false); + // NOTICE: Before modifying any of the logic read the comments in the synchronous GetCredentials method to + // understand the locking flow. If any changes are required be sure the comments in that method are also updated. try { - // We save the currentState as it might be modified or cleared. - var tempState = currentState; + var newState = await GenerateNewCredentialsAsync().ConfigureAwait(false); + ValidateGeneratedCredentials(newState); - // If credentials are expired, update - if (ShouldUpdateState(tempState)) + // Acquire the lock to atomically update both currentState and currentLoadState + await _updateGeneratedCredentialsSemaphore.WaitAsync().ConfigureAwait(false); + try { - tempState = await GenerateNewCredentialsAsync().ConfigureAwait(false); - UpdateToGeneratedCredentials(tempState); - currentState = tempState; + currentState = newState; + currentLoadState = CredentialsLoadState.NotLoading; + } + finally + { + _updateGeneratedCredentialsSemaphore.Release(); } - - return tempState.Credentials; } - finally + catch (Exception e) { - _updateGeneratedCredentialsSemaphore.Release(); + _logger.Error(e, "Exception occurred performing background credentials refresh."); + + // If any exceptions occur during background refresh, reset the state to NotLoading + // so that future GetCredentials calls can attempt to refresh again. + // + // This is safe to modify outside of the lock because the only other place currentLoadState is modified + // to a different value is within the block of code that initiates the background refresh. The block + // can never be entered while this background refresh is in progress because currentLoadState was set to Loading. + currentLoadState = CredentialsLoadState.NotLoading; + throw; } } @@ -171,10 +312,10 @@ public override sealed async System.Threading.Tasks.Task G #region Private/protected credential update methods - private void UpdateToGeneratedCredentials(CredentialsRefreshState state) + private void ValidateGeneratedCredentials(CredentialsRefreshState state) { // Check if the new credentials are already expired - if (ShouldUpdateState(state)) + if (IsExpired(state)) { string errorMessage; if (state == null) @@ -185,65 +326,81 @@ private void UpdateToGeneratedCredentials(CredentialsRefreshState state) { errorMessage = string.Format(CultureInfo.InvariantCulture, "The retrieved credentials have already expired: Now = {0}, Credentials expiration = {1}", - AWSSDKUtils.CorrectedUtcNow, state.Expiration); + _timeProvider.CorrectedUtcNow, state.Expiration); } throw new AmazonClientException(errorMessage); } - // Offset the Expiration by PreemptExpiryTime. This produces the expiration window - // where the credentials should be updated before they actually expire. - state.Expiration -= PreemptExpiryTime; - - if (ShouldUpdateState(state)) - { - // This could happen if the default value of PreemptExpiryTime is - // overridden and set too high such that ShouldUpdate returns true. - _logger.InfoFormat( - "The preempt expiry time is set too high: Current time = {0}, Credentials expiry time = {1}, Preempt expiry time = {2}.", - AWSSDKUtils.CorrectedUtcNow, - state.Expiration, PreemptExpiryTime); - } + state.Expiration -= ExpirationBuffer; } /// - /// Test credentials existence and expiration time - /// should update if: - /// credentials have not been loaded yet - /// it's past the expiration time. At this point currentState.Expiration may - /// have the PreemptExpiryTime baked into to the expiration from a call to - /// UpdateToGeneratedCredentials but it may not if this is new application load. + /// This property has been marked as Obsolete because it is no longer used for determining + /// if credentials should be updated. The boolean ShouldUpdate property did not provide + /// enough information on whether credentials are expired or in the preempt expiry window. /// + [Obsolete("Property is no longer used for determining if credentials should be updated.")] protected bool ShouldUpdate { get { - return ShouldUpdateState(currentState); + return IsExpired(currentState); } } - // Test credentials existence and expiration time - // should update if: - // credentials have not been loaded yet - // it's past the expiration time. At this point currentState.Expiration may - // have the PreemptExpiryTime baked into to the expiration from a call to - // UpdateToGeneratedCredentials but it may not if this is new application - // load. - private bool ShouldUpdateState(CredentialsRefreshState state) + /// + /// Test if the credentials are currently expired. + /// + private static bool IsExpired(CredentialsRefreshState state) { - // it's past the expiration time. At this point currentState.Expiration may - // have the PreemptExpiryTime baked into to the expiration from a call to - // UpdateToGeneratedCredentials but it may not if this is new application - // load. var isExpired = state?.IsExpiredWithin(TimeSpan.Zero); - if (isExpired == true) + return isExpired ?? true; + } + + private void LogCredentialsExpired(CredentialsRefreshState state) + { + if (state == null) { - _logger.InfoFormat("Determined refreshing credentials should update. Expiration time: {0}, Current time: {1}", - state.Expiration.Add(PreemptExpiryTime).ToUniversalTime().ToString("yyyy-MM-ddTHH:mm:ss.f ffffffK", CultureInfo.InvariantCulture), - AWSSDKUtils.CorrectedUtcNow.ToString("yyyy-MM-ddTHH:mm:ss.fffffffK", CultureInfo.InvariantCulture)); + return; } - return isExpired ?? true; + _logger.InfoFormat( + "Determined refreshing credentials should update. Expiration time: {0}, Current time: {1}", + state.Expiration.ToUniversalTime().ToString("yyyy-MM-ddTHH:mm:ss.fffffffK", CultureInfo.InvariantCulture), + _timeProvider.CorrectedUtcNow.ToString("yyyy-MM-ddTHH:mm:ss.fffffffK", CultureInfo.InvariantCulture) + ); + } + + /// + /// Test if the credentials are in the preempt expiry window. + /// + /// That means the instance currently has credentials and they are not expired but that will expire + /// within the window of expiration minus PreemptExpiryTime. + /// + private bool IsPreemptExpiryWindow(CredentialsRefreshState state) + { + if (state == null || IsExpired(state)) + { + return false; + } + + var isPreemptWindow = state.IsExpiredWithin(PreemptExpiryTime); + return isPreemptWindow; + } + + private void LogCredentialsPreemptExpiry(CredentialsRefreshState state) + { + if (state == null) + { + return; + } + + _logger.InfoFormat( + "Determined refreshing credentials are in window for preempt expiration. Preempt time: {0}, Current time: {1}", + state.Expiration.Subtract(PreemptExpiryTime).ToUniversalTime().ToString("yyyy-MM-ddTHH:mm:ss.fffffffK", CultureInfo.InvariantCulture), + _timeProvider.CorrectedUtcNow.ToString("yyyy-MM-ddTHH:mm:ss.fffffffK", CultureInfo.InvariantCulture) + ); } /// @@ -263,9 +420,9 @@ protected virtual CredentialsRefreshState GenerateNewCredentials() /// Called on first credentials request and when expiration date is in the past. /// /// - protected virtual System.Threading.Tasks.Task GenerateNewCredentialsAsync() + protected virtual Task GenerateNewCredentialsAsync() { - return System.Threading.Tasks.Task.Run(() => this.GenerateNewCredentials()); + return Task.Run(() => this.GenerateNewCredentials()); } protected virtual void Dispose(bool disposing) diff --git a/sdk/src/Core/Amazon.Util/Internal/ITimeProvider.cs b/sdk/src/Core/Amazon.Util/Internal/ITimeProvider.cs new file mode 100644 index 000000000000..1bb3a4fe909d --- /dev/null +++ b/sdk/src/Core/Amazon.Util/Internal/ITimeProvider.cs @@ -0,0 +1,42 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). + * You may not use this file except in compliance with the License. + * A copy of the License is located at + * + * http://aws.amazon.com/apache2.0 + * + * or in the "license" file accompanying this file. This file is distributed + * on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either + * express or implied. See the License for the specific language governing + * permissions and limitations under the License. + */ + +using System; + +namespace Amazon.Util.Internal +{ + /// + /// Interface that can be used to mock the current (corrected) UTC time for testing purposes. + /// + /// The default implementation will use the value from that takes + /// any manual clock correction into account. + /// + public interface ITimeProvider + { + public DateTime CorrectedUtcNow { get; } + } + + /// + /// Default implementation of using . + /// + public sealed class DefaultTimeProvider : ITimeProvider + { + public static readonly DefaultTimeProvider Instance = new(); + + private DefaultTimeProvider() { } + + public DateTime CorrectedUtcNow => AWSSDKUtils.CorrectedUtcNow; + } +} diff --git a/sdk/src/Core/GlobalSuppressions.cs b/sdk/src/Core/GlobalSuppressions.cs index 0eac98387a68..026c03612606 100644 --- a/sdk/src/Core/GlobalSuppressions.cs +++ b/sdk/src/Core/GlobalSuppressions.cs @@ -405,7 +405,8 @@ [module: SuppressMessage("AwsSdkRules", "CR1001:PreventHashAlgorithmCreateRule", Scope = "member", Target = "Amazon.Util.CryptoUtilFactory+CryptoUtil.#CreateSHA256Instance()")] // Visible instance fields -[module: SuppressMessage("Microsoft.Design", "CA1051:DoNotDeclareVisibleInstanceFields", Scope = "member", Target = "Amazon.Runtime.RefreshingAWSCredentials.#currentState")] +[module: SuppressMessage("Microsoft.Design", "CA1051:DoNotDeclareVisibleInstanceFields", Scope = "member", Target = "~F:Amazon.Runtime.RefreshingAWSCredentials.currentState")] +[module: SuppressMessage("Microsoft.Design", "CA1051:DoNotDeclareVisibleInstanceFields", Scope = "member", Target = "~F:Amazon.Runtime.RefreshingAWSCredentials._timeProvider")] // Supression due to IL2CPP error [module: SuppressMessage("Microsoft.Design", "CA1006:DoNotNestGenericTypesInMemberSignatures", Scope = "member", Target = "Amazon.Runtime.Internal.Auth.AWS4Signer.#SortHeaders(System.Collections.Generic.IEnumerable`1>)")] diff --git a/sdk/test/NetStandard/UnitTests/Core/Credentials/RefreshingAWSCredentialsTests.cs b/sdk/test/NetStandard/UnitTests/Core/Credentials/RefreshingAWSCredentialsTests.cs new file mode 100644 index 000000000000..e45cd38d4cbe --- /dev/null +++ b/sdk/test/NetStandard/UnitTests/Core/Credentials/RefreshingAWSCredentialsTests.cs @@ -0,0 +1,264 @@ +using System; +using System.Linq; +using System.Threading; +using System.Threading.Tasks; +using Amazon.Runtime; +using Amazon.Util.Internal; +using Xunit; + +namespace UnitTests.NetStandard.Core.Credentials +{ + public sealed class RefreshingAWSCredentialsTests + { + private readonly MockTimeProvider _mockProvider; + private readonly DateTime _baseTimeUtc = new DateTime(1970, 1, 1).ToUniversalTime(); + private readonly TimeSpan _lifetime = TimeSpan.FromMinutes(60); + + public RefreshingAWSCredentialsTests() => _mockProvider = new MockTimeProvider(); + + [Theory] + [InlineData(59.5)] // Credentials are not expired yet but just entered the expiration buffer + [InlineData(60)] // Credentials have just expired + [InlineData(75)] // Credentials are way past expiration + public void ConcurrentCallsToGetCredentialsOnlyGeneratesNewCredentialsOnce(double instantInMinutes) + { + var mockCredentials = new MockRefreshingAWSCredentials(_lifetime, _mockProvider) + { + PreemptExpiryTime = TimeSpan.FromMinutes(15), + }; + + _mockProvider.CorrectedUtcNow = _baseTimeUtc; + var initialCreds = mockCredentials.GetCredentials(); + + // Prevent GenerateNewCredentials from returning. + _mockProvider.CorrectedUtcNow = _baseTimeUtc + TimeSpan.FromMinutes(instantInMinutes); + mockCredentials.CloseGenerateCredentialsGate(); + var concurrentCredentialTasks = Task.WhenAll( + Enumerable.Range(1, 5).Select(i => Task.Run(() => mockCredentials.GetCredentials())) + ); + + // Allow GenerateNewCredentials to complete. + mockCredentials.OpenGenerateCredentialsGate(); + var allCreds = concurrentCredentialTasks.Result; + Assert.NotEqual(initialCreds, allCreds[0]); + + Assert.Equal(2, mockCredentials.GeneratedTokenCount); + for (var i = 1; i < allCreds.Length; i++) + { + Assert.Equal(allCreds[0], allCreds[i]); + } + } + + [Theory] + [InlineData(59.5)] // Credentials are not expired yet but just entered the expiration buffer + [InlineData(60)] // Credentials have just expired + [InlineData(75)] // Credentials are way past expiration + public async Task ConcurrentCallsToGetCredentialsOnlyGeneratesNewCredentialsOnceAsync(double instantInMinutes) + { + var mockCredentials = new MockRefreshingAWSCredentials(_lifetime, _mockProvider) + { + PreemptExpiryTime = TimeSpan.FromMinutes(15), + }; + + _mockProvider.CorrectedUtcNow = _baseTimeUtc; + var initialCreds = mockCredentials.GetCredentials(); + + // Prevent GenerateNewCredentials from returning. + _mockProvider.CorrectedUtcNow = _baseTimeUtc + TimeSpan.FromMinutes(instantInMinutes); + mockCredentials.CloseGenerateCredentialsGate(); + var concurrentCredentialTasks = Task.WhenAll( + Enumerable.Range(1, 5).Select(i => mockCredentials.GetCredentialsAsync()) + ); + + // Allow GenerateNewCredentials to complete. + mockCredentials.OpenGenerateCredentialsGate(); + var allCreds = await concurrentCredentialTasks; + + Assert.NotEqual(initialCreds, allCreds[0]); + Assert.Equal(2, mockCredentials.GeneratedTokenCount); + for (var i = 1; i < allCreds.Length; i++) + { + Assert.Equal(allCreds[0], allCreds[i]); + } + } + + [Theory] + [InlineData(15, 1, 60)] // Credentials have just expired. + [InlineData(5, 10, 51)] // Credentials have expired considering the expiration buffer. + public void CredentialsAreRefreshedImmediatelyWhenExpired(int preemptInMinutes, int bufferInMinutes, double instantInMinutes) + { + var mockCredentials = new MockRefreshingAWSCredentials(_lifetime, _mockProvider) + { + PreemptExpiryTime = TimeSpan.FromMinutes(preemptInMinutes), + ExpirationBuffer = TimeSpan.FromMinutes(bufferInMinutes), + }; + + _mockProvider.CorrectedUtcNow = _baseTimeUtc; + var initialCreds = mockCredentials.GetCredentials(); + + _mockProvider.CorrectedUtcNow = _baseTimeUtc + TimeSpan.FromMinutes(instantInMinutes); + var credsAfterExpiration = mockCredentials.GetCredentials(); + Assert.NotEqual(initialCreds, credsAfterExpiration); + Assert.Equal(2, mockCredentials.GeneratedTokenCount); + } + + [Theory] + [InlineData(15, 1, 60)] // Credentials have just expired. + [InlineData(5, 10, 51)] // Credentials have expired considering the expiration buffer. + public async Task CredentialsAreRefreshedImmediatelyWhenExpiredAsync(int preemptInMinutes, int bufferInMinutes, double instantInMinutes) + { + var mockCredentials = new MockRefreshingAWSCredentials(_lifetime, _mockProvider) + { + PreemptExpiryTime = TimeSpan.FromMinutes(preemptInMinutes), + ExpirationBuffer = TimeSpan.FromMinutes(bufferInMinutes), + }; + + _mockProvider.CorrectedUtcNow = _baseTimeUtc; + var initialCreds = await mockCredentials.GetCredentialsAsync().ConfigureAwait(false); + + _mockProvider.CorrectedUtcNow = _baseTimeUtc + TimeSpan.FromMinutes(instantInMinutes); + var credsAfterExpiration = await mockCredentials.GetCredentialsAsync().ConfigureAwait(false); + Assert.NotEqual(initialCreds, credsAfterExpiration); + Assert.Equal(2, mockCredentials.GeneratedTokenCount); + } + + [Theory] + [InlineData(45.5)] // Credentials just entered the preempt expiry period + [InlineData(50)] + [InlineData(58.5)] // Credentials are still within the preempt expiry period but before the expiration buffer + public void CredentialsAreRefreshedInBackgroundDuringPreemptyExpiryPeriod(double instantInMinutes) + { + var mockCredentials = new MockRefreshingAWSCredentials(_lifetime, _mockProvider) + { + PreemptExpiryTime = TimeSpan.FromMinutes(15), + }; + + _mockProvider.CorrectedUtcNow = _baseTimeUtc; + var initialCreds = mockCredentials.GetCredentials(); + + _mockProvider.CorrectedUtcNow = _baseTimeUtc + TimeSpan.FromMinutes(instantInMinutes); + var previousState = mockCredentials.CurrentState; + var credsDuringPreemptExpiry = mockCredentials.GetCredentials(); + Assert.Equal(initialCreds, credsDuringPreemptExpiry); + + // wait for background refresh to complete + Assert.True(SpinWait.SpinUntil(() => !ReferenceEquals(mockCredentials.CurrentState, previousState), 1_000)); + + var credsAfterRefresh = mockCredentials.GetCredentials(); + Assert.NotEqual(credsAfterRefresh, credsDuringPreemptExpiry); + Assert.Equal(_mockProvider.CorrectedUtcNow + _lifetime - mockCredentials.ExpirationBuffer, mockCredentials.CurrentState.Expiration); + Assert.Equal(2, mockCredentials.GeneratedTokenCount); + } + + [Theory] + [InlineData(45.5)] // Credentials just entered the preempt expiry period + [InlineData(50)] + [InlineData(58.5)] // Credentials are still within the preempt expiry period but before the expiration buffer + public async Task CredentialsAreRefreshedInBackgroundDuringPreemptyExpiryPeriodAsync(double instantInMinutes) + { + var mockCredentials = new MockRefreshingAWSCredentials(_lifetime, _mockProvider) + { + PreemptExpiryTime = TimeSpan.FromMinutes(15), + }; + + _mockProvider.CorrectedUtcNow = _baseTimeUtc; + var initialCreds = await mockCredentials.GetCredentialsAsync().ConfigureAwait(false); + + _mockProvider.CorrectedUtcNow = _baseTimeUtc + TimeSpan.FromMinutes(instantInMinutes); + var previousState = mockCredentials.CurrentState; + var credsDuringPreemptExpiry = await mockCredentials.GetCredentialsAsync().ConfigureAwait(false); + Assert.Equal(initialCreds, credsDuringPreemptExpiry); + + // wait for background refresh to complete + Assert.True(SpinWait.SpinUntil(() => !ReferenceEquals(mockCredentials.CurrentState, previousState), 1_000)); + + var credsAfterRefresh = await mockCredentials.GetCredentialsAsync().ConfigureAwait(false); + Assert.NotEqual(credsAfterRefresh, credsDuringPreemptExpiry); + Assert.Equal(_mockProvider.CorrectedUtcNow + _lifetime - mockCredentials.ExpirationBuffer, mockCredentials.CurrentState.Expiration); + Assert.Equal(2, mockCredentials.GeneratedTokenCount); + } + + [Fact] + public void ConcurrentCallsDuringPreemptWindowOnlyGeneratesNewCredentialsOnce() + { + var mockCredentials = new MockRefreshingAWSCredentials(_lifetime, _mockProvider) + { + PreemptExpiryTime = TimeSpan.FromMinutes(15), + }; + + _mockProvider.CorrectedUtcNow = _baseTimeUtc; + var initialCreds = mockCredentials.GetCredentials(); + var previousState = mockCredentials.CurrentState; + + // Move time into preempt expiry period (not expired, but within preempt window). + _mockProvider.CorrectedUtcNow = _baseTimeUtc + TimeSpan.FromMinutes(50); + mockCredentials.CloseGenerateCredentialsGate(); + + // Multiple parallel calls during preempt expiry. + var tasks = Enumerable.Range(1, 5) + .Select(_ => Task.Run(() => mockCredentials.GetCredentials())) + .ToArray(); + mockCredentials.OpenGenerateCredentialsGate(); + Task.WaitAll(tasks); + + // Wait for background refresh to complete. + Assert.True(SpinWait.SpinUntil(() => !ReferenceEquals(mockCredentials.CurrentState, previousState), 1_000)); + + // Only one background refresh should have occurred. + var credsAfterRefresh = mockCredentials.GetCredentials(); + Assert.Equal(2, mockCredentials.GeneratedTokenCount); + Assert.NotEqual(initialCreds, credsAfterRefresh); + } + + private class MockTimeProvider : ITimeProvider + { + public DateTime CorrectedUtcNow { get; set; } + } + + // using a hand-written mock in order to have access to the protected fields + private sealed class MockRefreshingAWSCredentials : RefreshingAWSCredentials + { + private readonly TimeSpan _credentialsLifetime; + private readonly ManualResetEventSlim _generateCredsEvent; + private int _tokenCounter; + + public MockRefreshingAWSCredentials(TimeSpan credentialsLifetime, ITimeProvider timeProvider) + : base(timeProvider) + { + _credentialsLifetime = credentialsLifetime; + _generateCredsEvent = new ManualResetEventSlim(initialState: true); + _tokenCounter = 0; + } + + public CredentialsRefreshState CurrentState => base.currentState; + + public int GeneratedTokenCount => _tokenCounter; + + public bool IsGenerateCredentialsGateClosed => !_generateCredsEvent.IsSet; + + public void OpenGenerateCredentialsGate() + { + _generateCredsEvent.Set(); + } + + public void CloseGenerateCredentialsGate() + { + _generateCredsEvent.Reset(); + } + + protected override CredentialsRefreshState GenerateNewCredentials() + { + _generateCredsEvent.Wait(); + + var credentials = new ImmutableCredentials("access_key_id", "secret_access_key", $"token_{Interlocked.Increment(ref _tokenCounter)}"); + var expiration = _timeProvider.CorrectedUtcNow + _credentialsLifetime; + return new CredentialsRefreshState(credentials, expiration, _timeProvider); + } + + protected override Task GenerateNewCredentialsAsync() + { + return Task.Run(GenerateNewCredentials); + } + } + } +}