Skip to content
This repository has been archived by the owner on Dec 18, 2018. It is now read-only.

fix issue with incorrect user detection when Invoking for User #747

Merged
merged 14 commits into from
Oct 6, 2017
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -138,17 +138,14 @@ public override Task InvokeGroupAsync(string groupName, string methodName, objec

public override Task InvokeUserAsync(string userId, string methodName, object[] args)
{
return InvokeAllWhere(methodName, args, connection =>
{
return string.Equals(connection.User.Identity.Name, userId, StringComparison.Ordinal);
});
return InvokeAllWhere(methodName, args, connection =>
string.Equals(connection.UserIdentifier, userId, StringComparison.Ordinal));
}

public override Task OnConnectedAsync(HubConnectionContext connection)
{
// Set the hub groups feature
connection.Features.Set<IHubGroupsFeature>(new HubGroupsFeature());

_connections.Add(connection);
return Task.CompletedTask;
}
Expand Down
15 changes: 15 additions & 0 deletions src/Microsoft.AspNetCore.SignalR.Core/DefaultUserIdProvider.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
// Copyright (c) .NET Foundation. All rights reserved.
// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information.

using System.Security.Claims;

namespace Microsoft.AspNetCore.SignalR.Core
{
public class DefaultUserIdProvider : IUserIdProvider
{
public string GetUserId(HubConnectionContext connection)
{
return connection.User.FindFirst(ClaimTypes.NameIdentifier)?.Value;
}
}
}
2 changes: 2 additions & 0 deletions src/Microsoft.AspNetCore.SignalR.Core/HubConnectionContext.cs
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,8 @@ public virtual void Abort()
Task.Factory.StartNew(_abortedCallback, this);
}

public string UserIdentifier { get; internal set; }

internal void Abort(Exception exception)
{
AbortException = ExceptionDispatchInfo.Capture(exception);
Expand Down
8 changes: 7 additions & 1 deletion src/Microsoft.AspNetCore.SignalR.Core/HubEndPoint.cs
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
using System.Threading.Tasks;
using System.Threading.Tasks.Channels;
using Microsoft.AspNetCore.Authorization;
using Microsoft.AspNetCore.SignalR.Core;
using Microsoft.AspNetCore.SignalR.Core.Internal;
using Microsoft.AspNetCore.SignalR.Features;
using Microsoft.AspNetCore.SignalR.Internal;
Expand Down Expand Up @@ -39,20 +40,23 @@ public class HubEndPoint<THub> : IInvocationBinder where THub : Hub
private readonly IServiceScopeFactory _serviceScopeFactory;
private readonly IHubProtocolResolver _protocolResolver;
private readonly IOptions<HubOptions> _hubOptions;
private readonly IUserIdProvider _userIdProvider;

public HubEndPoint(HubLifetimeManager<THub> lifetimeManager,
IHubProtocolResolver protocolResolver,
IHubContext<THub> hubContext,
IOptions<HubOptions> hubOptions,
ILogger<HubEndPoint<THub>> logger,
IServiceScopeFactory serviceScopeFactory)
IServiceScopeFactory serviceScopeFactory,
IUserIdProvider userIdProvider)
{
_protocolResolver = protocolResolver;
_lifetimeManager = lifetimeManager;
_hubContext = hubContext;
_hubOptions = hubOptions;
_logger = logger;
_serviceScopeFactory = serviceScopeFactory;
_userIdProvider = userIdProvider;

DiscoverHubMethods();
}
Expand All @@ -72,6 +76,8 @@ public async Task OnConnectedAsync(ConnectionContext connection)
return;
}

connectionContext.UserIdentifier = _userIdProvider.GetUserId(connectionContext);

// Hubs support multiple producers so we set up this loop to copy
// data written to the HubConnectionContext's channel to the transport channel
var protocolReaderWriter = connectionContext.ProtocolReaderWriter;
Expand Down
10 changes: 10 additions & 0 deletions src/Microsoft.AspNetCore.SignalR.Core/IUserIdProvider.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
// Copyright (c) .NET Foundation. All rights reserved.
// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information.

