Skip to content
This repository has been archived by the owner on Dec 24, 2022. It is now read-only.

Add "include" feature to LoadSelectAsync and LoadSingleByIdAsync #480

Closed
wants to merge 2 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
Expand Up @@ -6,6 +6,7 @@
using System.Collections;
using System.Collections.Generic;
using System.Data;
using System.Linq;
using System.Threading;
using System.Threading.Tasks;
using ServiceStack.Logging;
Expand Down Expand Up @@ -402,22 +403,34 @@ internal static Task<List<TOutputModel>> SqlProcedureAsync<TOutputModel>(this ID
return dbCmd.ConvertToListAsync<TOutputModel>(sql, token);
}

internal static async Task<T> LoadSingleByIdAsync<T>(this IDbCommand dbCmd, object value, CancellationToken token)
internal static async Task<T> LoadSingleByIdAsync<T>(this IDbCommand dbCmd, object value, string[] include = null, CancellationToken token = default(CancellationToken))
{
var row = await dbCmd.SingleByIdAsync<T>(value, token);
if (row == null)
return default(T);

await dbCmd.LoadReferencesAsync(row, token);
await dbCmd.LoadReferencesAsync(row, include, token);

return row;
}

public static async Task LoadReferencesAsync<T>(this IDbCommand dbCmd, T instance, CancellationToken token)
public static async Task LoadReferencesAsync<T>(this IDbCommand dbCmd, T instance, string[] include = null, CancellationToken token = default(CancellationToken))
{
var loadRef = new LoadReferencesAsync<T>(dbCmd, instance);
var fieldDefs = loadRef.FieldDefs;

foreach (var fieldDef in loadRef.FieldDefs)
if (!include.IsEmpty())
{
// Check that any include values aren't reference fields of the specified type
var fields = fieldDefs.Select(q => q.FieldName);
var invalid = include.Except<string>(fields).ToList();
if (invalid.Count > 0)
throw new ArgumentException("Fields '{0}' are not Reference Properties of Type '{1}'".Fmt(invalid.Join("', '"), typeof(T).Name));

fieldDefs = fieldDefs.Where(fd => include.Contains(fd.FieldName)).ToList();
}

foreach (var fieldDef in fieldDefs)
{
dbCmd.Parameters.Clear();
var listInterface = fieldDef.FieldType.GetTypeWithGenericInterfaceOf(typeof(IList<>));
Expand All @@ -432,11 +445,23 @@ public static async Task LoadReferencesAsync<T>(this IDbCommand dbCmd, T instanc
}
}

internal static async Task<List<Into>> LoadListWithReferences<Into, From>(this IDbCommand dbCmd, SqlExpression<From> expr = null, CancellationToken token = default(CancellationToken))
internal static async Task<List<Into>> LoadListWithReferences<Into, From>(this IDbCommand dbCmd, SqlExpression<From> expr = null, string[] include = null, CancellationToken token = default(CancellationToken))
{
var loadList = new LoadListAsync<Into, From>(dbCmd, expr);
var loadList = new LoadListAsync<Into, From>(dbCmd, expr);

var fieldDefs = loadList.FieldDefs;
if (!include.IsEmpty())
{
// Check that any include values aren't reference fields of the specified From type
var fields = fieldDefs.Select(q => q.FieldName);
var invalid = include.Except<string>(fields).ToList();
if (invalid.Count > 0)
throw new ArgumentException("Fields '{0}' are not Reference Properties of Type '{1}'".Fmt(invalid.Join("', '"), typeof(From).Name));

fieldDefs = loadList.FieldDefs.Where(fd => include.Contains(fd.FieldName)).ToList();
}

foreach (var fieldDef in loadList.FieldDefs)
foreach (var fieldDef in fieldDefs)
{
var listInterface = fieldDef.FieldType.GetTypeWithGenericInterfaceOf(typeof(IList<>));
if (listInterface != null)
Expand Down
Expand Up @@ -129,27 +129,27 @@ internal static Task<long> RowCountAsync(this IDbCommand dbCmd, string sql, Canc
return dbCmd.ScalarAsync<long>("SELECT COUNT(*) FROM ({0}) AS COUNT".Fmt(sql), token);
}

internal static Task<List<T>> LoadSelectAsync<T>(this IDbCommand dbCmd, Func<SqlExpression<T>, SqlExpression<T>> expression, CancellationToken token = default(CancellationToken))
internal static Task<List<T>> LoadSelectAsync<T>(this IDbCommand dbCmd, Func<SqlExpression<T>, SqlExpression<T>> expression, string[] include = null, CancellationToken token = default(CancellationToken))
{
var expr = dbCmd.GetDialectProvider().SqlExpression<T>();
expr = expression(expr);
return dbCmd.LoadListWithReferences<T, T>(expr, token);
return dbCmd.LoadListWithReferences<T, T>(expr, include, token);
}

internal static Task<List<T>> LoadSelectAsync<T>(this IDbCommand dbCmd, SqlExpression<T> expression = null, CancellationToken token = default(CancellationToken))
internal static Task<List<T>> LoadSelectAsync<T>(this IDbCommand dbCmd, SqlExpression<T> expression = null, string[] include = null, CancellationToken token = default(CancellationToken))
{
return dbCmd.LoadListWithReferences<T, T>(expression, token);
return dbCmd.LoadListWithReferences<T, T>(expression, include, token);
}

internal static Task<List<Into>> LoadSelectAsync<Into, From>(this IDbCommand dbCmd, SqlExpression<From> expression, CancellationToken token = default(CancellationToken))
internal static Task<List<Into>> LoadSelectAsync<Into, From>(this IDbCommand dbCmd, SqlExpression<From> expression, string[] include = null, CancellationToken token = default(CancellationToken))
{
return dbCmd.LoadListWithReferences<Into, From>(expression, token);
return dbCmd.LoadListWithReferences<Into, From>(expression, include, token);
}

internal static Task<List<T>> LoadSelectAsync<T>(this IDbCommand dbCmd, Expression<Func<T, bool>> predicate, CancellationToken token = default(CancellationToken))
internal static Task<List<T>> LoadSelectAsync<T>(this IDbCommand dbCmd, Expression<Func<T, bool>> predicate, string[] include = null, CancellationToken token = default(CancellationToken))
{
var expr = dbCmd.GetDialectProvider().SqlExpression<T>().Where(predicate);
return dbCmd.LoadListWithReferences<T, T>(expr, token);
return dbCmd.LoadListWithReferences<T, T>(expr, include, token);
}

internal static Task<T> ExprConvertToAsync<T>(this IDataReader dataReader, IOrmLiteDialectProvider dialectProvider, CancellationToken token)
Expand Down
17 changes: 13 additions & 4 deletions src/ServiceStack.OrmLite/OrmLiteReadApiAsync.cs
Expand Up @@ -649,18 +649,27 @@ public static Task<long> LongScalarAsync(this IDbConnection dbConn, Cancellation
/// Returns the first result with all its references loaded, using a primary key id. E.g:
/// <para>db.LoadSingleById&lt;Person&gt;(1)</para>
/// </summary>
public static Task<T> LoadSingleByIdAsync<T>(this IDbConnection dbConn, object idValue, CancellationToken token=default(CancellationToken))
public static Task<T> LoadSingleByIdAsync<T>(this IDbConnection dbConn, object idValue, string[] include = null, CancellationToken token=default(CancellationToken))
{
return dbConn.Exec(dbCmd => dbCmd.LoadSingleByIdAsync<T>(idValue, token));
return dbConn.Exec(dbCmd => dbCmd.LoadSingleByIdAsync<T>(idValue, include, token));
}

/// <summary>
/// Returns the first result with all its references loaded, using a primary key id. E.g:
/// <para>db.LoadSingleById&lt;Person&gt;(1, include = x => new{ x.Address })</para>
/// </summary>
public static Task<T> LoadSingleByIdAsync<T>(this IDbConnection dbConn, object idValue, Func<T, object> include, CancellationToken token = default(CancellationToken))
{
return dbConn.Exec(dbCmd => dbCmd.LoadSingleByIdAsync<T>(idValue, include(typeof(T).CreateInstance<T>()).GetType().AllAnonFields(), token));
}

/// <summary>
/// Loads all the related references onto the instance. E.g:
/// <para>db.LoadReferencesAsync(customer)</para>
/// </summary>
public static Task LoadReferencesAsync<T>(this IDbConnection dbConn, T instance, CancellationToken token = default(CancellationToken))
public static Task LoadReferencesAsync<T>(this IDbConnection dbConn, T instance, string[] include = null, CancellationToken token = default(CancellationToken))
{
return dbConn.Exec(dbCmd => dbCmd.LoadReferencesAsync(instance, token));
return dbConn.Exec(dbCmd => dbCmd.LoadReferencesAsync(instance, include, token));
}
}
}
Expand Down
16 changes: 8 additions & 8 deletions src/ServiceStack.OrmLite/OrmLiteReadExpressionsApiAsync.cs
Expand Up @@ -173,35 +173,35 @@ public static Task<long> RowCountAsync(this IDbConnection dbConn, string sql, Ca
/// Returns results with references from using a LINQ Expression. E.g:
/// <para>db.LoadSelectAsync&lt;Person&gt;(x =&gt; x.Age &gt; 40)</para>
/// </summary>
public static Task<List<T>> LoadSelectAsync<T>(this IDbConnection dbConn, Expression<Func<T, bool>> predicate, CancellationToken token = default(CancellationToken))
public static Task<List<T>> LoadSelectAsync<T>(this IDbConnection dbConn, Expression<Func<T, bool>> predicate, string[] include = null, CancellationToken token = default(CancellationToken))
{
return dbConn.Exec(dbCmd => dbCmd.LoadSelectAsync(predicate));
return dbConn.Exec(dbCmd => dbCmd.LoadSelectAsync(predicate, include, token));
}

/// <summary>
/// Returns results with references from using an SqlExpression lambda. E.g:
/// <para>db.LoadSelectAsync&lt;Person&gt;(q =&gt; q.Where(x =&gt; x.Age &gt; 40))</para>
/// </summary>
public static Task<List<T>> LoadSelectAsync<T>(this IDbConnection dbConn, Func<SqlExpression<T>, SqlExpression<T>> expression, CancellationToken token = default(CancellationToken))
public static Task<List<T>> LoadSelectAsync<T>(this IDbConnection dbConn, Func<SqlExpression<T>, SqlExpression<T>> expression, string[] include = null, CancellationToken token = default(CancellationToken))
{
return dbConn.Exec(dbCmd => dbCmd.LoadSelectAsync(expression));
return dbConn.Exec(dbCmd => dbCmd.LoadSelectAsync(expression, include, token));
}

/// <summary>
/// Returns results with references from using an SqlExpression lambda. E.g:
/// <para>db.LoadSelectAsync(db.From&lt;Person&gt;().Where(x =&gt; x.Age &gt; 40))</para>
/// </summary>
public static Task<List<T>> LoadSelectAsync<T>(this IDbConnection dbConn, SqlExpression<T> expression, CancellationToken token = default(CancellationToken))
public static Task<List<T>> LoadSelectAsync<T>(this IDbConnection dbConn, SqlExpression<T> expression, string[] include = null, CancellationToken token = default(CancellationToken))
{
return dbConn.Exec(dbCmd => dbCmd.LoadSelectAsync(expression));
return dbConn.Exec(dbCmd => dbCmd.LoadSelectAsync(expression, include, token));
}

/// <summary>
/// Project results with references from a number of joined tables into a different model
/// </summary>
public static Task<List<Into>> LoadSelectAsync<Into, From>(this IDbConnection dbConn, SqlExpression<From> expression, CancellationToken token = default(CancellationToken))
public static Task<List<Into>> LoadSelectAsync<Into, From>(this IDbConnection dbConn, SqlExpression<From> expression, string[] include = null, CancellationToken token = default(CancellationToken))
{
return dbConn.Exec(dbCmd => dbCmd.LoadSelectAsync<Into, From>(expression));
return dbConn.Exec(dbCmd => dbCmd.LoadSelectAsync<Into, From>(expression, include, token));
}
}

Expand Down
91 changes: 91 additions & 0 deletions tests/ServiceStack.OrmLiteV45.Tests/LoadReferencesTests.cs
Expand Up @@ -141,5 +141,96 @@ public async Task Can_Save_and_Load_References_Async()
Assert.That(dbCustomer.PrimaryAddress, Is.Not.Null);
Assert.That(dbCustomer.Orders.Count, Is.EqualTo(2));
}

[Test]
public async Task Can_load_only_included_references_async()
{
var customer = new Customer
{
Name = "Customer 1",
PrimaryAddress = new CustomerAddress
{
AddressLine1 = "1 Humpty Street",
City = "Humpty Doo",
State = "Northern Territory",
Country = "Australia"
},
Orders = new[] {
new Order { LineItem = "Line 1", Qty = 1, Cost = 1.99m },
new Order { LineItem = "Line 2", Qty = 2, Cost = 2.99m },
}.ToList(),
};

await db.SaveAsync(customer);
Assert.That(customer.Id, Is.GreaterThan(0));

await db.SaveReferencesAsync(customer, customer.PrimaryAddress);
Assert.That(customer.PrimaryAddress.CustomerId, Is.EqualTo(customer.Id));

await db.SaveReferencesAsync(customer, customer.Orders);
Assert.That(customer.Orders.All(x => x.CustomerId == customer.Id));

// LoadSelectAsync overload 1
var dbCustomers = await db.LoadSelectAsync<Customer>(db.From<Customer>().Where(q => q.Id == customer.Id), include: new[] { "PrimaryAddress" });
Assert.That(dbCustomers.Count, Is.EqualTo(1));
Assert.That(dbCustomers[0].Name, Is.EqualTo("Customer 1"));
Assert.That(dbCustomers[0].Orders, Is.Null);
Assert.That(dbCustomers[0].PrimaryAddress, Is.Not.Null);

// LoadSelectAsync overload 2
dbCustomers = await db.LoadSelectAsync<Customer>(q => q.Id == customer.Id, include: new[] { "PrimaryAddress" });
Assert.That(dbCustomers.Count, Is.EqualTo(1));
Assert.That(dbCustomers[0].Name, Is.EqualTo("Customer 1"));
Assert.That(dbCustomers[0].Orders, Is.Null);
Assert.That(dbCustomers[0].PrimaryAddress, Is.Not.Null);

// LoadSelectAsync overload 3
dbCustomers = await db.LoadSelectAsync<Customer>(q => q.Where(x => x.Id == customer.Id), include: new[] { "PrimaryAddress" });
Assert.That(dbCustomers.Count, Is.EqualTo(1));
Assert.That(dbCustomers[0].Name, Is.EqualTo("Customer 1"));
Assert.That(dbCustomers[0].Orders, Is.Null);
Assert.That(dbCustomers[0].PrimaryAddress, Is.Not.Null);

// LoadSingleById overload 1
var dbCustomer = await db.LoadSingleByIdAsync<Customer>(customer.Id, include: new[] { "PrimaryAddress" });
Assert.That(dbCustomer.Name, Is.EqualTo("Customer 1"));
Assert.That(dbCustomer.Orders, Is.Null);
Assert.That(dbCustomer.PrimaryAddress, Is.Not.Null);

// LoadSingleById overload 2
dbCustomer = await db.LoadSingleByIdAsync<Customer>(customer.Id, include: x => new { x.PrimaryAddress });
Assert.That(dbCustomer.Name, Is.EqualTo("Customer 1"));
Assert.That(dbCustomer.Orders, Is.Null);
Assert.That(dbCustomer.PrimaryAddress, Is.Not.Null);


// Invalid field name
try
{
dbCustomers = await db.LoadSelectAsync<Customer>(q => q.Id == customer.Id, include: new[] { "InvalidOption1", "InvalidOption2" });
Assert.Fail();
}
catch (System.ArgumentException ex)
{
}
catch (System.Exception ex)
{
Assert.Fail();
}


try
{
dbCustomer = await db.LoadSingleByIdAsync<Customer>(customer.Id, include: new[] { "InvalidOption1", "InvalidOption2" });
Assert.Fail();
}
catch (System.ArgumentException ex)
{
}
catch (System.Exception ex)
{
Assert.Fail();
}
}
}
}