diff --git a/unstract/sdk1/src/unstract/sdk1/adapters/base1.py b/unstract/sdk1/src/unstract/sdk1/adapters/base1.py index c78fe5c43d..467b1ff253 100644 --- a/unstract/sdk1/src/unstract/sdk1/adapters/base1.py +++ b/unstract/sdk1/src/unstract/sdk1/adapters/base1.py @@ -6,6 +6,7 @@ from abc import ABC, abstractmethod from importlib import import_module from typing import TYPE_CHECKING +from urllib.parse import urlparse if TYPE_CHECKING: from typing import Any @@ -372,6 +373,16 @@ def _is_openai_reasoning_model(model: str) -> bool: return bool(_OPENAI_REASONING_MODEL_PATTERN.match(model.lower())) +def _is_openai_api_endpoint(api_base: str | None) -> bool: + """Return True when api_base points at OpenAI's own HTTPS API host.""" + if not api_base: + return False + parsed = urlparse(api_base) + return ( + parsed.scheme == "https" and (parsed.hostname or "").lower() == "api.openai.com" + ) + + class OpenAICompatibleLLMParameters(BaseChatCompletionParameters): """See https://docs.litellm.ai/docs/providers/openai_compatible/.""" @@ -463,6 +474,14 @@ def validate(adapter_metadata: dict[str, "Any"]) -> dict[str, "Any"]: else: validated.pop("reasoning_effort", None) + # The custom_openai/ prefix has no entry in the cost price map. Strip + # it only for OpenAI's own endpoint; other gateways price the same + # model name differently, so leave their cost unresolved. + if _is_openai_api_endpoint(validated.get("api_base")): + validated["cost_model"] = validated["model"][ + len(_CUSTOM_OPENAI_PROVIDER_PREFIX) : + ] + return validated @staticmethod diff --git a/unstract/sdk1/tests/test_openai_compatible_adapter.py b/unstract/sdk1/tests/test_openai_compatible_adapter.py index fb7da31658..d1dd631d2b 100644 --- a/unstract/sdk1/tests/test_openai_compatible_adapter.py +++ b/unstract/sdk1/tests/test_openai_compatible_adapter.py @@ -303,6 +303,56 @@ def test_openai_compatible_validate_no_reasoning_unchanged() -> None: assert "reasoning_effort" not in validated +def test_openai_compatible_validate_sets_cost_model_for_openai_endpoint() -> None: + # On OpenAI's endpoint, cost_model drops the prefix so pricing resolves. + validated = OpenAICompatibleLLMParameters.validate( + { + "api_base": "https://api.openai.com/v1", + "api_key": "test-key", + "model": "gpt-4o", + } + ) + + assert validated["model"] == "custom_openai/gpt-4o" + assert validated["cost_model"] == "gpt-4o" + + +def test_openai_compatible_validate_cost_model_keeps_openai_subprefix() -> None: + validated = OpenAICompatibleLLMParameters.validate( + { + "api_base": "https://api.openai.com/v1", + "model": "custom_openai/openai/gpt-4o", + } + ) + + assert validated["cost_model"] == "openai/gpt-4o" + + +def test_openai_compatible_validate_no_cost_model_for_other_gateway() -> None: + # Non-OpenAI gateways price the same name differently; leave it unresolved. + validated = OpenAICompatibleLLMParameters.validate( + { + "api_base": "https://gateway.example.com/v1", + "model": "gpt-4o", + } + ) + + assert "cost_model" not in validated + + +def test_openai_compatible_validate_cost_model_stable_on_revalidation() -> None: + # validate() may run on its own previous output; cost_model must survive. + first = OpenAICompatibleLLMParameters.validate( + { + "api_base": "https://api.openai.com/v1", + "model": "gpt-4o", + } + ) + second = OpenAICompatibleLLMParameters.validate(first) + + assert second["cost_model"] == "gpt-4o" + + def test_openai_compatible_adapter_uses_distinct_description_and_icon() -> None: metadata = OpenAICompatibleLLMAdapter.get_metadata()