Skip to content

Commit

Permalink
Add CustomSqlConnectionFactory class wrapper to prevent possible conf…
Browse files Browse the repository at this point in the history
…licts
  • Loading branch information
a-belevich committed Aug 20, 2015
1 parent d860404 commit 30c216f
Show file tree
Hide file tree
Showing 13 changed files with 47 additions and 32 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
using System;
using System.Data.SqlClient;
using System.Reflection;
using NServiceBus.Transports.SQLServer;
using NServiceBus.Transports.SQLServer.Config;
using NUnit.Framework;

Expand All @@ -18,10 +19,10 @@ public void sql_connection_factory_exists_by_default()
busConfig.UseTransport<SqlServerTransport>();

var builder = Activate(busConfig, new SqlConnectionFactoryConfig());
var factory = builder.Build<Func<string, SqlConnection>>();
var factory = builder.Build<CustomSqlConnectionFactory>();

Assert.IsNotNull(factory);
Assert.AreEqual(defaultFactoryMethod, factory.Method);
Assert.AreEqual(defaultFactoryMethod, factory.OpenNewConnection.Method);
}

[Test]
Expand All @@ -35,10 +36,10 @@ public void sql_connection_factory_can_be_customized()
.UseCustomSqlConnectionFactory(testFactory);

var builder = Activate(busConfig, new SqlConnectionFactoryConfig());
var factory = builder.Build<Func<string, SqlConnection>>();
var factory = builder.Build<CustomSqlConnectionFactory>();

Assert.IsNotNull(factory);
Assert.AreEqual(factory, testFactory);
Assert.AreEqual(testFactory, factory.OpenNewConnection);
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,9 @@ class AmbientTransactionReceiveStrategy : IReceiveStrategy
readonly TableBasedQueue errorQueue;
readonly Func<TransportMessage, bool> tryProcessMessageCallback;
readonly TransactionOptions transactionOptions;
readonly Func<string, SqlConnection> sqlConnectionFactory;
readonly CustomSqlConnectionFactory sqlConnectionFactory;

