Skip to content

Commit

Permalink
Making datasourcename mandatory during executequeryasync for multi-db…
Browse files Browse the repository at this point in the history
… scenario. (#2112)

## Why make this change?
During the call to executeQuery, if DataSource Name is not specified,
the query will execute against the default dB. Currently this option is
an optional argument. This can lead to caller not sending the DataSource
Name as the method does not require it.

If used incorrectly, it will then execute against the default db which
can lead to errors during multi-db scenario. For example, in
SQLMetadataProvider, we have the queryexecutor call executequeryAsync
for readonly columns. Currently, that does not specify the datasource as
it is not required by the method. This means that will always execute
against the default db, even though the datasourceName property is
available on the SQLMetadataprovider which tells you which db to execute
this query against.

currently we execute the query against the default db (this was done to
maintain backward compatibility for single source scenarios). In this
change, we make it mandatory to specify the datasource name during a
call to executequeryAsync. String.Empty or null can be sent if query is
to be executed against default db.

## What is this change?
1. Change the executeAsync call to explicitly require for the the
dbname. The executeAsync in the queryExecutor is the final call made to
actually execute the query against the underlying db. By this time in
both rest and GQL, we have either determined the dbname or gone with the
default db name.
2. Updates to all callers to specify dbname explicitly.

## How was this tested?
1. Existing test cases are updated to account for mandatory passing of
the datasourceName.
3. Integration test done.
  • Loading branch information
rohkhann committed Mar 18, 2024
1 parent af42ac2 commit 31429e5
Show file tree
Hide file tree
Showing 9 changed files with 36 additions and 28 deletions.
6 changes: 3 additions & 3 deletions src/Core/Resolvers/IQueryExecutor.cs
Original file line number Diff line number Diff line change
Expand Up @@ -24,15 +24,15 @@ public interface IQueryExecutor
/// in the DbDataReader obtained after executing the query.</param>
/// <param name="httpContext">Current request httpContext.</param>
/// <param name="args">List of string arguments to the DbDataReader handler.</param>
/// <param name="dataSourceName">dataSourceName against which to run query.</param>
/// <param name="dataSourceName">dataSourceName against which to run query. Can specify null or empty to run against default db.</param>
/// <returns>An object formed using the results of the query as returned by the given handler.</returns>
public Task<TResult?> ExecuteQueryAsync<TResult>(
string sqltext,
IDictionary<string, DbConnectionParam> parameters,
Func<DbDataReader, List<string>?, Task<TResult>>? dataReaderHandler,
string dataSourceName,
HttpContext? httpContext = null,
List<string>? args = null,
string dataSourceName = "");
List<string>? args = null);

/// <summary>
/// Extracts the rows from the given DbDataReader to populate
Expand Down
4 changes: 2 additions & 2 deletions src/Core/Resolvers/QueryExecutor.cs
Original file line number Diff line number Diff line change
Expand Up @@ -65,9 +65,9 @@ public class QueryExecutor<TConnection> : IQueryExecutor
string sqltext,
IDictionary<string, DbConnectionParam> parameters,
Func<DbDataReader, List<string>?, Task<TResult>>? dataReaderHandler,
string dataSourceName,
HttpContext? httpContext = null,
List<string>? args = null,
string dataSourceName = "")
List<string>? args = null)
{
if (string.IsNullOrEmpty(dataSourceName))
{
Expand Down
9 changes: 5 additions & 4 deletions src/Core/Resolvers/SqlMutationEngine.cs
Original file line number Diff line number Diff line change
Expand Up @@ -277,6 +277,7 @@ public class SqlMutationEngine : IMutationEngine
queryText,
executeQueryStructure.Parameters,
queryExecutor.GetJsonArrayAsync,
dataSourceName,
GetHttpContext());

transactionScope.Complete();
Expand Down Expand Up @@ -840,9 +841,9 @@ private FindRequestContext ConstructFindRequestContext(RestRequestContext contex
queryString,
queryParameters,
queryExecutor.ExtractResultSetFromDbDataReader,
dataSourceName,
GetHttpContext(),
primaryKeyExposedColumnNames.Count > 0 ? primaryKeyExposedColumnNames : sourceDefinition.PrimaryKey,
dataSourceName);
primaryKeyExposedColumnNames.Count > 0 ? primaryKeyExposedColumnNames : sourceDefinition.PrimaryKey);

dbResultSetRow = dbResultSet is not null ?
(dbResultSet.Rows.FirstOrDefault() ?? new DbResultSetRow()) : null;
Expand Down Expand Up @@ -991,9 +992,9 @@ private FindRequestContext ConstructFindRequestContext(RestRequestContext contex
queryString,
queryParameters,
queryExecutor.GetMultipleResultSetsIfAnyAsync,
dataSourceName,
GetHttpContext(),
new List<string> { prettyPrintPk, entityName },
dataSourceName);
new List<string> { prettyPrintPk, entityName });
}

