Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .git-hooks/check_api_key.sh
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
#!/bin/bash

# Check for `api_key=` in staged changes
if git diff --cached | grep -q "api_key="; then
if git diff --cached | grep -q -E '\bapi_key=[^"]'; then
echo "❌ Commit blocked: Found 'api_key=' in staged changes."
exit 1 # Prevent commit
fi
Expand Down
25 changes: 23 additions & 2 deletions ai21/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,18 +4,18 @@
from ai21.clients.azure.ai21_azure_client import AI21AzureClient, AsyncAI21AzureClient
from ai21.clients.studio.ai21_client import AI21Client
from ai21.clients.studio.async_ai21_client import AsyncAI21Client

from ai21.errors import (
AI21APIError,
AI21Error,
APITimeoutError,
MissingApiKeyError,
ModelPackageDoesntExistError,
AI21Error,
TooManyRequestsError,
)
from ai21.logger import setup_logger
from ai21.version import VERSION


__version__ = VERSION
setup_logger()

Expand Down Expand Up @@ -44,6 +44,18 @@ def _import_vertex_client():
return AI21VertexClient


def _import_launchpad_client():
from ai21.clients.launchpad.ai21_launchpad_client import AI21LaunchpadClient

return AI21LaunchpadClient


def _import_async_launchpad_client():
from ai21.clients.launchpad.ai21_launchpad_client import AsyncAI21LaunchpadClient

return AsyncAI21LaunchpadClient


def _import_async_vertex_client():
from ai21.clients.vertex.ai21_vertex_client import AsyncAI21VertexClient

Expand All @@ -66,6 +78,13 @@ def __getattr__(name: str) -> Any:

if name == "AsyncAI21VertexClient":
return _import_async_vertex_client()

if name == "AI21LaunchpadClient":
return _import_launchpad_client()

if name == "AsyncAI21LaunchpadClient":
return _import_async_launchpad_client()

except ImportError as e:
raise ImportError('Please install "ai21[AWS]" for Bedrock, or "ai21[Vertex]" for Vertex') from e

Expand All @@ -87,4 +106,6 @@ def __getattr__(name: str) -> Any:
"AsyncAI21BedrockClient",
"AI21VertexClient",
"AsyncAI21VertexClient",
"AI21LaunchpadClient",
"AsyncAI21LaunchpadClient",
]
Empty file.
Empty file.
218 changes: 218 additions & 0 deletions ai21/clients/launchpad/ai21_launchpad_client.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,218 @@
from __future__ import annotations

from typing import Any, Dict, Optional

import httpx

from google.auth.credentials import Credentials as GCPCredentials

from ai21.clients.common.auth.gcp_authorization import GCPAuthorization
from ai21.clients.studio.resources.studio_chat import AsyncStudioChat, StudioChat
from ai21.http_client.async_http_client import AsyncAI21HTTPClient
from ai21.http_client.http_client import AI21HTTPClient
from ai21.models.request_options import RequestOptions


_DEFAULT_GCP_REGION = "us-central1"
_LAUNCHPAD_BASE_URL_FORMAT = "https://{region}-aiplatform.googleapis.com/v1"
_LAUNCHPAD_PATH_FORMAT = "/projects/{project_id}/locations/{region}/endpoints/{endpoint_id}:{endpoint}"


class BaseAI21LaunchpadClient:
def __init__(
self,
region: Optional[str] = None,
project_id: Optional[str] = None,
endpoint_id: Optional[str] = None,
access_token: Optional[str] = None,
credentials: Optional[GCPCredentials] = None,
):
if access_token is not None and project_id is None:
raise ValueError("Field project_id is required when setting access_token")
self._region = region or _DEFAULT_GCP_REGION
self._access_token = access_token
self._project_id = project_id
self._endpoint_id = endpoint_id
self._credentials = credentials
self._gcp_auth = GCPAuthorization()

def _get_base_url(self) -> str:
return _LAUNCHPAD_BASE_URL_FORMAT.format(region=self._region)

