From dcf45bd6ab69521536213e5577f762911f3328cd Mon Sep 17 00:00:00 2001 From: Andrew Allbright Date: Tue, 9 Jan 2024 09:04:57 -0500 Subject: [PATCH] Finish refactor --- .../Transports/SSL/CertificateSettings.cs | 8 ++- Assets/Mirror/Transports/SSL/SSLSettings.cs | 26 ---------- .../Transports/SSL/TransportSecurity.cs | 52 ++++++++++++++++++- .../SimpleWeb/Server/ICreateStream.cs | 2 + .../SimpleWeb/Server/ServerSslHelper.cs | 2 +- .../SimpleWeb/Server/SimpleWebServer.cs | 4 +- .../SimpleWeb/SimpleWebTransport.cs | 10 ++-- 7 files changed, 66 insertions(+), 38 deletions(-) diff --git a/Assets/Mirror/Transports/SSL/CertificateSettings.cs b/Assets/Mirror/Transports/SSL/CertificateSettings.cs index 6d4e158d6e1..3bab7195193 100644 --- a/Assets/Mirror/Transports/SSL/CertificateSettings.cs +++ b/Assets/Mirror/Transports/SSL/CertificateSettings.cs @@ -3,7 +3,6 @@ using System.Security.Cryptography.X509Certificates; using UnityEditor; using UnityEngine; -using UnityEngine.Serialization; namespace Mirror { @@ -32,6 +31,11 @@ public X509Certificate2 Certificate } } + public string CertPassword() + { + return File.ReadAllText(PasswordFilePath); + } + private X509Certificate2 NewPasswordProtectedCertificate() { if (!ValidateCertificatePath(CertificatePath)) @@ -44,7 +48,7 @@ private X509Certificate2 NewPasswordProtectedCertificate() Debug.LogError("Password file path is invalid (" + PasswordFilePath + "). Unable to create certificate."); return null; } - string password = File.ReadAllText(PasswordFilePath); + string password = CertPassword(); return new X509Certificate2(CertificatePath, password); } diff --git a/Assets/Mirror/Transports/SSL/SSLSettings.cs b/Assets/Mirror/Transports/SSL/SSLSettings.cs index af809528821..228b18cfb90 100644 --- a/Assets/Mirror/Transports/SSL/SSLSettings.cs +++ b/Assets/Mirror/Transports/SSL/SSLSettings.cs @@ -1,10 +1,5 @@ using System; -using System.IO; -using System.Net.Security; -using System.Net.Sockets; using System.Security.Authentication; -using System.Security.Cryptography.X509Certificates; -using UnityEditor; using UnityEngine; namespace Mirror @@ -17,26 +12,5 @@ public class SSLSettings [Tooltip("Protocol to use for ssl (default: TLS 1.2)")] public SslProtocols SSLProtocol = SslProtocols.Tls12; - - public Stream CreateStream(NetworkStream stream, X509Certificate2 certificate) - { - if (!SSLEnabled) - { - Debug.LogError("SSL is not enabled. Unable to create stream."); - return null; - } - - SslStream sslStream = new(stream, true, AcceptClient); - sslStream.AuthenticateAsServer(certificate, false, SSLProtocol, false); - - return sslStream; - } - - // Always accept client - private bool AcceptClient(object sender, X509Certificate certificate, X509Chain chain, - SslPolicyErrors sslPolicyErrors) - { - return true; - } } } diff --git a/Assets/Mirror/Transports/SSL/TransportSecurity.cs b/Assets/Mirror/Transports/SSL/TransportSecurity.cs index 31d33307882..f7eb2a4567b 100644 --- a/Assets/Mirror/Transports/SSL/TransportSecurity.cs +++ b/Assets/Mirror/Transports/SSL/TransportSecurity.cs @@ -1,10 +1,60 @@ +using System; +using System.IO; +using System.Net.Security; +using System.Net.Sockets; +using System.Security.Cryptography.X509Certificates; +using Mirror.SimpleWeb; using UnityEngine; namespace Mirror { - public class TransportSecurity : MonoBehaviour + public class TransportSecurity : MonoBehaviour, ICreateStream { public SSLSettings sslSettings; public CertificateSettings certificateSettings; + public SSLSettings GetSslSettings() + { + return sslSettings; + } + public bool TryCreateStream(IConnection conn) + { + NetworkStream stream = conn.Client.GetStream(); + if (sslSettings.SSLEnabled) + { + try + { + conn.Stream = CreateStream(stream); + return true; + } + catch (Exception e) + { + Debug.LogError($"[SWT-ServerSslHelper]: Create SSLStream Failed: {e.Message}"); + return false; + } + } + + conn.Stream = stream; + return true; + } + public Stream CreateStream(NetworkStream stream) + { + if (!sslSettings.SSLEnabled) + { + Debug.LogError("SSL is not enabled. Unable to create stream."); + return null; + } + + SslStream sslStream = new SslStream(stream, true, AcceptClient); + sslStream.AuthenticateAsServer(certificateSettings.Certificate, false, sslSettings.SSLProtocol, false); + + return sslStream; + } + + // Always accept client + private bool AcceptClient(object sender, X509Certificate certificate, X509Chain chain, + SslPolicyErrors sslPolicyErrors) + { + return true; + } } } diff --git a/Assets/Mirror/Transports/SimpleWeb/SimpleWeb/Server/ICreateStream.cs b/Assets/Mirror/Transports/SimpleWeb/SimpleWeb/Server/ICreateStream.cs index e88d1ea3ec2..5bfca603b1a 100644 --- a/Assets/Mirror/Transports/SimpleWeb/SimpleWeb/Server/ICreateStream.cs +++ b/Assets/Mirror/Transports/SimpleWeb/SimpleWeb/Server/ICreateStream.cs @@ -1,3 +1,5 @@ +using System.Security.Cryptography.X509Certificates; + namespace Mirror.SimpleWeb { public interface ICreateStream diff --git a/Assets/Mirror/Transports/SimpleWeb/SimpleWeb/Server/ServerSslHelper.cs b/Assets/Mirror/Transports/SimpleWeb/SimpleWeb/Server/ServerSslHelper.cs index 5652fd7083c..5f50b393354 100644 --- a/Assets/Mirror/Transports/SimpleWeb/SimpleWeb/Server/ServerSslHelper.cs +++ b/Assets/Mirror/Transports/SimpleWeb/SimpleWeb/Server/ServerSslHelper.cs @@ -22,7 +22,7 @@ public SslConfig(bool enabled, string certPath, string certPassword, SslProtocol this.sslProtocols = sslProtocols; } } - internal class ServerSslHelper: ICreateStream + public class ServerSslHelper: ICreateStream { readonly SslConfig config; readonly X509Certificate2 certificate; diff --git a/Assets/Mirror/Transports/SimpleWeb/SimpleWeb/Server/SimpleWebServer.cs b/Assets/Mirror/Transports/SimpleWeb/SimpleWeb/Server/SimpleWebServer.cs index 008780ee679..aae0c43f680 100644 --- a/Assets/Mirror/Transports/SimpleWeb/SimpleWeb/Server/SimpleWebServer.cs +++ b/Assets/Mirror/Transports/SimpleWeb/SimpleWeb/Server/SimpleWebServer.cs @@ -17,13 +17,13 @@ public class SimpleWebServer public bool Active { get; private set; } - public SimpleWebServer(int maxMessagesPerTick, TcpConfig tcpConfig, int maxMessageSize, int handshakeMaxSize, SslConfig sslConfig) + public SimpleWebServer(int maxMessagesPerTick, TcpConfig tcpConfig, int maxMessageSize, int handshakeMaxSize, ICreateStream streamCreator) { this.maxMessagesPerTick = maxMessagesPerTick; // use max because bufferpool is used for both messages and handshake int max = Math.Max(maxMessageSize, handshakeMaxSize); bufferPool = new BufferPool(5, 20, max); - server = new WebSocketServer(tcpConfig, maxMessageSize, handshakeMaxSize, bufferPool, new ServerSslHelper(sslConfig)); + server = new WebSocketServer(tcpConfig, maxMessageSize, handshakeMaxSize, bufferPool, streamCreator); } public void Start(ushort port) diff --git a/Assets/Mirror/Transports/SimpleWeb/SimpleWebTransport.cs b/Assets/Mirror/Transports/SimpleWeb/SimpleWebTransport.cs index 63954875a6e..7d9b5491f56 100644 --- a/Assets/Mirror/Transports/SimpleWeb/SimpleWebTransport.cs +++ b/Assets/Mirror/Transports/SimpleWeb/SimpleWebTransport.cs @@ -269,16 +269,14 @@ public override void ServerStart() if (ServerActive()) Log.Warn("[SWT-ServerStart]: Server Already Started"); - SslConfig config; if (transportSecurity && transportSecurity.enabled) - { - config = transportSecurity.GetSslSettings(); - } + server = new SimpleWebServer(serverMaxMsgsPerTick, TcpConfig, maxMessageSize, maxHandshakeSize, transportSecurity); else { - config = new SslConfig(false, "", "", System.Security.Authentication.SslProtocols.None); + SslConfig sslConfig = new SslConfig(false, "", "", System.Security.Authentication.SslProtocols.None); + ServerSslHelper serverSslHelper = new ServerSslHelper(sslConfig); + server = new SimpleWebServer(serverMaxMsgsPerTick, TcpConfig, maxMessageSize, maxHandshakeSize, serverSslHelper); } - server = new SimpleWebServer(serverMaxMsgsPerTick, TcpConfig, maxMessageSize, maxHandshakeSize, config); server.onConnect += OnServerConnected.Invoke; server.onDisconnect += OnServerDisconnected.Invoke;