public AmbientTransactionReceiveStrategy(string connectionString, TableBasedQueue errorQueue, Func<TransportMessage, bool> tryProcessMessageCallback, Func<string, SqlConnection> sqlConnectionFactory, PipelineExecutor pipelineExecutor, TransactionSettings transactionSettings)
public AmbientTransactionReceiveStrategy(string connectionString, TableBasedQueue errorQueue, Func<TransportMessage, bool> tryProcessMessageCallback, CustomSqlConnectionFactory sqlConnectionFactory, PipelineExecutor pipelineExecutor, TransactionSettings transactionSettings)
{
this.pipelineExecutor = pipelineExecutor;
this.tryProcessMessageCallback = tryProcessMessageCallback;
Expand All @@ -34,7 +34,7 @@ public ReceiveResult TryReceiveFrom(TableBasedQueue queue)
{
using (var scope = new TransactionScope(TransactionScopeOption.Required, transactionOptions))
{
using (var connection = sqlConnectionFactory(connectionString))
using (var connection = sqlConnectionFactory.OpenNewConnection(connectionString))
{
using (pipelineExecutor.SetConnection(connectionString, connection))
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,14 +28,13 @@ static SqlConnection DefaultOpenNewConnection(string connectionString)

public override void SetUpDefaults(SettingsHolder settings)
{
Func<string, SqlConnection> factoryMethod = DefaultOpenNewConnection;
settings.SetDefault(CustomSqlConnectionFactorySettingKey, factoryMethod);
settings.SetDefault(CustomSqlConnectionFactorySettingKey, new CustomSqlConnectionFactory(DefaultOpenNewConnection));
}

public override void Configure(FeatureConfigurationContext context, string connectionStringWithSchema)
{
var factoryMethod = (Func<string, SqlConnection>)context.Settings.Get(CustomSqlConnectionFactorySettingKey);
context.Container.ConfigureComponent(b => factoryMethod, DependencyLifecycle.SingleInstance);
var factory = (CustomSqlConnectionFactory)context.Settings.Get(CustomSqlConnectionFactorySettingKey);
context.Container.ConfigureComponent(b => factory, DependencyLifecycle.SingleInstance);
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ protected override void Configure(FeatureConfigurationContext context, string co
context.Container.ConfigureComponent<SqlServerQueueCreator>(DependencyLifecycle.InstancePerCall);

var errorQueue = ErrorQueueSettings.GetConfiguredErrorQueue(context.Settings);
context.Container.ConfigureComponent(b => new ReceiveStrategyFactory(b.Build<PipelineExecutor>(), b.Build<LocalConnectionParams>(), errorQueue, b.Build<Func<string, SqlConnection>>()), DependencyLifecycle.InstancePerCall);
context.Container.ConfigureComponent(b => new ReceiveStrategyFactory(b.Build<PipelineExecutor>(), b.Build<LocalConnectionParams>(), errorQueue, b.Build<CustomSqlConnectionFactory>()), DependencyLifecycle.InstancePerCall);

context.Container.ConfigureComponent<SqlServerPollingDequeueStrategy>(DependencyLifecycle.InstancePerCall);
context.Container.ConfigureComponent<SqlServerStorageContext>(DependencyLifecycle.InstancePerUnitOfWork);
Expand Down
15 changes: 15 additions & 0 deletions src/NServiceBus.SqlServer/CustomSqlConnectionFactory.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
namespace NServiceBus.Transports.SQLServer
{
using System;
using System.Data.SqlClient;

class CustomSqlConnectionFactory
{
public readonly Func<string, SqlConnection> OpenNewConnection;

public CustomSqlConnectionFactory(Func<string, SqlConnection> factory)
{
this.OpenNewConnection = factory;
}
}
}
1 change: 1 addition & 0 deletions src/NServiceBus.SqlServer/NServiceBus.SqlServer.csproj
Original file line number Diff line number Diff line change
Expand Up @@ -111,6 +111,7 @@
<Compile Include="ReceiveTaskTracker.cs" />
<Compile Include="SecondaryReceiveConfiguration.cs" />
<Compile Include="SecondaryReceiveSettings.cs" />
<Compile Include="CustomSqlConnectionFactory.cs" />
<Compile Include="SqlServerSettingsExtensions.cs" />
<Compile Include="SqlServerStorageContext.cs" />
<Compile Include="SqlServerTransport.cs" />
Expand Down
6 changes: 3 additions & 3 deletions src/NServiceBus.SqlServer/NativeTransactionReceiveStrategy.cs
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,9 @@ class NativeTransactionReceiveStrategy : IReceiveStrategy
readonly TableBasedQueue errorQueue;
readonly Func<TransportMessage, bool> tryProcessMessageCallback;
readonly IsolationLevel isolationLevel;
readonly Func<string, SqlConnection> sqlConnectionFactory;
readonly CustomSqlConnectionFactory sqlConnectionFactory;

public NativeTransactionReceiveStrategy(string connectionString, TableBasedQueue errorQueue, Func<TransportMessage, bool> tryProcessMessageCallback, Func<string, SqlConnection> sqlConnectionFactory, PipelineExecutor pipelineExecutor, TransactionSettings transactionSettings)
public NativeTransactionReceiveStrategy(string connectionString, TableBasedQueue errorQueue, Func<TransportMessage, bool> tryProcessMessageCallback, CustomSqlConnectionFactory sqlConnectionFactory, PipelineExecutor pipelineExecutor, TransactionSettings transactionSettings)
{
this.pipelineExecutor = pipelineExecutor;
this.tryProcessMessageCallback = tryProcessMessageCallback;
Expand All @@ -27,7 +27,7 @@ public NativeTransactionReceiveStrategy(string connectionString, TableBasedQueue

public ReceiveResult TryReceiveFrom(TableBasedQueue queue)
{
using (var connection = sqlConnectionFactory(connectionString))
using (var connection = sqlConnectionFactory.OpenNewConnection(connectionString))
{
using (pipelineExecutor.SetConnection(connectionString, connection))
{
Expand Down
6 changes: 3 additions & 3 deletions src/NServiceBus.SqlServer/NoTransactionReceiveStrategy.cs
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,9 @@ class NoTransactionReceiveStrategy : IReceiveStrategy
readonly string connectionString;
readonly TableBasedQueue errorQueue;
readonly Func<TransportMessage, bool> tryProcessMessageCallback;
readonly Func<string, SqlConnection> sqlConnectionFactory;
readonly CustomSqlConnectionFactory sqlConnectionFactory;

public NoTransactionReceiveStrategy(string connectionString, TableBasedQueue errorQueue, Func<TransportMessage, bool> tryProcessMessageCallback, Func<string, SqlConnection> sqlConnectionFactory)
public NoTransactionReceiveStrategy(string connectionString, TableBasedQueue errorQueue, Func<TransportMessage, bool> tryProcessMessageCallback, CustomSqlConnectionFactory sqlConnectionFactory)
{
this.connectionString = connectionString;
this.errorQueue = errorQueue;
Expand All @@ -21,7 +21,7 @@ public NoTransactionReceiveStrategy(string connectionString, TableBasedQueue err
public ReceiveResult TryReceiveFrom(TableBasedQueue queue)
{
MessageReadResult readResult;
using (var connection = sqlConnectionFactory(connectionString))
using (var connection = sqlConnectionFactory.OpenNewConnection(connectionString))
{
readResult = queue.TryReceive(connection);
if (readResult.IsPoison)
Expand Down
6 changes: 3 additions & 3 deletions src/NServiceBus.SqlServer/QueuePurger.cs
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,9 @@ namespace NServiceBus.Transports.SQLServer
class QueuePurger : IQueuePurger
{
readonly LocalConnectionParams localConnectionParams;
readonly Func<string, SqlConnection> sqlConnectionFactory;
readonly CustomSqlConnectionFactory sqlConnectionFactory;

public QueuePurger(SecondaryReceiveConfiguration secondaryReceiveConfiguration, LocalConnectionParams localConnectionParams, Func<string, SqlConnection> sqlConnectionFactory)
public QueuePurger(SecondaryReceiveConfiguration secondaryReceiveConfiguration, LocalConnectionParams localConnectionParams, CustomSqlConnectionFactory sqlConnectionFactory)
{
this.secondaryReceiveConfiguration = secondaryReceiveConfiguration;
this.localConnectionParams = localConnectionParams;
Expand All @@ -26,7 +26,7 @@ public void Purge(Address address)

void Purge(IEnumerable<string> tableNames)
{
using (var connection = sqlConnectionFactory(localConnectionParams.ConnectionString))
using (var connection = sqlConnectionFactory.OpenNewConnection(localConnectionParams.ConnectionString))
{
foreach (var tableName in tableNames)
{
Expand Down
4 changes: 2 additions & 2 deletions src/NServiceBus.SqlServer/ReceiveStrategyFactory.cs
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,9 @@ class ReceiveStrategyFactory
readonly PipelineExecutor pipelineExecutor;
readonly Address errorQueueAddress;
readonly LocalConnectionParams localConnectionParams;
readonly Func<string, SqlConnection> sqlConnectionFactory;
readonly CustomSqlConnectionFactory sqlConnectionFactory;

public ReceiveStrategyFactory(PipelineExecutor pipelineExecutor, LocalConnectionParams localConnectionParams, Address errorQueueAddress, Func<string, SqlConnection> sqlConnectionFactory)
public ReceiveStrategyFactory(PipelineExecutor pipelineExecutor, LocalConnectionParams localConnectionParams, Address errorQueueAddress, CustomSqlConnectionFactory sqlConnectionFactory)
{
this.pipelineExecutor = pipelineExecutor;
this.errorQueueAddress = errorQueueAddress;
Expand Down
8 changes: 4 additions & 4 deletions src/NServiceBus.SqlServer/SqlServerMessageSender.cs
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,9 @@ class SqlServerMessageSender : ISendMessages
{
readonly IConnectionStringProvider connectionStringProvider;
readonly PipelineExecutor pipelineExecutor;
readonly Func<string, SqlConnection> sqlConnectionFactory;
readonly CustomSqlConnectionFactory sqlConnectionFactory;

public SqlServerMessageSender(IConnectionStringProvider connectionStringProvider, PipelineExecutor pipelineExecutor, Func<string, SqlConnection> sqlConnectionFactory)
public SqlServerMessageSender(IConnectionStringProvider connectionStringProvider, PipelineExecutor pipelineExecutor, CustomSqlConnectionFactory sqlConnectionFactory)
{
this.connectionStringProvider = connectionStringProvider;
this.pipelineExecutor = pipelineExecutor;
Expand Down Expand Up @@ -45,7 +45,7 @@ public void Send(TransportMessage message, SendOptions sendOptions)
}
else
{
using (var connection = sqlConnectionFactory(connectionInfo.ConnectionString))
using (var connection = sqlConnectionFactory.OpenNewConnection(connectionInfo.ConnectionString))
{
queue.Send(message, sendOptions, connection);
}
Expand All @@ -57,7 +57,7 @@ public void Send(TransportMessage message, SendOptions sendOptions)
// Suppress so that even if DTC is on, we won't escalate
using (var tx = new TransactionScope(TransactionScopeOption.Suppress))
{
using (var connection = sqlConnectionFactory(connectionInfo.ConnectionString))
using (var connection = sqlConnectionFactory.OpenNewConnection(connectionInfo.ConnectionString))
{
queue.Send(message, sendOptions, connection);
}
Expand Down
7 changes: 3 additions & 4 deletions src/NServiceBus.SqlServer/SqlServerQueueCreator.cs
Original file line number Diff line number Diff line change
Expand Up @@ -30,9 +30,8 @@ [RowVersion] ASC
public void CreateQueueIfNecessary(Address address, string account)
{
var connectionParams = connectionStringProvider.GetForDestination(address);
using (var connection = sqlConnectionFactory(connectionParams.ConnectionString))
using (var connection = sqlConnectionFactory.OpenNewConnection(connectionParams.ConnectionString))
{
connection.Open();
var sql = string.Format(Ddl, connectionParams.Schema, address.GetTableName());

using (var command = new SqlCommand(sql, connection) {CommandType = CommandType.Text})
Expand All @@ -43,9 +42,9 @@ public void CreateQueueIfNecessary(Address address, string account)
}

readonly IConnectionStringProvider connectionStringProvider;
readonly Func<string, SqlConnection> sqlConnectionFactory;
readonly CustomSqlConnectionFactory sqlConnectionFactory;

public SqlServerQueueCreator(IConnectionStringProvider connectionStringProvider, Func<string, SqlConnection> sqlConnectionFactory)
public SqlServerQueueCreator(IConnectionStringProvider connectionStringProvider, CustomSqlConnectionFactory sqlConnectionFactory)
{
this.connectionStringProvider = connectionStringProvider;
this.sqlConnectionFactory = sqlConnectionFactory;
Expand Down
2 changes: 1 addition & 1 deletion src/NServiceBus.SqlServer/SqlServerSettingsExtensions.cs
Original file line number Diff line number Diff line change
Expand Up @@ -151,7 +151,7 @@ public static TransportExtensions<SqlServerTransport> PauseAfterReceiveFailure(t
/// <returns></returns>
public static TransportExtensions<SqlServerTransport> UseCustomSqlConnectionFactory(this TransportExtensions<SqlServerTransport> transportExtensions, Func<string, SqlConnection> sqlConnectionFactory)
{
transportExtensions.GetSettings().Set(SqlConnectionFactoryConfig.CustomSqlConnectionFactorySettingKey, sqlConnectionFactory);
transportExtensions.GetSettings().Set(SqlConnectionFactoryConfig.CustomSqlConnectionFactorySettingKey, new CustomSqlConnectionFactory(sqlConnectionFactory));
return transportExtensions;
}
}
Expand Down

0 comments on commit 30c216f

Please sign in to comment.