diff --git a/src/Core/Resolvers/QueryExecutor.cs b/src/Core/Resolvers/QueryExecutor.cs index 610100ea8d..fc76465764 100644 --- a/src/Core/Resolvers/QueryExecutor.cs +++ b/src/Core/Resolvers/QueryExecutor.cs @@ -412,7 +412,7 @@ public async Task ExtractResultSetFromDbDataReaderAsync(DbDataReader dbDataReader, List? args = null) { DbResultSet dbResultSet = new(resultProperties: GetResultPropertiesAsync(dbDataReader).Result ?? new()); - + long availableBytes = _maxResponseSizeBytes; while (await ReadAsync(dbDataReader)) { if (dbDataReader.HasRows) @@ -434,7 +434,16 @@ public async Task int colIndex = dbDataReader.GetOrdinal(columnName); if (!dbDataReader.IsDBNull(colIndex)) { - dbResultSetRow.Columns.Add(columnName, dbDataReader[columnName]); + if (!ConfigProvider.GetConfig().MaxResponseSizeLogicEnabled()) + { + dbResultSetRow.Columns.Add(columnName, dbDataReader[columnName]); + } + else + { + int columnSize = (int)schemaRow["ColumnSize"]; + availableBytes -= StreamDataIntoDbResultSetRow( + dbDataReader, dbResultSetRow, columnName, columnSize, ordinal: colIndex, availableBytes); + } } else { @@ -455,7 +464,7 @@ public DbResultSet ExtractResultSetFromDbDataReader(DbDataReader dbDataReader, List? args = null) { DbResultSet dbResultSet = new(resultProperties: GetResultProperties(dbDataReader) ?? new()); - + long availableBytes = _maxResponseSizeBytes; while (Read(dbDataReader)) { if (dbDataReader.HasRows) @@ -477,7 +486,16 @@ public DbResultSet int colIndex = dbDataReader.GetOrdinal(columnName); if (!dbDataReader.IsDBNull(colIndex)) { - dbResultSetRow.Columns.Add(columnName, dbDataReader[columnName]); + if (!ConfigProvider.GetConfig().MaxResponseSizeLogicEnabled()) + { + dbResultSetRow.Columns.Add(columnName, dbDataReader[columnName]); + } + else + { + int columnSize = (int)schemaRow["ColumnSize"]; + availableBytes -= StreamDataIntoDbResultSetRow( + dbDataReader, dbResultSetRow, columnName, columnSize, ordinal: colIndex, availableBytes); + } } else { @@ -649,10 +667,11 @@ public Dictionary GetResultProperties( /// DbDataReader. /// Available buffer. /// jsonString to read into. + /// Ordinal of column being read. /// size of data read in bytes. - internal int StreamData(DbDataReader dbDataReader, long availableSize, StringBuilder resultJsonString) + internal int StreamCharData(DbDataReader dbDataReader, long availableSize, StringBuilder resultJsonString, int ordinal) { - long resultFieldSize = dbDataReader.GetChars(ordinal: 0, dataOffset: 0, buffer: null, bufferOffset: 0, length: 0); + long resultFieldSize = dbDataReader.GetChars(ordinal: ordinal, dataOffset: 0, buffer: null, bufferOffset: 0, length: 0); // if the size of the field is less than available size, then we can read the entire field. // else we throw exception. @@ -661,12 +680,75 @@ internal int StreamData(DbDataReader dbDataReader, long availableSize, StringBui char[] buffer = new char[resultFieldSize]; // read entire field into buffer and reduce available size. - dbDataReader.GetChars(ordinal: 0, dataOffset: 0, buffer: buffer, bufferOffset: 0, length: buffer.Length); + dbDataReader.GetChars(ordinal: ordinal, dataOffset: 0, buffer: buffer, bufferOffset: 0, length: buffer.Length); resultJsonString.Append(buffer); return buffer.Length; } + /// + /// Reads data into byteObject. + /// + /// DbDataReader. + /// Available buffer. + /// ordinal of column being read + /// bytes array to read result into. + /// size of data read in bytes. + internal int StreamByteData(DbDataReader dbDataReader, long availableSize, int ordinal, out byte[]? resultBytes) + { + long resultFieldSize = dbDataReader.GetBytes( + ordinal: ordinal, dataOffset: 0, buffer: null, bufferOffset: 0, length: 0); + + // if the size of the field is less than available size, then we can read the entire field. + // else we throw exception. + ValidateSize(availableSize, resultFieldSize); + + resultBytes = new byte[resultFieldSize]; + + dbDataReader.GetBytes(ordinal: ordinal, dataOffset: 0, buffer: resultBytes, bufferOffset: 0, length: resultBytes.Length); + + return resultBytes.Length; + } + + /// + /// Streams a column into the dbResultSetRow + /// + /// DbDataReader + /// Result set row to read into + /// Available bytes to read. + /// columnName to read + /// ordinal of column. + /// size of data read in bytes + internal int StreamDataIntoDbResultSetRow(DbDataReader dbDataReader, DbResultSetRow dbResultSetRow, string columnName, int columnSize, int ordinal, long availableBytes) + { + Type systemType = dbDataReader.GetFieldType(ordinal); + int dataRead; + + if (systemType == typeof(string)) + { + StringBuilder jsonString = new(); + dataRead = StreamCharData( + dbDataReader: dbDataReader, availableSize: availableBytes, resultJsonString: jsonString, ordinal: ordinal); + + dbResultSetRow.Columns.Add(columnName, jsonString.ToString()); + } + else if (systemType == typeof(byte[])) + { + dataRead = StreamByteData( + dbDataReader: dbDataReader, availableSize: availableBytes, ordinal: ordinal, out byte[]? result); + + dbResultSetRow.Columns.Add(columnName, result); + } + else + { + dataRead = columnSize; + ValidateSize(availableBytes, dataRead); + dbResultSetRow.Columns.Add(columnName, dbDataReader[columnName]); + } + + return dataRead; + } + /// /// This function reads the data from the DbDataReader and returns a JSON string. /// 1. MaxResponseSizeLogicEnabled is used like a feature flag. @@ -698,7 +780,9 @@ private async Task GetJsonStringFromDbReader(DbDataReader dbDataReader) long availableSize = _maxResponseSizeBytes; while (await ReadAsync(dbDataReader)) { - availableSize -= StreamData(dbDataReader, availableSize, jsonString); + // We only have a single column and hence when streaming data, we pass in 0 as the ordinal. + availableSize -= StreamCharData( + dbDataReader: dbDataReader, availableSize: availableSize, resultJsonString: jsonString, ordinal: 0); } } diff --git a/src/Service.Tests/Unittests/SqlQueryExecutorUnitTests.cs b/src/Service.Tests/Unittests/SqlQueryExecutorUnitTests.cs index d507e849cb..06551e091b 100644 --- a/src/Service.Tests/Unittests/SqlQueryExecutorUnitTests.cs +++ b/src/Service.Tests/Unittests/SqlQueryExecutorUnitTests.cs @@ -364,8 +364,10 @@ public void ValidateStreamingLogicAsync(int readDataLoops, bool exceptionExpecte int availableSize = (int)runtimeConfig.MaxResponseSizeMB() * 1024 * 1024; for (int i = 0; i < readDataLoops; i++) { - availableSize -= msSqlQueryExecutor.StreamData(dbDataReader: dbDataReader.Object, availableSize: availableSize, resultJsonString: new()); + availableSize -= msSqlQueryExecutor.StreamCharData( + dbDataReader: dbDataReader.Object, availableSize: availableSize, resultJsonString: new(), ordinal: 0); } + } catch (DataApiBuilderException ex) { @@ -375,6 +377,76 @@ public void ValidateStreamingLogicAsync(int readDataLoops, bool exceptionExpecte } } + /// + /// Validates streaming logic for QueryExecutor + /// In this test the streaming logic for stored procedures is tested. + /// The test tries to validate the streaming across different column types (Byte, string, int etc) + /// Max available size is set to 4 MB, getChars and getBytes are moqed to return 1MB per read. + /// Exception should be thrown in test cases where we go above 4MB. + /// + [DataTestMethod, TestCategory(TestCategory.MSSQL)] + [DataRow(4, false, + DisplayName = "Max available size is set to 4MB.4 data read loop iterations, 4 columns of size 1MB -> should successfully read because max-db-response-size-mb is 4MB")] + [DataRow(5, true, + DisplayName = "Max available size is set to 4MB.5 data read loop iterations, 4 columns of size 1MB and one int read of 4 bytes -> Fails to read because max-db-response-size-mb is 4MB")] + public void ValidateStreamingLogicForStoredProcedures(int readDataLoops, bool exceptionExpected) + { + TestHelper.SetupDatabaseEnvironment(TestCategory.MSSQL); + string[] columnNames = { "NVarcharStringColumn1", "VarCharStringColumn2", "ImageByteColumn", "ImageByteColumn2", "IntColumn" }; + // 1MB -> 1024*1024 bytes, an int is 4 bytes + int[] columnSizeBytes = { 1024 * 1024, 1024 * 1024, 1024 * 1024, 1024 * 1024, 4 }; + + FileSystem fileSystem = new(); + FileSystemRuntimeConfigLoader loader = new(fileSystem); + RuntimeConfig runtimeConfig = new( + Schema: "UnitTestSchema", + DataSource: new DataSource(DatabaseType: DatabaseType.MSSQL, "", Options: null), + Runtime: new( + Rest: new(), + GraphQL: new(), + Host: new(Cors: null, Authentication: null, MaxResponseSizeMB: 4) + ), + Entities: new(new Dictionary())); + + RuntimeConfigProvider runtimeConfigProvider = TestHelper.GenerateInMemoryRuntimeConfigProvider(runtimeConfig); + + Mock>> queryExecutorLogger = new(); + Mock httpContextAccessor = new(); + DbExceptionParser dbExceptionParser = new MsSqlDbExceptionParser(runtimeConfigProvider); + + // Instantiate the MsSqlQueryExecutor and Setup parameters for the query + MsSqlQueryExecutor msSqlQueryExecutor = new(runtimeConfigProvider, dbExceptionParser, queryExecutorLogger.Object, httpContextAccessor.Object); + + try + { + // Test for general queries and mutations + Mock dbDataReader = new(); + dbDataReader.Setup(d => d.HasRows).Returns(true); + dbDataReader.Setup(x => x.GetChars(It.IsAny(), It.IsAny(), It.IsAny(), It.IsAny(), It.IsAny())).Returns(1024 * 1024); + dbDataReader.Setup(x => x.GetBytes(It.IsAny(), It.IsAny(), It.IsAny(), It.IsAny(), It.IsAny())).Returns(1024 * 1024); + dbDataReader.Setup(x => x.GetFieldType(0)).Returns(typeof(string)); + dbDataReader.Setup(x => x.GetFieldType(1)).Returns(typeof(string)); + dbDataReader.Setup(x => x.GetFieldType(2)).Returns(typeof(byte[])); + dbDataReader.Setup(x => x.GetFieldType(3)).Returns(typeof(byte[])); + dbDataReader.Setup(x => x.GetFieldType(4)).Returns(typeof(int)); + int availableSizeBytes = runtimeConfig.MaxResponseSizeMB() * 1024 * 1024; + DbResultSetRow dbResultSetRow = new(); + for (int i = 0; i < readDataLoops; i++) + { + availableSizeBytes -= msSqlQueryExecutor.StreamDataIntoDbResultSetRow( + dbDataReader.Object, dbResultSetRow, columnName: columnNames[i], + columnSize: columnSizeBytes[i], ordinal: i, availableBytes: availableSizeBytes); + Assert.IsTrue(dbResultSetRow.Columns.ContainsKey(columnNames[i]), $"Column {columnNames[i]} should be successfully read and added to DbResultRow while streaming."); + } + } + catch (DataApiBuilderException ex) + { + Assert.IsTrue(exceptionExpected); + Assert.AreEqual(HttpStatusCode.RequestEntityTooLarge, ex.StatusCode); + Assert.AreEqual("The JSON result size exceeds max result size of 4MB. Please use pagination to reduce size of result.", ex.Message); + } + } + [TestCleanup] public void CleanupAfterEachTest() {