Skip to content
This repository has been archived by the owner on Dec 24, 2022. It is now read-only.

Commit

Permalink
Add support for auto splitting of IEnumerable params into multi db pa…
Browse files Browse the repository at this point in the history
…rams
  • Loading branch information
mythz committed Nov 23, 2017
1 parent 742f9ae commit 8cee8a2
Show file tree
Hide file tree
Showing 5 changed files with 84 additions and 34 deletions.
Expand Up @@ -56,7 +56,7 @@ internal static Task<List<TModel>> SelectAsync<TModel>(this IDbCommand dbCmd, Ty

internal static Task<List<TModel>> SelectAsync<TModel>(this IDbCommand dbCmd, Type fromTableType, string sqlFilter, object anonType, CancellationToken token)
{
if (anonType != null) dbCmd.SetParameters(fromTableType, anonType, excludeDefaults: false);
if (anonType != null) dbCmd.SetParameters(fromTableType, anonType, excludeDefaults: false, sql: ref sqlFilter);
var sql = OrmLiteReadCommandExtensions.ToSelect<TModel>(dbCmd.GetDialectProvider(), fromTableType, sqlFilter);
return dbCmd.ConvertToListAsync<TModel>(sql, token);
}
Expand Down Expand Up @@ -105,7 +105,7 @@ internal static Task<T> SingleAsync<T>(this IDbCommand dbCmd, string sql, object
{
return OrmLiteUtils.IsScalar<T>()
? dbCmd.ScalarAsync<T>(sql, anonType, token)
: dbCmd.SetParameters<T>(anonType, excludeDefaults: false).ConvertToAsync<T>(dbCmd.GetDialectProvider().ToSelectStatement(typeof(T), sql), token);
: dbCmd.SetParameters<T>(anonType, excludeDefaults: false, sql: ref sql).ConvertToAsync<T>(dbCmd.GetDialectProvider().ToSelectStatement(typeof(T), sql), token);
}

internal static Task<List<T>> WhereAsync<T>(this IDbCommand dbCmd, string name, object value, CancellationToken token)
Expand All @@ -131,7 +131,7 @@ internal static Task<List<T>> SelectAsync<T>(this IDbCommand dbCmd, string sql,

internal static Task<List<T>> SelectAsync<T>(this IDbCommand dbCmd, string sql, object anonType, CancellationToken token)
{
dbCmd.SetParameters<T>(anonType, excludeDefaults: false).CommandText = dbCmd.GetDialectProvider().ToSelectStatement(typeof(T), sql);
dbCmd.SetParameters<T>(anonType, excludeDefaults: false, sql: ref sql).CommandText = dbCmd.GetDialectProvider().ToSelectStatement(typeof(T), sql);
return dbCmd.ConvertToListAsync<T>(null, token);
}

Expand All @@ -149,7 +149,7 @@ internal static Task<List<T>> SqlListAsync<T>(this IDbCommand dbCmd, string sql,

internal static Task<List<T>> SqlListAsync<T>(this IDbCommand dbCmd, string sql, object anonType, CancellationToken token)
{
dbCmd.SetParameters<T>(anonType, excludeDefaults: false).CommandText = sql;
dbCmd.SetParameters<T>(anonType, excludeDefaults: false, sql: ref sql).CommandText = sql;
return dbCmd.ConvertToListAsync<T>(null, token);
}

Expand Down Expand Up @@ -192,7 +192,7 @@ internal static Task<T> SqlScalarAsync<T>(this IDbCommand dbCmd, string sql, IEn

internal static Task<T> SqlScalarAsync<T>(this IDbCommand dbCmd, string sql, object anonType, CancellationToken token)
{
return dbCmd.SetParameters<T>(anonType, excludeDefaults: false).ScalarAsync<T>(sql, token);
return dbCmd.SetParameters<T>(anonType, excludeDefaults: false, sql: ref sql).ScalarAsync<T>(sql, token);
}

internal static Task<T> SqlScalarAsync<T>(this IDbCommand dbCmd, string sql, Dictionary<string, object> dict, CancellationToken token)
Expand All @@ -207,12 +207,12 @@ internal static Task<List<T>> SelectNonDefaultsAsync<T>(this IDbCommand dbCmd, o

internal static Task<List<T>> SelectNonDefaultsAsync<T>(this IDbCommand dbCmd, string sql, object anonType, CancellationToken token)
{
return dbCmd.SetParameters<T>(anonType, excludeDefaults: true).ConvertToListAsync<T>(dbCmd.GetDialectProvider().ToSelectStatement(typeof(T), sql), token);
return dbCmd.SetParameters<T>(anonType, excludeDefaults: true, sql: ref sql).ConvertToListAsync<T>(dbCmd.GetDialectProvider().ToSelectStatement(typeof(T), sql), token);
}

internal static Task<T> ScalarAsync<T>(this IDbCommand dbCmd, string sql, object anonType, CancellationToken token)
{
return dbCmd.SetParameters<T>(anonType, excludeDefaults: false).ScalarAsync<T>(sql, token);
return dbCmd.SetParameters<T>(anonType, excludeDefaults: false, sql: ref sql).ScalarAsync<T>(sql, token);
}

internal static Task<T> ScalarAsync<T>(this IDataReader reader, IOrmLiteDialectProvider dialectProvider, CancellationToken token)
Expand All @@ -229,7 +229,7 @@ public static Task<long> LongScalarAsync(this IDbCommand dbCmd, CancellationToke

internal static Task<List<T>> ColumnAsync<T>(this IDbCommand dbCmd, string sql, object anonType, CancellationToken token)
{
if (anonType != null) dbCmd.SetParameters<T>(anonType, excludeDefaults: false);
if (anonType != null) dbCmd.SetParameters<T>(anonType, excludeDefaults: false, sql: ref sql);

return dbCmd.ColumnAsync<T>(dbCmd.GetDialectProvider().ToSelectStatement(typeof(T), sql), token);
}
Expand All @@ -251,7 +251,7 @@ internal static Task<List<T>> ColumnAsync<T>(this IDataReader reader, IOrmLiteDi

internal static Task<HashSet<T>> ColumnDistinctAsync<T>(this IDbCommand dbCmd, string sql, object anonType, CancellationToken token)
{
if (anonType != null) dbCmd.SetParameters<T>(anonType, excludeDefaults: false);
if (anonType != null) dbCmd.SetParameters<T>(anonType, excludeDefaults: false, sql: ref sql);

return dbCmd.ColumnDistinctAsync<T>(sql, token);
}
Expand Down
Expand Up @@ -252,13 +252,13 @@ internal static Task<int> DeleteAllAsync(this IDbCommand dbCmd, Type tableType,

internal static Task<int> DeleteAsync<T>(this IDbCommand dbCmd, string sql, object anonType, CancellationToken token)
{
if (anonType != null) dbCmd.SetParameters<T>(anonType, excludeDefaults: false);
if (anonType != null) dbCmd.SetParameters<T>(anonType, excludeDefaults: false, sql: ref sql);
return dbCmd.ExecuteSqlAsync(dbCmd.GetDialectProvider().ToDeleteStatement(typeof(T), sql), token);
}

internal static Task<int> DeleteAsync(this IDbCommand dbCmd, Type tableType, string sql, object anonType, CancellationToken token)
{
if (anonType != null) dbCmd.SetParameters(tableType, anonType, excludeDefaults: false);
if (anonType != null) dbCmd.SetParameters(tableType, anonType, excludeDefaults: false, sql: ref sql);
return dbCmd.ExecuteSqlAsync(dbCmd.GetDialectProvider().ToDeleteStatement(tableType, sql), token);
}

Expand Down
65 changes: 44 additions & 21 deletions src/ServiceStack.OrmLite/OrmLiteReadCommandExtensions.cs
Expand Up @@ -80,17 +80,27 @@ internal static void SetFilter<T>(this IDbCommand dbCmd, string name, object val

internal static IDbCommand SetFilters<T>(this IDbCommand dbCmd, object anonType, bool excludeDefaults)
{
dbCmd.SetParameters<T>(anonType, excludeDefaults); //needs to be called first
string ignore = null;
dbCmd.SetParameters<T>(anonType, excludeDefaults, ref ignore); //needs to be called first
dbCmd.CommandText = dbCmd.GetFilterSql<T>();
return dbCmd;
}

internal static IDbCommand SetParameters<T>(this IDbCommand dbCmd, object anonType, bool excludeDefaults)
internal static IDbCommand SetParameters<T>(this IDbCommand dbCmd, object anonType, bool excludeDefaults, ref string sql) =>
dbCmd.SetParameters(typeof(T), anonType, excludeDefaults, ref sql);

private static IEnumerable GetMultiValues(object value)
{
return dbCmd.SetParameters(typeof(T), anonType, excludeDefaults);
if (value is SqlInValues inValues)
return inValues.GetValues();

return (value is IEnumerable enumerable &&
!(enumerable is string ||
enumerable is IEnumerable<KeyValuePair<string, object>>)
) ? enumerable : null;
}

internal static IDbCommand SetParameters(this IDbCommand dbCmd, Type type, object anonType, bool excludeDefaults)
internal static IDbCommand SetParameters(this IDbCommand dbCmd, Type type, object anonType, bool excludeDefaults, ref string sql)
{
if (anonType == null)
return dbCmd;
Expand All @@ -104,16 +114,25 @@ internal static IDbCommand SetParameters(this IDbCommand dbCmd, Type type, objec
? dialectProvider.GetFieldDefinitionMap(modelDef)
: null;

var sqlCopy = sql; //C# doesn't allow changing ref params in lambda's

anonType.ToObjectDictionary().ForEachParam(modelDef, excludeDefaults, (propName, columnName, value) =>
{
var propType = value?.GetType() ?? typeof(object);
if (value is SqlInValues inValues)
var inValues = GetMultiValues(value);
if (inValues != null)
{
var i = 0;
foreach (var item in inValues.GetValues())
var sb = StringBuilderCache.Allocate();
foreach (var item in inValues)
{
var p = dbCmd.CreateParameter();
p.ParameterName = "v" + i++;
if (sb.Length > 0)
sb.Append(',');
sb.Append(dialectProvider.ParamString + p.ParameterName);
p.Direction = ParameterDirection.Input;
dialectProvider.InitDbParam(p, item.GetType());
Expand All @@ -130,6 +149,9 @@ internal static IDbCommand SetParameters(this IDbCommand dbCmd, Type type, objec
dbCmd.Parameters.Add(p);
}
var sqlIn = StringBuilderCache.ReturnAndFree(sb);
sqlCopy = sqlCopy?.Replace(dialectProvider.ParamString + propName, sqlIn);
}
else
{
Expand All @@ -155,14 +177,15 @@ internal static IDbCommand SetParameters(this IDbCommand dbCmd, Type type, objec
p.Value = value == null ?
DBNull.Value
: p.DbType == DbType.String ?
value.ToString() :
value;
: p.DbType == DbType.String ?
value.ToString() :
value;
dbCmd.Parameters.Add(p);
}
});

sql = sqlCopy;
return dbCmd;
}

Expand Down Expand Up @@ -337,7 +360,7 @@ internal static T Single<T>(this IDbCommand dbCmd, string sql, IEnumerable<IDbDa

internal static T Single<T>(this IDbCommand dbCmd, string sql, object anonType)
{
dbCmd.SetParameters<T>(anonType, excludeDefaults: false);
dbCmd.SetParameters<T>(anonType, excludeDefaults: false, sql: ref sql);

return OrmLiteUtils.IsScalar<T>()
? dbCmd.Scalar<T>(sql)
Expand Down Expand Up @@ -371,7 +394,7 @@ internal static List<T> Select<T>(this IDbCommand dbCmd, string sql, IEnumerable

internal static List<T> Select<T>(this IDbCommand dbCmd, string sql, object anonType = null)
{
if (anonType != null) dbCmd.SetParameters<T>(anonType, excludeDefaults: false);
if (anonType != null) dbCmd.SetParameters<T>(anonType, excludeDefaults: false, sql: ref sql);
dbCmd.CommandText = dbCmd.GetDialectProvider().ToSelectStatement(typeof(T), sql);

return dbCmd.ConvertToList<T>();
Expand All @@ -392,7 +415,7 @@ internal static List<TModel> Select<TModel>(this IDbCommand dbCmd, Type fromTabl

internal static List<T> Select<T>(this IDbCommand dbCmd, Type fromTableType, string sql, object anonType = null)
{
if (anonType != null) dbCmd.SetParameters(fromTableType, anonType, excludeDefaults: false);
if (anonType != null) dbCmd.SetParameters(fromTableType, anonType, excludeDefaults: false, sql: ref sql);
dbCmd.CommandText = ToSelect<T>(dbCmd.GetDialectProvider(), fromTableType, sql);

return dbCmd.ConvertToList<T>();
Expand Down Expand Up @@ -423,7 +446,7 @@ internal static List<T> SqlList<T>(this IDbCommand dbCmd, string sql, IEnumerabl

internal static List<T> SqlList<T>(this IDbCommand dbCmd, string sql, object anonType = null)
{
if (anonType != null) dbCmd.SetParameters<T>(anonType, excludeDefaults: false);
if (anonType != null) dbCmd.SetParameters<T>(anonType, excludeDefaults: false, sql: ref sql);
dbCmd.CommandText = sql;

return dbCmd.ConvertToList<T>();
Expand Down Expand Up @@ -453,7 +476,7 @@ internal static List<T> SqlColumn<T>(this IDbCommand dbCmd, string sql, IEnumera

internal static List<T> SqlColumn<T>(this IDbCommand dbCmd, string sql, object anonType = null)
{
dbCmd.SetParameters<T>(anonType, excludeDefaults: false).CommandText = sql;
dbCmd.SetParameters<T>(anonType, excludeDefaults: false, sql: ref sql).CommandText = sql;
return dbCmd.ConvertToList<T>();
}

Expand All @@ -472,7 +495,7 @@ internal static T SqlScalar<T>(this IDbCommand dbCmd, string sql, IEnumerable<ID

internal static T SqlScalar<T>(this IDbCommand dbCmd, string sql, object anonType = null)
{
if (anonType != null) dbCmd.SetParameters<T>(anonType, excludeDefaults: false);
if (anonType != null) dbCmd.SetParameters<T>(anonType, excludeDefaults: false, sql: ref sql);

return dbCmd.Scalar<T>(sql);
}
Expand All @@ -493,7 +516,7 @@ internal static List<T> SelectNonDefaults<T>(this IDbCommand dbCmd, object filte

internal static List<T> SelectNonDefaults<T>(this IDbCommand dbCmd, string sql, object anonType = null)
{
if (anonType != null) dbCmd.SetParameters<T>(anonType, excludeDefaults: true);
if (anonType != null) dbCmd.SetParameters<T>(anonType, excludeDefaults: true, sql: ref sql);

return dbCmd.ConvertToList<T>(dbCmd.GetDialectProvider().ToSelectStatement(typeof(T), sql));
}
Expand All @@ -505,7 +528,7 @@ internal static IEnumerable<T> SelectLazy<T>(this IDbCommand dbCmd, string sql,

internal static IEnumerable<T> SelectLazy<T>(this IDbCommand dbCmd, string sql, object anonType = null)
{
if (anonType != null) dbCmd.SetParameters<T>(anonType, excludeDefaults: false);
if (anonType != null) dbCmd.SetParameters<T>(anonType, excludeDefaults: false, sql: ref sql);
var dialectProvider = dbCmd.GetDialectProvider();
dbCmd.CommandText = dialectProvider.ToSelectStatement(typeof(T), sql);

Expand Down Expand Up @@ -539,7 +562,7 @@ internal static IEnumerable<T> ColumnLazy<T>(this IDbCommand dbCmd, string sql,

internal static IEnumerable<T> ColumnLazy<T>(this IDbCommand dbCmd, string sql, object anonType)
{
foreach (var p in dbCmd.SetParameters<T>(anonType, excludeDefaults: false).ColumnLazy<T>(sql)) yield return p;
foreach (var p in dbCmd.SetParameters<T>(anonType, excludeDefaults: false, sql: ref sql).ColumnLazy<T>(sql)) yield return p;
}

private static IEnumerable<T> ColumnLazy<T>(this IDbCommand dbCmd, string sql)
Expand Down Expand Up @@ -603,7 +626,7 @@ internal static IEnumerable<T> SelectLazy<T>(this IDbCommand dbCmd)

internal static T Scalar<T>(this IDbCommand dbCmd, string sql, object anonType = null)
{
if (anonType != null) dbCmd.SetParameters<T>(anonType, excludeDefaults: false);
if (anonType != null) dbCmd.SetParameters<T>(anonType, excludeDefaults: false, sql: ref sql);

return dbCmd.Scalar<T>(sql);
}
Expand Down Expand Up @@ -656,7 +679,7 @@ internal static long LastInsertId(this IDbCommand dbCmd)

internal static List<T> Column<T>(this IDbCommand dbCmd, string sql, object anonType = null)
{
if (anonType != null) dbCmd.SetParameters<T>(anonType, excludeDefaults: false);
if (anonType != null) dbCmd.SetParameters<T>(anonType, excludeDefaults: false, sql: ref sql);

return dbCmd.Column<T>(dbCmd.GetDialectProvider().ToSelectStatement(typeof(T), sql));
}
Expand All @@ -683,7 +706,7 @@ internal static HashSet<T> ColumnDistinct<T>(this IDbCommand dbCmd, string sql,

internal static HashSet<T> ColumnDistinct<T>(this IDbCommand dbCmd, string sql, object anonType = null)
{
return dbCmd.SetParameters<T>(anonType, excludeDefaults: false).ColumnDistinct<T>(sql);
return dbCmd.SetParameters<T>(anonType, excludeDefaults: false, sql: ref sql).ColumnDistinct<T>(sql);
}

internal static HashSet<T> ColumnDistinct<T>(this IDataReader reader, IOrmLiteDialectProvider dialectProvider)
Expand Down
4 changes: 2 additions & 2 deletions src/ServiceStack.OrmLite/OrmLiteWriteCommandExtensions.cs
Expand Up @@ -618,13 +618,13 @@ internal static int DeleteAll(this IDbCommand dbCmd, Type tableType)

internal static int Delete<T>(this IDbCommand dbCmd, string sql, object anonType = null)
{
if (anonType != null) dbCmd.SetParameters<T>(anonType, excludeDefaults: false);
if (anonType != null) dbCmd.SetParameters<T>(anonType, excludeDefaults: false, sql: ref sql);
return dbCmd.ExecuteSql(dbCmd.GetDialectProvider().ToDeleteStatement(typeof(T), sql));
}

internal static int Delete(this IDbCommand dbCmd, Type tableType, string sql, object anonType = null)
{
if (anonType != null) dbCmd.SetParameters(tableType, anonType, excludeDefaults: false);
if (anonType != null) dbCmd.SetParameters(tableType, anonType, excludeDefaults: false, sql: ref sql);
return dbCmd.ExecuteSql(dbCmd.GetDialectProvider().ToDeleteStatement(tableType, sql));
}

Expand Down
27 changes: 27 additions & 0 deletions tests/ServiceStack.OrmLite.Tests/OrmLiteSelectTests.cs
Expand Up @@ -5,6 +5,7 @@
using NUnit.Framework;
using ServiceStack.Common;
using ServiceStack.Common.Tests.Models;
using ServiceStack.Logging;
using ServiceStack.Text;

namespace ServiceStack.OrmLite.Tests
Expand Down Expand Up @@ -331,11 +332,37 @@ public void Can_Select_In_for_string_value()
var rows = db.Select<ModelWithIdAndName>("Name IN ({0})".Fmt(selectInNames.SqlInParams()),
new { values = selectInNames.SqlInValues() });
Assert.That(rows.Count, Is.EqualTo(selectInNames.Length));

rows = db.Select<ModelWithIdAndName>("Name IN (@values)",
new { values = selectInNames });
Assert.That(rows.Count, Is.EqualTo(selectInNames.Length));

rows = db.Select<ModelWithIdAndName>("Name IN (@p1, @p2)".PreNormalizeSql(db), new { p1 = "Name1", p2 = "Name2" });
Assert.That(rows.Count, Is.EqualTo(selectInNames.Length));
}
}

[Test]
public void Can_select_IN_using_array_or_List_params()
{
LogManager.LogFactory = new ConsoleLogFactory();
using (var db = OpenDbConnection())
{
db.DropAndCreateTable<ModelWithIdAndName>();
5.Times(x => db.Insert(ModelWithIdAndName.Create(x)));

var names = new[] { "Name2", "Name3" };
var rows = db.Select<ModelWithIdAndName>("Name IN (@names)", new { names });
Assert.That(rows.Count, Is.EqualTo(2));
Assert.That(rows.Map(x => x.Name), Is.EquivalentTo(names));

var ids = new List<int> { 2, 3 };
rows = db.Select<ModelWithIdAndName>("Id IN (@ids)", new { ids });
Assert.That(rows.Count, Is.EqualTo(2));
Assert.That(rows.Map(x => x.Id), Is.EquivalentTo(ids));
}
}

public class PocoFlag
{
public string Name { get; set; }
Expand Down

0 comments on commit 8cee8a2

Please sign in to comment.