def _get_access_token(self) -> str:
if self._access_token is not None:
return self._access_token

if self._credentials is None:
self._credentials, self._project_id = self._gcp_auth.get_gcp_credentials(
project_id=self._project_id,
)

if self._credentials is None:
raise ValueError("Could not get credentials for GCP project")

self._gcp_auth.refresh_auth(self._credentials)

if self._credentials.token is None:
raise RuntimeError(f"Could not get access token for GCP project {self._project_id}")

return self._credentials.token

def _build_path(
self,
project_id: str,
region: str,
model: str,
endpoint: str,
) -> str:
return _LAUNCHPAD_PATH_FORMAT.format(
project_id=project_id,
region=region,
endpoint_id=self._endpoint_id,
model=model,
endpoint=endpoint,
)

def _get_authorization_header(self) -> Dict[str, Any]:
access_token = self._get_access_token()
return {"Authorization": f"Bearer {access_token}"}


class AI21LaunchpadClient(BaseAI21LaunchpadClient, AI21HTTPClient):
def __init__(
self,
region: Optional[str] = None,
project_id: Optional[str] = None,
endpoint_id: Optional[str] = None,
base_url: Optional[str] = None,
access_token: Optional[str] = None,
credentials: Optional[GCPCredentials] = None,
headers: Dict[str, str] | None = None,
timeout_sec: Optional[float] = None,
num_retries: Optional[int] = None,
http_client: Optional[httpx.Client] = None,
):
BaseAI21LaunchpadClient.__init__(
self,
region=region,
project_id=project_id,
endpoint_id=endpoint_id,
access_token=access_token,
credentials=credentials,
)

if base_url is None:
base_url = self._get_base_url()

AI21HTTPClient.__init__(
self,
base_url=base_url,
timeout_sec=timeout_sec,
num_retries=num_retries,
headers=headers,
client=http_client,
requires_api_key=False,
)

self.chat = StudioChat(self)
# Override the chat.create method to match the completions endpoint,
# so it wouldn't get to the old J2 completion endpoint
self.chat.create = self.chat.completions.create

def _build_request(self, options: RequestOptions) -> httpx.Request:
options = self._prepare_options(options)

return super()._build_request(options)

def _prepare_options(self, options: RequestOptions) -> RequestOptions:
body = options.body

model = body.pop("model")
stream = body.pop("stream", False)
endpoint = "streamRawPredict" if stream else "rawPredict"
headers = self._prepare_headers()
path = self._build_path(
project_id=self._project_id,
region=self._region,
model=model,
endpoint=endpoint,
)

return options.replace(
body=body,
path=path,
headers=headers,
)

def _prepare_headers(self) -> Dict[str, Any]:
return self._get_authorization_header()


class AsyncAI21LaunchpadClient(BaseAI21LaunchpadClient, AsyncAI21HTTPClient):
def __init__(
self,
region: Optional[str] = None,
project_id: Optional[str] = None,
endpoint_id: Optional[str] = None,
base_url: Optional[str] = None,
access_token: Optional[str] = None,
credentials: Optional[GCPCredentials] = None,
headers: Dict[str, str] | None = None,
timeout_sec: Optional[float] = None,
num_retries: Optional[int] = None,
http_client: Optional[httpx.AsyncClient] = None,
):
BaseAI21LaunchpadClient.__init__(
self,
region=region,
project_id=project_id,
endpoint_id=endpoint_id,
access_token=access_token,
credentials=credentials,
)

if base_url is None:
base_url = self._get_base_url()

AsyncAI21HTTPClient.__init__(
self,
base_url=base_url,
timeout_sec=timeout_sec,
num_retries=num_retries,
headers=headers,
client=http_client,
requires_api_key=False,
)

self.chat = AsyncStudioChat(self)
# Override the chat.create method to match the completions endpoint,
# so it wouldn't get to the old J2 completion endpoint
self.chat.create = self.chat.completions.create

def _build_request(self, options: RequestOptions) -> httpx.Request:
options = self._prepare_options(options)

return super()._build_request(options)

