Skip to content
This repository has been archived by the owner on Dec 18, 2018. It is now read-only.

fix issue with incorrect user detection when Invoking for User #747

Merged
merged 14 commits into from
Oct 6, 2017
Merged
Show file tree
Hide file tree
Changes from 10 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ public class DefaultHubLifetimeManager<THub> : HubLifetimeManager<THub>
{
private long _nextInvocationId = 0;
private readonly HubConnectionList _connections = new HubConnectionList();

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: remove empty space

public override Task AddGroupAsync(string connectionId, string groupName)
{
if (connectionId == null)
Expand Down Expand Up @@ -138,17 +138,14 @@ public override Task InvokeGroupAsync(string groupName, string methodName, objec

public override Task InvokeUserAsync(string userId, string methodName, object[] args)
{
return InvokeAllWhere(methodName, args, connection =>
{
return string.Equals(connection.User.Identity.Name, userId, StringComparison.Ordinal);
});
return InvokeAllWhere(methodName, args, connection =>
string.Equals(connection.UserIdentifier, userId, StringComparison.Ordinal));
}

public override Task OnConnectedAsync(HubConnectionContext connection)
{
// Set the hub groups feature
connection.Features.Set<IHubGroupsFeature>(new HubGroupsFeature());

_connections.Add(connection);
return Task.CompletedTask;
}
Expand All @@ -166,7 +163,7 @@ private async Task WriteAsync(HubConnectionContext connection, HubMessage hubMes
if (connection.Output.TryWrite(hubMessage))
{
break;
}
}
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: remove space

}
}

Expand Down Expand Up @@ -194,4 +191,4 @@ private class HubGroupsFeature : IHubGroupsFeature
public HashSet<string> Groups { get; } = new HashSet<string>(StringComparer.OrdinalIgnoreCase);
}
}
}
}
15 changes: 15 additions & 0 deletions src/Microsoft.AspNetCore.SignalR.Core/DefaultUserIdProvider.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
// Copyright (c) .NET Foundation. All rights reserved.
// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information.

using System.Security.Claims;

namespace Microsoft.AspNetCore.SignalR.Core
{
public class DefaultUserIdProvider : IUserIdProvider
{
public string GetUserId(HubConnectionContext connection)
{
return connection.User.FindFirst(ClaimTypes.NameIdentifier)?.Value;
}
}
}
18 changes: 17 additions & 1 deletion src/Microsoft.AspNetCore.SignalR.Core/HubConnectionContext.cs
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
using System.Threading.Tasks;
using System.Threading.Tasks.Channels;
using Microsoft.AspNetCore.Http.Features;
using Microsoft.AspNetCore.SignalR.Core;
using Microsoft.AspNetCore.SignalR.Features;
using Microsoft.AspNetCore.SignalR.Internal;
using Microsoft.AspNetCore.SignalR.Internal.Protocol;
Expand All @@ -25,12 +26,15 @@ public class HubConnectionContext
private readonly ConnectionContext _connectionContext;
private readonly CancellationTokenSource _connectionAbortedTokenSource = new CancellationTokenSource();
private readonly TaskCompletionSource<object> _abortCompletedTcs = new TaskCompletionSource<object>();
private readonly IUserIdProvider _userIdProvider;
private string _userIdCache = null;
Copy link
Member

@davidfowl davidfowl Oct 6, 2017

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Super nit: Just call this _userId.


public HubConnectionContext(WritableChannel<HubMessage> output, ConnectionContext connectionContext)
public HubConnectionContext(WritableChannel<HubMessage> output, ConnectionContext connectionContext, IUserIdProvider userIdProvider)
{
_output = output;
_connectionContext = connectionContext;
ConnectionAbortedToken = _connectionAbortedTokenSource.Token;
_userIdProvider = userIdProvider;
}

private IHubFeature HubFeature => Features.Get<IHubFeature>();
Expand Down Expand Up @@ -67,6 +71,18 @@ public virtual void Abort()
Task.Factory.StartNew(_abortedCallback, this);
}

public string UserIdentifier
Copy link
Member

@davidfowl davidfowl Oct 6, 2017

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not comfortable with the lazy evaluation of this. I think it should be set by the HubEndPoint after negotiate.

Copy link
Contributor Author

@ivankarpey ivankarpey Oct 6, 2017

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Well, I don't really like the idea of having public setter for that sort of things. Since I think that for end-user connection context should be something "immutable". Don't you think so?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agree with that too, make it an internal setter. We'll need to see what mocking this object looks like though.

{
get
{
if (String.IsNullOrEmpty(_userIdCache))
{
_userIdCache = _userIdProvider.GetUserId(this);
}
return _userIdCache;
}
}