private Dictionary<string, object?> PrepareParameters(RestRequestContext context)
Expand Down
3 changes: 2 additions & 1 deletion src/Core/Services/MetadataProviders/MsSqlMetadataProvider.cs
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,8 @@ public override async Task PopulateTriggerMetadataForTable(string entityName, st
JsonArray? resultArray = await QueryExecutor.ExecuteQueryAsync(
sqltext: enumerateEnabledTriggers,
parameters: parameters,
dataReaderHandler: QueryExecutor.GetJsonArrayAsync);
dataReaderHandler: QueryExecutor.GetJsonArrayAsync,
dataSourceName: _dataSourceName);
using JsonDocument sqlResult = JsonDocument.Parse(resultArray!.ToJsonString());

foreach (JsonElement element in sqlResult.RootElement.EnumerateArray())
Expand Down
5 changes: 3 additions & 2 deletions src/Core/Services/MetadataProviders/SqlMetadataProvider.cs
Original file line number Diff line number Diff line change
Expand Up @@ -1183,7 +1183,8 @@ private async Task PopulateColumnDefinitionsWithReadOnlyFlag(string tableName, s
List<string>? readOnlyFields = await QueryExecutor.ExecuteQueryAsync(
sqltext: queryToGetReadOnlyColumns,
parameters: parameters,
dataReaderHandler: SummarizeReadOnlyFieldsMetadata);
dataReaderHandler: SummarizeReadOnlyFieldsMetadata,
dataSourceName: _dataSourceName);

if (readOnlyFields is not null && readOnlyFields.Count > 0)
{
Expand Down Expand Up @@ -1496,7 +1497,7 @@ private async Task PopulateForeignKeyDefinitionAsync()
// Gather all the referencing and referenced columns for each pair
// of referencing and referenced tables.
PairToFkDefinition = await QueryExecutor.ExecuteQueryAsync(
queryForForeignKeyInfo, parameters, SummarizeFkMetadata, httpContext: null, args: null, _dataSourceName);
queryForForeignKeyInfo, parameters, SummarizeFkMetadata, _dataSourceName, httpContext: null, args: null);

if (PairToFkDefinition is not null)
{
Expand Down
16 changes: 8 additions & 8 deletions src/Service.Tests/Caching/DabCacheServiceIntegrationTests.cs
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ public async Task FirstCacheServiceInvocationCallsFactory()
IReadOnlyList<object> actualExecuteQueryAsyncArguments = mockQueryExecutor.Invocations[0].Arguments;
Assert.AreEqual(expected: queryMetadata.QueryText, actual: actualExecuteQueryAsyncArguments[0], message: "QueryText " + ERROR_FAILED_ARG_PASSTHROUGH);
Assert.AreEqual(expected: queryMetadata.QueryParameters, actual: actualExecuteQueryAsyncArguments[1], message: "Query parameters " + ERROR_FAILED_ARG_PASSTHROUGH);
Assert.AreEqual(expected: queryMetadata.DataSource, actual: actualExecuteQueryAsyncArguments[5], message: "Data source " + ERROR_FAILED_ARG_PASSTHROUGH);
Assert.AreEqual(expected: queryMetadata.DataSource, actual: actualExecuteQueryAsyncArguments[3], message: "Data source " + ERROR_FAILED_ARG_PASSTHROUGH);
}

/// <summary>
Expand Down Expand Up @@ -217,7 +217,7 @@ public async Task CacheServiceFactoryInvocationReturnsNull()
IReadOnlyList<object> actualExecuteQueryAsyncArguments = mockQueryExecutor.Invocations[0].Arguments;
Assert.AreEqual(expected: queryMetadata.QueryText, actual: actualExecuteQueryAsyncArguments[0], message: "QueryText " + ERROR_FAILED_ARG_PASSTHROUGH);
Assert.AreEqual(expected: queryMetadata.QueryParameters, actual: actualExecuteQueryAsyncArguments[1], message: "Query parameters " + ERROR_FAILED_ARG_PASSTHROUGH);
Assert.AreEqual(expected: queryMetadata.DataSource, actual: actualExecuteQueryAsyncArguments[5], message: "Data source " + ERROR_FAILED_ARG_PASSTHROUGH);
Assert.AreEqual(expected: queryMetadata.DataSource, actual: actualExecuteQueryAsyncArguments[3], message: "Data source " + ERROR_FAILED_ARG_PASSTHROUGH);

