Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -217,3 +217,4 @@ workspace/
# Evaluation outputs
eval_outputs/
builds/
results/
59 changes: 52 additions & 7 deletions benchmarks/openagentsafety/run_infer.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,13 +19,53 @@
from benchmarks.utils.models import EvalInstance, EvalMetadata, EvalOutput
from openhands.sdk import LLM, Agent, Conversation, get_logger
from openhands.sdk.workspace import RemoteWorkspace
from openhands.tools.preset.default import get_default_tools
from openhands.workspace import DockerWorkspace


logger = get_logger(__name__)


class ServerCompatibleAgent(Agent):
"""Agent that excludes forbidden LLM fields during serialization for server compatibility.

The OpenHands server rejects certain LLM fields that are not accepted in the API:
- extra_headers
- reasoning_summary
- litellm_extra_body

This agent class overrides model_dump to exclude these fields when serializing.

TODO: This is a temporary workaround. The proper fix should be in the SDK itself,
where Agent.model_dump() should have an option to exclude server-incompatible fields.
See: https://github.com/OpenHands/benchmarks/issues/100
"""

def model_dump(self, **kwargs):
"""Override model_dump to exclude forbidden LLM fields."""
# Get the standard dump
data = super().model_dump(**kwargs)

# Clean the LLM fields if present
if "llm" in data and isinstance(data["llm"], dict):
forbidden_fields = {
"extra_headers",
"reasoning_summary",
"litellm_extra_body",
}
for field in forbidden_fields:
if field in data["llm"]:
logger.debug(
f"Excluding forbidden field '{field}' from agent LLM serialization"
)
del data["llm"][field]

# Ensure the kind field is set to 'Agent' for server compatibility
# (in case the parent class uses the actual class name)
data["kind"] = "Agent"

return data


