Skip to content

Commit

Permalink
Fix ConfidentialClient's AcquireTokenSilent and AcquireTokenOnBehalfO…
Browse files Browse the repository at this point in the history
…f claims logic (#43513)
  • Loading branch information
christothes committed Apr 19, 2024
1 parent 1835c8f commit ae13ec2
Show file tree
Hide file tree
Showing 9 changed files with 112 additions and 15 deletions.
6 changes: 3 additions & 3 deletions sdk/identity/Azure.Identity/CHANGELOG.md
@@ -1,9 +1,9 @@
# Release History

## 1.12.0-beta.1 (2024-04-17)
## 1.11.2 (2024-04-19)

### Other Changes
- An experimental overload `Authenticate` method on `InteractiveBrowserCredential` now supports the experimental `PopTokenRequestContext` parameter.
### Bugs Fixed
- Fixed an issue which caused claims to be incorrectly added to confidential client credentials such as `DeviceCodeCredential` [#43468](https://github.com/Azure/azure-sdk-for-net/issues/43468)

## 1.11.1 (2024-04-16)

Expand Down
Expand Up @@ -261,23 +261,19 @@ public static partial class IdentityModelFactory
public static Azure.Identity.AuthenticationRecord AuthenticationRecord(string username, string authority, string homeAccountId, string tenantId, string clientId) { throw null; }
public static Azure.Identity.DeviceCodeInfo DeviceCodeInfo(string userCode, string deviceCode, System.Uri verificationUri, System.DateTimeOffset expiresOn, string message, string clientId, System.Collections.Generic.IReadOnlyCollection<string> scopes) { throw null; }
}
public partial class InteractiveBrowserCredential : Azure.Core.TokenCredential, Azure.Core.ISupportsProofOfPossession
public partial class InteractiveBrowserCredential : Azure.Core.TokenCredential
{
public InteractiveBrowserCredential() { }
public InteractiveBrowserCredential(Azure.Identity.InteractiveBrowserCredentialOptions options) { }
[System.ComponentModel.EditorBrowsableAttribute(System.ComponentModel.EditorBrowsableState.Never)]
public InteractiveBrowserCredential(string clientId) { }
[System.ComponentModel.EditorBrowsableAttribute(System.ComponentModel.EditorBrowsableState.Never)]
public InteractiveBrowserCredential(string tenantId, string clientId, Azure.Identity.TokenCredentialOptions options = null) { }
public virtual Azure.Identity.AuthenticationRecord Authenticate(Azure.Core.PopTokenRequestContext requestContext, System.Threading.CancellationToken cancellationToken = default(System.Threading.CancellationToken)) { throw null; }
public virtual Azure.Identity.AuthenticationRecord Authenticate(Azure.Core.TokenRequestContext requestContext, System.Threading.CancellationToken cancellationToken = default(System.Threading.CancellationToken)) { throw null; }
public virtual Azure.Identity.AuthenticationRecord Authenticate(System.Threading.CancellationToken cancellationToken = default(System.Threading.CancellationToken)) { throw null; }
public virtual System.Threading.Tasks.Task<Azure.Identity.AuthenticationRecord> AuthenticateAsync(Azure.Core.PopTokenRequestContext requestContext, System.Threading.CancellationToken cancellationToken = default(System.Threading.CancellationToken)) { throw null; }
public virtual System.Threading.Tasks.Task<Azure.Identity.AuthenticationRecord> AuthenticateAsync(Azure.Core.TokenRequestContext requestContext, System.Threading.CancellationToken cancellationToken = default(System.Threading.CancellationToken)) { throw null; }
public virtual System.Threading.Tasks.Task<Azure.Identity.AuthenticationRecord> AuthenticateAsync(System.Threading.CancellationToken cancellationToken = default(System.Threading.CancellationToken)) { throw null; }
public Azure.Core.AccessToken GetToken(Azure.Core.PopTokenRequestContext requestContext, System.Threading.CancellationToken cancellationToken = default(System.Threading.CancellationToken)) { throw null; }
public override Azure.Core.AccessToken GetToken(Azure.Core.TokenRequestContext requestContext, System.Threading.CancellationToken cancellationToken = default(System.Threading.CancellationToken)) { throw null; }
public System.Threading.Tasks.ValueTask<Azure.Core.AccessToken> GetTokenAsync(Azure.Core.PopTokenRequestContext requestContext, System.Threading.CancellationToken cancellationToken = default(System.Threading.CancellationToken)) { throw null; }
public override System.Threading.Tasks.ValueTask<Azure.Core.AccessToken> GetTokenAsync(Azure.Core.TokenRequestContext requestContext, System.Threading.CancellationToken cancellationToken = default(System.Threading.CancellationToken)) { throw null; }
}
public partial class InteractiveBrowserCredentialOptions : Azure.Identity.TokenCredentialOptions
Expand Down
2 changes: 1 addition & 1 deletion sdk/identity/Azure.Identity/src/Azure.Identity.csproj
Expand Up @@ -2,7 +2,7 @@
<PropertyGroup>
<Description>This is the implementation of the Azure SDK Client Library for Azure Identity</Description>
<AssemblyTitle>Microsoft Azure.Identity Component</AssemblyTitle>
<Version>1.12.0-beta.1</Version>
<Version>1.11.2</Version>
<!--The ApiCompatVersion is managed automatically and should not generally be modified manually.-->
<ApiCompatVersion>1.11.1</ApiCompatVersion>
<PackageTags>Microsoft Azure Identity;$(PackageCommonTags)</PackageTags>
Expand Down
Expand Up @@ -181,7 +181,15 @@ private async ValueTask<AccessToken> GetTokenImplAsync(bool async, TokenRequestC
private async Task<AccessToken> AcquireTokenWithCode(bool async, TokenRequestContext requestContext, AccessToken token, string tenantId, CancellationToken cancellationToken)
{
AuthenticationResult result = await Client
.AcquireTokenByAuthorizationCodeAsync(requestContext.Scopes, _authCode, tenantId, _redirectUri, requestContext.Claims, requestContext.IsCaeEnabled, async, cancellationToken)
.AcquireTokenByAuthorizationCodeAsync(
scopes: requestContext.Scopes,
code: _authCode,
tenantId: tenantId,
redirectUri: _redirectUri,
claims: requestContext.Claims,
enableCae: requestContext.IsCaeEnabled,
async,
cancellationToken)
.ConfigureAwait(false);
_record = new AuthenticationRecord(result, _clientId);
token = new AccessToken(result.AccessToken, result.ExpiresOn);
Expand Down
10 changes: 7 additions & 3 deletions sdk/identity/Azure.Identity/src/MsalConfidentialClient.cs
Expand Up @@ -220,7 +220,7 @@ protected virtual async ValueTask<IConfidentialClientApplication> CreateClientCo
};
builder.WithTenantIdFromAuthority(uriBuilder.Uri);
}
if (string.IsNullOrEmpty(claims))
if (!string.IsNullOrEmpty(claims))
{
builder.WithClaims(claims);
}
Expand All @@ -239,7 +239,7 @@ protected virtual async ValueTask<IConfidentialClientApplication> CreateClientCo
bool async,
CancellationToken cancellationToken)
{
var result = await AcquireTokenByAuthorizationCodeCoreAsync(scopes, code, tenantId, redirectUri, claims, enableCae, async, cancellationToken).ConfigureAwait(false);
var result = await AcquireTokenByAuthorizationCodeCoreAsync(scopes: scopes, code: code, tenantId: tenantId, redirectUri: redirectUri, claims: claims, enableCae: enableCae, async, cancellationToken).ConfigureAwait(false);
LogAccountDetails(result);
return result;
}
Expand All @@ -248,8 +248,8 @@ protected virtual async ValueTask<IConfidentialClientApplication> CreateClientCo
string[] scopes,
string code,
string tenantId,
string claims,
string redirectUri,
string claims,
bool enableCae,
bool async,
CancellationToken cancellationToken)
Expand Down Expand Up @@ -312,6 +312,10 @@ protected virtual async ValueTask<IConfidentialClientApplication> CreateClientCo
};
builder.WithTenantIdFromAuthority(uriBuilder.Uri);
}
if (!string.IsNullOrEmpty(claims))
{
builder.WithClaims(claims);
}
return await builder
.ExecuteAsync(async, cancellationToken)
.ConfigureAwait(false);
Expand Down
Expand Up @@ -33,6 +33,7 @@ public override TokenCredential GetTokenCredential(CommonCredentialTestConfig co
DisableInstanceDiscovery = config.DisableInstanceDiscovery,
AdditionallyAllowedTenants = config.AdditionallyAllowedTenants,
IsUnsafeSupportLoggingEnabled = config.IsUnsafeSupportLoggingEnabled,
RedirectUri = config.RedirectUri,
};
if (config.Transport != null)
{
Expand Down
4 changes: 2 additions & 2 deletions sdk/identity/Azure.Identity/tests/Azure.Identity.Tests.csproj
Expand Up @@ -16,9 +16,9 @@
<PackageReference Include="Azure.Storage.Blobs" />
</ItemGroup>
<!-- Remove before shipping GA -->
<PropertyGroup>
<!-- <PropertyGroup>
<DefineConstants>PREVIEW_FEATURE_FLAG</DefineConstants>
</PropertyGroup>
</PropertyGroup> -->
<!-- End remove before shipping GA -->
<ItemGroup>
<ProjectReference Include="$(AzureCoreTestFramework)" />
Expand Down
87 changes: 87 additions & 0 deletions sdk/identity/Azure.Identity/tests/CredentialTestBase.cs
Expand Up @@ -355,6 +355,92 @@ public async Task EnableCae()
Assert.True(observedNoCae);
}