namespace Microsoft.AspNetCore.SignalR.Core
{
public interface IUserIdProvider
{
string GetUserId(HubConnectionContext connection);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information.

using Microsoft.AspNetCore.SignalR;
using Microsoft.AspNetCore.SignalR.Core;
using Microsoft.AspNetCore.SignalR.Internal;

namespace Microsoft.Extensions.DependencyInjection
Expand All @@ -15,6 +16,7 @@ public static ISignalRBuilder AddSignalRCore(this IServiceCollection services)
services.AddSingleton(typeof(IHubContext<>), typeof(HubContext<>));
services.AddSingleton(typeof(IHubContext<,>), typeof(HubContext<,>));
services.AddSingleton(typeof(HubEndPoint<>), typeof(HubEndPoint<>));
services.AddSingleton(typeof(IUserIdProvider), typeof(DefaultUserIdProvider));
services.AddScoped(typeof(IHubActivator<>), typeof(DefaultHubActivator<>));

services.AddAuthorization();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -245,9 +245,9 @@ public override Task OnConnectedAsync(HubConnectionContext connection)
previousConnectionTask = WriteAsync(connection, message);
});

if (connection.User.Identity.IsAuthenticated)
if (!string.IsNullOrEmpty(connection.UserIdentifier))
{
var userChannel = _channelNamePrefix + ".user." + connection.User.Identity.Name;
var userChannel = _channelNamePrefix + ".user." + connection.UserIdentifier;
redisSubscriptions.Add(userChannel);

var previousUserTask = Task.CompletedTask;
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
using System.Threading.Tasks;
using System.Threading.Tasks.Channels;
using Microsoft.AspNetCore.SignalR.Core;
using Microsoft.AspNetCore.SignalR.Internal.Protocol;
using Xunit;

Expand Down
6 changes: 3 additions & 3 deletions test/Microsoft.AspNetCore.SignalR.Tests/HubEndpointTests.cs
Original file line number Diff line number Diff line change
Expand Up @@ -850,15 +850,15 @@ public async Task HubsCanSendToUser(Type hubType)

dynamic endPoint = serviceProvider.GetService(GetEndPointType(hubType));

using (var firstClient = new TestClient())
using (var secondClient = new TestClient())
using (var firstClient = new TestClient(addClaimId: true))
using (var secondClient = new TestClient(addClaimId: true))
{
Task firstEndPointTask = endPoint.OnConnectedAsync(firstClient.Connection);
Task secondEndPointTask = endPoint.OnConnectedAsync(secondClient.Connection);

await Task.WhenAll(firstClient.Connected, secondClient.Connected).OrTimeout();

await firstClient.SendInvocationAsync("ClientSendMethod", secondClient.Connection.User.Identity.Name, "test").OrTimeout();
await firstClient.SendInvocationAsync("ClientSendMethod", secondClient.Connection.User.FindFirst(ClaimTypes.NameIdentifier)?.Value, "test").OrTimeout();

// check that 'secondConnection' has received the group send
var hubMessage = await secondClient.ReadAsync().OrTimeout();
Expand Down
14 changes: 11 additions & 3 deletions test/Microsoft.AspNetCore.SignalR.Tests/TestClient.cs
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ public class TestClient : IDisposable, IInvocationBinder
public Channel<byte[]> Application { get; }
public Task Connected => ((TaskCompletionSource<bool>)Connection.Metadata["ConnectedTask"]).Task;

public TestClient(bool synchronousCallbacks = false, IHubProtocol protocol = null)
public TestClient(bool synchronousCallbacks = false, IHubProtocol protocol = null, bool addClaimId = false)
{
var options = new ChannelOptimizations { AllowSynchronousContinuations = synchronousCallbacks };
var transportToApplication = Channel.CreateUnbounded<byte[]>(options);
Expand All @@ -38,7 +38,15 @@ public TestClient(bool synchronousCallbacks = false, IHubProtocol protocol = nul
_transport = ChannelConnection.Create<byte[]>(input: transportToApplication, output: applicationToTransport);

Connection = new DefaultConnectionContext(Guid.NewGuid().ToString(), _transport, Application);
Connection.User = new ClaimsPrincipal(new ClaimsIdentity(new[] { new Claim(ClaimTypes.Name, Interlocked.Increment(ref _id).ToString()) }));

var claimValue = Interlocked.Increment(ref _id).ToString();
var claims = new List<Claim>{ new Claim(ClaimTypes.Name, claimValue) };
if (addClaimId)
{
claims.Add(new Claim(ClaimTypes.NameIdentifier, claimValue));
}

Connection.User = new ClaimsPrincipal(new ClaimsIdentity(claims));
Connection.Metadata["ConnectedTask"] = new TaskCompletionSource<bool>();

protocol = protocol ?? new JsonHubProtocol();
Expand Down Expand Up @@ -182,4 +190,4 @@ Type IInvocationBinder.GetReturnType(string invocationId)
return typeof(object);
}
}
}
}