diff --git a/openhands-agent-server/openhands/agent_server/api.py b/openhands-agent-server/openhands/agent_server/api.py index 0e8eb9acb3..0b10848dfa 100644 --- a/openhands-agent-server/openhands/agent_server/api.py +++ b/openhands-agent-server/openhands/agent_server/api.py @@ -29,11 +29,15 @@ ) from openhands.agent_server.sockets import sockets_router from openhands.agent_server.tool_router import tool_router +from openhands.agent_server.utils import patch_fastapi_discriminated_union_support from openhands.agent_server.vscode_router import vscode_router from openhands.agent_server.vscode_service import get_vscode_service from openhands.sdk.logger import DEBUG, get_logger +# Apply FastAPI patch for discriminated union support +patch_fastapi_discriminated_union_support() + logger = get_logger(__name__) diff --git a/openhands-agent-server/openhands/agent_server/utils.py b/openhands-agent-server/openhands/agent_server/utils.py index b95138403b..b874ef5f07 100644 --- a/openhands-agent-server/openhands/agent_server/utils.py +++ b/openhands-agent-server/openhands/agent_server/utils.py @@ -156,16 +156,7 @@ def patch_fastapi_discriminated_union_support(): Also extracts inline discriminated unions as separate schema components for better OpenAPI documentation and Swagger UI display. - - Skips patching if SKIP_FASTAPI_DISCRIMINATED_UNION_FIX environment variable is set. """ - # Skip patching if environment variable flag is defined - if os.environ.get("SKIP_FASTAPI_DISCRIMINATED_UNION_FIX"): - logger.debug( - "Skipping FastAPI discriminated union patch due to environment variable" - ) - return - try: import fastapi._compat.v2 as fastapi_v2 from fastapi import FastAPI diff --git a/openhands-sdk/openhands/sdk/utils/models.py b/openhands-sdk/openhands/sdk/utils/models.py index 39e520a29d..4e053d2350 100644 --- a/openhands-sdk/openhands/sdk/utils/models.py +++ b/openhands-sdk/openhands/sdk/utils/models.py @@ -12,8 +12,6 @@ TypeAdapter, ) -from openhands.agent_server.utils import patch_fastapi_discriminated_union_support - logger = logging.getLogger(__name__) _rebuild_required = True @@ -302,7 +300,3 @@ def __init_subclass__(cls, **kwargs): def _rebuild_if_required(): if _rebuild_required: rebuild_all() - - -# Always call the FastAPI patch after DiscriminatedUnionMixin definition -patch_fastapi_discriminated_union_support()