Skip to content

Commit

Permalink
Merge pull request #32 from Azure-Samples/jmprieur/addSignOut
Browse files Browse the repository at this point in the history
Implement SignOut
  • Loading branch information
jmprieur committed Dec 10, 2018
2 parents 1f2348a + 8b305e9 commit c77a270
Show file tree
Hide file tree
Showing 5 changed files with 144 additions and 33 deletions.
38 changes: 17 additions & 21 deletions Controllers/HomeController.cs
Original file line number Diff line number Diff line change
Expand Up @@ -51,8 +51,18 @@ public async Task<IActionResult> Contact()
}
catch (MsalUiRequiredException ex)
{
AuthenticationProperties properties = BuildAuthenticationPropertiesForIncrementalConsent(scopes);
return Challenge(properties);
if (ex.ErrorCode != MsalUiRequiredException.UserNullError)
{
AuthenticationProperties properties = BuildAuthenticationPropertiesForIncrementalConsent(scopes);
return Challenge(properties);
}
else
{
// UserNullError indicates a cache problem as when calling Contact we should have an
// authenticate user (see the [Authenticate] attribute on the controller, but
// and therefore its account should be in the cache
throw;
}
}
}

Expand All @@ -64,33 +74,19 @@ public async Task<IActionResult> Contact()
private AuthenticationProperties BuildAuthenticationPropertiesForIncrementalConsent(string[] scopes)
{
AuthenticationProperties properties = new AuthenticationProperties();
const string msaTenantId = "9188040d-6c67-4c5b-b112-36a304b66dad";

// Set the scopes, including the scopes that ADAL.NET / MASL.NET need for the Token cache
string[] additionalBuildInScopes = new string[] { "openid", "offline_access", "profile" };
properties.SetParameter<ICollection<string>>(OpenIdConnectParameterNames.Scope, scopes.Union(additionalBuildInScopes).ToList());

// Attempts to set the login_hint to avoid the logged-in user to be presented with an account selection dialog
string loginHint = string.Empty;
string displayName = HttpContext.User.FindFirstValue("preferred_username");
if (!string.IsNullOrWhiteSpace(displayName))
string loginHint = HttpContext.User.GetLoginHint();
if (!string.IsNullOrWhiteSpace(loginHint))
{
properties.SetParameter<string>(OpenIdConnectParameterNames.LoginHint, displayName);
properties.SetParameter<string>(OpenIdConnectParameterNames.LoginHint, loginHint);

string tenantId = HttpContext.User.FindFirstValue("http://schemas.microsoft.com/identity/claims/tenantid");
if (!string.IsNullOrWhiteSpace(tenantId))
{
string domainHint;
if (tenantId == msaTenantId)
{
domainHint = "consumers";
}
else
{
domainHint = "organizations";
}
properties.SetParameter<string>(OpenIdConnectParameterNames.DomainHint, domainHint);
}
string domainHint = HttpContext.User.GetDomainHint();
properties.SetParameter<string>(OpenIdConnectParameterNames.DomainHint, domainHint);
}

return properties;
Expand Down
80 changes: 73 additions & 7 deletions Extensions/ClaimPrincipalExtension.cs
Original file line number Diff line number Diff line change
Expand Up @@ -11,28 +11,94 @@ public static class ClaimsPrincipalExtension
/// <param name="claimsPrincipal">Claims principal</param>
/// <returns>A string corresponding to an account identifier as defined in <see cref="Microsoft.Identity.Client.AccountId.Identifier"/></returns>
public static string GetMsalAccountId(this ClaimsPrincipal claimsPrincipal)
{
string userObjectId = GetObjectId(claimsPrincipal);
string tenantId = GetTenantId(claimsPrincipal);

if (string.IsNullOrWhiteSpace(userObjectId)) // TODO: find a better typed exception
throw new ArgumentOutOfRangeException("Missing claim 'http://schemas.microsoft.com/identity/claims/objectidentifier' or 'oid' ");

if (string.IsNullOrWhiteSpace(tenantId))
throw new ArgumentOutOfRangeException("Missing claim 'http://schemas.microsoft.com/identity/claims/tenantid' or 'tid' ");

string accountId = userObjectId + "." + tenantId;
return accountId;
}

/// <summary>
/// Get the unique object ID associated with the claimsPrincipal
/// </summary>
/// <param name="claimsPrincipal">Claims principal from which to retrieve the unique object id</param>
/// <returns>Unique object ID of the identity, or <c>null</c> if it cannot be found</returns>
private static string GetObjectId(ClaimsPrincipal claimsPrincipal)
{
string userObjectId = claimsPrincipal.FindFirstValue("http://schemas.microsoft.com/identity/claims/objectidentifier");
if (string.IsNullOrEmpty(userObjectId))
{
userObjectId = claimsPrincipal.FindFirstValue("oid");
}

return userObjectId;
}

/// <summary>
/// Tenant ID of the identity
/// </summary>
/// <param name="claimsPrincipal">Claims principal from which to retrieve the tenant id</param>
/// <returns>Tenant ID of the identity, or <c>null</c> if it cannot be found</returns>
private static string GetTenantId(ClaimsPrincipal claimsPrincipal)
{
string tenantId = claimsPrincipal.FindFirstValue("http://schemas.microsoft.com/identity/claims/tenantid");
if (string.IsNullOrEmpty(tenantId))
{
tenantId = claimsPrincipal.FindFirstValue("tid");
}

if (string.IsNullOrWhiteSpace(userObjectId)) // TODO: find a better typed exception
throw new ArgumentOutOfRangeException("Missing claim 'http://schemas.microsoft.com/identity/claims/objectidentifier' or 'oid' ");
return tenantId;
}

if (string.IsNullOrWhiteSpace(tenantId))
throw new ArgumentOutOfRangeException("Missing claim 'http://schemas.microsoft.com/identity/claims/tenantid' or 'tid' ");
/// <summary>
/// Gets the login-hint associated with an identity
/// </summary>
/// <param name="claimsPrincipal">Identity for which to compte the login-hint</param>
/// <returns>login-hint for the identity, or <c>null</c> if it cannot be found</returns>
public static string GetLoginHint(this ClaimsPrincipal claimsPrincipal)
{
return GetDisplayName(claimsPrincipal);
}

string accountId = userObjectId + "." + tenantId;
return accountId;
/// <summary>
/// Gets the domain-hint associated with an identity
/// </summary>
/// <param name="claimsPrincipal">Identity for which to compte the domain-hint</param>
/// <returns>domain-hint for the identity, or <c>null</c> if it cannot be found</returns>
public static string GetDomainHint(this ClaimsPrincipal claimsPrincipal)
{
// Tenant for MSA accounts
const string msaTenantId = "9188040d-6c67-4c5b-b112-36a304b66dad";

string tenantId = GetTenantId(claimsPrincipal);
string domainHint;

if (!string.IsNullOrWhiteSpace(tenantId))
{
if (tenantId == msaTenantId)
{
domainHint = "consumers";
}
else
{
domainHint = "organizations";
}
}
else
{
domainHint = null;
}
return domainHint;
}


/// <summary>
/// Get the display name for the signed-in user, based on their claims principal
/// </summary>
Expand All @@ -50,7 +116,7 @@ public static string GetDisplayName(this ClaimsPrincipal claimsPrincipal)
{
displayName = claimsPrincipal.FindFirstValue(ClaimsIdentity.DefaultNameClaimType);
}

// Finally falling back to name
if (string.IsNullOrWhiteSpace(displayName))
{
Expand Down
7 changes: 7 additions & 0 deletions Extensions/ITokenAcquisition.cs
Original file line number Diff line number Diff line change
Expand Up @@ -100,5 +100,12 @@ public interface ITokenAcquisition
/// </example>
void AddAccountToCacheFromJwt(OpenIdConnect.TokenValidatedContext tokenValidationContext, IEnumerable<string> scopes = null);

/// <summary>
/// Removes the account associated with context.HttpContext.User from the MSAL.NET cache
/// </summary>
/// <param name="context">RedirectContext passed-in to a <see cref="OnRedirectToIdentityProviderForSignOut"/>
/// Openidconnect event</param>
/// <returns></returns>
Task RemoveAccount(RedirectContext context);
}
}
38 changes: 33 additions & 5 deletions Extensions/TokenAcquisition.cs
Original file line number Diff line number Diff line change
Expand Up @@ -224,6 +224,27 @@ public void AddAccountToCacheFromJwt(OpenIdConnect.TokenValidatedContext tokenVa
tokenValidatedContext.HttpContext);
}

