diff --git a/openhands-agent-server/openhands/agent_server/conversation_service.py b/openhands-agent-server/openhands/agent_server/conversation_service.py index 611f0dbc51..9d5fbba60e 100644 --- a/openhands-agent-server/openhands/agent_server/conversation_service.py +++ b/openhands-agent-server/openhands/agent_server/conversation_service.py @@ -189,6 +189,32 @@ async def start_conversation( ) return conversation_info, False + # Dynamically register tools from client's registry + if request.tool_module_qualnames: + import importlib + + for tool_name, module_qualname in request.tool_module_qualnames.items(): + try: + # Import the module to trigger tool auto-registration + importlib.import_module(module_qualname) + logger.debug( + f"Tool '{tool_name}' registered via module '{module_qualname}'" + ) + except ImportError as e: + logger.warning( + f"Failed to import module '{module_qualname}' for tool " + f"'{tool_name}': {e}. Tool will not be available." + ) + # Continue even if some tools fail to register + # The agent will fail gracefully if it tries to use unregistered + # tools + if request.tool_module_qualnames: + logger.info( + f"Dynamically registered {len(request.tool_module_qualnames)} " + f"tools for conversation {conversation_id}: " + f"{list(request.tool_module_qualnames.keys())}" + ) + stored = StoredConversation(id=conversation_id, **request.model_dump()) event_service = await self._start_event_service(stored) initial_message = request.initial_message diff --git a/openhands-agent-server/openhands/agent_server/models.py b/openhands-agent-server/openhands/agent_server/models.py index a19080f5e9..6258c95dc5 100644 --- a/openhands-agent-server/openhands/agent_server/models.py +++ b/openhands-agent-server/openhands/agent_server/models.py @@ -97,6 +97,14 @@ class StartConversationRequest(BaseModel): default_factory=dict, description="Secrets available in the conversation", ) + tool_module_qualnames: dict[str, str] = Field( + default_factory=dict, + description=( + "Mapping of tool names to their module qualnames from the client's " + "registry. These modules will be dynamically imported on the server " + "to register the tools for this conversation." + ), + ) class StoredConversation(StartConversationRequest): diff --git a/openhands-agent-server/openhands/agent_server/tool_router.py b/openhands-agent-server/openhands/agent_server/tool_router.py index b37a32e9e8..2583afd296 100644 --- a/openhands-agent-server/openhands/agent_server/tool_router.py +++ b/openhands-agent-server/openhands/agent_server/tool_router.py @@ -4,12 +4,13 @@ from openhands.sdk.tool.registry import list_registered_tools from openhands.tools.preset.default import register_default_tools -from openhands.tools.preset.planning import register_planning_tools tool_router = APIRouter(prefix="/tools", tags=["Tools"]) +# Register default tools for backward compatibility +# Planning tools and other custom tools are now dynamically registered +# when creating a RemoteConversation register_default_tools(enable_browser=True) -register_planning_tools() # Tool listing diff --git a/openhands-sdk/openhands/sdk/conversation/impl/remote_conversation.py b/openhands-sdk/openhands/sdk/conversation/impl/remote_conversation.py index 9b53ee3ea7..fbd15d81ca 100644 --- a/openhands-sdk/openhands/sdk/conversation/impl/remote_conversation.py +++ b/openhands-sdk/openhands/sdk/conversation/impl/remote_conversation.py @@ -451,6 +451,9 @@ def __init__( self._client = workspace.client if conversation_id is None: + # Import here to avoid circular imports + from openhands.sdk.tool.registry import get_tool_module_qualnames + payload = { "agent": agent.model_dump( mode="json", context={"expose_secrets": True} @@ -462,6 +465,8 @@ def __init__( "workspace": LocalWorkspace( working_dir=self.workspace.working_dir ).model_dump(), + # Include tool module qualnames for dynamic registration on server + "tool_module_qualnames": get_tool_module_qualnames(), } resp = _send_request( self._client, "POST", "/api/conversations", json=payload diff --git a/openhands-sdk/openhands/sdk/tool/registry.py b/openhands-sdk/openhands/sdk/tool/registry.py index 0095b2b770..2d1942c66e 100644 --- a/openhands-sdk/openhands/sdk/tool/registry.py +++ b/openhands-sdk/openhands/sdk/tool/registry.py @@ -29,6 +29,7 @@ _LOCK = RLock() _REG: dict[str, Resolver] = {} +_MODULE_QUALNAMES: dict[str, str] = {} # Maps tool name to module qualname def _resolver_from_instance(name: str, tool: ToolDefinition) -> Resolver: @@ -137,11 +138,22 @@ def register_tool( "(3) a callable factory returning a Sequence[ToolDefinition]" ) + # Track the module qualname for this tool + module_qualname = None + if isinstance(factory, type): + module_qualname = factory.__module__ + elif callable(factory): + module_qualname = getattr(factory, "__module__", None) + elif isinstance(factory, ToolDefinition): + module_qualname = factory.__class__.__module__ + with _LOCK: # TODO: throw exception when registering duplicate name tools if name in _REG: logger.warning(f"Duplicate tool name registerd {name}") _REG[name] = resolver + if module_qualname: + _MODULE_QUALNAMES[name] = module_qualname def resolve_tool( @@ -159,3 +171,14 @@ def resolve_tool( def list_registered_tools() -> list[str]: with _LOCK: return list(_REG.keys()) + + +def get_tool_module_qualnames() -> dict[str, str]: + """Get a mapping of tool names to their module qualnames. + + Returns: + A dictionary mapping tool names to module qualnames (e.g., + {"glob": "openhands.tools.glob.definition"}). + """ + with _LOCK: + return dict(_MODULE_QUALNAMES) diff --git a/tests/agent_server/test_conversation_router.py b/tests/agent_server/test_conversation_router.py index dc3106083c..9d332e35c2 100644 --- a/tests/agent_server/test_conversation_router.py +++ b/tests/agent_server/test_conversation_router.py @@ -1169,3 +1169,109 @@ def test_generate_conversation_title_invalid_params( assert response.status_code == 422 # Validation error finally: client.app.dependency_overrides.clear() + + +def test_start_conversation_with_tool_module_qualnames( + client, mock_conversation_service, sample_conversation_info +): + """Test start_conversation endpoint with tool_module_qualnames field.""" + + # Mock the service response + mock_conversation_service.start_conversation.return_value = ( + sample_conversation_info, + True, + ) + + # Override the dependency + client.app.dependency_overrides[get_conversation_service] = ( + lambda: mock_conversation_service + ) + + try: + request_data = { + "agent": { + "llm": { + "model": "gpt-4o", + "api_key": "test-key", + "usage_id": "test-llm", + }, + "tools": [ + {"name": "glob"}, + {"name": "grep"}, + {"name": "planning_file_editor"}, + ], + }, + "workspace": {"working_dir": "/tmp/test"}, + "tool_module_qualnames": { + "glob": "openhands.tools.glob.definition", + "grep": "openhands.tools.grep.definition", + "planning_file_editor": ( + "openhands.tools.planning_file_editor.definition" + ), + }, + } + + response = client.post("/api/conversations", json=request_data) + + assert response.status_code == 201 + data = response.json() + assert data["id"] == str(sample_conversation_info.id) + + # Verify service was called + mock_conversation_service.start_conversation.assert_called_once() + call_args = mock_conversation_service.start_conversation.call_args + request_arg = call_args[0][0] + assert hasattr(request_arg, "tool_module_qualnames") + assert request_arg.tool_module_qualnames == { + "glob": "openhands.tools.glob.definition", + "grep": "openhands.tools.grep.definition", + "planning_file_editor": ("openhands.tools.planning_file_editor.definition"), + } + finally: + client.app.dependency_overrides.clear() + + +def test_start_conversation_without_tool_module_qualnames( + client, mock_conversation_service, sample_conversation_info +): + """Test start_conversation endpoint without tool_module_qualnames field.""" + + # Mock the service response + mock_conversation_service.start_conversation.return_value = ( + sample_conversation_info, + True, + ) + + # Override the dependency + client.app.dependency_overrides[get_conversation_service] = ( + lambda: mock_conversation_service + ) + + try: + request_data = { + "agent": { + "llm": { + "model": "gpt-4o", + "api_key": "test-key", + "usage_id": "test-llm", + }, + "tools": [{"name": "TerminalTool"}], + }, + "workspace": {"working_dir": "/tmp/test"}, + } + + response = client.post("/api/conversations", json=request_data) + + assert response.status_code == 201 + data = response.json() + assert data["id"] == str(sample_conversation_info.id) + + # Verify service was called + mock_conversation_service.start_conversation.assert_called_once() + call_args = mock_conversation_service.start_conversation.call_args + request_arg = call_args[0][0] + assert hasattr(request_arg, "tool_module_qualnames") + # Should default to empty dict + assert request_arg.tool_module_qualnames == {} + finally: + client.app.dependency_overrides.clear() diff --git a/tests/sdk/test_registry_qualnames.py b/tests/sdk/test_registry_qualnames.py new file mode 100644 index 0000000000..2d691c3ec6 --- /dev/null +++ b/tests/sdk/test_registry_qualnames.py @@ -0,0 +1,68 @@ +"""Tests for tool registry module qualname tracking.""" + +from openhands.sdk.tool.registry import ( + get_tool_module_qualnames, + list_registered_tools, + register_tool, +) + + +def test_get_tool_module_qualnames_with_class(): + """Test that module qualnames are tracked when registering a class.""" + from openhands.tools.glob import GlobTool + + # Register the GlobTool class + register_tool("test_glob_class", GlobTool) + + # Get the module qualnames + qualnames = get_tool_module_qualnames() + + # Verify the tool is tracked with its module + assert "test_glob_class" in qualnames + assert qualnames["test_glob_class"] == "openhands.tools.glob.definition" + + +def test_get_tool_module_qualnames_with_callable(): + """Test that module qualnames are tracked when registering a callable.""" + + def test_factory(conv_state): + return [] + + # Register the callable + register_tool("test_callable", test_factory) + + # Get the module qualnames + qualnames = get_tool_module_qualnames() + + # Verify the tool is tracked with its module + assert "test_callable" in qualnames + assert "test_registry_qualnames" in qualnames["test_callable"] + + +def test_get_tool_module_qualnames_after_import(): + """Test that importing a tool module registers it with qualname.""" + # Import glob tool module to trigger auto-registration + import openhands.tools.glob.definition # noqa: F401 + + # Get registered tools + registered_tools = list_registered_tools() + + # Should have glob registered + assert "glob" in registered_tools + + # Get module qualnames + qualnames = get_tool_module_qualnames() + + # Verify glob has its module qualname tracked + if "glob" in qualnames: + assert qualnames["glob"] == "openhands.tools.glob.definition" + + +def test_get_tool_module_qualnames_returns_copy(): + """Test that get_tool_module_qualnames returns a copy, not the original dict.""" + qualnames1 = get_tool_module_qualnames() + qualnames2 = get_tool_module_qualnames() + + # Should be equal but not the same object + assert qualnames1 == qualnames2 + assert qualnames1 is not qualnames2