def _prepare_options(self, options: RequestOptions) -> RequestOptions:
body = options.body

model = body.pop("model")
stream = body.pop("stream", False)
endpoint = "streamRawPredict" if stream else "rawPredict"
headers = self._prepare_headers()
path = self._build_path(
project_id=self._project_id,
region=self._region,
model=model,
endpoint=endpoint,
)

return options.replace(
body=body,
path=path,
headers=headers,
)

def _prepare_headers(self) -> Dict[str, Any]:
return self._get_authorization_header()
8 changes: 5 additions & 3 deletions ai21/clients/vertex/ai21_vertex_client.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,18 @@
from __future__ import annotations

from typing import Optional, Dict, Any
from typing import Any, Dict, Optional

import httpx

from google.auth.credentials import Credentials as GCPCredentials

from ai21.clients.studio.resources.studio_chat import StudioChat, AsyncStudioChat
from ai21.clients.vertex.gcp_authorization import GCPAuthorization
from ai21.clients.common.auth.gcp_authorization import GCPAuthorization
from ai21.clients.studio.resources.studio_chat import AsyncStudioChat, StudioChat
from ai21.http_client.async_http_client import AsyncAI21HTTPClient
from ai21.http_client.http_client import AI21HTTPClient
from ai21.models.request_options import RequestOptions


_DEFAULT_GCP_REGION = "us-central1"
_VERTEX_BASE_URL_FORMAT = "https://{region}-aiplatform.googleapis.com/v1"
_VERTEX_PATH_FORMAT = "/projects/{project_id}/locations/{region}/publishers/ai21/models/{model}:{endpoint}"
Expand Down
Empty file added examples/launchpad/__init__.py
Empty file.
21 changes: 21 additions & 0 deletions examples/launchpad/async_chat_completions.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
import asyncio

from ai21 import AsyncAI21LaunchpadClient
from ai21.models.chat import ChatMessage


client = AsyncAI21LaunchpadClient(endpoint_id="<your_endpoint_id>")


async def main():
messages = ChatMessage(content="What is the meaning of life?", role="user")

completion = await client.chat.completions.create(
model="jamba-1.6-large",
messages=[messages],
)

print(completion)


asyncio.run(main())
16 changes: 16 additions & 0 deletions examples/launchpad/chat_completions.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
from ai21 import AI21LaunchpadClient
from ai21.models.chat import ChatMessage


client = AI21LaunchpadClient(endpoint_id="<your_endpoint_id>")

messages = ChatMessage(content="What is the meaning of life?", role="user")

completion = client.chat.completions.create(
model="jamba-1.6-large",
messages=[messages],
stream=True,
)


print(completion)
5 changes: 3 additions & 2 deletions tests/unittests/conftest.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
import boto3
import pytest
import httpx
import pytest

from google.auth.credentials import Credentials
from google.auth.transport.requests import Request

from ai21.clients.common.auth.gcp_authorization import GCPAuthorization
from ai21.clients.vertex.ai21_vertex_client import AI21VertexClient
from ai21.clients.vertex.gcp_authorization import GCPAuthorization


@pytest.fixture
Expand Down
4 changes: 3 additions & 1 deletion tests/unittests/test_gcp_authorization.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,13 @@

import google.auth.exceptions
import pytest

from google.auth.transport.requests import Request

from ai21.clients.vertex.gcp_authorization import GCPAuthorization
from ai21.clients.common.auth.gcp_authorization import GCPAuthorization
from ai21.errors import CredentialsError


_TEST_PROJECT_ID = "test-project"


Expand Down
3 changes: 3 additions & 0 deletions tests/unittests/test_imports.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from ai21 import * # noqa: F403
from ai21 import __all__


EXPECTED_ALL = [
"AI21EnvConfig",
"AI21Client",
Expand All @@ -20,6 +21,8 @@
"AsyncAI21BedrockClient",
"AI21VertexClient",
"AsyncAI21VertexClient",
"AI21LaunchpadClient",
"AsyncAI21LaunchpadClient",
]


Expand Down
Loading