diff --git a/ai21/clients/common/chat_base.py b/ai21/clients/common/chat_base.py index dee9fc4d..809fcc3c 100644 --- a/ai21/clients/common/chat_base.py +++ b/ai21/clients/common/chat_base.py @@ -1,6 +1,7 @@ from abc import ABC, abstractmethod from typing import List, Any, Dict, Optional +from ai21.clients.studio.resources.chat import ChatCompletions from ai21.models import Penalty, ChatResponse, ChatMessage @@ -47,6 +48,11 @@ def create( """ pass + @property + @abstractmethod + def completions(self) -> ChatCompletions: + pass + def _json_to_response(self, json: Dict[str, Any]) -> ChatResponse: return ChatResponse.from_dict(json) diff --git a/ai21/clients/studio/ai21_client.py b/ai21/clients/studio/ai21_client.py index dfb781b5..8e5979d0 100644 --- a/ai21/clients/studio/ai21_client.py +++ b/ai21/clients/studio/ai21_client.py @@ -50,7 +50,7 @@ def __init__( http_client=http_client, ) self.completion = StudioCompletion(self._http_client) - self.chat = StudioChat(self._http_client) + self.chat: StudioChat = StudioChat(self._http_client) self.summarize = StudioSummarize(self._http_client) self.embed = StudioEmbed(self._http_client) self.gec = StudioGEC(self._http_client) diff --git a/ai21/clients/studio/resources/chat/__init__.py b/ai21/clients/studio/resources/chat/__init__.py new file mode 100644 index 00000000..5c9e2645 --- /dev/null +++ b/ai21/clients/studio/resources/chat/__init__.py @@ -0,0 +1,3 @@ +from __future__ import annotations + +from .chat_completions import ChatCompletions as ChatCompletions diff --git a/ai21/clients/studio/resources/chat/chat_completions.py b/ai21/clients/studio/resources/chat/chat_completions.py new file mode 100644 index 00000000..791232e0 --- /dev/null +++ b/ai21/clients/studio/resources/chat/chat_completions.py @@ -0,0 +1,83 @@ +from __future__ import annotations + +from typing import List, Optional, Union, Any, Dict + +from ai21.clients.studio.resources.studio_resource import StudioResource +from ai21.models.chat import ChatMessage, ChatCompletionResponse +from ai21.types import NotGiven, NOT_GIVEN +from ai21.utils.typing import remove_not_given + +__all__ = ["ChatCompletions"] + + +class ChatCompletions(StudioResource): + _module_name = "chat/complete" + + def create( + self, + model: str, + messages: List[ChatMessage], + n: int | NotGiven = NOT_GIVEN, + logprobs: bool | NotGiven = NOT_GIVEN, + top_logprobs: int | NotGiven = NOT_GIVEN, + max_tokens: int | NotGiven = NOT_GIVEN, + temperature: float | NotGiven = NOT_GIVEN, + top_p: float | NotGiven = NOT_GIVEN, + stop: str | List[str] | NotGiven = NOT_GIVEN, + frequency_penalty: float | NotGiven = NOT_GIVEN, + presence_penalty: float | NotGiven = NOT_GIVEN, + **kwargs: Any, + ) -> ChatCompletionResponse: + body = self._create_body( + model=model, + messages=messages, + n=n, + logprobs=logprobs, + top_logprobs=top_logprobs, + stop=stop, + temperature=temperature, + max_tokens=max_tokens, + top_p=top_p, + frequency_penalty=frequency_penalty, + presence_penalty=presence_penalty, + **kwargs, + ) + + url = f"{self._client.get_base_url()}/{self._module_name}" + response = self._post(url=url, body=body) + return self._json_to_response(response) + + def _create_body( + self, + model: str, + messages: List[ChatMessage], + n: Optional[int] | NotGiven, + logprobs: Optional[bool] | NotGiven, + top_logprobs: Optional[int] | NotGiven, + max_tokens: Optional[int] | NotGiven, + temperature: Optional[float] | NotGiven, + top_p: Optional[float] | NotGiven, + stop: Optional[Union[str, List[str]]] | NotGiven, + frequency_penalty: Optional[float] | NotGiven, + presence_penalty: Optional[float] | NotGiven, + **kwargs: Any, + ) -> Dict[str, Any]: + return remove_not_given( + { + "model": model, + "messages": [message.to_dict() for message in messages], + "temperature": temperature, + "maxTokens": max_tokens, + "n": n, + "topP": top_p, + "logprobs": logprobs, + "topLogprobs": top_logprobs, + "stop": stop, + "frequencyPenalty": frequency_penalty, + "presencePenalty": presence_penalty, + **kwargs, + } + ) + + def _json_to_response(self, json: Dict[str, Any]) -> ChatCompletionResponse: + return ChatCompletionResponse.from_dict(json) diff --git a/ai21/clients/studio/resources/studio_chat.py b/ai21/clients/studio/resources/studio_chat.py index f37d9d74..70f8b01f 100644 --- a/ai21/clients/studio/resources/studio_chat.py +++ b/ai21/clients/studio/resources/studio_chat.py @@ -1,6 +1,7 @@ from typing import List, Optional from ai21.clients.common.chat_base import Chat +from ai21.clients.studio.resources.chat import ChatCompletions from ai21.clients.studio.resources.studio_resource import StudioResource from ai21.models import ChatMessage, Penalty, ChatResponse @@ -42,3 +43,7 @@ def create( url = f"{self._client.get_base_url()}/{model}/{self._module_name}" response = self._post(url=url, body=body) return self._json_to_response(response) + + @property + def completions(self) -> ChatCompletions: + return ChatCompletions(self._client) diff --git a/ai21/models/__init__.py b/ai21/models/__init__.py index 174215f5..cc851fc9 100644 --- a/ai21/models/__init__.py +++ b/ai21/models/__init__.py @@ -1,3 +1,4 @@ +from ai21.models.chat.role_type import RoleType from ai21.models.chat_message import ChatMessage from ai21.models.document_type import DocumentType from ai21.models.embed_type import EmbedType @@ -25,10 +26,8 @@ from ai21.models.responses.segmentation_response import SegmentationResponse from ai21.models.responses.summarize_by_segment_response import SummarizeBySegmentResponse, SegmentSummary, Highlight from ai21.models.responses.summarize_response import SummarizeResponse -from ai21.models.role_type import RoleType from ai21.models.summary_method import SummaryMethod - __all__ = [ "ChatMessage", "RoleType", diff --git a/ai21/models/chat/__init__.py b/ai21/models/chat/__init__.py new file mode 100644 index 00000000..0fb4df66 --- /dev/null +++ b/ai21/models/chat/__init__.py @@ -0,0 +1,8 @@ +from __future__ import annotations + +from .chat_completion_response import ChatCompletionResponse +from .chat_completion_response import ChatCompletionResponseChoice +from .chat_message import ChatMessage +from .role_type import RoleType as RoleType + +__all__ = ["ChatCompletionResponse", "ChatCompletionResponseChoice", "ChatMessage", "RoleType"] diff --git a/ai21/models/chat/chat_completion_response.py b/ai21/models/chat/chat_completion_response.py new file mode 100644 index 00000000..b4551e6b --- /dev/null +++ b/ai21/models/chat/chat_completion_response.py @@ -0,0 +1,22 @@ +from dataclasses import dataclass +from typing import Optional, List + +from ai21.models.ai21_base_model_mixin import AI21BaseModelMixin +from ai21.models.logprobs import Logprobs +from ai21.models.usage_info import UsageInfo +from .chat_message import ChatMessage + + +@dataclass +class ChatCompletionResponseChoice(AI21BaseModelMixin): + index: int + message: ChatMessage + logprobs: Optional[Logprobs] = None + finish_reason: Optional[str] = None + + +@dataclass +class ChatCompletionResponse(AI21BaseModelMixin): + id: str + choices: List[ChatCompletionResponseChoice] + usage: UsageInfo diff --git a/ai21/models/chat/chat_message.py b/ai21/models/chat/chat_message.py new file mode 100644 index 00000000..f3f4a437 --- /dev/null +++ b/ai21/models/chat/chat_message.py @@ -0,0 +1,10 @@ +from dataclasses import dataclass + +from ai21.models.ai21_base_model_mixin import AI21BaseModelMixin +from .role_type import RoleType + + +@dataclass +class ChatMessage(AI21BaseModelMixin): + role: RoleType + content: str diff --git a/ai21/models/role_type.py b/ai21/models/chat/role_type.py similarity index 100% rename from ai21/models/role_type.py rename to ai21/models/chat/role_type.py diff --git a/ai21/models/chat_message.py b/ai21/models/chat_message.py index c7536a77..788affee 100644 --- a/ai21/models/chat_message.py +++ b/ai21/models/chat_message.py @@ -1,7 +1,7 @@ from dataclasses import dataclass from ai21.models.ai21_base_model_mixin import AI21BaseModelMixin -from ai21.models.role_type import RoleType +from ai21.models.chat.role_type import RoleType @dataclass diff --git a/ai21/models/logprobs.py b/ai21/models/logprobs.py new file mode 100644 index 00000000..a2bd8d3b --- /dev/null +++ b/ai21/models/logprobs.py @@ -0,0 +1,22 @@ +from dataclasses import dataclass +from typing import List + +from ai21.models.ai21_base_model_mixin import AI21BaseModelMixin + + +@dataclass +class TopTokenData(AI21BaseModelMixin): + token: str + logprob: float + + +@dataclass +class LogprobsData(AI21BaseModelMixin): + token: str + logprob: float + top_logprobs: List[TopTokenData] + + +@dataclass +class Logprobs(AI21BaseModelMixin): + content: LogprobsData diff --git a/ai21/models/responses/chat_response.py b/ai21/models/responses/chat_response.py index e1a15a9f..4ca4ad63 100644 --- a/ai21/models/responses/chat_response.py +++ b/ai21/models/responses/chat_response.py @@ -2,7 +2,7 @@ from typing import Optional, List from ai21.models.ai21_base_model_mixin import AI21BaseModelMixin -from ai21.models.role_type import RoleType +from ai21.models.chat.role_type import RoleType @dataclass diff --git a/ai21/models/usage_info.py b/ai21/models/usage_info.py new file mode 100644 index 00000000..a6195cb0 --- /dev/null +++ b/ai21/models/usage_info.py @@ -0,0 +1,10 @@ +from dataclasses import dataclass + +from ai21.models.ai21_base_model_mixin import AI21BaseModelMixin + + +@dataclass +class UsageInfo(AI21BaseModelMixin): + prompt_tokens: int + completion_tokens: int + total_tokens: int diff --git a/examples/__init__.py b/examples/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/examples/studio/__init__.py b/examples/studio/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/examples/studio/chat.py b/examples/studio/chat.py index 516093d9..ff457b24 100644 --- a/examples/studio/chat.py +++ b/examples/studio/chat.py @@ -1,5 +1,11 @@ +""" +This examples uses a deprecated method client.chat.create and instead +should be replaced with the `client.chat.completions.create` +""" + from ai21 import AI21Client -from ai21.models import ChatMessage, RoleType, Penalty +from ai21.models import RoleType, Penalty +from ai21.models import ChatMessage system = "You're a support engineer in a SaaS company" messages = [ @@ -8,11 +14,12 @@ ChatMessage(text="I am having trouble signing up for your product with my Google account.", role=RoleType.USER), ] + client = AI21Client() response = client.chat.create( system=system, messages=messages, - model="j2-ultra", + model="j2-mid", count_penalty=Penalty( scale=0, apply_to_emojis=False, diff --git a/examples/studio/chat/__init__.py b/examples/studio/chat/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/examples/studio/chat/chat_completions.py b/examples/studio/chat/chat_completions.py new file mode 100644 index 00000000..9248bdf6 --- /dev/null +++ b/examples/studio/chat/chat_completions.py @@ -0,0 +1,28 @@ +from ai21 import AI21Client +from ai21.models import RoleType +from ai21.models.chat import ChatMessage + +system = "You're a support engineer in a SaaS company" +messages = [ + ChatMessage(content="Hello, I need help with a signup process.", role=RoleType.USER), + ChatMessage(content="Hi Alice, I can help you with that. What seems to be the problem?", role=RoleType.ASSISTANT), + ChatMessage(content="I am having trouble signing up for your product with my Google account.", role=RoleType.USER), +] + +client = AI21Client() + +response = client.chat.completions.create( + messages=messages, + model="new-model-name", + n=2, + logprobs=True, + top_logprobs=2, + max_tokens=100, + temperature=0.7, + top_p=1.0, + stop=["\n"], + frequency_penalty=0.1, + presence_penalty=0.1, +) + +print(response) diff --git a/tests/integration_tests/clients/studio/test_chat_completions.py b/tests/integration_tests/clients/studio/test_chat_completions.py new file mode 100644 index 00000000..a98f79eb --- /dev/null +++ b/tests/integration_tests/clients/studio/test_chat_completions.py @@ -0,0 +1,39 @@ +import pytest + +from ai21 import AI21Client +from ai21.models.chat import ChatMessage +from ai21.models import RoleType +from ai21.models.chat.chat_completion_response import ChatCompletionResponse + + +_MODEL = "new-model-name" +_MESSAGES = [ + ChatMessage( + content="Hello, I need help studying for the coming test, can you teach me about the US constitution? ", + role=RoleType.USER, + ), +] + + +# TODO: When the api is officially released, update the test to assert the actual response +@pytest.mark.skip(reason="API is not officially released") +def test_chat_completion(): + num_results = 5 + messages = _MESSAGES + + client = AI21Client() + response = client.chat.completions.create( + model=_MODEL, + messages=messages, + num_results=num_results, + max_tokens=64, + logprobs=True, + top_logprobs=0.6, + temperature=0.7, + stop=["\n"], + top_p=0.3, + frequency_penalty=0.2, + presence_penalty=0.4, + ) + + assert isinstance(response, ChatCompletionResponse) diff --git a/tests/integration_tests/clients/test_bedrock.py b/tests/integration_tests/clients/test_bedrock.py index 2e64efb0..02d0ee3e 100644 --- a/tests/integration_tests/clients/test_bedrock.py +++ b/tests/integration_tests/clients/test_bedrock.py @@ -1,6 +1,7 @@ """ Run this script after setting the environment variable called AWS_ACCESS_KEY_ID, AWS_SECRET_ACCESS_KEY """ + import subprocess from pathlib import Path diff --git a/tests/unittests/clients/studio/resources/conftest.py b/tests/unittests/clients/studio/resources/conftest.py index 5a7f3f2c..88cf5ce1 100644 --- a/tests/unittests/clients/studio/resources/conftest.py +++ b/tests/unittests/clients/studio/resources/conftest.py @@ -2,6 +2,7 @@ from pytest_mock import MockerFixture from ai21.ai21_http_client import AI21HTTPClient +from ai21.clients.studio.resources.chat import ChatCompletions from ai21.clients.studio.resources.studio_answer import StudioAnswer from ai21.clients.studio.resources.studio_chat import StudioChat from ai21.clients.studio.resources.studio_completion import StudioCompletion @@ -43,7 +44,13 @@ SummarizeBySegmentResponse, SegmentSummary, ) +from ai21.models.chat import ( + ChatMessage as ChatCompletionChatMessage, + ChatCompletionResponse, + ChatCompletionResponseChoice, +) from ai21.models.responses.segmentation_response import Segment +from ai21.models.usage_info import UsageInfo from ai21.utils.typing import to_lower_camel_case @@ -110,6 +117,45 @@ def get_studio_chat(): ) +def get_chat_completions(): + _DUMMY_MODEL = "dummy-chat-model" + _DUMMY_MESSAGES = [ + ChatCompletionChatMessage(content="Hello, I need help with a signup process.", role=RoleType.USER), + ChatCompletionChatMessage( + content="Hi Alice, I can help you with that. What seems to be the problem?", + role=RoleType.ASSISTANT, + ), + ] + + return ( + ChatCompletions, + {"model": _DUMMY_MODEL, "messages": _DUMMY_MESSAGES}, + "chat/complete", + { + "model": _DUMMY_MODEL, + "messages": [message.to_dict() for message in _DUMMY_MESSAGES], + }, + ChatCompletionResponse( + id="some-id", + choices=[ + ChatCompletionResponseChoice( + index=0, + message=ChatCompletionChatMessage( + content="Hello, I need help with a signup process.", role=RoleType.USER + ), + finish_reason="dummy_reason", + logprobs=None, + ) + ], + usage=UsageInfo( + prompt_tokens=10, + completion_tokens=20, + total_tokens=30, + ), + ), + ) + + def get_studio_completion(**kwargs): _DUMMY_MODEL = "dummy-completion-model" _DUMMY_PROMPT = "dummy-prompt" diff --git a/tests/unittests/clients/studio/resources/test_studio_resources.py b/tests/unittests/clients/studio/resources/test_studio_resources.py index 96fdc154..0c21d88c 100644 --- a/tests/unittests/clients/studio/resources/test_studio_resources.py +++ b/tests/unittests/clients/studio/resources/test_studio_resources.py @@ -17,6 +17,7 @@ get_studio_segmentation, get_studio_summarization, get_studio_summarize_by_segment, + get_chat_completions, ) _BASE_URL = "https://test.api.ai21.com/studio/v1" @@ -31,6 +32,7 @@ class TestStudioResources: ids=[ "studio_answer", "studio_chat", + "chat_completions", "studio_completion", "studio_completion_with_extra_args", "studio_embed", @@ -45,6 +47,7 @@ class TestStudioResources: argvalues=[ (get_studio_answer()), (get_studio_chat()), + (get_chat_completions()), (get_studio_completion()), (get_studio_completion(temperature=0.5, max_tokens=50)), (get_studio_embed()),