Skip to content
Merged

CAE #14567

Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
335 changes: 335 additions & 0 deletions src/Accounts/Accounts.Test/SilentReAuthByTenantCmdletTest.cs

Large diffs are not rendered by default.

22 changes: 22 additions & 0 deletions src/Accounts/Accounts/Account/ConnectAzureRmAccount.cs
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
using System.Management.Automation;
using System.Runtime.InteropServices;
using System.Security;
using System.Text;
using System.Threading;
using System.Threading.Tasks;

Expand Down Expand Up @@ -476,11 +477,17 @@ public override void ExecuteCmdlet()
}
catch (AuthenticationFailedException ex)
{
string message = string.Empty;
if (IsUnableToOpenWebPageError(ex))
{
WriteWarning(Resources.InteractiveAuthNotSupported);
WriteDebug(ex.ToString());
}
else if (TryParseUnknownAuthenticationException(ex, out message))
{
WriteDebug(ex.ToString());
throw ex.FromExceptionAndAdditionalMessage(message);
}
else
{
if (IsUsingInteractiveAuthentication())
Expand Down Expand Up @@ -519,6 +526,21 @@ private bool IsUnableToOpenWebPageError(AuthenticationFailedException exception)
|| (exception.Message?.ToLower()?.Contains("unable to open a web page") ?? false);
}

private bool TryParseUnknownAuthenticationException(AuthenticationFailedException exception, out string message)
{

var innerException = exception?.InnerException as MsalServiceException;
bool isUnknownMsalServiceException = string.Equals(innerException?.ErrorCode, "access_denied", StringComparison.OrdinalIgnoreCase);
message = null;
if(isUnknownMsalServiceException)
{
StringBuilder messageBuilder = new StringBuilder(nameof(innerException.ErrorCode));
messageBuilder.Append(": ").Append(innerException.ErrorCode);
message = messageBuilder.ToString();
}
return isUnknownMsalServiceException;
}

private ConcurrentQueue<Task> _tasks = new ConcurrentQueue<Task>();

private void HandleActions()
Expand Down
5 changes: 5 additions & 0 deletions src/Accounts/Accounts/ChangeLog.md
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,11 @@
-->

## Upcoming Release
* Implement CAE by adding handler to http pipeline
* Improved error message when login is blocked by AAD
* Improved error message when silent reauthentication failed
* Enabled CAE for Get-AzTenant and Get-AzSubcription
* Added test cases
* Disabled context auto saving when token cache persistence fails on Windows and macOS
* Upgraded Microsoft.ApplicationInsights from 2.4.0 to 2.12.0
* Updated Azure.Core to 1.16.0
Expand Down
38 changes: 32 additions & 6 deletions src/Accounts/Accounts/CommonModule/ContextAdapter.cs
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,8 @@
using System.Linq;
using System.Management.Automation;
using Microsoft.Azure.Commands.Profile.Properties;
using System.Linq.Expressions;
using Azure.Identity;

namespace Microsoft.Azure.Commands.Common
{
Expand Down Expand Up @@ -156,8 +158,31 @@ internal Func<HttpRequestMessage, CancellationToken, Action, SignalDelegate, Nex
return async (request, cancelToken, cancelAction, signal, next) =>
{
PatchRequestUri(context, request);
await AuthorizeRequest(context, resourceId, request, cancelToken);
return await next(request, cancelToken, cancelAction, signal);
IAccessToken accessToken = await AuthorizeRequest(context, resourceId, request, cancelToken);
var response = await next(request, cancelToken, cancelAction, signal);

if(response.StatusCode == System.Net.HttpStatusCode.Unauthorized && response.Headers.WwwAuthenticate?.Count > 0)
{
//get token again with claims challenge
if(accessToken is IClaimsChallengeProcessor processor)
{
try
{
var claimsChallenge = ClaimsChallengeUtilities.GetClaimsChallenge(response);
if (!string.IsNullOrEmpty(claimsChallenge))
{
await processor.OnClaimsChallenageAsync(request, claimsChallenge, cancelToken).ConfigureAwait(false);
response = await next(request, cancelToken, cancelAction, signal);
}
}
catch (AuthenticationFailedException e)
{
string message = response?.GetWwwAuthenticateMessage() ?? string.Empty;
throw e.FromExceptionAndAdditionalMessage(message);
}
}
}
return response;
};
}

Expand All @@ -167,21 +192,22 @@ internal Func<HttpRequestMessage, CancellationToken, Action, SignalDelegate, Nex
/// <param name="context"></param>
/// <param name="resourceId"></param>
/// <param name="request"></param>
/// <param name="outerToken"></param>
/// <param name="cancellationToken"></param>
/// <returns></returns>
internal async Task AuthorizeRequest(IAzureContext context, string resourceId, HttpRequestMessage request, CancellationToken outerToken)
internal async Task<IAccessToken> AuthorizeRequest(IAzureContext context, string resourceId, HttpRequestMessage request, CancellationToken cancellationToken)
{
if (context == null || context.Account == null || context.Environment == null)
{
throw new InvalidOperationException(Resources.InvalidAzureContext);
}

await Task.Run(() =>
return await Task.Run(() =>
{
resourceId = context?.Environment?.GetAudienceFromRequestUri(request.RequestUri) ?? resourceId;
var authToken = _authenticator.Authenticate(context.Account, context.Environment, context.Tenant.Id, null, "Never", null, resourceId);
authToken.AuthorizeRequest((type, token) => request.Headers.Authorization = new System.Net.Http.Headers.AuthenticationHeaderValue(type, token));
}, outerToken);
return authToken;
}, cancellationToken);
}

