Skip to content

Commit

Permalink
feat: Added support for default_headers for azure_openai. (#3211)
Browse files Browse the repository at this point in the history
* hope ci/cd works

* pls work

* remove depracated patch
  • Loading branch information
a-s-gorski committed May 17, 2024
1 parent 99f6a2c commit 2d48192
Show file tree
Hide file tree
Showing 2 changed files with 29 additions and 0 deletions.
5 changes: 5 additions & 0 deletions packages/phoenix-evals/src/phoenix/evals/models/openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
Callable,
Dict,
List,
Mapping,
Optional,
Tuple,
Union,
Expand Down Expand Up @@ -98,6 +99,8 @@ class OpenAIModel(BaseModel):
azure_deployment: Optional[str] = field(default=None)
azure_ad_token: Optional[str] = field(default=None)
azure_ad_token_provider: Optional[Callable[[], str]] = field(default=None)
default_headers: Optional[Mapping[str, str]] = field(default=None)
"""Default headers required by AzureOpenAI"""

# Deprecated fields
model_name: Optional[str] = field(default=None)
Expand Down Expand Up @@ -178,6 +181,7 @@ def _init_open_ai(self) -> None:
azure_ad_token_provider=azure_options.azure_ad_token_provider,
api_key=self.api_key,
organization=self.organization,
default_headers=self.default_headers,
)
self._async_client = self._openai.AsyncAzureOpenAI(
azure_endpoint=azure_options.azure_endpoint,
Expand All @@ -187,6 +191,7 @@ def _init_open_ai(self) -> None:
azure_ad_token_provider=azure_options.azure_ad_token_provider,
api_key=self.api_key,
organization=self.organization,
default_headers=self.default_headers,
)
# return early since we don't need to check the model
return
Expand Down
24 changes: 24 additions & 0 deletions packages/phoenix-evals/tests/phoenix/evals/models/test_openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,30 @@ def test_azure_openai_model(monkeypatch):
assert isinstance(model._client, AzureOpenAI)


def test_azure_openai_model_added_custom_header(monkeypatch):
monkeypatch.setenv(OPENAI_API_KEY_ENVVAR_NAME, "sk-0123456789")
header_key = "header"
header_value = "my-example-header-value"
default_headers = {header_key: header_value}
model = OpenAIModel(
model="gpt-4-turbo-preview",
api_version="2023-07-01-preview",
azure_endpoint="https://example-endpoint.openai.azure.com",
default_headers=default_headers,
)

assert isinstance(model._client, AzureOpenAI)
# check if custom header is added to headers
assert (
header_key in model._client.default_headers
and model._client.default_headers.get(header_key) == header_value
)
assert (
header_key in model._async_client.default_headers
and model._async_client.default_headers.get(header_key) == header_value
)


def test_azure_fails_when_missing_options(monkeypatch):
monkeypatch.setenv(OPENAI_API_KEY_ENVVAR_NAME, "sk-0123456789")
# Test missing api_version
Expand Down

0 comments on commit 2d48192

Please sign in to comment.