Skip to content

Commit

Permalink
Make usage of NHibernate compatible with SET NOCOUNT ON DB-wide set…
Browse files Browse the repository at this point in the history
…ting #112

-Allow customizing of SQL Transport connections with injectable factory, like NHibernate does.
  • Loading branch information
a-belevich committed Aug 11, 2015
1 parent 914190c commit 8dec49a
Show file tree
Hide file tree
Showing 11 changed files with 96 additions and 27 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -13,13 +13,16 @@ class AmbientTransactionReceiveStrategy : IReceiveStrategy
readonly TableBasedQueue errorQueue;
readonly Func<TransportMessage, bool> tryProcessMessageCallback;
readonly TransactionOptions transactionOptions;
readonly Func<string, SqlConnection> sqlConnectionFactory;

public AmbientTransactionReceiveStrategy(string connectionString, TableBasedQueue errorQueue, Func<TransportMessage, bool> tryProcessMessageCallback, PipelineExecutor pipelineExecutor, TransactionSettings transactionSettings)
public AmbientTransactionReceiveStrategy(string connectionString, TableBasedQueue errorQueue, Func<TransportMessage, bool> tryProcessMessageCallback, Func<string, SqlConnection> sqlConnectionFactory, PipelineExecutor pipelineExecutor, TransactionSettings transactionSettings)
{
this.pipelineExecutor = pipelineExecutor;
this.tryProcessMessageCallback = tryProcessMessageCallback;
this.errorQueue = errorQueue;
this.connectionString = connectionString;
this.sqlConnectionFactory = sqlConnectionFactory;

transactionOptions = new TransactionOptions
{
IsolationLevel = transactionSettings.IsolationLevel,
Expand All @@ -31,9 +34,8 @@ public ReceiveResult TryReceiveFrom(TableBasedQueue queue)
{
using (var scope = new TransactionScope(TransactionScopeOption.Required, transactionOptions))
{
using (var connection = new SqlConnection(connectionString))
using (var connection = sqlConnectionFactory(connectionString))
{
connection.Open();
using (pipelineExecutor.SetConnection(connectionString, connection))
{
var readResult = queue.TryReceive(connection);
Expand Down
41 changes: 41 additions & 0 deletions src/NServiceBus.SqlServer/Config/SqlConnectionFactoryConfig.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
namespace NServiceBus.Transports.SQLServer.Config
{
using System;
using System.Data.SqlClient;
using NServiceBus.Features;
using NServiceBus.Settings;

class SqlConnectionFactoryConfig : ConfigBase
{
internal const string CustomSqlConnectionFactorySettingKey = "SqlServer.CustomSqlConnectionFactory";

static SqlConnection DefaultOpenNewConnection(string connectionString)
{
var connection = new SqlConnection(connectionString);

try
{
connection.Open();
}
catch (Exception)
{
connection.Dispose();
throw;
}

return connection;
}

public override void SetUpDefaults(SettingsHolder settings)
{
Func<string, SqlConnection> factoryMethod = DefaultOpenNewConnection;
settings.SetDefault(CustomSqlConnectionFactorySettingKey, factoryMethod);
}

public override void Configure(FeatureConfigurationContext context, string connectionStringWithSchema)
{
var factoryMethod = (Func<string, SqlConnection>)context.Settings.Get(CustomSqlConnectionFactorySettingKey);
context.Container.ConfigureComponent(b => factoryMethod, DependencyLifecycle.SingleInstance);
}
}
}
6 changes: 4 additions & 2 deletions src/NServiceBus.SqlServer/Config/SqlServerTransportFeature.cs
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ namespace NServiceBus.Features
using System;
using System.Collections.Generic;
using System.Configuration;
using System.Data.SqlClient;
using System.Linq;
using NServiceBus.Transports.SQLServer.Config;
using Pipeline;
Expand All @@ -18,7 +19,8 @@ class SqlServerTransportFeature : ConfigureTransport
new CallbackConfig(),
new CircuitBreakerConfig(),
new ConnectionConfig(ConfigurationManager.ConnectionStrings.Cast<ConnectionStringSettings>().ToList()),
new PurgingConfig()
new PurgingConfig(),
new SqlConnectionFactoryConfig()
};

public SqlServerTransportFeature()
Expand Down Expand Up @@ -62,7 +64,7 @@ protected override void Configure(FeatureConfigurationContext context, string co
context.Container.ConfigureComponent<SqlServerQueueCreator>(DependencyLifecycle.InstancePerCall);

var errorQueue = ErrorQueueSettings.GetConfiguredErrorQueue(context.Settings);
context.Container.ConfigureComponent(b => new ReceiveStrategyFactory(b.Build<PipelineExecutor>(), b.Build<LocalConnectionParams>(), errorQueue), DependencyLifecycle.InstancePerCall);
context.Container.ConfigureComponent(b => new ReceiveStrategyFactory(b.Build<PipelineExecutor>(), b.Build<LocalConnectionParams>(), errorQueue, b.Build<Func<string, SqlConnection>>()), DependencyLifecycle.InstancePerCall);

context.Container.ConfigureComponent<SqlServerPollingDequeueStrategy>(DependencyLifecycle.InstancePerCall);
context.Container.ConfigureComponent<SqlServerStorageContext>(DependencyLifecycle.InstancePerUnitOfWork);
Expand Down
1 change: 1 addition & 0 deletions src/NServiceBus.SqlServer/NServiceBus.SqlServer.csproj
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,7 @@
<Compile Include="Config\ConfigBase.cs" />
<Compile Include="Config\ConnectionConfig.cs" />
<Compile Include="Config\PurgingConfig.cs" />
<Compile Include="Config\SqlConnectionFactoryConfig.cs" />
<Compile Include="Config\ValidateOutboxOrAmbientTransactionsEnabled.cs" />
<Compile Include="ConnectionInfo.cs" />
<Compile Include="ConnectionParams.cs" />
Expand Down
7 changes: 4 additions & 3 deletions src/NServiceBus.SqlServer/NativeTransactionReceiveStrategy.cs
Original file line number Diff line number Diff line change
Expand Up @@ -13,21 +13,22 @@ class NativeTransactionReceiveStrategy : IReceiveStrategy
readonly TableBasedQueue errorQueue;
readonly Func<TransportMessage, bool> tryProcessMessageCallback;
readonly IsolationLevel isolationLevel;
readonly Func<string, SqlConnection> sqlConnectionFactory;

public NativeTransactionReceiveStrategy(string connectionString, TableBasedQueue errorQueue, Func<TransportMessage, bool> tryProcessMessageCallback, PipelineExecutor pipelineExecutor, TransactionSettings transactionSettings)
public NativeTransactionReceiveStrategy(string connectionString, TableBasedQueue errorQueue, Func<TransportMessage, bool> tryProcessMessageCallback, Func<string, SqlConnection> sqlConnectionFactory, PipelineExecutor pipelineExecutor, TransactionSettings transactionSettings)
{
this.pipelineExecutor = pipelineExecutor;
this.tryProcessMessageCallback = tryProcessMessageCallback;
this.errorQueue = errorQueue;
this.connectionString = connectionString;
this.sqlConnectionFactory = sqlConnectionFactory;
isolationLevel = GetSqlIsolationLevel(transactionSettings.IsolationLevel);
}

public ReceiveResult TryReceiveFrom(TableBasedQueue queue)
{
using (var connection = new SqlConnection(connectionString))
using (var connection = sqlConnectionFactory(connectionString))
{
connection.Open();
using (pipelineExecutor.SetConnection(connectionString, connection))
{
using (var transaction = connection.BeginTransaction(isolationLevel))
Expand Down
7 changes: 4 additions & 3 deletions src/NServiceBus.SqlServer/NoTransactionReceiveStrategy.cs
Original file line number Diff line number Diff line change
Expand Up @@ -8,20 +8,21 @@ class NoTransactionReceiveStrategy : IReceiveStrategy
readonly string connectionString;
readonly TableBasedQueue errorQueue;
readonly Func<TransportMessage, bool> tryProcessMessageCallback;
readonly Func<string, SqlConnection> sqlConnectionFactory;

public NoTransactionReceiveStrategy(string connectionString, TableBasedQueue errorQueue, Func<TransportMessage, bool> tryProcessMessageCallback)
public NoTransactionReceiveStrategy(string connectionString, TableBasedQueue errorQueue, Func<TransportMessage, bool> tryProcessMessageCallback, Func<string, SqlConnection> sqlConnectionFactory)
{
this.connectionString = connectionString;
this.errorQueue = errorQueue;
this.tryProcessMessageCallback = tryProcessMessageCallback;
this.sqlConnectionFactory = sqlConnectionFactory;
}

public ReceiveResult TryReceiveFrom(TableBasedQueue queue)
{
MessageReadResult readResult;
using (var connection = new SqlConnection(connectionString))
using (var connection = sqlConnectionFactory(connectionString))
{
connection.Open();
readResult = queue.TryReceive(connection);
if (readResult.IsPoison)
{
Expand Down
9 changes: 5 additions & 4 deletions src/NServiceBus.SqlServer/QueuePurger.cs
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
namespace NServiceBus.Transports.SQLServer
{
using System;
using System.Collections.Generic;
using System.Data;
using System.Data.SqlClient;
Expand All @@ -9,11 +10,13 @@ namespace NServiceBus.Transports.SQLServer
class QueuePurger : IQueuePurger
{
readonly LocalConnectionParams localConnectionParams;
readonly Func<string, SqlConnection> sqlConnectionFactory;

public QueuePurger(SecondaryReceiveConfiguration secondaryReceiveConfiguration, LocalConnectionParams localConnectionParams)
public QueuePurger(SecondaryReceiveConfiguration secondaryReceiveConfiguration, LocalConnectionParams localConnectionParams, Func<string, SqlConnection> sqlConnectionFactory)
{
this.secondaryReceiveConfiguration = secondaryReceiveConfiguration;
this.localConnectionParams = localConnectionParams;
this.sqlConnectionFactory = sqlConnectionFactory;
}

public void Purge(Address address)
Expand All @@ -23,10 +26,8 @@ public void Purge(Address address)

void Purge(IEnumerable<string> tableNames)
{
using (var connection = new SqlConnection(localConnectionParams.ConnectionString))
using (var connection = sqlConnectionFactory(localConnectionParams.ConnectionString))
{
connection.Open();

foreach (var tableName in tableNames)
{
using (var command = new SqlCommand(string.Format(SqlPurge, localConnectionParams.Schema, tableName), connection)
Expand Down
12 changes: 8 additions & 4 deletions src/NServiceBus.SqlServer/ReceiveStrategyFactory.cs
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
namespace NServiceBus.Transports.SQLServer
{
using System;
using System.Data.SqlClient;
using NServiceBus.Pipeline;
using NServiceBus.Unicast.Transport;

Expand All @@ -9,26 +10,29 @@ class ReceiveStrategyFactory
readonly PipelineExecutor pipelineExecutor;
readonly Address errorQueueAddress;
readonly LocalConnectionParams localConnectionParams;
readonly Func<string, SqlConnection> sqlConnectionFactory;

public ReceiveStrategyFactory(PipelineExecutor pipelineExecutor, LocalConnectionParams localConnectionParams, Address errorQueueAddress)
public ReceiveStrategyFactory(PipelineExecutor pipelineExecutor, LocalConnectionParams localConnectionParams, Address errorQueueAddress, Func<string, SqlConnection> sqlConnectionFactory)
{
this.pipelineExecutor = pipelineExecutor;
this.errorQueueAddress = errorQueueAddress;
this.localConnectionParams = localConnectionParams;
this.sqlConnectionFactory = sqlConnectionFactory;
}

public IReceiveStrategy Create(TransactionSettings settings, Func<TransportMessage, bool> tryProcessMessageCallback)
{
var errorQueue = new TableBasedQueue(errorQueueAddress, localConnectionParams.Schema);

if (settings.IsTransactional)
{
if (settings.SuppressDistributedTransactions)
{
return new NativeTransactionReceiveStrategy(localConnectionParams.ConnectionString, errorQueue, tryProcessMessageCallback, pipelineExecutor, settings);
return new NativeTransactionReceiveStrategy(localConnectionParams.ConnectionString, errorQueue, tryProcessMessageCallback, sqlConnectionFactory, pipelineExecutor, settings);
}
return new AmbientTransactionReceiveStrategy(localConnectionParams.ConnectionString, errorQueue, tryProcessMessageCallback, pipelineExecutor, settings);
return new AmbientTransactionReceiveStrategy(localConnectionParams.ConnectionString, errorQueue, tryProcessMessageCallback, sqlConnectionFactory, pipelineExecutor, settings);
}
return new NoTransactionReceiveStrategy(localConnectionParams.ConnectionString, errorQueue, tryProcessMessageCallback);
return new NoTransactionReceiveStrategy(localConnectionParams.ConnectionString, errorQueue, tryProcessMessageCallback, sqlConnectionFactory);
}
}
}
10 changes: 5 additions & 5 deletions src/NServiceBus.SqlServer/SqlServerMessageSender.cs
Original file line number Diff line number Diff line change
Expand Up @@ -11,11 +11,13 @@ class SqlServerMessageSender : ISendMessages
{
readonly IConnectionStringProvider connectionStringProvider;
readonly PipelineExecutor pipelineExecutor;
readonly Func<string, SqlConnection> sqlConnectionFactory;

public SqlServerMessageSender(IConnectionStringProvider connectionStringProvider, PipelineExecutor pipelineExecutor)
public SqlServerMessageSender(IConnectionStringProvider connectionStringProvider, PipelineExecutor pipelineExecutor, Func<string, SqlConnection> sqlConnectionFactory)
{
this.connectionStringProvider = connectionStringProvider;
this.pipelineExecutor = pipelineExecutor;
this.sqlConnectionFactory = sqlConnectionFactory;
}

public void Send(TransportMessage message, SendOptions sendOptions)
Expand Down Expand Up @@ -43,9 +45,8 @@ public void Send(TransportMessage message, SendOptions sendOptions)
}
else
{
using (var connection = new SqlConnection(connectionInfo.ConnectionString))
using (var connection = sqlConnectionFactory(connectionInfo.ConnectionString))
{
connection.Open();
queue.Send(message, sendOptions, connection);
}
}
Expand All @@ -56,9 +57,8 @@ public void Send(TransportMessage message, SendOptions sendOptions)
// Suppress so that even if DTC is on, we won't escalate
using (var tx = new TransactionScope(TransactionScopeOption.Suppress))
{
using (var connection = new SqlConnection(connectionInfo.ConnectionString))
using (var connection = sqlConnectionFactory(connectionInfo.ConnectionString))
{
connection.Open();
queue.Send(message, sendOptions, connection);
}

Expand Down
9 changes: 6 additions & 3 deletions src/NServiceBus.SqlServer/SqlServerQueueCreator.cs
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
namespace NServiceBus.Transports.SQLServer
{
using System;
using System.Data;
using System.Data.SqlClient;

Expand Down Expand Up @@ -29,10 +30,10 @@ [RowVersion] ASC
public void CreateQueueIfNecessary(Address address, string account)
{
var connectionParams = connectionStringProvider.GetForDestination(address);
using (var connection = new SqlConnection(connectionParams.ConnectionString))
using (var connection = sqlConnectionFactory(connectionParams.ConnectionString))
{
connection.Open();
var sql = string.Format(Ddl, connectionParams.Schema, address.GetTableName());
connection.Open();

using (var command = new SqlCommand(sql, connection) {CommandType = CommandType.Text})
{
Expand All @@ -42,10 +43,12 @@ public void CreateQueueIfNecessary(Address address, string account)
}

readonly IConnectionStringProvider connectionStringProvider;
readonly Func<string, SqlConnection> sqlConnectionFactory;

public SqlServerQueueCreator(IConnectionStringProvider connectionStringProvider)
public SqlServerQueueCreator(IConnectionStringProvider connectionStringProvider, Func<string, SqlConnection> sqlConnectionFactory)
{
this.connectionStringProvider = connectionStringProvider;
this.sqlConnectionFactory = sqlConnectionFactory;
}
}
}
13 changes: 13 additions & 0 deletions src/NServiceBus.SqlServer/SqlServerSettingsExtensions.cs
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
{
using System;
using System.Collections.Generic;
using System.Data.SqlClient;
using System.Linq;
using Configuration.AdvanceExtensibility;
using NServiceBus.Transports.SQLServer;
Expand Down Expand Up @@ -141,5 +142,17 @@ public static TransportExtensions<SqlServerTransport> PauseAfterReceiveFailure(t
transportExtensions.GetSettings().Set(CircuitBreakerConfig.CircuitBreakerDelayAfterFailureSettingsKey, pauseTime);
return transportExtensions;
}

/// <summary>
/// Overrides the default time SQL Connections factory.
/// </summary>
/// <param name="transportExtensions"></param>
/// <param name="sqlConnectionFactory">Factory for creating and opening new SQL Connections.</param>
/// <returns></returns>
public static TransportExtensions<SqlServerTransport> UseCustomSqlConnectionFactory(this TransportExtensions<SqlServerTransport> transportExtensions, Func<string, SqlConnection> sqlConnectionFactory)
{
transportExtensions.GetSettings().Set(SqlConnectionFactoryConfig.CustomSqlConnectionFactorySettingKey, sqlConnectionFactory);
return transportExtensions;
}
}
}

0 comments on commit 8dec49a

Please sign in to comment.