diff --git a/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/ActiveDirectoryAuthenticationProvider.cs b/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/ActiveDirectoryAuthenticationProvider.cs index 25f591f1c2..127ab5bcc3 100644 --- a/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/ActiveDirectoryAuthenticationProvider.cs +++ b/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/ActiveDirectoryAuthenticationProvider.cs @@ -133,7 +133,7 @@ public override async Task AcquireTokenAsync(SqlAuthenti VisualStudioTenantId = tenantId, ExcludeInteractiveBrowserCredential = true // Force disabled, even though it's disabled by default to respect driver specifications. }; - AccessToken accessToken = await new DefaultAzureCredential(defaultAzureCredentialOptions).GetTokenAsync(tokenRequestContext, cts.Token); + AccessToken accessToken = await new DefaultAzureCredential(defaultAzureCredentialOptions).GetTokenAsync(tokenRequestContext, cts.Token).ConfigureAwait(false); SqlClientEventSource.Log.TryTraceEvent("AcquireTokenAsync | Acquired access token for Default auth mode. Expiry Time: {0}", accessToken.ExpiresOn); return new SqlAuthenticationToken(accessToken.Token, accessToken.ExpiresOn); } @@ -142,7 +142,7 @@ public override async Task AcquireTokenAsync(SqlAuthenti if (parameters.AuthenticationMethod == SqlAuthenticationMethod.ActiveDirectoryManagedIdentity || parameters.AuthenticationMethod == SqlAuthenticationMethod.ActiveDirectoryMSI) { - AccessToken accessToken = await new ManagedIdentityCredential(clientId, tokenCredentialOptions).GetTokenAsync(tokenRequestContext, cts.Token); + AccessToken accessToken = await new ManagedIdentityCredential(clientId, tokenCredentialOptions).GetTokenAsync(tokenRequestContext, cts.Token).ConfigureAwait(false); SqlClientEventSource.Log.TryTraceEvent("AcquireTokenAsync | Acquired access token for Managed Identity auth mode. Expiry Time: {0}", accessToken.ExpiresOn); return new SqlAuthenticationToken(accessToken.Token, accessToken.ExpiresOn); } @@ -150,7 +150,7 @@ public override async Task AcquireTokenAsync(SqlAuthenti AuthenticationResult result; if (parameters.AuthenticationMethod == SqlAuthenticationMethod.ActiveDirectoryServicePrincipal) { - AccessToken accessToken = await new ClientSecretCredential(tenantId, parameters.UserId, parameters.Password, tokenCredentialOptions).GetTokenAsync(tokenRequestContext, cts.Token); + AccessToken accessToken = await new ClientSecretCredential(tenantId, parameters.UserId, parameters.Password, tokenCredentialOptions).GetTokenAsync(tokenRequestContext, cts.Token).ConfigureAwait(false); SqlClientEventSource.Log.TryTraceEvent("AcquireTokenAsync | Acquired access token for Active Directory Service Principal auth mode. Expiry Time: {0}", accessToken.ExpiresOn); return new SqlAuthenticationToken(accessToken.Token, accessToken.ExpiresOn); } @@ -181,40 +181,13 @@ public override async Task AcquireTokenAsync(SqlAuthenti IPublicClientApplication app = GetPublicClientAppInstance(pcaKey); - if (parameters.AuthenticationMethod == SqlAuthenticationMethod.ActiveDirectoryIntegrated) - { - if (!string.IsNullOrEmpty(parameters.UserId)) - { - result = await app.AcquireTokenByIntegratedWindowsAuth(scopes) - .WithCorrelationId(parameters.ConnectionId) - .WithUsername(parameters.UserId) - .ExecuteAsync(cancellationToken: cts.Token); - } - else - { - result = await app.AcquireTokenByIntegratedWindowsAuth(scopes) - .WithCorrelationId(parameters.ConnectionId) - .ExecuteAsync(cancellationToken: cts.Token); - } - SqlClientEventSource.Log.TryTraceEvent("AcquireTokenAsync | Acquired access token for Active Directory Integrated auth mode. Expiry Time: {0}", result?.ExpiresOn); - } - else if (parameters.AuthenticationMethod == SqlAuthenticationMethod.ActiveDirectoryPassword) - { - SecureString password = new SecureString(); - foreach (char c in parameters.Password) - password.AppendChar(c); - password.MakeReadOnly(); - - result = await app.AcquireTokenByUsernamePassword(scopes, parameters.UserId, password) - .WithCorrelationId(parameters.ConnectionId) - .ExecuteAsync(cancellationToken: cts.Token); - SqlClientEventSource.Log.TryTraceEvent("AcquireTokenAsync | Acquired access token for Active Directory Password auth mode. Expiry Time: {0}", result?.ExpiresOn); - } - else if (parameters.AuthenticationMethod == SqlAuthenticationMethod.ActiveDirectoryInteractive || - parameters.AuthenticationMethod == SqlAuthenticationMethod.ActiveDirectoryDeviceCodeFlow) + if (parameters.AuthenticationMethod == SqlAuthenticationMethod.ActiveDirectoryIntegrated || + parameters.AuthenticationMethod == SqlAuthenticationMethod.ActiveDirectoryPassword || + parameters.AuthenticationMethod == SqlAuthenticationMethod.ActiveDirectoryInteractive || + parameters.AuthenticationMethod == SqlAuthenticationMethod.ActiveDirectoryDeviceCodeFlow) { // Fetch available accounts from 'app' instance - System.Collections.Generic.IEnumerator accounts = (await app.GetAccountsAsync()).GetEnumerator(); + System.Collections.Generic.IEnumerator accounts = (await app.GetAccountsAsync().ConfigureAwait(false)).GetEnumerator(); IAccount account = default; if (accounts.MoveNext()) @@ -244,7 +217,7 @@ public override async Task AcquireTokenAsync(SqlAuthenti { // If 'account' is available in 'app', we use the same to acquire token silently. // Read More on API docs: https://docs.microsoft.com/dotnet/api/microsoft.identity.client.clientapplicationbase.acquiretokensilent - result = await app.AcquireTokenSilent(scopes, account).ExecuteAsync(cancellationToken: cts.Token); + result = await app.AcquireTokenSilent(scopes, account).ExecuteAsync(cancellationToken: cts.Token).ConfigureAwait(false); SqlClientEventSource.Log.TryTraceEvent("AcquireTokenAsync | Acquired access token (silent) for {0} auth mode. Expiry Time: {1}", parameters.AuthenticationMethod, result?.ExpiresOn); } catch (MsalUiRequiredException) @@ -252,15 +225,58 @@ public override async Task AcquireTokenAsync(SqlAuthenti // An 'MsalUiRequiredException' is thrown in the case where an interaction is required with the end user of the application, // for instance, if no refresh token was in the cache, or the user needs to consent, or re-sign-in (for instance if the password expired), // or the user needs to perform two factor authentication. - result = await AcquireTokenInteractiveDeviceFlowAsync(app, scopes, parameters.ConnectionId, parameters.UserId, parameters.AuthenticationMethod, cts); - SqlClientEventSource.Log.TryTraceEvent("AcquireTokenAsync | Acquired access token (interactive) for {0} auth mode. Expiry Time: {1}", parameters.AuthenticationMethod, result?.ExpiresOn); + if (parameters.AuthenticationMethod == SqlAuthenticationMethod.ActiveDirectoryInteractive || + parameters.AuthenticationMethod == SqlAuthenticationMethod.ActiveDirectoryDeviceCodeFlow) + { + result = await AcquireTokenInteractiveDeviceFlowAsync(app, scopes, parameters.ConnectionId, parameters.UserId, parameters.AuthenticationMethod, cts).ConfigureAwait(false); + SqlClientEventSource.Log.TryTraceEvent("AcquireTokenAsync | Acquired access token (interactive) for {0} auth mode. Expiry Time: {1}", parameters.AuthenticationMethod, result?.ExpiresOn); + } + else + { + throw; + } } } else { - // If no existing 'account' is found, we request user to sign in interactively. - result = await AcquireTokenInteractiveDeviceFlowAsync(app, scopes, parameters.ConnectionId, parameters.UserId, parameters.AuthenticationMethod, cts); - SqlClientEventSource.Log.TryTraceEvent("AcquireTokenAsync | Acquired access token (interactive) for {0} auth mode. Expiry Time: {1}", parameters.AuthenticationMethod, result?.ExpiresOn); + if (parameters.AuthenticationMethod == SqlAuthenticationMethod.ActiveDirectoryIntegrated) + { + if (!string.IsNullOrEmpty(parameters.UserId)) + { + result = await app.AcquireTokenByIntegratedWindowsAuth(scopes) + .WithCorrelationId(parameters.ConnectionId) + .WithUsername(parameters.UserId) + .ExecuteAsync(cancellationToken: cts.Token) + .ConfigureAwait(false); + } + else + { + result = await app.AcquireTokenByIntegratedWindowsAuth(scopes) + .WithCorrelationId(parameters.ConnectionId) + .ExecuteAsync(cancellationToken: cts.Token) + .ConfigureAwait(false); + } + SqlClientEventSource.Log.TryTraceEvent("AcquireTokenAsync | Acquired access token for Active Directory Integrated auth mode. Expiry Time: {0}", result?.ExpiresOn); + } + else if (parameters.AuthenticationMethod == SqlAuthenticationMethod.ActiveDirectoryPassword) + { + SecureString password = new SecureString(); + foreach (char c in parameters.Password) + password.AppendChar(c); + password.MakeReadOnly(); + + result = await app.AcquireTokenByUsernamePassword(scopes, parameters.UserId, password) + .WithCorrelationId(parameters.ConnectionId) + .ExecuteAsync(cancellationToken: cts.Token) + .ConfigureAwait(false); + SqlClientEventSource.Log.TryTraceEvent("AcquireTokenAsync | Acquired access token for Active Directory Password auth mode. Expiry Time: {0}", result?.ExpiresOn); + } + else + { + // If no existing 'account' is found, we request user to sign in interactively. + result = await AcquireTokenInteractiveDeviceFlowAsync(app, scopes, parameters.ConnectionId, parameters.UserId, parameters.AuthenticationMethod, cts).ConfigureAwait(false); + SqlClientEventSource.Log.TryTraceEvent("AcquireTokenAsync | Acquired access token (interactive) for {0} auth mode. Expiry Time: {1}", parameters.AuthenticationMethod, result?.ExpiresOn); + } } } else @@ -298,7 +314,8 @@ public override async Task AcquireTokenAsync(SqlAuthenti .WithCorrelationId(connectionId) .WithCustomWebUi(_customWebUI) .WithLoginHint(userId) - .ExecuteAsync(ctsInteractive.Token); + .ExecuteAsync(ctsInteractive.Token) + .ConfigureAwait(false); } else { @@ -322,7 +339,8 @@ public override async Task AcquireTokenAsync(SqlAuthenti return await app.AcquireTokenInteractive(scopes) .WithCorrelationId(connectionId) .WithLoginHint(userId) - .ExecuteAsync(ctsInteractive.Token); + .ExecuteAsync(ctsInteractive.Token) + .ConfigureAwait(false); } } else @@ -330,7 +348,8 @@ public override async Task AcquireTokenAsync(SqlAuthenti AuthenticationResult result = await app.AcquireTokenWithDeviceCode(scopes, deviceCodeResult => _deviceCodeFlowCallback(deviceCodeResult)) .WithCorrelationId(connectionId) - .ExecuteAsync(cancellationToken: cts.Token); + .ExecuteAsync(cancellationToken: cts.Token) + .ConfigureAwait(false); return result; } }