Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add option --prompt-hint to enable customized header text for WAM prompts and web mode. #11

Merged
merged 19 commits into from
Apr 14, 2022
Merged
Show file tree
Hide file tree
Changes from 17 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
3 changes: 2 additions & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
## [v0.1.0] - 2022-03-30
### Added
- Initial project release.
- Option `--prompt-hint` to support custom text to prompt caller in web and WAM mode.
goagain marked this conversation as resolved.
Show resolved Hide resolved

[Unreleased]: https://github.com/AzureAD/microsoft-authentication-cli/compare/v0.1.0...HEAD
[v0.1.0]: https://github.com/AzureAD/microsoft-authentication-cli/releases/tag/v0.1.0
[v0.1.0]: https://github.com/AzureAD/microsoft-authentication-cli/releases/tag/v0.1.0
15 changes: 15 additions & 0 deletions src/AzureAuth.Test/AliasTest.cs
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,20 @@ public void TestMergePrefersNonNullMembers()
result.Should().BeEquivalentTo(this.expected);
}

/// <summary>
/// The test to merge prompt hints.
/// </summary>
[Test]
public void TestMergePromptHint()
{
this.alias.PromptHint = "expected prompt hint";
this.expected.PromptHint = "expected prompt hint";

Alias result = this.alias.Override(this.other);

result.Should().BeEquivalentTo(this.expected);
}

/// <summary>
/// The test to merge multiple members.
/// </summary>
Expand Down Expand Up @@ -100,6 +114,7 @@ public void TestMergeAllMembers()
this.alias.Client = "unexpected client";
this.alias.Domain = "unexpected domain";
this.alias.Tenant = "unexpected tenant";
this.alias.PromptHint = "unexpected prompt hint";
this.alias.Scopes = new List<string> { "unexpected scope" };

Alias result = this.alias.Override(this.other);
Expand Down
4 changes: 4 additions & 0 deletions src/AzureAuth.Test/CommandMainTest.cs
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ internal class CommandMainTest
domain = ""contoso.com""
tenant = ""a3be859b-7f9a-4955-98ed-f3602dbd954c""
scopes = [ "".default"", ]
prompt_hint = ""sample prompt hint.""
";

private const string PartialAliasTOML = @"
Expand Down Expand Up @@ -162,6 +163,7 @@ public void TestEvaluateOptionsProvidedAliasWithoutCommandLineOptions()
Domain = "contoso.com",
Tenant = "a3be859b-7f9a-4955-98ed-f3602dbd954c",
Scopes = new List<string> { ".default" },
PromptHint = "sample prompt hint.",
};

CommandMain subject = this.serviceProvider.GetService<CommandMain>();
Expand All @@ -188,6 +190,7 @@ public void TestEvaluateOptionsProvidedAliasWithCommandLineOptions()
Domain = "contoso.com",
Tenant = "a3be859b-7f9a-4955-98ed-f3602dbd954c",
Scopes = new List<string> { ".default" },
PromptHint = "sample prompt hint.",
};

CommandMain subject = this.serviceProvider.GetService<CommandMain>();
Expand Down Expand Up @@ -217,6 +220,7 @@ public void TestEvaluateOptionsProvidedAliasWithEnvVarConfig()
Domain = "contoso.com",
Tenant = "a3be859b-7f9a-4955-98ed-f3602dbd954c",
Scopes = new List<string> { ".default" },
PromptHint = "sample prompt hint.",
};

CommandMain subject = this.serviceProvider.GetService<CommandMain>();
Expand Down
6 changes: 6 additions & 0 deletions src/AzureAuth/Alias.cs
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,11 @@ public class Alias
/// </summary>
public string Tenant { get; set; }

/// <summary>
/// Gets or sets the customized prompt hint.
/// </summary>
public string PromptHint { get; set; }

/// <summary>
/// Gets or sets the scopes.
/// </summary>
Expand All @@ -53,6 +58,7 @@ public Alias Override(Alias other)
Client = other.Client ?? this.Client,
Domain = other.Domain ?? this.Domain,
Tenant = other.Tenant ?? this.Tenant,
PromptHint = other.PromptHint ?? this.PromptHint,
Scopes = other.Scopes ?? this.Scopes,
};
}
Expand Down
11 changes: 10 additions & 1 deletion src/AzureAuth/CommandMain.cs
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ public class CommandMain
private const string ResourceOption = "--resource";
private const string ClientOption = "--client";
private const string TenantOption = "--tenant";
private const string PromptHintOption = "--prompt-hint";
private const string ScopeOption = "--scope";
private const string ClearOption = "--clear";
private const string DomainOption = "--domain";
Expand Down Expand Up @@ -104,6 +105,12 @@ public CommandMain(CommandExecuteEventData eventData, ILogger<CommandMain> logge
[Option(TenantOption, "The ID of the Tenant where the client and resource entities exist in", CommandOptionType.SingleValue)]
public string Tenant { get; set; }

/// <summary>
/// Gets or sets the customized prompt hint text for WAM prompts and web mode.
/// </summary>
[Option(PromptHintOption, "The prompt hint text for WAM prompts and web mode.", CommandOptionType.SingleValue)]
public string PromptHint { get; set; }

/// <summary>
/// Gets or sets the scopes.
/// </summary>
Expand Down Expand Up @@ -172,6 +179,7 @@ public bool EvaluateOptions()
Client = this.Client,
Domain = this.PreferredDomain,
Tenant = this.Tenant,
PromptHint = this.PromptHint,
Scopes = this.Scopes?.ToList(),
};

