From e61b17feea4f52330eade3638fbbc5b23152c776 Mon Sep 17 00:00:00 2001 From: Nabin Mulepati Date: Fri, 13 Mar 2026 10:07:04 -0600 Subject: [PATCH 1/2] fix: preserve extra_body for LiteLLM to avoid UnsupportedParamsError (#409) TransportKwargs.from_request() flattened extra_body keys into top-level kwargs, causing LiteLLM to reject provider-specific params like reasoning_effort via its per-provider allowlist validation. Add a flatten_extra_body flag (default True for backward compat) so the LiteLLM bridge can opt out and preserve extra_body as a distinct kwarg that LiteLLM forwards without validation. Made-with: Cursor --- .../models/clients/adapters/litellm_bridge.py | 12 +++---- .../engine/models/clients/types.py | 23 ++++++++++-- .../models/clients/test_litellm_bridge.py | 36 +++++++++++++++---- .../engine/models/clients/test_parsing.py | 20 ++++++++++- 4 files changed, 75 insertions(+), 16 deletions(-) diff --git a/packages/data-designer-engine/src/data_designer/engine/models/clients/adapters/litellm_bridge.py b/packages/data-designer-engine/src/data_designer/engine/models/clients/adapters/litellm_bridge.py index f5b861b4..017a4af7 100644 --- a/packages/data-designer-engine/src/data_designer/engine/models/clients/adapters/litellm_bridge.py +++ b/packages/data-designer-engine/src/data_designer/engine/models/clients/adapters/litellm_bridge.py @@ -76,7 +76,7 @@ def supports_image_generation(self) -> bool: return True def completion(self, request: ChatCompletionRequest) -> ChatCompletionResponse: - transport = TransportKwargs.from_request(request) + transport = TransportKwargs.from_request(request, flatten_extra_body=False) with _handle_non_provider_errors(self.provider_name): response = self._router.completion( model=request.model, @@ -87,7 +87,7 @@ def completion(self, request: ChatCompletionRequest) -> ChatCompletionResponse: return parse_chat_completion_response(response) async def acompletion(self, request: ChatCompletionRequest) -> ChatCompletionResponse: - transport = TransportKwargs.from_request(request) + transport = TransportKwargs.from_request(request, flatten_extra_body=False) with _handle_non_provider_errors(self.provider_name): response = await self._router.acompletion( model=request.model, @@ -98,7 +98,7 @@ async def acompletion(self, request: ChatCompletionRequest) -> ChatCompletionRes return await aparse_chat_completion_response(response) def embeddings(self, request: EmbeddingRequest) -> EmbeddingResponse: - transport = TransportKwargs.from_request(request) + transport = TransportKwargs.from_request(request, flatten_extra_body=False) with _handle_non_provider_errors(self.provider_name): response = self._router.embedding( model=request.model, @@ -110,7 +110,7 @@ def embeddings(self, request: EmbeddingRequest) -> EmbeddingResponse: return EmbeddingResponse(vectors=vectors, usage=extract_usage(getattr(response, "usage", None)), raw=response) async def aembeddings(self, request: EmbeddingRequest) -> EmbeddingResponse: - transport = TransportKwargs.from_request(request) + transport = TransportKwargs.from_request(request, flatten_extra_body=False) with _handle_non_provider_errors(self.provider_name): response = await self._router.aembedding( model=request.model, @@ -122,7 +122,7 @@ async def aembeddings(self, request: EmbeddingRequest) -> EmbeddingResponse: return EmbeddingResponse(vectors=vectors, usage=extract_usage(getattr(response, "usage", None)), raw=response) def generate_image(self, request: ImageGenerationRequest) -> ImageGenerationResponse: - transport = TransportKwargs.from_request(request, exclude=self._IMAGE_EXCLUDE) + transport = TransportKwargs.from_request(request, exclude=self._IMAGE_EXCLUDE, flatten_extra_body=False) with _handle_non_provider_errors(self.provider_name): if request.messages is not None: response = self._router.completion( @@ -148,7 +148,7 @@ def generate_image(self, request: ImageGenerationRequest) -> ImageGenerationResp return ImageGenerationResponse(images=images, usage=usage, raw=response) async def agenerate_image(self, request: ImageGenerationRequest) -> ImageGenerationResponse: - transport = TransportKwargs.from_request(request, exclude=self._IMAGE_EXCLUDE) + transport = TransportKwargs.from_request(request, exclude=self._IMAGE_EXCLUDE, flatten_extra_body=False) with _handle_non_provider_errors(self.provider_name): if request.messages is not None: response = await self._router.acompletion( diff --git a/packages/data-designer-engine/src/data_designer/engine/models/clients/types.py b/packages/data-designer-engine/src/data_designer/engine/models/clients/types.py index add9142c..92693736 100644 --- a/packages/data-designer-engine/src/data_designer/engine/models/clients/types.py +++ b/packages/data-designer-engine/src/data_designer/engine/models/clients/types.py @@ -129,11 +129,21 @@ class TransportKwargs: headers: dict[str, str] @classmethod - def from_request(cls, request: Any, *, exclude: frozenset[str] = frozenset()) -> TransportKwargs: + def from_request( + cls, + request: Any, + *, + exclude: frozenset[str] = frozenset(), + # TODO: remove flatten_extra_body after LiteLLMBridgeClient is retired + flatten_extra_body: bool = True, + ) -> TransportKwargs: """Build transport-ready kwargs from a canonical request dataclass. 1. Collects all non-None optional fields (respecting *exclude*). - 2. Pops ``extra_body`` and merges its keys into the top-level body dict. + 2. Handles ``extra_body`` based on *flatten_extra_body*: + - ``True`` (default): merges its keys into the top-level body dict. + - ``False``: preserves it as ``extra_body`` in the body dict so + that callers like LiteLLM can forward it without param validation. 3. Pops ``extra_headers`` into a separate headers dict. """ optional_fields = cls._collect_optional_fields(request, exclude=exclude | cls._META_FIELDS) @@ -141,7 +151,14 @@ def from_request(cls, request: Any, *, exclude: frozenset[str] = frozenset()) -> extra_body = getattr(request, "extra_body", None) or {} extra_headers = getattr(request, "extra_headers", None) or {} - return cls(body={**optional_fields, **extra_body}, headers=dict(extra_headers)) + if flatten_extra_body: + body = {**optional_fields, **extra_body} + else: + body = {**optional_fields} + if extra_body: + body["extra_body"] = extra_body + + return cls(body=body, headers=dict(extra_headers)) @staticmethod def _collect_optional_fields(request: Any, *, exclude: frozenset[str] = frozenset()) -> dict[str, Any]: diff --git a/packages/data-designer-engine/tests/engine/models/clients/test_litellm_bridge.py b/packages/data-designer-engine/tests/engine/models/clients/test_litellm_bridge.py index 3c8b1793..34e73e6b 100644 --- a/packages/data-designer-engine/tests/engine/models/clients/test_litellm_bridge.py +++ b/packages/data-designer-engine/tests/engine/models/clients/test_litellm_bridge.py @@ -57,12 +57,12 @@ def test_completion_maps_canonical_fields_from_litellm_response( mock_router.completion.assert_called_once_with( model="stub-model", messages=[{"role": "user", "content": "hello"}], - extra_headers={"x-trace": "1"}, tools=[{"type": "function", "function": {"name": "lookup"}}], temperature=0.2, top_p=0.8, max_tokens=256, - foo="bar", + extra_body={"foo": "bar"}, + extra_headers={"x-trace": "1"}, ) @@ -86,6 +86,30 @@ async def test_acompletion_maps_canonical_fields_from_litellm_response( ) +def test_completion_passes_extra_body_as_distinct_kwarg( + mock_router: MagicMock, + bridge_client: LiteLLMBridgeClient, +) -> None: + response = _build_chat_response(content="ok", reasoning_content=None, tool_calls=[], usage=None) + mock_router.completion.return_value = response + + request = ChatCompletionRequest( + model="stub-model", + messages=[{"role": "user", "content": "hello"}], + temperature=0.5, + extra_body={"reasoning_effort": "high"}, + ) + bridge_client.completion(request) + + mock_router.completion.assert_called_once_with( + model="stub-model", + messages=[{"role": "user", "content": "hello"}], + temperature=0.5, + extra_body={"reasoning_effort": "high"}, + extra_headers=None, + ) + + def test_embeddings_maps_vectors_and_usage( mock_router: MagicMock, bridge_client: LiteLLMBridgeClient, @@ -107,9 +131,9 @@ def test_embeddings_maps_vectors_and_usage( mock_router.embedding.assert_called_once_with( model="stub-model", input=["a", "b"], - extra_headers=None, encoding_format="float", dimensions=32, + extra_headers=None, ) @@ -148,8 +172,8 @@ def test_generate_image_uses_chat_completion_path_when_messages_provided( mock_router.completion.assert_called_once_with( model="stub-model", messages=messages, + extra_body={"n": 1}, extra_headers=None, - n=1, ) mock_router.image_generation.assert_not_called() @@ -178,7 +202,7 @@ def test_generate_image_uses_diffusion_path_without_messages( assert result.usage.total_tokens == 21 assert result.usage.generated_images == 2 mock_router.image_generation.assert_called_once_with( - prompt="make an image", model="stub-model", extra_headers=None, n=2 + prompt="make an image", model="stub-model", extra_body={"n": 2}, extra_headers=None ) @@ -249,7 +273,7 @@ async def test_agenerate_image_uses_diffusion_path_without_messages( assert result.usage is not None assert result.usage.generated_images == 1 mock_router.aimage_generation.assert_awaited_once_with( - prompt="async image", model="stub-model", extra_headers=None, n=1 + prompt="async image", model="stub-model", extra_body={"n": 1}, extra_headers=None ) diff --git a/packages/data-designer-engine/tests/engine/models/clients/test_parsing.py b/packages/data-designer-engine/tests/engine/models/clients/test_parsing.py index a0afd98d..bf7ad6e5 100644 --- a/packages/data-designer-engine/tests/engine/models/clients/test_parsing.py +++ b/packages/data-designer-engine/tests/engine/models/clients/test_parsing.py @@ -13,7 +13,7 @@ TransportKwargs, ) -# --- TransportKwargs.from_request: extra_body flattening --- +# --- TransportKwargs.from_request: extra_body flattening (default) --- def test_extra_body_keys_are_flattened_into_body() -> None: @@ -46,6 +46,24 @@ def test_extra_body_empty_dict_produces_no_extra_keys() -> None: assert "extra_body" not in transport.body +# --- TransportKwargs.from_request: extra_body preserved (opt-in) --- + + +def test_extra_body_preserved_when_flatten_disabled() -> None: + request = ChatCompletionRequest( + model="m", + messages=[], + temperature=0.7, + extra_body={"reasoning_effort": "high", "seed": 42}, + ) + transport = TransportKwargs.from_request(request, flatten_extra_body=False) + + assert transport.body["temperature"] == 0.7 + assert transport.body["extra_body"] == {"reasoning_effort": "high", "seed": 42} + assert "reasoning_effort" not in transport.body + assert "seed" not in transport.body + + # --- TransportKwargs.from_request: extra_headers separation --- From 8cfdce38483898d6d0aaac53f7a949be5018dba9 Mon Sep 17 00:00:00 2001 From: Nabin Mulepati Date: Fri, 13 Mar 2026 13:13:03 -0600 Subject: [PATCH 2/2] fix: address PR #412 review comments Update stale docstrings in TransportKwargs and facade.py to reflect the new flatten_extra_body flag, and add an edge-case test for empty extra_body with flatten disabled. Made-with: Cursor --- .../src/data_designer/engine/models/clients/types.py | 5 +++-- .../src/data_designer/engine/models/facade.py | 5 +++-- .../tests/engine/models/clients/test_parsing.py | 7 +++++++ 3 files changed, 13 insertions(+), 4 deletions(-) diff --git a/packages/data-designer-engine/src/data_designer/engine/models/clients/types.py b/packages/data-designer-engine/src/data_designer/engine/models/clients/types.py index 92693736..034170b5 100644 --- a/packages/data-designer-engine/src/data_designer/engine/models/clients/types.py +++ b/packages/data-designer-engine/src/data_designer/engine/models/clients/types.py @@ -118,8 +118,9 @@ class TransportKwargs: Adapters call ``TransportKwargs.from_request(request)`` instead of manually handling ``extra_body`` / ``extra_headers`` on every request type. - - ``body``: API-level keyword arguments with ``extra_body`` keys merged - into the top level (mirroring how LiteLLM flattens them). + - ``body``: API-level keyword arguments. By default ``extra_body`` keys are + merged into the top level; pass ``flatten_extra_body=False`` to preserve + ``extra_body`` as a nested dict (needed by LiteLLM). - ``headers``: Extra HTTP headers to attach to the outgoing request. """ diff --git a/packages/data-designer-engine/src/data_designer/engine/models/facade.py b/packages/data-designer-engine/src/data_designer/engine/models/facade.py index e89d9851..d13e96ea 100644 --- a/packages/data-designer-engine/src/data_designer/engine/models/facade.py +++ b/packages/data-designer-engine/src/data_designer/engine/models/facade.py @@ -50,8 +50,9 @@ def _identity(x: Any) -> Any: # Known keyword arguments extracted into request fields for each modality. # Note: `extra_body` and `extra_headers` appear in every set but receive special # treatment in `consolidate_kwargs` (merged with provider-level overrides) and in -# `TransportKwargs` (extra_body is flattened into the request body, extra_headers -# are forwarded as HTTP headers). They are NOT regular model parameters. +# `TransportKwargs` (extra_body is either flattened into the request body or +# preserved as a nested dict depending on the adapter; extra_headers are +# forwarded as HTTP headers). They are NOT regular model parameters. _COMPLETION_REQUEST_FIELDS = frozenset( { "temperature", diff --git a/packages/data-designer-engine/tests/engine/models/clients/test_parsing.py b/packages/data-designer-engine/tests/engine/models/clients/test_parsing.py index bf7ad6e5..e6662e3f 100644 --- a/packages/data-designer-engine/tests/engine/models/clients/test_parsing.py +++ b/packages/data-designer-engine/tests/engine/models/clients/test_parsing.py @@ -64,6 +64,13 @@ def test_extra_body_preserved_when_flatten_disabled() -> None: assert "seed" not in transport.body +def test_extra_body_empty_dict_not_injected_when_flatten_disabled() -> None: + request = ChatCompletionRequest(model="m", messages=[], extra_body={}) + transport = TransportKwargs.from_request(request, flatten_extra_body=False) + + assert "extra_body" not in transport.body + + # --- TransportKwargs.from_request: extra_headers separation ---