-
Notifications
You must be signed in to change notification settings - Fork 4
/
MockExtension.cs
92 lines (83 loc) · 3.91 KB
/
MockExtension.cs
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
using System;
using System.Collections.Generic;
using System.Data.Entity;
using System.Data.Entity.Infrastructure;
using System.Linq;
using System.Linq.Expressions;
using NSubstitute;
namespace MockDbContextTests
{
/// <summary>
/// Extensions to generate DbSets
/// </summary>
public static class MockExtension
{
/// <summary>
/// Creates a mocked generic DbSet of type passed in based on the data supplied
/// </summary>
/// <param name="queryableEnumerable">The data to emulate in the set</param>
/// <returns>A mocked DbSet using NSubstitute</returns>
public static DbSet<TEntity> GenerateMockDbSet<TEntity>(this IEnumerable<TEntity> queryableEnumerable) where TEntity : class
{
var queryable = queryableEnumerable as IQueryable<TEntity> ?? queryableEnumerable.AsQueryable();
var mockSet = Substitute.For<DbSet<TEntity>, IQueryable<TEntity>>();
var castMockSet = (IQueryable<TEntity>)mockSet;
castMockSet.Provider.Returns(queryable.Provider);
castMockSet.Expression.Returns(queryable.Expression);
castMockSet.ElementType.Returns(queryable.ElementType);
castMockSet.GetEnumerator().Returns(queryable.GetEnumerator());
castMockSet.AsNoTracking().Returns(castMockSet);
mockSet.Include(Arg.Any<string>()).Returns(mockSet);
return mockSet;
}
/// <summary>
/// Generates a mock/fake dbset that can be called using the async keyword
/// </summary>
/// <typeparam name="TEntity"></typeparam>
/// <param name="queryableEnumerable">The queryable.</param>
public static DbSet<TEntity> GenerateMockDbSetForAsync<TEntity>(this IEnumerable<TEntity> queryableEnumerable) where TEntity : class
{
var queryable = queryableEnumerable as IQueryable<TEntity> ?? queryableEnumerable.AsQueryable();
var mockSet = Substitute.For<DbSet<TEntity>, IQueryable<TEntity>, IDbAsyncEnumerable<TEntity>>();
// async support
var castMockSet = (IQueryable<TEntity>)mockSet;
var castAsyncEnum = (IDbAsyncEnumerable<TEntity>)mockSet;
castAsyncEnum.GetAsyncEnumerator().Returns(new TestDbAsyncEnumerator<TEntity>(queryable.GetEnumerator()));
castMockSet.Provider.Returns(new TestDbAsyncQueryProvider<TEntity>(queryable.Provider));
castMockSet.Expression.Returns(queryable.Expression);
castMockSet.ElementType.Returns(queryable.ElementType);
castMockSet.GetEnumerator().Returns(queryable.GetEnumerator());
castMockSet.AsNoTracking().Returns(castMockSet);
mockSet.Include(Arg.Any<string>()).Returns(mockSet);
return mockSet;
}
/// <summary>
/// Adds the IEnumerable parameter to the DbContext Set (of type DbSet) that can be used using asynchronous calls
/// </summary>
/// <param name="context">The context to add the IEnumerable parameter to.</param>
/// <param name="queryableEnumerable">The enumerable object to add as a DbSet.</param>
public static void AddToDbSetForAsync<TEntity>(this DbContext context, IEnumerable<TEntity> queryableEnumerable) where TEntity : class
{
var set = queryableEnumerable.GenerateMockDbSetForAsync();
context.Set<TEntity>().Returns(set);
}
/// <summary>
/// Adds the IEnumerable parameter to the DbContext Set (of type DbSet) (can not be used using asynchronous calls)
/// </summary>
/// <param name="context">The context to add the IEnumerable parameter to.</param>
/// <param name="queryableEnumerable">The enumerable object to add as a DbSet.</param>
public static void AddToDbSet<TEntity>(this DbContext context, IEnumerable<TEntity> queryableEnumerable) where TEntity : class
{
var set = queryableEnumerable.GenerateMockDbSetForAsync();
context.Set<TEntity>().Returns(set);
}
/// <summary>
/// Mocks the include.
/// </summary>
public static IQueryable<T> MockInclude<T, TProperty>(this IQueryable<T> source, Expression<Func<T, TProperty>> path)
{
source.Include(path).Returns(source);
return source;
}
}
}