Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion src/Microsoft.Azure.WebJobs.Extensions.Sql.csproj
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
<Project Sdk="Microsoft.NET.Sdk">
<PropertyGroup>
<TargetFramework>netstandard2.1</TargetFramework>
<TargetFramework>netstandard2.0</TargetFramework>
<Description>SQL binding extension for Azure Functions</Description>
<Company>Microsoft</Company>
<Authors>Microsoft</Authors>
Expand Down
183 changes: 96 additions & 87 deletions src/SqlAsyncCollector.cs
Original file line number Diff line number Diff line change
Expand Up @@ -147,90 +147,92 @@ public async Task FlushAsync(CancellationToken cancellationToken = default)
/// <param name="configuration"> Used to build up the connection </param>
private async Task UpsertRowsAsync(IEnumerable<T> rows, SqlAttribute attribute, IConfiguration configuration)
{
using SqlConnection connection = SqlBindingUtilities.BuildConnection(attribute.ConnectionStringSetting, configuration);
await connection.OpenAsync();
Dictionary<string, string> props = connection.AsConnectionProps();
using (SqlConnection connection = SqlBindingUtilities.BuildConnection(attribute.ConnectionStringSetting, configuration))
{
await connection.OpenAsync();
Dictionary<string, string> props = connection.AsConnectionProps();

string fullTableName = attribute.CommandText;
string fullTableName = attribute.CommandText;

// Include the connection string hash as part of the key in case this customer has the same table in two different Sql Servers
string cacheKey = $"{connection.ConnectionString.GetHashCode()}-{fullTableName}";
// Include the connection string hash as part of the key in case this customer has the same table in two different Sql Servers
string cacheKey = $"{connection.ConnectionString.GetHashCode()}-{fullTableName}";

ObjectCache cachedTables = MemoryCache.Default;
var tableInfo = cachedTables[cacheKey] as TableInformation;
ObjectCache cachedTables = MemoryCache.Default;
var tableInfo = cachedTables[cacheKey] as TableInformation;

if (tableInfo == null)
{
TelemetryInstance.TrackEvent(TelemetryEventName.TableInfoCacheMiss, props);
// set the columnNames for supporting T as JObject since it doesn't have columns in the memeber info.
tableInfo = await TableInformation.RetrieveTableInformationAsync(connection, fullTableName, this._logger, GetColumnNamesFromItem(rows.First()));
var policy = new CacheItemPolicy
if (tableInfo == null)
{
// Re-look up the primary key(s) after 10 minutes (they should not change very often!)
AbsoluteExpiration = DateTimeOffset.Now.AddMinutes(10)
};

this._logger.LogInformation($"DB and Table: {connection.Database}.{fullTableName}. Primary keys: [{string.Join(",", tableInfo.PrimaryKeys.Select(pk => pk.Name))}]. SQL Column and Definitions: [{string.Join(",", tableInfo.ColumnDefinitions)}]");
cachedTables.Set(cacheKey, tableInfo, policy);
}
else
{
TelemetryInstance.TrackEvent(TelemetryEventName.TableInfoCacheHit, props);
}
TelemetryInstance.TrackEvent(TelemetryEventName.TableInfoCacheMiss, props);
// set the columnNames for supporting T as JObject since it doesn't have columns in the memeber info.
tableInfo = await TableInformation.RetrieveTableInformationAsync(connection, fullTableName, this._logger, GetColumnNamesFromItem(rows.First()));
var policy = new CacheItemPolicy
{
// Re-look up the primary key(s) after 10 minutes (they should not change very often!)
AbsoluteExpiration = DateTimeOffset.Now.AddMinutes(10)
};

IEnumerable<string> extraProperties = GetExtraProperties(tableInfo.Columns, rows.First());
if (extraProperties.Any())
{
string message = $"The following properties in {typeof(T)} do not exist in the table {fullTableName}: {string.Join(", ", extraProperties.ToArray())}.";
var ex = new InvalidOperationException(message);
TelemetryInstance.TrackException(TelemetryErrorName.PropsNotExistOnTable, ex, props);
throw ex;
}
this._logger.LogInformation($"DB and Table: {connection.Database}.{fullTableName}. Primary keys: [{string.Join(",", tableInfo.PrimaryKeys.Select(pk => pk.Name))}]. SQL Column and Definitions: [{string.Join(",", tableInfo.ColumnDefinitions)}]");
cachedTables.Set(cacheKey, tableInfo, policy);
}
else
{
TelemetryInstance.TrackEvent(TelemetryEventName.TableInfoCacheHit, props);
}

TelemetryInstance.TrackEvent(TelemetryEventName.UpsertStart, props);
var transactionSw = Stopwatch.StartNew();
int batchSize = 1000;
SqlTransaction transaction = connection.BeginTransaction();
try
{
SqlCommand command = connection.CreateCommand();
command.Connection = connection;
command.Transaction = transaction;
SqlParameter par = command.Parameters.Add(RowDataParameter, SqlDbType.NVarChar, -1);
int batchCount = 0;
var commandSw = Stopwatch.StartNew();
foreach (IEnumerable<T> batch in rows.Batch(batchSize))
IEnumerable<string> extraProperties = GetExtraProperties(tableInfo.Columns, rows.First());
if (extraProperties.Any())
{
batchCount++;
GenerateDataQueryForMerge(tableInfo, batch, out string newDataQuery, out string rowData);
command.CommandText = $"{newDataQuery} {tableInfo.Query};";
par.Value = rowData;
await command.ExecuteNonQueryAsync();
string message = $"The following properties in {typeof(T)} do not exist in the table {fullTableName}: {string.Join(", ", extraProperties.ToArray())}.";
var ex = new InvalidOperationException(message);
TelemetryInstance.TrackException(TelemetryErrorName.PropsNotExistOnTable, ex, props);
throw ex;
}
transaction.Commit();
var measures = new Dictionary<string, double>()

TelemetryInstance.TrackEvent(TelemetryEventName.UpsertStart, props);
var transactionSw = Stopwatch.StartNew();
int batchSize = 1000;
SqlTransaction transaction = connection.BeginTransaction();
try
{
SqlCommand command = connection.CreateCommand();
command.Connection = connection;
command.Transaction = transaction;
SqlParameter par = command.Parameters.Add(RowDataParameter, SqlDbType.NVarChar, -1);
int batchCount = 0;
var commandSw = Stopwatch.StartNew();
foreach (IEnumerable<T> batch in rows.Batch(batchSize))
{
batchCount++;
GenerateDataQueryForMerge(tableInfo, batch, out string newDataQuery, out string rowData);
command.CommandText = $"{newDataQuery} {tableInfo.Query};";
par.Value = rowData;
await command.ExecuteNonQueryAsync();
}
transaction.Commit();
var measures = new Dictionary<string, double>()
{
{ TelemetryMeasureName.BatchCount.ToString(), batchCount },
{ TelemetryMeasureName.TransactionDurationMs.ToString(), transactionSw.ElapsedMilliseconds },
{ TelemetryMeasureName.CommandDurationMs.ToString(), commandSw.ElapsedMilliseconds }
};
TelemetryInstance.TrackEvent(TelemetryEventName.UpsertEnd, props, measures);
this._logger.LogInformation($"Upserted {rows.Count()} row(s) into database: {connection.Database} and table: {fullTableName}.");
}
catch (Exception ex)
{
try
{
TelemetryInstance.TrackException(TelemetryErrorName.Upsert, ex, props);
transaction.Rollback();
TelemetryInstance.TrackEvent(TelemetryEventName.UpsertEnd, props, measures);
this._logger.LogInformation($"Upserted {rows.Count()} row(s) into database: {connection.Database} and table: {fullTableName}.");
}
catch (Exception ex2)
catch (Exception ex)
{
TelemetryInstance.TrackException(TelemetryErrorName.UpsertRollback, ex2, props);
string message2 = $"Encountered exception during upsert and rollback.";
throw new AggregateException(message2, new List<Exception> { ex, ex2 });
try
{
TelemetryInstance.TrackException(TelemetryErrorName.Upsert, ex, props);
transaction.Rollback();
}
catch (Exception ex2)
{
TelemetryInstance.TrackException(TelemetryErrorName.UpsertRollback, ex2, props);
string message2 = $"Encountered exception during upsert and rollback.";
throw new AggregateException(message2, new List<Exception> { ex, ex2 });
}
throw;
}
throw;
}
}

Expand Down Expand Up @@ -376,7 +378,7 @@ public static bool GetCaseSensitivityFromCollation(string collation)
public static string GetDatabaseCollationQuery(SqlConnection sqlConnection)
{
return $@"
SELECT
SELECT
DATABASEPROPERTYEX('{sqlConnection.Database}', '{Collation}') AS {Collation};";
}

Expand Down Expand Up @@ -454,7 +456,7 @@ public static string GetMergeQuery(IList<PrimaryKey> primaryKeys, SqlObject tabl
}

string columnMatchingQuery = columnMatchingQueryBuilder.ToString().TrimEnd(',');
return @$"
return $@"
MERGE INTO {table.BracketQuotedFullName} WITH (HOLDLOCK)
AS ExistingData
USING {CteName}
Expand Down Expand Up @@ -490,17 +492,19 @@ public static async Task<TableInformation> RetrieveTableInformationAsync(SqlConn
try
{
var cmdCollation = new SqlCommand(GetDatabaseCollationQuery(sqlConnection), sqlConnection);
using SqlDataReader rdr = await cmdCollation.ExecuteReaderAsync();
while (await rdr.ReadAsync())
using (SqlDataReader rdr = await cmdCollation.ExecuteReaderAsync())
{
caseSensitive = GetCaseSensitivityFromCollation(rdr[Collation].ToString());
while (await rdr.ReadAsync())
{
caseSensitive = GetCaseSensitivityFromCollation(rdr[Collation].ToString());
}
caseSensitiveSw.Stop();
TelemetryInstance.TrackDuration(TelemetryEventName.GetCaseSensitivity, caseSensitiveSw.ElapsedMilliseconds, sqlConnProps);
}
caseSensitiveSw.Stop();
TelemetryInstance.TrackDuration(TelemetryEventName.GetCaseSensitivity, caseSensitiveSw.ElapsedMilliseconds, sqlConnProps);
}
catch (Exception ex)
{
// Since this doesn't rethrow make sure we stop here too (don't use finally because we want the execution time to be the same here and in the
// Since this doesn't rethrow make sure we stop here too (don't use finally because we want the execution time to be the same here and in the
// overall event but we also only want to send the GetCaseSensitivity event if it succeeds)
caseSensitiveSw.Stop();
TelemetryInstance.TrackException(TelemetryErrorName.GetCaseSensitivity, ex, sqlConnProps);
Expand All @@ -515,14 +519,17 @@ public static async Task<TableInformation> RetrieveTableInformationAsync(SqlConn
try
{
var cmdColDef = new SqlCommand(GetColumnDefinitionsQuery(table), sqlConnection);
using SqlDataReader rdr = await cmdColDef.ExecuteReaderAsync();
while (await rdr.ReadAsync())
using (SqlDataReader rdr = await cmdColDef.ExecuteReaderAsync())
{
string columnName = caseSensitive ? rdr[ColumnName].ToString() : rdr[ColumnName].ToString().ToLowerInvariant();
columnDefinitionsFromSQL.Add(columnName, rdr[ColumnDefinition].ToString());
while (await rdr.ReadAsync())
{
string columnName = caseSensitive ? rdr[ColumnName].ToString() : rdr[ColumnName].ToString().ToLowerInvariant();
columnDefinitionsFromSQL.Add(columnName, rdr[ColumnDefinition].ToString());
}
columnDefinitionsSw.Stop();
TelemetryInstance.TrackDuration(TelemetryEventName.GetColumnDefinitions, columnDefinitionsSw.ElapsedMilliseconds, sqlConnProps);
}
columnDefinitionsSw.Stop();
TelemetryInstance.TrackDuration(TelemetryEventName.GetColumnDefinitions, columnDefinitionsSw.ElapsedMilliseconds, sqlConnProps);

}
catch (Exception ex)
{
Expand All @@ -546,14 +553,16 @@ public static async Task<TableInformation> RetrieveTableInformationAsync(SqlConn
try
{
var cmd = new SqlCommand(GetPrimaryKeysQuery(table), sqlConnection);
using SqlDataReader rdr = await cmd.ExecuteReaderAsync();
while (await rdr.ReadAsync())
using (SqlDataReader rdr = await cmd.ExecuteReaderAsync())
{
string columnName = caseSensitive ? rdr[ColumnName].ToString() : rdr[ColumnName].ToString().ToLowerInvariant();
primaryKeys.Add(new PrimaryKey(columnName, bool.Parse(rdr[IsIdentity].ToString())));
while (await rdr.ReadAsync())
{
string columnName = caseSensitive ? rdr[ColumnName].ToString() : rdr[ColumnName].ToString().ToLowerInvariant();
primaryKeys.Add(new PrimaryKey(columnName, bool.Parse(rdr[IsIdentity].ToString())));
}
primaryKeysSw.Stop();
TelemetryInstance.TrackDuration(TelemetryEventName.GetPrimaryKeys, primaryKeysSw.ElapsedMilliseconds, sqlConnProps);
}
primaryKeysSw.Stop();
TelemetryInstance.TrackDuration(TelemetryEventName.GetPrimaryKeys, primaryKeysSw.ElapsedMilliseconds, sqlConnProps);
}
catch (Exception ex)
{
Expand Down
8 changes: 5 additions & 3 deletions src/SqlAsyncEnumerable.cs
Original file line number Diff line number Diff line change
Expand Up @@ -105,9 +105,11 @@ private async Task<bool> GetNextRowAsync()
{
if (this._reader == null)
{
using SqlCommand command = SqlBindingUtilities.BuildCommand(this._attribute, this._connection);
await command.Connection.OpenAsync();
this._reader = await command.ExecuteReaderAsync();
using (SqlCommand command = SqlBindingUtilities.BuildCommand(this._attribute, this._connection))
{
await command.Connection.OpenAsync();
this._reader = await command.ExecuteReaderAsync();
}
}
if (await this._reader.ReadAsync())
{
Expand Down
22 changes: 12 additions & 10 deletions src/SqlConverters.cs
Original file line number Diff line number Diff line change
Expand Up @@ -154,18 +154,20 @@ async Task<string> IAsyncConverter<SqlAttribute, string>.ConvertAsync(SqlAttribu
/// <returns></returns>
public virtual async Task<string> BuildItemFromAttributeAsync(SqlAttribute attribute)
{
using SqlConnection connection = SqlBindingUtilities.BuildConnection(attribute.ConnectionStringSetting, this._configuration);
using (SqlConnection connection = SqlBindingUtilities.BuildConnection(attribute.ConnectionStringSetting, this._configuration))
// Ideally, we would like to move away from using SqlDataAdapter both here and in the
// SqlAsyncCollector since it does not support asynchronous operations.
// There is a GitHub issue open to track this
using var adapter = new SqlDataAdapter();
using SqlCommand command = SqlBindingUtilities.BuildCommand(attribute, connection);
adapter.SelectCommand = command;
await connection.OpenAsync();
var dataTable = new DataTable();
adapter.Fill(dataTable);
this._logger.LogInformation($"{dataTable.Rows.Count} row(s) queried from database: {connection.Database} using Command: {command.CommandText}");
return JsonConvert.SerializeObject(dataTable);
using (var adapter = new SqlDataAdapter())
using (SqlCommand command = SqlBindingUtilities.BuildCommand(attribute, connection))
{
adapter.SelectCommand = command;
await connection.OpenAsync();
var dataTable = new DataTable();
adapter.Fill(dataTable);
this._logger.LogInformation($"{dataTable.Rows.Count} row(s) queried from database: {connection.Database} using Command: {command.CommandText}");
return JsonConvert.SerializeObject(dataTable);
}

}

IAsyncEnumerable<T> IConverter<SqlAttribute, IAsyncEnumerable<T>>.Convert(SqlAttribute attribute)
Expand Down
Loading