Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 19 additions & 0 deletions unstract/sdk1/src/unstract/sdk1/adapters/base1.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from abc import ABC, abstractmethod
from importlib import import_module
from typing import TYPE_CHECKING
from urllib.parse import urlparse

if TYPE_CHECKING:
from typing import Any
Expand Down Expand Up @@ -372,6 +373,16 @@ def _is_openai_reasoning_model(model: str) -> bool:
return bool(_OPENAI_REASONING_MODEL_PATTERN.match(model.lower()))


def _is_openai_api_endpoint(api_base: str | None) -> bool:
"""Return True when api_base points at OpenAI's own HTTPS API host."""
if not api_base:
return False
parsed = urlparse(api_base)
return (
parsed.scheme == "https" and (parsed.hostname or "").lower() == "api.openai.com"
)


class OpenAICompatibleLLMParameters(BaseChatCompletionParameters):
"""See https://docs.litellm.ai/docs/providers/openai_compatible/."""

Expand Down Expand Up @@ -463,6 +474,14 @@ def validate(adapter_metadata: dict[str, "Any"]) -> dict[str, "Any"]:
else:
validated.pop("reasoning_effort", None)

# The custom_openai/ prefix has no entry in the cost price map. Strip
# it only for OpenAI's own endpoint; other gateways price the same
# model name differently, so leave their cost unresolved.
if _is_openai_api_endpoint(validated.get("api_base")):
validated["cost_model"] = validated["model"][
len(_CUSTOM_OPENAI_PROVIDER_PREFIX) :
]

return validated

@staticmethod
Expand Down
50 changes: 50 additions & 0 deletions unstract/sdk1/tests/test_openai_compatible_adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -303,6 +303,56 @@ def test_openai_compatible_validate_no_reasoning_unchanged() -> None:
assert "reasoning_effort" not in validated


def test_openai_compatible_validate_sets_cost_model_for_openai_endpoint() -> None:
# On OpenAI's endpoint, cost_model drops the prefix so pricing resolves.
validated = OpenAICompatibleLLMParameters.validate(
{
"api_base": "https://api.openai.com/v1",
"api_key": "test-key",
"model": "gpt-4o",
}
)

assert validated["model"] == "custom_openai/gpt-4o"
assert validated["cost_model"] == "gpt-4o"


def test_openai_compatible_validate_cost_model_keeps_openai_subprefix() -> None:
validated = OpenAICompatibleLLMParameters.validate(
{
"api_base": "https://api.openai.com/v1",
"model": "custom_openai/openai/gpt-4o",
}
)

assert validated["cost_model"] == "openai/gpt-4o"


def test_openai_compatible_validate_no_cost_model_for_other_gateway() -> None:
# Non-OpenAI gateways price the same name differently; leave it unresolved.
validated = OpenAICompatibleLLMParameters.validate(
{
"api_base": "https://gateway.example.com/v1",
"model": "gpt-4o",
}
)

assert "cost_model" not in validated


def test_openai_compatible_validate_cost_model_stable_on_revalidation() -> None:
# validate() may run on its own previous output; cost_model must survive.
first = OpenAICompatibleLLMParameters.validate(
{
"api_base": "https://api.openai.com/v1",
"model": "gpt-4o",
}
)
second = OpenAICompatibleLLMParameters.validate(first)

assert second["cost_model"] == "gpt-4o"


def test_openai_compatible_adapter_uses_distinct_description_and_icon() -> None:
metadata = OpenAICompatibleLLMAdapter.get_metadata()

Comment thread
greptile-apps[bot] marked this conversation as resolved.
Expand Down