Skip to content

Commit

Permalink
remove NetworkConnection.isAuthenticated (#167)
Browse files Browse the repository at this point in the history
* remove NetworkConnection.isAuthenticated

this field is not needed anymore,  we simply don't register
handlers until the connection is authenticated.
Eliminates extra state tracking from NetworkConnections

BREAKING CHANGE: Remove isAuthenticated

* Fix typo

* Fix smells

* Remove smells
  • Loading branch information
paulpach committed Apr 10, 2020
1 parent eaaf59f commit 8a0e0b3
Show file tree
Hide file tree
Showing 10 changed files with 44 additions and 57 deletions.
10 changes: 2 additions & 8 deletions Assets/Mirror/Authenticators/BasicAuthenticator.cs
Original file line number Diff line number Diff line change
Expand Up @@ -30,12 +30,12 @@ public class AuthResponseMessage : MessageBase
public override void OnServerAuthenticate(NetworkConnection conn)
{
// wait for AuthRequestMessage from client
conn.RegisterHandler<AuthRequestMessage>(OnAuthRequestMessage, false);
conn.RegisterHandler<AuthRequestMessage>(OnAuthRequestMessage);
}

public override void OnClientAuthenticate(NetworkConnection conn)
{
conn.RegisterHandler<AuthResponseMessage>(OnAuthResponseMessage, false);
conn.RegisterHandler<AuthResponseMessage>(OnAuthResponseMessage);

var authRequestMessage = new AuthRequestMessage
{
Expand Down Expand Up @@ -76,9 +76,6 @@ public void OnAuthRequestMessage(NetworkConnection conn, AuthRequestMessage msg)

conn.Send(authResponseMessage);

// must set NetworkConnection isAuthenticated = false
conn.isAuthenticated = false;

// disconnect the client after 1 second so that response message gets delivered
StartCoroutine(DelayedDisconnect(conn, 1));
}
Expand All @@ -103,9 +100,6 @@ public void OnAuthResponseMessage(NetworkConnection conn, AuthResponseMessage ms
{
Debug.LogErrorFormat("Authentication Response: {0}", msg.Message);

// Set this on the client for local reference
conn.isAuthenticated = false;

// disconnect the client
conn.Disconnect();
}
Expand Down
25 changes: 22 additions & 3 deletions Assets/Mirror/Authenticators/TimeoutAuthenticator.cs
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
using System.Collections;
using System.Collections.Generic;
using UnityEngine;

namespace Mirror.Authenticators
Expand All @@ -17,19 +18,36 @@ public class TimeoutAuthenticator : NetworkAuthenticator

public void Awake()
{
Authenticator.OnClientAuthenticated += base.OnClientAuthenticate;
Authenticator.OnServerAuthenticated += base.OnServerAuthenticate;
Authenticator.OnClientAuthenticated += HandleClientAuthenticated;
Authenticator.OnServerAuthenticated += HandleServerAuthenticated;
}

private readonly HashSet<NetworkConnection> pendingAuthentication = new HashSet<NetworkConnection>();

private void HandleServerAuthenticated(NetworkConnection connection)
{
pendingAuthentication.Remove(connection);
base.OnClientAuthenticate(connection);
}

private void HandleClientAuthenticated(NetworkConnection connection)
{
pendingAuthentication.Remove(connection);
base.OnServerAuthenticate(connection);
}

public override void OnClientAuthenticate(NetworkConnection conn)
{
pendingAuthentication.Add(conn);
Authenticator.OnClientAuthenticate(conn);

if (Timeout > 0)
StartCoroutine(BeginAuthentication(conn));
}

public override void OnServerAuthenticate(NetworkConnection conn)
{
pendingAuthentication.Add(conn);
Authenticator.OnServerAuthenticate(conn);
if (Timeout > 0)
StartCoroutine(BeginAuthentication(conn));
Expand All @@ -41,10 +59,11 @@ IEnumerator BeginAuthentication(NetworkConnection conn)

yield return new WaitForSecondsRealtime(Timeout);

if (!conn.isAuthenticated)
if (pendingAuthentication.Contains(conn))
{
if (LogFilter.Debug) Debug.Log($"Authentication Timeout {conn}");

pendingAuthentication.Remove(conn);
conn.Disconnect();
}
}
Expand Down
4 changes: 2 additions & 2 deletions Assets/Mirror/Runtime/INetworkConnection.cs
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,10 @@ public interface INetworkConnection
{
void Disconnect();

void RegisterHandler<T>(Action<NetworkConnection, T> handler, bool requireAuthentication = true)
void RegisterHandler<T>(Action<NetworkConnection, T> handler)
where T : IMessageBase, new();

void RegisterHandler<T>(Action<T> handler, bool requireAuthentication = true) where T : IMessageBase, new();
void RegisterHandler<T>(Action<T> handler) where T : IMessageBase, new();

void UnregisterHandler<T>() where T : IMessageBase;

Expand Down
7 changes: 2 additions & 5 deletions Assets/Mirror/Runtime/NetworkClient.cs
Original file line number Diff line number Diff line change
Expand Up @@ -231,9 +231,6 @@ async Task OnConnected()

public void OnAuthenticated(NetworkConnection conn)
{
// set connection to authenticated
conn.isAuthenticated = true;

Authenticated?.Invoke(conn);
}

Expand Down Expand Up @@ -279,7 +276,7 @@ internal void RegisterHostHandlers(NetworkConnection connection)
{
connection.RegisterHandler<ObjectDestroyMessage>(OnHostClientObjectDestroy);
connection.RegisterHandler<ObjectHideMessage>(OnHostClientObjectHide);
connection.RegisterHandler<NetworkPongMessage>(msg => { }, false);
connection.RegisterHandler<NetworkPongMessage>(msg => { });
connection.RegisterHandler<SpawnMessage>(OnHostClientSpawn);
// host mode reuses objects in the server
// so we don't need to spawn them
Expand All @@ -294,7 +291,7 @@ internal void RegisterMessageHandlers(NetworkConnection connection)
{
connection.RegisterHandler<ObjectDestroyMessage>(OnObjectDestroy);
connection.RegisterHandler<ObjectHideMessage>(OnObjectHide);
connection.RegisterHandler<NetworkPongMessage>(Time.OnClientPong, false);
connection.RegisterHandler<NetworkPongMessage>(Time.OnClientPong);
connection.RegisterHandler<SpawnMessage>(OnSpawn);
connection.RegisterHandler<ObjectSpawnStartedMessage>(OnObjectSpawnStarted);
connection.RegisterHandler<ObjectSpawnFinishedMessage>(OnObjectSpawnFinished);
Expand Down
28 changes: 7 additions & 21 deletions Assets/Mirror/Runtime/NetworkConnection.cs
Original file line number Diff line number Diff line change
Expand Up @@ -39,11 +39,6 @@ public class NetworkConnection : INetworkConnection
/// </remarks>
private readonly IConnection connection;

/// <summary>
/// Flag that indicates the client has been authenticated.
/// </summary>
public bool isAuthenticated;

/// <summary>
/// General purpose object to hold authentication data, character selection, tokens, etc.
/// associated with the connection for reference after Authentication completes.
Expand Down Expand Up @@ -110,7 +105,7 @@ public virtual void Disconnect()
connection.Disconnect();
}

private static NetworkMessageDelegate MessageHandler<T, C>(Action<C, T> handler, bool requireAuthenication)
private static NetworkMessageDelegate MessageHandler<T, C>(Action<C, T> handler)
where T : IMessageBase, new()
where C : NetworkConnection
{
Expand All @@ -128,18 +123,9 @@ void AdapterFunction(NetworkConnection conn, NetworkReader reader, int channelId
//
// let's catch them all and then disconnect that connection to avoid
// further attacks.
T message = default;
T message = default(T) != null ? default(T) : new T();
try
{
if (requireAuthenication && !conn.isAuthenticated)
{
// message requires authentication, but the connection was not authenticated
Debug.LogWarning($"Closing connection: {conn}. Received message {typeof(T)} that required authentication, but the user has not authenticated yet");
conn.Disconnect();
return;
}

message = default(T) != null ? default(T) : new T();
{
message.Deserialize(reader);
}
finally
Expand All @@ -159,15 +145,15 @@ void AdapterFunction(NetworkConnection conn, NetworkReader reader, int channelId
/// <typeparam name="T">Message type</typeparam>
/// <param name="handler">Function handler which will be invoked for when this message type is received.</param>
/// <param name="requireAuthentication">True if the message requires an authenticated connection</param>
public void RegisterHandler<T>(Action<NetworkConnection, T> handler, bool requireAuthentication = true)
public void RegisterHandler<T>(Action<NetworkConnection, T> handler)
where T : IMessageBase, new()
{
int msgType = MessagePacker.GetId<T>();
if (LogFilter.Debug && messageHandlers.ContainsKey(msgType))
{
Debug.Log("NetworkServer.RegisterHandler replacing " + msgType);
}
messageHandlers[msgType] = MessageHandler(handler, requireAuthentication);
messageHandlers[msgType] = MessageHandler(handler);
}

/// <summary>
Expand All @@ -177,9 +163,9 @@ public void RegisterHandler<T>(Action<NetworkConnection, T> handler, bool requir
/// <typeparam name="T">Message type</typeparam>
/// <param name="handler">Function handler which will be invoked for when this message type is received.</param>
/// <param name="requireAuthentication">True if the message requires an authenticated connection</param>
public void RegisterHandler<T>(Action<T> handler, bool requireAuthentication = true) where T : IMessageBase, new()
public void RegisterHandler<T>(Action<T> handler) where T : IMessageBase, new()
{
RegisterHandler<T>((_, value) => { handler(value); }, requireAuthentication);
RegisterHandler<T>((_, value) => { handler(value); });
}

/// <summary>
Expand Down
2 changes: 1 addition & 1 deletion Assets/Mirror/Runtime/NetworkManager.cs
Original file line number Diff line number Diff line change
Expand Up @@ -669,7 +669,7 @@ void OnServerRemovePlayerMessageInternal(NetworkConnection conn, RemovePlayerMes
void RegisterClientMessages(NetworkConnection connection)
{
connection.RegisterHandler<NotReadyMessage>(OnClientNotReadyMessageInternal);
connection.RegisterHandler<SceneMessage>(OnClientSceneInternal, false);
connection.RegisterHandler<SceneMessage>(OnClientSceneInternal);
}

// called after successful authentication
Expand Down
9 changes: 4 additions & 5 deletions Assets/Mirror/Runtime/NetworkServer.cs
Original file line number Diff line number Diff line change
Expand Up @@ -147,7 +147,6 @@ internal void RegisterMessageHandlers(NetworkConnection connection)
connection.RegisterHandler<ReadyMessage>(OnClientReadyMessage);
connection.RegisterHandler<CommandMessage>(OnCommandMessage);
connection.RegisterHandler<RemovePlayerMessage>(OnRemovePlayerMessage);
connection.RegisterHandler<NetworkPingMessage>(Time.OnServerPing, false);
}

/// <summary>
Expand Down Expand Up @@ -208,7 +207,7 @@ private void Cleanup()
{

if (authenticator != null)
{
{
authenticator.OnServerAuthenticated -= OnAuthenticated;
Connected.RemoveListener(authenticator.OnServerAuthenticateInternal);
}
Expand Down Expand Up @@ -236,7 +235,7 @@ public void AddConnection(NetworkConnection conn)
// connection cannot be null here or conn.connectionId
// would throw NRE
connections.Add(conn);
RegisterMessageHandlers(conn);
conn.RegisterHandler<NetworkPingMessage>(Time.OnServerPing);
}
}

Expand Down Expand Up @@ -426,8 +425,8 @@ internal void OnAuthenticated(NetworkConnection conn)
{
if (LogFilter.Debug) Debug.Log("Server authenticate client:" + conn);

// set connection to authenticated
conn.isAuthenticated = true;
// connection has been authenticated, now we can handle other messages
RegisterMessageHandlers(conn);

Authenticated?.Invoke(conn);
}
Expand Down
5 changes: 1 addition & 4 deletions Assets/Mirror/Tests/Common/LocalConnections.cs
Original file line number Diff line number Diff line change
Expand Up @@ -3,15 +3,12 @@ namespace Mirror.Tests

public static class LocalConnections
{
public static (NetworkConnection, NetworkConnection) PipedConnections(bool authenticated = false)
public static (NetworkConnection, NetworkConnection) PipedConnections()
{
(IConnection c1, IConnection c2) = PipeConnection.CreatePipe();
var toServer = new NetworkConnection(c2);
var toClient = new NetworkConnection(c1);

toServer.isAuthenticated = authenticated;
toClient.isAuthenticated = authenticated;

return (toServer, toClient);
}

Expand Down
6 changes: 2 additions & 4 deletions Assets/Mirror/Tests/Runtime/NetworkServerTest.cs
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
using System;
using System.Collections;
using System.Linq;
using System.Threading.Tasks;
using NSubstitute;
using NUnit.Framework;
Expand Down Expand Up @@ -172,8 +171,7 @@ public void ConnectedEventTest()
public void ConnectionTest()
{
transport.AcceptCompletionSource.SetResult(tconn42);
NetworkConnection conn = server.connections.First();
Assert.That(conn.isAuthenticated);
Assert.That(server.connections, Has.Count.EqualTo(1));
}

[Test]
Expand Down Expand Up @@ -373,7 +371,7 @@ public void HideForConnection()
{
// add connection

NetworkConnection connectionToClient = Substitute.For<NetworkConnection>((IConnection) null);
NetworkConnection connectionToClient = Substitute.For<NetworkConnection>((IConnection)null);

NetworkIdentity identity = new GameObject().AddComponent<NetworkIdentity>();

Expand Down
5 changes: 1 addition & 4 deletions Assets/Mirror/Tests/Runtime/NetworkServerTests.cs
Original file line number Diff line number Diff line change
Expand Up @@ -33,9 +33,6 @@ public class NetworkServerTests
connectionToClient = server.connections.First();
connectionToServer = new NetworkConnection(tconn);
connectionToClient.isAuthenticated = true;
connectionToServer.isAuthenticated = true;
message = new WovenTestMessage
{
IntValue = 1,
Expand Down Expand Up @@ -182,7 +179,7 @@ public IEnumerator RegisterMessage2()

Action<NetworkConnection, WovenTestMessage> func = Substitute.For<Action<NetworkConnection, WovenTestMessage>>();

connectionToClient.RegisterHandler<WovenTestMessage> (func);
connectionToClient.RegisterHandler<WovenTestMessage>(func);

connectionToServer.Send(message);

Expand Down

0 comments on commit 8a0e0b3

Please sign in to comment.