-
-
Notifications
You must be signed in to change notification settings - Fork 3.6k
/
TestBase.cs
131 lines (115 loc) · 4.91 KB
/
TestBase.cs
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
using System;
using System.Data;
using System.Data.Common;
using System.Globalization;
using System.Threading;
using Xunit;
namespace Dapper.Tests
{
public static class DatabaseProvider<TProvider> where TProvider : DatabaseProvider
{
public static TProvider Instance { get; } = Activator.CreateInstance<TProvider>();
}
public abstract class DatabaseProvider
{
public abstract DbProviderFactory Factory { get; }
public virtual void Dispose() { }
public abstract string GetConnectionString();
protected string GetConnectionString(string name, string defaultConnectionString) =>
Environment.GetEnvironmentVariable(name) ?? defaultConnectionString;
public DbConnection GetOpenConnection()
{
var conn = Factory.CreateConnection();
conn.ConnectionString = GetConnectionString();
conn.Open();
if (conn.State != ConnectionState.Open) throw new InvalidOperationException("should be open!");
return conn;
}
public DbConnection GetClosedConnection()
{
var conn = Factory.CreateConnection();
conn.ConnectionString = GetConnectionString();
if (conn.State != ConnectionState.Closed) throw new InvalidOperationException("should be closed!");
return conn;
}
public DbParameter CreateRawParameter(string name, object value)
{
var p = Factory.CreateParameter();
p.ParameterName = name;
p.Value = value ?? DBNull.Value;
return p;
}
}
public abstract class SqlServerDatabaseProvider : DatabaseProvider
{
public override string GetConnectionString() =>
GetConnectionString("SqlServerConnectionString", "Data Source=.;Initial Catalog=tempdb;Integrated Security=True");
public DbConnection GetOpenConnection(bool mars)
{
if (!mars) return GetOpenConnection();
var scsb = Factory.CreateConnectionStringBuilder();
scsb.ConnectionString = GetConnectionString();
((dynamic)scsb).MultipleActiveResultSets = true;
var conn = Factory.CreateConnection();
conn.ConnectionString = scsb.ConnectionString;
conn.Open();
if (conn.State != ConnectionState.Open) throw new InvalidOperationException("should be open!");
return conn;
}
}
public sealed class SystemSqlClientProvider : SqlServerDatabaseProvider
{
public override DbProviderFactory Factory => System.Data.SqlClient.SqlClientFactory.Instance;
}
#if MSSQLCLIENT
public sealed class MicrosoftSqlClientProvider : SqlServerDatabaseProvider
{
public override DbProviderFactory Factory => Microsoft.Data.SqlClient.SqlClientFactory.Instance;
}
#endif
public abstract class TestBase<TProvider> : IDisposable where TProvider : DatabaseProvider
{
protected void SkipIfMsDataClient()
=> Skip.If<Microsoft.Data.SqlClient.SqlConnection>(connection);
protected DbConnection GetOpenConnection() => Provider.GetOpenConnection();
protected DbConnection GetClosedConnection() => Provider.GetClosedConnection();
protected DbConnection _connection;
protected DbConnection connection => _connection ?? (_connection = Provider.GetOpenConnection());
public TProvider Provider { get; } = DatabaseProvider<TProvider>.Instance;
protected static CultureInfo ActiveCulture
{
get { return Thread.CurrentThread.CurrentCulture; }
set { Thread.CurrentThread.CurrentCulture = value; }
}
static TestBase()
{
Console.WriteLine("Dapper: " + typeof(SqlMapper).AssemblyQualifiedName);
var provider = DatabaseProvider<TProvider>.Instance;
Console.WriteLine("Using Connectionstring: {0}", provider.GetConnectionString());
var factory = provider.Factory;
Console.WriteLine("Using Provider: {0}", factory.GetType().FullName);
Console.WriteLine(".NET: " + Environment.Version);
Console.Write("Loading native assemblies for SQL types...");
try
{
SqlServerTypesLoader.LoadNativeAssemblies(AppDomain.CurrentDomain.BaseDirectory);
Console.WriteLine("done.");
}
catch (Exception ex)
{
Console.WriteLine("failed.");
Console.Error.WriteLine(ex.Message);
}
}
public virtual void Dispose()
{
_connection?.Dispose();
_connection = null;
Provider?.Dispose();
}
}
public static class NonParallelDefinition
{
public const string Name = "NonParallel";
}
}