diff --git a/src/Service.Tests/SqlTests/SqlTestBase.cs b/src/Service.Tests/SqlTests/SqlTestBase.cs index 70d95a5984..2d6cd68a11 100644 --- a/src/Service.Tests/SqlTests/SqlTestBase.cs +++ b/src/Service.Tests/SqlTests/SqlTestBase.cs @@ -261,11 +261,11 @@ protected static void SetUpSQLMetadataProvider() _sqlMetadataLogger); break; case TestCategory.MYSQL: - Mock>> mySqlQueryExecutorLogger = new(); + Mock> mySqlQueryExecutorLogger = new(); _queryBuilder = new MySqlQueryBuilder(); _defaultSchemaName = "mysql"; _dbExceptionParser = new MySqlDbExceptionParser(_runtimeConfigProvider); - _queryExecutor = new QueryExecutor( + _queryExecutor = new MySqlQueryExecutor( _runtimeConfigProvider, _dbExceptionParser, mySqlQueryExecutorLogger.Object); diff --git a/src/Service.Tests/Unittests/MySqlQueryExecutorUnitTests.cs b/src/Service.Tests/Unittests/MySqlQueryExecutorUnitTests.cs new file mode 100644 index 0000000000..15b33d916e --- /dev/null +++ b/src/Service.Tests/Unittests/MySqlQueryExecutorUnitTests.cs @@ -0,0 +1,89 @@ +using System; +using System.Collections.Generic; +using System.Text.Json; +using System.Threading.Tasks; +using Azure.Core; +using Azure.DataApiBuilder.Service.Configurations; +using Azure.DataApiBuilder.Service.Resolvers; +using Azure.Identity; +using Microsoft.Extensions.Logging; +using Microsoft.VisualStudio.TestTools.UnitTesting; +using Moq; +using MySqlConnector; + +namespace Azure.DataApiBuilder.Service.Tests.UnitTests +{ + [TestClass, TestCategory(TestCategory.MYSQL)] + public class MySqlQueryExecutorUnitTests + { + /// + /// Validates managed identity token issued ONLY when connection string does not specify + /// User, Password, and Authentication method. + /// + [DataTestMethod] + [DataRow("Server =<>;Database=<>;User=xyz;Password=xxx", false, false, + DisplayName = "No managed identity access token when connection string specifies both User and Password.")] + [DataRow("Server =<>;Database=<>;User=xyz;Password=xxx;", false, false, + DisplayName = "No managed identity access token when connection string specifies User, Password.")] + [DataRow("Server =<>;Database=<>;User=xyz;", true, false, + DisplayName = "Managed identity access token from config used when connection string specifies User but not the Password.")] + [DataRow("Server =<>;Database=<>;User=xyz;", true, true, + DisplayName = "Managed identity access token from Default Azure Credential used when connection string specifies User but not the Password.")] + public async Task TestHandleManagedIdentityAccess( + string connectionString, + bool expectManagedIdentityAccessToken, + bool isDefaultAzureCredential) + { + RuntimeConfigProvider runtimeConfigProvider = TestHelper.GetRuntimeConfigProvider(TestCategory.MYSQL); + runtimeConfigProvider.GetRuntimeConfiguration().ConnectionString = connectionString; + Mock dbExceptionParser = new(runtimeConfigProvider, new HashSet()); + Mock> queryExecutorLogger = new(); + MySqlQueryExecutor mySqlQueryExecutor = new(runtimeConfigProvider, dbExceptionParser.Object, queryExecutorLogger.Object); + + const string DEFAULT_TOKEN = "Default access token"; + const string CONFIG_TOKEN = "Configuration controller access token"; + AccessToken testValidToken = new(accessToken: DEFAULT_TOKEN, expiresOn: DateTimeOffset.MaxValue); + if (expectManagedIdentityAccessToken) + { + if (isDefaultAzureCredential) + { + Mock dacMock = new(); + dacMock + .Setup(m => m.GetTokenAsync(It.IsAny(), + It.IsAny())) + .Returns(ValueTask.FromResult(testValidToken)); + mySqlQueryExecutor.AzureCredential = dacMock.Object; + } + else + { + runtimeConfigProvider.Initialize( + JsonSerializer.Serialize(runtimeConfigProvider.GetRuntimeConfiguration()), + schema: null, + connectionString: connectionString, + accessToken: CONFIG_TOKEN); + mySqlQueryExecutor = new(runtimeConfigProvider, dbExceptionParser.Object, queryExecutorLogger.Object); + } + } + + using MySqlConnection conn = new(connectionString); + await mySqlQueryExecutor.SetManagedIdentityAccessTokenIfAnyAsync(conn); + MySqlConnectionStringBuilder my = new(conn.ConnectionString); + + if (expectManagedIdentityAccessToken) + { + if (isDefaultAzureCredential) + { + Assert.AreEqual(expected: DEFAULT_TOKEN, actual: my.Password); + } + else + { + Assert.AreEqual(expected: CONFIG_TOKEN, actual: my.Password); + } + } + else + { + Assert.AreEqual(expected: "xxx", actual: my.Password); + } + } + } +} diff --git a/src/Service/Resolvers/MySqlQueryExecutor.cs b/src/Service/Resolvers/MySqlQueryExecutor.cs new file mode 100644 index 0000000000..68070d2b5e --- /dev/null +++ b/src/Service/Resolvers/MySqlQueryExecutor.cs @@ -0,0 +1,131 @@ +using System.Data.Common; +using System.Threading.Tasks; +using Azure.Core; +using Azure.DataApiBuilder.Service.Configurations; +using Azure.Identity; +using Microsoft.Extensions.Logging; +using MySqlConnector; + +namespace Azure.DataApiBuilder.Service.Resolvers +{ + /// + /// Specialized QueryExecutor for MySql mainly providing methods to + /// handle connecting to the database with a managed identity. + /// /// + public class MySqlQueryExecutor : QueryExecutor + { + // This is the same scope for any Azure SQL database that is + // required to request a default azure credential access token + // for a managed identity. + public const string DATABASE_SCOPE = @"https://ossrdbms-aad.database.windows.net/.default"; + + /// + /// The managed identity Access Token string obtained + /// from the configuration controller. + /// + private readonly string? _accessTokenFromController; + + public DefaultAzureCredential AzureCredential { get; set; } = new(); + + /// + /// The saved cached access token obtained from DefaultAzureCredentials + /// representing a managed identity. + /// + private AccessToken? _defaultAccessToken; + + private bool _attemptToSetAccessToken; + + public MySqlQueryExecutor( + RuntimeConfigProvider runtimeConfigProvider, + DbExceptionParser dbExceptionParser, + ILogger> logger) + : base(runtimeConfigProvider, dbExceptionParser, logger) + { + _accessTokenFromController = runtimeConfigProvider.ManagedIdentityAccessToken; + _attemptToSetAccessToken = + ShouldManagedIdentityAccessBeAttempted(runtimeConfigProvider.GetRuntimeConfiguration().ConnectionString); + } + + /// + /// Modifies the properties of the supplied connection to support managed identity access. + /// In the case of MySql, gets access token if deemed necessary and sets it on the connection. + /// The supplied connection is assumed to already have the same connection string + /// provided in the runtime configuration. + /// + /// The supplied connection to modify for managed identity access. + public override async Task SetManagedIdentityAccessTokenIfAnyAsync(DbConnection conn) + { + // Only attempt to get the access token if the connection string is in the appropriate format + if (_attemptToSetAccessToken) + { + + // If the configuration controller provided a managed identity access token use that, + // else use the default saved access token if still valid. + // Get a new token only if the saved token is null or expired. + string? accessToken = _accessTokenFromController ?? + (IsDefaultAccessTokenValid() ? + ((AccessToken)_defaultAccessToken!).Token : + await GetAccessTokenAsync()); + + if (accessToken is not null) + { + MySqlConnectionStringBuilder connstr = new(conn.ConnectionString) + { + Password = accessToken + }; + conn.ConnectionString = connstr.ConnectionString; + } + } + } + + /// + /// Determines if managed identity access should be attempted or not. + /// It should only be attempted, + /// 1. If none of UserID, Password or Authentication + /// method are specified in the connection string since they have higher precedence + /// and any attempt to use an access token in their presence would lead to + /// a System.InvalidOperationException. + /// 2. It is NOT a Windows Integrated Security scenario. + /// + private static bool ShouldManagedIdentityAccessBeAttempted(string connString) + { + MySqlConnectionStringBuilder connStringBuilder = new(connString); + return !string.IsNullOrEmpty(connStringBuilder.UserID) && + string.IsNullOrEmpty(connStringBuilder.Password); + } + + /// + /// Determines if the saved default azure credential's access token is valid and not expired. + /// + /// True if valid, false otherwise. + private bool IsDefaultAccessTokenValid() + { + return _defaultAccessToken is not null && + ((AccessToken)_defaultAccessToken).ExpiresOn.CompareTo(System.DateTimeOffset.Now) > 0; + } + + /// + /// Tries to get an access token using DefaultAzureCredentials. + /// Catches any CredentialUnavailableException and logs only a warning + /// since this is best effort. + /// + /// The string representation of the access token if found, + /// null otherwise. + private async Task GetAccessTokenAsync() + { + try + { + _defaultAccessToken = + await AzureCredential.GetTokenAsync( + new TokenRequestContext(new[] { DATABASE_SCOPE })); + } + catch (CredentialUnavailableException ex) + { + QueryExecutorLogger.LogWarning($"Attempt to retrieve a managed identity access token using DefaultAzureCredential" + + $" failed due to: \n{ex}"); + } + + return _defaultAccessToken?.Token; + } + } +} diff --git a/src/Service/Startup.cs b/src/Service/Startup.cs index 272c06c67d..deccc3b56b 100644 --- a/src/Service/Startup.cs +++ b/src/Service/Startup.cs @@ -26,7 +26,6 @@ using Microsoft.Extensions.DependencyInjection; using Microsoft.Extensions.Hosting; using Microsoft.Extensions.Logging; -using MySqlConnector; using Npgsql; namespace Azure.DataApiBuilder.Service @@ -110,7 +109,7 @@ public void ConfigureServices(IServiceCollection services) case DatabaseType.postgresql: return ActivatorUtilities.GetServiceOrCreateInstance>(serviceProvider); case DatabaseType.mysql: - return ActivatorUtilities.GetServiceOrCreateInstance>(serviceProvider); + return ActivatorUtilities.GetServiceOrCreateInstance(serviceProvider); default: throw new NotSupportedException( runtimeConfig.DatabaseTypeNotSupportedMessage);