internal void Abort(Exception exception)
{
AbortException = ExceptionDispatchInfo.Capture(exception);
Expand Down
8 changes: 6 additions & 2 deletions src/Microsoft.AspNetCore.SignalR.Core/HubEndPoint.cs
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
using System.Threading.Tasks;
using System.Threading.Tasks.Channels;
using Microsoft.AspNetCore.Authorization;
using Microsoft.AspNetCore.SignalR.Core;
using Microsoft.AspNetCore.SignalR.Core.Internal;
using Microsoft.AspNetCore.SignalR.Features;
using Microsoft.AspNetCore.SignalR.Internal;
Expand Down Expand Up @@ -39,20 +40,23 @@ public class HubEndPoint<THub> : IInvocationBinder where THub : Hub
private readonly IServiceScopeFactory _serviceScopeFactory;
private readonly IHubProtocolResolver _protocolResolver;
private readonly IOptions<HubOptions> _hubOptions;
private readonly IUserIdProvider _userIdProvider;

public HubEndPoint(HubLifetimeManager<THub> lifetimeManager,
IHubProtocolResolver protocolResolver,
IHubContext<THub> hubContext,
IOptions<HubOptions> hubOptions,
ILogger<HubEndPoint<THub>> logger,
IServiceScopeFactory serviceScopeFactory)
IServiceScopeFactory serviceScopeFactory,
IUserIdProvider userIdProvider)
{
_protocolResolver = protocolResolver;
_lifetimeManager = lifetimeManager;
_hubContext = hubContext;
_hubOptions = hubOptions;
_logger = logger;
_serviceScopeFactory = serviceScopeFactory;
_userIdProvider = userIdProvider;

DiscoverHubMethods();
}
Expand All @@ -65,7 +69,7 @@ public async Task OnConnectedAsync(ConnectionContext connection)
// all the relevant state for a SignalR Hub connection.
connection.Features.Set<IHubFeature>(new HubFeature());

var connectionContext = new HubConnectionContext(output, connection);
var connectionContext = new HubConnectionContext(output, connection, _userIdProvider);

