diff --git a/.git-hooks/check_api_key.sh b/.git-hooks/check_api_key.sh index 6bb06596..2d294fed 100755 --- a/.git-hooks/check_api_key.sh +++ b/.git-hooks/check_api_key.sh @@ -1,7 +1,7 @@ #!/bin/bash # Check for `api_key=` in staged changes -if git diff --cached | grep -q "api_key="; then +if git diff --cached | grep -q -E '\bapi_key=[^"]'; then echo "❌ Commit blocked: Found 'api_key=' in staged changes." exit 1 # Prevent commit fi diff --git a/ai21/__init__.py b/ai21/__init__.py index 0548e4ce..82c34d13 100644 --- a/ai21/__init__.py +++ b/ai21/__init__.py @@ -4,18 +4,18 @@ from ai21.clients.azure.ai21_azure_client import AI21AzureClient, AsyncAI21AzureClient from ai21.clients.studio.ai21_client import AI21Client from ai21.clients.studio.async_ai21_client import AsyncAI21Client - from ai21.errors import ( AI21APIError, + AI21Error, APITimeoutError, MissingApiKeyError, ModelPackageDoesntExistError, - AI21Error, TooManyRequestsError, ) from ai21.logger import setup_logger from ai21.version import VERSION + __version__ = VERSION setup_logger() @@ -44,6 +44,18 @@ def _import_vertex_client(): return AI21VertexClient +def _import_launchpad_client(): + from ai21.clients.launchpad.ai21_launchpad_client import AI21LaunchpadClient + + return AI21LaunchpadClient + + +def _import_async_launchpad_client(): + from ai21.clients.launchpad.ai21_launchpad_client import AsyncAI21LaunchpadClient + + return AsyncAI21LaunchpadClient + + def _import_async_vertex_client(): from ai21.clients.vertex.ai21_vertex_client import AsyncAI21VertexClient @@ -66,6 +78,13 @@ def __getattr__(name: str) -> Any: if name == "AsyncAI21VertexClient": return _import_async_vertex_client() + + if name == "AI21LaunchpadClient": + return _import_launchpad_client() + + if name == "AsyncAI21LaunchpadClient": + return _import_async_launchpad_client() + except ImportError as e: raise ImportError('Please install "ai21[AWS]" for Bedrock, or "ai21[Vertex]" for Vertex') from e @@ -87,4 +106,6 @@ def __getattr__(name: str) -> Any: "AsyncAI21BedrockClient", "AI21VertexClient", "AsyncAI21VertexClient", + "AI21LaunchpadClient", + "AsyncAI21LaunchpadClient", ] diff --git a/ai21/clients/common/auth/__init__.py b/ai21/clients/common/auth/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/ai21/clients/vertex/gcp_authorization.py b/ai21/clients/common/auth/gcp_authorization.py similarity index 100% rename from ai21/clients/vertex/gcp_authorization.py rename to ai21/clients/common/auth/gcp_authorization.py diff --git a/ai21/clients/launchpad/__init__.py b/ai21/clients/launchpad/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/ai21/clients/launchpad/ai21_launchpad_client.py b/ai21/clients/launchpad/ai21_launchpad_client.py new file mode 100644 index 00000000..a0aa1dfb --- /dev/null +++ b/ai21/clients/launchpad/ai21_launchpad_client.py @@ -0,0 +1,218 @@ +from __future__ import annotations + +from typing import Any, Dict, Optional + +import httpx + +from google.auth.credentials import Credentials as GCPCredentials + +from ai21.clients.common.auth.gcp_authorization import GCPAuthorization +from ai21.clients.studio.resources.studio_chat import AsyncStudioChat, StudioChat +from ai21.http_client.async_http_client import AsyncAI21HTTPClient +from ai21.http_client.http_client import AI21HTTPClient +from ai21.models.request_options import RequestOptions + + +_DEFAULT_GCP_REGION = "us-central1" +_LAUNCHPAD_BASE_URL_FORMAT = "https://{region}-aiplatform.googleapis.com/v1" +_LAUNCHPAD_PATH_FORMAT = "/projects/{project_id}/locations/{region}/endpoints/{endpoint_id}:{endpoint}" + + +class BaseAI21LaunchpadClient: + def __init__( + self, + region: Optional[str] = None, + project_id: Optional[str] = None, + endpoint_id: Optional[str] = None, + access_token: Optional[str] = None, + credentials: Optional[GCPCredentials] = None, + ): + if access_token is not None and project_id is None: + raise ValueError("Field project_id is required when setting access_token") + self._region = region or _DEFAULT_GCP_REGION + self._access_token = access_token + self._project_id = project_id + self._endpoint_id = endpoint_id + self._credentials = credentials + self._gcp_auth = GCPAuthorization() + + def _get_base_url(self) -> str: + return _LAUNCHPAD_BASE_URL_FORMAT.format(region=self._region) + + def _get_access_token(self) -> str: + if self._access_token is not None: + return self._access_token + + if self._credentials is None: + self._credentials, self._project_id = self._gcp_auth.get_gcp_credentials( + project_id=self._project_id, + ) + + if self._credentials is None: + raise ValueError("Could not get credentials for GCP project") + + self._gcp_auth.refresh_auth(self._credentials) + + if self._credentials.token is None: + raise RuntimeError(f"Could not get access token for GCP project {self._project_id}") + + return self._credentials.token + + def _build_path( + self, + project_id: str, + region: str, + model: str, + endpoint: str, + ) -> str: + return _LAUNCHPAD_PATH_FORMAT.format( + project_id=project_id, + region=region, + endpoint_id=self._endpoint_id, + model=model, + endpoint=endpoint, + ) + + def _get_authorization_header(self) -> Dict[str, Any]: + access_token = self._get_access_token() + return {"Authorization": f"Bearer {access_token}"} + + +class AI21LaunchpadClient(BaseAI21LaunchpadClient, AI21HTTPClient): + def __init__( + self, + region: Optional[str] = None, + project_id: Optional[str] = None, + endpoint_id: Optional[str] = None, + base_url: Optional[str] = None, + access_token: Optional[str] = None, + credentials: Optional[GCPCredentials] = None, + headers: Dict[str, str] | None = None, + timeout_sec: Optional[float] = None, + num_retries: Optional[int] = None, + http_client: Optional[httpx.Client] = None, + ): + BaseAI21LaunchpadClient.__init__( + self, + region=region, + project_id=project_id, + endpoint_id=endpoint_id, + access_token=access_token, + credentials=credentials, + ) + + if base_url is None: + base_url = self._get_base_url() + + AI21HTTPClient.__init__( + self, + base_url=base_url, + timeout_sec=timeout_sec, + num_retries=num_retries, + headers=headers, + client=http_client, + requires_api_key=False, + ) + + self.chat = StudioChat(self) + # Override the chat.create method to match the completions endpoint, + # so it wouldn't get to the old J2 completion endpoint + self.chat.create = self.chat.completions.create + + def _build_request(self, options: RequestOptions) -> httpx.Request: + options = self._prepare_options(options) + + return super()._build_request(options) + + def _prepare_options(self, options: RequestOptions) -> RequestOptions: + body = options.body + + model = body.pop("model") + stream = body.pop("stream", False) + endpoint = "streamRawPredict" if stream else "rawPredict" + headers = self._prepare_headers() + path = self._build_path( + project_id=self._project_id, + region=self._region, + model=model, + endpoint=endpoint, + ) + + return options.replace( + body=body, + path=path, + headers=headers, + ) + + def _prepare_headers(self) -> Dict[str, Any]: + return self._get_authorization_header() + + +class AsyncAI21LaunchpadClient(BaseAI21LaunchpadClient, AsyncAI21HTTPClient): + def __init__( + self, + region: Optional[str] = None, + project_id: Optional[str] = None, + endpoint_id: Optional[str] = None, + base_url: Optional[str] = None, + access_token: Optional[str] = None, + credentials: Optional[GCPCredentials] = None, + headers: Dict[str, str] | None = None, + timeout_sec: Optional[float] = None, + num_retries: Optional[int] = None, + http_client: Optional[httpx.AsyncClient] = None, + ): + BaseAI21LaunchpadClient.__init__( + self, + region=region, + project_id=project_id, + endpoint_id=endpoint_id, + access_token=access_token, + credentials=credentials, + ) + + if base_url is None: + base_url = self._get_base_url() + + AsyncAI21HTTPClient.__init__( + self, + base_url=base_url, + timeout_sec=timeout_sec, + num_retries=num_retries, + headers=headers, + client=http_client, + requires_api_key=False, + ) + + self.chat = AsyncStudioChat(self) + # Override the chat.create method to match the completions endpoint, + # so it wouldn't get to the old J2 completion endpoint + self.chat.create = self.chat.completions.create + + def _build_request(self, options: RequestOptions) -> httpx.Request: + options = self._prepare_options(options) + + return super()._build_request(options) + + def _prepare_options(self, options: RequestOptions) -> RequestOptions: + body = options.body + + model = body.pop("model") + stream = body.pop("stream", False) + endpoint = "streamRawPredict" if stream else "rawPredict" + headers = self._prepare_headers() + path = self._build_path( + project_id=self._project_id, + region=self._region, + model=model, + endpoint=endpoint, + ) + + return options.replace( + body=body, + path=path, + headers=headers, + ) + + def _prepare_headers(self) -> Dict[str, Any]: + return self._get_authorization_header() diff --git a/ai21/clients/vertex/ai21_vertex_client.py b/ai21/clients/vertex/ai21_vertex_client.py index f4f06692..9cad09f3 100644 --- a/ai21/clients/vertex/ai21_vertex_client.py +++ b/ai21/clients/vertex/ai21_vertex_client.py @@ -1,16 +1,18 @@ from __future__ import annotations -from typing import Optional, Dict, Any +from typing import Any, Dict, Optional import httpx + from google.auth.credentials import Credentials as GCPCredentials -from ai21.clients.studio.resources.studio_chat import StudioChat, AsyncStudioChat -from ai21.clients.vertex.gcp_authorization import GCPAuthorization +from ai21.clients.common.auth.gcp_authorization import GCPAuthorization +from ai21.clients.studio.resources.studio_chat import AsyncStudioChat, StudioChat from ai21.http_client.async_http_client import AsyncAI21HTTPClient from ai21.http_client.http_client import AI21HTTPClient from ai21.models.request_options import RequestOptions + _DEFAULT_GCP_REGION = "us-central1" _VERTEX_BASE_URL_FORMAT = "https://{region}-aiplatform.googleapis.com/v1" _VERTEX_PATH_FORMAT = "/projects/{project_id}/locations/{region}/publishers/ai21/models/{model}:{endpoint}" diff --git a/examples/launchpad/__init__.py b/examples/launchpad/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/examples/launchpad/async_chat_completions.py b/examples/launchpad/async_chat_completions.py new file mode 100644 index 00000000..aea556d5 --- /dev/null +++ b/examples/launchpad/async_chat_completions.py @@ -0,0 +1,21 @@ +import asyncio + +from ai21 import AsyncAI21LaunchpadClient +from ai21.models.chat import ChatMessage + + +client = AsyncAI21LaunchpadClient(endpoint_id="") + + +async def main(): + messages = ChatMessage(content="What is the meaning of life?", role="user") + + completion = await client.chat.completions.create( + model="jamba-1.6-large", + messages=[messages], + ) + + print(completion) + + +asyncio.run(main()) diff --git a/examples/launchpad/chat_completions.py b/examples/launchpad/chat_completions.py new file mode 100644 index 00000000..ca0f5f1f --- /dev/null +++ b/examples/launchpad/chat_completions.py @@ -0,0 +1,16 @@ +from ai21 import AI21LaunchpadClient +from ai21.models.chat import ChatMessage + + +client = AI21LaunchpadClient(endpoint_id="") + +messages = ChatMessage(content="What is the meaning of life?", role="user") + +completion = client.chat.completions.create( + model="jamba-1.6-large", + messages=[messages], + stream=True, +) + + +print(completion) diff --git a/tests/unittests/conftest.py b/tests/unittests/conftest.py index 74c8506b..a1d640df 100644 --- a/tests/unittests/conftest.py +++ b/tests/unittests/conftest.py @@ -1,11 +1,12 @@ import boto3 -import pytest import httpx +import pytest + from google.auth.credentials import Credentials from google.auth.transport.requests import Request +from ai21.clients.common.auth.gcp_authorization import GCPAuthorization from ai21.clients.vertex.ai21_vertex_client import AI21VertexClient -from ai21.clients.vertex.gcp_authorization import GCPAuthorization @pytest.fixture diff --git a/tests/unittests/test_gcp_authorization.py b/tests/unittests/test_gcp_authorization.py index bc8b2bfc..e499aa6a 100644 --- a/tests/unittests/test_gcp_authorization.py +++ b/tests/unittests/test_gcp_authorization.py @@ -2,11 +2,13 @@ import google.auth.exceptions import pytest + from google.auth.transport.requests import Request -from ai21.clients.vertex.gcp_authorization import GCPAuthorization +from ai21.clients.common.auth.gcp_authorization import GCPAuthorization from ai21.errors import CredentialsError + _TEST_PROJECT_ID = "test-project" diff --git a/tests/unittests/test_imports.py b/tests/unittests/test_imports.py index 7b6359f4..b9a8c0d8 100644 --- a/tests/unittests/test_imports.py +++ b/tests/unittests/test_imports.py @@ -3,6 +3,7 @@ from ai21 import * # noqa: F403 from ai21 import __all__ + EXPECTED_ALL = [ "AI21EnvConfig", "AI21Client", @@ -20,6 +21,8 @@ "AsyncAI21BedrockClient", "AI21VertexClient", "AsyncAI21VertexClient", + "AI21LaunchpadClient", + "AsyncAI21LaunchpadClient", ]