Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Enable multiple endpoint from configs #341

Merged
merged 8 commits into from Jan 17, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
Expand Up @@ -2,8 +2,6 @@
// Licensed under the MIT license. See LICENSE file in the project root for full license information.

using System;
using System.Collections.Generic;
using System.Linq;
using Microsoft.Extensions.Logging;

namespace Microsoft.Azure.SignalR.AspNet
Expand All @@ -13,9 +11,14 @@ internal class ServiceEndpointManager : ServiceEndpointManagerBase
private readonly TimeSpan? _ttl;

public ServiceEndpointManager(ServiceOptions options, ILoggerFactory loggerFactory) :
base(GetEndpoints(options).ToArray(),
base(options,
loggerFactory?.CreateLogger<ServiceEndpointManager>())
{
if (Endpoints.Length == 0)
{
throw new ArgumentException(ServiceEndpointProvider.ConnectionStringNotFound);
}

_ttl = options.AccessTokenLifetime;
}

Expand All @@ -28,17 +31,5 @@ public override IServiceEndpointProvider GetEndpointProvider(ServiceEndpoint end

return new ServiceEndpointProvider(endpoint, _ttl);
}

private static IEnumerable<ServiceEndpoint> GetEndpoints(ServiceOptions options)
{
// TODO: support multiple endpoints
var connectionString = options.ConnectionString;
if (string.IsNullOrEmpty(connectionString))
{
throw new ArgumentException(ServiceEndpointProvider.ConnectionStringNotFound);
}

yield return new ServiceEndpoint(connectionString);
}
}
}
1 change: 1 addition & 0 deletions src/Microsoft.Azure.SignalR.AspNet/OwinExtensions.cs
Expand Up @@ -206,6 +206,7 @@ private static void RunAzureSignalRCore(IAppBuilder builder, string applicationN
}

var endpoint = new ServiceEndpointManager(options, logger);
configuration.Resolver.Register(typeof(IServiceEndpointManager), () => endpoint);

// Get the one from DI or new a default one
var router = configuration.Resolver.Resolve<IEndpointRouter>() ?? new DefaultRouter();
Expand Down
76 changes: 71 additions & 5 deletions src/Microsoft.Azure.SignalR.AspNet/ServiceOptions.cs
Expand Up @@ -4,6 +4,7 @@
using System;
using System.Collections.Generic;
using System.Configuration;
using System.Linq;
using System.Security.Claims;
using Microsoft.Owin;