if (!await ProcessNegotiate(connectionContext))
{
Expand Down
10 changes: 10 additions & 0 deletions src/Microsoft.AspNetCore.SignalR.Core/IUserIdProvider.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
// Copyright (c) .NET Foundation. All rights reserved.
// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information.

namespace Microsoft.AspNetCore.SignalR.Core
{
public interface IUserIdProvider
{
string GetUserId(HubConnectionContext connection);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information.

using Microsoft.AspNetCore.SignalR;
using Microsoft.AspNetCore.SignalR.Core;
using Microsoft.AspNetCore.SignalR.Internal;

namespace Microsoft.Extensions.DependencyInjection
Expand All @@ -15,6 +16,7 @@ public static ISignalRBuilder AddSignalRCore(this IServiceCollection services)
services.AddSingleton(typeof(IHubContext<>), typeof(HubContext<>));
services.AddSingleton(typeof(IHubContext<,>), typeof(HubContext<,>));
services.AddSingleton(typeof(HubEndPoint<>), typeof(HubEndPoint<>));
services.AddSingleton(typeof(IUserIdProvider), typeof(DefaultUserIdProvider));
services.AddScoped(typeof(IHubActivator<>), typeof(DefaultHubActivator<>));

services.AddAuthorization();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -245,9 +245,9 @@ public override Task OnConnectedAsync(HubConnectionContext connection)
previousConnectionTask = WriteAsync(connection, message);
});

if (connection.User.Identity.IsAuthenticated)
if (!String.IsNullOrEmpty(connection.UserIdentifier))
Copy link
Member

@davidfowl davidfowl Oct 6, 2017

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: lowercase s in string

{
var userChannel = _channelNamePrefix + ".user." + connection.User.Identity.Name;
var userChannel = _channelNamePrefix + ".user." + connection.UserIdentifier;
redisSubscriptions.Add(userChannel);

var previousUserTask = Task.CompletedTask;
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
using System.Threading.Tasks;
using System.Threading.Tasks.Channels;
using Microsoft.AspNetCore.SignalR.Core;
using Microsoft.AspNetCore.SignalR.Internal.Protocol;
using Xunit;

Expand All @@ -17,8 +18,8 @@ public async Task InvokeAllAsyncWritesToAllConnectionsOutput()
var output2 = Channel.CreateUnbounded<HubMessage>();

var manager = new DefaultHubLifetimeManager<MyHub>();
var connection1 = new HubConnectionContext(output1, client1.Connection);
var connection2 = new HubConnectionContext(output2, client2.Connection);
var connection1 = new HubConnectionContext(output1, client1.Connection, new DefaultUserIdProvider());
var connection2 = new HubConnectionContext(output2, client2.Connection, new DefaultUserIdProvider());

await manager.OnConnectedAsync(connection1);
await manager.OnConnectedAsync(connection2);
Expand Down Expand Up @@ -51,8 +52,8 @@ public async Task InvokeAllAsyncDoesNotWriteToDisconnectedConnectionsOutput()
var output2 = Channel.CreateUnbounded<HubMessage>();

var manager = new DefaultHubLifetimeManager<MyHub>();
var connection1 = new HubConnectionContext(output1, client1.Connection);
var connection2 = new HubConnectionContext(output2, client2.Connection);
var connection1 = new HubConnectionContext(output1, client1.Connection, new DefaultUserIdProvider());
var connection2 = new HubConnectionContext(output2, client2.Connection, new DefaultUserIdProvider());

await manager.OnConnectedAsync(connection1);
await manager.OnConnectedAsync(connection2);
Expand Down Expand Up @@ -82,8 +83,8 @@ public async Task InvokeGroupAsyncWritesToAllConnectionsInGroupOutput()
var output2 = Channel.CreateUnbounded<HubMessage>();

var manager = new DefaultHubLifetimeManager<MyHub>();
var connection1 = new HubConnectionContext(output1, client1.Connection);
var connection2 = new HubConnectionContext(output2, client2.Connection);
var connection1 = new HubConnectionContext(output1, client1.Connection, new DefaultUserIdProvider());
var connection2 = new HubConnectionContext(output2, client2.Connection, new DefaultUserIdProvider());

await manager.OnConnectedAsync(connection1);
await manager.OnConnectedAsync(connection2);
Expand All @@ -110,7 +111,7 @@ public async Task InvokeConnectionAsyncWritesToConnectionOutput()
{
var output = Channel.CreateUnbounded<HubMessage>();
var manager = new DefaultHubLifetimeManager<MyHub>();
var connection = new HubConnectionContext(output, client.Connection);
var connection = new HubConnectionContext(output, client.Connection, new DefaultUserIdProvider());

await manager.OnConnectedAsync(connection);

Expand Down
6 changes: 3 additions & 3 deletions test/Microsoft.AspNetCore.SignalR.Tests/HubEndpointTests.cs
Original file line number Diff line number Diff line change
Expand Up @@ -850,15 +850,15 @@ public async Task HubsCanSendToUser(Type hubType)

dynamic endPoint = serviceProvider.GetService(GetEndPointType(hubType));

using (var firstClient = new TestClient())
using (var secondClient = new TestClient())
using (var firstClient = new TestClient(addClaimId: true))
using (var secondClient = new TestClient(addClaimId: true))
{
Task firstEndPointTask = endPoint.OnConnectedAsync(firstClient.Connection);
Task secondEndPointTask = endPoint.OnConnectedAsync(secondClient.Connection);

await Task.WhenAll(firstClient.Connected, secondClient.Connected).OrTimeout();

await firstClient.SendInvocationAsync("ClientSendMethod", secondClient.Connection.User.Identity.Name, "test").OrTimeout();
await firstClient.SendInvocationAsync("ClientSendMethod", secondClient.Connection.User.FindFirst(ClaimTypes.NameIdentifier)?.Value, "test").OrTimeout();

// check that 'secondConnection' has received the group send
var hubMessage = await secondClient.ReadAsync().OrTimeout();
Expand Down
14 changes: 11 additions & 3 deletions test/Microsoft.AspNetCore.SignalR.Tests/TestClient.cs
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ public class TestClient : IDisposable, IInvocationBinder
public Channel<byte[]> Application { get; }
public Task Connected => ((TaskCompletionSource<bool>)Connection.Metadata["ConnectedTask"]).Task;

public TestClient(bool synchronousCallbacks = false, IHubProtocol protocol = null)
public TestClient(bool synchronousCallbacks = false, IHubProtocol protocol = null, bool addClaimId = false)
{
var options = new ChannelOptimizations { AllowSynchronousContinuations = synchronousCallbacks };
var transportToApplication = Channel.CreateUnbounded<byte[]>(options);
Expand All @@ -38,7 +38,15 @@ public TestClient(bool synchronousCallbacks = false, IHubProtocol protocol = nul
_transport = ChannelConnection.Create<byte[]>(input: transportToApplication, output: applicationToTransport);

Connection = new DefaultConnectionContext(Guid.NewGuid().ToString(), _transport, Application);
Connection.User = new ClaimsPrincipal(new ClaimsIdentity(new[] { new Claim(ClaimTypes.Name, Interlocked.Increment(ref _id).ToString()) }));

var claimValue = Interlocked.Increment(ref _id).ToString();
var claims = new List<Claim>{ new Claim(ClaimTypes.Name, claimValue) };
if (addClaimId)
{
claims.Add(new Claim(ClaimTypes.NameIdentifier, claimValue));
}

Connection.User = new ClaimsPrincipal(new ClaimsIdentity(claims));
Connection.Metadata["ConnectedTask"] = new TaskCompletionSource<bool>();

protocol = protocol ?? new JsonHubProtocol();
Expand Down Expand Up @@ -182,4 +190,4 @@ Type IInvocationBinder.GetReturnType(string invocationId)
return typeof(object);
}
}
}
}