/// <summary>
/// Removes the account associated with context.HttpContext.User from the MSAL.NET cache
/// </summary>
/// <param name="context">RedirectContext passed-in to a <see cref="OnRedirectToIdentityProviderForSignOut"/>
/// Openidconnect event</param>
/// <returns></returns>
public async Task RemoveAccount(RedirectContext context)
{
var user = context.HttpContext.User;
var app = CreateApplication(context.HttpContext, user, context.Properties, AzureADDefaults.CookieScheme);
var account = await app.GetAccountAsync(context.HttpContext.User.GetMsalAccountId());

// Workaround for the guest account
if (account == null)
{
var accounts = await app.GetAccountsAsync();
account = accounts.FirstOrDefault(a => a.Username == user.GetLoginHint());
}
await app.RemoveAsync(account);
}

/// <summary>
/// Creates an MSAL Confidential client application
/// </summary>
Expand Down Expand Up @@ -254,7 +275,8 @@ private ConfidentialClientApplication CreateApplication(HttpContext httpContext,
private async Task<string> GetAccessTokenOnBehalfOfUser(ConfidentialClientApplication application, ClaimsPrincipal claimsPrincipal, IEnumerable<string> scopes)
{
string accountIdentifier = claimsPrincipal.GetMsalAccountId();
return await GetAccessTokenOnBehalfOfUser(application, accountIdentifier, scopes);
string loginHint = claimsPrincipal.GetLoginHint();
return await GetAccessTokenOnBehalfOfUser(application, accountIdentifier, scopes, loginHint);
}

/// <summary>
Expand All @@ -263,21 +285,27 @@ private async Task<string> GetAccessTokenOnBehalfOfUser(ConfidentialClientApplic
/// <param name="accountIdentifier">User account identifier for which to acquire a token.
/// See <see cref="Microsoft.Identity.Client.AccountId.Identifier"/></param>
/// <param name="scopes">Scopes for the downstream API to call</param>
private async Task<string> GetAccessTokenOnBehalfOfUser(ConfidentialClientApplication application, string accountIdentifier, IEnumerable<string> scopes)
private async Task<string> GetAccessTokenOnBehalfOfUser(ConfidentialClientApplication application, string accountIdentifier, IEnumerable<string> scopes, string loginHint)
{
if (accountIdentifier == null)
throw new ArgumentNullException(nameof(accountIdentifier));

if (scopes == null)
throw new ArgumentNullException(nameof(scopes));

// Remove: this is for debugging
var accounts = await application.GetAccountsAsync();
// Get the account
IAccount account = await application.GetAccountAsync(accountIdentifier);

// Special case for guest users as the Guest iod / tenant id are not surfaced.
if (account == null)
{
var accounts = await application.GetAccountsAsync();
account = accounts.FirstOrDefault(a => a.Username == loginHint);
}

try
{
AuthenticationResult result = null;
IAccount account = await application.GetAccountAsync(accountIdentifier);
result = await application.AcquireTokenSilentAsync(scopes.Except(scopesRequestedByMsalNet), account);
return result.AccessToken;
}
Expand Down
14 changes: 14 additions & 0 deletions Startup.cs
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,20 @@ public void ConfigureServices(IServiceCollection services)
await handler(context);
};
// Handling the sign-out: removing the account from MSAL.NET cache
options.Events.OnRedirectToIdentityProviderForSignOut = async context =>
{
var user = context.HttpContext.User;
// Avoid displaying the select account dialog
context.ProtocolMessage.LoginHint = user.GetLoginHint();
context.ProtocolMessage.DomainHint = user.GetDomainHint();
// Remove the account from MSAL.NET token cache
var _tokenAcquisition = context.HttpContext.RequestServices.GetRequiredService<ITokenAcquisition>();
await _tokenAcquisition.RemoveAccount(context);
};
// Avoids having users being presented the select account dialog when they are already signed-in
// for instance when going through incremental consent
options.Events.OnRedirectToIdentityProvider = async context =>
Expand Down

0 comments on commit c77a270

Please sign in to comment.