diff --git a/ai21/clients/azure/ai21_azure_client.py b/ai21/clients/azure/ai21_azure_client.py index b2d59aee..39f45ee4 100644 --- a/ai21/clients/azure/ai21_azure_client.py +++ b/ai21/clients/azure/ai21_azure_client.py @@ -8,6 +8,7 @@ from ai21.clients.studio.resources.studio_chat import StudioChat, AsyncStudioChat AzureADTokenProvider = Callable[[], str] +_DEFAULT_AZURE_VERSION = "v1" class BaseAzureClient(ABC): @@ -42,11 +43,18 @@ def _get_azure_ad_token(self) -> Optional[str]: return None + def _add_version_to_url(self, base_url: str, api_version: str) -> str: + if api_version: + return f"{base_url}/{api_version}" + + return f"{base_url}/{_DEFAULT_AZURE_VERSION}" + class AsyncAI21AzureClient(BaseAzureClient, AsyncAI21HTTPClient): def __init__( self, base_url: str, + api_version: str = _DEFAULT_AZURE_VERSION, api_key: Optional[str] = None, azure_ad_token: str | None = None, azure_ad_token_provider: AzureADTokenProvider | None = None, @@ -62,6 +70,7 @@ def __init__( raise ValueError("Must provide either api_key or azure_ad_token_provider or azure_ad_token") headers = self._prepare_headers(headers=default_headers or {}) + base_url = self._add_version_to_url(base_url=base_url, api_version=api_version) super().__init__( api_key=api_key, @@ -81,6 +90,7 @@ class AI21AzureClient(BaseAzureClient, AI21HTTPClient): def __init__( self, base_url: str, + api_version: str = _DEFAULT_AZURE_VERSION, api_key: Optional[str] = None, azure_ad_token: str | None = None, azure_ad_token_provider: AzureADTokenProvider | None = None, @@ -96,6 +106,7 @@ def __init__( raise ValueError("Must provide either api_key or azure_ad_token_provider or azure_ad_token") headers = self._prepare_headers(headers=default_headers or {}) + base_url = self._add_version_to_url(base_url=base_url, api_version=api_version) super().__init__( api_key=api_key,