Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
77 changes: 69 additions & 8 deletions src/SqlAsyncCollector.cs
Original file line number Diff line number Diff line change
Expand Up @@ -10,13 +10,16 @@
using System.Text;
using System.Threading;
using System.Threading.Tasks;
using static Microsoft.Azure.WebJobs.Extensions.Sql.Telemetry.Telemetry;
using Microsoft.Azure.WebJobs.Logging;
using Microsoft.Data.SqlClient;
using Microsoft.Extensions.Configuration;
using Microsoft.Extensions.Logging;
using MoreLinq;
using Newtonsoft.Json;
using Newtonsoft.Json.Serialization;
using Microsoft.Azure.WebJobs.Extensions.Sql.Telemetry;
using System.Diagnostics;

namespace Microsoft.Azure.WebJobs.Extensions.Sql
{
Expand Down Expand Up @@ -74,6 +77,7 @@ public SqlAsyncCollector(IConfiguration configuration, SqlAttribute attribute, I
this._configuration = configuration ?? throw new ArgumentNullException(nameof(configuration));
this._attribute = attribute ?? throw new ArgumentNullException(nameof(attribute));
this._logger = loggerFactory?.CreateLogger(LogCategories.Bindings) ?? throw new ArgumentNullException(nameof(loggerFactory));
TelemetryInstance.TrackCreate(CreateType.SqlAsyncCollector);
}

/// <summary>
Expand All @@ -89,7 +93,7 @@ public async Task AddAsync(T item, CancellationToken cancellationToken = default
if (item != null)
{
await this._rowLock.WaitAsync(cancellationToken);

TelemetryInstance.TrackEvent(TelemetryEventName.AddAsync);
try
{
this._rows.Add(item);
Expand All @@ -116,6 +120,7 @@ public async Task FlushAsync(CancellationToken cancellationToken = default)
{
if (this._rows.Count != 0)
{
TelemetryInstance.TrackEvent(TelemetryEventName.FlushAsync);
await this.UpsertRowsAsync(this._rows, this._attribute, this._configuration);
this._rows.Clear();
}
Expand All @@ -139,6 +144,7 @@ private async Task UpsertRowsAsync(IEnumerable<T> rows, SqlAttribute attribute,
{
using SqlConnection connection = SqlBindingUtilities.BuildConnection(attribute.ConnectionStringSetting, configuration);
await connection.OpenAsync();
Dictionary<string, string> props = connection.AsConnectionProps();

string fullTableName = attribute.CommandText;

Expand All @@ -150,8 +156,8 @@ private async Task UpsertRowsAsync(IEnumerable<T> rows, SqlAttribute attribute,

if (tableInfo == null)
{
TelemetryInstance.TrackEvent(TelemetryEventName.TableInfoCacheMiss, props);
tableInfo = await TableInformation.RetrieveTableInformationAsync(connection, fullTableName, this._logger);

var policy = new CacheItemPolicy
{
// Re-look up the primary key(s) after 10 minutes (they should not change very often!)
Expand All @@ -161,14 +167,22 @@ private async Task UpsertRowsAsync(IEnumerable<T> rows, SqlAttribute attribute,
this._logger.LogInformation($"DB and Table: {connection.Database}.{fullTableName}. Primary keys: [{string.Join(",", tableInfo.PrimaryKeys.Select(pk => pk.Name))}]. SQL Column and Definitions: [{string.Join(",", tableInfo.ColumnDefinitions)}]");
cachedTables.Set(cacheKey, tableInfo, policy);
}
else
{
TelemetryInstance.TrackEvent(TelemetryEventName.TableInfoCacheHit, props);
}

IEnumerable<string> extraProperties = GetExtraProperties(tableInfo.Columns);
if (extraProperties.Any())
{
string message = $"The following properties in {typeof(T)} do not exist in the table {fullTableName}: {string.Join(", ", extraProperties.ToArray())}.";
throw new InvalidOperationException(message);
var ex = new InvalidOperationException(message);
TelemetryInstance.TrackError(TelemetryErrorName.PropsNotExistOnTable, ex, props);
throw ex;
}

TelemetryInstance.TrackEvent(TelemetryEventName.UpsertStart, props);
var transactionSw = Stopwatch.StartNew();
int batchSize = 1000;
SqlTransaction transaction = connection.BeginTransaction();
try
Expand All @@ -177,24 +191,35 @@ private async Task UpsertRowsAsync(IEnumerable<T> rows, SqlAttribute attribute,
command.Connection = connection;
command.Transaction = transaction;
SqlParameter par = command.Parameters.Add(RowDataParameter, SqlDbType.NVarChar, -1);

int batchCount = 0;
var commandSw = Stopwatch.StartNew();
foreach (IEnumerable<T> batch in rows.Batch(batchSize))
{
batchCount++;
GenerateDataQueryForMerge(tableInfo, batch, out string newDataQuery, out string rowData);
command.CommandText = $"{newDataQuery} {tableInfo.Query};";
par.Value = rowData;
await command.ExecuteNonQueryAsync();
}
transaction.Commit();
var measures = new Dictionary<string, double>()
{
{ TelemetryMeasureName.BatchCount.ToString(), batchCount },
{ TelemetryMeasureName.TransactionDurationMs.ToString(), transactionSw.ElapsedMilliseconds },
{ TelemetryMeasureName.CommandDurationMs.ToString(), commandSw.ElapsedMilliseconds }
};
TelemetryInstance.TrackEvent(TelemetryEventName.UpsertEnd, props, measures);
}
catch (Exception ex)
{
try
{
TelemetryInstance.TrackError(TelemetryErrorName.Upsert, ex, props);
transaction.Rollback();
}
catch (Exception ex2)
{
TelemetryInstance.TrackError(TelemetryErrorName.UpsertRollback, ex2, props);
string message2 = $"Encountered exception during upsert and rollback.";
throw new AggregateException(message2, new List<Exception> { ex, ex2 });
}
Expand Down Expand Up @@ -418,10 +443,14 @@ WHEN NOT MATCHED THEN
/// <returns>TableInformation object containing primary keys, column types, etc.</returns>
public static async Task<TableInformation> RetrieveTableInformationAsync(SqlConnection sqlConnection, string fullName, ILogger logger)
{
Dictionary<string, string> sqlConnProps = sqlConnection.AsConnectionProps();
TelemetryInstance.TrackEvent(TelemetryEventName.GetTableInfoStart, sqlConnProps);
var table = new SqlObject(fullName);

// Get case sensitivity from database collation (default to false if any exception occurs)
bool caseSensitive = false;
var tableInfoSw = Stopwatch.StartNew();
var caseSensitiveSw = Stopwatch.StartNew();
try
{
var cmdCollation = new SqlCommand(GetDatabaseCollationQuery(sqlConnection), sqlConnection);
Expand All @@ -430,16 +459,23 @@ public static async Task<TableInformation> RetrieveTableInformationAsync(SqlConn
{
caseSensitive = GetCaseSensitivityFromCollation(rdr[Collation].ToString());
}
caseSensitiveSw.Stop();
TelemetryInstance.TrackDuration(TelemetryEventName.GetCaseSensitivity, caseSensitiveSw.ElapsedMilliseconds, sqlConnProps);
}
catch (Exception ex)
{
// Since this doesn't rethrow make sure we stop here too (don't use finally because we want the execution time to be the same here and in the
// overall event but we also only want to send the GetCaseSensitivity event if it succeeds)
caseSensitiveSw.Stop();
TelemetryInstance.TrackError(TelemetryErrorName.GetCaseSensitivity, ex, sqlConnProps);
logger.LogWarning($"Encountered exception while retrieving database collation: {ex}. Case insensitive behavior will be used by default.");
}

StringComparer comparer = caseSensitive ? StringComparer.Ordinal : StringComparer.OrdinalIgnoreCase;

// Get all column names and types
var columnDefinitionsFromSQL = new Dictionary<string, string>(comparer);
var columnDefinitionsSw = Stopwatch.StartNew();
try
{
var cmdColDef = new SqlCommand(GetColumnDefinitionsQuery(table), sqlConnection);
Expand All @@ -449,9 +485,12 @@ public static async Task<TableInformation> RetrieveTableInformationAsync(SqlConn
string columnName = caseSensitive ? rdr[ColumnName].ToString() : rdr[ColumnName].ToString().ToLowerInvariant();
columnDefinitionsFromSQL.Add(columnName, rdr[ColumnDefinition].ToString());
}
columnDefinitionsSw.Stop();
TelemetryInstance.TrackDuration(TelemetryEventName.GetColumnDefinitions, columnDefinitionsSw.ElapsedMilliseconds, sqlConnProps);
}
catch (Exception ex)
{
TelemetryInstance.TrackError(TelemetryErrorName.GetColumnDefinitions, ex, sqlConnProps);
// Throw a custom error so that it's easier to decipher.
string message = $"Encountered exception while retrieving column names and types for table {table}. Cannot generate upsert command without them.";
throw new InvalidOperationException(message, ex);
Expand All @@ -460,11 +499,14 @@ public static async Task<TableInformation> RetrieveTableInformationAsync(SqlConn
if (columnDefinitionsFromSQL.Count == 0)
{
string message = $"Table {table} does not exist.";
throw new InvalidOperationException(message);
var ex = new InvalidOperationException(message);
TelemetryInstance.TrackError(TelemetryErrorName.GetColumnDefinitionsTableDoesNotExist, ex, sqlConnProps);
throw ex;
}

// Query SQL for table Primary Keys
var primaryKeys = new List<PrimaryKey>();
var primaryKeysSw = Stopwatch.StartNew();
try
{
var cmd = new SqlCommand(GetPrimaryKeysQuery(table), sqlConnection);
Expand All @@ -474,9 +516,12 @@ public static async Task<TableInformation> RetrieveTableInformationAsync(SqlConn
string columnName = caseSensitive ? rdr[ColumnName].ToString() : rdr[ColumnName].ToString().ToLowerInvariant();
primaryKeys.Add(new PrimaryKey(columnName, bool.Parse(rdr[IsIdentity].ToString())));
}
primaryKeysSw.Stop();
TelemetryInstance.TrackDuration(TelemetryEventName.GetPrimaryKeys, primaryKeysSw.ElapsedMilliseconds, sqlConnProps);
}
catch (Exception ex)
{
TelemetryInstance.TrackError(TelemetryErrorName.GetPrimaryKeys, ex, sqlConnProps);
// Throw a custom error so that it's easier to decipher.
string message = $"Encountered exception while retrieving primary keys for table {table}. Cannot generate upsert command without them.";
throw new InvalidOperationException(message, ex);
Expand All @@ -485,7 +530,9 @@ public static async Task<TableInformation> RetrieveTableInformationAsync(SqlConn
if (!primaryKeys.Any())
{
string message = $"Did not retrieve any primary keys for {table}. Cannot generate upsert command without them.";
throw new InvalidOperationException(message);
var ex = new InvalidOperationException(message);
TelemetryInstance.TrackError(TelemetryErrorName.NoPrimaryKeys, ex, sqlConnProps);
throw ex;
}

// Match SQL Primary Key column names to POCO field/property objects. Ensure none are missing.
Expand All @@ -500,12 +547,26 @@ public static async Task<TableInformation> RetrieveTableInformationAsync(SqlConn
if (!hasIdentityColumnPrimaryKeys && missingPrimaryKeysFromPOCO.Any())
{
string message = $"All primary keys for SQL table {table} need to be found in '{typeof(T)}.' Missing primary keys: [{string.Join(",", missingPrimaryKeysFromPOCO)}]";
throw new InvalidOperationException(message);
var ex = new InvalidOperationException(message);
TelemetryInstance.TrackError(TelemetryErrorName.MissingPrimaryKeys, ex, sqlConnProps);
throw ex;
}

// If any identity columns aren't included in the object then we have to generate a basic insert since the merge statement expects all primary key
// columns to exist. (the merge statement can handle nullable columns though if those exist)
string query = hasIdentityColumnPrimaryKeys && missingPrimaryKeysFromPOCO.Any() ? GetInsertQuery(table) : GetMergeQuery(primaryKeys, table, comparison);
bool usingInsertQuery = hasIdentityColumnPrimaryKeys && missingPrimaryKeysFromPOCO.Any();
string query = usingInsertQuery ? GetInsertQuery(table) : GetMergeQuery(primaryKeys, table, comparison);

tableInfoSw.Stop();
var durations = new Dictionary<string, double>()
{
{ TelemetryMeasureName.GetCaseSensitivityDurationMs.ToString(), caseSensitiveSw.ElapsedMilliseconds },
{ TelemetryMeasureName.GetColumnDefinitionsDurationMs.ToString(), columnDefinitionsSw.ElapsedMilliseconds },
{ TelemetryMeasureName.GetPrimaryKeysDurationMs.ToString(), primaryKeysSw.ElapsedMilliseconds }
};
sqlConnProps.Add(TelemetryPropertyName.QueryType.ToString(), usingInsertQuery ? "insert" : "merge");
sqlConnProps.Add(TelemetryPropertyName.HasIdentityColumn.ToString(), hasIdentityColumnPrimaryKeys.ToString());
TelemetryInstance.TrackDuration(TelemetryEventName.GetTableInfoEnd, tableInfoSw.ElapsedMilliseconds, sqlConnProps, durations);
return new TableInformation(primaryKeyFields, columnDefinitionsFromSQL, comparer, query);
}
}
Expand Down
3 changes: 2 additions & 1 deletion src/SqlBindingConfigProvider.cs
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
using System;
using Microsoft.Azure.WebJobs.Description;
using static Microsoft.Azure.WebJobs.Extensions.Sql.SqlConverters;
using static Microsoft.Azure.WebJobs.Extensions.Sql.Telemetry.Telemetry;
using Microsoft.Azure.WebJobs.Host.Bindings;
using Microsoft.Azure.WebJobs.Host.Config;
using Microsoft.Extensions.Configuration;
Expand Down Expand Up @@ -45,7 +46,7 @@ public void Initialize(ExtensionConfigContext context)
{
throw new ArgumentNullException(nameof(context));
}
Telemetry.Telemetry.Instance.Initialize(this._configuration, this._loggerFactory);
TelemetryInstance.Initialize(this._configuration, this._loggerFactory);
#pragma warning disable CS0618 // Fine to use this for our stuff
FluentBindingRule<SqlAttribute> inputOutputRule = context.AddBindingRule<SqlAttribute>();
var converter = new SqlConverter(this._configuration);
Expand Down
Loading