[Test]
public async Task ClaimsSetCorrectlyOnRequest()
{
// Configure the transport
var token = Guid.NewGuid().ToString();
var idToken = CredentialTestHelpers.CreateMsalIdToken(Guid.NewGuid().ToString(), "userName", TenantId);
bool calledDiscoveryEndpoint = false;
bool isPubClient = false;
const string Claims = "myClaims";

var mockTransport = new MockTransport(req =>
{
calledDiscoveryEndpoint |= req.Uri.Path.Contains("discovery/instance");
MockResponse response = new(200);
if (req.Uri.Path.EndsWith("/devicecode"))
{
response = CredentialTestHelpers.CreateMockMsalDeviceCodeResponse();
}
else if (req.Uri.Path.Contains("/userrealm/"))
{
response.SetContent(UserrealmResponse);
}
else
{
if (isPubClient || typeof(TCredOptions) == typeof(AuthorizationCodeCredentialOptions) || typeof(TCredOptions) == typeof(OnBehalfOfCredentialOptions))
{
response = CredentialTestHelpers.CreateMockMsalTokenResponse(200, token, TenantId, ExpectedUsername, ObjectId);
}
else
{
response.SetContent($"{{\"token_type\": \"Bearer\",\"expires_in\": 9999,\"ext_expires_in\": 9999,\"access_token\": \"{token}\" }}");
}
if (req.Content != null)
{
var stream = new MemoryStream();
req.Content.WriteTo(stream, default);
var content = new BinaryData(stream.ToArray()).ToString();
var queryString = Uri.UnescapeDataString(content)
.Split('&')
.Select(q => q.Split('='))
.ToDictionary(kvp => kvp[0], kvp => kvp[1]);
bool containsClaims = queryString.TryGetValue("claims", out var claimsJson);
if (req.ClientRequestId == "NoClaims")
{
Assert.False(containsClaims, "(NoClaims) Claims should not be present. Claims=" + claimsJson);
}
if (req.ClientRequestId == "WithClaims")
{
Assert.True(containsClaims, "(WithClaims) Claims should be present");
Assert.AreEqual(Claims, claimsJson, "(WithClaims) Claims should match");
}
}
}
return response;
});

var config = new CommonCredentialTestConfig()
{
Transport = mockTransport,
TenantId = TenantId,
RedirectUri = new Uri("http://localhost:8400/")
};
var credential = GetTokenCredential(config);
if (!CredentialTestHelpers.IsMsalCredential(credential))
{
Assert.Ignore("EnableCAE tests do not apply to the non-MSAL credentials.");
}
isPubClient = CredentialTestHelpers.IsCredentialTypePubClient(credential);

using (HttpPipeline.CreateClientRequestIdScope("NoClaims"))
{
// First call to populate the account record for confidential client creds
await credential.GetTokenAsync(new TokenRequestContext(MockScopes.Default), default);
var actualToken = await credential.GetTokenAsync(new TokenRequestContext(MockScopes.Alternate), default);
Assert.AreEqual(token, actualToken.Token);
}
using (HttpPipeline.CreateClientRequestIdScope("WithClaims"))
{
var actualToken = await credential.GetTokenAsync(new TokenRequestContext(MockScopes.Alternate2, claims: Claims), default);
Assert.AreEqual(token, actualToken.Token);
}
}

