diff --git a/MCPForUnity/Editor/Services/EditorStateCache.cs b/MCPForUnity/Editor/Services/EditorStateCache.cs index 24fec0f14..04edb76ac 100644 --- a/MCPForUnity/Editor/Services/EditorStateCache.cs +++ b/MCPForUnity/Editor/Services/EditorStateCache.cs @@ -7,6 +7,7 @@ using UnityEditorInternal; using UnityEditor.SceneManagement; using UnityEngine; +using System.Collections.Generic; namespace MCPForUnity.Editor.Services { @@ -15,7 +16,7 @@ namespace MCPForUnity.Editor.Services /// Updated on the main thread via Editor callbacks and periodic update ticks. /// [InitializeOnLoad] - internal static class EditorStateCache + public static class EditorStateCache { private static readonly object LockObj = new(); private static long _sequence; @@ -42,6 +43,11 @@ internal static class EditorStateCache private static bool _lastTrackedTestsRunning; private static string _lastTrackedActivityPhase; + // Selection state tracking for state-aware tool filtering + private static int _lastTrackedActiveInstanceID; + private static string _lastTrackedActiveGameObjectName; + private static int _lastTrackedSelectionCount; + private static JObject _cached; private sealed class EditorStateSnapshot @@ -75,6 +81,9 @@ private sealed class EditorStateSnapshot [JsonProperty("transport")] public EditorStateTransport Transport { get; set; } + + [JsonProperty("advice")] + public EditorStateAdvice Advice { get; set; } } private sealed class EditorStateUnity @@ -105,6 +114,24 @@ private sealed class EditorStateEditor [JsonProperty("active_scene")] public EditorStateActiveScene ActiveScene { get; set; } + + [JsonProperty("selection")] + public EditorStateSelection Selection { get; set; } + } + + private sealed class EditorStateSelection + { + [JsonProperty("has_selection")] + public bool HasSelection { get; set; } + + [JsonProperty("active_instance_id")] + public int ActiveInstanceID { get; set; } + + [JsonProperty("active_game_object_name")] + public string ActiveGameObjectName { get; set; } + + [JsonProperty("selection_count")] + public int SelectionCount { get; set; } } private sealed class EditorStatePlayMode @@ -230,6 +257,21 @@ private sealed class EditorStateLastRun public object Counts { get; set; } } + private sealed class EditorStateAdvice + { + [JsonProperty("ready_for_tools")] + public bool ReadyForTools { get; set; } + + [JsonProperty("blocking_reasons")] + public string[] BlockingReasons { get; set; } + + [JsonProperty("recommended_retry_after_ms")] + public long? RecommendedRetryAfterMs { get; set; } + + [JsonProperty("recommended_next_action")] + public string RecommendedNextAction { get; set; } + } + private sealed class EditorStateTransport { [JsonProperty("unity_bridge_connected")] @@ -249,6 +291,7 @@ static EditorStateCache() EditorApplication.update += OnUpdate; EditorApplication.playModeStateChanged += _ => ForceUpdate("playmode"); + Selection.selectionChanged += () => ForceUpdate("selection"); AssemblyReloadEvents.beforeAssemblyReload += () => { @@ -296,6 +339,11 @@ private static void OnUpdate() bool isUpdating = EditorApplication.isUpdating; bool testsRunning = TestRunStatus.IsRunning; + // Selection state reading for state-aware tool filtering + int activeInstanceID = Selection.activeInstanceID; + string activeGameObjectName = Selection.activeGameObject?.name ?? string.Empty; + int selectionCount = Selection.count; + var activityPhase = "idle"; if (testsRunning) { @@ -326,7 +374,10 @@ private static void OnUpdate() || _lastTrackedIsPaused != isPaused || _lastTrackedIsUpdating != isUpdating || _lastTrackedTestsRunning != testsRunning - || _lastTrackedActivityPhase != activityPhase; + || _lastTrackedActivityPhase != activityPhase + || _lastTrackedActiveInstanceID != activeInstanceID + || _lastTrackedActiveGameObjectName != activeGameObjectName + || _lastTrackedSelectionCount != selectionCount; if (!hasChanges) { @@ -344,6 +395,9 @@ private static void OnUpdate() _lastTrackedIsUpdating = isUpdating; _lastTrackedTestsRunning = testsRunning; _lastTrackedActivityPhase = activityPhase; + _lastTrackedActiveInstanceID = activeInstanceID; + _lastTrackedActiveGameObjectName = activeGameObjectName; + _lastTrackedSelectionCount = selectionCount; _lastUpdateTimeSinceStartup = now; ForceUpdate("tick"); @@ -404,6 +458,11 @@ private static JObject BuildSnapshot(string reason) activityPhase = "playmode_transition"; } + // Read current selection state directly for snapshot + int currentActiveInstanceID = Selection.activeInstanceID; + string currentActiveGameObjectName = Selection.activeGameObject?.name ?? string.Empty; + int currentSelectionCount = Selection.count; + var snapshot = new EditorStateSnapshot { SchemaVersion = "unity-mcp/editor_state@2", @@ -431,6 +490,13 @@ private static JObject BuildSnapshot(string reason) Path = scenePath, Guid = sceneGuid, Name = scene.name ?? string.Empty + }, + Selection = new EditorStateSelection + { + HasSelection = currentSelectionCount > 0, + ActiveInstanceID = currentActiveInstanceID, + ActiveGameObjectName = currentActiveGameObjectName, + SelectionCount = currentSelectionCount } }, Activity = new EditorStateActivity @@ -482,20 +548,51 @@ private static JObject BuildSnapshot(string reason) { UnityBridgeConnected = null, LastMessageUnixMs = null - } + }, + Advice = BuildEditorStateAdvice(isCompiling, testsRunning) }; return JObject.FromObject(snapshot); } - public static JObject GetSnapshot() + private static EditorStateAdvice BuildEditorStateAdvice(bool isCompiling, bool testsRunning) + { + var blockingReasons = new List(); + + if (isCompiling) + { + blockingReasons.Add("compiling"); + } + + if (_domainReloadPending) + { + blockingReasons.Add("domain_reload"); + } + + if (testsRunning) + { + blockingReasons.Add("tests_running"); + } + + bool readyForTools = blockingReasons.Count == 0; + + return new EditorStateAdvice + { + ReadyForTools = readyForTools, + BlockingReasons = blockingReasons.ToArray(), + RecommendedRetryAfterMs = isCompiling ? 1000 : null, + RecommendedNextAction = isCompiling ? "wait_for_compile" : null + }; + } + + public static JObject GetSnapshot(bool forceRefresh = false) { lock (LockObj) { // Defensive: if something went wrong early, rebuild once. - if (_cached == null) + if (_cached == null || forceRefresh) { - _cached = BuildSnapshot("rebuild"); + _cached = BuildSnapshot(forceRefresh ? "get_snapshot_force" : "get_snapshot_rebuild"); } // Always return a fresh clone to prevent mutation bugs. diff --git a/MCPForUnity/Editor/Services/ToolDiscoveryService.cs b/MCPForUnity/Editor/Services/ToolDiscoveryService.cs index b5b86c0a2..ac7cc1455 100644 --- a/MCPForUnity/Editor/Services/ToolDiscoveryService.cs +++ b/MCPForUnity/Editor/Services/ToolDiscoveryService.cs @@ -226,16 +226,6 @@ private void EnsurePreferenceInitialized(ToolMetadata metadata) { bool defaultValue = metadata.AutoRegister || metadata.IsBuiltIn; EditorPrefs.SetBool(key, defaultValue); - return; - } - - if (metadata.IsBuiltIn && !metadata.AutoRegister) - { - bool currentValue = EditorPrefs.GetBool(key, metadata.AutoRegister); - if (currentValue == metadata.AutoRegister) - { - EditorPrefs.SetBool(key, true); - } } } diff --git a/MCPForUnity/Editor/Services/Transport/IMcpTransportClient.cs b/MCPForUnity/Editor/Services/Transport/IMcpTransportClient.cs index 3d8584fd9..44f1d3b8c 100644 --- a/MCPForUnity/Editor/Services/Transport/IMcpTransportClient.cs +++ b/MCPForUnity/Editor/Services/Transport/IMcpTransportClient.cs @@ -14,5 +14,6 @@ public interface IMcpTransportClient Task StartAsync(); Task StopAsync(); Task VerifyAsync(); + Task ReregisterToolsAsync(); } } diff --git a/MCPForUnity/Editor/Services/Transport/TransportManager.cs b/MCPForUnity/Editor/Services/Transport/TransportManager.cs index 1204e7014..b544273b4 100644 --- a/MCPForUnity/Editor/Services/Transport/TransportManager.cs +++ b/MCPForUnity/Editor/Services/Transport/TransportManager.cs @@ -42,16 +42,6 @@ private IMcpTransportClient GetOrCreateClient(TransportMode mode) }; } - private IMcpTransportClient GetClient(TransportMode mode) - { - return mode switch - { - TransportMode.Http => _httpClient, - TransportMode.Stdio => _stdioClient, - _ => throw new ArgumentOutOfRangeException(nameof(mode), mode, "Unsupported transport mode"), - }; - } - public async Task StartAsync(TransportMode mode) { IMcpTransportClient client = GetOrCreateClient(mode); @@ -128,6 +118,20 @@ public TransportState GetState(TransportMode mode) public bool IsRunning(TransportMode mode) => GetState(mode).IsConnected; + /// + /// Gets the active transport client for the specified mode. + /// Returns null if the client hasn't been created yet. + /// + public IMcpTransportClient GetClient(TransportMode mode) + { + return mode switch + { + TransportMode.Http => _httpClient, + TransportMode.Stdio => _stdioClient, + _ => throw new ArgumentOutOfRangeException(nameof(mode), mode, "Unsupported transport mode"), + }; + } + private void UpdateState(TransportMode mode, TransportState state) { switch (mode) diff --git a/MCPForUnity/Editor/Services/Transport/Transports/StdioTransportClient.cs b/MCPForUnity/Editor/Services/Transport/Transports/StdioTransportClient.cs index ea3ed1a22..9dadc4336 100644 --- a/MCPForUnity/Editor/Services/Transport/Transports/StdioTransportClient.cs +++ b/MCPForUnity/Editor/Services/Transport/Transports/StdioTransportClient.cs @@ -46,5 +46,12 @@ public Task VerifyAsync() return Task.FromResult(running); } + public Task ReregisterToolsAsync() + { + // Stdio transport doesn't support dynamic tool reregistration + // Tools are registered at server startup + return Task.CompletedTask; + } + } } diff --git a/MCPForUnity/Editor/Services/Transport/Transports/WebSocketTransportClient.cs b/MCPForUnity/Editor/Services/Transport/Transports/WebSocketTransportClient.cs index b94c0836f..6d2f76838 100644 --- a/MCPForUnity/Editor/Services/Transport/Transports/WebSocketTransportClient.cs +++ b/MCPForUnity/Editor/Services/Transport/Transports/WebSocketTransportClient.cs @@ -506,6 +506,29 @@ private async Task SendRegisterToolsAsync(CancellationToken token) McpLog.Info($"[WebSocket] Sent {tools.Count} tools registration", false); } + public async Task ReregisterToolsAsync() + { + if (!IsConnected || _lifecycleCts == null) + { + McpLog.Warn("[WebSocket] Cannot reregister tools: not connected"); + return; + } + + try + { + await SendRegisterToolsAsync(_lifecycleCts.Token).ConfigureAwait(false); + McpLog.Info("[WebSocket] Tool reregistration completed", false); + } + catch (System.OperationCanceledException) + { + McpLog.Warn("[WebSocket] Tool reregistration cancelled"); + } + catch (System.Exception ex) + { + McpLog.Error($"[WebSocket] Tool reregistration failed: {ex.Message}"); + } + } + private async Task HandleExecuteAsync(JObject payload, CancellationToken token) { string commandId = payload.Value("id"); diff --git a/MCPForUnity/Editor/Windows/Components/Tools/McpToolsSection.cs b/MCPForUnity/Editor/Windows/Components/Tools/McpToolsSection.cs index dc5a3eabf..8644a7ad6 100644 --- a/MCPForUnity/Editor/Windows/Components/Tools/McpToolsSection.cs +++ b/MCPForUnity/Editor/Windows/Components/Tools/McpToolsSection.cs @@ -1,9 +1,11 @@ using System; using System.Collections.Generic; using System.Linq; +using System.Threading; using MCPForUnity.Editor.Constants; using MCPForUnity.Editor.Helpers; using MCPForUnity.Editor.Services; +using MCPForUnity.Editor.Services.Transport; using MCPForUnity.Editor.Tools; using UnityEditor; using UnityEngine.UIElements; @@ -231,6 +233,30 @@ private void HandleToggleChange(ToolMetadata tool, bool enabled, bool updateSumm { UpdateSummary(); } + + // Trigger tool reregistration with connected MCP server + ReregisterToolsAsync(); + } + + private void ReregisterToolsAsync() + { + // Fire and forget - don't block UI + ThreadPool.QueueUserWorkItem(_ => + { + try + { + var transportManager = MCPServiceLocator.TransportManager; + var client = transportManager.GetClient(TransportMode.Http); + if (client != null && client.IsConnected) + { + client.ReregisterToolsAsync().Wait(); + } + } + catch (Exception ex) + { + McpLog.Warn($"Failed to reregister tools: {ex.Message}"); + } + }); } private void SetAllToolsState(bool enabled) @@ -253,6 +279,9 @@ private void SetAllToolsState(bool enabled) } UpdateSummary(); + + // Trigger tool reregistration after bulk change + ReregisterToolsAsync(); } private void UpdateSummary() diff --git a/Server/src/core/tool_filter_decorator.py b/Server/src/core/tool_filter_decorator.py new file mode 100644 index 000000000..aa1f5b4bd --- /dev/null +++ b/Server/src/core/tool_filter_decorator.py @@ -0,0 +1,243 @@ +"""Tool filter decorator for state-aware prerequisite checking. + +This module provides the prerequisite_check decorator that allows tools to declare +their dependencies on Unity editor state (compilation, selection, play mode, etc.). +""" + +import functools +import inspect +import logging +import threading +from typing import Callable, Any, Final + +from models import MCPResponse + +logger = logging.getLogger("mcp-for-unity-server") + +__all__ = ["ToolPrerequisite", "prerequisite_check", "tool_prerequisites"] + + +class ToolPrerequisite: + """Defines conditions under which a tool is available. + + Args: + require_no_compile: Tool hidden when Unity is compiling + require_selection: Tool hidden when no GameObject is selected + require_paused_for_destructive: Tool hidden during play mode (unless paused) + require_no_tests: Tool hidden while tests are running + """ + + def __init__( + self, + require_no_compile: bool = False, + require_selection: bool = False, + require_paused_for_destructive: bool = False, + require_no_tests: bool = False, + ): + self.require_no_compile = require_no_compile + self.require_selection = require_selection + self.require_paused_for_destructive = require_paused_for_destructive + self.require_no_tests = require_no_tests + + def is_met(self, editor_state: dict) -> tuple[bool, str | None]: + """Evaluate if prerequisites are met. + + Args: + editor_state: The current editor state from get_editor_state() + + Returns: + (is_met, blocking_reason) tuple + """ + advice = editor_state.get("advice", {}) + blocking_reasons = advice.get("blocking_reasons", []) if isinstance(advice, dict) else [] + editor = editor_state.get("editor", {}) + selection = editor.get("selection", {}) if isinstance(editor, dict) else {} + play_mode = editor.get("play_mode", {}) if isinstance(editor, dict) else {} + tests = editor_state.get("tests", {}) if isinstance(editor_state, dict) else {} + + # Check compilation prerequisite + if self.require_no_compile: + if "compiling" in blocking_reasons: + return False, "compiling" + + # Check domain reload prerequisite + if self.require_no_compile: + if "domain_reload" in blocking_reasons: + return False, "domain_reload" + + # Check tests prerequisite + if self.require_no_tests: + if isinstance(tests, dict) and tests.get("is_running") is True: + return False, "tests_running" + + # Check selection prerequisite + # Only hide if we know there's no selection (fail-open if state is unknown) + if self.require_selection: + if isinstance(selection, dict): + has_selection = selection.get("has_selection") + if has_selection is False: + return False, "no_selection" + + # Check paused for destructive operations + if self.require_paused_for_destructive: + if isinstance(play_mode, dict): + is_playing = play_mode.get("is_playing") + is_paused = play_mode.get("is_paused") + if is_playing is True and is_paused is False: + return False, "play_mode_active" + + return True, None + + +# Global storage for tool prerequisites +# Key: tool name, Value: ToolPrerequisite instance +# Thread-safe: use _prerequisites_lock for all modifications +# Note: `Final` ensures the reference is not rebound; the dict contents remain mutable +_prerequisites_lock = threading.Lock() +tool_prerequisites: Final[dict[str, ToolPrerequisite]] = {} + + +def prerequisite_check( + require_no_compile: bool = False, + require_selection: bool = False, + require_paused_for_destructive: bool = False, + require_no_tests: bool = False, +) -> Callable: + """Decorator that adds prerequisite checks to MCP tools. + + The decorator stores the prerequisite rules and evaluates them before tool execution. + If prerequisites are not met, the tool call returns an MCPResponse with an error. + + Args: + require_no_compile: Tool hidden when Unity is compiling + require_selection: Tool hidden when no GameObject is selected + require_paused_for_destructive: Tool hidden during play mode (unless paused) + require_no_tests: Tool hidden while tests are running + + Usage: + @prerequisite_check(require_no_compile=True, require_selection=True) + @mcp_for_unity_tool(description="Modify selected GameObject") + async def my_tool(ctx: Context, ...): + ... + + Returns: + Decorator function + """ + + def decorator(func: Callable) -> Callable: + # Store prerequisites for this tool (used by filtering middleware) + tool_name = func.__name__ + with _prerequisites_lock: + tool_prerequisites[tool_name] = ToolPrerequisite( + require_no_compile=require_no_compile, + require_selection=require_selection, + require_paused_for_destructive=require_paused_for_destructive, + require_no_tests=require_no_tests, + ) + + @functools.wraps(func) + def _sync_wrapper(*args, **kwargs) -> Any: + # Check prerequisites before executing + # Note: For direct tool calls, we check here + # For tool list filtering, the middleware handles it + ctx = None + if args and hasattr(args[0], "get_state"): + ctx = args[0] + elif "ctx" in kwargs: + ctx = kwargs["ctx"] + + if ctx: + try: + from services.resources.editor_state import get_editor_state + import asyncio + + # Run async get_editor_state in sync context + loop = None + try: + loop = asyncio.get_event_loop() + except RuntimeError: + loop = asyncio.new_event_loop() + asyncio.set_event_loop(loop) + + if loop and not loop.is_closed(): + state_resp = loop.run_until_complete(get_editor_state(ctx)) + state_data = state_resp.data if hasattr(state_resp, "data") else None + if isinstance(state_data, dict): + # Reuse the already-registered ToolPrerequisite instance + prereq = tool_prerequisites.get(tool_name) + if prereq is not None: + is_met, blocking_reason = prereq.is_met(state_data) + if not is_met: + advice = state_data.get("advice", {}) + blocking_reasons = advice.get("blocking_reasons", []) if isinstance(advice, dict) else [] + from models import MCPResponse + return MCPResponse( + success=False, + error="prerequisite_failed", + message=f"Tool '{tool_name}' is not available: {blocking_reason}", + data={ + "tool": tool_name, + "blocking_reason": blocking_reason, + "current_state": { + "ready_for_tools": advice.get("ready_for_tools"), + "blocking_reasons": blocking_reasons, + } if isinstance(state_data.get("advice"), dict) else None + } + ) + except RuntimeError as e: + # Event loop is already running in another thread - proceed (fail-safe) + logger.warning(f"Event loop conflict checking prerequisites for '{tool_name}': {e}") + except Exception as e: + # If we can't check prerequisites, proceed (fail-safe) + logger.warning(f"Failed to check prerequisites for '{tool_name}': {e}") + + # Prerequisites met or couldn't check - proceed with original function + return func(*args, **kwargs) + + @functools.wraps(func) + async def _async_wrapper(*args, **kwargs) -> Any: + # Check prerequisites before executing + ctx = None + if args and hasattr(args[0], "get_state"): + ctx = args[0] + elif "ctx" in kwargs: + ctx = kwargs["ctx"] + + if ctx: + try: + from services.resources.editor_state import get_editor_state + + state_resp = await get_editor_state(ctx) + state_data = state_resp.data if hasattr(state_resp, "data") else None + if isinstance(state_data, dict): + # Reuse the already-registered ToolPrerequisite instance + prereq = tool_prerequisites.get(tool_name) + if prereq is not None: + is_met, blocking_reason = prereq.is_met(state_data) + if not is_met: + advice = state_data.get("advice", {}) + blocking_reasons = advice.get("blocking_reasons", []) if isinstance(advice, dict) else [] + from models import MCPResponse + return MCPResponse( + success=False, + error="prerequisite_failed", + message=f"Tool '{tool_name}' is not available: {blocking_reason}", + data={ + "tool": tool_name, + "blocking_reason": blocking_reason, + "current_state": { + "ready_for_tools": advice.get("ready_for_tools"), + "blocking_reasons": blocking_reasons, + } if isinstance(state_data.get("advice"), dict) else None + } + ) + except Exception as e: + # If we can't check prerequisites, proceed (fail-safe) + logger.warning(f"Failed to check prerequisites for '{tool_name}': {e}") + + # Prerequisites met or couldn't check - proceed with original function + return await func(*args, **kwargs) + + return _async_wrapper if inspect.iscoroutinefunction(func) else _sync_wrapper + + return decorator diff --git a/Server/src/services/filter_middleware.py b/Server/src/services/filter_middleware.py new file mode 100644 index 000000000..9a8e95a50 --- /dev/null +++ b/Server/src/services/filter_middleware.py @@ -0,0 +1,105 @@ +"""Tool filtering middleware for state-aware tool visibility. + +This module provides middleware that filters the tool list based on editor state +before sending it to the LLM. Tools that don't meet their prerequisites are hidden. + +TODO: Integrate get_tools_matching_state() into the FastMCP tool listing flow. +Currently, tools are filtered at execution time via the @prerequisite_check decorator, +but filtering at list-time would prevent LLMs from seeing unavailable tools entirely. +Integration point: Add a middleware hook to modify the tools list returned to MCP clients. +""" + +import logging +from typing import Any + +from fastmcp import Context + +from core.tool_filter_decorator import tool_prerequisites + +logger = logging.getLogger("mcp-for-unity-server") + + +async def get_tools_matching_state( + ctx: Context, + all_tools: list[dict[str, Any]], +) -> list[dict[str, Any]]: + """Filter tools based on current editor state. + + Args: + ctx: The MCP context + all_tools: List of all registered tool dictionaries + + Returns: + Filtered list of tools that meet their prerequisites + """ + try: + from services.resources.editor_state import get_editor_state + + # Query current editor state + state_resp = await get_editor_state(ctx) + state_data = state_resp.data if hasattr(state_resp, "data") else None + + if not isinstance(state_data, dict): + # Fail-safe: if we can't get state, return all tools + logger.warning("Failed to query editor state, returning all tools (fail-safe)") + return all_tools + + # Filter tools based on their prerequisites + filtered_tools = [] + for tool in all_tools: + tool_name = tool.get("name", "") + prereq = tool_prerequisites.get(tool_name) + + if prereq is None: + # No prerequisites - always visible + filtered_tools.append(tool) + continue + + # Check if prerequisites are met + is_met, blocking_reason = prereq.is_met(state_data) + + if is_met: + filtered_tools.append(tool) + logger.debug(f"Tool '{tool_name}' visible: all prerequisites met") + else: + logger.debug( + f"Tool '{tool_name}' hidden: {blocking_reason}" + ) + + return filtered_tools + + except Exception as e: + # Fail-safe: on error, return all tools + logger.error(f"Error filtering tools: {e}, returning all tools (fail-safe)") + return all_tools + + +class FilterResult: + """Outcome of prerequisite evaluation for a tool. + + NOTE: This class is reserved for future use in providing detailed filtering + results to clients (e.g., for debugging or UI feedback). Currently used in tests + to verify filtering behavior. + + Attributes: + tool_name: The name of the tool + is_visible: Whether the tool should be visible + blocking_reason: The reason why the tool is hidden (if applicable) + """ + + def __init__( + self, + tool_name: str, + is_visible: bool, + blocking_reason: str | None = None, + ): + self.tool_name = tool_name + self.is_visible = is_visible + self.blocking_reason = blocking_reason + + def to_dict(self) -> dict[str, Any]: + return { + "tool_name": self.tool_name, + "is_visible": self.is_visible, + "blocking_reason": self.blocking_reason, + } diff --git a/Server/src/services/resources/editor_state.py b/Server/src/services/resources/editor_state.py index 4de79429e..47c68ffac 100644 --- a/Server/src/services/resources/editor_state.py +++ b/Server/src/services/resources/editor_state.py @@ -35,10 +35,18 @@ class EditorStateActiveScene(BaseModel): name: str | None = None +class EditorStateSelection(BaseModel): + has_selection: bool | None = None + active_instance_id: int | None = None + active_game_object_name: str | None = None + selection_count: int | None = None + + class EditorStateEditor(BaseModel): is_focused: bool | None = None play_mode: EditorStatePlayMode | None = None active_scene: EditorStateActiveScene | None = None + selection: EditorStateSelection | None = None class EditorStateActivity(BaseModel): diff --git a/Server/src/services/tools/__init__.py b/Server/src/services/tools/__init__.py index 91c1f9e00..1ee8175f5 100644 --- a/Server/src/services/tools/__init__.py +++ b/Server/src/services/tools/__init__.py @@ -8,6 +8,7 @@ from fastmcp import Context, FastMCP from core.telemetry_decorator import telemetry_tool from core.logging_decorator import log_execution +from core.tool_filter_decorator import prerequisite_check, tool_prerequisites from utils.module_discovery import discover_modules from services.registry import get_registered_tools diff --git a/Server/src/services/tools/manage_asset.py b/Server/src/services/tools/manage_asset.py index e5328dfa3..ffa22d8e7 100644 --- a/Server/src/services/tools/manage_asset.py +++ b/Server/src/services/tools/manage_asset.py @@ -10,6 +10,7 @@ from services.registry import mcp_for_unity_tool from services.tools import get_unity_instance_from_context +from core.tool_filter_decorator import prerequisite_check from services.tools.utils import parse_json_payload, coerce_int, normalize_properties from transport.unity_transport import send_with_unity_instance from transport.legacy.unity_connection import async_send_command_with_retry @@ -27,6 +28,7 @@ destructiveHint=True, ), ) +@prerequisite_check(require_no_compile=True) async def manage_asset( ctx: Context, action: Annotated[Literal["import", "create", "modify", "delete", "duplicate", "move", "rename", "search", "get_info", "create_folder", "get_components"], "Perform CRUD operations on assets."], diff --git a/Server/src/services/tools/manage_components.py b/Server/src/services/tools/manage_components.py index 2c5c0c94d..940350387 100644 --- a/Server/src/services/tools/manage_components.py +++ b/Server/src/services/tools/manage_components.py @@ -7,6 +7,7 @@ from fastmcp import Context from services.registry import mcp_for_unity_tool from services.tools import get_unity_instance_from_context +from core.tool_filter_decorator import prerequisite_check from transport.unity_transport import send_with_unity_instance from transport.legacy.unity_connection import async_send_command_with_retry from services.tools.utils import parse_json_payload, normalize_properties @@ -16,6 +17,7 @@ @mcp_for_unity_tool( description="Manages components on GameObjects (add, remove, set_property). For reading component data, use the mcpforunity://scene/gameobject/{id}/components resource." ) +@prerequisite_check(require_selection=True) async def manage_components( ctx: Context, action: Annotated[ diff --git a/Server/src/services/tools/manage_gameobject.py b/Server/src/services/tools/manage_gameobject.py index fa0dca647..70003404f 100644 --- a/Server/src/services/tools/manage_gameobject.py +++ b/Server/src/services/tools/manage_gameobject.py @@ -5,6 +5,7 @@ from services.registry import mcp_for_unity_tool from services.tools import get_unity_instance_from_context +from core.tool_filter_decorator import prerequisite_check from transport.unity_transport import send_with_unity_instance from transport.legacy.unity_connection import async_send_command_with_retry from services.tools.utils import coerce_bool, parse_json_payload, normalize_vector3 @@ -51,6 +52,7 @@ def _normalize_component_properties(value: Any) -> tuple[dict[str, dict[str, Any destructiveHint=True, ), ) +@prerequisite_check(require_paused_for_destructive=True) async def manage_gameobject( ctx: Context, action: Annotated[Literal["create", "modify", "delete", "duplicate", diff --git a/Server/src/services/tools/manage_scene.py b/Server/src/services/tools/manage_scene.py index 2a29b906e..e9250462d 100644 --- a/Server/src/services/tools/manage_scene.py +++ b/Server/src/services/tools/manage_scene.py @@ -6,6 +6,7 @@ from services.registry import mcp_for_unity_tool from services.tools import get_unity_instance_from_context from services.tools.utils import coerce_int, coerce_bool +from core.tool_filter_decorator import prerequisite_check from transport.unity_transport import send_with_unity_instance from transport.legacy.unity_connection import async_send_command_with_retry from services.tools.preflight import preflight @@ -18,6 +19,7 @@ destructiveHint=True, ), ) +@prerequisite_check(require_no_compile=True) async def manage_scene( ctx: Context, action: Annotated[Literal[ diff --git a/Server/src/services/tools/manage_script.py b/Server/src/services/tools/manage_script.py index dc9ab5fe9..42af96b36 100644 --- a/Server/src/services/tools/manage_script.py +++ b/Server/src/services/tools/manage_script.py @@ -8,6 +8,7 @@ from services.registry import mcp_for_unity_tool from services.tools import get_unity_instance_from_context +from core.tool_filter_decorator import prerequisite_check from transport.unity_transport import send_with_unity_instance import transport.legacy.unity_connection diff --git a/Server/tests/integration/test_tool_filtering_integration.py b/Server/tests/integration/test_tool_filtering_integration.py new file mode 100644 index 000000000..e8e68b25d --- /dev/null +++ b/Server/tests/integration/test_tool_filtering_integration.py @@ -0,0 +1,420 @@ +"""Integration tests for tool filtering middleware. + +Tests the complete flow from editor state query to tool filtering. +""" + +import pytest +from unittest.mock import AsyncMock, MagicMock, patch +from fastmcp import Context + +from services.filter_middleware import get_tools_matching_state, FilterResult +from core.tool_filter_decorator import tool_prerequisites, ToolPrerequisite, prerequisite_check + + +class TestCompilationFiltering: + """Test that tools are hidden during compilation.""" + + @pytest.mark.asyncio + async def test_tool_hidden_when_compiling(self): + """Tools with require_no_compile should be hidden during compilation.""" + # Setup: Register a tool that requires no compilation + tool_name = "test_modify_script" + with patch("core.tool_filter_decorator._prerequisites_lock"): + tool_prerequisites[tool_name] = ToolPrerequisite(require_no_compile=True) + + # Mock editor state with compilation in progress + mock_ctx = MagicMock(spec=Context) + mock_state = { + "advice": { + "blocking_reasons": ["compiling"], + "ready_for_tools": False + }, + "editor": { + "selection": {"has_selection": True} + } + } + + with patch("services.resources.editor_state.get_editor_state", new_callable=AsyncMock) as mock_get_state: + mock_resp = MagicMock() + mock_resp.data = mock_state + mock_get_state.return_value = mock_resp + + # All tools list + all_tools = [ + {"name": tool_name, "description": "Modify script"}, + {"name": "read_only_tool", "description": "Read something"} + ] + + # Run filtering + filtered = await get_tools_matching_state(mock_ctx, all_tools) + + # The compilation-sensitive tool should be hidden + filtered_names = [t["name"] for t in filtered] + assert tool_name not in filtered_names + assert "read_only_tool" in filtered_names + + # Cleanup + with patch("core.tool_filter_decorator._prerequisites_lock"): + tool_prerequisites.pop(tool_name, None) + + @pytest.mark.asyncio + async def test_tool_visible_when_not_compiling(self): + """Tools with require_no_compile should be visible when not compiling.""" + tool_name = "test_modify_script" + with patch("core.tool_filter_decorator._prerequisites_lock"): + tool_prerequisites[tool_name] = ToolPrerequisite(require_no_compile=True) + + # Mock editor state without compilation + mock_ctx = MagicMock(spec=Context) + mock_state = { + "advice": { + "blocking_reasons": [], + "ready_for_tools": True + }, + "editor": { + "selection": {"has_selection": True} + } + } + + with patch("services.resources.editor_state.get_editor_state", new_callable=AsyncMock) as mock_get_state: + mock_resp = MagicMock() + mock_resp.data = mock_state + mock_get_state.return_value = mock_resp + + all_tools = [ + {"name": tool_name, "description": "Modify script"} + ] + + filtered = await get_tools_matching_state(mock_ctx, all_tools) + + # Tool should be visible + assert len(filtered) == 1 + assert filtered[0]["name"] == tool_name + + # Cleanup + with patch("core.tool_filter_decorator._prerequisites_lock"): + tool_prerequisites.pop(tool_name, None) + + +class TestSelectionFiltering: + """Test that tools requiring selection are hidden when nothing selected.""" + + @pytest.mark.asyncio + async def test_tool_hidden_without_selection(self): + """Tools with require_selection should be hidden when no selection.""" + tool_name = "test_adjust_transform" + with patch("core.tool_filter_decorator._prerequisites_lock"): + tool_prerequisites[tool_name] = ToolPrerequisite(require_selection=True) + + # Mock editor state with no selection + mock_ctx = MagicMock(spec=Context) + mock_state = { + "advice": {"blocking_reasons": [], "ready_for_tools": True}, + "editor": { + "selection": {"has_selection": False} + } + } + + with patch("services.resources.editor_state.get_editor_state", new_callable=AsyncMock) as mock_get_state: + mock_resp = MagicMock() + mock_resp.data = mock_state + mock_get_state.return_value = mock_resp + + all_tools = [ + {"name": tool_name, "description": "Adjust transform"}, + {"name": "read_only_tool", "description": "Read something"} + ] + + filtered = await get_tools_matching_state(mock_ctx, all_tools) + + filtered_names = [t["name"] for t in filtered] + assert tool_name not in filtered_names + assert "read_only_tool" in filtered_names + + # Cleanup + with patch("core.tool_filter_decorator._prerequisites_lock"): + tool_prerequisites.pop(tool_name, None) + + @pytest.mark.asyncio + async def test_tool_visible_with_selection(self): + """Tools with require_selection should be visible when GameObject selected.""" + tool_name = "test_adjust_transform" + with patch("core.tool_filter_decorator._prerequisites_lock"): + tool_prerequisites[tool_name] = ToolPrerequisite(require_selection=True) + + # Mock editor state with selection + mock_ctx = MagicMock(spec=Context) + mock_state = { + "advice": {"blocking_reasons": [], "ready_for_tools": True}, + "editor": { + "selection": {"has_selection": True} + } + } + + with patch("services.resources.editor_state.get_editor_state", new_callable=AsyncMock) as mock_get_state: + mock_resp = MagicMock() + mock_resp.data = mock_state + mock_get_state.return_value = mock_resp + + all_tools = [ + {"name": tool_name, "description": "Adjust transform"} + ] + + filtered = await get_tools_matching_state(mock_ctx, all_tools) + + assert len(filtered) == 1 + assert filtered[0]["name"] == tool_name + + # Cleanup + with patch("core.tool_filter_decorator._prerequisites_lock"): + tool_prerequisites.pop(tool_name, None) + + +class TestPlayModeFiltering: + """Test that destructive tools are hidden during active play mode.""" + + @pytest.mark.asyncio + async def test_destructive_tool_hidden_in_play_mode(self): + """Destructive tools should be hidden during active play mode.""" + tool_name = "test_delete_gameobject" + with patch("core.tool_filter_decorator._prerequisites_lock"): + tool_prerequisites[tool_name] = ToolPrerequisite( + require_selection=True, + require_paused_for_destructive=True + ) + + # Mock editor state in active play mode + mock_ctx = MagicMock(spec=Context) + mock_state = { + "advice": {"blocking_reasons": [], "ready_for_tools": True}, + "editor": { + "selection": {"has_selection": True}, + "play_mode": {"is_playing": True, "is_paused": False} + } + } + + with patch("services.resources.editor_state.get_editor_state", new_callable=AsyncMock) as mock_get_state: + mock_resp = MagicMock() + mock_resp.data = mock_state + mock_get_state.return_value = mock_resp + + all_tools = [ + {"name": tool_name, "description": "Delete GameObject"}, + {"name": "safe_tool", "description": "Safe operation"} + ] + + filtered = await get_tools_matching_state(mock_ctx, all_tools) + + filtered_names = [t["name"] for t in filtered] + assert tool_name not in filtered_names + assert "safe_tool" in filtered_names + + # Cleanup + with patch("core.tool_filter_decorator._prerequisites_lock"): + tool_prerequisites.pop(tool_name, None) + + @pytest.mark.asyncio + async def test_destructive_tool_visible_when_paused(self): + """Destructive tools should be visible when play mode is paused.""" + tool_name = "test_delete_gameobject" + with patch("core.tool_filter_decorator._prerequisites_lock"): + tool_prerequisites[tool_name] = ToolPrerequisite( + require_selection=True, + require_paused_for_destructive=True + ) + + # Mock editor state in paused play mode + mock_ctx = MagicMock(spec=Context) + mock_state = { + "advice": {"blocking_reasons": [], "ready_for_tools": True}, + "editor": { + "selection": {"has_selection": True}, + "play_mode": {"is_playing": True, "is_paused": True} + } + } + + with patch("services.resources.editor_state.get_editor_state", new_callable=AsyncMock) as mock_get_state: + mock_resp = MagicMock() + mock_resp.data = mock_state + mock_get_state.return_value = mock_resp + + all_tools = [ + {"name": tool_name, "description": "Delete GameObject"} + ] + + filtered = await get_tools_matching_state(mock_ctx, all_tools) + + assert len(filtered) == 1 + assert filtered[0]["name"] == tool_name + + # Cleanup + with patch("core.tool_filter_decorator._prerequisites_lock"): + tool_prerequisites.pop(tool_name, None) + + +class TestFailsafeBehavior: + """Test fail-safe behavior when state query fails.""" + + @pytest.mark.asyncio + async def test_returns_all_tools_on_state_query_error(self): + """When state query fails, all tools should be returned (fail-safe).""" + # Mock context + mock_ctx = MagicMock(spec=Context) + + # Mock get_editor_state to raise exception + with patch("services.resources.editor_state.get_editor_state", new_callable=AsyncMock) as mock_get_state: + mock_get_state.side_effect = Exception("Unity connection lost") + + all_tools = [ + {"name": "tool1", "description": "Tool 1"}, + {"name": "tool2", "description": "Tool 2"} + ] + + filtered = await get_tools_matching_state(mock_ctx, all_tools) + + # Should return all tools (fail-safe) + assert len(filtered) == 2 + assert filtered[0]["name"] == "tool1" + assert filtered[1]["name"] == "tool2" + + @pytest.mark.asyncio + async def test_returns_all_tools_on_invalid_state_data(self): + """When state data is invalid, all tools should be returned.""" + mock_ctx = MagicMock(spec=Context) + + with patch("services.resources.editor_state.get_editor_state", new_callable=AsyncMock) as mock_get_state: + mock_resp = MagicMock() + mock_resp.data = None # Invalid data + mock_get_state.return_value = mock_resp + + all_tools = [ + {"name": "tool1", "description": "Tool 1"} + ] + + filtered = await get_tools_matching_state(mock_ctx, all_tools) + + # Should return all tools (fail-safe) + assert len(filtered) == 1 + + +class TestFilterResult: + """Test FilterResult data class.""" + + def test_to_dict_contains_all_fields(self): + """FilterResult.to_dict() should contain all fields.""" + result = FilterResult( + tool_name="test_tool", + is_visible=False, + blocking_reason="compiling" + ) + + result_dict = result.to_dict() + + assert result_dict["tool_name"] == "test_tool" + assert result_dict["is_visible"] is False + assert result_dict["blocking_reason"] == "compiling" + + def test_to_dict_with_none_blocking_reason(self): + """FilterResult.to_dict() should handle None blocking_reason.""" + result = FilterResult( + tool_name="test_tool", + is_visible=True, + blocking_reason=None + ) + + result_dict = result.to_dict() + + assert result_dict["blocking_reason"] is None + + +class TestAsyncDecoratorWrapper: + """Test the async wrapper path of the prerequisite_check decorator. + + These are integration tests because they verify the complete decorator flow + including interaction with get_editor_state. + """ + + @pytest.mark.asyncio + async def test_async_wrapper_prereq_met(self): + """Async tool should execute when prerequisites are met.""" + # Create a mock async tool + @prerequisite_check(require_selection=True) + async def mock_async_tool(ctx): + return "tool_executed" + + # Mock context with editor state + mock_ctx = MagicMock() + mock_state = { + "advice": {"blocking_reasons": [], "ready_for_tools": True}, + "editor": {"selection": {"has_selection": True}} + } + + with patch("services.resources.editor_state.get_editor_state", new_callable=AsyncMock) as mock_get_state: + mock_resp = MagicMock() + mock_resp.data = mock_state + mock_get_state.return_value = mock_resp + + result = await mock_async_tool(mock_ctx) + + # Tool should execute successfully + assert result == "tool_executed" + + # Clean up + with patch("core.tool_filter_decorator._prerequisites_lock"): + tool_prerequisites.pop("mock_async_tool", None) + + @pytest.mark.asyncio + async def test_async_wrapper_prereq_not_met(self): + """Async tool should return error response when prerequisites not met.""" + @prerequisite_check(require_selection=True) + async def mock_async_tool(ctx): + return "should_not_execute" + + mock_ctx = MagicMock() + mock_state = { + "advice": {"blocking_reasons": [], "ready_for_tools": True}, + "editor": {"selection": {"has_selection": False}} + } + + with patch("services.resources.editor_state.get_editor_state", new_callable=AsyncMock) as mock_get_state: + mock_resp = MagicMock() + mock_resp.data = mock_state + mock_get_state.return_value = mock_resp + + result = await mock_async_tool(mock_ctx) + + # Should return MCPResponse error, not execute tool + assert hasattr(result, "success") + assert result.success is False + assert result.error == "prerequisite_failed" + assert "no_selection" in result.message + + # Clean up + with patch("core.tool_filter_decorator._prerequisites_lock"): + tool_prerequisites.pop("mock_async_tool", None) + + @pytest.mark.asyncio + async def test_async_wrapper_reuses_registered_prerequisite(self): + """Async wrapper should reuse the registered ToolPrerequisite instance.""" + @prerequisite_check( + require_no_compile=True, + require_selection=True, + require_paused_for_destructive=True, + require_no_tests=True + ) + async def mock_tool(ctx): + return "executed" + + tool_name = "mock_tool" + prereq = tool_prerequisites.get(tool_name) + + # Verify the registered instance has all flags set + assert prereq is not None + assert prereq.require_no_compile is True + assert prereq.require_selection is True + assert prereq.require_paused_for_destructive is True + assert prereq.require_no_tests is True + + # Clean up + with patch("core.tool_filter_decorator._prerequisites_lock"): + tool_prerequisites.pop(tool_name, None) diff --git a/Server/tests/unit/__init__.py b/Server/tests/unit/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/Server/tests/unit/test_tool_filter_decorator.py b/Server/tests/unit/test_tool_filter_decorator.py new file mode 100644 index 000000000..4f7c993a1 --- /dev/null +++ b/Server/tests/unit/test_tool_filter_decorator.py @@ -0,0 +1,330 @@ +"""Unit tests for tool_filter_decorator module. + +Tests the ToolPrerequisite class logic without external dependencies. +Integration tests for the full decorator flow are in tests/integration/test_tool_filtering_integration.py +""" + +import pytest +from unittest.mock import patch + +from core.tool_filter_decorator import ToolPrerequisite, tool_prerequisites + + +class TestToolPrerequisite: + """Test ToolPrerequisite.is_met() method with various editor states.""" + + def test_no_prerequisites_always_pass(self): + """Tool with no prerequisites should always be available.""" + prereq = ToolPrerequisite() + state = {} + + is_met, reason = prereq.is_met(state) + + assert is_met is True + assert reason is None + + def test_require_no_compile_pass_when_not_compiling(self): + """Tool should be available when Unity is not compiling.""" + prereq = ToolPrerequisite(require_no_compile=True) + state = { + "advice": {"blocking_reasons": []} + } + + is_met, reason = prereq.is_met(state) + + assert is_met is True + assert reason is None + + def test_require_no_compile_fail_when_compiling(self): + """Tool should be hidden when Unity is compiling.""" + prereq = ToolPrerequisite(require_no_compile=True) + state = { + "advice": {"blocking_reasons": ["compiling"]} + } + + is_met, reason = prereq.is_met(state) + + assert is_met is False + assert reason == "compiling" + + def test_require_no_compile_fail_when_domain_reload(self): + """Tool should be hidden during domain reload.""" + prereq = ToolPrerequisite(require_no_compile=True) + state = { + "advice": {"blocking_reasons": ["domain_reload"]} + } + + is_met, reason = prereq.is_met(state) + + assert is_met is False + assert reason == "domain_reload" + + def test_require_selection_pass_when_has_selection(self): + """Tool should be available when GameObject is selected.""" + prereq = ToolPrerequisite(require_selection=True) + state = { + "editor": { + "selection": { + "has_selection": True + } + } + } + + is_met, reason = prereq.is_met(state) + + assert is_met is True + assert reason is None + + def test_require_selection_fail_when_no_selection(self): + """Tool should be hidden when no GameObject is selected.""" + prereq = ToolPrerequisite(require_selection=True) + state = { + "editor": { + "selection": { + "has_selection": False + } + } + } + + is_met, reason = prereq.is_met(state) + + assert is_met is False + assert reason == "no_selection" + + def test_require_selection_pass_when_unknown(self): + """Tool should be available when selection state is unknown (fail-open).""" + prereq = ToolPrerequisite(require_selection=True) + state = { + "editor": { + "selection": { + "has_selection": None + } + } + } + + is_met, reason = prereq.is_met(state) + + # Fail-open: unknown state should not hide the tool + assert is_met is True + assert reason is None + + def test_require_paused_for_destructive_fail_in_play_mode(self): + """Tool should be hidden during active play mode.""" + prereq = ToolPrerequisite(require_paused_for_destructive=True) + state = { + "editor": { + "play_mode": { + "is_playing": True, + "is_paused": False + } + } + } + + is_met, reason = prereq.is_met(state) + + assert is_met is False + assert reason == "play_mode_active" + + def test_require_paused_for_destructive_pass_when_paused(self): + """Tool should be available when play mode is paused.""" + prereq = ToolPrerequisite(require_paused_for_destructive=True) + state = { + "editor": { + "play_mode": { + "is_playing": True, + "is_paused": True + } + } + } + + is_met, reason = prereq.is_met(state) + + assert is_met is True + assert reason is None + + def test_require_paused_for_destructive_pass_when_not_playing(self): + """Tool should be available when not in play mode.""" + prereq = ToolPrerequisite(require_paused_for_destructive=True) + state = { + "editor": { + "play_mode": { + "is_playing": False, + "is_paused": False + } + } + } + + is_met, reason = prereq.is_met(state) + + assert is_met is True + assert reason is None + + def test_require_no_tests_fail_when_tests_running(self): + """Tool should be hidden when tests are running.""" + prereq = ToolPrerequisite(require_no_tests=True) + state = { + "tests": { + "is_running": True + } + } + + is_met, reason = prereq.is_met(state) + + assert is_met is False + assert reason == "tests_running" + + def test_require_no_tests_pass_when_tests_not_running(self): + """Tool should be available when tests are not running.""" + prereq = ToolPrerequisite(require_no_tests=True) + state = { + "tests": { + "is_running": False + } + } + + is_met, reason = prereq.is_met(state) + + assert is_met is True + assert reason is None + + def test_combined_prerequisites_all_pass(self): + """Tool should be available when all prerequisites are met.""" + prereq = ToolPrerequisite( + require_no_compile=True, + require_selection=True, + require_paused_for_destructive=True + ) + state = { + "advice": {"blocking_reasons": []}, + "editor": { + "selection": {"has_selection": True}, + "play_mode": {"is_playing": False, "is_paused": False} + }, + "tests": {"is_running": False} + } + + is_met, reason = prereq.is_met(state) + + assert is_met is True + assert reason is None + + def test_combined_prerequisites_first_blocking_wins(self): + """Should return first blocking reason, not all of them.""" + prereq = ToolPrerequisite( + require_no_compile=True, + require_selection=True + ) + state = { + "advice": {"blocking_reasons": ["compiling"]}, + "editor": { + "selection": {"has_selection": False} + } + } + + is_met, reason = prereq.is_met(state) + + assert is_met is False + # Compilation is checked first, so it should be the blocking reason + assert reason == "compiling" + + def test_missing_editor_section_graceful(self): + """Should handle missing editor section gracefully.""" + prereq = ToolPrerequisite(require_selection=True) + state = { + "advice": {"blocking_reasons": []} + } + + is_met, reason = prereq.is_met(state) + + # Missing selection data means we can't confirm no selection + # Fail-open behavior: tool should be visible + assert is_met is True + assert reason is None + + def test_missing_advice_section_graceful(self): + """Should handle missing advice section gracefully.""" + prereq = ToolPrerequisite(require_no_compile=True) + state = {} + + is_met, reason = prereq.is_met(state) + + assert is_met is True + assert reason is None + + +class TestPrerequisiteDecoratorRegistration: + """Test that decorator properly registers tools.""" + + def test_tool_prerequisites_is_dict(self): + """tool_prerequisites should be a dictionary.""" + assert isinstance(tool_prerequisites, dict) + + def test_decorator_registers_tool(self): + """Decorator should add tool to global registry.""" + # Create a dummy function to decorate + def sample_tool(ctx): + return None + + # Manually register like the decorator does + from core.tool_filter_decorator import _prerequisites_lock + tool_name = "test_sample_tool" + with _prerequisites_lock: + tool_prerequisites[tool_name] = ToolPrerequisite(require_selection=True) + + # Verify it was registered + assert tool_name in tool_prerequisites + assert tool_prerequisites[tool_name].require_selection is True + + # Clean up + with _prerequisites_lock: + del tool_prerequisites[tool_name] + + +class TestConcurrentAccess: + """Test thread-safe concurrent access to tool_prerequisites dictionary.""" + + def test_concurrent_read_with_registration(self): + """Concurrent reads during registration should be thread-safe.""" + import threading + + results = [] + errors = [] + + def register_tools(): + try: + for i in range(5): + tool_name = f"concurrent_tool_{i}" + with patch("core.tool_filter_decorator._prerequisites_lock"): + tool_prerequisites[tool_name] = ToolPrerequisite( + require_selection=(i % 2 == 0) + ) + except Exception as e: + errors.append(e) + + def read_tools(): + try: + for _ in range(10): + # Simulate reads + _ = list(tool_prerequisites.keys()) + except Exception as e: + errors.append(e) + + # Start threads + threads = [ + threading.Thread(target=register_tools), + threading.Thread(target=read_tools), + threading.Thread(target=read_tools), + ] + + for t in threads: + t.start() + for t in threads: + t.join() + + # No errors should occur + assert len(errors) == 0 + + # Clean up + with patch("core.tool_filter_decorator._prerequisites_lock"): + for i in range(5): + tool_prerequisites.pop(f"concurrent_tool_{i}", None) diff --git a/TestProjects/UnityMCPTests/Assets/Tests/EditMode/Services/ToolDiscoveryServiceTests.cs b/TestProjects/UnityMCPTests/Assets/Tests/EditMode/Services/ToolDiscoveryServiceTests.cs new file mode 100644 index 000000000..3ae8b7e35 --- /dev/null +++ b/TestProjects/UnityMCPTests/Assets/Tests/EditMode/Services/ToolDiscoveryServiceTests.cs @@ -0,0 +1,115 @@ +using NUnit.Framework; +using MCPForUnity.Editor.Constants; +using MCPForUnity.Editor.Services; +using UnityEditor; + +namespace MCPForUnity.Editor.Tests.EditMode.Services +{ + [TestFixture] + public class ToolDiscoveryServiceTests + { + private const string TestToolName = "test_tool_for_testing"; + + [SetUp] + public void SetUp() + { + // Clean up any test preferences + string testKey = EditorPrefKeys.ToolEnabledPrefix + TestToolName; + if (EditorPrefs.HasKey(testKey)) + { + EditorPrefs.DeleteKey(testKey); + } + } + + [TearDown] + public void TearDown() + { + // Clean up test preferences after each test + string testKey = EditorPrefKeys.ToolEnabledPrefix + TestToolName; + if (EditorPrefs.HasKey(testKey)) + { + EditorPrefs.DeleteKey(testKey); + } + } + + [Test] + public void SetToolEnabled_WritesToEditorPrefs() + { + // Arrange + var service = new ToolDiscoveryService(); + + // Act + service.SetToolEnabled(TestToolName, false); + + // Assert + string key = EditorPrefKeys.ToolEnabledPrefix + TestToolName; + Assert.IsTrue(EditorPrefs.HasKey(key), "Preference key should exist after SetToolEnabled"); + Assert.IsFalse(EditorPrefs.GetBool(key, true), "Preference should be set to false"); + } + + [Test] + public void IsToolEnabled_ReturnsTrue_WhenNoPreferenceExistsAndToolIsBuiltIn() + { + // Arrange - Ensure no preference exists + string key = EditorPrefKeys.ToolEnabledPrefix + TestToolName; + if (EditorPrefs.HasKey(key)) + { + EditorPrefs.DeleteKey(key); + } + + var service = new ToolDiscoveryService(); + + // Act - For a non-existent tool, IsToolEnabled should return false + // (since metadata.AutoRegister defaults to false for non-existent tools) + bool result = service.IsToolEnabled(TestToolName); + + // Assert - Non-existent tools return false (no metadata found) + Assert.IsFalse(result, "Non-existent tool should return false"); + } + + [Test] + public void IsToolEnabled_ReturnsStoredValue_WhenPreferenceExists() + { + // Arrange + string key = EditorPrefKeys.ToolEnabledPrefix + TestToolName; + EditorPrefs.SetBool(key, false); // Store false value + var service = new ToolDiscoveryService(); + + // Act + bool result = service.IsToolEnabled(TestToolName); + + // Assert + Assert.IsFalse(result, "Should return the stored preference value (false)"); + } + + [Test] + public void IsToolEnabled_ReturnsTrue_WhenPreferenceSetToTrue() + { + // Arrange + string key = EditorPrefKeys.ToolEnabledPrefix + TestToolName; + EditorPrefs.SetBool(key, true); + var service = new ToolDiscoveryService(); + + // Act + bool result = service.IsToolEnabled(TestToolName); + + // Assert + Assert.IsTrue(result, "Should return the stored preference value (true)"); + } + + [Test] + public void ToolToggle_PersistsAcrossServiceInstances() + { + // Arrange + var service1 = new ToolDiscoveryService(); + service1.SetToolEnabled(TestToolName, false); + + // Act - Create a new service instance + var service2 = new ToolDiscoveryService(); + bool result = service2.IsToolEnabled(TestToolName); + + // Assert - The disabled state should persist + Assert.IsFalse(result, "Tool state should persist across service instances"); + } + } +} diff --git a/TestProjects/UnityMCPTests/Assets/Tests/EditMode/Services/ToolDiscoveryServiceTests.cs.meta b/TestProjects/UnityMCPTests/Assets/Tests/EditMode/Services/ToolDiscoveryServiceTests.cs.meta new file mode 100644 index 000000000..cea36ebb1 --- /dev/null +++ b/TestProjects/UnityMCPTests/Assets/Tests/EditMode/Services/ToolDiscoveryServiceTests.cs.meta @@ -0,0 +1,11 @@ +fileFormatVersion: 2 +guid: 7a8b9c0d1e2f3a4b5c6d7e8f9a0b1c2d +MonoImporter: + externalObjects: {} + serializedVersion: 2 + defaultReferences: [] + executionOrder: 0 + icon: {instanceID: 0} + userData: + assetBundleName: + assetBundleVariant: diff --git a/TestProjects/UnityMCPTests/Assets/Tests/Editor.meta b/TestProjects/UnityMCPTests/Assets/Tests/Editor.meta new file mode 100644 index 000000000..b6e169d0d --- /dev/null +++ b/TestProjects/UnityMCPTests/Assets/Tests/Editor.meta @@ -0,0 +1,8 @@ +fileFormatVersion: 2 +guid: ea478f98f983e4c4daafd9a07cf03e3b +folderAsset: yes +DefaultImporter: + externalObjects: {} + userData: + assetBundleName: + assetBundleVariant: diff --git a/TestProjects/UnityMCPTests/Assets/Tests/Editor/State.meta b/TestProjects/UnityMCPTests/Assets/Tests/Editor/State.meta new file mode 100644 index 000000000..a4a8007ae --- /dev/null +++ b/TestProjects/UnityMCPTests/Assets/Tests/Editor/State.meta @@ -0,0 +1,8 @@ +fileFormatVersion: 2 +guid: 778e8246ab09497419985acaee595f6c +folderAsset: yes +DefaultImporter: + externalObjects: {} + userData: + assetBundleName: + assetBundleVariant: diff --git a/TestProjects/UnityMCPTests/Assets/Tests/Editor/State/EditorStateSelectionTests.cs b/TestProjects/UnityMCPTests/Assets/Tests/Editor/State/EditorStateSelectionTests.cs new file mode 100644 index 000000000..0ba8ce18c --- /dev/null +++ b/TestProjects/UnityMCPTests/Assets/Tests/Editor/State/EditorStateSelectionTests.cs @@ -0,0 +1,340 @@ +using NUnit.Framework; +using UnityEditor; +using UnityEngine; + +namespace MCPForUnity.Editor.Services +{ + /// + /// Tests for EditorStateCache selection tracking and compilation state. + /// Part of the State-Aware Tool Filtering feature (001-state-aware-tool-filtering). + /// + [TestFixture] + public class EditorStateSelectionTests + { + private GameObject _testGameObject; + private const string TestObjectName = "TestSelectionObject"; + + [SetUp] + public void SetUp() + { + // Create a test GameObject for selection tests + _testGameObject = new GameObject(TestObjectName); + } + + [TearDown] + public void TearDown() + { + // Clear selection and clean up + Selection.activeInstanceID = 0; + if (_testGameObject != null) + { + Object.DestroyImmediate(_testGameObject); + } + } + + #region Selection State Tests + + [Test] + public void GetSnapshot_WithNoSelection_HasSelectionIsFalse() + { + // Arrange - Ensure no selection + Selection.activeInstanceID = 0; + + // Act - Get the snapshot (force refresh to get current state) + var snapshot = EditorStateCache.GetSnapshot(forceRefresh: true); + + // Assert - Verify selection state + Assert.IsNotNull(snapshot, "Snapshot should not be null"); + + var editor = snapshot["editor"] as Newtonsoft.Json.Linq.JObject; + Assert.IsNotNull(editor, "Editor section should exist"); + + var selection = editor["selection"] as Newtonsoft.Json.Linq.JObject; + Assert.IsNotNull(selection, "Selection section should exist"); + + var hasSelection = selection["has_selection"] as Newtonsoft.Json.Linq.JValue; + Assert.IsFalse( + hasSelection != null && (bool)hasSelection, + "has_selection should be false when nothing is selected" + ); + + var selectionCount = selection["selection_count"] as Newtonsoft.Json.Linq.JValue; + Assert.AreEqual( + 0, + selectionCount != null ? (int)selectionCount : 0, + "selection_count should be 0 when nothing is selected" + ); + } + + [Test] + public void GetSnapshot_WithGameObjectSelected_HasSelectionIsTrue() + { + // Arrange - Select the test GameObject + Selection.activeGameObject = _testGameObject; + + // Act - Get the snapshot (force refresh to get current state) + var snapshot = EditorStateCache.GetSnapshot(forceRefresh: true); + + // Assert - Verify selection state + var editor = snapshot["editor"] as Newtonsoft.Json.Linq.JObject; + Assert.IsNotNull(editor, "Editor section should exist"); + + var selection = editor["selection"] as Newtonsoft.Json.Linq.JObject; + Assert.IsNotNull(selection, "Selection section should exist"); + + var hasSelection = selection["has_selection"] as Newtonsoft.Json.Linq.JValue; + Assert.IsTrue( + hasSelection != null && (bool)hasSelection, + "has_selection should be true when GameObject is selected" + ); + + var gameObjectName = selection["active_game_object_name"] as Newtonsoft.Json.Linq.JValue; + Assert.AreEqual( + TestObjectName, + gameObjectName != null ? (string)gameObjectName : string.Empty, + "active_game_object_name should match selected GameObject" + ); + + var selectionCount = selection["selection_count"] as Newtonsoft.Json.Linq.JValue; + Assert.AreEqual( + 1, + selectionCount != null ? (int)selectionCount : 0, + "selection_count should be 1 when one GameObject is selected" + ); + } + + [Test] + public void GetSnapshot_SelectionChange_TriggerUpdate() + { + // Arrange - Start with no selection + Selection.activeInstanceID = 0; + var initialSnapshot = EditorStateCache.GetSnapshot(forceRefresh: true); + var initialSelection = initialSnapshot["editor"]?["selection"]?["has_selection"] as Newtonsoft.Json.Linq.JValue; + bool initialVal = initialSelection != null && (bool)initialSelection; + + // Act - Select a GameObject + Selection.activeGameObject = _testGameObject; + + // Force an update by getting a new snapshot + // The cache updates automatically via Selection.selectionChanged callback + var updatedSnapshot = EditorStateCache.GetSnapshot(forceRefresh: true); + var updatedSelection = updatedSnapshot["editor"]?["selection"]?["has_selection"] as Newtonsoft.Json.Linq.JValue; + bool updatedVal = updatedSelection != null && (bool)updatedSelection; + + // Assert - Selection state should have changed + Assert.AreNotEqual( + initialVal, + updatedVal, + "Selection state should change after GameObject selection" + ); + } + + [Test] + public void GetSnapshot_ActiveInstanceID_IsCorrect() + { + // Arrange - Select the test GameObject + Selection.activeGameObject = _testGameObject; + int expectedInstanceId = _testGameObject.GetInstanceID(); + + // Act - Get the snapshot (force refresh to get current state) + var snapshot = EditorStateCache.GetSnapshot(forceRefresh: true); + + // Assert - Verify instance ID + var editor = snapshot["editor"] as Newtonsoft.Json.Linq.JObject; + Assert.IsNotNull(editor, "Editor section should exist"); + + var selection = editor["selection"] as Newtonsoft.Json.Linq.JObject; + Assert.IsNotNull(selection, "Selection section should exist"); + + var instanceIdToken = selection["active_instance_id"] as Newtonsoft.Json.Linq.JValue; + int actualInstanceId = instanceIdToken != null ? (int)instanceIdToken : 0; + + Assert.AreEqual( + expectedInstanceId, + actualInstanceId, + "active_instance_id should match selected GameObject instance ID" + ); + } + + #endregion + + #region Compilation State Tests + + [Test] + public void GetSnapshot_CompilationState_IsTracked() + { + // Act - Get the snapshot + var snapshot = EditorStateCache.GetSnapshot(); + + // Assert - Verify compilation state exists + var compilation = snapshot["compilation"] as Newtonsoft.Json.Linq.JObject; + Assert.IsNotNull(compilation, "Compilation section should exist"); + + // Check that compilation state fields are present + Assert.IsTrue( + compilation.ContainsKey("is_compiling"), + "is_compiling field should exist" + ); + + Assert.IsTrue( + compilation.ContainsKey("is_domain_reload_pending"), + "is_domain_reload_pending field should exist" + ); + } + + [Test] + public void GetSnapshot_CompilationState_WhenNotCompiling() + { + // Arrange - Wait for any pending compilation to finish + while (EditorApplication.isCompiling) + { + System.Threading.Thread.Sleep(100); + } + + // Act - Get the snapshot + var snapshot = EditorStateCache.GetSnapshot(); + + // Assert - Verify not compiling + var compilation = snapshot["compilation"] as Newtonsoft.Json.Linq.JObject; + var isCompilingToken = compilation["is_compiling"] as Newtonsoft.Json.Linq.JValue; + bool isCompiling = isCompilingToken != null && (bool)isCompilingToken; + + Assert.IsFalse( + isCompiling, + "is_compiling should be false when not compiling" + ); + } + + #endregion + + #region Blocking Reasons Tests + + [Test] + public void GetSnapshot_Advice_ContainsBlockingReasons() + { + // Act - Get the snapshot + var snapshot = EditorStateCache.GetSnapshot(); + + // Assert - Verify advice section exists + var advice = snapshot["advice"] as Newtonsoft.Json.Linq.JObject; + Assert.IsNotNull(advice, "Advice section should exist"); + + Assert.IsTrue( + advice.ContainsKey("blocking_reasons"), + "blocking_reasons field should exist in advice" + ); + + Assert.IsTrue( + advice.ContainsKey("ready_for_tools"), + "ready_for_tools field should exist in advice" + ); + } + + [Test] + public void GetSnapshot_WhenCompiling_CompilingInBlockingReasons() + { + // This test would require triggering compilation, which is complex + // For now, we just verify the structure exists + var snapshot = EditorStateCache.GetSnapshot(); + var advice = snapshot["advice"] as Newtonsoft.Json.Linq.JObject; + Assert.IsNotNull(advice, "Advice section should exist"); + + var blockingReasons = advice["blocking_reasons"] as Newtonsoft.Json.Linq.JArray; + Assert.IsNotNull( + blockingReasons, + "blocking_reasons should be a JArray" + ); + } + + #endregion + + #region Play Mode State Tests + + [Test] + public void GetSnapshot_PlayModeState_IsTracked() + { + // Act - Get the snapshot (should be in edit mode for these tests) + var snapshot = EditorStateCache.GetSnapshot(); + + // Assert - Verify play mode state exists + var editor = snapshot["editor"] as Newtonsoft.Json.Linq.JObject; + Assert.IsNotNull(editor, "Editor section should exist"); + + var playMode = editor["play_mode"] as Newtonsoft.Json.Linq.JObject; + Assert.IsNotNull(playMode, "Play mode section should exist"); + + Assert.IsTrue( + playMode.ContainsKey("is_playing"), + "is_playing field should exist" + ); + + Assert.IsTrue( + playMode.ContainsKey("is_paused"), + "is_paused field should exist" + ); + } + + [Test] + public void GetSnapshot_InEditMode_IsPlayingIsFalse() + { + // Arrange - Ensure we're not in play mode + if (EditorApplication.isPlaying) + { + EditorApplication.isPlaying = false; + } + + // Act - Get the snapshot + var snapshot = EditorStateCache.GetSnapshot(); + + // Assert - Verify not in play mode + var editor = snapshot["editor"] as Newtonsoft.Json.Linq.JObject; + Assert.IsNotNull(editor, "Editor section should exist"); + + var playMode = editor["play_mode"] as Newtonsoft.Json.Linq.JObject; + Assert.IsNotNull(playMode, "Play mode section should exist"); + + var isPlayingToken = playMode["is_playing"] as Newtonsoft.Json.Linq.JValue; + bool isPlaying = isPlayingToken != null && (bool)isPlayingToken; + + Assert.IsFalse( + isPlaying, + "is_playing should be false in edit mode" + ); + } + + #endregion + + #region Snapshot Structure Tests + + [Test] + public void GetSnapshot_HasRequiredFields() + { + // Act - Get the snapshot + var snapshot = EditorStateCache.GetSnapshot(); + + // Assert - Verify top-level fields exist + Assert.IsTrue(snapshot.ContainsKey("schema_version"), "schema_version should exist"); + Assert.IsTrue(snapshot.ContainsKey("observed_at_unix_ms"), "observed_at_unix_ms should exist"); + Assert.IsTrue(snapshot.ContainsKey("sequence"), "sequence should exist"); + Assert.IsTrue(snapshot.ContainsKey("editor"), "editor should exist"); + Assert.IsTrue(snapshot.ContainsKey("advice"), "advice should exist"); + } + + [Test] + public void GetSnapshot_ReturnsNewCloneEachTime() + { + // Act - Get two snapshots + var snapshot1 = EditorStateCache.GetSnapshot(); + var snapshot2 = EditorStateCache.GetSnapshot(); + + // Assert - They should be different object instances + Assert.AreNotSame( + snapshot1, + snapshot2, + "GetSnapshot should return a new clone each time" + ); + } + + #endregion + } +} diff --git a/TestProjects/UnityMCPTests/Assets/Tests/Editor/State/EditorStateSelectionTests.cs.meta b/TestProjects/UnityMCPTests/Assets/Tests/Editor/State/EditorStateSelectionTests.cs.meta new file mode 100644 index 000000000..b0df81459 --- /dev/null +++ b/TestProjects/UnityMCPTests/Assets/Tests/Editor/State/EditorStateSelectionTests.cs.meta @@ -0,0 +1,11 @@ +fileFormatVersion: 2 +guid: 37781283349281c4a88c1686efa6709a +MonoImporter: + externalObjects: {} + serializedVersion: 2 + defaultReferences: [] + executionOrder: 0 + icon: {instanceID: 0} + userData: + assetBundleName: + assetBundleVariant: