diff --git a/Directory.Packages.props b/Directory.Packages.props index e2940d07..46686db7 100644 --- a/Directory.Packages.props +++ b/Directory.Packages.props @@ -4,13 +4,14 @@ - + + - - + + @@ -23,7 +24,7 @@ - + diff --git a/docs/rules/DAP245.md b/docs/rules/DAP245.md new file mode 100644 index 00000000..674caf3d --- /dev/null +++ b/docs/rules/DAP245.md @@ -0,0 +1,24 @@ +# DAP245 + +It is possible for an identifier to be *technically valid* to use without quoting, yet highly confusing As an example, the following TSQL is *entirely valid*: + +``` sql +CREATE TABLE GO (GO int not null) +GO +INSERT GO ( GO ) VALUES (42) +GO +SELECT GO FROM GO +``` + +However, this can confuse readers and parsing tools. It would be *hugely* +advantageous to use delimited identifiers appropriately: + +``` sql +CREATE TABLE [GO] ([GO] int not null) +GO +INSERT [GO] ( [GO] ) VALUES (42) +GO +SELECT [GO] FROM [GO] +``` + +Or... maybe just use a different name? \ No newline at end of file diff --git a/src/Dapper.AOT.Analyzers/CodeAnalysis/DapperAnalyzer.Diagnostics.cs b/src/Dapper.AOT.Analyzers/CodeAnalysis/DapperAnalyzer.Diagnostics.cs index a47041b1..3d481f48 100644 --- a/src/Dapper.AOT.Analyzers/CodeAnalysis/DapperAnalyzer.Diagnostics.cs +++ b/src/Dapper.AOT.Analyzers/CodeAnalysis/DapperAnalyzer.Diagnostics.cs @@ -95,6 +95,7 @@ public static readonly DiagnosticDescriptor InterpolatedStringSqlExpression = SqlWarning("DAP241", "Interpolated string usage", "Data values should not be interpolated into SQL string - use parameters instead"), ConcatenatedStringSqlExpression = SqlWarning("DAP242", "Concatenated string usage", "Data values should not be concatenated into SQL string - use parameters instead"), InvalidDatepartToken = SqlWarning("DAP243", "Valid datepart token expected", "Date functions require a recognized datepart argument"), - SelectAggregateMismatch = SqlWarning("DAP244", "SELECT aggregate mismatch", "SELECT has mixture of aggregate and non-aggregate expressions"); + SelectAggregateMismatch = SqlWarning("DAP244", "SELECT aggregate mismatch", "SELECT has mixture of aggregate and non-aggregate expressions"), + DangerousNonDelimitedIdentifier = SqlWarning("DAP245", "Dangerous non-delimited identifier", "The identifier '{0}' can be confusing when not delimited; consider delimiting it with [...]"); } } diff --git a/src/Dapper.AOT.Analyzers/Internal/DiagnosticTSqlProcessor.cs b/src/Dapper.AOT.Analyzers/Internal/DiagnosticTSqlProcessor.cs index bb6356ea..a5c46a10 100644 --- a/src/Dapper.AOT.Analyzers/Internal/DiagnosticTSqlProcessor.cs +++ b/src/Dapper.AOT.Analyzers/Internal/DiagnosticTSqlProcessor.cs @@ -161,4 +161,7 @@ protected override void OnInvalidNullExpression(Location location) protected override void OnTrivialOperand(Location location) => OnDiagnostic(DapperAnalyzer.Diagnostics.TrivialOperand, location); + + protected override void OnDangerousNonDelimitedIdentifier(Location location, string name) + => OnDiagnostic(DapperAnalyzer.Diagnostics.DangerousNonDelimitedIdentifier, location, name); } diff --git a/src/Dapper.AOT.Analyzers/Internal/GeneralSqlParser.cs b/src/Dapper.AOT.Analyzers/Internal/GeneralSqlParser.cs new file mode 100644 index 00000000..d593da3c --- /dev/null +++ b/src/Dapper.AOT.Analyzers/Internal/GeneralSqlParser.cs @@ -0,0 +1,562 @@ +using Dapper.SqlAnalysis; +using System; +using System.Buffers; +using System.Collections.Generic; +using System.Collections.Immutable; +using System.Diagnostics; +using System.Linq; +using System.Text; + +namespace Dapper.Internal.SqlParsing; + +internal readonly struct CommandVariable : IEquatable +{ + public CommandVariable(string name, int index) + { + Name = name; + Index = index; + } + public int Index { get; } + public string Name { get; } + + public ParameterKind Kind + { + get + { + var name = Name; + if (name is { Length: >= 2 } && GeneralSqlParser.IsParameterPrefix(name[0])) + { + if ((char.IsLetter(name[1]) || name[1] == '_')) + return ParameterKind.Nominal; + + if (char.IsNumber(name[1])) + return ParameterKind.Ordinal; + } + + return ParameterKind.Unknown; + } + } + + public override int GetHashCode() => Index; + public override bool Equals(object obj) => obj is CommandVariable other && Equals(other); + public bool Equals(CommandVariable other) + => Index == other.Index && Name == other.Name; + public override string ToString() => $"@{Index}:{Name}"; +} +internal readonly struct CommandBatch : IEquatable +{ + public ImmutableArray Variables { get; } + public string Sql { get; } + + public CommandBatch(string sql) : this(ImmutableArray.Empty, sql) { } + public CommandBatch(string sql, CommandVariable var0) : this(ImmutableArray.Create(var0), sql) { } + public CommandBatch(string sql, CommandVariable var0, CommandVariable var1) : this(ImmutableArray.Create(var0, var1), sql) { } + public CommandBatch(string sql, CommandVariable var0, CommandVariable var1, CommandVariable var2) : this(ImmutableArray.Create(var0, var1, var2), sql) { } + public CommandBatch(string sql, CommandVariable var0, CommandVariable var1, CommandVariable var2, CommandVariable var3) : this(ImmutableArray.Create(var0, var1, var2, var3), sql) { } + public CommandBatch(string sql, CommandVariable var0, CommandVariable var1, CommandVariable var2, CommandVariable var3, CommandVariable var4) : this(ImmutableArray.Create(var0, var1, var2, var3, var4), sql) { } + public static CommandBatch Create(string sql, params CommandVariable[] variables) + => new(ImmutableArray.Create(variables), sql); + public static CommandBatch Create(string sql, ImmutableArray variables) + => new(variables, sql); + // invert order to solve some .ctor ambiguity issues + private CommandBatch(ImmutableArray variables, string sql) + { + Sql = sql; + Variables = variables; + } + + public ParameterKind ParameterKind + { + get + { + if (Variables.IsDefaultOrEmpty) + { + return ParameterKind.NonParametrized; + } + var first = Variables[0].Kind; + foreach (var arg in Variables.AsSpan().Slice(1)) + { + if (arg.Kind != first) return ParameterKind.Mixed; + } + return first; + } + } + + public override int GetHashCode() => Sql.GetHashCode(); // args are a component of the sql; no need to hash them + public override string ToString() => Variables.IsDefaultOrEmpty ? Sql : + (Sql + " with " + string.Join(", ", Variables)); + + public override bool Equals(object obj) => obj is CommandBatch other && Equals(other); + + public bool Equals(CommandBatch other) + => Sql == other.Sql && Variables.SequenceEqual(other.Variables); + + public OrdinalResult TryMakeOrdinal(IList inputArgs, Func argName, Func argFactory, out IList args, out string sql, int argIndex = 0) + { + static bool TryFindByName(string name, IList inputArgs, Func argNameSelector, out T found) + { + if (string.IsNullOrWhiteSpace(name) || name.Length < 2 + || !GeneralSqlParser.IsParameterPrefix(name[0])) + { + // general preconditions for nominal match: looking for @x + // i.e. must have parameter symbol and at least one token character + found = default!; + return false; + + } + foreach (var arg in inputArgs) + { + var argName = argNameSelector(arg); + if (string.IsNullOrWhiteSpace(argName)) continue; // looking for nominal match + + // check for exact match including prefix, i.e. "@foo" vs "@foo" + if (string.Equals(name, argName, StringComparison.OrdinalIgnoreCase)) + { + found = arg; + return true; + } + // check for input name excluding prefix, i.e. "foo" vs detected "@foo" + // (when using Dapper, this is the normal usage) + if (argName.Length == name.Length - 1 && !GeneralSqlParser.IsParameterPrefix(argName[0]) + && name.EndsWith(argName, StringComparison.OrdinalIgnoreCase)) + { + found = arg; + return true; + } + } + + found = default!; + return false; + } + sql = Sql; + var kind = ParameterKind; + switch (kind) + { + case ParameterKind.NonParametrized: + args = []; + return OrdinalResult.NoChange; + case ParameterKind.Mixed: + args = inputArgs; + return OrdinalResult.MixedParameters; + case ParameterKind.Ordinal: + // TODO: rewrite, filtering and ordering; i.e. + // where Id = $4 and Name = $3 -- no mention of $1 or $2 + // could be + // where Id = $1 and Name = $2 + args = inputArgs; + return OrdinalResult.NoChange; + case ParameterKind.Nominal: + break; // below + default: + args = inputArgs; + return OrdinalResult.UnsupportedScenario; + } + + Debug.Assert(kind == ParameterKind.Nominal); + + var map = new Dictionary(Variables.Length, StringComparer.OrdinalIgnoreCase); + args = new List(); + var sb = new StringBuilder(sql); + int delta = 0; // change in length of string + foreach (var queryArg in Variables) + { + if (!map.TryGetValue(queryArg.Name, out var finalArg)) + { + if (!TryFindByName(queryArg.Name, inputArgs, argName, out var found)) + { + args = inputArgs; + return OrdinalResult.UnsupportedScenario; + } + finalArg = argFactory(found, argIndex++); + map.Add(queryArg.Name, finalArg); + args.Add(finalArg); + } + var newName = argName(finalArg); + // could potentially be more efficient with forwards-only write + sb.Remove(queryArg.Index + delta, queryArg.Name.Length); + sb.Insert(queryArg.Index + delta, newName); + delta += newName.Length - queryArg.Name.Length; + } + sql = sb.ToString(); + return sql == Sql ? OrdinalResult.NoChange : OrdinalResult.Success; + } + + internal static Func OrdinalNaming { get; } = (name, index) => $"${index + 1}"; +} + +internal enum ParameterKind +{ + NonParametrized, + Mixed, + Ordinal, // $1 + Nominal, // @foo + Unknown, +} +internal enum OrdinalResult +{ + NoChange, + MixedParameters, + Success, + UnsupportedScenario, +} + +internal static class GeneralSqlParser +{ + private enum ParseState + { + None, + Token, + Variable, + LineComment, + BlockComment, + String, + Whitespace, + } + + /// + /// Tokenize a sql fragment into batches, extracting the variables/locals in use + /// + /// This is a basic parse only; no syntax processing - just literals, identifiers, etc + public static List Parse(string sql, SqlSyntax syntax, bool strip = false) + { + // this is a basic first pass; TODO: rewrite using a forwards seek approach, i.e. + // "find first [@$:'"...] etc, copy backfill then search for end of that symbol and process + // accordingly + + int bIndex = 0, parenDepth = 0; + char[] buffer = ArrayPool.Shared.Rent(sql.Length + 1); + + char stringType = '\0'; + var state = ParseState.None; + int i = 0, elementStartbIndex = 0; + ImmutableArray.Builder? variables = null; + var result = new List(); + + bool BatchSemicolon() => syntax == SqlSyntax.PostgreSql; + + char LookAhead(int delta = 1) + { + var ci = i + delta; + return ci >= 0 && ci < sql.Length ? sql[ci] : '\0'; + } + char Last(int offset) + { + var ci = bIndex - (offset + 1); + return ci >= 0 && ci < bIndex ? buffer[ci] : '\0'; + } + char LookBehind(int delta = 1) => LookAhead(-delta); + void Discard() => bIndex--; + void NormalizeSpace() + { + if (strip) + { + if (bIndex > 1 && buffer[bIndex - 2] == ' ') + { + Discard(); + } + else + { + buffer[bIndex - 1] = ' '; + } + } + } + bool ActivateStringPrefix() + { + if (ElementLength() == 2) // already written, so: N'... E'... etc + { + stringType = Last(0); + return true; + }; + return false; + } + void SkipLeadingWhitespace(char v) + { + if (bIndex == 1 && ((v is '\r' or '\n') || strip)) + { + // always omit leading CRLFs; omit leading whitespace + // when stripping + Discard(); + } + else if (strip && Last(0) == ';') + { + Discard(); // don't write whitespace after ; + } + else + { + NormalizeSpace(); + } + } + int ElementLength() => bIndex - elementStartbIndex + 1; + + void FlushBatch() + { + if (IsGo()) bIndex -= 2; // don't retain the GO + + //bool removedSemicolon = false; + if ((strip || BatchSemicolon()) && Last(0) == ';') + { + Discard(); + //removedSemicolon = true; + } + + if (strip) // remove trailing whitespace + { + while (bIndex > 0 && char.IsWhiteSpace(buffer[bIndex - 1])) + { + bIndex--; + } + } + + if (!IsWhitespace()) // anything left? + { + //if (removedSemicolon) + //{ + // // reattach + // buffer[bIndex++] = ';'; + //} + + var batchSql = new string(buffer, 0, bIndex); + var args = variables is null ? ImmutableArray.Empty : variables.ToImmutable(); + result.Add(CommandBatch.Create(batchSql, args)); + } + // logical reset + bIndex = 0; + variables?.Clear(); + state = ParseState.None; + + // lose any same-line simple space between batches + while (LookAhead() is ' ' or '\t') + { + i++; // same as Advance();Discard(); + } + } + bool IsWhitespace() + { + if (bIndex == 0) return true; + for (int i = 0; i < bIndex; i++) + { + if (!char.IsWhiteSpace(buffer[i])) return false; + } + return true; + } + bool IsGo() + { + return syntax == SqlSyntax.SqlServer && ElementLength() == 2 + && Last(1) is 'g' or 'G' && Last(0) is 'o' or 'O'; + } + void FlushVariable() + { + int varLen = ElementLength(), varStart = bIndex - varLen; + var name = new string(buffer, varStart, varLen); + variables ??= ImmutableArray.CreateBuilder(); + variables.Add(new(name, varStart)); + } + + bool IsString(char c) => state == ParseState.String && stringType == c; + + bool IsSingleQuoteString() => state == ParseState.String && (stringType == '\'' || char.IsLetter(stringType)); + void Advance() => buffer[bIndex++] = sql[++i]; + + for (; i < sql.Length; i++) + { + var c = i == sql.Length ? ';' : sql[i]; // spoof a ; at the end to simplify end-of-block handling + + // detect end of GO token + if (state == ParseState.Token && !IsToken(c) && IsGo()) + { + FlushBatch(); // and keep going + } + else if (state == ParseState.Variable && !IsMidToken(c)) + { + FlushVariable(); + } + + // store by default, we'll backtrack in the rare scenarios that we don't want it + buffer[bIndex++] = sql[i]; + + switch (state) + { + case ParseState.Whitespace when char.IsWhiteSpace(c): // more contiguous whitespace + if (strip) Discard(); + else SkipLeadingWhitespace(c); + continue; + case ParseState.LineComment when c is '\r' or '\n': // end of line comment + case ParseState.BlockComment when c == '/' && LookBehind() == '*': // end of block comment + if (strip) Discard(); + else NormalizeSpace(); + state = ParseState.Whitespace; + continue; + case ParseState.BlockComment or ParseState.LineComment: // keep ignoring line comment + if (strip) Discard(); + continue; + // string-escape characters + case ParseState.String when c == '\'' && IsSingleQuoteString() && LookAhead() == '\'': // [?]'...''...' + case ParseState.String when c == '"' && IsString('"') && LookAhead() == '\"': // "...""..." + case ParseState.String when c == '\\' && (IsString('E') || IsString('e')) && LookAhead() != '\0' && AllowEscapedStrings(): // E'...\*...' + case ParseState.String when c == ']' && IsString('[') && LookAhead() == ']': // [...]]...] + // escaped or double-quote; move forwards immediately + Advance(); + continue; + // end string + case ParseState.String when c == '"' && IsString('"'): // "....." + case ParseState.String when c == ']' && IsString('['): // [.....] + case ParseState.String when c == '\'' && IsSingleQuoteString(): // [?]'....' + state = ParseState.None; + continue; + case ParseState.String: + // ongoing string content + continue; + case ParseState.Token when c == '\'' && ActivateStringPrefix(): // E'..., N'... etc + continue; + case ParseState.Token or ParseState.Variable when IsMidToken(c): + // ongoing token / variable content + continue; + case ParseState.Variable: // end of variable + case ParseState.Whitespace: // end of whitespace chunk + case ParseState.Token: // end of token + case ParseState.None: // not started + state = ParseState.None; + break; // make sure we still handle the value, below + default: + throw new InvalidOperationException($"Token kind not handled: {state}"); + } + + if (c == '-' && LookAhead() == '-') + { + state = ParseState.LineComment; + if (strip) Discard(); + continue; + } + if (c == '/' && LookAhead() == '*') + { + state = ParseState.BlockComment; + if (strip) Discard(); + continue; + } + + if (c == '(') parenDepth++; + if (c == ')') parenDepth--; + if (c == ';') + { + if (BatchSemicolon() && parenDepth == 0) + { + FlushBatch(); + continue; + } + + // otherwise end-statement + // (prevent unnecessary additional whitespace when stripping) + state = ParseState.Whitespace; + if (strip && Last(1) == ';') + { // squash down to just one + Discard(); + } + continue; + } + + if (char.IsWhiteSpace(c)) + { + SkipLeadingWhitespace(c); + state = ParseState.Whitespace; + continue; + } + + elementStartbIndex = bIndex; + + if (c is '"' or '\'' or '[') // TODO: '$' dollar quoting + { + // start a new string + state = ParseState.String; + stringType = c; + continue; + } + + if (c is '$' && AllowDollarQuotedStrings()) + { + TryReadDollarQuotedString(); + continue; + } + + if (IsParameterPrefix(c) + && IsToken(LookAhead()) && LookBehind() != c) // avoid @>, @@IDENTTIY etc + { + // start a new variable + state = ParseState.Variable; + continue; + } + + if (IsToken(c)) + { + // start a new token + state = ParseState.Token; + continue; + } + + // other arbitrary syntax - operators etc + } + + // deal with any remaining bits + if (state == ParseState.Variable) FlushVariable(); + if (BatchSemicolon()) + { + // spoof a final ; + buffer[bIndex++] = ';'; + } + FlushBatch(); + + ArrayPool.Shared.Return(buffer); + + return result; + + bool IsMidToken(char c) => IsToken(c) + || (syntax == SqlSyntax.PostgreSql && ( + (c == '$' && state != ParseState.Variable) // postgresql identifiers can contain $, but variables can't + || (c == '.' && state == ParseState.Variable) // postgresql mapped variables (only) can contain . + )); + + bool IsToken(char c) => c == '_' || char.IsLetterOrDigit(c); + + bool AllowEscapedStrings() => syntax == SqlSyntax.PostgreSql; + bool AllowDollarQuotedStrings() => syntax == SqlSyntax.PostgreSql; + + void TryReadDollarQuotedString() + { + // https://www.postgresql.org/docs/current/sql-syntax-lexical.html#SQL-SYNTAX-DOLLAR-QUOTING + + // A dollar-quoted string constant consists of a dollar sign ($), an optional “tag” of zero or more characters, + // another dollar sign, an arbitrary sequence of characters that makes up the string content, a dollar sign, + // the same tag that began this dollar quote, and a dollar sign. + + // The tag, if any, of a dollar-quoted string follows the same rules as an + // unquoted identifier, except that it cannot contain a dollar sign + + // note this is too complex to process in the iterative way; we'll switch to forwards looking + int len = 1; // $ + while (true) + { + var next = LookAhead(len++); + if (next == '$') break; // found end of marker + if (!IsToken(next)) return; // covers _, letters and digits + if (len == 2 && char.IsDigit(next)) return; // identifier rules; cannot start with digit + } + + var sqlSpan = sql.AsSpan(); + var hunt = sqlSpan.Slice(i, len); + var remaining = sqlSpan.Slice(i + len); + var found = remaining.IndexOf(hunt); + if (found < 0) return; // non-terminated; ignore + + var toCopy = len * 2 + found - 1; // we already copied the first $ + for (int j = 0; j < toCopy; j++) + { + Advance(); + } + + } + } + + internal static bool IsParameterPrefix(char c) + => SqlTools.ParameterPrefixCharacters.IndexOf(c) >= 0; +} + + diff --git a/src/Dapper.AOT.Analyzers/SqlAnalysis/TSqlProcessor.cs b/src/Dapper.AOT.Analyzers/SqlAnalysis/TSqlProcessor.cs index 875fa09c..5e15e454 100644 --- a/src/Dapper.AOT.Analyzers/SqlAnalysis/TSqlProcessor.cs +++ b/src/Dapper.AOT.Analyzers/SqlAnalysis/TSqlProcessor.cs @@ -257,7 +257,8 @@ protected virtual void OnInvalidDatepartToken(Location location) => OnError($"Valid datepart token expected", location); protected virtual void OnTopWithOffset(Location location) => OnError($"TOP cannot be used when OFFSET is specified", location); - + protected virtual void OnDangerousNonDelimitedIdentifier(Location location, string name) + => OnError($"The identifier '{name}' can be confusing when not delimited", location); internal readonly struct Location { @@ -756,6 +757,15 @@ static bool IsAnyCaseInsensitive(string value, string[] options) "nanosecond", "ns" ]; + public override void Visit(Identifier node) + { + if (node.QuoteType == QuoteType.NotQuoted && string.Equals("GO", node.Value, StringComparison.OrdinalIgnoreCase)) + { + parser.OnDangerousNonDelimitedIdentifier(node, node.Value); + } + base.Visit(node); + } + private void ValidateDateArg(ScalarExpression value) { if (!(value is ColumnReferenceExpression col diff --git a/src/Dapper.AOT/CommandFactory.cs b/src/Dapper.AOT/CommandFactory.cs index ea0e0dda..cab68b2d 100644 --- a/src/Dapper.AOT/CommandFactory.cs +++ b/src/Dapper.AOT/CommandFactory.cs @@ -159,6 +159,57 @@ protected static bool TryRecycle(ref DbCommand? storage, DbCommand command) command.Transaction = null; return Interlocked.CompareExchange(ref storage, command, null) is null; } + + +#if NET6_0_OR_GREATER + /// + /// Provides an opportunity to recycle and reuse batch instances + /// + protected static bool TryRecycle(ref DbBatch? storage, DbBatch batch) + { + // detach and recycle + batch.Connection = null; + batch.Transaction = null; + return Interlocked.CompareExchange(ref storage, batch, null) is null; + } + + /// + /// Provides an opportunity to recycle and reuse batch instances + /// + public virtual bool TryRecycle(DbBatch batch) => false; +#endif + + /// + /// Creates and initializes new instances. + /// + public virtual DbParameter CreateNewParameter(in UnifiedCommand command) + => command.DefaultCreateParameter(); + + /// + /// Creates and initializes new instances. + /// + public virtual DbCommand CreateNewCommand(DbConnection connection) + => connection.CreateCommand(); + +#if NET6_0_OR_GREATER + /// + /// Creates and initializes new instances. + /// + public virtual DbBatch CreateNewBatch(DbConnection connection) + => connection.CreateBatch(); + + /// + /// Creates and initializes new instances. + /// + public virtual DbBatchCommand CreateNewCommand(DbBatch batch) + => batch.CreateBatchCommand(); +#endif + + + /// + /// Indicates where it is required to invoke post-operation logic to update parameter values. + /// + public virtual bool RequirePostProcess => false; } /// @@ -182,18 +233,13 @@ public class CommandFactory : CommandFactory public virtual DbCommand GetCommand(DbConnection connection, string sql, CommandType commandType, T args) { // default behavior assumes no args, no special logic - var cmd = connection.CreateCommand(); - Initialize(new(cmd), sql, commandType, args); + var cmd = CreateNewCommand(connection); + var unified = new UnifiedCommand(this, cmd); + unified.SetCommand(sql, commandType); + AddParameters(in unified, args); return cmd; } - internal void Initialize(in UnifiedCommand cmd, - string sql, CommandType commandType, T args) - { - cmd.CommandText = sql; - cmd.CommandType = commandType != 0 ? commandType : DapperAotExtensions.GetCommandType(sql); - AddParameters(in cmd, args); - } internal override sealed void PostProcessObject(in UnifiedCommand command, object? args, int rowCount) => PostProcess(in command, (T)args!, rowCount); @@ -214,9 +260,10 @@ public virtual void AddParameters(in UnifiedCommand command, T args) /// public virtual void UpdateParameters(in UnifiedCommand command, T args) { - if (command.Parameters.Count != 0) // try to avoid rogue "dirty" checks + var ps = command.Parameters; + if (ps.Count != 0) // try to avoid rogue "dirty" checks { - command.Parameters.Clear(); + ps.Clear(); } AddParameters(in command, args); } @@ -232,14 +279,80 @@ public virtual void UpdateParameters(in UnifiedCommand command, T args) // try to avoid any dirty detection in the setters if (cmd.CommandText != sql) cmd.CommandText = sql; if (cmd.CommandType != commandType) cmd.CommandType = commandType; - UpdateParameters(new(cmd), args); + UpdateParameters(new UnifiedCommand(this, cmd), args); } return cmd; } +#pragma warning disable IDE0079 // following will look unnecessary on up-level +#pragma warning disable CS1574 // DbBatchCommand will not resolve on down-level TFMs /// - /// Indicates where it is required to invoke post-operation logic to update parameter values. + /// Indicates whether the factory wishes to split this command into a multi-command batch. /// - public virtual bool RequirePostProcess => false; + /// This may or may not be implemented using , depending on the capabilities + /// of the runtime and ADO.NET provider. + /// #pragma warning disable IDE0079 // following will look unnecessary on up-level +#pragma warning restore CS1574 // DbBatchCommand will not resolve on down-level TFMs +#pragma warning restore IDE0079 // following will look unnecessary on up-level + public virtual bool UseBatch(string sql) => false; + +#if NET6_0_OR_GREATER + /// + /// Create a populated batch from a command + /// + public virtual DbBatch GetBatch(DbConnection connection, string sql, CommandType commandType, T args) + { + Debug.Assert(connection.CanCreateBatch); + var batch = CreateNewBatch(connection); + // initialize with a command + batch.BatchCommands.Add(CreateNewCommand(batch)); + AddCommands(new(this, batch), sql, args); + return batch; + } + + /// + /// Provides an opportunity to recycle and reuse batch instances + /// + protected DbBatch? TryReuse(ref DbBatch? storage, T args) + { + var batch = Interlocked.Exchange(ref storage, null); + if (batch is not null) + { + // try to avoid any dirty detection in the setters + UpdateParameters(new UnifiedBatch(this, batch), args); + } + return batch; + } +#endif + + /// + /// Allows the caller to rewrite a composite command into a multi-command batch. + /// + public virtual void AddCommands(in UnifiedBatch batch, string sql, T args) + { + // implement as basic mode + batch.SetCommand(sql, CommandType.Text); + AddParameters(in batch.Command, args); + } + + /// + /// Allows the caller to update the parameter values of a multi-command batch. + /// + public virtual void UpdateParameters(in UnifiedBatch batch, T args) + { + UpdateParameters(in batch.Command, args); + } + + /// + /// Allows an implementation to process output parameters etc after a multi-command batch has completed. + /// + /// This API is only invoked when reported true, and + /// corresponds to + public virtual void PostProcess(in UnifiedBatch batch, T args, int rowCount) { } + + internal void PostProcess(in UnifiedCommand command, TArgs? val, object recordsAffected) + { + throw new NotImplementedException(); + } } \ No newline at end of file diff --git a/src/Dapper.AOT/CommandT.Batch.cs b/src/Dapper.AOT/CommandT.Batch.cs index 65e10286..26675730 100644 --- a/src/Dapper.AOT/CommandT.Batch.cs +++ b/src/Dapper.AOT/CommandT.Batch.cs @@ -1,10 +1,9 @@ using Dapper.Internal; using System; +using System.Buffers; using System.Collections.Generic; using System.Collections.Immutable; -using System.Data.Common; using System.Diagnostics; -using System.Linq; using System.Runtime.CompilerServices; using System.Runtime.InteropServices; using System.Threading; @@ -142,24 +141,6 @@ public Task ExecuteAsync(TArgs[] values, int offset, int count, int batchSi }; } - internal void Recycle(ref SyncCommandState state) - { - Debug.Assert(state.Command is not null); - if (commandFactory.TryRecycle(state.Command!)) - { - state.Command = null; - } - } - - internal void Recycle(AsyncCommandState state) - { - Debug.Assert(state.Command is not null); - if (commandFactory.TryRecycle(state.Command!)) - { - state.Command = null; - } - } - [MethodImpl(MethodImplOptions.AggressiveInlining)] private int ExecuteMulti(ReadOnlySpan source, int batchSize) { @@ -180,7 +161,7 @@ private int ExecuteMultiSequential(ReadOnlySpan source) var current = source[0]; var local = state.ExecuteNonQuery(GetCommand(current)); - UnifiedCommand cmdState = new(state.Command); + UnifiedCommand cmdState = new(commandFactory, state.Command); commandFactory.PostProcess(in cmdState, current, local); total += local; @@ -193,7 +174,7 @@ private int ExecuteMultiSequential(ReadOnlySpan source) total += local; } - Recycle(ref state); + state.UnifiedBatch.TryRecycle(); return total; } finally @@ -206,68 +187,140 @@ private int ExecuteMultiSequential(ReadOnlySpan source) [MethodImpl(MethodImplOptions.AggressiveInlining)] private bool UseBatch(int batchSize) => batchSize != 0 && connection is { CanCreateBatch: true }; - private DbBatchCommand AddCommand(ref UnifiedCommand state, TArgs args) + void Add(ref UnifiedBatch batch, TArgs args) { - var cmd = state.UnsafeCreateNewCommand(); - commandFactory.Initialize(state, sql, commandType, args); - state.AssertBatchCommands.Add(cmd); - return cmd; + if (batch.IsDefault) + { + // create a new batch initialized on a ready command + batch = new(commandFactory, connection, transaction); + if (timeout >= 0) batch.TimeoutSeconds = timeout; + Debug.Assert(batch.Mode is BatchMode.MultiCommandDbBatchCommand); // current expectations + batch.SetCommand(sql, commandType); + commandFactory.AddParameters(in batch.Command, args); + } + else if (batch.IsLastCommand) + { + // create a new command at the end of the batch + batch.CreateNextBatchGroup(sql, commandType); + commandFactory.AddParameters(in batch.Command, args); + } + else + { + // overwriting + batch.OverwriteNextBatchGroup(); + commandFactory.UpdateParameters(in batch.Command, args); + } } - private int ExecuteMultiBatch(ReadOnlySpan source, int batchSize) // TODO: sub-batching + private int ExecuteMultiBatch(ReadOnlySpan source, int batchSize) { Debug.Assert(source.Length > 1); - UnifiedCommand batch = default; + SyncCommandState state = default; + // note that we currently only use single-command-per-TArg mode, i.e. UseBatch is ignored try { + int sum = 0, ppOffset = 0; foreach (var arg in source) { - if (!batch.HasBatch) batch = new(connection.CreateBatch()); - AddCommand(ref batch, arg); + Add(ref state.UnifiedBatch, arg); + if (state.UnifiedBatch.GroupCount == batchSize) + { + sum += state.ExecuteNonQueryUnified(); + PostProcessMultiBatch(in state.UnifiedBatch, ref ppOffset, source); + } } - if (!batch.HasBatch) return 0; - - var result = batch.AssertBatch.ExecuteNonQuery(); - - if (commandFactory.RequirePostProcess) + if (state.UnifiedBatch.GroupCount != 0) { - batch.PostProcess(source, commandFactory); + state.UnifiedBatch.Trim(); + sum += state.ExecuteNonQueryUnified(); + PostProcessMultiBatch(in state.UnifiedBatch, ref ppOffset, source); } - return result; + state.UnifiedBatch.TryRecycle(); + return sum; } finally { - batch.Cleanup(); + state.Dispose(); } } - private int ExecuteMultiBatch(IEnumerable source, int batchSize) // TODO: sub-batching + + private void PostProcessMultiBatch(in UnifiedBatch batch, ReadOnlySpan args) { + int i = 0; + PostProcessMultiBatch(batch, ref i, args); + } + + private void PostProcessMultiBatch(in UnifiedBatch batch, ref int argOffset, ReadOnlySpan args) + { + // TODO: we'd need to buffer the sub-batch from IEnumerable + if (batch.IsDefault) return; + + // assert that we currently expect only single commands per element + if (batch.Command.CommandCount != batch.GroupCount) Throw(); + if (commandFactory.RequirePostProcess) { - // try to ensure it is repeatable - source = (source as IReadOnlyCollection) ?? source.ToList(); + batch.Command.UnsafeMoveTo(0); + foreach (var val in args.Slice(argOffset, batch.GroupCount)) + { + commandFactory.PostProcess(in batch.Command, val, batch.Command.RecordsAffected); + batch.Command.UnsafeAdvance(); + } + argOffset += batch.GroupCount; } + // prepare for the next batch, if one + batch.UnsafeMoveBeforeFirst(); + + static void Throw() => throw new InvalidOperationException("The number of operations should have matched the number of groups!"); + } + + private TArgs[]? GetMultiBatchBuffer(ref int batchSize) + { + if (!commandFactory.RequirePostProcess) return null; // no problem, then - UnifiedCommand batch = default; + const int MAX_SIZE = 1024; + if (batchSize < 0 || batchSize > 1024) batchSize = MAX_SIZE; + + return ArrayPool.Shared.Rent(batchSize); + } + private static void RecycleMultiBatchBuffer(TArgs[]? buffer) + { + if (buffer is not null) ArrayPool.Shared.Return(buffer); + } + + private int ExecuteMultiBatch(IEnumerable source, int batchSize) + { + SyncCommandState state = default; + var buffer = GetMultiBatchBuffer(ref batchSize); try { + int sum = 0, ppOffset = 0; foreach (var arg in source) { - if (!batch.HasBatch) batch = new(connection.CreateBatch()); - AddCommand(ref batch, arg); + Add(ref state.UnifiedBatch, arg); + if (buffer is not null) buffer[ppOffset++] = arg; + + if (state.UnifiedBatch.GroupCount == batchSize) + { + sum += state.ExecuteNonQueryUnified(); + PostProcessMultiBatch(in state.UnifiedBatch, buffer); + ppOffset = 0; + } } - if (!batch.HasBatch) return 0; - var result = batch.AssertBatch.ExecuteNonQuery(); - if (commandFactory.RequirePostProcess) + if (state.UnifiedBatch.GroupCount != 0) { - batch.PostProcess(source, commandFactory); + state.UnifiedBatch.Trim(); + sum += state.ExecuteNonQueryUnified(); + PostProcessMultiBatch(in state.UnifiedBatch, buffer); } - return result; + state.UnifiedBatch.TryRecycle(); + return sum; } finally { - batch.Cleanup(); + RecycleMultiBatchBuffer(buffer); + state.Dispose(); } } #endif @@ -280,7 +333,7 @@ private int ExecuteMulti(IEnumerable source, int batchSize) #endif return ExecuteMultiSequential(source); } - + private int ExecuteMultiSequential(IEnumerable source) { SyncCommandState state = default; @@ -294,7 +347,7 @@ private int ExecuteMultiSequential(IEnumerable source) bool haveMore = iterator.MoveNext(); if (haveMore && commandFactory.CanPrepare) state.PrepareBeforeExecute(); var local = state.ExecuteNonQuery(GetCommand(current)); - UnifiedCommand cmdState = new(state.Command); + UnifiedCommand cmdState = new(commandFactory, state.Command); commandFactory.PostProcess(in cmdState, current, local); total += local; @@ -308,7 +361,7 @@ private int ExecuteMultiSequential(IEnumerable source) haveMore = iterator.MoveNext(); } - Recycle(ref state); + state.UnifiedBatch.TryRecycle(); return total; } return total; @@ -385,16 +438,16 @@ public Task ExecuteAsync(ImmutableArray values, int batchSize = -1, [MethodImpl(MethodImplOptions.AggressiveInlining)] private Task ExecuteMultiAsync(ReadOnlyMemory source, int batchSize, CancellationToken cancellationToken) { -//#if NET6_0_OR_GREATER -// if (UseBatch(batchSize)) return ExecuteMultiBatchAsync(source, batchSize, cancellationToken); -//#endif + //#if NET6_0_OR_GREATER + // if (UseBatch(batchSize)) return ExecuteMultiBatchAsync(source, batchSize, cancellationToken); + //#endif return ExecuteMultiSequentialAsync(source, cancellationToken); } private async Task ExecuteMultiSequentialAsync(ReadOnlyMemory source, CancellationToken cancellationToken) { Debug.Assert(source.Length > 1); - AsyncCommandState state = new(); + var state = AsyncCommandState.Create(); try { if (commandFactory.CanPrepare) state.PrepareBeforeExecute(); @@ -402,7 +455,7 @@ private async Task ExecuteMultiSequentialAsync(ReadOnlyMemory source var current = source.Span[0]; var local = await state.ExecuteNonQueryAsync(GetCommand(current), cancellationToken); - UnifiedCommand cmdState = new(state.Command); + UnifiedCommand cmdState = new(commandFactory, state.Command); commandFactory.PostProcess(in cmdState, current, local); total += local; @@ -415,37 +468,41 @@ private async Task ExecuteMultiSequentialAsync(ReadOnlyMemory source total += local; } - Recycle(state); + state.UnifiedBatch.TryRecycle(); return total; } finally { await state.DisposeAsync(); + state.Recycle(); } } [MethodImpl(MethodImplOptions.AggressiveInlining)] private Task ExecuteMultiAsync(IAsyncEnumerable source, int batchSize, CancellationToken cancellationToken) { -//#if NET6_0_OR_GREATER -// if (UseBatch(batchSize)) return ExecuteMultiBatchAsync(source, batchSize, cancellationToken); -//#endif + //#if NET6_0_OR_GREATER + // if (UseBatch(batchSize)) return ExecuteMultiBatchAsync(source, batchSize, cancellationToken); + //#endif return ExecuteMultiSequentialAsync(source, cancellationToken); } private async Task ExecuteMultiSequentialAsync(IAsyncEnumerable source, CancellationToken cancellationToken) { - AsyncCommandState state = new(); + AsyncCommandState? state = null; var iterator = source.GetAsyncEnumerator(cancellationToken); try { int total = 0; if (await iterator.MoveNextAsync()) { + state ??= AsyncCommandState.Create(); + var current = iterator.Current; bool haveMore = await iterator.MoveNextAsync(); if (haveMore && commandFactory.CanPrepare) state.PrepareBeforeExecute(); + var local = await state.ExecuteNonQueryAsync(GetCommand(current), cancellationToken); - UnifiedCommand cmdState = new(state.Command); + UnifiedCommand cmdState = new(commandFactory, state.Command); commandFactory.PostProcess(in cmdState, current, local); total += local; @@ -459,41 +516,46 @@ private async Task ExecuteMultiSequentialAsync(IAsyncEnumerable sour haveMore = await iterator.MoveNextAsync(); } - Recycle(state); - return total; + state.UnifiedBatch.TryRecycle(); } return total; } finally { await iterator.DisposeAsync(); - await state.DisposeAsync(); + if (state is not null) + { + await state.DisposeAsync(); + state.Recycle(); + } } } [MethodImpl(MethodImplOptions.AggressiveInlining)] private Task ExecuteMultiAsync(IEnumerable source, int batchSize, CancellationToken cancellationToken) { -//#if NET6_0_OR_GREATER -// if (UseBatch(batchSize)) return ExecuteMultiBatchAsync(source, batchSize, cancellationToken); -//#endif +#if NET6_0_OR_GREATER + if (UseBatch(batchSize)) return ExecuteMultiBatchAsync(source, batchSize, cancellationToken); +#endif return ExecuteMultiSequentialAsync(source, cancellationToken); } private async Task ExecuteMultiSequentialAsync(IEnumerable source, CancellationToken cancellationToken) { - AsyncCommandState state = new(); + AsyncCommandState? state = null; var iterator = source.GetEnumerator(); try { int total = 0; if (iterator.MoveNext()) { + state ??= AsyncCommandState.Create(); + var current = iterator.Current; bool haveMore = iterator.MoveNext(); if (haveMore && commandFactory.CanPrepare) state.PrepareBeforeExecute(); var local = await state.ExecuteNonQueryAsync(GetCommand(current), cancellationToken); - UnifiedCommand cmdState = new(state.Command); + UnifiedCommand cmdState = new(commandFactory, state.Command); commandFactory.PostProcess(in cmdState, current, local); total += local; @@ -507,18 +569,66 @@ private async Task ExecuteMultiSequentialAsync(IEnumerable source, C haveMore = iterator.MoveNext(); } - Recycle(state); - return total; + state.UnifiedBatch.TryRecycle(); } return total; } finally { iterator.Dispose(); - await state.DisposeAsync(); + if (state is not null) + { + await state.DisposeAsync(); + state.Recycle(); + } } } +#if NET6_0_OR_GREATER + private async Task ExecuteMultiBatchAsync(IEnumerable source, int batchSize, CancellationToken cancellationToken) + { + AsyncCommandState? state = null; + var buffer = GetMultiBatchBuffer(ref batchSize); + try + { + int sum = 0, ppOffset = 0; + foreach (var arg in source) + { + state ??= AsyncCommandState.Create(); + Add(ref state.UnifiedBatch, arg); + if (buffer is not null) buffer[ppOffset++] = arg; + + if (state.UnifiedBatch.GroupCount == batchSize) + { + sum += await state.ExecuteNonQueryUnifiedAsync(cancellationToken); + PostProcessMultiBatch(in state.UnifiedBatch, buffer); + ppOffset = 0; + } + } + + if (state is not null) + { + if (state.UnifiedBatch.GroupCount != 0) + { + sum += await state.ExecuteNonQueryUnifiedAsync(cancellationToken); + PostProcessMultiBatch(in state.UnifiedBatch, buffer); + } + state.UnifiedBatch.TryRecycle(); + } + return sum; + } + finally + { + RecycleMultiBatchBuffer(buffer); + if (state is not null) + { + await state.DisposeAsync(); + state.Recycle(); + } + } + } +#endif + [MethodImpl(MethodImplOptions.AggressiveInlining)] private Task ExecuteMultiAsync(TArgs[] source, int offset, int count, int batchSize, CancellationToken cancellationToken) { @@ -529,7 +639,8 @@ private Task ExecuteMultiAsync(TArgs[] source, int offset, int count, int b } private async Task ExecuteMultiSequentialAsync(TArgs[] source, int offset, int count, CancellationToken cancellationToken) { - AsyncCommandState state = new(); + Debug.Assert(count > 0); + var state = AsyncCommandState.Create(); try { // count is now actually "end" @@ -540,7 +651,7 @@ private async Task ExecuteMultiSequentialAsync(TArgs[] source, int offset, var current = source[offset++]; var local = await state.ExecuteNonQueryAsync(GetCommand(current), cancellationToken); - UnifiedCommand cmdState = new(state.Command); + UnifiedCommand cmdState = new(commandFactory, state.Command); commandFactory.PostProcess(in cmdState, current, local); total += local; @@ -553,41 +664,54 @@ private async Task ExecuteMultiSequentialAsync(TArgs[] source, int offset, total += local; } - Recycle(state); + state.UnifiedBatch.TryRecycle(); return total; } finally { await state.DisposeAsync(); + state.Recycle(); } } #if NET6_0_OR_GREATER - private async Task ExecuteMultiBatchAsync(TArgs[] source, int offset, int count, int batchSize, CancellationToken cancellationToken) // TODO: sub-batching + private async Task ExecuteMultiBatchAsync(TArgs[] source, int offset, int count, int batchSize, CancellationToken cancellationToken) { Debug.Assert(source.Length > 1); - UnifiedCommand batch = default; + AsyncCommandState? state = null; var end = offset + count; try { - for (int i = offset ; i < end; i++) + int sum = 0, ppOffset = offset; + for (int i = offset; i < end; i++) { - if (!batch.HasBatch) batch = new(connection.CreateBatch()); - AddCommand(ref batch, source[i]); + state ??= AsyncCommandState.Create(); + Add(ref state.UnifiedBatch, source[i]); + if (state.UnifiedBatch.GroupCount == batchSize) + { + sum += await state.ExecuteNonQueryUnifiedAsync(cancellationToken); + PostProcessMultiBatch(in state.UnifiedBatch, ref ppOffset, source); + } } - if (!batch.HasBatch) return 0; - var result = await batch.AssertBatch.ExecuteNonQueryAsync(cancellationToken); - - if (commandFactory.RequirePostProcess) + if (state is not null) { - batch.PostProcess(new ReadOnlySpan(source, offset, count), commandFactory); + if (state.UnifiedBatch.GroupCount != 0) + { + sum += await state.ExecuteNonQueryUnifiedAsync(cancellationToken); + PostProcessMultiBatch(in state.UnifiedBatch, ref ppOffset, source); + } + state.UnifiedBatch.TryRecycle(); } - return result; + return sum; } finally { - batch.Cleanup(); + if (state is not null) + { + await state.DisposeAsync(); + state.Recycle(); + } } } #endif diff --git a/src/Dapper.AOT/CommandT.Execute.cs b/src/Dapper.AOT/CommandT.Execute.cs index 461ece74..2a975b32 100644 --- a/src/Dapper.AOT/CommandT.Execute.cs +++ b/src/Dapper.AOT/CommandT.Execute.cs @@ -14,8 +14,9 @@ public int Execute(TArgs args) SyncCommandState state = default; try { - var result = state.ExecuteNonQuery(GetCommand(args)); - PostProcessAndRecycle(ref state, args, result); + GetUnifiedBatch(out state.UnifiedBatch, args); + var result = state.ExecuteNonQueryUnified(); + PostProcessAndRecycleUnified(state.UnifiedBatch, args, result); return result; } finally @@ -29,16 +30,18 @@ public int Execute(TArgs args) /// public async Task ExecuteAsync(TArgs args, CancellationToken cancellationToken = default) { - AsyncCommandState state = new(); + var state = AsyncCommandState.Create(); try { - var result = await state.ExecuteNonQueryAsync(GetCommand(args), cancellationToken); - PostProcessAndRecycle(state, args, result); + GetUnifiedBatch(out state.UnifiedBatch, args); + var result = await state.ExecuteNonQueryUnifiedAsync(cancellationToken); + PostProcessAndRecycleUnified(in state.UnifiedBatch, args, result); return result; } finally { await state.DisposeAsync(); + state.Recycle(); } } } diff --git a/src/Dapper.AOT/CommandT.ExecuteScalar.cs b/src/Dapper.AOT/CommandT.ExecuteScalar.cs index 89cdab3e..fdc5e94b 100644 --- a/src/Dapper.AOT/CommandT.ExecuteScalar.cs +++ b/src/Dapper.AOT/CommandT.ExecuteScalar.cs @@ -1,4 +1,5 @@ using Dapper.Internal; +using System.Data.Common; using System.Threading; using System.Threading.Tasks; @@ -29,7 +30,7 @@ partial struct Command /// public async Task ExecuteScalarAsync(TArgs args, CancellationToken cancellationToken = default) { - AsyncCommandState state = new(); + var state = AsyncCommandState.Create(); try { var result = await state.ExecuteScalarAsync(GetCommand(args), cancellationToken); @@ -39,6 +40,7 @@ partial struct Command finally { await state.DisposeAsync(); + state.Recycle(); } } @@ -65,7 +67,7 @@ public T ExecuteScalar(TArgs args) /// public async Task ExecuteScalarAsync(TArgs args, CancellationToken cancellationToken = default) { - AsyncCommandState state = new(); + var state = AsyncCommandState.Create(); try { var result = await state.ExecuteScalarAsync(GetCommand(args), cancellationToken); @@ -75,6 +77,7 @@ public async Task ExecuteScalarAsync(TArgs args, CancellationToken cancell finally { await state.DisposeAsync(); + state.Recycle(); } } } diff --git a/src/Dapper.AOT/CommandT.Query.cs b/src/Dapper.AOT/CommandT.Query.cs index 069b28b1..57ebff1a 100644 --- a/src/Dapper.AOT/CommandT.Query.cs +++ b/src/Dapper.AOT/CommandT.Query.cs @@ -25,7 +25,8 @@ public List QueryBuffered(TArgs args, [DapperAot] RowFactory? SyncQueryState state = default; try { - state.ExecuteReader(GetCommand(args), CommandBehavior.SingleResult | CommandBehavior.SequentialAccess); + GetUnifiedBatch(out state.CommandState.UnifiedBatch, args); + state.ExecuteReaderUnified(CommandBehavior.SingleResult | CommandBehavior.SequentialAccess); List results; if (state.Reader.Read()) @@ -51,7 +52,7 @@ public List QueryBuffered(TArgs args, [DapperAot] RowFactory? // consume entire results (avoid unobserved TDS error messages) while (state.Reader.NextResult()) { } - PostProcessAndRecycle(ref state, args, state.Reader.CloseAndCapture()); + PostProcessAndRecycleUnified(in state.CommandState.UnifiedBatch, args, state.Reader.CloseAndCapture()); return results; } finally @@ -66,7 +67,7 @@ public List QueryBuffered(TArgs args, [DapperAot] RowFactory? public async Task> QueryBufferedAsync(TArgs args, [DapperAot] RowFactory? rowFactory = null, int rowCountHint = 0, CancellationToken cancellationToken = default) { - AsyncQueryState state = new(); + var state = AsyncQueryState.Create(); try { await state.ExecuteReaderAsync(GetCommand(args), CommandBehavior.SingleResult | CommandBehavior.SequentialAccess, cancellationToken); @@ -105,7 +106,7 @@ public List QueryBuffered(TArgs args, [DapperAot] RowFactory? public async IAsyncEnumerable QueryUnbufferedAsync(TArgs args, [DapperAot] RowFactory? rowFactory = null, [EnumeratorCancellation] CancellationToken cancellationToken = default) { - AsyncQueryState state = new(); + var state = AsyncQueryState.Create(); try { await state.ExecuteReaderAsync(GetCommand(args), CommandBehavior.SingleResult | CommandBehavior.SequentialAccess, cancellationToken); @@ -213,8 +214,7 @@ static CommandBehavior SingleFlags(OneRowFlags flags) RowFactory? rowFactory, CancellationToken cancellationToken) { - AsyncQueryState state = new(); - + var state = AsyncQueryState.Create(); try { await state.ExecuteReaderAsync(GetCommand(args), SingleFlags(flags), cancellationToken); diff --git a/src/Dapper.AOT/CommandT.cs b/src/Dapper.AOT/CommandT.cs index c20f7d12..c3a1765e 100644 --- a/src/Dapper.AOT/CommandT.cs +++ b/src/Dapper.AOT/CommandT.cs @@ -36,7 +36,7 @@ namespace Dapper; this.connection = connection!; this.transaction = null; this.sql = sql; - this.commandType = commandType; + this.commandType = commandType == 0 ? DapperAotExtensions.GetCommandType(sql) : commandType; this.timeout = timeout; this.commandFactory = commandFactory ?? CommandFactory.Default; @@ -86,20 +86,65 @@ private DbCommand GetCommand(TArgs args) return cmd; } + private void GetUnifiedBatch(out UnifiedBatch batch, TArgs args) // it will most likely turn out to be a batch of one, note + { + if (commandType == CommandType.Text && commandFactory.UseBatch(sql)) + { +#if NET6_0_OR_GREATER + if (connection.CanCreateBatch) + { + var dbBatch = commandFactory.GetBatch(connection, sql, commandType, args); + dbBatch.Connection = connection; + dbBatch.Timeout = timeout; + dbBatch.Transaction = transaction; + batch = new(commandFactory, dbBatch); + } + else +#endif + { + var cmd = commandFactory.CreateNewCommand(connection); + cmd.Connection = connection; + cmd.Transaction = transaction; + cmd.CommandTimeout = timeout; + batch = new(commandFactory, cmd); + commandFactory.AddCommands(in batch, sql, args); + } + batch.Command.UnsafeMoveToFinal(); + } + else + { + var cmd = GetCommand(args); + batch = new UnifiedBatch(commandFactory, cmd); + } + } + internal void PostProcessAndRecycle(ref SyncQueryState state, TArgs args, int rowCount) { Debug.Assert(state.Command is not null); - commandFactory.PostProcess(new(state.Command!), args, rowCount); + commandFactory.PostProcess(new UnifiedCommand(commandFactory, state.Command!), args, rowCount); if (commandFactory.TryRecycle(state.Command!)) { state.Command = null; } } + internal void PostProcessAndRecycleUnified(in UnifiedBatch batch, TArgs args, int rowCount) + { + if (batch.Mode is BatchMode.SingleCommandDbBatch) + { + commandFactory.PostProcess(in batch, args, rowCount); + } + else + { + commandFactory.PostProcess(in batch.Command, args, rowCount); + } + batch.Command.TryRecycle(); + } + internal void PostProcessAndRecycle(AsyncQueryState state, TArgs args, int rowCount) { Debug.Assert(state.Command is not null); - commandFactory.PostProcess(new(state.Command!), args, rowCount); + commandFactory.PostProcess(new UnifiedCommand(commandFactory, state.Command!), args, rowCount); if (commandFactory.TryRecycle(state.Command!)) { state.Command = null; @@ -109,7 +154,7 @@ internal void PostProcessAndRecycle(AsyncQueryState state, TArgs args, int rowCo internal void PostProcessAndRecycle(ref SyncCommandState state, TArgs args, int rowCount) { Debug.Assert(state.Command is not null); - commandFactory.PostProcess(new(state.Command!), args, rowCount); + commandFactory.PostProcess(new UnifiedCommand(commandFactory, state.Command!), args, rowCount); if (commandFactory.TryRecycle(state.Command!)) { state.Command = null; @@ -119,7 +164,7 @@ internal void PostProcessAndRecycle(ref SyncCommandState state, TArgs args, int internal void PostProcessAndRecycle(AsyncCommandState state, TArgs args, int rowCount) { Debug.Assert(state.Command is not null); - commandFactory.PostProcess(new(state.Command!), args, rowCount); + commandFactory.PostProcess(new UnifiedCommand(commandFactory, state.Command!), args, rowCount); if (commandFactory.TryRecycle(state.Command!)) { state.Command = null; @@ -164,7 +209,7 @@ public Task ExecuteReaderAsync(TArgs args, CommandBehavior behavio public async Task ExecuteReaderAsync(TArgs args, CommandBehavior behavior = CommandBehavior.Default, CancellationToken cancellationToken = default) where TReader : WrappedDbDataReader, new() { - AsyncQueryState? state = new(); + var state = AsyncQueryState.Create(); try { await state.ExecuteReaderAsync(GetCommand(args), behavior, cancellationToken); diff --git a/src/Dapper.AOT/Internal/AsyncCommandState.cs b/src/Dapper.AOT/Internal/AsyncCommandState.cs index 4922de8b..b8810b8c 100644 --- a/src/Dapper.AOT/Internal/AsyncCommandState.cs +++ b/src/Dapper.AOT/Internal/AsyncCommandState.cs @@ -11,8 +11,14 @@ namespace Dapper.Internal // split out because of async state machine semantics; see https://github.com/DapperLib/DapperAOT/issues/27 internal class AsyncCommandState : IAsyncDisposable { + private static AsyncCommandState? _spare; + protected AsyncCommandState() {} + + public static AsyncCommandState Create() => Interlocked.Exchange(ref _spare, null) ?? new(); + private DbConnection? connection; public DbCommand? Command; + public UnifiedBatch UnifiedBatch; private int _flags; const int @@ -48,6 +54,47 @@ static async Task Awaited(Task pending, DbCommand command, Command } } + private Task OnBeforeExecuteUnifiedAsync(CancellationToken cancellationToken) + { + Debug.Assert(UnifiedBatch.Connection is not null); + connection = UnifiedBatch.Connection!; + + if (connection.State == ConnectionState.Open) + { + if ((_flags & FLAG_PREPARE_COMMMAND) == 0) + { + // nothing to do + return Task.CompletedTask; + } + else + { + // just need to prepare + _flags &= ~FLAG_PREPARE_COMMMAND; // once is enough + return UnifiedBatch.PrepareAsync(cancellationToken); + } + } + else + { + _flags |= FLAG_CLOSE_CONNECTION; + if ((_flags & FLAG_PREPARE_COMMMAND) == 0) + { + // just need to open + return connection.OpenAsync(cancellationToken); + } + else + { + _flags &= ~FLAG_PREPARE_COMMMAND; // once is enough + return OpenAndPrepareAsync(this, cancellationToken); + + static async Task OpenAndPrepareAsync(AsyncCommandState state, CancellationToken cancellationToken) + { + await state.UnifiedBatch.Connection!.OpenAsync(cancellationToken); + await state.UnifiedBatch.PrepareAsync(cancellationToken); + } + } + } + } + [MemberNotNull(nameof(Command))] private Task OnBeforeExecuteAsync(DbCommand command, CancellationToken cancellationToken) { @@ -65,6 +112,7 @@ private Task OnBeforeExecuteAsync(DbCommand command, CancellationToken cancellat else { // just need to prepare + _flags &= ~FLAG_PREPARE_COMMMAND; // once is enough #if NETCOREAPP3_1_OR_GREATER return command.PrepareAsync(cancellationToken); #else @@ -83,6 +131,7 @@ private Task OnBeforeExecuteAsync(DbCommand command, CancellationToken cancellat } else { + _flags &= ~FLAG_PREPARE_COMMMAND; // once is enough return OpenAndPrepareAsync(command, cancellationToken); static async Task OpenAndPrepareAsync(DbCommand command, CancellationToken cancellationToken) @@ -98,6 +147,20 @@ static async Task OpenAndPrepareAsync(DbCommand command, CancellationToken cance } } + public Task ExecuteNonQueryUnifiedAsync(CancellationToken cancellationToken) + { + var pending = OnBeforeExecuteUnifiedAsync(cancellationToken); + return pending.IsCompletedSuccessfully() ? UnifiedBatch.ExecuteNonQueryAsync(cancellationToken) + : Awaited(pending, this, cancellationToken); + + static async Task Awaited(Task pending, AsyncCommandState state, CancellationToken cancellationToken) + { + await pending; + return await state.UnifiedBatch.ExecuteNonQueryAsync(cancellationToken); + } + } + + [MemberNotNull(nameof(Command))] public Task ExecuteNonQueryAsync(DbCommand command, CancellationToken cancellationToken) @@ -115,18 +178,24 @@ static async Task Awaited(Task pending, DbCommand command, CancellationToke public virtual ValueTask DisposeAsync() { + var tmp = UnifiedBatch; + UnifiedBatch = default; + tmp.Cleanup(); + var cmd = Command; Command = null; var conn = connection; connection = null; + var flags = _flags; + _flags = 0; + if (cmd is not null) { - if (conn is not null && (_flags & FLAG_CLOSE_CONNECTION) != 0) + if (conn is not null && (flags & FLAG_CLOSE_CONNECTION) != 0) { // need to close the connection and dispose the command - _flags &= ~FLAG_CLOSE_CONNECTION; return DisposeCommandAndCloseConnectionAsync(conn, cmd); #if NETCOREAPP3_1_OR_GREATER @@ -157,7 +226,7 @@ static ValueTask DisposeCommandAndCloseConnectionAsync(DbConnection conn, DbComm } else { - if (conn is not null && (_flags & FLAG_CLOSE_CONNECTION) != 0) + if (conn is not null && (flags & FLAG_CLOSE_CONNECTION) != 0) { #if NETCOREAPP3_1_OR_GREATER return new(conn.CloseAsync()); @@ -176,18 +245,27 @@ static ValueTask DisposeCommandAndCloseConnectionAsync(DbConnection conn, DbComm public virtual void Dispose() { + var tmp = UnifiedBatch; + UnifiedBatch = default; + tmp.Cleanup(); + var cmd = Command; Command = null; cmd?.Dispose(); var conn = connection; connection = null; - if (conn is not null && (_flags & FLAG_CLOSE_CONNECTION) != 0) + + var flags = _flags; + _flags = 0; + + if (conn is not null && (flags & FLAG_CLOSE_CONNECTION) != 0) { - _flags &= ~FLAG_CLOSE_CONNECTION; conn.Close(); } } + + public virtual void Recycle() => Interlocked.Exchange(ref _spare, this); } } diff --git a/src/Dapper.AOT/Internal/AsyncQueryState.cs b/src/Dapper.AOT/Internal/AsyncQueryState.cs index bbfe5a9f..0e3e2569 100644 --- a/src/Dapper.AOT/Internal/AsyncQueryState.cs +++ b/src/Dapper.AOT/Internal/AsyncQueryState.cs @@ -14,7 +14,6 @@ internal interface IQueryState { DbDataReader? Reader { get; } DbCommand? Command{ get; set; } - void Dispose(); ValueTask DisposeAsync(); } @@ -110,6 +109,15 @@ public override void Dispose() Reader?.Dispose(); base.Dispose(); } + + public override void Recycle() => Interlocked.Exchange(ref _spare, this); + + protected AsyncQueryState() : base() { } + + public static new AsyncQueryState Create() => Interlocked.Exchange(ref _spare, null) ?? new(); + + private static AsyncQueryState? _spare; + } } diff --git a/src/Dapper.AOT/Internal/SyncCommandState.cs b/src/Dapper.AOT/Internal/SyncCommandState.cs index 85594995..932c9bd5 100644 --- a/src/Dapper.AOT/Internal/SyncCommandState.cs +++ b/src/Dapper.AOT/Internal/SyncCommandState.cs @@ -10,6 +10,9 @@ internal struct SyncCommandState // note mutable; deliberately not : IDisposable { private DbConnection? connection; public DbCommand? Command; + + public UnifiedBatch UnifiedBatch; + private int _flags; const int @@ -32,6 +35,12 @@ public DbDataReader ExecuteReader(DbCommand command, CommandBehavior flags) return command.ExecuteReader(flags); } + public DbDataReader ExecuteReaderUnified(CommandBehavior flags) + { + OnBeforeExecuteUnified(); + return UnifiedBatch.ExecuteReader(flags); + } + [MemberNotNull(nameof(Command))] private void OnBeforeExecute(DbCommand command) { @@ -51,6 +60,29 @@ private void OnBeforeExecute(DbCommand command) } } + private void OnBeforeExecuteUnified() + { + connection = UnifiedBatch.Connection; + Debug.Assert(connection is not null); + + if (connection!.State != ConnectionState.Open) + { + connection.Open(); + _flags |= FLAG_CLOSE_CONNECTION; + } + if ((_flags & FLAG_PREPARE_COMMMAND) != 0) + { + _flags &= ~FLAG_PREPARE_COMMMAND; + UnifiedBatch.Prepare(); + } + } + + public int ExecuteNonQueryUnified() + { + OnBeforeExecuteUnified(); + return UnifiedBatch.ExecuteNonQuery(); + } + [MemberNotNull(nameof(Command))] public int ExecuteNonQuery(DbCommand command) { @@ -60,10 +92,16 @@ public int ExecuteNonQuery(DbCommand command) public void Dispose() { + var tmp = UnifiedBatch; + UnifiedBatch = default; + tmp.Cleanup(); + var cmd = Command; Command = null; cmd?.Dispose(); + UnifiedBatch.Cleanup(); + var conn = connection; connection = null; if (conn is not null && (_flags & FLAG_CLOSE_CONNECTION) != 0) @@ -71,6 +109,8 @@ public void Dispose() _flags &= ~FLAG_CLOSE_CONNECTION; conn.Close(); } + + _flags = 0; } } diff --git a/src/Dapper.AOT/Internal/SyncQueryState.cs b/src/Dapper.AOT/Internal/SyncQueryState.cs index 44271042..8d2d53c0 100644 --- a/src/Dapper.AOT/Internal/SyncQueryState.cs +++ b/src/Dapper.AOT/Internal/SyncQueryState.cs @@ -18,7 +18,7 @@ internal struct SyncQueryState : IQueryState // note mutable; deliberately not : set => Command = value; } - private SyncCommandState commandState; + public SyncCommandState CommandState; public DbDataReader? Reader; public int[]? Leased; private int fieldCount; @@ -32,19 +32,23 @@ public void Dispose() { Return(); Reader?.Dispose(); - commandState.Dispose(); + CommandState.Dispose(); } public DbCommand? Command { - readonly get => commandState.Command; - set => commandState.Command = value; + readonly get => CommandState.Command; + set => CommandState.Command = value; } #pragma warning disable CS8774 // Member must have a non-null value when exiting. - validated [MemberNotNull(nameof(Reader), nameof(Command))] public void ExecuteReader(DbCommand command, CommandBehavior flags) - => Reader = commandState.ExecuteReader(command, flags); + => Reader = CommandState.ExecuteReader(command, flags); + + [MemberNotNull(nameof(Reader), nameof(Command))] + public void ExecuteReaderUnified(CommandBehavior flags) + => Reader = CommandState.ExecuteReaderUnified(flags); public Span Lease() { diff --git a/src/Dapper.AOT/RowFactory.cs b/src/Dapper.AOT/RowFactory.cs index 99c9fa0e..33d0f628 100644 --- a/src/Dapper.AOT/RowFactory.cs +++ b/src/Dapper.AOT/RowFactory.cs @@ -104,6 +104,8 @@ protected static T GetValueExact(DbDataReader reader, int fieldOffset) internal const int MAX_STACK_TOKENS = 64; + internal static CommandFactory NotRequired => null!; // for times when we don't *actually* need a row-factory, but we want to be clear + internal static Span Lease(int fieldCount, ref int[]? lease) { if (lease is null || lease.Length < fieldCount) diff --git a/src/Dapper.AOT/UnifiedBatch.cs b/src/Dapper.AOT/UnifiedBatch.cs new file mode 100644 index 00000000..cdc8985c --- /dev/null +++ b/src/Dapper.AOT/UnifiedBatch.cs @@ -0,0 +1,215 @@ +using System; +using System.Collections.Generic; +using System.Data; +using System.Data.Common; +using System.Diagnostics; +using System.Runtime.CompilerServices; +using System.Threading; +using System.Threading.Tasks; + +namespace Dapper; + +internal enum BatchMode +{ + None, + SingleCommandDbCommand, + SingleCommandDbBatch, + MultiCommandDbCommand, + MultiCommandDbBatchCommand, +} + +#pragma warning disable IDE0079 // following will look unnecessary on up-level +#pragma warning disable CS1574 // DbBatchCommand will not resolve on down-level TFMs +/// +/// Represents the state associated with multiple or instances (where supported). +/// +/// Only the current command is available to the caller. +#pragma warning disable CS1574 +#pragma warning restore IDE0079 +public readonly struct UnifiedBatch +{ + + private readonly BatchMode mode; + internal readonly UnifiedCommand Command; // avoid duplication by offloading a lot of details here + + private readonly int commandStart, commandCount, groupCount; // these are used to restrict the commands that are available to a single consumer + internal int GroupCount => groupCount; + internal BatchMode Mode => mode; + + internal UnifiedBatch(CommandFactory commandFactory, DbCommand command) + { + Command = new UnifiedCommand(commandFactory, command); + commandStart = 0; + commandCount = groupCount = 1; + mode = BatchMode.SingleCommandDbCommand; + Debug.Assert(Command.CommandCount == 1); + } + +#if NET6_0_OR_GREATER + internal UnifiedBatch(CommandFactory commandFactory, DbBatch batch) + { + if (batch.BatchCommands.Count == 0) Throw(); + + Command = new UnifiedCommand(commandFactory, batch); + commandStart = 0; + commandCount = batch.BatchCommands.Count; + groupCount = 1; + mode = BatchMode.SingleCommandDbBatch; + static void Throw() => throw new ArgumentException( + message: "When creating a " + nameof(UnifiedBatch) + " for an existing batch, the batch cannot be empty", + paramName: nameof(batch)); + } +#endif + + internal UnifiedBatch(CommandFactory commandFactory, DbConnection connection, DbTransaction? transaction) + { +#if NET6_0_OR_GREATER + if (connection is { CanCreateBatch: true }) + { + var batch = commandFactory.CreateNewBatch(connection); + batch.Connection = connection; + batch.Transaction = transaction; + Command = new UnifiedCommand(commandFactory, batch); + mode = BatchMode.MultiCommandDbBatchCommand; + } + else +#endif + { + var cmd = commandFactory.CreateNewCommand(connection); + cmd.Connection = connection; + cmd.Transaction = transaction; + Command = new UnifiedCommand(commandFactory, cmd); + mode = BatchMode.MultiCommandDbCommand; + } + commandStart = 0; + commandCount = groupCount = 1; + Debug.Assert(Command.CommandCount == 1); + } + +#if NET6_0_OR_GREATER + internal DbBatch? Batch => Command.Batch; +#endif + + private int GetCommandIndex(int localIndex) + { + if (localIndex < 0 || localIndex >= commandCount) Throw(); + return commandStart + localIndex; + static void Throw() => throw new IndexOutOfRangeException(); + } + + /// + /// Returns the parameters of the corresponding command + /// + public DbParameterCollection this[int commandIndex] => Command[GetCommandIndex(commandIndex)]; + + /// + public override string ToString() => Command.ToString(); + + /// + internal DbConnection? Connection => Command.Connection; + + /// + internal DbTransaction? Transaction => Command.Transaction; + + /// + public string CommandText + { + get => Command.CommandText; + [Obsolete("When possible, " + nameof(SetCommand) + " should be preferred", false)] + set => Command.CommandText = value; + } + + /// + /// Initialize the and + /// + public void SetCommand(string commandText, CommandType commandType = CommandType.Text) + => Command.SetCommand(commandText, commandType); + + /// + public CommandType CommandType + { + get => Command.CommandType; + [Obsolete("When possible, " + nameof(SetCommand) + " should be preferred", false)] + set => Command.CommandType = value; + } + + /// + public int TimeoutSeconds + { + get => Command.TimeoutSeconds; + set => Command.TimeoutSeconds = value; + } + + /// + public DbParameterCollection Parameters => Command.Parameters; + + /// +#if DEBUG + [Obsolete("Prefer " + nameof(AddParameter))] +#endif + public DbParameter CreateParameter() => Command.CreateParameter(); + + /// + public DbParameter AddParameter() => Command.AddParameter(); + + internal bool IsLastCommand => Command.CommandCount == Command.Index + 1; + + /// + /// Creates and initializes new command, returning .the parameters collection. + /// + public DbParameterCollection AddCommand(string commandText, CommandType commandType = CommandType.Text) + { + Debug.Assert(mode is BatchMode.SingleCommandDbBatch); + return Command.AddCommand(commandText, commandType); + } + + internal void OverwriteNextBatchGroup() + { + Debug.Assert(mode is BatchMode.MultiCommandDbCommand or BatchMode.MultiCommandDbBatchCommand); + Debug.Assert(!IsLastCommand); + Command.UnsafeAdvance(); + Unsafe.AsRef(in commandStart) = Command.Index; + Unsafe.AsRef(in groupCount)++; + } + + internal void CreateNextBatchGroup(string commandText, CommandType commandType) + { + Debug.Assert(mode is BatchMode.MultiCommandDbCommand or BatchMode.MultiCommandDbBatchCommand); + Debug.Assert(IsLastCommand); + AddCommand(commandText, commandType); + Unsafe.AsRef(in commandStart) = Command.Index; + Unsafe.AsRef(in commandCount) = 1; + Unsafe.AsRef(in groupCount)++; + } + + internal bool IsDefault => Command.IsDefault; + + internal CommandFactory CommandFactory => Command.CommandFactory; + + internal int ExecuteNonQuery() => GroupCount == 0 ? 0 : Command.ExecuteNonQuery(); + + internal Task ExecuteNonQueryAsync(CancellationToken cancellationToken) => GroupCount == 0 ? TaskZero : Command.ExecuteNonQueryAsync(cancellationToken); + + internal void Cleanup() => Command.Cleanup(); + + internal void Trim() => Command.Trim(); + + internal void TryRecycle() => Command.TryRecycle(); + + internal DbDataReader ExecuteReader(CommandBehavior flags) + => Command.ExecuteReader(flags); + + internal void Prepare() => Command.Prepare(); + + internal Task PrepareAsync(CancellationToken cancellationToken) => Command.PrepareAsync(cancellationToken); + + internal void UnsafeMoveBeforeFirst() + { + Command.UnsafeMoveTo(-1); + Unsafe.AsRef(in commandStart) = 0; + Unsafe.AsRef(in commandCount) = 0; + Unsafe.AsRef(in groupCount) = 0; + } + + static readonly Task TaskZero = Task.FromResult(0); +} diff --git a/src/Dapper.AOT/UnifiedCommand.cs b/src/Dapper.AOT/UnifiedCommand.cs index ad618a24..6193660d 100644 --- a/src/Dapper.AOT/UnifiedCommand.cs +++ b/src/Dapper.AOT/UnifiedCommand.cs @@ -3,6 +3,8 @@ using System.Data; using System.Data.Common; using System.Runtime.CompilerServices; +using System.Threading; +using System.Threading.Tasks; namespace Dapper; @@ -18,208 +20,619 @@ namespace Dapper; public readonly struct UnifiedCommand { + // could be: + // a) a DbCommand instance for a single operation + // b) a DbBatch instance for a multi-command operation using the DbBatch API + // c) a List instance for a multi-command operation using the legacy API + private readonly object _source; + + private readonly CommandFactory _commandFactory; + + internal CommandFactory CommandFactory => _commandFactory; + + // 0 for "a"; for "b" and "c", this is the index of the current operation + private readonly int _index; + +#if NET6_0_OR_GREATER + // may be used by the multi-command API to create new parameters + private readonly DbCommand? _spareCommandForParameters; +#endif + + /* + private object? Source => _source switch + { + DbCommand cmd => cmd, + List list => list[_index], +#if NET6_0_OR_GREATER + DbBatch batch => batch.BatchCommands[_index], +#endif + _ => null, + }; + */ + + internal int CommandCount => _source switch + { + DbCommand cmd => 1, + List list => list.Count, +#if NET6_0_OR_GREATER + DbBatch batch => batch.BatchCommands.Count, +#endif + _ => 0, + }; + + /// + public override string ToString() => CommandText; + + /// + internal DbConnection? Connection => _source switch + { + DbCommand cmd => cmd.Connection, + List list => list[_index].Connection, +#if NET6_0_OR_GREATER + DbBatch batch => batch.Connection, +#endif + _ => null, + }; + + /// + internal DbTransaction? Transaction => _source switch + { + DbCommand cmd => cmd.Transaction, + List list => list[_index].Transaction, +#if NET6_0_OR_GREATER + DbBatch batch => batch.Transaction, +#endif + _ => null, + }; + /// - /// The associated with the current operation; this may be null for batch commands. + /// The associated with the current operation, if appropriate. /// - public DbCommand? Command + public DbCommand? Command => _source switch { - get - { + DbCommand cmd => cmd, + List list => list[_index], + _ => null, + }; + #if NET6_0_OR_GREATER - return batchCommand is null ? dbCommand : null; -#else - return dbCommand; + /// + /// The associated with the current operation, if appropriate. + /// + public DbBatchCommand? BatchCommand => _source switch + { + DbBatch batch => batch.BatchCommands[_index], + _ => null, + }; + /// + /// The associated with the current operation, if appropriate. + /// + internal DbBatch? Batch => _source as DbBatch; + + + #endif - } - } /// public string CommandText { - get + get => _source switch { + DbCommand cmd => cmd.CommandText, + List list => list[_index].CommandText, #if NET6_0_OR_GREATER - if (batchCommand is not null) return batchCommand.CommandText; + DbBatch batch => batch.BatchCommands[_index].CommandText, #endif - return AssertCommand.CommandText; - } + _ => "", + }; + [Obsolete("When possible, " + nameof(SetCommand) + " should be preferred", false)] set { -#if NET6_0_OR_GREATER - if (batchCommand is not null) + switch (_source) { - batchCommand.CommandText = value; - return; - } + case DbCommand cmd: + cmd.CommandText = value; + break; + case List list: + list[_index].CommandText = value; + break; +#if NET6_0_OR_GREATER + case DbBatch batch: + batch.BatchCommands[_index].CommandText = value; + break; #endif - AssertCommand.CommandText = value; + } } } + internal bool IsDefault => _source is null; + /// public DbParameterCollection Parameters { - get - { + get { + return _source switch + { + DbCommand cmd => cmd.Parameters, + List list => list[_index].Parameters, #if NET6_0_OR_GREATER - if (batchCommand is not null) return batchCommand.Parameters; + DbBatch batch => batch.BatchCommands[_index].Parameters, #endif - return AssertCommand.Parameters; + _ => Throw(), + }; + static DbParameterCollection Throw() => throw new InvalidOperationException("The command is not initialized"); } } /// public CommandType CommandType { - get + get => _source switch { + DbCommand cmd => cmd.CommandType, + List list => list[_index].CommandType, #if NET6_0_OR_GREATER - if (batchCommand is not null) return batchCommand.CommandType; + DbBatch batch => batch.BatchCommands[_index].CommandType, #endif - return AssertCommand.CommandType; - } + _ => CommandType.Text, + }; + [Obsolete("When possible, " + nameof(SetCommand) + " should be preferred", false)] set { -#if NET6_0_OR_GREATER - if (batchCommand is not null) + switch (_source) { - batchCommand.CommandType = value; - return; - } + case DbCommand cmd: + cmd.CommandType = value; + break; + case List list: + list[_index].CommandType = value; + break; +#if NET6_0_OR_GREATER + case DbBatch batch: + batch.BatchCommands[_index].CommandType = value; + break; #endif - AssertCommand.CommandType = value; + } } } /// public int TimeoutSeconds { - get + get => _source switch { + DbCommand cmd => cmd.CommandTimeout, + List list => list[_index].CommandTimeout, #if NET6_0_OR_GREATER - if (batch is not null) return batch.Timeout; + DbBatch batch => batch.Timeout, #endif - return AssertCommand.CommandTimeout; - } + _ => 0, + }; set { -#if NET6_0_OR_GREATER - if (batch is not null) + switch (_source) { - batch.Timeout = value; - return; + case DbCommand cmd: + cmd.CommandTimeout = value; + break; + case List list: + list[_index].CommandTimeout = value; + break; +#if NET6_0_OR_GREATER + case DbBatch batch: + batch.Timeout = value; + break; +#endif } + } + } + + /// + /// Create a parameter and add it to the parameter collection + /// + public DbParameter AddParameter() + { + var p = _commandFactory.CreateNewParameter(in this); + Parameters.Add(p); + return p; + } + + /// +#if DEBUG + [Obsolete("Prefer " + nameof(AddParameter))] #endif - AssertCommand.CommandTimeout = value; + public DbParameter CreateParameter() => _commandFactory.CreateNewParameter(in this); + + internal DbParameter DefaultCreateParameter() + { + switch (_source) + { + case DbCommand cmd: + return cmd.CreateParameter(); + case List list: + var activeCmd = list[_index]; + return activeCmd.CreateParameter(); +#if NET6_0_OR_GREATER + case DbBatch batch: + var bc = batch.BatchCommands[_index]; +#if NET8_0_OR_GREATER // https://github.com/dotnet/runtime/issues/82326 + if (bc.CanCreateParameter) return bc.CreateParameter(); +#endif // NET8 + return (_spareCommandForParameters ?? UnsafeBatchWithCommandForParameters()).CreateParameter(); +#endif // NET6 + default: + return Throw(); + } + static DbParameter Throw() => throw new InvalidOperationException("It was not possible to create a parameter for this command"); } - private DbCommand AssertCommand + internal UnifiedCommand(CommandFactory commandFactory, DbCommand command) { - get + _source = command; + _index = 0; +#if NET6_0_OR_GREATER + _spareCommandForParameters = null; +#endif + _commandFactory = commandFactory; + } + +#if NET6_0_OR_GREATER + + internal UnifiedCommand(CommandFactory commandFactory, DbBatch batch) + { + _source = batch; + _spareCommandForParameters = null; + + var bc = batch.BatchCommands; + if (bc.Count == 0) { - return dbCommand ?? Throw(); - static DbCommand Throw() => throw new InvalidOperationException($"The {nameof(UnifiedCommand)} is not associated with a valid command"); + // initialize the first command + bc.Add(commandFactory.CreateNewCommand(batch)); } + _index = 0; + _commandFactory = commandFactory; } - private readonly DbCommand? dbCommand; + private DbCommand UnsafeBatchWithCommandForParameters() + { + return _spareCommandForParameters + ?? (Connection is { } conn ? (Unsafe.AsRef(in _spareCommandForParameters) = _commandFactory.CreateNewCommand(conn)) : null) + ?? Throw(); + static DbCommand Throw() => throw new InvalidOperationException("It was not possible to create command parameters for this batch; the connection may be null"); + } +#endif - /// - public DbParameter CreateParameter() + internal int RecordsAffected { -#if NET8_0_OR_GREATER // https://github.com/dotnet/runtime/issues/82326 - if (batchCommand is { CanCreateParameter: true }) + get { - return batchCommand.CreateParameter(); +#if NET6_0_OR_GREATER + if (_source is DbBatch batch) + { + return batch.BatchCommands[_index].RecordsAffected; + } +#endif + return -1; } + } + + /// + /// Creates a new command and moves into that context + /// + internal DbParameterCollection AddCommand(string? commandText, CommandType commandType) + { + DbParameterCollection? result = null; + switch (_source) + { + case DbCommand cmd when cmd.Connection is not null: + if (Index != 0) ThrowNotLast(); + // swap for a list, then! + var newCmd = _commandFactory.CreateNewCommand(cmd.Connection); + newCmd.Connection = cmd.Connection; + newCmd.Transaction = cmd.Transaction; + if (commandText is not null) + { + newCmd.CommandText = commandText; + newCmd.CommandType = commandType; + } + Unsafe.AsRef(in _source) = new List { cmd, newCmd }; + result = newCmd.Parameters; + break; + case List list: + if (Index != list.Count - 1) ThrowNotLast(); + foreach (var item in list) + { + if (item.Connection is not null) + { + newCmd = _commandFactory.CreateNewCommand(item.Connection); + newCmd.Connection = item.Connection; + newCmd.Transaction = item.Transaction; + if (commandText is not null) + { + newCmd.CommandText = commandText; + newCmd.CommandType = commandType; + } + list.Add(newCmd); + result = newCmd.Parameters; + break; + } + } + break; +#if NET6_0_OR_GREATER + case DbBatch batch: + if (Index != batch.BatchCommands.Count - 1) ThrowNotLast(); + var bc = _commandFactory.CreateNewCommand(batch); + if (commandText is not null) + { + bc.CommandText = commandText; + bc.CommandType = commandType; + } + batch.BatchCommands.Add(bc); + result = bc.Parameters; + break; #endif - return (dbCommand ?? UnsafeWithCommandForParameters()).CreateParameter(); + } + if (result is null) Throw(); + Unsafe.AsRef(in _index)++; + return result!; + + static void Throw() => throw new NotSupportedException("It was not possible to create a new command in this batch; the connection may be invalid"); + static void ThrowNotLast() => throw new InvalidOperationException(nameof(AddCommand) + " would overwrite existing commands; command may be incorrectly positioned"); } - internal UnifiedCommand(DbCommand command) + internal int Index => _index; + + internal void UnsafeMoveToFinal() => Unsafe.AsRef(in _index) = CommandCount - 1; + internal void UnsafeMoveTo(int index) => Unsafe.AsRef(in _index) = index; + internal void UnsafeAdvance() => Unsafe.AsRef(in _index)++; + + internal DbParameterCollection this[int index] { - dbCommand = command; + get + { + return _source switch + { + DbCommand cmd when index == 0 => cmd.Parameters, + List list => list[index].Parameters, #if NET6_0_OR_GREATER - batchCommand = null; - batch = null; + DbBatch batch => batch.BatchCommands[index].Parameters, #endif + _ => Throw() + }; + static DbParameterCollection Throw() => throw new IndexOutOfRangeException(); + } } + internal void Cleanup() + { +#if NET6_0_OR_GREATER + var spare = _spareCommandForParameters; +#endif + var source = _source; + Unsafe.AsRef(in this) = default; // best efforts to prevent double-stomp #if NET6_0_OR_GREATER - private readonly DbBatch? batch; - private readonly DbBatchCommand? batchCommand; + spare?.Dispose(); +#endif - /// - /// The associated with the current operation - this may be null for non-batch operations. - /// - public DbBatchCommand? BatchCommand => batchCommand; + switch (source) + { + case DbCommand cmd: + cmd.Dispose(); + break; + case List list: + foreach (var cmd in list) + { + cmd.Dispose(); + } + break; +#if NET6_0_OR_GREATER + case DbBatch batch: + batch.Dispose(); + break; +#endif + } + } - internal UnifiedCommand(DbBatch batch) + internal void Trim() // removes all *excess* commands { - this.batch = batch; - batchCommand = null; - dbCommand = null; + switch (_source) + { + case List list: + int remove = list.Count - (_index + 1); + if (remove > 0) + { + list.RemoveRange(_index + 1, remove); + } + break; +#if NET6_0_OR_GREATER + case DbBatch batch: + var bc = batch.BatchCommands; + for (int i = bc.Count - 1; i > _index; i--) + { + bc.RemoveAt(i); + } + break; +#endif + } } - internal DbBatchCommandCollection AssertBatchCommands => AssertBatch.BatchCommands; - internal DbBatch AssertBatch + internal void Prepare() { - get + switch (_source) { - return batch ?? Throw(); - static DbBatch Throw() => throw new InvalidOperationException($"The {nameof(UnifiedCommand)} is not associated with a valid command-batch"); + case DbCommand cmd: + cmd.Prepare(); + break; + case List list: + foreach (var cmd in list) + { + cmd.Prepare(); + } + break; +#if NET6_0_OR_GREATER + case DbBatch batch: + batch.Prepare(); + break; +#endif } } - internal void PostProcess(IEnumerable source, CommandFactory commandFactory) + internal Task PrepareAsync(CancellationToken cancellationToken) { - int i = 0; - var commands = AssertBatchCommands; - foreach (var arg in source) + return _source switch { - var cmd = commands[i++]; - UnsafeSetBatchCommand(cmd); - commandFactory.PostProcess(in this, arg, cmd.RecordsAffected); +#if NETCOREAPP3_0_OR_GREATER + DbCommand cmd => cmd.PrepareAsync(cancellationToken), +#else + DbCommand cmd => PrepareSingleAsync(cmd), +#endif + List list => PrepareAsync(list, cancellationToken), +#if NET6_0_OR_GREATER + DbBatch batch => batch.PrepareAsync(cancellationToken), +#endif + _ => Task.CompletedTask, + }; + +#if NETCOREAPP3_0_OR_GREATER + + static async Task PrepareAsync(List list, CancellationToken cancellationToken) + { + foreach (var cmd in list) + await cmd.PrepareAsync(cancellationToken); } - UnsafeSetBatchCommand(null); +#else +// best we can do without the missing API + static Task PrepareSingleAsync(DbCommand cmd) + { + cmd.Prepare(); + return Task.CompletedTask; + } + static Task PrepareAsync(List list, CancellationToken _) + { + foreach (var cmd in list) + cmd.Prepare(); + return Task.CompletedTask; + } +#endif + + + } + + void AssertFinal([CallerMemberName]string caller = "") + { + if (Index != CommandCount - 1) Throw(caller); + static void Throw(string caller) => throw new InvalidOperationException(caller + " can only be invoked when in the last command position"); } - internal void PostProcess(ReadOnlySpan source, CommandFactory commandFactory) + internal int ExecuteNonQuery() { - int i = 0; - var commands = AssertBatchCommands; - foreach (var arg in source) + AssertFinal(); + switch (_source) { - var cmd = commands[i++]; - UnsafeSetBatchCommand(cmd); - commandFactory.PostProcess(in this, arg, cmd.RecordsAffected); + case DbCommand cmd: + return cmd.ExecuteNonQuery(); + case List list: + int sum = 0; + foreach (var cmd in list) + { + sum += cmd.ExecuteNonQuery(); + } + return sum; +#if NET6_0_OR_GREATER + case DbBatch batch: + return batch.ExecuteNonQuery(); +#endif + default: + return 0; } - UnsafeSetBatchCommand(null); } - internal bool HasBatch => batch is not null; - - internal DbBatchCommand UnsafeCreateNewCommand() => Unsafe.AsRef(in batchCommand) = AssertBatch.CreateBatchCommand(); + internal DbDataReader ExecuteReader(CommandBehavior flags) + { + AssertFinal(); + return _source switch + { + DbCommand cmd => cmd.ExecuteReader(flags), +#if NET6_0_OR_GREATER + DbBatch batch => batch.ExecuteReader(flags), +#endif + null => throw new InvalidOperationException(), + _ => throw new NotImplementedException($"ExecuteReader for {_source.GetType().Name} is not yet implemented; poke Marc"), + }; + } - internal void UnsafeSetBatchCommand(DbBatchCommand? value) => Unsafe.AsRef(in batchCommand) = value; + internal Task ExecuteNonQueryAsync(CancellationToken cancellationToken) + { + AssertFinal(); + return _source switch + { + DbCommand cmd => cmd.ExecuteNonQueryAsync(cancellationToken), + List list => ExecuteListAsync(list, cancellationToken), +#if NET6_0_OR_GREATER + DbBatch batch => batch.ExecuteNonQueryAsync(cancellationToken), #endif + _ => TaskZero, + }; - private DbCommand UnsafeWithCommandForParameters() + static async Task ExecuteListAsync(List list, CancellationToken cancellationToken) + { + int sum = 0; + foreach (var cmd in list) + { + sum += await cmd.ExecuteNonQueryAsync(cancellationToken); + } + return sum; + } + } + + /// + /// Initialize the and + /// + public void SetCommand(string commandText, CommandType commandType = CommandType.Text) { - return dbCommand + switch (_source) + { + // note we're trying to avoid triggering any unnecessary side-effects and + // cache-invalidations that could be triggered from setters + case DbCommand cmd: + if (cmd.CommandText != commandText) cmd.CommandText = commandText; + if (cmd.CommandType != commandType) cmd.CommandType = commandType; + break; + case List list: + var activeCmd = list[_index]; + if (activeCmd.CommandText != commandText) activeCmd.CommandText = commandText; + if (activeCmd.CommandType != commandType) activeCmd.CommandType = commandType; + break; #if NET6_0_OR_GREATER - ?? (Unsafe.AsRef(in dbCommand) = batch?.Connection?.CreateCommand()) + case DbBatch batch: + var bc = batch.BatchCommands[_index]; + if (bc.CommandText != commandText) bc.CommandText = commandText; + if (bc.CommandType != commandType) bc.CommandType = commandType; + break; #endif - ?? Throw(); - static DbCommand Throw() => throw new InvalidOperationException("It was not possible to create command parameters for this batch; the connection may be null"); + } } - internal void Cleanup() + internal void TryRecycle() { - dbCommand?.Dispose(); + if (_source switch + { + // note we're trying to avoid triggering any unnecessary side-effects and + // cache-invalidations that could be triggered from setters + DbCommand cmd => _commandFactory.TryRecycle(cmd), + // note we don't expect to recycle list usage in this way; we're only expecting + // single-arg scenarios #if NET6_0_OR_GREATER - batch?.Dispose(); + DbBatch batch => _commandFactory.TryRecycle(batch), #endif + _ => false, + }) + { + // wipe the source - someone else can see it + Unsafe.AsRef(in _source) = null!; + } } + + private static readonly Task TaskZero = Task.FromResult(0); } \ No newline at end of file diff --git a/src/Dapper.AOT/WrappedDbDataReader.cs b/src/Dapper.AOT/WrappedDbDataReader.cs index 31854dc8..140db587 100644 --- a/src/Dapper.AOT/WrappedDbDataReader.cs +++ b/src/Dapper.AOT/WrappedDbDataReader.cs @@ -47,9 +47,10 @@ public sealed override void Close() { if (IsClosed) return; var state = this.state; // snapshot + this.state = null; if (state is not null) { - commandFactory.PostProcessObject(new(state.Command!), args, state.Reader.CloseAndCapture()); + commandFactory.PostProcessObject(new(RowFactory.NotRequired, state.Command!), args, state.Reader.CloseAndCapture()); if (commandFactory.TryRecycle(state.Command!)) { state.Command = null; @@ -229,7 +230,7 @@ private async Task CloseAsyncImpl() var state = this.state; // snapshot if (state is not null) { - commandFactory.PostProcessObject(new(state.Command!), args, await state.Reader.CloseAndCaptureAsync()); + commandFactory.PostProcessObject(new(RowFactory.NotRequired, state.Command!), args, await state.Reader.CloseAndCaptureAsync()); if (commandFactory.TryRecycle(state.Command!)) { state.Command = null; diff --git a/test/Dapper.AOT.Test/GeneralSqlParseTests.cs b/test/Dapper.AOT.Test/GeneralSqlParseTests.cs new file mode 100644 index 00000000..ee080ba0 --- /dev/null +++ b/test/Dapper.AOT.Test/GeneralSqlParseTests.cs @@ -0,0 +1,120 @@ + +using Dapper.Internal.SqlParsing; +using Xunit; +using static global::Dapper.SqlAnalysis.SqlSyntax; +namespace Dapper.AOT.Test; + +public class GeneralSqlParseTests +{ + [Fact] + public void BatchifyNonStrippedPostgresql() => Assert.Equal( + [ + new("something;"), + new(" something else;"), + new(""" + -- comment + ; + """), + new("more;"), + ], GeneralSqlParser.Parse(""" + something; + something else; + -- comment + ; + ; + more + """, PostgreSql, strip: false)); + + [Fact] + public void BatchifyStrippedPostgresql() => Assert.Equal( + [ + new("something;"), + new("something else;"), + new("more;"), + ], GeneralSqlParser.Parse(""" + something; + something else; + -- comment + ; + ; + more + """, PostgreSql, strip: true)); + + [Fact] + public void BatchifyNonStrippedSqlServer() => Assert.Equal( + [ + new(""" + something; + something else; + -- comment + ; + ; + more + """), + ], GeneralSqlParser.Parse(""" + something; + something else; + -- comment + ; + ; + more + """, SqlServer, strip: false)); + + [Fact] + public void BatchifyStrippedSqlServer() => Assert.Equal( + [ + new(""" + something;something else;more + """), + ], GeneralSqlParser.Parse(""" + something; + something else; + -- comment + ; + ; + more + """, SqlServer, strip: true)); + + [Fact] + public void BatchifyNonStrippedSqlServer_Go() => Assert.Equal( + [ + new("something"), + new("something ' GO ' else;"), + ], GeneralSqlParser.Parse(""" + something + GO + something ' GO ' else; + """, SqlServer, strip: true)); + + [Fact] + public void DetectArgs() => Assert.Equal( + [ + new (""" + select * from SomeTable where Id = @foo and Name = '@bar'; + """, new("@foo", 35)), + ], GeneralSqlParser.Parse(""" + select * from SomeTable + where Id = @foo and Name = '@bar' + """, PostgreSql, strip: true)); + + [Fact] + public void DetectArgsAndBatchify() => Assert.Equal( + [ + new("select * from SomeTable where Id = @foo and Name = '@bar';", new("@foo", 35)), + new("insert Bar (Id, X) values ($1, @@IDENTITY);", new("$1", 27)), + new("insert Blap (Id) values ($1, @foo);", new("$1", 25), new("@foo", 29)), + ], GeneralSqlParser.Parse(""" + select * from SomeTable where Id = @foo and Name = '@bar' -- $4 + ; + insert Bar (Id, X) /* @abc */ values ($1, @@IDENTITY); + insert Blap (Id) values ($1, @foo) + """, PostgreSql, strip: true)); + + [Fact] + public void StringEscapingSqlServer() => Assert.Equal( + [ + new("select ' @a '' @b ' as [ @c ]] @d ];") // no vars + ], GeneralSqlParser.Parse(""" + select ' @a '' @b ' as [ @c ]] @d ]; + """, SqlServer, strip: true)); +} diff --git a/test/Dapper.AOT.Test/Integration/BatchPostgresql.cs b/test/Dapper.AOT.Test/Integration/BatchPostgresql.cs index 55fb266c..f48ea28c 100644 --- a/test/Dapper.AOT.Test/Integration/BatchPostgresql.cs +++ b/test/Dapper.AOT.Test/Integration/BatchPostgresql.cs @@ -7,7 +7,7 @@ namespace Dapper.AOT.Test.Integration; [Collection(SharedPostgresqlClient.Collection)] public class BatchPostgresql { - private PostgresqlFixture _fixture; + private readonly PostgresqlFixture _fixture; public BatchPostgresql(PostgresqlFixture fixture) { @@ -45,21 +45,18 @@ private class CustomHandler : CommandFactory<(int x, string y)> public override void AddParameters(in UnifiedCommand command, (int x, string y) args) { - var ps = command.Parameters; - var p = command.CreateParameter(); + var p = command.AddParameter(); p.DbType = DbType.Int32; p.Direction = ParameterDirection.Input; p.Value = AsValue(args.x); p.ParameterName = "x"; - ps.Add(p); - p = command.CreateParameter(); + p = command.AddParameter(); p.DbType = DbType.AnsiString; p.Size = 40; p.Direction = ParameterDirection.Input; p.Value = AsValue(args.y); p.ParameterName = "y"; - ps.Add(p); } public override void UpdateParameters(in UnifiedCommand command, (int x, string y) args) diff --git a/test/Dapper.AOT.Test/Integration/BatchTests.cs b/test/Dapper.AOT.Test/Integration/BatchTests.cs index d154bdce..018e2d80 100644 --- a/test/Dapper.AOT.Test/Integration/BatchTests.cs +++ b/test/Dapper.AOT.Test/Integration/BatchTests.cs @@ -135,12 +135,11 @@ internal class CustomHandler : CommandFactory public override void AddParameters(in UnifiedCommand command, string name) { - var p = command.CreateParameter(); + var p = command.AddParameter(); p.ParameterName = "name"; p.DbType = DbType.String; p.Size = 400; p.Value = AsValue(name); - command.Parameters.Add(p); } public override void UpdateParameters(in UnifiedCommand command, string name) diff --git a/test/Dapper.AOT.Test/Integration/NpgsqlRewriteManual.cs b/test/Dapper.AOT.Test/Integration/NpgsqlRewriteManual.cs new file mode 100644 index 00000000..c7a5468e --- /dev/null +++ b/test/Dapper.AOT.Test/Integration/NpgsqlRewriteManual.cs @@ -0,0 +1,197 @@ +#if NET6_0_OR_GREATER +using System; +using System.Collections.Generic; +using System.Data; +using System.Data.Common; +using Xunit; + +namespace Dapper.AOT.Test.Integration; + +[Collection(SharedPostgresqlClient.Collection)] +public class NpgsqlRewriteManual +{ + private readonly PostgresqlFixture _fixture; + + public NpgsqlRewriteManual(PostgresqlFixture fixture) + { + _fixture = fixture; + fixture.NpgsqlConnection.Execute(""" + CREATE TABLE IF NOT EXISTS rewrite_test( + id integer PRIMARY KEY GENERATED ALWAYS AS IDENTITY, + name varchar(40) NOT NULL CHECK (name <> '') + ); + TRUNCATE rewrite_test; + """); + } + + const string CompositeQuery = """ + TRUNCATE rewrite_test RESTART IDENTITY; + + INSERT INTO rewrite_test(name) + VALUES (@x); + + INSERT INTO rewrite_test(name) + VALUES (@x || @y); + + INSERT INTO rewrite_test(name) + VALUES (@y); + + select id, name + from rewrite_test + order by id + """; + private static readonly object CompositeArgs = new { x = "abc", y = "def" }; + + private static void AssertResults(IEnumerable results) + { + var list = results.AsList(); + Assert.Equal(3, list.Count); + Assert.Equal(1, list[0].Id); + Assert.Equal("abc", list[0].Name); + Assert.Equal(2, list[1].Id); + Assert.Equal("abcdef", list[1].Name); + Assert.Equal(3, list[2].Id); + Assert.Equal("def", list[2].Name); + } + public record struct RewriteTestRow(int Id, string Name); + + [Fact] + public void DapperVanillaConsistency() + => AssertResults(_fixture.NpgsqlConnection.Query(CompositeQuery, CompositeArgs)); + + [Fact] + public void ManuallyImplementedBasic() + => AssertResults(_fixture.NpgsqlConnection + .Command(CompositeQuery, CommandType.Text, handler: BasicCommandFactory.Instance) + .Query(CompositeArgs, true, rowFactory: BasicRowFactory.Instance)); + + [Fact] + public void ManuallyImplementedFancy() + => AssertResults(_fixture.NpgsqlConnection + .Command(CompositeQuery, CommandType.Text, handler: FancyCommandFactory.Instance) + .Query(CompositeArgs, true, rowFactory: BasicRowFactory.Instance)); + + sealed class BasicCommandFactory : CommandFactory + { + public static BasicCommandFactory Instance { get; } = new(); + private BasicCommandFactory() { } + + public override void AddParameters(in UnifiedCommand command, object args) + { + var typed = Cast(args, static () => new { x = "abc", y = "def" }); + var ps = command.Parameters; + DbParameter p = command.AddParameter(); + p.ParameterName = "x"; + p.DbType = DbType.String; + p.Size = -1; + p.Value = AsValue(typed.x); + + p = command.AddParameter(); + p.ParameterName = "y"; + p.DbType = DbType.String; + p.Size = -1; + p.Value = AsValue(typed.y); + } + + public override void UpdateParameters(in UnifiedCommand command, object args) + { + var typed = Cast(args, static () => new { x = "abc", y = "def" }); + var ps = command.Parameters; + ps[0].Value = AsValue(typed.x); + ps[1].Value = AsValue(typed.y); + } + } + + sealed class BasicRowFactory : RowFactory + { + public static BasicRowFactory Instance { get; } = new(); + + // note we're ignoring tokenize etc for simplicity; just do raw + public override RewriteTestRow Read(DbDataReader reader, ReadOnlySpan tokens, int columnOffset, object? state) + => new(reader.GetInt32(columnOffset), reader.GetString(columnOffset + 1)); + } + + sealed class FancyCommandFactory : CommandFactory + { + // trying to prove the feature for https://github.com/DapperLib/DapperAOT/issues/78 + public static FancyCommandFactory Instance { get; } = new(); + private FancyCommandFactory() { } + + // assert that the SQL is what we expected; we do *not* want to parse and split + // SQL at runtime, so we only do this if we already figured out how to do it + // (otherwise, we'll defer to the underlying ADO.NET provider) + public override bool UseBatch(string sql) => sql == CompositeQuery; + + public override void AddCommands(in UnifiedBatch batch, string sql, object args) + { + var typed = Cast(args, static () => new { x = "abc", y = "def" }); + + batch.AddCommand(""" + TRUNCATE rewrite_test RESTART IDENTITY + """); + + // note: not allowed to reuse parameters between commands; throws if you try + + batch.AddCommand(""" + INSERT INTO rewrite_test(name) + VALUES ($1) + """); + + var p = batch.AddParameter(); + p.DbType = DbType.String; + p.Size = -1; + p.Value = AsValue(typed.x); + + batch.AddCommand(""" + INSERT INTO rewrite_test(name) + VALUES ($1 || $2) + """); + + p = batch.AddParameter(); + p.DbType = DbType.String; + p.Size = -1; + p.Value = AsValue(typed.x); + + p = batch.AddParameter(); + p.DbType = DbType.String; + p.Size = -1; + p.Value = AsValue(typed.y); + + batch.AddCommand(""" + INSERT INTO rewrite_test(name) + VALUES ($1) + """); + + p = batch.AddParameter(); + p.DbType = DbType.String; + p.Size = -1; + p.Value = AsValue(typed.y); + + batch.AddCommand(""" + select id, name + from rewrite_test + order by id + """); + } + + public override void PostProcess(in UnifiedBatch batch, object args, int rowCount) + { + // example usage + // args.Something = Cast(commands[commandIndex].Parameters[1].Value); + } + + public override bool RequirePostProcess => true; + + public override void PostProcess(in UnifiedCommand command, object args, int rowCount) + => throw new NotImplementedException("we don't expect to get here!"); + + public override void AddParameters(in UnifiedCommand command, object args) + => throw new NotImplementedException("we don't expect to get here!"); + + public override void UpdateParameters(in UnifiedCommand command, object args) + => throw new NotImplementedException("we don't expect to get here!"); + } +} + + +#endif \ No newline at end of file diff --git a/test/Dapper.AOT.Test/NpgsqlParseTests.cs b/test/Dapper.AOT.Test/NpgsqlParseTests.cs new file mode 100644 index 00000000..336de7c0 --- /dev/null +++ b/test/Dapper.AOT.Test/NpgsqlParseTests.cs @@ -0,0 +1,231 @@ +using Dapper.Internal.SqlParsing; +using System.Collections.Generic; +using System.Data; +using System.Linq; +using Xunit; + +namespace Dapper.AOT.Test; +// inspired, with love, from https://github.com/npgsql/npgsql/blob/main/test/Npgsql.Tests/SqlQueryParserTests.cs + +public class NpgsqlParseTests +{ + [Fact] + public void Parameter_simple() + { + var result = ParseCommand("SELECT :p1, :p2", [(":p1", "foo"), (":p2", "bar")]).Single(); + Assert.Equal("SELECT $1, $2", result.FinalCommandText); + Assert.Equal(Args(("$1", "foo"), ("$2", "bar")), result.Parameters); + } + + static (string name, object value)[] Args(params (string name, object value)[] args) => args; + + [Fact] + public void Parameter_name_with_dot() + { + var p = (":a.parameter", "foo"); + var results = ParseCommand("INSERT INTO data (field_char5) VALUES (:a.parameter)", p); + var result = Assert.Single(results); + Assert.Equal("INSERT INTO data (field_char5) VALUES ($1)", result.FinalCommandText); + Assert.Equal(Args(("$1", "foo")), result.Parameters); + } + + [Theory] + [InlineData(@"SELECT to_tsvector('fat cats ate rats') @@ to_tsquery('cat & rat')")] + [InlineData(@"SELECT 'cat'::tsquery @> 'cat & rat'::tsquery")] + [InlineData(@"SELECT 'cat'::tsquery <@ 'cat & rat'::tsquery")] + [InlineData(@"SELECT 'b''la'")] + [InlineData(@"SELECT 'type(''m.response'')#''O''%'")] + [InlineData(@"SELECT 'abc'':str''a:str'")] + [InlineData(@"SELECT 1 FROM "":str""")] + [InlineData(@"SELECT 1 FROM 'yo'::str")] + [InlineData("SELECT $\u00ffabc0$literal string :str :int$\u00ffabc0 $\u00ffabc0$")] + [InlineData("SELECT $$:str$$")] + public void Untouched(string sql) + { + var results = ParseCommand(sql, (":param", "foo")); + var result = Assert.Single(results); + Assert.Equal(sql, result.FinalCommandText); + Assert.Empty(result.Parameters); + } + + [Theory] + [InlineData(@"SELECT 1<:param")] + [InlineData(@"SELECT 1>:param")] + [InlineData(@"SELECT 1<>:param")] + [InlineData("SELECT--comment\r:param")] + public void Parameter_gets_bound(string sql) + { + var p = (":param", "foo"); + var results = ParseCommand(sql, p); + var result = Assert.Single(results); + Assert.Equal(Args(("$1", "foo")), result.Parameters); + } + + [Fact] //, IssueLink("https://github.com/npgsql/npgsql/issues/1177")] + public void Parameter_gets_bound_non_ascii() + { + var p = ("漢字", "foo"); + var results = Assert.Single(ParseCommand("SELECT @漢字", p)); + Assert.Equal("SELECT $1", results.FinalCommandText); + Assert.Equal(Args(("$1", "foo")), results.Parameters); + } + + [Theory] + [InlineData(@"SELECT e'ab\'c:param'")] + [InlineData(@"SELECT/*/* -- nested comment :int /*/* *//*/ **/*/*/*/1")] + [InlineData(@"SELECT 1, + -- Comment, @param and also :param + 2")] + public void Parameter_does_not_get_bound(string sql) + { + var p = (":param", "foo"); + var results = Assert.Single(ParseCommand(sql, p)); + Assert.Equal(sql, results.FinalCommandText); + Assert.Empty(results.Parameters); + } + + [Fact] + public void Non_conforming_string() + { + var result = ParseCommand(@"SELECT 'abc\':str''a:str'").Single(); + Assert.Equal(@"SELECT 'abc\':str''a:str'", result.FinalCommandText); + Assert.Empty(result.Parameters); + } + + [Fact] + public void Multiquery_with_parameters() + { + var parameters = Args(("p1", "abc"), ("p2", "def"), ("p3", "ghi")); + + var results = ParseCommand("SELECT @p3, @p1; SELECT @p2, @p3", parameters); + + Assert.Equal(2, results.Count); + + var result = results[0]; + Assert.Equal("SELECT $1, $2", result.FinalCommandText); + Assert.Equal(Args(("$1", "ghi"), ("$2", "abc")), result.Parameters); + + result = results[1]; + Assert.Equal("SELECT $1, $2", result.FinalCommandText); + Assert.Equal(Args(("$1", "def"), ("$2", "ghi")), result.Parameters); + } + + // N/A in this context of Dapper, although we should investigate further + //[Fact] + //public void No_output_parameters() + //{ + // var p = new NpgsqlParameter("p", DbType.String) { Direction = ParameterDirection.Output }; + // Assert.That(() => ParseCommand("SELECT @p", p), Throws.Exception); + //} + + [Fact] + public void Missing_parameter_is_ignored() + { + var results = ParseCommand("SELECT @p; SELECT 1"); + Assert.Equal(2, results.Count); + Assert.Equal("SELECT @p", results[0].FinalCommandText); + Assert.Equal("SELECT 1", results[1].FinalCommandText); + Assert.Empty(results[0].Parameters); + Assert.Empty(results[1].Parameters); + } + + [Fact] + public void Consecutive_semicolons() + { + var results = ParseCommand(";;SELECT 1"); + + Assert.Equal("SELECT 1", Assert.Single(results).FinalCommandText); + + // // Npgsql behaviour, discussed with roji - above is fine + // Assert.Equal(3, results.Count); + // Assert.Equal("", results[0].FinalCommandText); + // Assert.Equal("", results[1].FinalCommandText); + // Assert.Equal("SELECT 1", results[2].FinalCommandText); + } + + [Fact] + public void Trailing_semicolon() + { + var results = Assert.Single(ParseCommand("SELECT 1;")); + Assert.Equal("SELECT 1", results.FinalCommandText); + } + + [Fact] + public void Empty() + { + var results = ParseCommand(""); + Assert.Equal("", Assert.Single(results).FinalCommandText); + } + + [Fact] + public void Semicolon_in_parentheses() + { + var results = Assert.Single(ParseCommand("CREATE OR REPLACE RULE test AS ON UPDATE TO test DO (SELECT 1; SELECT 1)")); + Assert.Equal("CREATE OR REPLACE RULE test AS ON UPDATE TO test DO (SELECT 1; SELECT 1)", results.FinalCommandText); + } + + [Fact] + public void Semicolon_after_parentheses() + { + var results = ParseCommand("CREATE OR REPLACE RULE test AS ON UPDATE TO test DO (SELECT 1); SELECT 1"); + Assert.Equal(2, results.Count); + Assert.Equal("CREATE OR REPLACE RULE test AS ON UPDATE TO test DO (SELECT 1)", results[0].FinalCommandText); + Assert.Equal("SELECT 1", results[1].FinalCommandText); + } + + [Fact] + public void Reduce_number_of_statements() + { + Assert.Equal(2, ParseCommand("SELECT 1; SELECT 2").Count); + Assert.Single(ParseCommand("SELECT 1")); + } + + [Fact] + public void Trim_whitespace() + { + var result = Assert.Single(ParseCommand(" SELECT 1\t", strip: true)); + Assert.Equal("SELECT 1", result.FinalCommandText); + Assert.Empty(result.Parameters); + } + + #region Setup / Teardown / Utils + + public class ParseResult + { + public string FinalCommandText { get; set; } = ""; + public (string name, object value)[] Parameters { get; set; } = []; + internal OrdinalResult Result { get; set; } + } + + List ParseCommand(string sql, params (string name, object value)[] parameters) + => ParseCommand(sql, false, parameters); + List ParseCommand(string sql, bool strip, params (string name, object value)[] parameters) + { + var parsed = GeneralSqlParser.Parse(sql, SqlAnalysis.SqlSyntax.PostgreSql, strip); + if (parsed.Count == 0) + { + return new List + { + new ParseResult + { + FinalCommandText = "", + Parameters = [], + Result = sql == "" ? OrdinalResult.NoChange : OrdinalResult.Success, + } + }; + } + return parsed.ConvertAll(cmd => + { + var result = cmd.TryMakeOrdinal(parameters, p => p.name, + (p, i) => (CommandBatch.OrdinalNaming(p.name, i), p.value), out var args, out var sql); + return new ParseResult + { + FinalCommandText = sql, + Parameters = args.ToArray(), + Result = result + }; + }); + } + + #endregion +} \ No newline at end of file diff --git a/test/Dapper.AOT.Test/Verifiers/DAP245.cs b/test/Dapper.AOT.Test/Verifiers/DAP245.cs new file mode 100644 index 00000000..d5ec9a07 --- /dev/null +++ b/test/Dapper.AOT.Test/Verifiers/DAP245.cs @@ -0,0 +1,25 @@ +using Dapper.CodeAnalysis; +using System.Threading.Tasks; +using Xunit; +using Diagnostics = Dapper.CodeAnalysis.DapperAnalyzer.Diagnostics; +namespace Dapper.AOT.Test.Verifiers; + +public class DAP245 : Verifier +{ + [Fact] + public Task UseGoWithoutDelimiter() => SqlVerifyAsync(""" + INSERT {|#0:GO|} ({|#1:GO|}) VALUES (42); + SELECT {|#2:GO|} FROM {|#3:GO|}; + """, + Diagnostic(Diagnostics.DangerousNonDelimitedIdentifier).WithLocation(0).WithArguments("GO"), + Diagnostic(Diagnostics.DangerousNonDelimitedIdentifier).WithLocation(1).WithArguments("GO"), + Diagnostic(Diagnostics.DangerousNonDelimitedIdentifier).WithLocation(2).WithArguments("GO"), + Diagnostic(Diagnostics.DangerousNonDelimitedIdentifier).WithLocation(3).WithArguments("GO")); + + [Fact] + public Task UseGoWithDelimited() => SqlVerifyAsync(""" + INSERT [GO] ([GO]) VALUES (42); + SELECT [GO] FROM [GO]; + """); + +} \ No newline at end of file diff --git a/test/UsageBenchmark/BatchInsertBenchmarks.cs b/test/UsageBenchmark/BatchInsertBenchmarks.cs index a0db9166..fa18f3ec 100644 --- a/test/UsageBenchmark/BatchInsertBenchmarks.cs +++ b/test/UsageBenchmark/BatchInsertBenchmarks.cs @@ -341,12 +341,11 @@ public override DbCommand GetCommand(DbConnection connection, string sql, Comman public override void AddParameters(in UnifiedCommand command, Customer obj) { - var p = command.CreateParameter(); + var p = command.AddParameter(); p.ParameterName = "name"; p.DbType = DbType.String; p.Size = 400; p.Value = AsValue(obj.Name); - command.Parameters.Add(p); } public override void UpdateParameters(in UnifiedCommand command, Customer obj) diff --git a/test/UsageBenchmark/CommandRewriteBenchmarks.cs b/test/UsageBenchmark/CommandRewriteBenchmarks.cs new file mode 100644 index 00000000..c5aa9bc6 --- /dev/null +++ b/test/UsageBenchmark/CommandRewriteBenchmarks.cs @@ -0,0 +1,396 @@ +using BenchmarkDotNet.Attributes; +using BenchmarkDotNet.Configs; +using Npgsql; +using System; +using System.Data; +using System.Data.Common; +using System.Threading; +using System.Threading.Tasks; +using Testcontainers.PostgreSql; + +namespace Dapper; + +[ShortRunJob, MemoryDiagnoser, GroupBenchmarksBy(BenchmarkLogicalGroupRule.ByCategory), CategoriesColumn] +public class CommandRewriteBenchmarks : IAsyncDisposable +{ + private readonly NpgsqlConnection npgsql = new(); + + private readonly PostgreSqlContainer _postgresContainer = new PostgreSqlBuilder() + .WithImage("postgres:15-alpine") + .Build(); + + + public CommandRewriteBenchmarks() + { + _postgresContainer.StartAsync().GetAwaiter().GetResult(); // yes, I know + npgsql.ConnectionString = _postgresContainer.GetConnectionString(); + npgsql.Open(); + npgsql.Execute(""" + CREATE TABLE IF NOT EXISTS RewriteCustomers( + Id integer GENERATED ALWAYS AS IDENTITY, + Name varchar(40) NOT NULL + ); + """); + } + + [GlobalSetup] + public void Setup() + { + npgsql.Execute("TRUNCATE RewriteCustomers RESTART IDENTITY;"); + } + + public class MyArgsType + { + public string? Name0 { get; set; } + public string? Name1 { get; set; } + public string? Name2 { get; set; } + public string? Name3 { get; set; } + } + + public static readonly MyArgsType Args = new() { Name0 = "abc", Name1 = "def", Name2 = "ghi", Name3 = "jkl" }; + + // for our test, we're going to do 4 operations and try rewriting it as batch + public const string BasicSql = """ + insert into RewriteCustomers (Name) values (@Name0); + insert into RewriteCustomers (Name) values (@Name1); + insert into RewriteCustomers (Name) values (@Name2); + insert into RewriteCustomers (Name) values (@Name3); + """; + + [Benchmark(Baseline = true)] + public int Dapper() => npgsql.Execute(BasicSql, Args); + + // note these are hand written + + [Benchmark] + public int DapperAOT() => npgsql.Command(BasicSql, handler: BasicCommand.NonCached).Execute(Args); + + [Benchmark] + public int DapperAOT_Cached() => npgsql.Command(BasicSql, handler: BasicCommand.Cached).Execute(Args); + + [Benchmark] + public int DapperAOT_Batch() => npgsql.Command(BasicSql, handler: RewriteCommand.NonCached).Execute(Args); + + [Benchmark] + public int DapperAOT_BatchCached() => npgsql.Command(BasicSql, handler: RewriteCommand.Cached).Execute(Args); + + + [Benchmark] + public Task DapperAOT_Async() => npgsql.Command(BasicSql, handler: BasicCommand.NonCached).ExecuteAsync(Args); + + [Benchmark] + public Task DapperAOT_CachedAsync() => npgsql.Command(BasicSql, handler: BasicCommand.Cached).ExecuteAsync(Args); + + [Benchmark] + public Task DapperAOT_BatchAsync() => npgsql.Command(BasicSql, handler: RewriteCommand.NonCached).ExecuteAsync(Args); + + [Benchmark] + public Task DapperAOT_BatchCachedAsync() => npgsql.Command(BasicSql, handler: RewriteCommand.Cached).ExecuteAsync(Args); + + static DbCommand CreateCommand(DbConnection connection) + { + // note that connection.CreateCommand has a stash for command recycling, + // that isn't available if you "new" etc + var cmd = connection.CreateCommand(); + cmd.Connection = connection; + cmd.CommandText = BasicSql; + cmd.CommandType = CommandType.Text; + + var ps = cmd.Parameters; + var p = cmd.CreateParameter(); + p.ParameterName = "Name0"; + p.DbType = DbType.String; + p.Size = -1; + p.Value = Args.Name0; + ps.Add(p); + + p = cmd.CreateParameter(); + p.ParameterName = "Name1"; + p.DbType = DbType.String; + p.Size = -1; + p.Value = Args.Name1; + ps.Add(p); + + p = cmd.CreateParameter(); + p.ParameterName = "Name2"; + p.DbType = DbType.String; + p.Size = -1; + p.Value = Args.Name2; + ps.Add(p); + + p = cmd.CreateParameter(); + p.ParameterName = "Name3"; + p.DbType = DbType.String; + p.Size = -1; + p.Value = Args.Name3; + ps.Add(p); + + return cmd; + } + + static void UpdateCommand(DbCommand cmd, DbConnection connection) + { + cmd.Connection = connection; + var ps = cmd.Parameters; + ps[0].Value = Args.Name0; + ps[1].Value = Args.Name1; + ps[2].Value = Args.Name2; + ps[3].Value = Args.Name3; + } + + [Benchmark] + public int AdoNetCommand() + { + using var cmd = CreateCommand(npgsql); + return cmd.ExecuteNonQuery(); + } + + [Benchmark] + public async Task AdoNetCommandAsync() + { + using var cmd = CreateCommand(npgsql); + return await cmd.ExecuteNonQueryAsync(); + } + + static DbCommand? _spareCommand; + + private DbCommand GetCached() + { + var cmd = Interlocked.Exchange(ref _spareCommand, null); + if (cmd is not null) + { + UpdateCommand(cmd, npgsql); + return cmd; + } + return CreateCommand(npgsql); + } + private void PutCommand(DbCommand command) + { + command.Connection = null; + Interlocked.Exchange(ref _spareCommand, command)?.Dispose(); + } + + [Benchmark] + public int AdoNetCommandCached() + { + var cmd = GetCached(); + var result = cmd.ExecuteNonQuery(); + PutCommand(cmd); + return result; + } + + [Benchmark] + public async Task AdoNetCommandCachedAsync() + { + var cmd = GetCached(); + var result = await cmd.ExecuteNonQueryAsync(); + PutCommand(cmd); + return result; + } + + private static NpgsqlBatch CreateBatch(NpgsqlConnection connection) + { + // note that CreateBatch has obj reuse that new() lacks + var batch = connection.CanCreateBatch ? connection.CreateBatch() : new NpgsqlBatch(); + batch.Connection = connection; + var commands = batch.BatchCommands; + + var cmd = batch.CreateBatchCommand(); + cmd.CommandText = "insert into RewriteCustomers(Name) values($1)"; + cmd.CommandType = CommandType.Text; + var p = cmd.CanCreateParameter ? cmd.CreateParameter() : new NpgsqlParameter(); + p.DbType = DbType.String; + p.Size = -1; + p.Value = Args.Name0; + cmd.Parameters.Add(p); + commands.Add(cmd); + + cmd = batch.CreateBatchCommand(); + cmd.CommandText = "insert into RewriteCustomers(Name) values($1)"; + cmd.CommandType = CommandType.Text; + p = cmd.CanCreateParameter ? cmd.CreateParameter() : new NpgsqlParameter(); + p.DbType = DbType.String; + p.Size = -1; + p.Value = Args.Name1; + cmd.Parameters.Add(p); + commands.Add(cmd); + + cmd = batch.CreateBatchCommand(); + cmd.CommandText = "insert into RewriteCustomers(Name) values($1)"; + cmd.CommandType = CommandType.Text; + p = cmd.CanCreateParameter ? cmd.CreateParameter() : new NpgsqlParameter(); + p.DbType = DbType.String; + p.Size = -1; + p.Value = Args.Name2; + cmd.Parameters.Add(p); + commands.Add(cmd); + + cmd = batch.CreateBatchCommand(); + cmd.CommandText = "insert into RewriteCustomers(Name) values($1)"; + cmd.CommandType = CommandType.Text; + p = cmd.CanCreateParameter ? cmd.CreateParameter() : new NpgsqlParameter(); + p.DbType = DbType.String; + p.Size = -1; + p.Value = Args.Name3; + cmd.Parameters.Add(p); + commands.Add(cmd); + + return batch; + } + + static void UpdateBatch(NpgsqlBatch batch, NpgsqlConnection connection) + { + batch.Connection = connection; + var commands = batch.BatchCommands; + commands[0].Parameters[0].Value = Args.Name0; + commands[1].Parameters[0].Value = Args.Name1; + commands[2].Parameters[0].Value = Args.Name2; + commands[3].Parameters[0].Value = Args.Name3; + } + + static NpgsqlBatch? _spareBatch; + + [Benchmark] + public int AdoNetBatch() + { + using var batch = CreateBatch(npgsql); + return batch.ExecuteNonQuery(); + } + + [Benchmark] + public int AdoNetBatchCached() + { + var batch = Interlocked.Exchange(ref _spareBatch, null); + if (batch is null) + { + batch = CreateBatch(npgsql); + } + else + { + UpdateBatch(batch, npgsql); + } + var result = batch.ExecuteNonQuery(); + batch.Connection = null; + Interlocked.Exchange(ref _spareBatch, batch)?.Dispose(); + return result; + } + + public async ValueTask DisposeAsync() + { + await npgsql.DisposeAsync(); + await _postgresContainer.DisposeAsync(); + GC.SuppressFinalize(this); + } + + private sealed class BasicCommand : CommandFactory + { + private readonly bool cached; + private BasicCommand(bool cached) => this.cached = cached; + public static BasicCommand NonCached { get; } = new BasicCommand(false); + public static BasicCommand Cached { get; } = new BasicCommand(true); + + public override void AddParameters(in UnifiedCommand command, MyArgsType args) + { + var p = command.AddParameter(); + p.ParameterName = "Name0"; + p.DbType = DbType.String; + p.Size = -1; + p.Value = AsValue(args.Name0); + + p = command.AddParameter(); + p.ParameterName = "Name1"; + p.DbType = DbType.String; + p.Size = -1; + p.Value = AsValue(args.Name1); + + p = command.AddParameter(); + p.ParameterName = "Name2"; + p.DbType = DbType.String; + p.Size = -1; + p.Value = AsValue(args.Name2); + + p = command.AddParameter(); + p.ParameterName = "Name3"; + p.DbType = DbType.String; + p.Size = -1; + p.Value = AsValue(args.Name3); + } + + public override void UpdateParameters(in UnifiedCommand command, MyArgsType args) + { + var ps = command.Parameters; + ps[0].Value = AsValue(args.Name0); + ps[1].Value = AsValue(args.Name1); + ps[2].Value = AsValue(args.Name2); + ps[3].Value = AsValue(args.Name3); + } + + private static DbCommand? _spareCommand; + public override bool TryRecycle(DbCommand command) + => cached && TryRecycle(ref _spareCommand, command); + + + public override DbCommand GetCommand(DbConnection connection, string sql, CommandType commandType, MyArgsType args) + => (cached ? TryReuse(ref _spareCommand, sql, commandType, args) : null) + ?? base.GetCommand(connection, sql, commandType, args); + } + + private sealed class RewriteCommand : CommandFactory + { + private readonly bool cached; + private RewriteCommand(bool cached) => this.cached = cached; + public static RewriteCommand Cached { get; } = new RewriteCommand(true); + public static RewriteCommand NonCached { get; } = new RewriteCommand(false); + + public override void AddParameters(in UnifiedCommand command, MyArgsType args) + => throw new NotSupportedException(); // we don't expect to get here (in reality, we would have both versions) + + public override bool UseBatch(string sql) => sql == BasicSql; // assert that we're doing the right thing + + public override void AddCommands(in UnifiedBatch batch, string sql, MyArgsType args) + { + // the first command is initialized automatically + batch.SetCommand("insert into RewriteCustomers(Name) values($1)"); + var p = batch.AddParameter(); + p.DbType = DbType.String; + p.Size = -1; + p.Value = AsValue(args.Name0); + + // the fact that the NpgSql command is the same is a coincidence of the test + batch.AddCommand("insert into RewriteCustomers(Name) values($1)"); + p = batch.AddParameter(); + p.DbType = DbType.String; + p.Size = -1; + p.Value = AsValue(args.Name1); + + batch.AddCommand("insert into RewriteCustomers(Name) values($1)"); + p = batch.AddParameter(); + p.DbType = DbType.String; + p.Size = -1; + p.Value = AsValue(args.Name2); + + batch.AddCommand("insert into RewriteCustomers(Name) values($1)"); + p = batch.AddParameter(); + p.DbType = DbType.String; + p.Size = -1; + p.Value = AsValue(args.Name3); + } + + public override void UpdateParameters(in UnifiedBatch command, MyArgsType args) + { + command[0][0].Value = AsValue(args.Name0); + command[1][0].Value = AsValue(args.Name1); + command[2][0].Value = AsValue(args.Name2); + command[3][0].Value = AsValue(args.Name3); + } + + private static DbBatch? _spareBatch; + + public override bool TryRecycle(DbBatch batch) + => cached && TryRecycle(ref _spareBatch, batch); + + public override DbBatch GetBatch(DbConnection connection, string sql, CommandType commandType, MyArgsType args) + => (cached ? TryReuse(ref _spareBatch, args) : null) + ?? base.GetBatch(connection, sql, commandType, args); + } +} \ No newline at end of file diff --git a/test/UsageBenchmark/ListIterationBenchmarks.cs b/test/UsageBenchmark/ListIterationBenchmarks.cs index ac10fa82..ee73f2b3 100644 --- a/test/UsageBenchmark/ListIterationBenchmarks.cs +++ b/test/UsageBenchmark/ListIterationBenchmarks.cs @@ -7,7 +7,7 @@ namespace Dapper; [ShortRunJob, MemoryDiagnoser] public class ListIterationBenchmarks { - private readonly List customers = new(); + private readonly List customers = []; [Params(0, 1, 10, 100, 1000)] public int Count { get; set; } diff --git a/test/UsageBenchmark/Program.cs b/test/UsageBenchmark/Program.cs index 2ba424f7..30c1a7a7 100644 --- a/test/UsageBenchmark/Program.cs +++ b/test/UsageBenchmark/Program.cs @@ -13,70 +13,99 @@ static class Program #else static async Task Main() { - await using (var obj = new BatchInsertBenchmarks()) + await Task.Yield(); + var obj = new CommandRewriteBenchmarks(); + obj.Setup(); + for (int i = 0; i < 10; i++) { - await RunAllInserts(obj, 10, false); - await RunAllInserts(obj, 10, true); - } - - using (var obj = new QueryBenchmarks()) - { - await RunAllQueries(obj, 10, false); - await RunAllQueries(obj, 10, true); - } - - static async Task RunAllInserts(BatchInsertBenchmarks obj, int count, bool isOpen) - { - obj.Count = count; - obj.IsOpen = isOpen; - obj.Setup(); - - obj.DebugState(); - - Console.WriteLine(obj.Manual()); - Console.WriteLine(await obj.ManualAsync()); - Console.WriteLine(obj.Dapper()); - Console.WriteLine(await obj.DapperAsync()); - Console.WriteLine(obj.DapperAot()); - Console.WriteLine(await obj.DapperAotAsync()); + Console.WriteLine(obj.DapperAOT()); + Console.WriteLine(obj.DapperAOT_Cached()); + Console.WriteLine(obj.DapperAOT_Batch()); + Console.WriteLine(obj.DapperAOT_BatchCached()); - Console.WriteLine(obj.DapperAotManual()); - Console.WriteLine(await obj.DapperAotAsync()); + Console.WriteLine(await obj.DapperAOT_Async()); + Console.WriteLine(await obj.DapperAOT_CachedAsync()); + Console.WriteLine(await obj.DapperAOT_BatchAsync()); + Console.WriteLine(await obj.DapperAOT_BatchCachedAsync()); - Console.WriteLine(obj.DapperAot_PreparedManual()); - Console.WriteLine(await obj.DapperAot_PreparedAsync()); + Console.WriteLine(obj.AdoNetCommand()); + Console.WriteLine(obj.AdoNetBatch()); + Console.WriteLine(obj.AdoNetCommandCached()); + Console.WriteLine(obj.AdoNetBatchCached()); - Console.WriteLine(obj.EntityFramework()); - Console.WriteLine(await obj.EntityFrameworkAsync()); + Console.WriteLine(await obj.AdoNetCommandAsync()); + Console.WriteLine(await obj.AdoNetCommandCachedAsync()); - Console.WriteLine(obj.SqlBulkCopyFastMember()); - Console.WriteLine(obj.SqlBulkCopyDapper()); - Console.WriteLine(await obj.SqlBulkCopyFastMemberAsync()); - - Console.WriteLine(obj.NpgsqlDapperAotNoBatch()); - Console.WriteLine(obj.NpgsqlDapperAotFullBatch()); + // for profiling single methods etc + // _ = obj.AdoNetCommand(); } - static async Task RunAllQueries(QueryBenchmarks obj, int count, bool isOpen) - { - obj.Count = count; - obj.IsOpen = isOpen; - obj.Setup(); - - Console.WriteLine(obj.DapperDynamic()); - Console.WriteLine(obj.Dapper()); - Console.WriteLine(obj.DapperAotDynamic()); - Console.WriteLine(obj.DapperAot()); - Console.WriteLine(obj.EntityFramework()); - - Console.WriteLine(await obj.DapperDynamicAsync()); - Console.WriteLine(await obj.DapperAsync()); - Console.WriteLine(await obj.DapperAotDynamicAsync()); - Console.WriteLine(await obj.DapperAotAsync()); - Console.WriteLine(await obj.EntityFrameworkAsync()); - } + //await using (var obj = new BatchInsertBenchmarks()) + //{ + // await RunAllInserts(obj, 10, false); + // await RunAllInserts(obj, 10, true); + //} + + //using (var obj = new QueryBenchmarks()) + //{ + // await RunAllQueries(obj, 10, false); + // await RunAllQueries(obj, 10, true); + //} + + //static async Task RunAllInserts(BatchInsertBenchmarks obj, int count, bool isOpen) + //{ + // obj.Count = count; + // obj.IsOpen = isOpen; + // obj.Setup(); + + // obj.DebugState(); + + // Console.WriteLine(obj.Manual()); + // Console.WriteLine(await obj.ManualAsync()); + + // Console.WriteLine(obj.Dapper()); + // Console.WriteLine(await obj.DapperAsync()); + + // Console.WriteLine(obj.DapperAot()); + // Console.WriteLine(await obj.DapperAotAsync()); + + // Console.WriteLine(obj.DapperAotManual()); + // Console.WriteLine(await obj.DapperAotAsync()); + + // Console.WriteLine(obj.DapperAot_PreparedManual()); + // Console.WriteLine(await obj.DapperAot_PreparedAsync()); + + // Console.WriteLine(obj.EntityFramework()); + // Console.WriteLine(await obj.EntityFrameworkAsync()); + + // Console.WriteLine(obj.SqlBulkCopyFastMember()); + // Console.WriteLine(obj.SqlBulkCopyDapper()); + // Console.WriteLine(await obj.SqlBulkCopyFastMemberAsync()); + + // Console.WriteLine(obj.NpgsqlDapperAotNoBatch()); + // Console.WriteLine(obj.NpgsqlDapperAotFullBatch()); + //} + + //static async Task RunAllQueries(QueryBenchmarks obj, int count, bool isOpen) + //{ + // obj.Count = count; + // obj.IsOpen = isOpen; + // obj.Setup(); + + // Console.WriteLine(obj.DapperDynamic()); + // Console.WriteLine(obj.Dapper()); + // Console.WriteLine(obj.DapperAotDynamic()); + // Console.WriteLine(obj.DapperAot()); + // Console.WriteLine(obj.EntityFramework()); + + // Console.WriteLine(await obj.DapperDynamicAsync()); + // Console.WriteLine(await obj.DapperAsync()); + // Console.WriteLine(await obj.DapperAotDynamicAsync()); + // Console.WriteLine(await obj.DapperAotAsync()); + // Console.WriteLine(await obj.EntityFrameworkAsync()); + //} } #endif } \ No newline at end of file diff --git a/test/UsageBenchmark/UsageBenchmark.csproj b/test/UsageBenchmark/UsageBenchmark.csproj index b3eff8e1..d9f4886e 100644 --- a/test/UsageBenchmark/UsageBenchmark.csproj +++ b/test/UsageBenchmark/UsageBenchmark.csproj @@ -6,11 +6,15 @@ $(InterceptorsPreviewNamespaces);Dapper.AOT + + BatchInsertBenchmarks.cs + - + +