Expand Down Expand Up @@ -358,7 +366,8 @@ private ITokenFetcher TokenFetcher()
new Guid(this.tokenFetcherOptions.Client),
new Guid(this.tokenFetcherOptions.Tenant),
osxKeyChainSuffix: Constants.AuthOSXKeyChainSuffix,
preferredDomain: this.tokenFetcherOptions.Domain);
preferredDomain: this.tokenFetcherOptions.Domain,
promptHint: this.PromptHint);
}

return this.tokenFetcher;
Expand Down
31 changes: 30 additions & 1 deletion src/MSALWrapper.Test/TokenFetcherPublicClientTest.cs
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,8 @@ public class TokenFetcherPublicClientTest
private IEnumerable<string> scopes = new string[] { $"{ResourceId}/.default" };
private TokenResult tokenResult;

private string promptHint = "test prompt hint";

/// <summary>
/// The setup.
/// </summary>
Expand All @@ -73,7 +75,7 @@ public void Setup()
.AddTransient<TokenFetcherPublicClient>((provider) =>
{
var logger = provider.GetService<ILogger<TokenFetcherPublicClient>>();
return new TokenFetcherPublicClient(logger, ResourceId, ClientId, TenantId);
return new TokenFetcherPublicClient(logger, ResourceId, ClientId, TenantId, promptHint: this.promptHint);
goagain marked this conversation as resolved.
Show resolved Hide resolved
})
.BuildServiceProvider();

Expand Down Expand Up @@ -371,6 +373,25 @@ public async Task GetTokenNormalFlowAsync_GetTokenInteractive_Throws_OperationCa
tokenFetcher.ErrorsList[2].Message.Should().Be("Interactive Auth (with extra claims) timed out after 15 minutes.");
}

/// <summary>
/// Ensure <see cref="IPCAWrapper.WithPromptHint"/> be invoked in <see cref="TokenFetcherPublicClient.GetTokenNormalFlowAsync"/>.
/// </summary>
/// <returns>The <see cref="Task"/>.</returns>
[Test]
public async Task GetTokenNormalFlowAsync_GetTokenInteractive_WithPromptHint()
{
this.SilentAuthUIRequired();
this.InteractiveAuthResult();

// Act
var tokenFetcher = this.Subject();
var result = await tokenFetcher.GetTokenNormalFlowAsync(this.pcaMock.Object, this.scopes, this.testAccount.Object);

// Verify
this.pcaMock.Verify((pca) => pca.WithPromptHint(this.promptHint), Times.Once());
this.pcaMock.VerifyAll();
}

/// <summary>
/// The get token_ device code_ flow_ happy path.
/// </summary>
Expand Down Expand Up @@ -556,6 +577,7 @@ private void SilentAuthUIRequired()
this.pcaMock
.Setup((pca) => pca.GetTokenSilentAsync(this.scopes, this.testAccount.Object, It.IsAny<CancellationToken>()))
.Throws(new MsalUiRequiredException("1", "UI is required"));
this.SetupInteractiveAuthWithPromptHint();
}

private void SilentAuthServiceException()
Expand Down Expand Up @@ -628,6 +650,13 @@ private void InteractiveAuthWithClaimsTimeout()
.Throws(new OperationCanceledException());
}

private void SetupInteractiveAuthWithPromptHint()
{
this.pcaMock
.Setup(pca => pca.WithPromptHint(It.IsAny<string>()))
.Returns((string s) => this.pcaMock.Object);
}

private TokenFetcherPublicClient Subject() => this.serviceProvider.GetService<TokenFetcherPublicClient>();
}
}
31 changes: 31 additions & 0 deletions src/MSALWrapper/PCAWrapper.cs
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,13 @@ public interface IPCAWrapper
/// The <see cref="Task"/>.
/// </returns>
Task<TokenResult> GetTokenDeviceCodeAsync(IEnumerable<string> scopes, Func<DeviceCodeResult, Task> callback, CancellationToken cancellationToken);

/// <summary>
/// Customize the title bar by prompt hint(Web mode only).
/// </summary>
/// <param name="promptHint">The prompt hint text.</param>
/// <returns>This.</returns>
IPCAWrapper WithPromptHint(string promptHint);
}

/// <summary>
Expand All @@ -106,6 +113,22 @@ public PCAWrapper(IPublicClientApplication pca)
this.pca = pca;
}

