diff --git a/README.md b/README.md index 6a04ecd0..5b5aa484 100644 --- a/README.md +++ b/README.md @@ -36,6 +36,7 @@ - [AWS](#AWS) - [SageMaker](#SageMaker) - [Bedrock](#Bedrock) + - [Azure](#Azure) ## Examples (tl;dr) @@ -530,4 +531,33 @@ response = client.completion.create( ) ``` +### Azure + +If you wish to interact with your Azure endpoint on Azure AI Studio, you can use the `AI21AzureClient` +and `AsyncAI21AzureClient`. + +The following models are supported on Azure: + +- `jamba-instruct` + +```python +from ai21 import AI21AzureClient +from ai21.models.chat import ChatMessage + +client = AI21AzureClient( + base_url="https://your-endpoint.inference.ai.azure.com/v1/chat/completions", + api_key="", +) + +messages = [ + ChatMessage(content="You are a helpful assistant", role="system"), + ChatMessage(content="What is the meaning of life?", role="user") +] + +response = client.chat.completions.create( + model="jamba-instruct", + messages=[messages], +) +``` + Happy prompting! 🚀 diff --git a/ai21/__init__.py b/ai21/__init__.py index 6e3757d1..bf6eda84 100644 --- a/ai21/__init__.py +++ b/ai21/__init__.py @@ -1,8 +1,10 @@ from typing import Any from ai21.ai21_env_config import AI21EnvConfig +from ai21.clients.azure.ai21_azure_client import AI21AzureClient, AsyncAI21AzureClient from ai21.clients.studio.ai21_client import AI21Client from ai21.clients.studio.async_ai21_client import AsyncAI21Client + from ai21.errors import ( AI21APIError, APITimeoutError, @@ -65,4 +67,6 @@ def __getattr__(name: str) -> Any: "AI21SageMakerClient", "BedrockModelID", "SageMaker", + "AI21AzureClient", + "AsyncAI21AzureClient", ] diff --git a/ai21/ai21_http_client/base_ai21_http_client.py b/ai21/ai21_http_client/base_ai21_http_client.py index 0d1dab7f..1201df9a 100644 --- a/ai21/ai21_http_client/base_ai21_http_client.py +++ b/ai21/ai21_http_client/base_ai21_http_client.py @@ -27,7 +27,6 @@ def __init__( timeout_sec: Optional[int] = None, num_retries: Optional[int] = None, via: Optional[str] = None, - http_client: Optional[HttpClient] = None, ): self._api_key = api_key diff --git a/ai21/clients/azure/__init__.py b/ai21/clients/azure/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/ai21/clients/azure/ai21_azure_client.py b/ai21/clients/azure/ai21_azure_client.py new file mode 100644 index 00000000..b2d59aee --- /dev/null +++ b/ai21/clients/azure/ai21_azure_client.py @@ -0,0 +1,111 @@ +from __future__ import annotations + +from abc import ABC +from typing import Optional, Callable, Dict + +from ai21.ai21_http_client.ai21_http_client import AI21HTTPClient +from ai21.ai21_http_client.async_ai21_http_client import AsyncAI21HTTPClient +from ai21.clients.studio.resources.studio_chat import StudioChat, AsyncStudioChat + +AzureADTokenProvider = Callable[[], str] + + +class BaseAzureClient(ABC): + _azure_endpoint: str + _api_key: Optional[str] + _azure_ad_token: Optional[str] + _azure_ad_token_provider: Optional[AzureADTokenProvider] + + def _prepare_headers(self, headers: Dict[str, str]) -> Dict[str, str]: + azure_ad_token = self._get_azure_ad_token() + + if azure_ad_token is not None and "Authorization" not in headers: + return { + "Authorization": f"Bearer {azure_ad_token}", + **headers, + } + + if self._api_key is not None: + return { + "api-key": self._api_key, + **headers, + } + + return headers + + def _get_azure_ad_token(self) -> Optional[str]: + if self._azure_ad_token is not None: + return self._azure_ad_token + + if self._azure_ad_token_provider is not None: + return self._azure_ad_token_provider() + + return None + + +class AsyncAI21AzureClient(BaseAzureClient, AsyncAI21HTTPClient): + def __init__( + self, + base_url: str, + api_key: Optional[str] = None, + azure_ad_token: str | None = None, + azure_ad_token_provider: AzureADTokenProvider | None = None, + default_headers: Dict[str, str] | None = None, + timeout_sec: int | None = None, + num_retries: int | None = None, + ): + self._api_key = api_key + self._azure_ad_token = azure_ad_token + self._azure_ad_token_provider = azure_ad_token_provider + + if self._api_key is None and self._azure_ad_token_provider is None and self._azure_ad_token is None: + raise ValueError("Must provide either api_key or azure_ad_token_provider or azure_ad_token") + + headers = self._prepare_headers(headers=default_headers or {}) + + super().__init__( + api_key=api_key, + base_url=base_url, + headers=headers, + timeout_sec=timeout_sec, + num_retries=num_retries, + ) + + self.chat = AsyncStudioChat(self) + # Override the chat.create method to match the completions endpoint, + # so it wouldn't get to the old J2 completion endpoint + self.chat.create = self.chat.completions.create + + +class AI21AzureClient(BaseAzureClient, AI21HTTPClient): + def __init__( + self, + base_url: str, + api_key: Optional[str] = None, + azure_ad_token: str | None = None, + azure_ad_token_provider: AzureADTokenProvider | None = None, + default_headers: Dict[str, str] | None = None, + timeout_sec: int | None = None, + num_retries: int | None = None, + ): + self._api_key = api_key + self._azure_ad_token = azure_ad_token + self._azure_ad_token_provider = azure_ad_token_provider + + if self._api_key is None and self._azure_ad_token_provider is None and self._azure_ad_token is None: + raise ValueError("Must provide either api_key or azure_ad_token_provider or azure_ad_token") + + headers = self._prepare_headers(headers=default_headers or {}) + + super().__init__( + api_key=api_key, + base_url=base_url, + headers=headers, + timeout_sec=timeout_sec, + num_retries=num_retries, + ) + + self.chat = StudioChat(self) + # Override the chat.create method to match the completions endpoint, + # so it wouldn't get to the old J2 completion endpoint + self.chat.create = self.chat.completions.create diff --git a/examples/azure/__init__.py b/examples/azure/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/examples/azure/async_azure_chat_completions.py b/examples/azure/async_azure_chat_completions.py new file mode 100644 index 00000000..9553d30e --- /dev/null +++ b/examples/azure/async_azure_chat_completions.py @@ -0,0 +1,23 @@ +import asyncio + +from ai21 import AsyncAI21AzureClient +from ai21.models.chat import ChatMessage + + +async def chat_completions(): + client = AsyncAI21AzureClient( + base_url="", + api_key="", + ) + + messages = ChatMessage(content="What is the meaning of life?", role="user") + + completion = await client.chat.completions.create( + model="jamba-instruct", + messages=[messages], + ) + + print(completion.to_json()) + + +asyncio.run(chat_completions()) diff --git a/examples/azure/azure_chat_completions.py b/examples/azure/azure_chat_completions.py new file mode 100644 index 00000000..a0bbf1d8 --- /dev/null +++ b/examples/azure/azure_chat_completions.py @@ -0,0 +1,17 @@ +from ai21 import AI21AzureClient + +from ai21.models.chat import ChatMessage + +client = AI21AzureClient( + base_url="", + api_key="", +) + +messages = ChatMessage(content="What is the meaning of life?", role="user") + +completion = client.chat.completions.create( + model="jamba-instruct", + messages=[messages], +) + +print(completion.to_json()) diff --git a/tests/unittests/clients/azure/__init__.py b/tests/unittests/clients/azure/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/unittests/clients/azure/test_chat_completions.py b/tests/unittests/clients/azure/test_chat_completions.py new file mode 100644 index 00000000..390a980d --- /dev/null +++ b/tests/unittests/clients/azure/test_chat_completions.py @@ -0,0 +1,10 @@ +import pytest + +from ai21 import AI21AzureClient + + +def test__azure_client__when_init_with_no_auth__should_raise_error(): + with pytest.raises(ValueError) as e: + AI21AzureClient(base_url="http://some_endpoint_url") + + assert str(e.value) == "Must provide either api_key or azure_ad_token_provider or azure_ad_token" diff --git a/tests/unittests/test_imports.py b/tests/unittests/test_imports.py index b75bb664..76a677ab 100644 --- a/tests/unittests/test_imports.py +++ b/tests/unittests/test_imports.py @@ -4,19 +4,21 @@ from ai21 import __all__ EXPECTED_ALL = [ - "AI21EnvConfig", - "AI21Client", - "AsyncAI21Client", "AI21APIError", - "APITimeoutError", - "AI21Error", - "MissingApiKeyError", - "ModelPackageDoesntExistError", - "TooManyRequestsError", + "AI21AzureClient", "AI21BedrockClient", + "AI21Client", + "AI21EnvConfig", + "AI21Error", "AI21SageMakerClient", + "APITimeoutError", + "AsyncAI21AzureClient", + "AsyncAI21Client", "BedrockModelID", + "MissingApiKeyError", + "ModelPackageDoesntExistError", "SageMaker", + "TooManyRequestsError", ]