Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions src/Service.Tests/SqlTests/SqlTestBase.cs
Original file line number Diff line number Diff line change
Expand Up @@ -261,11 +261,11 @@ protected static void SetUpSQLMetadataProvider()
_sqlMetadataLogger);
break;
case TestCategory.MYSQL:
Mock<ILogger<QueryExecutor<MySqlConnection>>> mySqlQueryExecutorLogger = new();
Mock<ILogger<MySqlQueryExecutor>> mySqlQueryExecutorLogger = new();
_queryBuilder = new MySqlQueryBuilder();
_defaultSchemaName = "mysql";
_dbExceptionParser = new MySqlDbExceptionParser(_runtimeConfigProvider);
_queryExecutor = new QueryExecutor<MySqlConnection>(
_queryExecutor = new MySqlQueryExecutor(
_runtimeConfigProvider,
_dbExceptionParser,
mySqlQueryExecutorLogger.Object);
Expand Down
89 changes: 89 additions & 0 deletions src/Service.Tests/Unittests/MySqlQueryExecutorUnitTests.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,89 @@
using System;
using System.Collections.Generic;
using System.Text.Json;
using System.Threading.Tasks;
using Azure.Core;
using Azure.DataApiBuilder.Service.Configurations;
using Azure.DataApiBuilder.Service.Resolvers;
using Azure.Identity;
using Microsoft.Extensions.Logging;
using Microsoft.VisualStudio.TestTools.UnitTesting;
using Moq;
using MySqlConnector;

namespace Azure.DataApiBuilder.Service.Tests.UnitTests
{
[TestClass, TestCategory(TestCategory.MYSQL)]
public class MySqlQueryExecutorUnitTests
{
/// <summary>
/// Validates managed identity token issued ONLY when connection string does not specify
/// User, Password, and Authentication method.
/// </summary>
[DataTestMethod]
[DataRow("Server =<>;Database=<>;User=xyz;Password=xxx", false, false,
DisplayName = "No managed identity access token when connection string specifies both User and Password.")]
[DataRow("Server =<>;Database=<>;User=xyz;Password=xxx;", false, false,
DisplayName = "No managed identity access token when connection string specifies User, Password.")]
[DataRow("Server =<>;Database=<>;User=xyz;", true, false,
DisplayName = "Managed identity access token from config used when connection string specifies User but not the Password.")]
[DataRow("Server =<>;Database=<>;User=xyz;", true, true,
DisplayName = "Managed identity access token from Default Azure Credential used when connection string specifies User but not the Password.")]
public async Task TestHandleManagedIdentityAccess(
string connectionString,
bool expectManagedIdentityAccessToken,
bool isDefaultAzureCredential)
{
RuntimeConfigProvider runtimeConfigProvider = TestHelper.GetRuntimeConfigProvider(TestCategory.MYSQL);
runtimeConfigProvider.GetRuntimeConfiguration().ConnectionString = connectionString;
Mock<DbExceptionParser> dbExceptionParser = new(runtimeConfigProvider, new HashSet<string>());
Mock<ILogger<MySqlQueryExecutor>> queryExecutorLogger = new();
MySqlQueryExecutor mySqlQueryExecutor = new(runtimeConfigProvider, dbExceptionParser.Object, queryExecutorLogger.Object);

const string DEFAULT_TOKEN = "Default access token";
const string CONFIG_TOKEN = "Configuration controller access token";
AccessToken testValidToken = new(accessToken: DEFAULT_TOKEN, expiresOn: DateTimeOffset.MaxValue);
if (expectManagedIdentityAccessToken)
{
if (isDefaultAzureCredential)
{
Mock<DefaultAzureCredential> dacMock = new();
dacMock
.Setup(m => m.GetTokenAsync(It.IsAny<TokenRequestContext>(),
It.IsAny<System.Threading.CancellationToken>()))
.Returns(ValueTask.FromResult(testValidToken));
mySqlQueryExecutor.AzureCredential = dacMock.Object;
}
else
{
runtimeConfigProvider.Initialize(
JsonSerializer.Serialize(runtimeConfigProvider.GetRuntimeConfiguration()),
schema: null,
connectionString: connectionString,
accessToken: CONFIG_TOKEN);
mySqlQueryExecutor = new(runtimeConfigProvider, dbExceptionParser.Object, queryExecutorLogger.Object);
}
}

using MySqlConnection conn = new(connectionString);
await mySqlQueryExecutor.SetManagedIdentityAccessTokenIfAnyAsync(conn);
MySqlConnectionStringBuilder my = new(conn.ConnectionString);

if (expectManagedIdentityAccessToken)
{
if (isDefaultAzureCredential)
{
Assert.AreEqual(expected: DEFAULT_TOKEN, actual: my.Password);
}
else
{
Assert.AreEqual(expected: CONFIG_TOKEN, actual: my.Password);
}
}
else
{
Assert.AreEqual(expected: "xxx", actual: my.Password);
}
}
}
}
131 changes: 131 additions & 0 deletions src/Service/Resolvers/MySqlQueryExecutor.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,131 @@
using System.Data.Common;
using System.Threading.Tasks;
using Azure.Core;
using Azure.DataApiBuilder.Service.Configurations;
using Azure.Identity;
using Microsoft.Extensions.Logging;
using MySqlConnector;

namespace Azure.DataApiBuilder.Service.Resolvers
{
/// <summary>
/// Specialized QueryExecutor for MySql mainly providing methods to
/// handle connecting to the database with a managed identity.
/// /// </summary>
public class MySqlQueryExecutor : QueryExecutor<MySqlConnection>
{
// This is the same scope for any Azure SQL database that is
// required to request a default azure credential access token
// for a managed identity.
public const string DATABASE_SCOPE = @"https://ossrdbms-aad.database.windows.net/.default";

/// <summary>
/// The managed identity Access Token string obtained
/// from the configuration controller.
/// </summary>
private readonly string? _accessTokenFromController;

public DefaultAzureCredential AzureCredential { get; set; } = new();

/// <summary>
/// The saved cached access token obtained from DefaultAzureCredentials
/// representing a managed identity.
/// </summary>
private AccessToken? _defaultAccessToken;

private bool _attemptToSetAccessToken;

public MySqlQueryExecutor(
RuntimeConfigProvider runtimeConfigProvider,
DbExceptionParser dbExceptionParser,
ILogger<QueryExecutor<MySqlConnection>> logger)
: base(runtimeConfigProvider, dbExceptionParser, logger)
{
_accessTokenFromController = runtimeConfigProvider.ManagedIdentityAccessToken;
_attemptToSetAccessToken =
ShouldManagedIdentityAccessBeAttempted(runtimeConfigProvider.GetRuntimeConfiguration().ConnectionString);
}

/// <summary>
/// Modifies the properties of the supplied connection to support managed identity access.
/// In the case of MySql, gets access token if deemed necessary and sets it on the connection.
/// The supplied connection is assumed to already have the same connection string
/// provided in the runtime configuration.
/// </summary>
/// <param name="conn">The supplied connection to modify for managed identity access.</param>
public override async Task SetManagedIdentityAccessTokenIfAnyAsync(DbConnection conn)
{
// Only attempt to get the access token if the connection string is in the appropriate format
if (_attemptToSetAccessToken)
{

// If the configuration controller provided a managed identity access token use that,
// else use the default saved access token if still valid.
// Get a new token only if the saved token is null or expired.
string? accessToken = _accessTokenFromController ??
(IsDefaultAccessTokenValid() ?
((AccessToken)_defaultAccessToken!).Token :
await GetAccessTokenAsync());

if (accessToken is not null)
{
MySqlConnectionStringBuilder connstr = new(conn.ConnectionString)
{
Password = accessToken
};
conn.ConnectionString = connstr.ConnectionString;
}
}
}

/// <summary>
/// Determines if managed identity access should be attempted or not.
/// It should only be attempted,
/// 1. If none of UserID, Password or Authentication
/// method are specified in the connection string since they have higher precedence
/// and any attempt to use an access token in their presence would lead to
/// a System.InvalidOperationException.
/// 2. It is NOT a Windows Integrated Security scenario.
/// </summary>
private static bool ShouldManagedIdentityAccessBeAttempted(string connString)
{
MySqlConnectionStringBuilder connStringBuilder = new(connString);
return !string.IsNullOrEmpty(connStringBuilder.UserID) &&
string.IsNullOrEmpty(connStringBuilder.Password);
}

/// <summary>
/// Determines if the saved default azure credential's access token is valid and not expired.
/// </summary>
/// <returns>True if valid, false otherwise.</returns>
private bool IsDefaultAccessTokenValid()
{
return _defaultAccessToken is not null &&
((AccessToken)_defaultAccessToken).ExpiresOn.CompareTo(System.DateTimeOffset.Now) > 0;
}

/// <summary>
/// Tries to get an access token using DefaultAzureCredentials.
/// Catches any CredentialUnavailableException and logs only a warning
/// since this is best effort.
/// </summary>
/// <returns>The string representation of the access token if found,
/// null otherwise.</returns>
private async Task<string?> GetAccessTokenAsync()
{
try
{
_defaultAccessToken =
await AzureCredential.GetTokenAsync(
new TokenRequestContext(new[] { DATABASE_SCOPE }));
}
catch (CredentialUnavailableException ex)
{
QueryExecutorLogger.LogWarning($"Attempt to retrieve a managed identity access token using DefaultAzureCredential" +
$" failed due to: \n{ex}");
}

return _defaultAccessToken?.Token;
}
}
}
3 changes: 1 addition & 2 deletions src/Service/Startup.cs
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,6 @@
using Microsoft.Extensions.DependencyInjection;
using Microsoft.Extensions.Hosting;
using Microsoft.Extensions.Logging;
using MySqlConnector;
using Npgsql;

namespace Azure.DataApiBuilder.Service
Expand Down Expand Up @@ -110,7 +109,7 @@ public void ConfigureServices(IServiceCollection services)
case DatabaseType.postgresql:
return ActivatorUtilities.GetServiceOrCreateInstance<QueryExecutor<NpgsqlConnection>>(serviceProvider);
case DatabaseType.mysql:
return ActivatorUtilities.GetServiceOrCreateInstance<QueryExecutor<MySqlConnection>>(serviceProvider);
return ActivatorUtilities.GetServiceOrCreateInstance<MySqlQueryExecutor>(serviceProvider);
default:
throw new NotSupportedException(
runtimeConfig.DatabaseTypeNotSupportedMessage);
Expand Down