/// <summary>
/// Gets or sets, The prompt hint displayed in the title bar.
/// </summary>
public string PromptHint { get; set; }

/// <summary>
/// Customize the title bar by prompt hint(Web mode only).
/// </summary>
/// <param name="promptHint">see <see cref="PromptHint"/>.</param>
/// <returns>This.</returns>
public IPCAWrapper WithPromptHint(string promptHint)
{
this.PromptHint = promptHint;
return this;
}

/// <summary>
/// The get token silent async.
/// </summary>
Expand Down Expand Up @@ -146,6 +169,10 @@ public async Task<TokenResult> GetTokenInteractiveAsync(IEnumerable<string> scop
{
AuthenticationResult result = await this.pca
.AcquireTokenInteractive(scopes)
.WithEmbeddedWebViewOptions(new EmbeddedWebViewOptions()
{
Title = this.PromptHint,
})
.WithAccount(account)
.ExecuteAsync(cancellationToken)
.ConfigureAwait(false);
Expand All @@ -171,6 +198,10 @@ public async Task<TokenResult> GetTokenInteractiveAsync(IEnumerable<string> scop
{
AuthenticationResult result = await this.pca
.AcquireTokenInteractive(scopes)
.WithEmbeddedWebViewOptions(new EmbeddedWebViewOptions()
{
Title = this.PromptHint,
})
.WithClaims(claims)
.ExecuteAsync(cancellationToken)
.ConfigureAwait(false);
Expand Down
22 changes: 19 additions & 3 deletions src/MSALWrapper/TokenFetcherPublicClient.cs
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,8 @@ public class TokenFetcherPublicClient : ITokenFetcher
private readonly bool windows;
private readonly bool windows10;

private readonly string promptHint;

#region Required MSAL GUIDs

/// <summary>
Expand Down Expand Up @@ -115,7 +117,10 @@ public class TokenFetcherPublicClient : ITokenFetcher
/// <param name="verifyPersistence">
/// Optionally choose to verify the cache persistence layer when setting up the token cache.
/// </param>
public TokenFetcherPublicClient(ILogger logger, Guid resourceId, Guid clientId, Guid tenantId, string osxKeyChainSuffix = null, string preferredDomain = null, bool verifyPersistence = false)
/// <param name="promptHint">
/// The customized header text in account picker for WAM prompts.
/// </param>
public TokenFetcherPublicClient(ILogger logger, Guid resourceId, Guid clientId, Guid tenantId, string osxKeyChainSuffix = null, string preferredDomain = null, bool verifyPersistence = false, string promptHint = null)
{
this.windows = PlatformUtils.IsWindows(logger);
this.windows10 = PlatformUtils.IsWindows10(logger);
Expand All @@ -125,6 +130,8 @@ public TokenFetcherPublicClient(ILogger logger, Guid resourceId, Guid clientId,
this.resourceId = resourceId;
this.clientId = clientId;

this.promptHint = promptHint;

this.osxKeyChainSuffix = osxKeyChainSuffix;
this.verifyPersistence = verifyPersistence;
this.preferredDomain = preferredDomain;
Expand Down Expand Up @@ -334,7 +341,9 @@ public async Task<TokenResult> GetTokenNormalFlowAsync(IPCAWrapper pcaWrapper, I
var tokenResult = await this.CompleteWithin(
this.interactiveAuthTimeout,
"Interactive Auth",
(cancellationToken) => pcaWrapper.GetTokenInteractiveAsync(scopes, account, cancellationToken)) // TODO: Need to pass account here
(cancellationToken) => pcaWrapper
.WithPromptHint(this.promptHint)
.GetTokenInteractiveAsync(scopes, account, cancellationToken)) // TODO: Need to pass account here
.ConfigureAwait(false);
this.SetAuthenticationType(tokenResult, AuthType.Interactive);
return tokenResult;
Expand All @@ -347,7 +356,9 @@ public async Task<TokenResult> GetTokenNormalFlowAsync(IPCAWrapper pcaWrapper, I
var tokenResult = await this.CompleteWithin(
this.interactiveAuthTimeout,
"Interactive Auth (with extra claims)",
(cancellationToken) => pcaWrapper.GetTokenInteractiveAsync(scopes, ex.Claims, cancellationToken))
(cancellationToken) => pcaWrapper
.WithPromptHint(this.promptHint)
.GetTokenInteractiveAsync(scopes, ex.Claims, cancellationToken))
.ConfigureAwait(false);
this.SetAuthenticationType(tokenResult, AuthType.Interactive);
return tokenResult;
Expand Down Expand Up @@ -471,6 +482,11 @@ private IPublicClientApplication PCAWeb()
private IPublicClientApplication PCABroker()
{
var pcaBuilder = this.PCABase();
pcaBuilder.WithWindowsBrokerOptions(new WindowsBrokerOptions
{
HeaderText = this.promptHint,
});

#if NETFRAMEWORK
pcaBuilder.WithWindowsBroker();
#else
Expand Down