diff --git a/src/Service.Tests/SqlTests/SqlTestBase.cs b/src/Service.Tests/SqlTests/SqlTestBase.cs index 2d6cd68a11..65abf290d9 100644 --- a/src/Service.Tests/SqlTests/SqlTestBase.cs +++ b/src/Service.Tests/SqlTests/SqlTestBase.cs @@ -84,7 +84,7 @@ protected static async Task InitializeTestFixture(TestContext context, List>> pgQueryExecutorLogger = new(); + Mock> pgQueryExecutorLogger = new(); _queryBuilder = new PostgresQueryBuilder(); _defaultSchemaName = "public"; _dbExceptionParser = new PostgreSqlDbExceptionParser(_runtimeConfigProvider); - _queryExecutor = new QueryExecutor( + _queryExecutor = new PostgreSqlQueryExecutor( _runtimeConfigProvider, _dbExceptionParser, pgQueryExecutorLogger.Object); diff --git a/src/Service.Tests/Unittests/PostgreSqlQueryExecutorUnitTests.cs b/src/Service.Tests/Unittests/PostgreSqlQueryExecutorUnitTests.cs new file mode 100644 index 0000000000..aacafc0fbf --- /dev/null +++ b/src/Service.Tests/Unittests/PostgreSqlQueryExecutorUnitTests.cs @@ -0,0 +1,90 @@ +using System; +using System.Collections.Generic; +using System.Text.Json; +using System.Threading.Tasks; +using Azure.Core; +using Azure.DataApiBuilder.Service.Configurations; +using Azure.DataApiBuilder.Service.Resolvers; +using Azure.Identity; +using Microsoft.Extensions.Logging; +using Microsoft.VisualStudio.TestTools.UnitTesting; +using Moq; +using Npgsql; + +namespace Azure.DataApiBuilder.Service.Tests.UnitTests +{ + [TestClass, TestCategory(TestCategory.POSTGRESQL)] + public class PostgreSqlQueryExecutorUnitTests + { + /// + /// Validates managed identity token issued ONLY when connection string does not specify password + /// + [DataTestMethod] + [DataRow("Server =<>;Database=<>;Username=xyz;", false, false, + DisplayName = "No managed identity access token even when connection string specifies Username only.")] + [DataRow("Server =<>;Database=<>;Username=xyz;", true, false, + DisplayName = "Managed identity access token from config used when connection string specifies Username only.")] + [DataRow("Server =<>;Database=<>;Username=xyz;", true, true, + DisplayName = "Default managed identity access token used when connection string specifies Username only.")] + [DataRow("Server =<>;Database=<>;Password=xyz;", false, false, + DisplayName = "No managed identity access token when connection string specifies Password only.")] + [DataRow("Server =<>;Database=<>;Username=xyz;Password=xxx", false, false, + DisplayName = "No managed identity access token when connection string specifies both Username and Password.")] + public async Task TestHandleManagedIdentityAccess( + string connectionString, + bool expectManagedIdentityAccessToken, + bool isDefaultAzureCredential) + { + RuntimeConfigProvider runtimeConfigProvider = TestHelper.GetRuntimeConfigProvider(TestCategory.POSTGRESQL); + runtimeConfigProvider.GetRuntimeConfiguration().ConnectionString = connectionString; + Mock dbExceptionParser = new(runtimeConfigProvider, new HashSet()); + Mock> queryExecutorLogger = new(); + PostgreSqlQueryExecutor postgreSqlQueryExecutor = new(runtimeConfigProvider, dbExceptionParser.Object, queryExecutorLogger.Object); + + const string DEFAULT_TOKEN = "Default access token"; + const string CONFIG_TOKEN = "Configuration controller access token"; + AccessToken testValidToken = new(accessToken: DEFAULT_TOKEN, expiresOn: DateTimeOffset.MaxValue); + if (expectManagedIdentityAccessToken) + { + if (isDefaultAzureCredential) + { + Mock dacMock = new(); + dacMock + .Setup(m => m.GetTokenAsync(It.IsAny(), + It.IsAny())) + .Returns(ValueTask.FromResult(testValidToken)); + postgreSqlQueryExecutor.AzureCredential = dacMock.Object; + } + else + { + runtimeConfigProvider.Initialize( + JsonSerializer.Serialize(runtimeConfigProvider.GetRuntimeConfiguration()), + schema: null, + connectionString: connectionString, + accessToken: CONFIG_TOKEN); + postgreSqlQueryExecutor = new(runtimeConfigProvider, dbExceptionParser.Object, queryExecutorLogger.Object); + } + } + + using NpgsqlConnection conn = new(connectionString); + await postgreSqlQueryExecutor.SetManagedIdentityAccessTokenIfAnyAsync(conn); + NpgsqlConnectionStringBuilder connStringBuilder = new(conn.ConnectionString); + + if (expectManagedIdentityAccessToken) + { + if (isDefaultAzureCredential) + { + Assert.AreEqual(expected: DEFAULT_TOKEN, actual: connStringBuilder.Password); + } + else + { + Assert.AreEqual(expected: CONFIG_TOKEN, actual: connStringBuilder.Password); + } + } + else + { + Assert.AreEqual(connectionString, conn.ConnectionString); + } + } + } +} diff --git a/src/Service/Resolvers/PostgreSqlExecutor.cs b/src/Service/Resolvers/PostgreSqlExecutor.cs new file mode 100644 index 0000000000..870ba6e3e3 --- /dev/null +++ b/src/Service/Resolvers/PostgreSqlExecutor.cs @@ -0,0 +1,151 @@ +using System; +using System.Data.Common; +using System.Threading.Tasks; +using Azure.Core; +using Azure.DataApiBuilder.Service.Configurations; +using Azure.Identity; +using Microsoft.Extensions.Logging; +using Npgsql; + +namespace Azure.DataApiBuilder.Service.Resolvers +{ + /// + /// Specialized QueryExecutor for PostgreSql mainly providing methods to + /// handle connecting to the database with a managed identity. + /// for more info: https://learn.microsoft.com/EN-us/azure/postgresql/single-server/how-to-connect-with-managed-identity + /// + public class PostgreSqlQueryExecutor : QueryExecutor + { + // This is the same scope for any Azure Database for PostgreSQL that is + // required to request a default azure credential access token + // for a managed identity. + public const string DATABASE_SCOPE = @"https://ossrdbms-aad.database.windows.net/.default"; + + /// + /// The managed identity Access Token string obtained + /// from the configuration controller. + /// + private readonly string? _accessTokenFromController; + + public DefaultAzureCredential AzureCredential { get; set; } = new(); + + /// + /// The saved cached access token obtained from DefaultAzureCredentials + /// representing a managed identity. + /// + private AccessToken? _defaultAccessToken; + + private bool _attemptToSetAccessToken; + + public PostgreSqlQueryExecutor( + RuntimeConfigProvider runtimeConfigProvider, + DbExceptionParser dbExceptionParser, + ILogger> logger) + : base(runtimeConfigProvider, dbExceptionParser, logger) + { + _accessTokenFromController = runtimeConfigProvider.ManagedIdentityAccessToken; + _attemptToSetAccessToken = + ShouldManagedIdentityAccessBeAttempted(runtimeConfigProvider.GetRuntimeConfiguration().ConnectionString); + } + + /// + /// Modifies the properties of the supplied connection string to support managed identity access. + /// In the case of Postgres, if a default managed identity needs to be used, the password in the + /// connection needs to be replaced with the default access token. + /// + /// The supplied connection to modify for managed identity access. + public override async Task SetManagedIdentityAccessTokenIfAnyAsync(DbConnection conn) + { + // Only attempt to get the access token if the connection string is in the appropriate format + if (_attemptToSetAccessToken) + { + NpgsqlConnection sqlConn = (NpgsqlConnection)conn; + + // If the configuration controller provided a managed identity access token use that, + // else use the default saved access token if still valid. + // Get a new token only if the saved token is null or expired. + string? accessToken = _accessTokenFromController ?? + (IsDefaultAccessTokenValid() ? + ((AccessToken)_defaultAccessToken!).Token : + await GetAccessTokenAsync()); + + if (accessToken is not null) + { + NpgsqlConnectionStringBuilder newConnectionString = new(sqlConn.ConnectionString) + { + Password = accessToken + }; + sqlConn.ConnectionString = newConnectionString.ToString(); + } + } + } + + /// + /// Determines if managed identity access should be attempted or not. + /// It should only be attempted if the password is not provided + /// + private static bool ShouldManagedIdentityAccessBeAttempted(string connString) + { + NpgsqlConnectionStringBuilder connStringBuilder = new(connString); + return string.IsNullOrEmpty(connStringBuilder.Password); + } + + /// + /// Determines if the saved default azure credential's access token is valid and not expired. + /// + /// True if valid, false otherwise. + private bool IsDefaultAccessTokenValid() + { + return _defaultAccessToken is not null && + ((AccessToken)_defaultAccessToken).ExpiresOn.CompareTo(System.DateTimeOffset.Now) > 0; + } + + /// + /// Tries to get an access token using DefaultAzureCredentials. + /// Catches any CredentialUnavailableException and logs only a warning + /// since this is best effort. + /// + /// The string representation of the access token if found, + /// null otherwise. + private async Task GetAccessTokenAsync() + { + bool firstAttemptAtDefaultAccessToken = _defaultAccessToken is null; + + try + { + _defaultAccessToken = + await AzureCredential.GetTokenAsync( + new TokenRequestContext(new[] { DATABASE_SCOPE })); + } + // because there can be scenarios where password is not specified but + // default managed identity is not the intended method of authentication + // so a bunch of different exceptions could occur in that scenario + catch (Exception ex) + { + QueryExecutorLogger.LogWarning($"No password detected in the connection string. Attempt to retrieve " + + $"a managed identity access token using DefaultAzureCredential failed due to: \n{ex}\n" + + (firstAttemptAtDefaultAccessToken ? + $"If authentication with DefaultAzureCrendential is not intended, this warning can be safely ignored." : + string.Empty)); + + // the config doesn't contain an identity token + // and a default identity token cannot be obtained + // so the application should not attempt to set the token + // for future conntions + // note though that if a default access token has been previously + // obtained successfully (firstAttemptAtDefaultAccessToken == false) + // this might be a transitory failure don't disable attempts to set + // the token + // + // disabling the attempts is useful in scenarios where the user + // has a valid connection string without a password in it + if (firstAttemptAtDefaultAccessToken) + { + _attemptToSetAccessToken = false; + } + } + + return _defaultAccessToken?.Token; + } + } +} diff --git a/src/Service/Startup.cs b/src/Service/Startup.cs index deccc3b56b..bc0f6fc4cc 100644 --- a/src/Service/Startup.cs +++ b/src/Service/Startup.cs @@ -26,7 +26,6 @@ using Microsoft.Extensions.DependencyInjection; using Microsoft.Extensions.Hosting; using Microsoft.Extensions.Logging; -using Npgsql; namespace Azure.DataApiBuilder.Service { @@ -107,7 +106,7 @@ public void ConfigureServices(IServiceCollection services) case DatabaseType.mssql: return ActivatorUtilities.GetServiceOrCreateInstance(serviceProvider); case DatabaseType.postgresql: - return ActivatorUtilities.GetServiceOrCreateInstance>(serviceProvider); + return ActivatorUtilities.GetServiceOrCreateInstance(serviceProvider); case DatabaseType.mysql: return ActivatorUtilities.GetServiceOrCreateInstance(serviceProvider); default: