diff --git a/README.md b/README.md index b3bbeb80..f61214e4 100644 --- a/README.md +++ b/README.md @@ -178,7 +178,7 @@ messages = [ chat_completions = client.chat.completions.create( messages=messages, - model="jamba-1.5-mini", + model="jamba-1.6-mini-2025-03", ) ``` @@ -208,7 +208,7 @@ client = AsyncAI21Client( async def main(): response = await client.chat.completions.create( messages=messages, - model="jamba-1.5-mini", + model="jamba-1.6-mini-2025-03", ) print(response) @@ -346,7 +346,7 @@ client = AsyncAI21Client() async def main(): response = await client.chat.completions.create( messages=messages, - model="jamba-1.5-mini", + model="jamba-1.6-mini-2025-03", stream=True, ) async for chunk in response: @@ -705,7 +705,7 @@ messages = [ ] response = client.chat.completions.create( - model="jamba-1.5-mini", + model="jamba-1.6-mini-2025-03", messages=messages, ) ``` diff --git a/ai21/clients/bedrock/ai21_bedrock_client.py b/ai21/clients/bedrock/ai21_bedrock_client.py index 0eded54f..d2cf5cfe 100644 --- a/ai21/clients/bedrock/ai21_bedrock_client.py +++ b/ai21/clients/bedrock/ai21_bedrock_client.py @@ -1,7 +1,6 @@ import json -import logging -import warnings -from typing import Optional, Dict, Any + +from typing import Any, Dict, Optional import boto3 import httpx @@ -10,15 +9,16 @@ from ai21.ai21_env_config import AI21EnvConfig from ai21.clients.aws.aws_authorization import AWSAuthorization from ai21.clients.bedrock._stream_decoder import _AWSEventStreamDecoder -from ai21.clients.studio.resources.studio_chat import StudioChat, AsyncStudioChat -from ai21.clients.studio.resources.studio_completion import StudioCompletion, AsyncStudioCompletion -from ai21.errors import AccessDenied, NotFound, APITimeoutError, ModelErrorException +from ai21.clients.studio.resources.studio_chat import AsyncStudioChat, StudioChat +from ai21.clients.studio.resources.studio_completion import ( + AsyncStudioCompletion, + StudioCompletion, +) +from ai21.errors import AccessDenied, APITimeoutError, ModelErrorException, NotFound from ai21.http_client.async_http_client import AsyncAI21HTTPClient from ai21.http_client.http_client import AI21HTTPClient from ai21.models.request_options import RequestOptions -_logger = logging.getLogger(__name__) - _BEDROCK_URL_FORMAT = "https://bedrock-runtime.{region}.amazonaws.com" @@ -76,7 +76,6 @@ def _prepare_headers(self, url: str, body: Dict[str, Any]) -> dict: class AI21BedrockClient(AI21HTTPClient, BaseBedrockClient): def __init__( self, - model_id: Optional[str] = None, base_url: Optional[str] = None, region: Optional[str] = None, headers: Optional[Dict[str, Any]] = None, @@ -85,12 +84,6 @@ def __init__( session: Optional[boto3.Session] = None, http_client: Optional[httpx.Client] = None, ): - if model_id is not None: - warnings.warn( - "Please consider using the 'model' parameter in the " - "'create' method calls instead of the constructor.", - DeprecationWarning, - ) self._region = region or AI21EnvConfig.aws_region if base_url is None: base_url = _BEDROCK_URL_FORMAT.format(region=self._region) @@ -128,7 +121,6 @@ def _get_streaming_decoder(self) -> _AWSEventStreamDecoder: class AsyncAI21BedrockClient(AsyncAI21HTTPClient, BaseBedrockClient): def __init__( self, - model_id: Optional[str] = None, base_url: Optional[str] = None, region: Optional[str] = None, headers: Optional[Dict[str, Any]] = None, @@ -137,12 +129,6 @@ def __init__( session: Optional[boto3.Session] = None, http_client: Optional[httpx.AsyncClient] = None, ): - if model_id is not None: - warnings.warn( - "Please consider using the 'model' parameter in the " - "'create' method calls instead of the constructor.", - DeprecationWarning, - ) self._region = region or AI21EnvConfig.aws_region if base_url is None: base_url = _BEDROCK_URL_FORMAT.format(region=self._region) diff --git a/ai21/clients/common/completion_base.py b/ai21/clients/common/completion_base.py index a57c9e1f..2e39e432 100644 --- a/ai21/clients/common/completion_base.py +++ b/ai21/clients/common/completion_base.py @@ -1,35 +1,17 @@ from __future__ import annotations -import warnings from abc import ABC, abstractmethod -from typing import List, Dict, Optional +from typing import Dict, List -from ai21.models import Penalty, CompletionsResponse +from ai21.models import CompletionsResponse, Penalty +from ai21.models._pydantic_compatibility import _to_dict from ai21.types import NOT_GIVEN, NotGiven from ai21.utils.typing import remove_not_given -from ai21.models._pydantic_compatibility import _to_dict class Completion(ABC): _module_name = "complete" - def _get_model(self, model: Optional[str], model_id: Optional[str]) -> str: - if model_id is not None: - warnings.warn( - "The 'model_id' parameter is deprecated and will be removed in a future version." - " Please use 'model' instead.", - DeprecationWarning, - stacklevel=2, - ) - - if model_id and model: - raise ValueError("Please provide only 'model' as 'model_id' is deprecated.") - - if not model and not model_id: - raise ValueError("model should be provided 'create' method call") - - return model or model_id - @abstractmethod def create( self, diff --git a/ai21/clients/studio/resources/chat/async_chat_completions.py b/ai21/clients/studio/resources/chat/async_chat_completions.py index e0df9e4b..0a820e1d 100644 --- a/ai21/clients/studio/resources/chat/async_chat_completions.py +++ b/ai21/clients/studio/resources/chat/async_chat_completions.py @@ -1,17 +1,18 @@ from __future__ import annotations -from typing import List, Optional, Any, Literal, overload +from typing import Any, List, Literal, Optional, overload -from ai21.clients.studio.resources.studio_resource import AsyncStudioResource from ai21.clients.studio.resources.chat.base_chat_completions import BaseChatCompletions +from ai21.clients.studio.resources.studio_resource import AsyncStudioResource from ai21.models import ChatMessage as J2ChatMessage -from ai21.models.chat import ChatCompletionResponse, ChatCompletionChunk +from ai21.models.chat import ChatCompletionChunk, ChatCompletionResponse from ai21.models.chat.chat_message import ChatMessageParam from ai21.models.chat.document_schema import DocumentSchema from ai21.models.chat.response_format import ResponseFormat from ai21.models.chat.tool_defintions import ToolDefinition from ai21.stream.async_stream import AsyncStream -from ai21.types import NotGiven, NOT_GIVEN +from ai21.types import NOT_GIVEN, NotGiven + __all__ = ["AsyncChatCompletions"] @@ -74,8 +75,10 @@ async def create( " instead of ai21.models when working with chat completions." ) + model = self._get_model(model=model) + body = self._create_body( - model=self._get_model(model=model, model_id=kwargs.pop("model_id", None)), + model=model, messages=messages, stop=stop, temperature=temperature, @@ -96,3 +99,9 @@ async def create( stream_cls=AsyncStream[ChatCompletionChunk], response_cls=ChatCompletionResponse, ) + + def _get_model(self, model: str) -> str: + if self._client.__class__.__name__ == "AsyncAI21Client": + return self._check_model(model=model) + + return model diff --git a/ai21/clients/studio/resources/chat/base_chat_completions.py b/ai21/clients/studio/resources/chat/base_chat_completions.py index 5f719f7b..7a57092d 100644 --- a/ai21/clients/studio/resources/chat/base_chat_completions.py +++ b/ai21/clients/studio/resources/chat/base_chat_completions.py @@ -1,37 +1,41 @@ from __future__ import annotations import warnings + from abc import ABC -from typing import List, Optional, Union, Any, Dict, Literal +from typing import Any, Dict, List, Literal, Optional, Union +from ai21.models._pydantic_compatibility import _to_dict from ai21.models.chat.chat_message import ChatMessageParam from ai21.models.chat.document_schema import DocumentSchema from ai21.models.chat.response_format import ResponseFormat from ai21.models.chat.tool_defintions import ToolDefinition from ai21.types import NotGiven from ai21.utils.typing import remove_not_given -from ai21.models._pydantic_compatibility import _to_dict + + +_MODEL_DEPRECATION_WARNING = """ +The 'jamba-1.5-mini' and 'jamba-1.5-large' models are deprecated and will +be removed in a future version. +Please use jamba-mini-1.6-2025-03 or jamba-large-1.6-2025-03 instead. +""" class BaseChatCompletions(ABC): _module_name = "chat/completions" - def _get_model(self, model: Optional[str], model_id: Optional[str]) -> str: - if model_id is not None: + def _check_model(self, model: Optional[str]) -> str: + if not model: + raise ValueError("model should be provided 'create' method call") + + if model in ["jamba-1.5-mini", "jamba-1.5-large"]: warnings.warn( - "The 'model_id' parameter is deprecated and will be removed in a future version." - " Please use 'model' instead.", + _MODEL_DEPRECATION_WARNING, DeprecationWarning, - stacklevel=2, + stacklevel=3, ) - if model_id and model: - raise ValueError("Please provide only 'model' as 'model_id' is deprecated.") - - if not model and not model_id: - raise ValueError("model should be provided 'create' method call") - - return model or model_id + return model def _create_body( self, diff --git a/ai21/clients/studio/resources/chat/chat_completions.py b/ai21/clients/studio/resources/chat/chat_completions.py index 7352a85b..1fd5e363 100644 --- a/ai21/clients/studio/resources/chat/chat_completions.py +++ b/ai21/clients/studio/resources/chat/chat_completions.py @@ -1,17 +1,18 @@ from __future__ import annotations -from typing import List, Optional, Any, Literal, overload +from typing import Any, List, Literal, Optional, overload -from ai21.clients.studio.resources.studio_resource import StudioResource from ai21.clients.studio.resources.chat.base_chat_completions import BaseChatCompletions +from ai21.clients.studio.resources.studio_resource import StudioResource from ai21.models import ChatMessage as J2ChatMessage -from ai21.models.chat import ChatCompletionResponse, ChatCompletionChunk +from ai21.models.chat import ChatCompletionChunk, ChatCompletionResponse from ai21.models.chat.chat_message import ChatMessageParam from ai21.models.chat.document_schema import DocumentSchema from ai21.models.chat.response_format import ResponseFormat from ai21.models.chat.tool_defintions import ToolDefinition from ai21.stream.stream import Stream -from ai21.types import NotGiven, NOT_GIVEN +from ai21.types import NOT_GIVEN, NotGiven + __all__ = ["ChatCompletions"] @@ -56,7 +57,7 @@ def create( def create( self, messages: List[ChatMessageParam], - model: Optional[str] = None, + model: str, max_tokens: int | NotGiven = NOT_GIVEN, temperature: float | NotGiven = NOT_GIVEN, top_p: float | NotGiven = NOT_GIVEN, @@ -74,8 +75,10 @@ def create( " instead of ai21.models when working with chat completions." ) + model = self._get_model(model=model) + body = self._create_body( - model=self._get_model(model=model, model_id=kwargs.pop("model_id", None)), + model=model, messages=messages, stop=stop, temperature=temperature, @@ -96,3 +99,9 @@ def create( stream_cls=Stream[ChatCompletionChunk], response_cls=ChatCompletionResponse, ) + + def _get_model(self, model: str) -> str: + if self._client.__class__.__name__ == "AI21Client": + return self._check_model(model=model) + + return model diff --git a/ai21/clients/studio/resources/studio_completion.py b/ai21/clients/studio/resources/studio_completion.py index db1930b5..82f3acf6 100644 --- a/ai21/clients/studio/resources/studio_completion.py +++ b/ai21/clients/studio/resources/studio_completion.py @@ -1,10 +1,13 @@ from __future__ import annotations -from typing import List, Dict, Optional +from typing import Dict, List from ai21.clients.common.completion_base import Completion -from ai21.clients.studio.resources.studio_resource import StudioResource, AsyncStudioResource -from ai21.models import Penalty, CompletionsResponse +from ai21.clients.studio.resources.studio_resource import ( + AsyncStudioResource, + StudioResource, +) +from ai21.models import CompletionsResponse, Penalty from ai21.types import NOT_GIVEN, NotGiven @@ -12,8 +15,8 @@ class StudioCompletion(StudioResource, Completion): def create( self, prompt: str, + model: str, *, - model: Optional[str] = None, max_tokens: int | NotGiven = NOT_GIVEN, num_results: int | NotGiven = NOT_GIVEN, min_tokens: int | NotGiven = NOT_GIVEN, @@ -28,7 +31,6 @@ def create( logit_bias: Dict[str, float] | NotGiven = NOT_GIVEN, **kwargs, ) -> CompletionsResponse: - model = self._get_model(model=model, model_id=kwargs.pop("model_id", None)) path = self._get_completion_path(model=model) body = self._create_body( model=model, @@ -54,8 +56,8 @@ class AsyncStudioCompletion(AsyncStudioResource, Completion): async def create( self, prompt: str, + model: str, *, - model: Optional[str] = None, max_tokens: int | NotGiven = NOT_GIVEN, num_results: int | NotGiven = NOT_GIVEN, min_tokens: int | NotGiven = NOT_GIVEN, @@ -70,7 +72,6 @@ async def create( logit_bias: Dict[str, float] | NotGiven = NOT_GIVEN, **kwargs, ) -> CompletionsResponse: - model = self._get_model(model=model, model_id=kwargs.pop("model_id", None)) path = self._get_completion_path(model=model) body = self._create_body( model=model, diff --git a/examples/studio/async_chat.py b/examples/studio/async_chat.py deleted file mode 100644 index fc1c35e5..00000000 --- a/examples/studio/async_chat.py +++ /dev/null @@ -1,41 +0,0 @@ -""" -This examples uses a deprecated method client.chat.create and instead -should be replaced with the `client.chat.completions.create` -""" - -import asyncio - -from ai21 import AsyncAI21Client -from ai21.models import RoleType, Penalty -from ai21.models import ChatMessage - -system = "You're a support engineer in a SaaS company" -messages = [ - ChatMessage(text="Hello, I need help with a signup process.", role=RoleType.USER), - ChatMessage(text="Hi Alice, I can help you with that. What seems to be the problem?", role=RoleType.ASSISTANT), - ChatMessage(text="I am having trouble signing up for your product with my Google account.", role=RoleType.USER), -] - - -client = AsyncAI21Client() - - -async def main(): - response = await client.chat.create( - system=system, - messages=messages, - model="j2-ultra", - count_penalty=Penalty( - scale=0, - apply_to_emojis=False, - apply_to_numbers=False, - apply_to_stopwords=False, - apply_to_punctuation=False, - apply_to_whitespaces=False, - ), - ) - - print(response) - - -asyncio.run(main()) diff --git a/examples/studio/chat.py b/examples/studio/chat.py deleted file mode 100644 index 6b62a849..00000000 --- a/examples/studio/chat.py +++ /dev/null @@ -1,33 +0,0 @@ -""" -This examples uses a deprecated method client.chat.create and instead -should be replaced with the `client.chat.completions.create` -""" - -from ai21 import AI21Client -from ai21.models import RoleType, Penalty -from ai21.models import ChatMessage - -system = "You're a support engineer in a SaaS company" -messages = [ - ChatMessage(text="Hello, I need help with a signup process.", role=RoleType.USER), - ChatMessage(text="Hi Alice, I can help you with that. What seems to be the problem?", role=RoleType.ASSISTANT), - ChatMessage(text="I am having trouble signing up for your product with my Google account.", role=RoleType.USER), -] - - -client = AI21Client() -response = client.chat.create( - system=system, - messages=messages, - model="j2-ultra", - count_penalty=Penalty( - scale=0, - apply_to_emojis=False, - apply_to_numbers=False, - apply_to_stopwords=False, - apply_to_punctuation=False, - apply_to_whitespaces=False, - ), -) - -print(response) diff --git a/examples/studio/chat/async_chat_completions.py b/examples/studio/chat/async_chat_completions.py index d7bbb2df..f8bae07d 100644 --- a/examples/studio/chat/async_chat_completions.py +++ b/examples/studio/chat/async_chat_completions.py @@ -18,7 +18,7 @@ async def main(): response = await client.chat.completions.create( messages=messages, - model="jamba-1.5-mini", + model="jamba-mini-1.6-2025-03", max_tokens=100, temperature=0.7, top_p=1.0, diff --git a/examples/studio/chat/async_stream_chat_completions.py b/examples/studio/chat/async_stream_chat_completions.py index 994205dc..f5653339 100644 --- a/examples/studio/chat/async_stream_chat_completions.py +++ b/examples/studio/chat/async_stream_chat_completions.py @@ -3,6 +3,7 @@ from ai21 import AsyncAI21Client from ai21.models.chat import ChatMessage + system = "You're a support engineer in a SaaS company" messages = [ ChatMessage(content=system, role="system"), @@ -17,7 +18,7 @@ async def main(): response = await client.chat.completions.create( messages=messages, - model="jamba-1.5-large", + model="jamba-mini-1.6-2025-03", max_tokens=100, stream=True, ) diff --git a/examples/studio/chat/chat_completions.py b/examples/studio/chat/chat_completions.py index 727312bd..fee17adf 100644 --- a/examples/studio/chat/chat_completions.py +++ b/examples/studio/chat/chat_completions.py @@ -1,5 +1,6 @@ from ai21 import AI21Client -from ai21.models.chat.chat_message import SystemMessage, UserMessage, AssistantMessage +from ai21.models.chat.chat_message import AssistantMessage, SystemMessage, UserMessage + system = "You're a support engineer in a SaaS company" messages = [ @@ -13,7 +14,7 @@ response = client.chat.completions.create( messages=messages, - model="jamba-1.5-mini", + model="jamba-mini-1.6-2025-03", max_tokens=100, temperature=0.7, top_p=1.0, diff --git a/examples/studio/chat/chat_completions_jamba_instruct.py b/examples/studio/chat/chat_completions_jamba_instruct.py index af84d0bc..21c51845 100644 --- a/examples/studio/chat/chat_completions_jamba_instruct.py +++ b/examples/studio/chat/chat_completions_jamba_instruct.py @@ -1,5 +1,6 @@ from ai21 import AI21Client -from ai21.models.chat.chat_message import SystemMessage, UserMessage, AssistantMessage +from ai21.models.chat.chat_message import AssistantMessage, SystemMessage, UserMessage + system = "You're a support engineer in a SaaS company" messages = [ diff --git a/examples/studio/chat/chat_documents.py b/examples/studio/chat/chat_documents.py index ca7135a7..d079c0cd 100644 --- a/examples/studio/chat/chat_documents.py +++ b/examples/studio/chat/chat_documents.py @@ -4,6 +4,7 @@ from ai21.logger import set_verbose from ai21.models.chat import ChatMessage, DocumentSchema + set_verbose(True) schnoodel = DocumentSchema( @@ -39,6 +40,10 @@ client = AI21Client() -response = client.chat.completions.create(messages=messages, model="jamba-1.5-mini", documents=documents) +response = client.chat.completions.create( + messages=messages, + model="jamba-mini-1.6-2025-03", + documents=documents, +) print(response) diff --git a/examples/studio/chat/chat_function_calling.py b/examples/studio/chat/chat_function_calling.py index e55feaf2..7cf61f99 100644 --- a/examples/studio/chat/chat_function_calling.py +++ b/examples/studio/chat/chat_function_calling.py @@ -46,7 +46,7 @@ def get_order_delivery_date(order_id: str) -> str: client = AI21Client() -response = client.chat.completions.create(messages=messages, model="jamba-1.5-large", tools=tools) +response = client.chat.completions.create(messages=messages, model="jamba-large-1.6-2025-03", tools=tools) """ AI models can be error-prone, it's crucial to ensure that the tool calls align with the expectations. The below code snippet demonstrates how to handle tool calls in the response and invoke the tool function @@ -79,5 +79,5 @@ def get_order_delivery_date(order_id: str) -> str: tool_message = ToolMessage(role="tool", tool_call_id=tool_calls[0].id, content=delivery_date) messages.append(tool_message) - response = client.chat.completions.create(messages=messages, model="jamba-1.5-large", tools=tools) + response = client.chat.completions.create(messages=messages, model="jamba-large-1.6-2025-03", tools=tools) print(response.choices[0].message.content) diff --git a/examples/studio/chat/chat_function_calling_multiple_tools.py b/examples/studio/chat/chat_function_calling_multiple_tools.py index b61deeb2..cfdd7711 100644 --- a/examples/studio/chat/chat_function_calling_multiple_tools.py +++ b/examples/studio/chat/chat_function_calling_multiple_tools.py @@ -7,6 +7,7 @@ from ai21.models.chat.tool_defintions import ToolDefinition from ai21.models.chat.tool_parameters import ToolParameters + set_verbose(True) @@ -74,7 +75,7 @@ def get_sunset_hour(place: str, date: str) -> str: client = AI21Client() -response = client.chat.completions.create(messages=messages, model="jamba-1.5-large", tools=tools) +response = client.chat.completions.create(messages=messages, model="jamba-large-1.6-2025-03", tools=tools) """ AI models can be error-prone, it's crucial to ensure that the tool calls align with the expectations. The below code snippet demonstrates how to handle tool calls in the response and invoke the tool function @@ -122,5 +123,5 @@ def get_sunset_hour(place: str, date: str) -> str: tool_message = ToolMessage(role="tool", tool_call_id=tool_id_called, content=str(result)) messages.append(tool_message) - response = client.chat.completions.create(messages=messages, model="jamba-1.5-large", tools=tools) + response = client.chat.completions.create(messages=messages, model="jamba-large-1.6-2025-03", tools=tools) print(response.choices[0].message.content) diff --git a/examples/studio/chat/chat_response_format.py b/examples/studio/chat/chat_response_format.py index 2c51be55..cc744303 100644 --- a/examples/studio/chat/chat_response_format.py +++ b/examples/studio/chat/chat_response_format.py @@ -1,4 +1,5 @@ import json + from enum import Enum from pydantic import BaseModel @@ -7,6 +8,7 @@ from ai21.logger import set_verbose from ai21.models.chat import ChatMessage, ResponseFormat + set_verbose(True) @@ -37,7 +39,7 @@ class ZooTicketsOrder(BaseModel): response = client.chat.completions.create( messages=messages, - model="jamba-1.5-large", + model="jamba-large-1.6-2025-03", max_tokens=800, temperature=0, response_format=ResponseFormat(type="json_object"), diff --git a/examples/studio/chat/stream_chat_completions.py b/examples/studio/chat/stream_chat_completions.py index 415b3260..6f4b42e5 100644 --- a/examples/studio/chat/stream_chat_completions.py +++ b/examples/studio/chat/stream_chat_completions.py @@ -1,6 +1,7 @@ from ai21 import AI21Client from ai21.models.chat import ChatMessage + system = "You're a support engineer in a SaaS company" messages = [ ChatMessage(content=system, role="system"), @@ -13,7 +14,7 @@ response = client.chat.completions.create( messages=messages, - model="jamba-1.5-large", + model="jamba-large-1.6-2025-03", max_tokens=100, stream=True, ) diff --git a/tests/integration_tests/clients/bedrock/test_chat_completions.py b/tests/integration_tests/clients/bedrock/test_chat_completions.py index c067f34a..00b4f2e3 100644 --- a/tests/integration_tests/clients/bedrock/test_chat_completions.py +++ b/tests/integration_tests/clients/bedrock/test_chat_completions.py @@ -4,6 +4,7 @@ from ai21.models._pydantic_compatibility import _to_dict from ai21.models.chat import ChatMessage + _SYSTEM_MSG = "You're a support engineer in a SaaS company" _MESSAGES = [ ChatMessage(content=_SYSTEM_MSG, role="system"), @@ -15,7 +16,7 @@ def test_chat_completions__when_stream__last_chunk_should_hold_bedrock_metrics() client = AI21BedrockClient() response = client.chat.completions.create( messages=_MESSAGES, - model=BedrockModelID.JAMBA_INSTRUCT_V1, + model=BedrockModelID.JAMBA_1_5_MINI, stream=True, ) @@ -29,7 +30,7 @@ async def test__async_chat_completions__when_stream__last_chunk_should_hold_bedr client = AsyncAI21BedrockClient() response = await client.chat.completions.create( messages=_MESSAGES, - model=BedrockModelID.JAMBA_INSTRUCT_V1, + model=BedrockModelID.JAMBA_1_5_MINI, stream=True, ) diff --git a/tests/integration_tests/clients/bedrock/test_completion.py b/tests/integration_tests/clients/bedrock/test_completion.py index 8424267c..edf1bc00 100644 --- a/tests/integration_tests/clients/bedrock/test_completion.py +++ b/tests/integration_tests/clients/bedrock/test_completion.py @@ -7,6 +7,7 @@ from ai21.models import Penalty from tests.integration_tests.skip_helpers import should_skip_bedrock_integration_tests + _PROMPT = "Once upon a time, in a land far, far away, there was a" @@ -43,7 +44,7 @@ def test_completion_penalties__should_return_response( completion_args = dict( prompt=_PROMPT, max_tokens=64, - model_id=BedrockModelID.J2_ULTRA_V1, + model=BedrockModelID.J2_ULTRA_V1, temperature=0, top_p=1, top_k_return=0, @@ -68,15 +69,6 @@ def test_completion_penalties__should_return_response( assert isinstance(completion.data.text, str) -@pytest.mark.skipif(should_skip_bedrock_integration_tests(), reason="No keys supplied for AWS. Skipping.") -def test_completion__when_no_model_id__should_raise_exception(): - with pytest.raises(ValueError) as e: - client = AI21BedrockClient() - client.completion.create(prompt=_PROMPT) - - assert e.value.args[0] == "model should be provided 'create' method call" - - @pytest.mark.asyncio @pytest.mark.skipif(should_skip_bedrock_integration_tests(), reason="No keys supplied for AWS. Skipping.") @pytest.mark.parametrize( @@ -111,7 +103,7 @@ async def test_async_completion_penalties__should_return_response( completion_args = dict( prompt=_PROMPT, max_tokens=64, - model_id=BedrockModelID.J2_ULTRA_V1, + model=BedrockModelID.J2_ULTRA_V1, temperature=0, top_p=1, top_k_return=0, @@ -134,13 +126,3 @@ async def test_async_completion_penalties__should_return_response( assert len([completion.data.text for completion in response.completions]) == 1 for completion in response.completions: assert isinstance(completion.data.text, str) - - -@pytest.mark.asyncio -@pytest.mark.skipif(should_skip_bedrock_integration_tests(), reason="No keys supplied for AWS. Skipping.") -async def test_async_completion__when_no_model_id__should_raise_exception(): - with pytest.raises(ValueError) as e: - client = AsyncAI21BedrockClient() - await client.completion.create(prompt=_PROMPT) - - assert e.value.args[0] == "model should be provided 'create' method call" diff --git a/tests/integration_tests/clients/studio/test_chat_completions.py b/tests/integration_tests/clients/studio/test_chat_completions.py index 18df3091..b10b4cc7 100644 --- a/tests/integration_tests/clients/studio/test_chat_completions.py +++ b/tests/integration_tests/clients/studio/test_chat_completions.py @@ -1,14 +1,22 @@ import json + from unittest.mock import patch -import pytest import httpx +import pytest from ai21 import AI21Client, AsyncAI21Client -from ai21.models.chat import ChatMessage, ChatCompletionResponse, ChatCompletionChunk, ChoicesChunk, ChoiceDelta from ai21.models import RoleType +from ai21.models.chat import ( + ChatCompletionChunk, + ChatCompletionResponse, + ChatMessage, + ChoiceDelta, + ChoicesChunk, +) + -_MODEL = "jamba-instruct-preview" +_MODEL = "jamba-mini-1.6-2025-03" _MESSAGES = [ ChatMessage( content="Hello, I need help studying for the coming test, can you teach me about the US constitution? ", diff --git a/tests/unittests/clients/bedrock/test_chat_completions.py b/tests/unittests/clients/bedrock/test_chat_completions.py index 5ceec306..6f669c4a 100644 --- a/tests/unittests/clients/bedrock/test_chat_completions.py +++ b/tests/unittests/clients/bedrock/test_chat_completions.py @@ -1,19 +1,26 @@ import json -from typing import Optional, Union -from unittest.mock import Mock, patch, ANY + +from typing import Optional +from unittest.mock import ANY, Mock, patch import httpx import pytest + from pytest_mock import MockerFixture from ai21 import AI21EnvConfig from ai21.clients.aws.aws_authorization import AWSAuthorization -from ai21.clients.bedrock.ai21_bedrock_client import AI21BedrockClient, AsyncAI21BedrockClient +from ai21.clients.bedrock.ai21_bedrock_client import ( + AI21BedrockClient, + AsyncAI21BedrockClient, +) from ai21.clients.bedrock.bedrock_model_id import BedrockModelID -from ai21.models.chat import ChatMessage -from tests.unittests.commons import FAKE_CHAT_COMPLETION_RESPONSE_DICT, FAKE_AUTH_HEADERS - from ai21.models._pydantic_compatibility import _to_dict +from ai21.models.chat import ChatMessage +from tests.unittests.commons import ( + FAKE_AUTH_HEADERS, + FAKE_CHAT_COMPLETION_RESPONSE_DICT, +) _FAKE_RESPONSE_DICT = { @@ -115,39 +122,3 @@ async def test__options_in_async_request(mock_async_httpx_client: Mock): data=json.dumps({"messages": [_to_dict(message)]}).encode("utf-8"), files=None, ) - - -@pytest.mark.asyncio -@pytest.mark.parametrize( - ids=[ - "when_sync_client__model_and_model_id__should_raise_error", - "when_sync_client__no_model_and_no_model_id__should_raise_error", - "when_async_client__no_model_and_no_model_id__should_raise_error", - "when_async_client__no_model_and_no_model_id__should_raise_error", - ], - argvalues=[ - (BedrockModelID.JAMBA_INSTRUCT_V1, BedrockModelID.JAMBA_INSTRUCT_V1, AI21BedrockClient()), - (None, None, AI21BedrockClient()), - (BedrockModelID.JAMBA_INSTRUCT_V1, BedrockModelID.JAMBA_INSTRUCT_V1, AsyncAI21BedrockClient()), - (None, None, AsyncAI21BedrockClient()), - ], - argnames=["model", "model_id", "client"], -) -async def test_model_id_and_model_supported_params( - model: Optional[BedrockModelID], - model_id: Optional[BedrockModelID], - client: Union[AI21BedrockClient, AsyncAI21BedrockClient], -): - with pytest.raises(ValueError): - if isinstance(client, AsyncAI21BedrockClient): - await client.chat.completions.create( - model=model, - messages=[ChatMessage(content="This is a test", role="user")], - model_id=model_id, - ) - else: - client.chat.completions.create( - model=model, - messages=[ChatMessage(content="This is a test", role="user")], - model_id=model_id, - ) diff --git a/tests/unittests/clients/studio/resources/chat/test_chat_completions.py b/tests/unittests/clients/studio/resources/chat/test_chat_completions.py index 657dedc0..67a0d4a7 100644 --- a/tests/unittests/clients/studio/resources/chat/test_chat_completions.py +++ b/tests/unittests/clients/studio/resources/chat/test_chat_completions.py @@ -1,19 +1,22 @@ import uuid + from unittest.mock import AsyncMock import httpx import pytest + from pytest_mock import MockerFixture from ai21 import AI21Client, AsyncAI21Client from ai21.models import ChatMessage, RoleType from ai21.models.chat import ChatCompletionResponse, ResponseFormat -from ai21.models.chat.chat_message import UserMessage, SystemMessage, AssistantMessage +from ai21.models.chat.chat_message import AssistantMessage, SystemMessage, UserMessage from ai21.models.chat.document_schema import DocumentSchema from ai21.models.chat.function_tool_definition import FunctionToolDefinition from ai21.models.chat.tool_defintions import ToolDefinition from ai21.models.chat.tool_parameters import ToolParameters + _FAKE_API_KEY = "dummy_api_key" @@ -46,18 +49,6 @@ async def test_async_chat_create__when_bad_import_to_chat_message__raise_error() ) -def test__when_model_and_model_id__raise_error(): - client = AI21Client( - api_key=_FAKE_API_KEY, - ) - with pytest.raises(ValueError): - client.chat.completions.create( - model="jamba-1.5", - model_id="jamba-instruct", - messages=[ChatMessage(role=RoleType.USER, text="Hello")], - ) - - # ----------------------------------- Basic Happy Flow: ----------------------------------- # _FAKE_BASIC_HAPPY_FLOW_RESPONSE_JSON = { @@ -81,7 +72,7 @@ def test_chat_completion_basic_happy_flow(mocker: MockerFixture) -> None: mocked_client.send.return_value = httpx.Response(status_code=200, json=_FAKE_BASIC_HAPPY_FLOW_RESPONSE_JSON) client = AI21Client(api_key=_FAKE_API_KEY, http_client=mocked_client) response: ChatCompletionResponse = client.chat.completions.create( - model="jamba-1.5-mini", messages=[UserMessage(role="user", content="Hello")] + model="jamba-mini-1.6-2025-03", messages=[UserMessage(role="user", content="Hello")] ) assert response.choices[0].message.content == _FAKE_BASIC_HAPPY_FLOW_EXPECTED_CONTENT @@ -94,7 +85,7 @@ async def test_async_chat_completion_basic_happy_flow(mocker: MockerFixture) -> ) async_client = AsyncAI21Client(api_key=_FAKE_API_KEY, http_client=mocked_client) response: ChatCompletionResponse = await async_client.chat.completions.create( - model="jamba-1.5-mini", messages=[UserMessage(role="user", content="Hello")] + model="jamba-mini-1.6-2025-03", messages=[UserMessage(role="user", content="Hello")] ) assert response.choices[0].message.content == _FAKE_BASIC_HAPPY_FLOW_EXPECTED_CONTENT @@ -158,7 +149,7 @@ def test_chat_completion_with_tool_calls_happy_flow(mocker: MockerFixture) -> No mocked_client.send.return_value = httpx.Response(status_code=200, json=_FAKE_TOOL_CALLS_HAPPY_FLOW_RESPONSE_JSON) client = AI21Client(api_key=_FAKE_API_KEY, http_client=mocked_client) response = client.chat.completions.create( - model="jamba-1.5-mini", messages=_FAKE_TOOL_CALL_TEST_MESSAGES, tools=_FAKE_TOOL_CALL_TEST_TOOLS + model="jamba-mini-1.6-2025-03", messages=_FAKE_TOOL_CALL_TEST_MESSAGES, tools=_FAKE_TOOL_CALL_TEST_TOOLS ) assert response.choices[0].message.tool_calls[0].function.name == _FAKE_TOOL_CALL_TEST_EXPECTED_FUNCTION_NAME assert ( @@ -174,7 +165,7 @@ async def test_async_chat_completion_with_tool_calls_happy_flow(mocker: MockerFi ) async_client = AsyncAI21Client(api_key=_FAKE_API_KEY, http_client=mocked_client) response = await async_client.chat.completions.create( - model="jamba-1.5-mini", messages=_FAKE_TOOL_CALL_TEST_MESSAGES, tools=_FAKE_TOOL_CALL_TEST_TOOLS + model="jamba-mini-1.6-2025-03", messages=_FAKE_TOOL_CALL_TEST_MESSAGES, tools=_FAKE_TOOL_CALL_TEST_TOOLS ) assert response.choices[0].message.tool_calls[0].function.name == _FAKE_TOOL_CALL_TEST_EXPECTED_FUNCTION_NAME assert ( @@ -240,7 +231,7 @@ def test_chat_completion_with_documents_happy_flow(mocker: MockerFixture) -> Non mocked_client.send.return_value = httpx.Response(status_code=200, json=_FAKE_DOCUMENTS_TEST_RESPONSE_JSON) client = AI21Client(api_key=_FAKE_API_KEY, http_client=mocked_client) response = client.chat.completions.create( - model="jamba-1.5-mini", messages=_FAKE_DOCUMENTS_TEST_MESSAGES, documents=_FAKE_DOCUMENTS_TEST_DOCUMENTS + model="jamba-mini-1.6-2025-03", messages=_FAKE_DOCUMENTS_TEST_MESSAGES, documents=_FAKE_DOCUMENTS_TEST_DOCUMENTS ) assert response.choices[0].message.content == _FAKE_DOCUMENTS_TEST_EXPECTED_CONTENT @@ -253,7 +244,7 @@ async def test_async_chat_completion_with_documents_happy_flow(mocker: MockerFix ) async_client = AsyncAI21Client(api_key=_FAKE_API_KEY, http_client=mocked_client) response = await async_client.chat.completions.create( - model="jamba-1.5-mini", messages=_FAKE_DOCUMENTS_TEST_MESSAGES, documents=_FAKE_DOCUMENTS_TEST_DOCUMENTS + model="jamba-mini-1.6-2025-03", messages=_FAKE_DOCUMENTS_TEST_MESSAGES, documents=_FAKE_DOCUMENTS_TEST_DOCUMENTS ) assert response.choices[0].message.content == _FAKE_DOCUMENTS_TEST_EXPECTED_CONTENT @@ -310,7 +301,7 @@ def test_chat_completion_response_format_json_happy_flow(mocker: MockerFixture) ) client = AI21Client(api_key=_FAKE_API_KEY, http_client=mocked_client) response = client.chat.completions.create( - model="jamba-1.5-mini", + model="jamba-mini-1.6-2025-03", messages=_FAKE_RESPONSE_FORMAT_JSON_TEST_MESSAGES, response_format=ResponseFormat(type="json_object"), ) @@ -325,7 +316,7 @@ async def test_async_chat_completion_response_format_json_happy_flow(mocker: Moc ) async_client = AsyncAI21Client(api_key=_FAKE_API_KEY, http_client=mocked_client) response = await async_client.chat.completions.create( - model="jamba-1.5-mini", + model="jamba-mini-1.6-2025-03", messages=_FAKE_RESPONSE_FORMAT_JSON_TEST_MESSAGES, response_format=ResponseFormat(type="json_object"), )