internal void PatchRequestUri(IAzureContext context, HttpRequestMessage request)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@
using Microsoft.Azure.Internal.Subscriptions;
using Microsoft.Azure.Internal.Subscriptions.Models;
using Microsoft.Azure.Internal.Subscriptions.Models.Utilities;
using Microsoft.Rest;
using Microsoft.WindowsAzure.Commands.Utilities.Common;
using System.Collections.Generic;
using System.Linq;
Expand All @@ -41,7 +40,7 @@ public IList<AzureTenant> ListAccountTenants(IAccessToken accessToken, IAzureEnv
{
subscriptionClient = AzureSession.Instance.ClientFactory.CreateCustomArmClient<SubscriptionClient>(
environment.GetEndpointAsUri(AzureEnvironment.Endpoint.ResourceManager),
new TokenCredentials(accessToken.AccessToken) as ServiceClientCredentials,
new RenewingTokenCredential(accessToken),
AzureSession.Instance.ClientFactory.GetCustomHandlers());

var tenants = new GenericPageEnumerable<TenantIdDescription>(subscriptionClient.Tenants.List, subscriptionClient.Tenants.ListNext, ulong.MaxValue, 0).ToList();
Expand Down Expand Up @@ -71,7 +70,7 @@ public IList<AzureSubscription> ListAllSubscriptionsForTenant(IAccessToken acces
{
using (var subscriptionClient = AzureSession.Instance.ClientFactory.CreateCustomArmClient<SubscriptionClient>(
environment.GetEndpointAsUri(AzureEnvironment.Endpoint.ResourceManager),
new TokenCredentials(accessToken.AccessToken) as ServiceClientCredentials,
new RenewingTokenCredential(accessToken),
AzureSession.Instance.ClientFactory.GetCustomHandlers()))
{
return (subscriptionClient.ListAllSubscriptions()?
Expand All @@ -83,7 +82,7 @@ public AzureSubscription GetSubscriptionById(string subscriptionId, IAccessToken
{
using (var subscriptionClient = AzureSession.Instance.ClientFactory.CreateCustomArmClient<SubscriptionClient>(
environment.GetEndpointAsUri(AzureEnvironment.Endpoint.ResourceManager),
new TokenCredentials(accessToken.AccessToken) as ServiceClientCredentials,
new RenewingTokenCredential(accessToken),
AzureSession.Instance.ClientFactory.GetCustomHandlers()))
{
var subscription = subscriptionClient.Subscriptions.Get(subscriptionId);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ public IList<AzureTenant> ListAccountTenants(IAccessToken accessToken, IAzureEnv
{
subscriptionClient = AzureSession.Instance.ClientFactory.CreateCustomArmClient<SubscriptionClient>(
environment.GetEndpointAsUri(AzureEnvironment.Endpoint.ResourceManager),
new TokenCredentials(accessToken.AccessToken) as ServiceClientCredentials,
new RenewingTokenCredential(accessToken),
AzureSession.Instance.ClientFactory.GetCustomHandlers());

var tenants = new GenericPageEnumerable<TenantIdDescription>(subscriptionClient.Tenants.List, subscriptionClient.Tenants.ListNext, ulong.MaxValue, 0).ToList();
Expand Down Expand Up @@ -72,7 +72,7 @@ public IList<AzureSubscription> ListAllSubscriptionsForTenant(IAccessToken acces
{
using (var subscriptionClient = AzureSession.Instance.ClientFactory.CreateCustomArmClient<SubscriptionClient>(
environment.GetEndpointAsUri(AzureEnvironment.Endpoint.ResourceManager),
new TokenCredentials(accessToken.AccessToken) as ServiceClientCredentials,
new RenewingTokenCredential(accessToken),
AzureSession.Instance.ClientFactory.GetCustomHandlers()))
{
return subscriptionClient.ListAllSubscriptions()?
Expand All @@ -84,7 +84,7 @@ public AzureSubscription GetSubscriptionById(string subscriptionId, IAccessToken
{
using (var subscriptionClient = AzureSession.Instance.ClientFactory.CreateCustomArmClient<SubscriptionClient>(
environment.GetEndpointAsUri(AzureEnvironment.Endpoint.ResourceManager),
new TokenCredentials(accessToken.AccessToken) as ServiceClientCredentials,
new RenewingTokenCredential(accessToken),
AzureSession.Instance.ClientFactory.GetCustomHandlers()))
{
var subscription = subscriptionClient.Subscriptions.Get(subscriptionId);
Expand Down
20 changes: 18 additions & 2 deletions src/Accounts/Accounts/Tenant/GetAzureRMTenant.cs
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,10 @@
using Microsoft.Azure.Commands.Common.Authentication.Models;
using Microsoft.Azure.Commands.Profile.Models;
using Microsoft.Azure.Commands.ResourceManager.Common;
using Microsoft.WindowsAzure.Commands.Common;
using System.Collections.Concurrent;
using System.Linq;
using System.Management.Automation;
using System.Threading.Tasks;

namespace Microsoft.Azure.Commands.Profile
{
Expand All @@ -36,11 +37,26 @@ public class GetAzureRMTenantCommand : AzureRMCmdlet
[ValidateNotNullOrEmpty]
public string TenantId { get; set; }


public override void ExecuteCmdlet()
{
var profileClient = new RMProfileClient(AzureRmProfileProvider.Instance.GetProfile<AzureRmProfile>());
profileClient.WarningLog = (message) => _tasks.Enqueue(new Task(() => this.WriteWarning(message)));

var tenants = profileClient.ListTenants(TenantId).Select((t) => new PSAzureTenant(t));
HandleActions();
WriteObject(tenants, enumerateCollection: true);
}

WriteObject(profileClient.ListTenants(TenantId).Select((t) => new PSAzureTenant(t)), enumerateCollection: true);
private ConcurrentQueue<Task> _tasks = new ConcurrentQueue<Task>();

private void HandleActions()
{
Task task;
while (_tasks.TryDequeue(out task))
{
task.RunSynchronously();
}
}
}
}
1 change: 0 additions & 1 deletion src/Accounts/Accounts/Token/GetAzureRmAccessToken.cs
Original file line number Diff line number Diff line change
Expand Up @@ -119,7 +119,6 @@ public override void ExecuteCmdlet()
{
var tokenParts = accessToken.AccessToken.Split('.');
var decodedToken = Base64UrlHelper.DecodeToString(tokenParts[1]);

var tokenDocument = JsonDocument.Parse(decodedToken);
int expSeconds = tokenDocument.RootElement.EnumerateObject()
.Where(p => p.Name == "exp")
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,97 @@
// ----------------------------------------------------------------------------------
//
// Copyright Microsoft Corporation
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
// http://www.apache.org/licenses/LICENSE-2.0
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// ----------------------------------------------------------------------------------

using Azure.Core;
using System;
using System.Threading;
using System.Threading.Tasks;
using Microsoft.Azure.Commands.Common.Authentication;
using Microsoft.Azure.Commands.Common.Authentication.Abstractions;
using Microsoft.Azure.PowerShell.Authenticators;
using Microsoft.Azure.PowerShell.Authenticators.Factories;
using Microsoft.WindowsAzure.Commands.ScenarioTest;
using Moq;
using Xunit;
using Xunit.Abstractions;
using Azure.Identity;

namespace Common.Authenticators.Test
{
public class SilentAuthenticatorTests
{
private const string TestTenantId = "123";
private const string TestResourceId = "ActiveDirectoryServiceEndpointResourceId";

private const string fakeToken = "faketoken";

private ITestOutputHelper Output { get; set; }

class TokenCredentialMock : TokenCredential
{
public override AccessToken GetToken(TokenRequestContext requestContext, CancellationToken cancellationToken)
{
throw new NotImplementedException();
}

public override ValueTask<AccessToken> GetTokenAsync(TokenRequestContext requestContext, CancellationToken cancellationToken)
{
return new ValueTask<AccessToken>(new AccessToken(fakeToken, DateTimeOffset.Now));
}
}

public SilentAuthenticatorTests(ITestOutputHelper output)
{
AzureSessionInitializer.InitializeAzureSession();
Output = output;
}

[Fact]
[Trait(Category.AcceptanceType, Category.CheckIn)]
public async Task SimpleSilentAuthenticationTest()
{
var accountId = "testuser";

//Setup
var mockAzureCredentialFactory = new Mock<AzureCredentialFactory>();
mockAzureCredentialFactory.Setup(f => f.CreateSharedTokenCacheCredentials(It.IsAny<SharedTokenCacheCredentialOptions>())).Returns(() => new TokenCredentialMock());
AzureSession.Instance.RegisterComponent(nameof(AzureCredentialFactory), () => mockAzureCredentialFactory.Object, true);
InMemoryTokenCacheProvider cacheProvider = new InMemoryTokenCacheProvider();

var account = new AzureAccount
{
Id = accountId,
Type = AzureAccount.AccountType.User,
};
account.SetTenants(TestTenantId);

var parameter = new SilentParameters(
cacheProvider,
AzureEnvironment.PublicEnvironments["AzureCloud"],
null,
TestTenantId,
TestResourceId,
account.Id,
accountId);

//Run
var authenticator = new SilentAuthenticator();
var token = await authenticator.Authenticate(parameter);

//Verify
mockAzureCredentialFactory.Verify();
Assert.Equal(fakeToken, token.AccessToken);
Assert.Equal(TestTenantId, token.TenantId);
}
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
// ----------------------------------------------------------------------------------
//
// Copyright Microsoft Corporation
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
// http://www.apache.org/licenses/LICENSE-2.0
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// ----------------------------------------------------------------------------------

using Azure.Identity;
using System;
using System.Text;

namespace Microsoft.Azure.Commands.Common.Authentication
{
public static class AuthenticationFailedExceptionExtention
{
public static AuthenticationFailedException FromExceptionAndAdditionalMessage(this AuthenticationFailedException e, string additonal)
{
var errorMessage = new StringBuilder(e.Message);
errorMessage.Append(Environment.NewLine).Append(additonal);
return new AuthenticationFailedException(errorMessage.ToString(), e);
}
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
// ----------------------------------------------------------------------------------
//
// Copyright Microsoft Corporation
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
// http://www.apache.org/licenses/LICENSE-2.0
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// ----------------------------------------------------------------------------------

using System.Net.Http;
using System.Threading;
using System.Threading.Tasks;

namespace Microsoft.Azure.Commands.Common.Authentication
{
public interface IClaimsChallengeProcessor
{
ValueTask<bool> OnClaimsChallenageAsync(HttpRequestMessage request, string claimsChallenge, CancellationToken cancellationToken);
}
}
Loading