Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@ def supports_image_generation(self) -> bool:
return True

def completion(self, request: ChatCompletionRequest) -> ChatCompletionResponse:
transport = TransportKwargs.from_request(request)
transport = TransportKwargs.from_request(request, flatten_extra_body=False)
with _handle_non_provider_errors(self.provider_name):
response = self._router.completion(
model=request.model,
Expand All @@ -87,7 +87,7 @@ def completion(self, request: ChatCompletionRequest) -> ChatCompletionResponse:
return parse_chat_completion_response(response)

async def acompletion(self, request: ChatCompletionRequest) -> ChatCompletionResponse:
transport = TransportKwargs.from_request(request)
transport = TransportKwargs.from_request(request, flatten_extra_body=False)
with _handle_non_provider_errors(self.provider_name):
response = await self._router.acompletion(
model=request.model,
Expand All @@ -98,7 +98,7 @@ async def acompletion(self, request: ChatCompletionRequest) -> ChatCompletionRes
return await aparse_chat_completion_response(response)

def embeddings(self, request: EmbeddingRequest) -> EmbeddingResponse:
transport = TransportKwargs.from_request(request)
transport = TransportKwargs.from_request(request, flatten_extra_body=False)
with _handle_non_provider_errors(self.provider_name):
response = self._router.embedding(
model=request.model,
Expand All @@ -110,7 +110,7 @@ def embeddings(self, request: EmbeddingRequest) -> EmbeddingResponse:
return EmbeddingResponse(vectors=vectors, usage=extract_usage(getattr(response, "usage", None)), raw=response)

async def aembeddings(self, request: EmbeddingRequest) -> EmbeddingResponse:
transport = TransportKwargs.from_request(request)
transport = TransportKwargs.from_request(request, flatten_extra_body=False)
with _handle_non_provider_errors(self.provider_name):
response = await self._router.aembedding(
model=request.model,
Expand All @@ -122,7 +122,7 @@ async def aembeddings(self, request: EmbeddingRequest) -> EmbeddingResponse:
return EmbeddingResponse(vectors=vectors, usage=extract_usage(getattr(response, "usage", None)), raw=response)

def generate_image(self, request: ImageGenerationRequest) -> ImageGenerationResponse:
transport = TransportKwargs.from_request(request, exclude=self._IMAGE_EXCLUDE)
transport = TransportKwargs.from_request(request, exclude=self._IMAGE_EXCLUDE, flatten_extra_body=False)
with _handle_non_provider_errors(self.provider_name):
if request.messages is not None:
response = self._router.completion(
Expand All @@ -148,7 +148,7 @@ def generate_image(self, request: ImageGenerationRequest) -> ImageGenerationResp
return ImageGenerationResponse(images=images, usage=usage, raw=response)

async def agenerate_image(self, request: ImageGenerationRequest) -> ImageGenerationResponse:
transport = TransportKwargs.from_request(request, exclude=self._IMAGE_EXCLUDE)
transport = TransportKwargs.from_request(request, exclude=self._IMAGE_EXCLUDE, flatten_extra_body=False)
with _handle_non_provider_errors(self.provider_name):
if request.messages is not None:
response = await self._router.acompletion(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -118,8 +118,9 @@ class TransportKwargs:
Adapters call ``TransportKwargs.from_request(request)`` instead of
manually handling ``extra_body`` / ``extra_headers`` on every request type.

- ``body``: API-level keyword arguments with ``extra_body`` keys merged
into the top level (mirroring how LiteLLM flattens them).
- ``body``: API-level keyword arguments. By default ``extra_body`` keys are
merged into the top level; pass ``flatten_extra_body=False`` to preserve
``extra_body`` as a nested dict (needed by LiteLLM).
- ``headers``: Extra HTTP headers to attach to the outgoing request.
"""

Expand All @@ -129,19 +130,36 @@ class TransportKwargs:
headers: dict[str, str]

@classmethod
def from_request(cls, request: Any, *, exclude: frozenset[str] = frozenset()) -> TransportKwargs:
def from_request(
cls,
request: Any,
*,
exclude: frozenset[str] = frozenset(),
# TODO: remove flatten_extra_body after LiteLLMBridgeClient is retired
flatten_extra_body: bool = True,
) -> TransportKwargs:
"""Build transport-ready kwargs from a canonical request dataclass.

1. Collects all non-None optional fields (respecting *exclude*).
2. Pops ``extra_body`` and merges its keys into the top-level body dict.
2. Handles ``extra_body`` based on *flatten_extra_body*:
- ``True`` (default): merges its keys into the top-level body dict.
- ``False``: preserves it as ``extra_body`` in the body dict so
that callers like LiteLLM can forward it without param validation.
3. Pops ``extra_headers`` into a separate headers dict.
"""
optional_fields = cls._collect_optional_fields(request, exclude=exclude | cls._META_FIELDS)

extra_body = getattr(request, "extra_body", None) or {}
extra_headers = getattr(request, "extra_headers", None) or {}

return cls(body={**optional_fields, **extra_body}, headers=dict(extra_headers))
if flatten_extra_body:
body = {**optional_fields, **extra_body}
else:
body = {**optional_fields}
if extra_body:
body["extra_body"] = extra_body

return cls(body=body, headers=dict(extra_headers))

@staticmethod
def _collect_optional_fields(request: Any, *, exclude: frozenset[str] = frozenset()) -> dict[str, Any]:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -50,8 +50,9 @@ def _identity(x: Any) -> Any:
# Known keyword arguments extracted into request fields for each modality.
# Note: `extra_body` and `extra_headers` appear in every set but receive special
# treatment in `consolidate_kwargs` (merged with provider-level overrides) and in
# `TransportKwargs` (extra_body is flattened into the request body, extra_headers
# are forwarded as HTTP headers). They are NOT regular model parameters.
# `TransportKwargs` (extra_body is either flattened into the request body or
# preserved as a nested dict depending on the adapter; extra_headers are
# forwarded as HTTP headers). They are NOT regular model parameters.
_COMPLETION_REQUEST_FIELDS = frozenset(
{
"temperature",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -57,12 +57,12 @@ def test_completion_maps_canonical_fields_from_litellm_response(
mock_router.completion.assert_called_once_with(
model="stub-model",
messages=[{"role": "user", "content": "hello"}],
extra_headers={"x-trace": "1"},
tools=[{"type": "function", "function": {"name": "lookup"}}],
temperature=0.2,
top_p=0.8,
max_tokens=256,
foo="bar",
extra_body={"foo": "bar"},
extra_headers={"x-trace": "1"},
)


Expand All @@ -86,6 +86,30 @@ async def test_acompletion_maps_canonical_fields_from_litellm_response(
)


def test_completion_passes_extra_body_as_distinct_kwarg(
mock_router: MagicMock,
bridge_client: LiteLLMBridgeClient,
) -> None:
response = _build_chat_response(content="ok", reasoning_content=None, tool_calls=[], usage=None)
mock_router.completion.return_value = response

request = ChatCompletionRequest(
model="stub-model",
messages=[{"role": "user", "content": "hello"}],
temperature=0.5,
extra_body={"reasoning_effort": "high"},
)
bridge_client.completion(request)

mock_router.completion.assert_called_once_with(
model="stub-model",
messages=[{"role": "user", "content": "hello"}],
temperature=0.5,
extra_body={"reasoning_effort": "high"},
extra_headers=None,
)


def test_embeddings_maps_vectors_and_usage(
mock_router: MagicMock,
bridge_client: LiteLLMBridgeClient,
Expand All @@ -107,9 +131,9 @@ def test_embeddings_maps_vectors_and_usage(
mock_router.embedding.assert_called_once_with(
model="stub-model",
input=["a", "b"],
extra_headers=None,
encoding_format="float",
dimensions=32,
extra_headers=None,
)


Expand Down Expand Up @@ -148,8 +172,8 @@ def test_generate_image_uses_chat_completion_path_when_messages_provided(
mock_router.completion.assert_called_once_with(
model="stub-model",
messages=messages,
extra_body={"n": 1},
extra_headers=None,
n=1,
)
mock_router.image_generation.assert_not_called()

Expand Down Expand Up @@ -178,7 +202,7 @@ def test_generate_image_uses_diffusion_path_without_messages(
assert result.usage.total_tokens == 21
assert result.usage.generated_images == 2
mock_router.image_generation.assert_called_once_with(
prompt="make an image", model="stub-model", extra_headers=None, n=2
prompt="make an image", model="stub-model", extra_body={"n": 2}, extra_headers=None
)


Expand Down Expand Up @@ -249,7 +273,7 @@ async def test_agenerate_image_uses_diffusion_path_without_messages(
assert result.usage is not None
assert result.usage.generated_images == 1
mock_router.aimage_generation.assert_awaited_once_with(
prompt="async image", model="stub-model", extra_headers=None, n=1
prompt="async image", model="stub-model", extra_body={"n": 1}, extra_headers=None
)


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
TransportKwargs,
)

# --- TransportKwargs.from_request: extra_body flattening ---
# --- TransportKwargs.from_request: extra_body flattening (default) ---


def test_extra_body_keys_are_flattened_into_body() -> None:
Expand Down Expand Up @@ -46,6 +46,31 @@ def test_extra_body_empty_dict_produces_no_extra_keys() -> None:
assert "extra_body" not in transport.body


# --- TransportKwargs.from_request: extra_body preserved (opt-in) ---


def test_extra_body_preserved_when_flatten_disabled() -> None:
request = ChatCompletionRequest(
model="m",
messages=[],
temperature=0.7,
extra_body={"reasoning_effort": "high", "seed": 42},
)
transport = TransportKwargs.from_request(request, flatten_extra_body=False)

assert transport.body["temperature"] == 0.7
assert transport.body["extra_body"] == {"reasoning_effort": "high", "seed": 42}
assert "reasoning_effort" not in transport.body
assert "seed" not in transport.body


def test_extra_body_empty_dict_not_injected_when_flatten_disabled() -> None:
request = ChatCompletionRequest(model="m", messages=[], extra_body={})
transport = TransportKwargs.from_request(request, flatten_extra_body=False)

assert "extra_body" not in transport.body


# --- TransportKwargs.from_request: extra_headers separation ---


Expand Down
Loading