diff --git a/.gitignore b/.gitignore index 459fad58..a123918e 100644 --- a/.gitignore +++ b/.gitignore @@ -217,3 +217,4 @@ workspace/ # Evaluation outputs eval_outputs/ builds/ +results/ diff --git a/benchmarks/openagentsafety/run_infer.py b/benchmarks/openagentsafety/run_infer.py index 5fd09292..8e06ce8f 100644 --- a/benchmarks/openagentsafety/run_infer.py +++ b/benchmarks/openagentsafety/run_infer.py @@ -19,13 +19,53 @@ from benchmarks.utils.models import EvalInstance, EvalMetadata, EvalOutput from openhands.sdk import LLM, Agent, Conversation, get_logger from openhands.sdk.workspace import RemoteWorkspace -from openhands.tools.preset.default import get_default_tools from openhands.workspace import DockerWorkspace logger = get_logger(__name__) +class ServerCompatibleAgent(Agent): + """Agent that excludes forbidden LLM fields during serialization for server compatibility. + + The OpenHands server rejects certain LLM fields that are not accepted in the API: + - extra_headers + - reasoning_summary + - litellm_extra_body + + This agent class overrides model_dump to exclude these fields when serializing. + + TODO: This is a temporary workaround. The proper fix should be in the SDK itself, + where Agent.model_dump() should have an option to exclude server-incompatible fields. + See: https://github.com/OpenHands/benchmarks/issues/100 + """ + + def model_dump(self, **kwargs): + """Override model_dump to exclude forbidden LLM fields.""" + # Get the standard dump + data = super().model_dump(**kwargs) + + # Clean the LLM fields if present + if "llm" in data and isinstance(data["llm"], dict): + forbidden_fields = { + "extra_headers", + "reasoning_summary", + "litellm_extra_body", + } + for field in forbidden_fields: + if field in data["llm"]: + logger.debug( + f"Excluding forbidden field '{field}' from agent LLM serialization" + ) + del data["llm"][field] + + # Ensure the kind field is set to 'Agent' for server compatibility + # (in case the parent class uses the actual class name) + data["kind"] = "Agent" + + return data + + def convert_numpy_types(obj: Any) -> Any: """Recursively convert numpy types to Python native types.""" if isinstance(obj, np.integer): @@ -392,13 +432,18 @@ def evaluate_instance( from pydantic import ValidationError - # Setup tools - tools = get_default_tools( - enable_browser=False, - ) + # Use the correct tool names that the server supports + # Server supports: ["BashTool","FileEditorTool","TaskTrackerTool","BrowserToolSet"] + from openhands.sdk import Tool + + tools = [ + Tool(name="BashTool", params={}), + Tool(name="FileEditorTool", params={}), + Tool(name="TaskTrackerTool", params={}), + ] - # Create agent - agent = Agent(llm=self.metadata.llm, tools=tools) + # Create agent with server-compatible serialization + agent = ServerCompatibleAgent(llm=self.metadata.llm, tools=tools) # Collect events received_events = [] diff --git a/pyproject.toml b/pyproject.toml index a7cf1cbe..345caa5b 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -82,3 +82,14 @@ members = [ "vendor/software-agent-sdk/openhands-workspace", "vendor/software-agent-sdk/openhands-agent-server", ] + +[tool.pytest.ini_options] +testpaths = ["tests"] +python_files = ["test_*.py"] +python_classes = ["Test*"] +python_functions = ["test_*"] +addopts = [ + "--strict-markers", + "--strict-config", + "-ra", +] diff --git a/tests/test_metrics.py b/tests/test_metrics.py index 89585191..1cd7abb4 100644 --- a/tests/test_metrics.py +++ b/tests/test_metrics.py @@ -11,6 +11,7 @@ import importlib import inspect import json +from contextlib import ExitStack from pathlib import Path from unittest.mock import MagicMock, patch @@ -315,16 +316,31 @@ def test_benchmark_metrics_collection( # Setup benchmark-specific mocks mock_conversation = _setup_mocks_for_benchmark(benchmark_name, expected_metrics) - # Mock common dependencies to avoid actual LLM calls - with ( + # Import the benchmark module to check what functions exist + import importlib + + benchmark_module = importlib.import_module(f"benchmarks.{benchmark_name}.run_infer") + + # Build list of patches - only patch functions that exist in the module + patches = [ patch( f"benchmarks.{benchmark_name}.run_infer.Conversation", return_value=mock_conversation, ), patch(f"benchmarks.{benchmark_name}.run_infer.Agent"), - patch(f"benchmarks.{benchmark_name}.run_infer.get_default_tools"), patch.dict("os.environ", {"TAVILY_API_KEY": "test-key"}), - ): + ] + + # Only patch get_default_tools if it exists in the module + if hasattr(benchmark_module, "get_default_tools"): + patches.append( + patch(f"benchmarks.{benchmark_name}.run_infer.get_default_tools") + ) + + # Mock common dependencies to avoid actual LLM calls + with ExitStack() as stack: + for patch_obj in patches: + stack.enter_context(patch_obj) # Add benchmark-specific patches if benchmark_name == "swe_bench": with patch( @@ -397,15 +413,28 @@ def test_metrics_with_zero_cost(mock_workspace): # Setup mocks mock_conversation = _setup_mocks_for_benchmark(benchmark_name, zero_metrics) - with ( + # Import the benchmark module to check what functions exist + benchmark_module = importlib.import_module(f"benchmarks.{benchmark_name}.run_infer") + + # Build list of patches - only patch functions that exist in the module + patches = [ patch( f"benchmarks.{benchmark_name}.run_infer.Conversation", return_value=mock_conversation, ), patch(f"benchmarks.{benchmark_name}.run_infer.Agent"), - patch(f"benchmarks.{benchmark_name}.run_infer.get_default_tools"), patch.dict("os.environ", {"TAVILY_API_KEY": "test-key"}), - ): + ] + + # Only patch get_default_tools if it exists in the module + if hasattr(benchmark_module, "get_default_tools"): + patches.append( + patch(f"benchmarks.{benchmark_name}.run_infer.get_default_tools") + ) + + with ExitStack() as stack: + for patch_obj in patches: + stack.enter_context(patch_obj) if benchmark_name == "swe_bench": with patch( f"benchmarks.{benchmark_name}.run_infer.get_instruction", diff --git a/tests/test_openagentsafety_fix.py b/tests/test_openagentsafety_fix.py new file mode 100644 index 00000000..75f0f91c --- /dev/null +++ b/tests/test_openagentsafety_fix.py @@ -0,0 +1,76 @@ +"""Test for OpenAgentSafety 422 error fix.""" + +from pydantic import SecretStr + +from benchmarks.openagentsafety.run_infer import ServerCompatibleAgent +from openhands.sdk import LLM, Tool + + +def test_server_compatible_agent_removes_forbidden_llm_fields(): + """Test that ServerCompatibleAgent.model_dump() excludes forbidden LLM fields.""" + # Create an LLM with forbidden fields + llm = LLM( + model="test-model", + api_key=SecretStr("test-key"), + extra_headers={"X-Custom": "value"}, + reasoning_summary="detailed", + litellm_extra_body={"custom": "data"}, + temperature=0.7, + ) + + # Create agent with this LLM + tools = [Tool(name="BashTool", params={})] + agent = ServerCompatibleAgent(llm=llm, tools=tools) + + # Serialize the agent as would be sent to server + agent_data = agent.model_dump() + + # Verify forbidden LLM fields are excluded + assert "extra_headers" not in agent_data["llm"] + assert "reasoning_summary" not in agent_data["llm"] + assert "litellm_extra_body" not in agent_data["llm"] + + # Verify other LLM fields are preserved + assert agent_data["llm"]["model"] == "test-model" + assert agent_data["llm"]["temperature"] == 0.7 + + # Verify the kind field is set to "Agent" for server compatibility + assert agent_data["kind"] == "Agent" + + +def test_server_compatible_agent_with_minimal_llm(): + """Test that the agent works with an LLM without forbidden fields.""" + # Create a minimal LLM + llm = LLM( + model="test-model", + temperature=0.5, + ) + + # Create agent + tools = [Tool(name="BashTool", params={})] + agent = ServerCompatibleAgent(llm=llm, tools=tools) + + # Verify it serializes without errors + agent_data = agent.model_dump() + assert agent_data["llm"]["model"] == "test-model" + assert agent_data["llm"]["temperature"] == 0.5 + assert agent_data["kind"] == "Agent" + + +def test_server_compatible_agent_preserves_tools(): + """Test that tools are properly preserved in serialization.""" + # Create agent with multiple tools + llm = LLM(model="test-model") + tools = [ + Tool(name="BashTool", params={}), + Tool(name="FileEditorTool", params={}), + Tool(name="TaskTrackerTool", params={}), + ] + agent = ServerCompatibleAgent(llm=llm, tools=tools) + + # Serialize and verify tools are preserved + agent_data = agent.model_dump() + assert len(agent_data["tools"]) == 3 + assert agent_data["tools"][0]["name"] == "BashTool" + assert agent_data["tools"][1]["name"] == "FileEditorTool" + assert agent_data["tools"][2]["name"] == "TaskTrackerTool"