Expand All @@ -12,12 +13,12 @@ namespace Microsoft.Azure.SignalR.AspNet
/// <summary>
/// Configurable options when using Azure SignalR Service.
/// </summary>
public class ServiceOptions
public class ServiceOptions : IServiceEndpointOptions
{
/// <summary>
/// Gets or sets the connection string of Azure SignalR Service instance.
/// </summary>
public string ConnectionString { get; set; } = GetDefaultConnectionString();
public string ConnectionString { get; set; }

/// <summary>
/// Gets or sets the total number of connections from SDK to Azure SignalR Service. Default value is 5.
Expand All @@ -36,10 +37,75 @@ public class ServiceOptions
/// </summary>
public TimeSpan AccessTokenLifetime { get; set; } = Constants.DefaultAccessTokenLifetime;

private static string GetDefaultConnectionString()
/// <summary>
/// TODO: expose to customer
/// Gets or sets list of endpoints
/// </summary>
internal ServiceEndpoint[] Endpoints { get; set; }

ServiceEndpoint[] IServiceEndpointOptions.Endpoints => Endpoints;

public ServiceOptions()
{
var count = ConfigurationManager.ConnectionStrings.Count;
string connectionString = null;
var endpoints = new List<ServiceEndpoint>();
for (int i = 0; i < count; i++)
{
var setting = ConfigurationManager.ConnectionStrings[i];
var (isDefault, endpoint) = GetEndpoint(setting.Name, () => setting.ConnectionString);
if (endpoint != null)
{
if (isDefault)
{
connectionString = endpoint.ConnectionString;
}

endpoints.Add(endpoint);
}
}

if (endpoints.Count == 0)
{
// Fallback to use AppSettings
foreach(var key in ConfigurationManager.AppSettings.AllKeys)
{
var (isDefault, endpoint) = GetEndpoint(key, () => ConfigurationManager.AppSettings[key]);
if (endpoint != null)
{
if (isDefault)
{
connectionString = endpoint.ConnectionString;
}

endpoints.Add(endpoint);
}
}
}

// Load connection string from "Azure:SignalR:ConnectionString" section or key starts with "Azure:SignalR:ConnectionString:" when default key doesn't exist or holds an empty value.
if (string.IsNullOrEmpty(connectionString))
{
connectionString = endpoints.FirstOrDefault()?.ConnectionString;
}

ConnectionString = connectionString;
Endpoints = endpoints.ToArray();
}

private static (bool isDefault, ServiceEndpoint endpoint) GetEndpoint(string key, Func<string> valueGetter)
{
return ConfigurationManager.ConnectionStrings[Constants.ConnectionStringDefaultKey]?.ConnectionString
?? ConfigurationManager.AppSettings[Constants.ConnectionStringDefaultKey];
if (key == Constants.ConnectionStringDefaultKey && !string.IsNullOrEmpty(valueGetter()))
{
return (true, new ServiceEndpoint(valueGetter()));
}

if (key.StartsWith(Constants.ConnectionStringKeyPrefix) && !string.IsNullOrEmpty(valueGetter()))
{
return (false, new ServiceEndpoint(key, valueGetter()));
}

return (false, null);
}
}
}
5 changes: 5 additions & 0 deletions src/Microsoft.Azure.SignalR.Common/Constants.cs
Expand Up @@ -9,8 +9,13 @@ internal static class Constants
{
public const string ConnectionStringDefaultKey = "Azure:SignalR:ConnectionString";

public static readonly string ConnectionStringSecondaryKey =
$"ConnectionStrings:{ConnectionStringDefaultKey}";

public static readonly string ConnectionStringKeyPrefix = $"{ConnectionStringDefaultKey}:";

public static readonly string ConnectionStringSecondaryKeyPrefix = $"{ConnectionStringSecondaryKey}:";

// Default access token lifetime
public static readonly TimeSpan DefaultAccessTokenLifetime = TimeSpan.FromHours(1);

Expand Down
@@ -0,0 +1,11 @@
// Copyright (c) Microsoft. All rights reserved.
// Licensed under the MIT license. See LICENSE file in the project root for full license information.

namespace Microsoft.Azure.SignalR
{
internal interface IServiceEndpointOptions
{
ServiceEndpoint[] Endpoints { get; }
string ConnectionString { get; }
}
}
40 changes: 25 additions & 15 deletions src/Microsoft.Azure.SignalR.Common/Endpoints/ServiceEndpoint.cs
Expand Up @@ -65,34 +65,44 @@ public override int GetHashCode()

internal static (string, EndpointType) ParseKey(string key)
{
if (key == Constants.ConnectionStringDefaultKey)
if (key == Constants.ConnectionStringDefaultKey || key == Constants.ConnectionStringSecondaryKey)
{
return (string.Empty, EndpointType.Primary);
}

if (key.StartsWith(Constants.ConnectionStringKeyPrefix))
{
// Azure:SignalR:ConnectionString:<name>:<type>
var status = key.Substring(Constants.ConnectionStringKeyPrefix.Length);
var parts = status.Split(':');
if (parts.Length == 1)
return ParseKeyWithPrefix(key, Constants.ConnectionStringKeyPrefix);
}

if (key.StartsWith(Constants.ConnectionStringSecondaryKey))
{
return ParseKeyWithPrefix(key, Constants.ConnectionStringSecondaryKey);
}

throw new ArgumentException($"Invalid format: {key}", nameof(key));
}

private static (string, EndpointType) ParseKeyWithPrefix(string key, string prefix)
{
var status = key.Substring(prefix.Length);
var parts = status.Split(':');
if (parts.Length == 1)
{
return (parts[0], EndpointType.Primary);
}
else
{
if (Enum.TryParse<EndpointType>(parts[1], true, out var endpointStatus))
{
return (parts[0], EndpointType.Primary);
return (parts[0], endpointStatus);
}
else
{
if (Enum.TryParse<EndpointType>(parts[1], true, out var endpointStatus))
{
return (parts[0], endpointStatus);
}
else
{
return (status, EndpointType.Primary);
}
return (status, EndpointType.Primary);
}
}

throw new ArgumentException($"Invalid format: {key}", nameof(key));
}
}
}
Expand Up @@ -12,48 +12,55 @@ namespace Microsoft.Azure.SignalR
{
internal abstract class ServiceEndpointManagerBase : IServiceEndpointManager
{
private readonly ServiceEndpoint[] _endpoints;
private readonly ServiceEndpoint[] _primaryEndpoints;
private readonly ILogger _logger;

public ServiceEndpointManagerBase(IReadOnlyCollection<ServiceEndpoint> endpoints, ILogger logger)
protected ServiceEndpoint[] Endpoints { get; }

public ServiceEndpointManagerBase(IServiceEndpointOptions options, ILogger logger)
: this(GetEndpoints(options).ToArray(), logger)
{
if (endpoints.Count == 0)
{
throw new AzureSignalRNoEndpointAvailableException();
}
}

// for test purpose
internal ServiceEndpointManagerBase(ServiceEndpoint[] endpoints, ILogger logger)
{
Endpoints = endpoints;

_logger = logger ?? NullLogger.Instance;

var groupedEndpoints = endpoints.GroupBy(s => s.Endpoint).Select(s =>
if (Endpoints.Length != 0)
{
var items = s.ToList();
if (items.Count > 1)
var groupedEndpoints = Endpoints.GroupBy(s => s.Endpoint).Select(s =>
{
// By default pick up the primary endpoint, otherwise the first one
var item = items.FirstOrDefault(i => i.EndpointType == EndpointType.Primary) ?? items.FirstOrDefault();
Log.DuplicateEndpointFound(_logger, items.Count, item.Endpoint, item.ToString());
return item;
}
var items = s.ToList();
if (items.Count > 1)
{
// By default pick up the primary endpoint, otherwise the first one
var item = items.FirstOrDefault(i => i.EndpointType == EndpointType.Primary) ?? items.FirstOrDefault();
Log.DuplicateEndpointFound(_logger, items.Count, item.Endpoint, item.ToString());
return item;
}

return items[0];
});
return items[0];
});

_endpoints = groupedEndpoints.ToArray();
Endpoints = groupedEndpoints.ToArray();

_primaryEndpoints = _endpoints.Where(s => s.EndpointType == EndpointType.Primary).ToArray();
_primaryEndpoints = Endpoints.Where(s => s.EndpointType == EndpointType.Primary).ToArray();

if (_primaryEndpoints.Length == 0)
{
throw new AzureSignalRNoPrimaryEndpointException();
if (_primaryEndpoints.Length == 0)
{
throw new AzureSignalRNoPrimaryEndpointException();
}
}
}

public abstract IServiceEndpointProvider GetEndpointProvider(ServiceEndpoint endpoint);

public IReadOnlyList<ServiceEndpoint> GetAvailableEndpoints()
{
return _endpoints;
return Endpoints;
}

/// <summary>
Expand All @@ -65,6 +72,33 @@ public IReadOnlyList<ServiceEndpoint> GetPrimaryEndpoints()
return _primaryEndpoints;
}

