Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -178,7 +178,7 @@ messages = [

chat_completions = client.chat.completions.create(
messages=messages,
model="jamba-1.5-mini",
model="jamba-1.6-mini-2025-03",
)
```

Expand Down Expand Up @@ -208,7 +208,7 @@ client = AsyncAI21Client(
async def main():
response = await client.chat.completions.create(
messages=messages,
model="jamba-1.5-mini",
model="jamba-1.6-mini-2025-03",
)

print(response)
Expand Down Expand Up @@ -346,7 +346,7 @@ client = AsyncAI21Client()
async def main():
response = await client.chat.completions.create(
messages=messages,
model="jamba-1.5-mini",
model="jamba-1.6-mini-2025-03",
stream=True,
)
async for chunk in response:
Expand Down Expand Up @@ -705,7 +705,7 @@ messages = [
]

response = client.chat.completions.create(
model="jamba-1.5-mini",
model="jamba-1.6-mini-2025-03",
messages=messages,
)
```
Expand Down
30 changes: 8 additions & 22 deletions ai21/clients/bedrock/ai21_bedrock_client.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
import json
import logging
import warnings
from typing import Optional, Dict, Any

from typing import Any, Dict, Optional

import boto3
import httpx
Expand All @@ -10,15 +9,16 @@
from ai21.ai21_env_config import AI21EnvConfig
from ai21.clients.aws.aws_authorization import AWSAuthorization
from ai21.clients.bedrock._stream_decoder import _AWSEventStreamDecoder
from ai21.clients.studio.resources.studio_chat import StudioChat, AsyncStudioChat
from ai21.clients.studio.resources.studio_completion import StudioCompletion, AsyncStudioCompletion
from ai21.errors import AccessDenied, NotFound, APITimeoutError, ModelErrorException
from ai21.clients.studio.resources.studio_chat import AsyncStudioChat, StudioChat
from ai21.clients.studio.resources.studio_completion import (
AsyncStudioCompletion,
StudioCompletion,
)
from ai21.errors import AccessDenied, APITimeoutError, ModelErrorException, NotFound
from ai21.http_client.async_http_client import AsyncAI21HTTPClient
from ai21.http_client.http_client import AI21HTTPClient
from ai21.models.request_options import RequestOptions

_logger = logging.getLogger(__name__)


_BEDROCK_URL_FORMAT = "https://bedrock-runtime.{region}.amazonaws.com"

Expand Down Expand Up @@ -76,7 +76,6 @@ def _prepare_headers(self, url: str, body: Dict[str, Any]) -> dict:
class AI21BedrockClient(AI21HTTPClient, BaseBedrockClient):
def __init__(
self,
model_id: Optional[str] = None,
base_url: Optional[str] = None,
region: Optional[str] = None,
headers: Optional[Dict[str, Any]] = None,
Expand All @@ -85,12 +84,6 @@ def __init__(
session: Optional[boto3.Session] = None,
http_client: Optional[httpx.Client] = None,
):
if model_id is not None:
warnings.warn(
"Please consider using the 'model' parameter in the "
"'create' method calls instead of the constructor.",
DeprecationWarning,
)
self._region = region or AI21EnvConfig.aws_region
if base_url is None:
base_url = _BEDROCK_URL_FORMAT.format(region=self._region)
Expand Down Expand Up @@ -128,7 +121,6 @@ def _get_streaming_decoder(self) -> _AWSEventStreamDecoder:
class AsyncAI21BedrockClient(AsyncAI21HTTPClient, BaseBedrockClient):
def __init__(
self,
model_id: Optional[str] = None,
base_url: Optional[str] = None,
region: Optional[str] = None,
headers: Optional[Dict[str, Any]] = None,
Expand All @@ -137,12 +129,6 @@ def __init__(
session: Optional[boto3.Session] = None,
http_client: Optional[httpx.AsyncClient] = None,
):
if model_id is not None:
warnings.warn(
"Please consider using the 'model' parameter in the "
"'create' method calls instead of the constructor.",
DeprecationWarning,
)
self._region = region or AI21EnvConfig.aws_region
if base_url is None:
base_url = _BEDROCK_URL_FORMAT.format(region=self._region)
Expand Down
24 changes: 3 additions & 21 deletions ai21/clients/common/completion_base.py
Original file line number Diff line number Diff line change
@@ -1,35 +1,17 @@
from __future__ import annotations

import warnings
from abc import ABC, abstractmethod
from typing import List, Dict, Optional
from typing import Dict, List

from ai21.models import Penalty, CompletionsResponse
from ai21.models import CompletionsResponse, Penalty
from ai21.models._pydantic_compatibility import _to_dict
from ai21.types import NOT_GIVEN, NotGiven
from ai21.utils.typing import remove_not_given
from ai21.models._pydantic_compatibility import _to_dict


class Completion(ABC):
_module_name = "complete"

def _get_model(self, model: Optional[str], model_id: Optional[str]) -> str:
if model_id is not None:
warnings.warn(
"The 'model_id' parameter is deprecated and will be removed in a future version."
" Please use 'model' instead.",
DeprecationWarning,
stacklevel=2,
)

if model_id and model:
raise ValueError("Please provide only 'model' as 'model_id' is deprecated.")

if not model and not model_id:
raise ValueError("model should be provided 'create' method call")

return model or model_id

@abstractmethod
def create(
self,
Expand Down
19 changes: 14 additions & 5 deletions ai21/clients/studio/resources/chat/async_chat_completions.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,18 @@
from __future__ import annotations

from typing import List, Optional, Any, Literal, overload
from typing import Any, List, Literal, Optional, overload

from ai21.clients.studio.resources.studio_resource import AsyncStudioResource
from ai21.clients.studio.resources.chat.base_chat_completions import BaseChatCompletions
from ai21.clients.studio.resources.studio_resource import AsyncStudioResource
from ai21.models import ChatMessage as J2ChatMessage
from ai21.models.chat import ChatCompletionResponse, ChatCompletionChunk
from ai21.models.chat import ChatCompletionChunk, ChatCompletionResponse
from ai21.models.chat.chat_message import ChatMessageParam
from ai21.models.chat.document_schema import DocumentSchema
from ai21.models.chat.response_format import ResponseFormat
from ai21.models.chat.tool_defintions import ToolDefinition
from ai21.stream.async_stream import AsyncStream
from ai21.types import NotGiven, NOT_GIVEN
from ai21.types import NOT_GIVEN, NotGiven


__all__ = ["AsyncChatCompletions"]

Expand Down Expand Up @@ -74,8 +75,10 @@ async def create(
" instead of ai21.models when working with chat completions."
)

model = self._get_model(model=model)

body = self._create_body(
model=self._get_model(model=model, model_id=kwargs.pop("model_id", None)),
model=model,
messages=messages,
stop=stop,
temperature=temperature,
Expand All @@ -96,3 +99,9 @@ async def create(
stream_cls=AsyncStream[ChatCompletionChunk],
response_cls=ChatCompletionResponse,
)

def _get_model(self, model: str) -> str:
if self._client.__class__.__name__ == "AsyncAI21Client":
return self._check_model(model=model)

return model
32 changes: 18 additions & 14 deletions ai21/clients/studio/resources/chat/base_chat_completions.py
Original file line number Diff line number Diff line change
@@ -1,37 +1,41 @@
from __future__ import annotations

import warnings

from abc import ABC
from typing import List, Optional, Union, Any, Dict, Literal
from typing import Any, Dict, List, Literal, Optional, Union

from ai21.models._pydantic_compatibility import _to_dict
from ai21.models.chat.chat_message import ChatMessageParam
from ai21.models.chat.document_schema import DocumentSchema
from ai21.models.chat.response_format import ResponseFormat
from ai21.models.chat.tool_defintions import ToolDefinition
from ai21.types import NotGiven
from ai21.utils.typing import remove_not_given
from ai21.models._pydantic_compatibility import _to_dict


_MODEL_DEPRECATION_WARNING = """
The 'jamba-1.5-mini' and 'jamba-1.5-large' models are deprecated and will
be removed in a future version.
Please use jamba-mini-1.6-2025-03 or jamba-large-1.6-2025-03 instead.
"""


class BaseChatCompletions(ABC):
_module_name = "chat/completions"

def _get_model(self, model: Optional[str], model_id: Optional[str]) -> str:
if model_id is not None:
def _check_model(self, model: Optional[str]) -> str:
if not model:
raise ValueError("model should be provided 'create' method call")

if model in ["jamba-1.5-mini", "jamba-1.5-large"]:
warnings.warn(
"The 'model_id' parameter is deprecated and will be removed in a future version."
" Please use 'model' instead.",
_MODEL_DEPRECATION_WARNING,
DeprecationWarning,
stacklevel=2,
stacklevel=3,
)

if model_id and model:
raise ValueError("Please provide only 'model' as 'model_id' is deprecated.")

if not model and not model_id:
raise ValueError("model should be provided 'create' method call")

return model or model_id
return model

def _create_body(
self,
Expand Down
21 changes: 15 additions & 6 deletions ai21/clients/studio/resources/chat/chat_completions.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,18 @@
from __future__ import annotations

from typing import List, Optional, Any, Literal, overload
from typing import Any, List, Literal, Optional, overload

from ai21.clients.studio.resources.studio_resource import StudioResource
from ai21.clients.studio.resources.chat.base_chat_completions import BaseChatCompletions
from ai21.clients.studio.resources.studio_resource import StudioResource
from ai21.models import ChatMessage as J2ChatMessage
from ai21.models.chat import ChatCompletionResponse, ChatCompletionChunk
from ai21.models.chat import ChatCompletionChunk, ChatCompletionResponse
from ai21.models.chat.chat_message import ChatMessageParam
from ai21.models.chat.document_schema import DocumentSchema
from ai21.models.chat.response_format import ResponseFormat
from ai21.models.chat.tool_defintions import ToolDefinition
from ai21.stream.stream import Stream
from ai21.types import NotGiven, NOT_GIVEN
from ai21.types import NOT_GIVEN, NotGiven


__all__ = ["ChatCompletions"]

Expand Down Expand Up @@ -56,7 +57,7 @@ def create(
def create(
self,
messages: List[ChatMessageParam],
model: Optional[str] = None,
model: str,
max_tokens: int | NotGiven = NOT_GIVEN,
temperature: float | NotGiven = NOT_GIVEN,
top_p: float | NotGiven = NOT_GIVEN,
Expand All @@ -74,8 +75,10 @@ def create(
" instead of ai21.models when working with chat completions."
)

model = self._get_model(model=model)

body = self._create_body(
model=self._get_model(model=model, model_id=kwargs.pop("model_id", None)),
model=model,
messages=messages,
stop=stop,
temperature=temperature,
Expand All @@ -96,3 +99,9 @@ def create(
stream_cls=Stream[ChatCompletionChunk],
response_cls=ChatCompletionResponse,
)

def _get_model(self, model: str) -> str:
if self._client.__class__.__name__ == "AI21Client":
return self._check_model(model=model)

return model
15 changes: 8 additions & 7 deletions ai21/clients/studio/resources/studio_completion.py
Original file line number Diff line number Diff line change
@@ -1,19 +1,22 @@
from __future__ import annotations

from typing import List, Dict, Optional
from typing import Dict, List

from ai21.clients.common.completion_base import Completion
from ai21.clients.studio.resources.studio_resource import StudioResource, AsyncStudioResource
from ai21.models import Penalty, CompletionsResponse
from ai21.clients.studio.resources.studio_resource import (
AsyncStudioResource,
StudioResource,
)
from ai21.models import CompletionsResponse, Penalty
from ai21.types import NOT_GIVEN, NotGiven


class StudioCompletion(StudioResource, Completion):
def create(
self,
prompt: str,
model: str,
*,
model: Optional[str] = None,
max_tokens: int | NotGiven = NOT_GIVEN,
num_results: int | NotGiven = NOT_GIVEN,
min_tokens: int | NotGiven = NOT_GIVEN,
Expand All @@ -28,7 +31,6 @@ def create(
logit_bias: Dict[str, float] | NotGiven = NOT_GIVEN,
**kwargs,
) -> CompletionsResponse:
model = self._get_model(model=model, model_id=kwargs.pop("model_id", None))
path = self._get_completion_path(model=model)
body = self._create_body(
model=model,
Expand All @@ -54,8 +56,8 @@ class AsyncStudioCompletion(AsyncStudioResource, Completion):
async def create(
self,
prompt: str,
model: str,
*,
model: Optional[str] = None,
max_tokens: int | NotGiven = NOT_GIVEN,
num_results: int | NotGiven = NOT_GIVEN,
min_tokens: int | NotGiven = NOT_GIVEN,
Expand All @@ -70,7 +72,6 @@ async def create(
logit_bias: Dict[str, float] | NotGiven = NOT_GIVEN,
**kwargs,
) -> CompletionsResponse:
model = self._get_model(model=model, model_id=kwargs.pop("model_id", None))
path = self._get_completion_path(model=model)
body = self._create_body(
model=model,
Expand Down
41 changes: 0 additions & 41 deletions examples/studio/async_chat.py

This file was deleted.

Loading
Loading