def convert_numpy_types(obj: Any) -> Any:
"""Recursively convert numpy types to Python native types."""
if isinstance(obj, np.integer):
Expand Down Expand Up @@ -392,13 +432,18 @@ def evaluate_instance(

from pydantic import ValidationError

# Setup tools
tools = get_default_tools(
enable_browser=False,
)
# Use the correct tool names that the server supports
# Server supports: ["BashTool","FileEditorTool","TaskTrackerTool","BrowserToolSet"]
from openhands.sdk import Tool

tools = [
Tool(name="BashTool", params={}),
Tool(name="FileEditorTool", params={}),
Tool(name="TaskTrackerTool", params={}),
]

# Create agent
agent = Agent(llm=self.metadata.llm, tools=tools)
# Create agent with server-compatible serialization
agent = ServerCompatibleAgent(llm=self.metadata.llm, tools=tools)

# Collect events
received_events = []
Expand Down
11 changes: 11 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -82,3 +82,14 @@ members = [
"vendor/software-agent-sdk/openhands-workspace",
"vendor/software-agent-sdk/openhands-agent-server",
]

[tool.pytest.ini_options]
testpaths = ["tests"]
python_files = ["test_*.py"]
python_classes = ["Test*"]
python_functions = ["test_*"]
addopts = [
"--strict-markers",
"--strict-config",
"-ra",
]
43 changes: 36 additions & 7 deletions tests/test_metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
import importlib
import inspect
import json
from contextlib import ExitStack
from pathlib import Path
from unittest.mock import MagicMock, patch

Expand Down Expand Up @@ -315,16 +316,31 @@ def test_benchmark_metrics_collection(
# Setup benchmark-specific mocks
mock_conversation = _setup_mocks_for_benchmark(benchmark_name, expected_metrics)

# Mock common dependencies to avoid actual LLM calls
with (
# Import the benchmark module to check what functions exist
import importlib

benchmark_module = importlib.import_module(f"benchmarks.{benchmark_name}.run_infer")

# Build list of patches - only patch functions that exist in the module
patches = [
patch(
f"benchmarks.{benchmark_name}.run_infer.Conversation",
return_value=mock_conversation,
),
patch(f"benchmarks.{benchmark_name}.run_infer.Agent"),
patch(f"benchmarks.{benchmark_name}.run_infer.get_default_tools"),
patch.dict("os.environ", {"TAVILY_API_KEY": "test-key"}),
):
]

# Only patch get_default_tools if it exists in the module
if hasattr(benchmark_module, "get_default_tools"):
patches.append(
patch(f"benchmarks.{benchmark_name}.run_infer.get_default_tools")
)

# Mock common dependencies to avoid actual LLM calls
with ExitStack() as stack:
for patch_obj in patches:
stack.enter_context(patch_obj)
# Add benchmark-specific patches
if benchmark_name == "swe_bench":
with patch(
Expand Down Expand Up @@ -397,15 +413,28 @@ def test_metrics_with_zero_cost(mock_workspace):
# Setup mocks
mock_conversation = _setup_mocks_for_benchmark(benchmark_name, zero_metrics)

with (
# Import the benchmark module to check what functions exist
benchmark_module = importlib.import_module(f"benchmarks.{benchmark_name}.run_infer")

# Build list of patches - only patch functions that exist in the module
patches = [
patch(
f"benchmarks.{benchmark_name}.run_infer.Conversation",
return_value=mock_conversation,
),
patch(f"benchmarks.{benchmark_name}.run_infer.Agent"),
patch(f"benchmarks.{benchmark_name}.run_infer.get_default_tools"),
patch.dict("os.environ", {"TAVILY_API_KEY": "test-key"}),
):
]

# Only patch get_default_tools if it exists in the module
if hasattr(benchmark_module, "get_default_tools"):
patches.append(
patch(f"benchmarks.{benchmark_name}.run_infer.get_default_tools")
)

with ExitStack() as stack:
for patch_obj in patches:
stack.enter_context(patch_obj)
if benchmark_name == "swe_bench":
with patch(
f"benchmarks.{benchmark_name}.run_infer.get_instruction",
Expand Down
76 changes: 76 additions & 0 deletions tests/test_openagentsafety_fix.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
"""Test for OpenAgentSafety 422 error fix."""

from pydantic import SecretStr

from benchmarks.openagentsafety.run_infer import ServerCompatibleAgent
from openhands.sdk import LLM, Tool


def test_server_compatible_agent_removes_forbidden_llm_fields():
"""Test that ServerCompatibleAgent.model_dump() excludes forbidden LLM fields."""
# Create an LLM with forbidden fields
llm = LLM(
model="test-model",
api_key=SecretStr("test-key"),
extra_headers={"X-Custom": "value"},
reasoning_summary="detailed",
litellm_extra_body={"custom": "data"},
temperature=0.7,
)

# Create agent with this LLM
tools = [Tool(name="BashTool", params={})]
agent = ServerCompatibleAgent(llm=llm, tools=tools)

# Serialize the agent as would be sent to server
agent_data = agent.model_dump()

# Verify forbidden LLM fields are excluded
assert "extra_headers" not in agent_data["llm"]
assert "reasoning_summary" not in agent_data["llm"]
assert "litellm_extra_body" not in agent_data["llm"]

# Verify other LLM fields are preserved
assert agent_data["llm"]["model"] == "test-model"
assert agent_data["llm"]["temperature"] == 0.7

# Verify the kind field is set to "Agent" for server compatibility
assert agent_data["kind"] == "Agent"


def test_server_compatible_agent_with_minimal_llm():
"""Test that the agent works with an LLM without forbidden fields."""
# Create a minimal LLM
llm = LLM(
model="test-model",
temperature=0.5,
)

# Create agent
tools = [Tool(name="BashTool", params={})]
agent = ServerCompatibleAgent(llm=llm, tools=tools)

# Verify it serializes without errors
agent_data = agent.model_dump()
assert agent_data["llm"]["model"] == "test-model"
assert agent_data["llm"]["temperature"] == 0.5
assert agent_data["kind"] == "Agent"


def test_server_compatible_agent_preserves_tools():
"""Test that tools are properly preserved in serialization."""
# Create agent with multiple tools
llm = LLM(model="test-model")
tools = [
Tool(name="BashTool", params={}),
Tool(name="FileEditorTool", params={}),
Tool(name="TaskTrackerTool", params={}),
]
agent = ServerCompatibleAgent(llm=llm, tools=tools)

# Serialize and verify tools are preserved
agent_data = agent.model_dump()
assert len(agent_data["tools"]) == 3
assert agent_data["tools"][0]["name"] == "BashTool"
assert agent_data["tools"][1]["name"] == "FileEditorTool"
assert agent_data["tools"][2]["name"] == "TaskTrackerTool"