Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion docs/model-providers/index.md
Original file line number Diff line number Diff line change
Expand Up @@ -186,7 +186,7 @@ from openarmature.llm import OpenAIProvider, UserMessage

async def main() -> None:
provider = OpenAIProvider(
base_url="http://localhost:8000/v1", # any OpenAI-compatible endpoint
base_url="http://localhost:8000", # any OpenAI-compatible endpoint; host root only, /v1 added by provider
model="some-model",
api_key="optional-for-local-servers",
)
Expand Down
56 changes: 55 additions & 1 deletion src/openarmature/llm/providers/openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@
import uuid
from collections.abc import Sequence
from typing import Any, Literal, cast
from urllib.parse import urlparse

import httpx
import jsonschema
Expand Down Expand Up @@ -111,6 +112,19 @@ class OpenAIProvider:
drives the conformance fixtures by intercepting HTTP calls and
returning canned responses, exercising the same wire-mapping
code production traffic would.

**``base_url`` shape.** Pass the host root only — e.g.
``"https://api.openai.com"`` or ``"http://localhost:8000"``. The
provider appends ``/v1/chat/completions`` and ``/v1/models``
itself. A trailing ``/v1`` on ``base_url`` raises ``ValueError``:
httpx joins paths by appending, so an unprefixed ``base_url``
suffix would produce a doubled ``/v1/v1/...`` wire path that
silently 404/405s on most backends (some — like Bifrost — return
200 for ``GET /v1/v1/models`` while rejecting ``POST
/v1/v1/chat/completions``, leaving the readiness probe green and
every completion broken). Trailing slashes are stripped; other
non-empty paths (proxy prefixes like ``/api/openai-proxy``) are
left intact for intentional proxy setups.
"""

def __init__(
Expand All @@ -124,7 +138,7 @@ def __init__(
force_prompt_augmentation_fallback: bool = False,
genai_system: str = "openai",
) -> None:
self.base_url = base_url.rstrip("/")
self.base_url = _validate_and_normalize_base_url(base_url)
self.model = model
# ``force_prompt_augmentation_fallback`` switches structured-output
# calls from the native response_format wire path to the
Expand Down Expand Up @@ -589,6 +603,46 @@ def _parse_response(
)


# ---------------------------------------------------------------------------
# base_url validation
# ---------------------------------------------------------------------------


# Rejects base_urls that end in /v1 or /v1/ because httpx joins paths by
# appending — a base_url with a trailing /v1 produces a doubled /v1/v1/...
# wire path. The failure mode is sneaky: some backends (Bifrost was the
# motivating case) return 200 for GET /v1/v1/models while rejecting POST
# /v1/v1/chat/completions, so the readiness probe stays green while every
# completion fails. Strict rejection is safer than silent strip — it keeps
# the bug visible at construction time.
def _validate_and_normalize_base_url(base_url: str) -> str:
"""Validate ``base_url`` and return its normalized form.

Strips trailing slashes. Raises :class:`ValueError` when the path
component ends in ``/v1`` (with or without a trailing slash) — the
provider appends ``/v1/`` segments itself, so a base_url with a
``/v1`` suffix would produce a doubled path on the wire. Other
non-empty paths (e.g., proxy prefixes like ``/api/openai-proxy``)
are left intact.
"""
normalized = base_url.rstrip("/")
# ``rstrip`` on the full URL is a no-op when a query string or
# fragment follows the path (e.g., ``https://host/v1/?token=...``
# ends in ``c`` so the URL-level rstrip leaves the parsed path's
# trailing slash intact). Strip the parsed path itself so the
# suffix check catches those shapes too.
path = urlparse(normalized).path.rstrip("/")
if path == "/v1" or path.endswith("/v1"):
raise ValueError(
f"OpenAIProvider base_url must not end with '/v1' — the provider "
f"appends '/v1/chat/completions' and '/v1/models' itself, and "
f"httpx would produce a doubled '/v1/v1/...' wire path. Pass the "
f"host root instead (e.g., 'https://api.openai.com'). "
f"Got: {base_url!r}"
)
return normalized


# ---------------------------------------------------------------------------
# Wire-format helpers
# ---------------------------------------------------------------------------
Expand Down
70 changes: 70 additions & 0 deletions tests/unit/test_llm_provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -207,6 +207,76 @@ def test_validate_tools_duplicate_names_rejected() -> None:
)


# ---------------------------------------------------------------------------
# OpenAIProvider base_url validation
# ---------------------------------------------------------------------------


def test_openai_provider_rejects_v1_suffix() -> None:
with pytest.raises(ValueError, match=r"base_url must not end with '/v1'"):
OpenAIProvider(base_url="http://localhost:8090/v1", model="m", api_key="k")


def test_openai_provider_rejects_v1_suffix_with_trailing_slash() -> None:
with pytest.raises(ValueError, match=r"base_url must not end with '/v1'"):
OpenAIProvider(base_url="http://localhost:8090/v1/", model="m", api_key="k")


def test_openai_provider_rejects_openai_cloud_with_v1() -> None:
# The motivating real-world case: api.openai.com/v1 is in the
# OpenAI docs as the API endpoint, but for OpenAIProvider's
# base_url the /v1 must be omitted.
with pytest.raises(ValueError, match=r"base_url must not end with '/v1'"):
OpenAIProvider(base_url="https://api.openai.com/v1", model="gpt-4", api_key="k")


def test_openai_provider_accepts_host_root() -> None:
provider = OpenAIProvider(base_url="https://api.openai.com", model="gpt-4", api_key="k")
assert provider.base_url == "https://api.openai.com"


def test_openai_provider_accepts_host_root_with_trailing_slash() -> None:
provider = OpenAIProvider(base_url="http://localhost:8090/", model="m", api_key="k")
assert provider.base_url == "http://localhost:8090"


def test_openai_provider_accepts_non_v1_path() -> None:
# Proxy prefixes (Cloudflare AI Gateway, internal reverse proxies)
# are intentional and left alone.
provider = OpenAIProvider(
base_url="https://gateway.example.com/openai-proxy",
model="m",
api_key="k",
)
assert provider.base_url == "https://gateway.example.com/openai-proxy"


def test_openai_provider_accepts_v1_in_middle_of_path() -> None:
# Only a trailing /v1 is rejected — proxies that include /v1
# somewhere mid-path are intentional.
provider = OpenAIProvider(
base_url="https://gateway.example.com/v1/openai-proxy",
model="m",
api_key="k",
)
assert provider.base_url == "https://gateway.example.com/v1/openai-proxy"


def test_openai_provider_rejects_v1_with_query_string() -> None:
# The trailing slash on the path is followed by a query string,
# so a URL-level rstrip("/") doesn't normalize it. The parsed
# path's own trailing slash MUST be stripped before the suffix
# check or this case slips through.
with pytest.raises(ValueError, match=r"base_url must not end with '/v1'"):
OpenAIProvider(base_url="https://host/v1/?token=abc", model="m", api_key="k")


def test_openai_provider_rejects_v1_with_fragment() -> None:
# Same shape as the query-string case but with a URL fragment.
with pytest.raises(ValueError, match=r"base_url must not end with '/v1'"):
OpenAIProvider(base_url="https://host/v1/#frag", model="m", api_key="k")


# ---------------------------------------------------------------------------
# Error categories — canonical string contract + __cause__ preservation
# ---------------------------------------------------------------------------
Expand Down