[Test]
public async Task TokenRequestContextClaimsPassedToMSAL()
{
Expand Down Expand Up @@ -624,6 +710,7 @@ public class CommonCredentialTestConfig : TokenCredentialOptions, ISupportsAddit
public TokenRequestContext RequestContext { get; set; }
public string TenantId { get; set; }
public IList<string> AdditionallyAllowedTenants { get; set; } = new List<string>();
public Uri RedirectUri { get; set; }
internal TenantIdResolverBase TestTentantIdResolver { get; set; }
internal MockMsalConfidentialClient MockConfidentialMsalClient { get; set; }
internal MockMsalPublicClient MockPublicMsalClient { get; set; }
Expand Down
1 change: 1 addition & 0 deletions sdk/identity/Azure.Identity/tests/Mock/MockScopes.cs
Expand Up @@ -19,6 +19,7 @@ private MockScopes(string[] scopes)
public static MockScopes Default = new MockScopes(new string[] { "https://default.mock.auth.scope/.default" });

public static MockScopes Alternate = new MockScopes(new string[] { "https://alternate.mock.auth.scope/.default" });
public static MockScopes Alternate2 = new MockScopes(new string[] { "https://alternate2.mock.auth.scope/.default" });

public override string ToString()
{
Expand Down

0 comments on commit ae13ec2

Please sign in to comment.