diff --git a/MCPForUnity/Editor/Constants/EditorPrefKeys.cs b/MCPForUnity/Editor/Constants/EditorPrefKeys.cs index ffaa31c0..30dcd2bb 100644 --- a/MCPForUnity/Editor/Constants/EditorPrefKeys.cs +++ b/MCPForUnity/Editor/Constants/EditorPrefKeys.cs @@ -11,6 +11,7 @@ internal static class EditorPrefKeys internal const string ValidationLevel = "MCPForUnity.ValidationLevel"; internal const string UnitySocketPort = "MCPForUnity.UnitySocketPort"; internal const string ResumeHttpAfterReload = "MCPForUnity.ResumeHttpAfterReload"; + internal const string ResumeStdioAfterReload = "MCPForUnity.ResumeStdioAfterReload"; internal const string UvxPathOverride = "MCPForUnity.UvxPath"; internal const string ClaudeCliPathOverride = "MCPForUnity.ClaudeCliPath"; diff --git a/MCPForUnity/Editor/Services/BridgeControlService.cs b/MCPForUnity/Editor/Services/BridgeControlService.cs index c67efd1b..4c18da04 100644 --- a/MCPForUnity/Editor/Services/BridgeControlService.cs +++ b/MCPForUnity/Editor/Services/BridgeControlService.cs @@ -49,13 +49,21 @@ private static BridgeVerificationResult BuildVerificationResult(TransportState s }; } - public bool IsRunning => _transportManager.GetState().IsConnected; + public bool IsRunning + { + get + { + var mode = ResolvePreferredMode(); + return _transportManager.IsRunning(mode); + } + } public int CurrentPort { get { - var state = _transportManager.GetState(); + var mode = ResolvePreferredMode(); + var state = _transportManager.GetState(mode); if (state.Port.HasValue) { return state.Port.Value; @@ -67,7 +75,7 @@ public int CurrentPort } public bool IsAutoConnectMode => StdioBridgeHost.IsAutoConnectMode(); - public TransportMode? ActiveMode => _transportManager.ActiveMode; + public TransportMode? ActiveMode => ResolvePreferredMode(); public async Task StartAsync() { @@ -92,7 +100,8 @@ public async Task StopAsync() { try { - await _transportManager.StopAsync(); + var mode = ResolvePreferredMode(); + await _transportManager.StopAsync(mode); } catch (Exception ex) { @@ -102,17 +111,17 @@ public async Task StopAsync() public async Task VerifyAsync() { - var mode = _transportManager.ActiveMode ?? ResolvePreferredMode(); - bool pingSucceeded = await _transportManager.VerifyAsync(); - var state = _transportManager.GetState(); + var mode = ResolvePreferredMode(); + bool pingSucceeded = await _transportManager.VerifyAsync(mode); + var state = _transportManager.GetState(mode); return BuildVerificationResult(state, mode, pingSucceeded); } public BridgeVerificationResult Verify(int port) { - var mode = _transportManager.ActiveMode ?? ResolvePreferredMode(); - bool pingSucceeded = _transportManager.VerifyAsync().GetAwaiter().GetResult(); - var state = _transportManager.GetState(); + var mode = ResolvePreferredMode(); + bool pingSucceeded = _transportManager.VerifyAsync(mode).GetAwaiter().GetResult(); + var state = _transportManager.GetState(mode); if (mode == TransportMode.Stdio) { diff --git a/MCPForUnity/Editor/Services/HttpBridgeReloadHandler.cs b/MCPForUnity/Editor/Services/HttpBridgeReloadHandler.cs index 16b8bd87..0422a92e 100644 --- a/MCPForUnity/Editor/Services/HttpBridgeReloadHandler.cs +++ b/MCPForUnity/Editor/Services/HttpBridgeReloadHandler.cs @@ -24,8 +24,8 @@ private static void OnBeforeAssemblyReload() { try { - var bridge = MCPServiceLocator.Bridge; - bool shouldResume = bridge.IsRunning && bridge.ActiveMode == TransportMode.Http; + var transport = MCPServiceLocator.TransportManager; + bool shouldResume = transport.IsRunning(TransportMode.Http); if (shouldResume) { @@ -36,9 +36,9 @@ private static void OnBeforeAssemblyReload() EditorPrefs.DeleteKey(EditorPrefKeys.ResumeHttpAfterReload); } - if (bridge.IsRunning) + if (shouldResume) { - var stopTask = bridge.StopAsync(); + var stopTask = transport.StopAsync(TransportMode.Http); stopTask.ContinueWith(t => { if (t.IsFaulted && t.Exception != null) @@ -59,7 +59,9 @@ private static void OnAfterAssemblyReload() bool resume = false; try { - resume = EditorPrefs.GetBool(EditorPrefKeys.ResumeHttpAfterReload, false); + // Only resume HTTP if it is still the selected transport. + bool useHttp = EditorPrefs.GetBool(EditorPrefKeys.UseHttpTransport, true); + resume = useHttp && EditorPrefs.GetBool(EditorPrefKeys.ResumeHttpAfterReload, false); if (resume) { EditorPrefs.DeleteKey(EditorPrefKeys.ResumeHttpAfterReload); @@ -90,7 +92,7 @@ private static void OnAfterAssemblyReload() { try { - var startTask = MCPServiceLocator.Bridge.StartAsync(); + var startTask = MCPServiceLocator.TransportManager.StartAsync(TransportMode.Http); startTask.ContinueWith(t => { if (t.IsFaulted) @@ -123,7 +125,7 @@ private static void OnAfterAssemblyReload() { try { - bool started = await MCPServiceLocator.Bridge.StartAsync(); + bool started = await MCPServiceLocator.TransportManager.StartAsync(TransportMode.Http); if (!started) { McpLog.Warn("Failed to resume HTTP MCP bridge after domain reload"); diff --git a/MCPForUnity/Editor/Services/StdioBridgeReloadHandler.cs b/MCPForUnity/Editor/Services/StdioBridgeReloadHandler.cs new file mode 100644 index 00000000..32020351 --- /dev/null +++ b/MCPForUnity/Editor/Services/StdioBridgeReloadHandler.cs @@ -0,0 +1,104 @@ +using System; +using UnityEditor; +using MCPForUnity.Editor.Constants; +using MCPForUnity.Editor.Helpers; +using MCPForUnity.Editor.Services.Transport; +using MCPForUnity.Editor.Services.Transport.Transports; + +namespace MCPForUnity.Editor.Services +{ + /// + /// Ensures the legacy stdio bridge resumes after domain reloads, mirroring the HTTP handler. + /// + [InitializeOnLoad] + internal static class StdioBridgeReloadHandler + { + static StdioBridgeReloadHandler() + { + AssemblyReloadEvents.beforeAssemblyReload += OnBeforeAssemblyReload; + AssemblyReloadEvents.afterAssemblyReload += OnAfterAssemblyReload; + } + + private static void OnBeforeAssemblyReload() + { + try + { + // Only persist resume intent when stdio is the active transport and the bridge is running. + bool useHttp = EditorPrefs.GetBool(EditorPrefKeys.UseHttpTransport, true); + bool isRunning = MCPServiceLocator.TransportManager.IsRunning(TransportMode.Stdio); + bool shouldResume = !useHttp && isRunning; + + if (shouldResume) + { + EditorPrefs.SetBool(EditorPrefKeys.ResumeStdioAfterReload, true); + + // Stop only the stdio bridge; leave HTTP untouched if it is running concurrently. + var stopTask = MCPServiceLocator.TransportManager.StopAsync(TransportMode.Stdio); + stopTask.ContinueWith(t => + { + if (t.IsFaulted && t.Exception != null) + { + McpLog.Warn($"Error stopping stdio bridge before reload: {t.Exception.GetBaseException()?.Message}"); + } + }, System.Threading.Tasks.TaskScheduler.Default); + } + else + { + EditorPrefs.DeleteKey(EditorPrefKeys.ResumeStdioAfterReload); + } + } + catch (Exception ex) + { + McpLog.Warn($"Failed to persist stdio reload flag: {ex.Message}"); + } + } + + private static void OnAfterAssemblyReload() + { + bool resume = false; + try + { + resume = EditorPrefs.GetBool(EditorPrefKeys.ResumeStdioAfterReload, false); + bool useHttp = EditorPrefs.GetBool(EditorPrefKeys.UseHttpTransport, true); + resume = resume && !useHttp; + if (resume) + { + EditorPrefs.DeleteKey(EditorPrefKeys.ResumeStdioAfterReload); + } + } + catch (Exception ex) + { + McpLog.Warn($"Failed to read stdio reload flag: {ex.Message}"); + } + + if (!resume) + { + return; + } + + // Restart via TransportManager so state stays in sync; if it fails (port busy), rely on UI to retry. + TryStartBridgeImmediate(); + } + + private static void TryStartBridgeImmediate() + { + var startTask = MCPServiceLocator.TransportManager.StartAsync(TransportMode.Stdio); + startTask.ContinueWith(t => + { + if (t.IsFaulted) + { + var baseEx = t.Exception?.GetBaseException(); + McpLog.Warn($"Failed to resume stdio bridge after reload: {baseEx?.Message}"); + return; + } + if (!t.Result) + { + McpLog.Warn("Failed to resume stdio bridge after domain reload"); + return; + } + + MCPForUnity.Editor.Windows.MCPForUnityEditorWindow.RequestHealthVerification(); + }, System.Threading.Tasks.TaskScheduler.Default); + } + } +} diff --git a/MCPForUnity/Editor/Services/StdioBridgeReloadHandler.cs.meta b/MCPForUnity/Editor/Services/StdioBridgeReloadHandler.cs.meta new file mode 100644 index 00000000..d4e43fa9 --- /dev/null +++ b/MCPForUnity/Editor/Services/StdioBridgeReloadHandler.cs.meta @@ -0,0 +1,11 @@ +fileFormatVersion: 2 +guid: 6e603c72a87974cf5b495cd683165fbf +MonoImporter: + externalObjects: {} + serializedVersion: 2 + defaultReferences: [] + executionOrder: 0 + icon: {instanceID: 0} + userData: + assetBundleName: + assetBundleVariant: diff --git a/MCPForUnity/Editor/Services/Transport/TransportManager.cs b/MCPForUnity/Editor/Services/Transport/TransportManager.cs index 7a6afe92..d221ab83 100644 --- a/MCPForUnity/Editor/Services/Transport/TransportManager.cs +++ b/MCPForUnity/Editor/Services/Transport/TransportManager.cs @@ -10,8 +10,10 @@ namespace MCPForUnity.Editor.Services.Transport /// public class TransportManager { - private IMcpTransportClient _active; - private TransportMode? _activeMode; + private IMcpTransportClient _httpClient; + private IMcpTransportClient _stdioClient; + private TransportState _httpState = TransportState.Disconnected("http"); + private TransportState _stdioState = TransportState.Disconnected("stdio"); private Func _webSocketFactory; private Func _stdioFactory; @@ -22,8 +24,8 @@ public TransportManager() () => new StdioTransportClient()); } - public IMcpTransportClient ActiveTransport => _active; - public TransportMode? ActiveMode => _activeMode; + public IMcpTransportClient ActiveTransport => null; // Deprecated single-transport accessor + public TransportMode? ActiveMode => null; // Deprecated single-transport accessor public void Configure( Func webSocketFactory, @@ -33,68 +35,115 @@ public void Configure( _stdioFactory = stdioFactory ?? throw new ArgumentNullException(nameof(stdioFactory)); } - public async Task StartAsync(TransportMode mode) + private IMcpTransportClient GetOrCreateClient(TransportMode mode) { - await StopAsync(); - - IMcpTransportClient next = mode switch + return mode switch { - TransportMode.Stdio => _stdioFactory(), - TransportMode.Http => _webSocketFactory(), - _ => throw new ArgumentOutOfRangeException(nameof(mode), mode, "Unsupported transport mode") - } ?? throw new InvalidOperationException($"Factory returned null for transport mode {mode}"); + TransportMode.Http => _httpClient ??= _webSocketFactory(), + TransportMode.Stdio => _stdioClient ??= _stdioFactory(), + _ => throw new ArgumentOutOfRangeException(nameof(mode), mode, "Unsupported transport mode"), + }; + } - bool started = await next.StartAsync(); - if (!started) + private IMcpTransportClient GetClient(TransportMode mode) + { + return mode switch { - await next.StopAsync(); - _active = null; - _activeMode = null; - return false; - } - - _active = next; - _activeMode = mode; - return true; + TransportMode.Http => _httpClient, + TransportMode.Stdio => _stdioClient, + _ => throw new ArgumentOutOfRangeException(nameof(mode), mode, "Unsupported transport mode"), + }; } - public async Task StopAsync() + public async Task StartAsync(TransportMode mode) { - if (_active != null) + IMcpTransportClient client = GetOrCreateClient(mode); + + bool started = await client.StartAsync(); + if (!started) { try { - await _active.StopAsync(); + await client.StopAsync(); } catch (Exception ex) { - McpLog.Warn($"Error while stopping transport {_active.TransportName}: {ex.Message}"); - } - finally - { - _active = null; - _activeMode = null; + McpLog.Warn($"Error while stopping transport {client.TransportName}: {ex.Message}"); } + UpdateState(mode, TransportState.Disconnected(client.TransportName, "Failed to start")); + return false; } + + UpdateState(mode, client.State ?? TransportState.Connected(client.TransportName)); + return true; } - public async Task VerifyAsync() + public async Task StopAsync(TransportMode? mode = null) { - if (_active == null) + async Task StopClient(IMcpTransportClient client, TransportMode clientMode) { - return false; + if (client == null) return; + try { await client.StopAsync(); } + catch (Exception ex) { McpLog.Warn($"Error while stopping transport {client.TransportName}: {ex.Message}"); } + finally { UpdateState(clientMode, TransportState.Disconnected(client.TransportName)); } + } + + if (mode == null) + { + await StopClient(_httpClient, TransportMode.Http); + await StopClient(_stdioClient, TransportMode.Stdio); + return; + } + + if (mode == TransportMode.Http) + { + await StopClient(_httpClient, TransportMode.Http); + } + else + { + await StopClient(_stdioClient, TransportMode.Stdio); } - return await _active.VerifyAsync(); } - public TransportState GetState() + public async Task VerifyAsync(TransportMode mode) { - if (_active == null) + IMcpTransportClient client = GetClient(mode); + if (client == null) { - return TransportState.Disconnected(_activeMode?.ToString()?.ToLowerInvariant() ?? "unknown", "Transport not started"); + return false; } - return _active.State ?? TransportState.Disconnected(_active.TransportName, "No state reported"); + bool ok = await client.VerifyAsync(); + var state = client.State ?? TransportState.Disconnected(client.TransportName, "No state reported"); + UpdateState(mode, state); + return ok; + } + + public TransportState GetState(TransportMode mode) + { + return mode switch + { + TransportMode.Http => _httpState, + TransportMode.Stdio => _stdioState, + _ => throw new ArgumentOutOfRangeException(nameof(mode), mode, "Unsupported transport mode"), + }; + } + + public bool IsRunning(TransportMode mode) => GetState(mode).IsConnected; + + private void UpdateState(TransportMode mode, TransportState state) + { + switch (mode) + { + case TransportMode.Http: + _httpState = state; + break; + case TransportMode.Stdio: + _stdioState = state; + break; + default: + throw new ArgumentOutOfRangeException(nameof(mode), mode, "Unsupported transport mode"); + } } } diff --git a/MCPForUnity/Editor/Services/Transport/Transports/StdioBridgeHost.cs b/MCPForUnity/Editor/Services/Transport/Transports/StdioBridgeHost.cs index 5cc1585a..ab127b48 100644 --- a/MCPForUnity/Editor/Services/Transport/Transports/StdioBridgeHost.cs +++ b/MCPForUnity/Editor/Services/Transport/Transports/StdioBridgeHost.cs @@ -57,7 +57,6 @@ public static class StdioBridgeHost private static int mainThreadId; private static int currentUnityPort = 6400; private static bool isAutoConnectMode = false; - private static bool shouldRestartAfterReload = false; private const ulong MaxFrameBytes = 64UL * 1024 * 1024; private const int FrameIOTimeoutMs = 30000; @@ -162,8 +161,6 @@ static StdioBridgeHost() } } EditorApplication.quitting += Stop; - AssemblyReloadEvents.beforeAssemblyReload += OnBeforeAssemblyReload; - AssemblyReloadEvents.afterAssemblyReload += OnAfterAssemblyReload; EditorApplication.playModeStateChanged += _ => { if (ShouldAutoStartBridge()) @@ -406,10 +403,6 @@ public static void Start() listenerTask = Task.Run(() => ListenerLoopAsync(cts.Token)); CommandRegistry.Initialize(); EditorApplication.update += ProcessCommands; - try { AssemblyReloadEvents.beforeAssemblyReload -= OnBeforeAssemblyReload; } catch { } - try { AssemblyReloadEvents.beforeAssemblyReload += OnBeforeAssemblyReload; } catch { } - try { AssemblyReloadEvents.afterAssemblyReload -= OnAfterAssemblyReload; } catch { } - try { AssemblyReloadEvents.afterAssemblyReload += OnAfterAssemblyReload; } catch { } try { EditorApplication.quitting -= Stop; } catch { } try { EditorApplication.quitting += Stop; } catch { } heartbeatSeq++; @@ -470,8 +463,6 @@ public static void Stop() } try { EditorApplication.update -= ProcessCommands; } catch { } - try { AssemblyReloadEvents.beforeAssemblyReload -= OnBeforeAssemblyReload; } catch { } - try { AssemblyReloadEvents.afterAssemblyReload -= OnAfterAssemblyReload; } catch { } try { EditorApplication.quitting -= Stop; } catch { } try @@ -1023,47 +1014,6 @@ private static bool IsValidJson(string text) return false; } - private static void OnBeforeAssemblyReload() - { - if (isRunning) - { - shouldRestartAfterReload = true; - } - try { Stop(); } catch { } - } - - private static void OnAfterAssemblyReload() - { - WriteHeartbeat(false, "idle"); - LogBreadcrumb("Idle"); - bool shouldResume = ShouldAutoStartBridge() || shouldRestartAfterReload; - if (shouldRestartAfterReload) - { - shouldRestartAfterReload = false; - } - if (!shouldResume) - { - return; - } - - // If we're not compiling, try to bring the bridge up immediately to avoid depending on editor focus. - if (!IsCompiling()) - { - try - { - Start(); - return; // Successful immediate start; no need to schedule a delayed retry - } - catch (Exception ex) - { - // Fall through to delayed retry if immediate start fails - McpLog.Warn($"Immediate STDIO bridge restart after reload failed: {ex.Message}"); - } - } - - // Fallback path when compiling or if immediate start failed - ScheduleInitRetry(); - } private static void WriteHeartbeat(bool reloading, string reason = null) {