Skip to content

Commit

Permalink
Related to #3303 - Added support for ICredentialsProvider and perhaps…
Browse files Browse the repository at this point in the history
… ICredentialsRefresher (RabbitMQ).
  • Loading branch information
phatboyg committed Apr 25, 2024
1 parent 2af9f8f commit c287ee3
Show file tree
Hide file tree
Showing 7 changed files with 66 additions and 24 deletions.
Original file line number Diff line number Diff line change
@@ -1,13 +1,10 @@
namespace MassTransit
#nullable enable
namespace MassTransit
{
using System;
using System.Threading.Tasks;
using RabbitMQ.Client;


public delegate Task RefreshConnectionFactoryCallback(ConnectionFactory connectionFactory);


public interface IRabbitMqHostConfigurator
{
/// <summary>
Expand All @@ -18,11 +15,21 @@ public interface IRabbitMqHostConfigurator

RefreshConnectionFactoryCallback OnRefreshConnectionFactory { set; }

/// <summary>
/// Sets the credential provider, overriding the default username/password credentials
/// </summary>
ICredentialsProvider CredentialsProvider { set; }

/// <summary>
/// Sets the credentials refresher, allowing access token based credentials to be refreshed
/// </summary>
ICredentialsRefresher CredentialsRefresher { set; }

/// <summary>
/// Configure the use of SSL to connection to RabbitMQ
/// </summary>
/// <param name="configureSsl"></param>
void UseSsl(Action<IRabbitMqSslConfigurator> configureSsl);
/// <param name="configure"></param>
void UseSsl(Action<IRabbitMqSslConfigurator>? configure = null);

/// <summary>
/// Specifies the heartbeat interval, in seconds, used to maintain the connection to RabbitMQ.
Expand Down Expand Up @@ -94,6 +101,6 @@ public interface IRabbitMqHostConfigurator
/// <summary>
/// Sets the connection name for the connection to RabbitMQ
/// </summary>
void ConnectionName(string connectionName);
void ConnectionName(string? connectionName);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -148,6 +148,16 @@ public interface RabbitMqHostSettings
/// </summary>
uint? MaxMessageSize { get; }

/// <summary>
/// The credential provider, overriding the default username/password credentials
/// </summary>
ICredentialsProvider CredentialsProvider { get; }

/// <summary>
/// The credentials refresher, allowing access token based credentials to be refreshed
/// </summary>
ICredentialsRefresher CredentialsRefresher { get; }

/// <summary>
/// Called prior to the connection factory being used to connect, so that any settings can be updated.
/// Typically this would be the username/password in response to an expired token, etc.
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
namespace MassTransit;

using System.Threading.Tasks;
using RabbitMQ.Client;


public delegate Task RefreshConnectionFactoryCallback(ConnectionFactory connectionFactory);
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
#nullable enable
namespace MassTransit.RabbitMqTransport.Configuration
{
using System;
Expand Down Expand Up @@ -55,7 +56,7 @@ public ConfigurationHostSettings()
public LocalCertificateSelectionCallback CertificateSelectionCallback { get; set; }
public RemoteCertificateValidationCallback CertificateValidationCallback { get; set; }
public IRabbitMqEndpointResolver EndpointResolver { get; set; }
public string ClientProvidedName { get; set; }
public string? ClientProvidedName { get; set; }
public bool PublisherConfirmation { get; set; }
public Uri HostAddress => _hostAddress.Value;
public ushort RequestedChannelMax { get; set; }
Expand All @@ -65,6 +66,9 @@ public ConfigurationHostSettings()
public TimeSpan ContinuationTimeout { get; set; }
public uint? MaxMessageSize { get; set; }

public ICredentialsProvider? CredentialsProvider { get; set; }
public ICredentialsRefresher? CredentialsRefresher { get; set; }

public Task Refresh(ConnectionFactory connectionFactory)
{
return OnRefreshConnectionFactory?.Invoke(connectionFactory) ?? Task.CompletedTask;
Expand Down
Original file line number Diff line number Diff line change
@@ -1,32 +1,30 @@
#nullable enable
namespace MassTransit.RabbitMqTransport.Configuration
{
using System;
using RabbitMQ.Client;


public class RabbitMqHostConfigurator :
IRabbitMqHostConfigurator
{
static readonly char[] _pathSeparator = { '/' };
static readonly char[] _pathSeparator = ['/'];
readonly ConfigurationHostSettings _settings;

public RabbitMqHostConfigurator(Uri hostAddress, string connectionName = null)
public RabbitMqHostConfigurator(Uri hostAddress, string? connectionName = null)
{
_settings = hostAddress.GetConfigurationHostSettings();

if (_settings.Port == 5671)
{
UseSsl(s =>
{
});
}
UseSsl();

_settings.VirtualHost = Uri.UnescapeDataString(GetVirtualHost(hostAddress));

if (!string.IsNullOrEmpty(connectionName))
_settings.ClientProvidedName = connectionName;
}

public RabbitMqHostConfigurator(string host, string virtualHost, ushort port = 5672, string connectionName = null)
public RabbitMqHostConfigurator(string host, string virtualHost, ushort port = 5672, string? connectionName = null)
{
_settings = new ConfigurationHostSettings
{
Expand All @@ -53,11 +51,11 @@ public bool PublisherConfirmation
set => _settings.PublisherConfirmation = value;
}

public void UseSsl(Action<IRabbitMqSslConfigurator> configureSsl)
public void UseSsl(Action<IRabbitMqSslConfigurator>? configure = null)
{
var configurator = new RabbitMqSslConfigurator(_settings);

configureSsl(configurator);
configure?.Invoke(configurator);

_settings.Ssl = true;
_settings.ClientCertificatePassphrase = configurator.CertificatePassphrase;
Expand Down Expand Up @@ -116,6 +114,16 @@ public void Password(string password)
_settings.Password = password;
}

public ICredentialsProvider CredentialsProvider
{
set => _settings.CredentialsProvider = value;
}

public ICredentialsRefresher CredentialsRefresher
{
set => _settings.CredentialsRefresher = value;
}

public void UseCluster(Action<IRabbitMqClusterConfigurator> configureCluster)
{
var configurator = new RabbitMqClusterConfigurator(_settings);
Expand All @@ -139,6 +147,11 @@ public void RequestedConnectionTimeout(TimeSpan timeSpan)
_settings.RequestedConnectionTimeout = timeSpan;
}

public void ConnectionName(string? connectionName)
{
_settings.ClientProvidedName = connectionName;
}

string GetVirtualHost(Uri address)
{
var segments = address.AbsolutePath.Split(_pathSeparator, StringSplitOptions.RemoveEmptyEntries);
Expand All @@ -151,10 +164,5 @@ string GetVirtualHost(Uri address)

throw new FormatException("The host path must be empty or contain a single virtual host name");
}

public void ConnectionName(string connectionName)
{
_settings.ClientProvidedName = connectionName;
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,8 @@ public static ConnectionFactory GetConnectionFactory(this RabbitMqHostSettings s
factory.UserName = "";
factory.Password = "";
}
else if (settings.CredentialsProvider != null)
factory.CredentialsProvider = settings.CredentialsProvider;
else
{
if (!string.IsNullOrWhiteSpace(settings.Username))
Expand All @@ -74,6 +76,8 @@ public static ConnectionFactory GetConnectionFactory(this RabbitMqHostSettings s
factory.Password = settings.Password;
}

factory.CredentialsRefresher = settings.CredentialsRefresher;

ApplySslOptions(settings, factory.Ssl);

factory.ClientProperties ??= new Dictionary<string, object>();
Expand Down
2 changes: 2 additions & 0 deletions tests/MassTransit.Benchmark/RabbitMqOptionSet.cs
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,8 @@ public RabbitMqOptionSet()

public TimeSpan ContinuationTimeout => TimeSpan.FromSeconds(20);
public uint? MaxMessageSize { get; set; }
public ICredentialsProvider CredentialsProvider { get; set; }
public ICredentialsRefresher CredentialsRefresher { get; set; }

public Task Refresh(ConnectionFactory connectionFactory)
{
Expand Down

0 comments on commit c287ee3

Please sign in to comment.