private static IEnumerable<ServiceEndpoint> GetEndpoints(IServiceEndpointOptions options)
{
if (options == null)
{
yield break;
}

var endpoints = options.Endpoints;
var connectionString = options.ConnectionString;

// ConnectionString can be set by custom Csonfigure
// Return both the one from ConnectionString and from Endpoints
// TODO: Better way if Endpoints already contains ConnectionString one?
if (!string.IsNullOrEmpty(connectionString))
{
yield return new ServiceEndpoint(options.ConnectionString);
}

if (endpoints != null)
{
foreach (var endpoint in endpoints)
{
yield return endpoint;
}
}
}

private static class Log
{
private static readonly Action<ILogger, int, string, string, Exception> _duplicateEndpointFound =
Expand Down
Expand Up @@ -2,9 +2,6 @@
// Licensed under the MIT license. See LICENSE file in the project root for full license information.

using System;
using System.Collections.Generic;
using System.Linq;
using Microsoft.Azure.SignalR.Common;
using Microsoft.Extensions.Logging;
using Microsoft.Extensions.Options;

Expand All @@ -15,27 +12,20 @@ internal class ServiceEndpointManager : ServiceEndpointManagerBase
private readonly TimeSpan? _ttl;

public ServiceEndpointManager(IOptions<ServiceOptions> options, ILoggerFactory loggerFactory) :
base(GetEndpoints(options?.Value).ToArray(),
loggerFactory?.CreateLogger<ServiceEndpointManager>())
base(options.Value,
loggerFactory.CreateLogger<ServiceEndpointManager>())
{
if (Endpoints.Length == 0)
{
throw new ArgumentException(ServiceEndpointProvider.ConnectionStringNotFound);
}

_ttl = options.Value?.AccessTokenLifetime;
}

public override IServiceEndpointProvider GetEndpointProvider(ServiceEndpoint endpoint)
{
return new ServiceEndpointProvider(endpoint, _ttl);
}

private static IEnumerable<ServiceEndpoint> GetEndpoints(ServiceOptions options)
{
// TODO: support multiple endpoints
var connectionString = options?.ConnectionString;
if (string.IsNullOrEmpty(connectionString))
{
throw new ArgumentException(ServiceEndpointProvider.ConnectionStringNotFound);
}

yield return new ServiceEndpoint(connectionString);
}
}
}