Skip to content

Commit

Permalink
Browse files Browse the repository at this point in the history
  • Loading branch information
David-Engel committed Feb 7, 2023
1 parent 947b7fc commit 6755789
Showing 1 changed file with 63 additions and 44 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -133,7 +133,7 @@ public override async Task<SqlAuthenticationToken> AcquireTokenAsync(SqlAuthenti
VisualStudioTenantId = tenantId,
ExcludeInteractiveBrowserCredential = true // Force disabled, even though it's disabled by default to respect driver specifications.
};
AccessToken accessToken = await new DefaultAzureCredential(defaultAzureCredentialOptions).GetTokenAsync(tokenRequestContext, cts.Token);
AccessToken accessToken = await new DefaultAzureCredential(defaultAzureCredentialOptions).GetTokenAsync(tokenRequestContext, cts.Token).ConfigureAwait(false);
SqlClientEventSource.Log.TryTraceEvent("AcquireTokenAsync | Acquired access token for Default auth mode. Expiry Time: {0}", accessToken.ExpiresOn);
return new SqlAuthenticationToken(accessToken.Token, accessToken.ExpiresOn);
}
Expand All @@ -142,15 +142,15 @@ public override async Task<SqlAuthenticationToken> AcquireTokenAsync(SqlAuthenti

if (parameters.AuthenticationMethod == SqlAuthenticationMethod.ActiveDirectoryManagedIdentity || parameters.AuthenticationMethod == SqlAuthenticationMethod.ActiveDirectoryMSI)
{
AccessToken accessToken = await new ManagedIdentityCredential(clientId, tokenCredentialOptions).GetTokenAsync(tokenRequestContext, cts.Token);
AccessToken accessToken = await new ManagedIdentityCredential(clientId, tokenCredentialOptions).GetTokenAsync(tokenRequestContext, cts.Token).ConfigureAwait(false);
SqlClientEventSource.Log.TryTraceEvent("AcquireTokenAsync | Acquired access token for Managed Identity auth mode. Expiry Time: {0}", accessToken.ExpiresOn);
return new SqlAuthenticationToken(accessToken.Token, accessToken.ExpiresOn);
}

AuthenticationResult result;
if (parameters.AuthenticationMethod == SqlAuthenticationMethod.ActiveDirectoryServicePrincipal)
{
AccessToken accessToken = await new ClientSecretCredential(tenantId, parameters.UserId, parameters.Password, tokenCredentialOptions).GetTokenAsync(tokenRequestContext, cts.Token);
AccessToken accessToken = await new ClientSecretCredential(tenantId, parameters.UserId, parameters.Password, tokenCredentialOptions).GetTokenAsync(tokenRequestContext, cts.Token).ConfigureAwait(false);
SqlClientEventSource.Log.TryTraceEvent("AcquireTokenAsync | Acquired access token for Active Directory Service Principal auth mode. Expiry Time: {0}", accessToken.ExpiresOn);
return new SqlAuthenticationToken(accessToken.Token, accessToken.ExpiresOn);
}
Expand Down Expand Up @@ -181,40 +181,13 @@ public override async Task<SqlAuthenticationToken> AcquireTokenAsync(SqlAuthenti

IPublicClientApplication app = GetPublicClientAppInstance(pcaKey);

if (parameters.AuthenticationMethod == SqlAuthenticationMethod.ActiveDirectoryIntegrated)
{
if (!string.IsNullOrEmpty(parameters.UserId))
{
result = await app.AcquireTokenByIntegratedWindowsAuth(scopes)
.WithCorrelationId(parameters.ConnectionId)
.WithUsername(parameters.UserId)
.ExecuteAsync(cancellationToken: cts.Token);
}
else
{
result = await app.AcquireTokenByIntegratedWindowsAuth(scopes)
.WithCorrelationId(parameters.ConnectionId)
.ExecuteAsync(cancellationToken: cts.Token);
}
SqlClientEventSource.Log.TryTraceEvent("AcquireTokenAsync | Acquired access token for Active Directory Integrated auth mode. Expiry Time: {0}", result?.ExpiresOn);
}
else if (parameters.AuthenticationMethod == SqlAuthenticationMethod.ActiveDirectoryPassword)
{
SecureString password = new SecureString();
foreach (char c in parameters.Password)
password.AppendChar(c);
password.MakeReadOnly();

result = await app.AcquireTokenByUsernamePassword(scopes, parameters.UserId, password)
.WithCorrelationId(parameters.ConnectionId)
.ExecuteAsync(cancellationToken: cts.Token);
SqlClientEventSource.Log.TryTraceEvent("AcquireTokenAsync | Acquired access token for Active Directory Password auth mode. Expiry Time: {0}", result?.ExpiresOn);
}
else if (parameters.AuthenticationMethod == SqlAuthenticationMethod.ActiveDirectoryInteractive ||
parameters.AuthenticationMethod == SqlAuthenticationMethod.ActiveDirectoryDeviceCodeFlow)
if (parameters.AuthenticationMethod == SqlAuthenticationMethod.ActiveDirectoryIntegrated ||
parameters.AuthenticationMethod == SqlAuthenticationMethod.ActiveDirectoryPassword ||
parameters.AuthenticationMethod == SqlAuthenticationMethod.ActiveDirectoryInteractive ||
parameters.AuthenticationMethod == SqlAuthenticationMethod.ActiveDirectoryDeviceCodeFlow)
{
// Fetch available accounts from 'app' instance
System.Collections.Generic.IEnumerator<IAccount> accounts = (await app.GetAccountsAsync()).GetEnumerator();
System.Collections.Generic.IEnumerator<IAccount> accounts = (await app.GetAccountsAsync().ConfigureAwait(false)).GetEnumerator();

IAccount account = default;
if (accounts.MoveNext())
Expand Down Expand Up @@ -244,23 +217,66 @@ public override async Task<SqlAuthenticationToken> AcquireTokenAsync(SqlAuthenti
{
// If 'account' is available in 'app', we use the same to acquire token silently.
// Read More on API docs: https://docs.microsoft.com/dotnet/api/microsoft.identity.client.clientapplicationbase.acquiretokensilent
result = await app.AcquireTokenSilent(scopes, account).ExecuteAsync(cancellationToken: cts.Token);
result = await app.AcquireTokenSilent(scopes, account).ExecuteAsync(cancellationToken: cts.Token).ConfigureAwait(false);
SqlClientEventSource.Log.TryTraceEvent("AcquireTokenAsync | Acquired access token (silent) for {0} auth mode. Expiry Time: {1}", parameters.AuthenticationMethod, result?.ExpiresOn);
}
catch (MsalUiRequiredException)
{
// An 'MsalUiRequiredException' is thrown in the case where an interaction is required with the end user of the application,
// for instance, if no refresh token was in the cache, or the user needs to consent, or re-sign-in (for instance if the password expired),
// or the user needs to perform two factor authentication.
result = await AcquireTokenInteractiveDeviceFlowAsync(app, scopes, parameters.ConnectionId, parameters.UserId, parameters.AuthenticationMethod, cts);
SqlClientEventSource.Log.TryTraceEvent("AcquireTokenAsync | Acquired access token (interactive) for {0} auth mode. Expiry Time: {1}", parameters.AuthenticationMethod, result?.ExpiresOn);
if (parameters.AuthenticationMethod == SqlAuthenticationMethod.ActiveDirectoryInteractive ||
parameters.AuthenticationMethod == SqlAuthenticationMethod.ActiveDirectoryDeviceCodeFlow)
{
result = await AcquireTokenInteractiveDeviceFlowAsync(app, scopes, parameters.ConnectionId, parameters.UserId, parameters.AuthenticationMethod, cts).ConfigureAwait(false);
SqlClientEventSource.Log.TryTraceEvent("AcquireTokenAsync | Acquired access token (interactive) for {0} auth mode. Expiry Time: {1}", parameters.AuthenticationMethod, result?.ExpiresOn);
}
else
{
throw;
}
}
}
else
{
// If no existing 'account' is found, we request user to sign in interactively.
result = await AcquireTokenInteractiveDeviceFlowAsync(app, scopes, parameters.ConnectionId, parameters.UserId, parameters.AuthenticationMethod, cts);
SqlClientEventSource.Log.TryTraceEvent("AcquireTokenAsync | Acquired access token (interactive) for {0} auth mode. Expiry Time: {1}", parameters.AuthenticationMethod, result?.ExpiresOn);
if (parameters.AuthenticationMethod == SqlAuthenticationMethod.ActiveDirectoryIntegrated)
{
if (!string.IsNullOrEmpty(parameters.UserId))
{
result = await app.AcquireTokenByIntegratedWindowsAuth(scopes)
.WithCorrelationId(parameters.ConnectionId)
.WithUsername(parameters.UserId)
.ExecuteAsync(cancellationToken: cts.Token)
.ConfigureAwait(false);
}
else
{
result = await app.AcquireTokenByIntegratedWindowsAuth(scopes)
.WithCorrelationId(parameters.ConnectionId)
.ExecuteAsync(cancellationToken: cts.Token)
.ConfigureAwait(false);
}
SqlClientEventSource.Log.TryTraceEvent("AcquireTokenAsync | Acquired access token for Active Directory Integrated auth mode. Expiry Time: {0}", result?.ExpiresOn);
}
else if (parameters.AuthenticationMethod == SqlAuthenticationMethod.ActiveDirectoryPassword)
{
SecureString password = new SecureString();
foreach (char c in parameters.Password)
password.AppendChar(c);
password.MakeReadOnly();

result = await app.AcquireTokenByUsernamePassword(scopes, parameters.UserId, password)
.WithCorrelationId(parameters.ConnectionId)
.ExecuteAsync(cancellationToken: cts.Token)
.ConfigureAwait(false);
SqlClientEventSource.Log.TryTraceEvent("AcquireTokenAsync | Acquired access token for Active Directory Password auth mode. Expiry Time: {0}", result?.ExpiresOn);
}
else
{
// If no existing 'account' is found, we request user to sign in interactively.
result = await AcquireTokenInteractiveDeviceFlowAsync(app, scopes, parameters.ConnectionId, parameters.UserId, parameters.AuthenticationMethod, cts).ConfigureAwait(false);
SqlClientEventSource.Log.TryTraceEvent("AcquireTokenAsync | Acquired access token (interactive) for {0} auth mode. Expiry Time: {1}", parameters.AuthenticationMethod, result?.ExpiresOn);
}
}
}
else
Expand Down Expand Up @@ -298,7 +314,8 @@ public override async Task<SqlAuthenticationToken> AcquireTokenAsync(SqlAuthenti
.WithCorrelationId(connectionId)
.WithCustomWebUi(_customWebUI)
.WithLoginHint(userId)
.ExecuteAsync(ctsInteractive.Token);
.ExecuteAsync(ctsInteractive.Token)
.ConfigureAwait(false);
}
else
{
Expand All @@ -322,15 +339,17 @@ public override async Task<SqlAuthenticationToken> AcquireTokenAsync(SqlAuthenti
return await app.AcquireTokenInteractive(scopes)
.WithCorrelationId(connectionId)
.WithLoginHint(userId)
.ExecuteAsync(ctsInteractive.Token);
.ExecuteAsync(ctsInteractive.Token)
.ConfigureAwait(false);
}
}
else
{
AuthenticationResult result = await app.AcquireTokenWithDeviceCode(scopes,
deviceCodeResult => _deviceCodeFlowCallback(deviceCodeResult))
.WithCorrelationId(connectionId)
.ExecuteAsync(cancellationToken: cts.Token);
.ExecuteAsync(cancellationToken: cts.Token)
.ConfigureAwait(false);
return result;
}
}
Expand Down

0 comments on commit 6755789

Please sign in to comment.