// Validate that the null value retrned by the factory method is propogated through to and returned by the cache service.
Assert.AreEqual(expected: null, actual: result, message: "Expected factory to return a null result.");
Expand Down Expand Up @@ -422,19 +422,19 @@ private static Mock<IQueryExecutor> CreateMockQueryExecutor(string rawJsonRespon
It.IsAny<string>(),
It.IsAny<IDictionary<string, DbConnectionParam>>(),
It.IsAny<Func<DbDataReader?, List<string>?, Task<JsonElement?>>>(),
It.IsAny<string>(),
httpContext,
args,
It.IsAny<string>()).Result)
args).Result)
.Returns((JsonElement?)null);
break;
case ExecutorReturnType.Exception:
mockQueryExecutor.Setup(x => x.ExecuteQueryAsync(
It.IsAny<string>(),
It.IsAny<IDictionary<string, DbConnectionParam>>(),
It.IsAny<Func<DbDataReader?, List<string>?, Task<JsonElement?>>>(),
It.IsAny<string>(),
httpContext,
args,
It.IsAny<string>()).Result)
args).Result)
.Throws(new DataApiBuilderException(
message: "DB ERROR",
statusCode: HttpStatusCode.InternalServerError,
Expand All @@ -446,9 +446,9 @@ private static Mock<IQueryExecutor> CreateMockQueryExecutor(string rawJsonRespon
It.IsAny<string>(),
It.IsAny<IDictionary<string, DbConnectionParam>>(),
It.IsAny<Func<DbDataReader?, List<string>?, Task<JsonElement?>>>(),
It.IsAny<string>(),
httpContext,
args,
It.IsAny<string>()).Result)
args).Result)
.Returns(executorJsonResponse.RootElement.Clone());
break;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -133,7 +133,7 @@ private static async Task SetupDatabaseAsync(string compositeDbViewquery, string
public async Task TestCleanup()
{
string dropViewQuery = $"DROP VIEW IF EXISTS {_compositeViewName}";
await _queryExecutor.ExecuteQueryAsync<object>(dropViewQuery, parameters: null, dataReaderHandler: null);
await _queryExecutor.ExecuteQueryAsync<object>(dropViewQuery, parameters: null, dataReaderHandler: null, dataSourceName: string.Empty);
}
}
}
9 changes: 6 additions & 3 deletions src/Service.Tests/SqlTests/SqlTestBase.cs
Original file line number Diff line number Diff line change
Expand Up @@ -238,7 +238,7 @@ private static async Task ExecuteQueriesOnDbAsync(List<string> customQueries)
{
foreach (string query in customQueries)
{
await _queryExecutor.ExecuteQueryAsync<object>(query, parameters: null, dataReaderHandler: null);
await _queryExecutor.ExecuteQueryAsync<object>(query, dataSourceName: string.Empty, parameters: null, dataReaderHandler: null);
}
}
}
Expand Down Expand Up @@ -367,6 +367,7 @@ protected static async Task ResetDbStateAsync()
{
await _queryExecutor.ExecuteQueryAsync<object>(
File.ReadAllText($"DatabaseSchema-{DatabaseEngine}.sql"),
dataSourceName: string.Empty,
parameters: null,
dataReaderHandler: null);
}
Expand All @@ -388,7 +389,8 @@ protected static async Task ResetDbStateAsync()
await _queryExecutor.ExecuteQueryAsync(
queryText,
parameters: null,
_queryExecutor.GetJsonResultAsync<JsonDocument>);
_queryExecutor.GetJsonResultAsync<JsonDocument>,
string.Empty);

result = sqlResult is not null ?
sqlResult.RootElement.ToString() :
Expand All @@ -400,7 +402,8 @@ protected static async Task ResetDbStateAsync()
await _queryExecutor.ExecuteQueryAsync(
queryText,
parameters: null,
_queryExecutor.GetJsonArrayAsync);
_queryExecutor.GetJsonArrayAsync,
string.Empty);
using JsonDocument sqlResult = resultArray is not null ? JsonDocument.Parse(resultArray.ToJsonString()) : null;
result = sqlResult is not null ? sqlResult.RootElement.ToString() : null;
}
Expand Down
10 changes: 6 additions & 4 deletions src/Service.Tests/Unittests/SqlQueryExecutorUnitTests.cs
Original file line number Diff line number Diff line change
Expand Up @@ -184,16 +184,17 @@ Mock<MsSqlQueryExecutor> queryExecutor
It.IsAny<string>(),
It.IsAny<IDictionary<string, DbConnectionParam>>(),
It.IsAny<Func<DbDataReader, List<string>, Task<object>>>(),
It.IsAny<string>(),
It.IsAny<HttpContext>(),
It.IsAny<List<string>>(),
It.IsAny<string>())).CallBase();
It.IsAny<List<string>>())).CallBase();

DataApiBuilderException ex = await Assert.ThrowsExceptionAsync<DataApiBuilderException>(async () =>
{
await queryExecutor.Object.ExecuteQueryAsync<object>(
sqltext: string.Empty,
parameters: new Dictionary<string, DbConnectionParam>(),
dataReaderHandler: null,
dataSourceName: String.Empty,
httpContext: null,
args: null);
});
Expand Down Expand Up @@ -242,14 +243,15 @@ Mock<MsSqlQueryExecutor> queryExecutor
It.IsAny<string>(),
It.IsAny<IDictionary<string, DbConnectionParam>>(),
It.IsAny<Func<DbDataReader, List<string>, Task<object>>>(),
It.IsAny<string>(),
It.IsAny<HttpContext>(),
It.IsAny<List<string>>(),
It.IsAny<string>())).CallBase();
It.IsAny<List<string>>())).CallBase();

string sqltext = "SELECT * from books";

await queryExecutor.Object.ExecuteQueryAsync<object>(
sqltext: sqltext,
dataSourceName: String.Empty,
parameters: new Dictionary<string, DbConnectionParam>(),
dataReaderHandler: null,
args: null);
Expand Down

0 comments on commit 31429e5

Please sign in to comment.