From debb61ab02b0a5d2590aff5a6f43910504776e88 Mon Sep 17 00:00:00 2001 From: Asaf Gardin Date: Sun, 24 Mar 2024 17:37:52 +0200 Subject: [PATCH 01/15] feat: support chat completion in studio SDK --- ai21/clients/common/chat_base.py | 8 +- ai21/clients/studio/resources/studio_chat.py | 5 ++ .../resources/studio_chat_completion.py | 79 +++++++++++++++++++ ai21/models/__init__.py | 5 +- ai21/models/chat_message.py | 13 ++- ai21/models/logprobs.py | 22 ++++++ .../responses/chat_completion_response.py | 22 ++++++ ai21/models/responses/legacy/__init__.py | 0 .../responses/{ => legacy}/chat_response.py | 0 ai21/models/usage_info.py | 10 +++ examples/studio/chat.py | 8 +- examples/studio/chat_complete.py | 28 +++++++ .../clients/studio/test_chat_completions.py | 38 +++++++++ .../clients/studio/resources/conftest.py | 6 +- 14 files changed, 232 insertions(+), 12 deletions(-) create mode 100644 ai21/clients/studio/resources/studio_chat_completion.py create mode 100644 ai21/models/logprobs.py create mode 100644 ai21/models/responses/chat_completion_response.py create mode 100644 ai21/models/responses/legacy/__init__.py rename ai21/models/responses/{ => legacy}/chat_response.py (100%) create mode 100644 ai21/models/usage_info.py create mode 100644 examples/studio/chat_complete.py create mode 100644 tests/integration_tests/clients/studio/test_chat_completions.py diff --git a/ai21/clients/common/chat_base.py b/ai21/clients/common/chat_base.py index dee9fc4d..50c35022 100644 --- a/ai21/clients/common/chat_base.py +++ b/ai21/clients/common/chat_base.py @@ -1,6 +1,7 @@ from abc import ABC, abstractmethod from typing import List, Any, Dict, Optional +from ai21.clients.studio.resources.studio_chat_completion import ChatCompletions from ai21.models import Penalty, ChatResponse, ChatMessage @@ -47,6 +48,11 @@ def create( """ pass + @property + @abstractmethod + def completions(self) -> ChatCompletions: + pass + def _json_to_response(self, json: Dict[str, Any]) -> ChatResponse: return ChatResponse.from_dict(json) @@ -69,7 +75,7 @@ def _create_body( return { "model": model, "system": system, - "messages": [message.to_dict() for message in messages], + "messages": [{"role": message.role, "text": message.content} for message in messages], "temperature": temperature, "maxTokens": max_tokens, "minTokens": min_tokens, diff --git a/ai21/clients/studio/resources/studio_chat.py b/ai21/clients/studio/resources/studio_chat.py index f37d9d74..dbae2ef3 100644 --- a/ai21/clients/studio/resources/studio_chat.py +++ b/ai21/clients/studio/resources/studio_chat.py @@ -1,6 +1,7 @@ from typing import List, Optional from ai21.clients.common.chat_base import Chat +from ai21.clients.studio.resources.studio_chat_completion import ChatCompletions from ai21.clients.studio.resources.studio_resource import StudioResource from ai21.models import ChatMessage, Penalty, ChatResponse @@ -42,3 +43,7 @@ def create( url = f"{self._client.get_base_url()}/{model}/{self._module_name}" response = self._post(url=url, body=body) return self._json_to_response(response) + + @property + def completions(self) -> ChatCompletions: + return ChatCompletions(self._client) diff --git a/ai21/clients/studio/resources/studio_chat_completion.py b/ai21/clients/studio/resources/studio_chat_completion.py new file mode 100644 index 00000000..e722257d --- /dev/null +++ b/ai21/clients/studio/resources/studio_chat_completion.py @@ -0,0 +1,79 @@ +from __future__ import annotations + +from typing import List, Optional, Union, Any, Dict + +from ai21.clients.studio.resources.studio_resource import StudioResource +from ai21.models import ChatMessage +from ai21.models.responses.chat_completion_response import ChatCompletionResponse +from ai21.types import NotGiven, NOT_GIVEN +from ai21.utils.typing import remove_not_given + + +class ChatCompletions(StudioResource): + _module_name = "chat/complete" + + def create( + self, + model: str, + messages: List[ChatMessage], + n: Optional[int] | NotGiven = NOT_GIVEN, + logprobs: Optional[bool] | NotGiven = NOT_GIVEN, + top_logprobs: Optional[int] | NotGiven = NOT_GIVEN, + max_tokens: Optional[int] | NotGiven = NOT_GIVEN, + temperature: Optional[float] | NotGiven = NOT_GIVEN, + top_p: Optional[float] | NotGiven = NOT_GIVEN, + stop: Optional[Union[str, List[str]]] | NotGiven = NOT_GIVEN, + frequency_penalty: Optional[float] | NotGiven = NOT_GIVEN, + presence_penalty: Optional[float] | NotGiven = NOT_GIVEN, + **kwargs: Any, + ) -> ChatCompletionResponse: + body = self._create_body( + model=model, + messages=messages, + n=n, + logprobs=logprobs, + top_logprobs=top_logprobs, + stop=stop, + temperature=temperature, + max_tokens=max_tokens, + top_p=top_p, + frequency_penalty=frequency_penalty, + presence_penalty=presence_penalty, + ) + + url = f"{self._client.get_base_url()}/{self._module_name}" + response = self._post(url=url, body=body) + return self._json_to_response(response) + + def _create_body( + self, + model: str, + messages: List[ChatMessage], + n: Optional[int] | NotGiven, + logprobs: Optional[bool] | NotGiven, + top_logprobs: Optional[int] | NotGiven, + max_tokens: Optional[int] | NotGiven, + temperature: Optional[float] | NotGiven, + top_p: Optional[float] | NotGiven, + stop: Optional[Union[str, List[str]]] | NotGiven, + frequency_penalty: Optional[float] | NotGiven, + presence_penalty: Optional[float] | NotGiven, + ) -> Dict[str, Any]: + return remove_not_given( + { + "model": model, + "messages": [{"role": message.role, "content": message.content} for message in messages], + "temperature": temperature, + "maxTokens": max_tokens, + "n": n, + "topP": top_p, + "logprobs": logprobs, + "topLogprobs": top_logprobs, + "stop": stop, + "frequencyPenalty": frequency_penalty, + "presencePenalty": presence_penalty, + } + ) + + def _json_to_response(self, json: Dict[str, Any]) -> ChatCompletionResponse: + return ChatCompletionResponse.from_dict(json) diff --git a/ai21/models/__init__.py b/ai21/models/__init__.py index 174215f5..5dc0a782 100644 --- a/ai21/models/__init__.py +++ b/ai21/models/__init__.py @@ -1,11 +1,10 @@ -from ai21.models.chat_message import ChatMessage from ai21.models.document_type import DocumentType from ai21.models.embed_type import EmbedType from ai21.models.improvement_type import ImprovementType +from ai21.models.chat_message import ChatMessage from ai21.models.paraphrase_style_type import ParaphraseStyleType from ai21.models.penalty import Penalty from ai21.models.responses.answer_response import AnswerResponse -from ai21.models.responses.chat_response import ChatResponse, ChatOutput, FinishReason from ai21.models.responses.completion_response import ( CompletionsResponse, Completion, @@ -19,6 +18,7 @@ from ai21.models.responses.file_response import FileResponse from ai21.models.responses.gec_response import GECResponse, Correction, CorrectionType from ai21.models.responses.improvement_response import ImprovementsResponse, Improvement +from ai21.models.responses.legacy.chat_response import ChatResponse, ChatOutput, FinishReason from ai21.models.responses.library_answer_response import LibraryAnswerResponse, SourceDocument from ai21.models.responses.library_search_response import LibrarySearchResponse, LibrarySearchResult from ai21.models.responses.paraphrase_response import ParaphraseResponse, Suggestion @@ -28,7 +28,6 @@ from ai21.models.role_type import RoleType from ai21.models.summary_method import SummaryMethod - __all__ = [ "ChatMessage", "RoleType", diff --git a/ai21/models/chat_message.py b/ai21/models/chat_message.py index c7536a77..9c5cb499 100644 --- a/ai21/models/chat_message.py +++ b/ai21/models/chat_message.py @@ -1,4 +1,6 @@ +import warnings from dataclasses import dataclass +from typing import Optional from ai21.models.ai21_base_model_mixin import AI21BaseModelMixin from ai21.models.role_type import RoleType @@ -7,4 +9,13 @@ @dataclass class ChatMessage(AI21BaseModelMixin): role: RoleType - text: str + text: Optional[str] = None + content: Optional[str] = None + + def __post_init__(self): + if self.text is None and self.content is None: + raise ValueError("'content' field or 'text' field must be provided") + + if self.text is not None and self.content is None: + warnings.warn("'text' field is deprecated, please use 'content' field instead", DeprecationWarning) + self.content = self.text diff --git a/ai21/models/logprobs.py b/ai21/models/logprobs.py new file mode 100644 index 00000000..a2bd8d3b --- /dev/null +++ b/ai21/models/logprobs.py @@ -0,0 +1,22 @@ +from dataclasses import dataclass +from typing import List + +from ai21.models.ai21_base_model_mixin import AI21BaseModelMixin + + +@dataclass +class TopTokenData(AI21BaseModelMixin): + token: str + logprob: float + + +@dataclass +class LogprobsData(AI21BaseModelMixin): + token: str + logprob: float + top_logprobs: List[TopTokenData] + + +@dataclass +class Logprobs(AI21BaseModelMixin): + content: LogprobsData diff --git a/ai21/models/responses/chat_completion_response.py b/ai21/models/responses/chat_completion_response.py new file mode 100644 index 00000000..1050f9eb --- /dev/null +++ b/ai21/models/responses/chat_completion_response.py @@ -0,0 +1,22 @@ +from dataclasses import dataclass +from typing import Optional, List + +from ai21.models import ChatMessage +from ai21.models.ai21_base_model_mixin import AI21BaseModelMixin +from ai21.models.logprobs import Logprobs +from ai21.models.usage_info import UsageInfo + + +@dataclass +class ChatCompletionResponseChoice(AI21BaseModelMixin): + index: int + message: ChatMessage + logprobs: Optional[Logprobs] = None + finish_reason: Optional[str] = None + + +@dataclass +class ChatCompletionResponse(AI21BaseModelMixin): + id: str + choices: List[ChatCompletionResponseChoice] + usage: UsageInfo diff --git a/ai21/models/responses/legacy/__init__.py b/ai21/models/responses/legacy/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/ai21/models/responses/chat_response.py b/ai21/models/responses/legacy/chat_response.py similarity index 100% rename from ai21/models/responses/chat_response.py rename to ai21/models/responses/legacy/chat_response.py diff --git a/ai21/models/usage_info.py b/ai21/models/usage_info.py new file mode 100644 index 00000000..a6195cb0 --- /dev/null +++ b/ai21/models/usage_info.py @@ -0,0 +1,10 @@ +from dataclasses import dataclass + +from ai21.models.ai21_base_model_mixin import AI21BaseModelMixin + + +@dataclass +class UsageInfo(AI21BaseModelMixin): + prompt_tokens: int + completion_tokens: int + total_tokens: int diff --git a/examples/studio/chat.py b/examples/studio/chat.py index 516093d9..ada59562 100644 --- a/examples/studio/chat.py +++ b/examples/studio/chat.py @@ -3,16 +3,16 @@ system = "You're a support engineer in a SaaS company" messages = [ - ChatMessage(text="Hello, I need help with a signup process.", role=RoleType.USER), - ChatMessage(text="Hi Alice, I can help you with that. What seems to be the problem?", role=RoleType.ASSISTANT), - ChatMessage(text="I am having trouble signing up for your product with my Google account.", role=RoleType.USER), + ChatMessage(content="Hello, I need help with a signup process.", role=RoleType.USER), + ChatMessage(content="Hi Alice, I can help you with that. What seems to be the problem?", role=RoleType.ASSISTANT), + ChatMessage(content="I am having trouble signing up for your product with my Google account.", role=RoleType.USER), ] client = AI21Client() response = client.chat.create( system=system, messages=messages, - model="j2-ultra", + model="j2-mid", count_penalty=Penalty( scale=0, apply_to_emojis=False, diff --git a/examples/studio/chat_complete.py b/examples/studio/chat_complete.py new file mode 100644 index 00000000..23e88aa7 --- /dev/null +++ b/examples/studio/chat_complete.py @@ -0,0 +1,28 @@ +from ai21 import AI21Client, AI21EnvConfig +from ai21.models import ChatMessage, RoleType + +system = "You're a support engineer in a SaaS company" +messages = [ + ChatMessage(content="Hello, I need help with a signup process.", role=RoleType.USER), + ChatMessage(content="Hi Alice, I can help you with that. What seems to be the problem?", role=RoleType.ASSISTANT), + ChatMessage(content="I am having trouble signing up for your product with my Google account.", role=RoleType.USER), +] + +AI21EnvConfig.api_host = "https://api-stage.ai21.com" +client = AI21Client() + +response = client.chat.completions.create( + messages=messages, + model="gaia-small", + n=2, + logprobs=True, + top_logprobs=2, + max_tokens=100, + temperature=0.7, + top_p=1.0, + stop=["\n"], + frequency_penalty=0.1, + presence_penalty=0.1, +) + +print(response) diff --git a/tests/integration_tests/clients/studio/test_chat_completions.py b/tests/integration_tests/clients/studio/test_chat_completions.py new file mode 100644 index 00000000..3fbc000d --- /dev/null +++ b/tests/integration_tests/clients/studio/test_chat_completions.py @@ -0,0 +1,38 @@ +import pytest + +from ai21 import AI21Client +from ai21.models import ChatMessage, RoleType +from ai21.models.responses.chat_completion_response import ChatCompletionResponse + + +_MODEL = "gaia-small" +_MESSAGES = [ + ChatMessage( + content="Hello, I need help studying for the coming test, can you teach me about the US constitution? ", + role=RoleType.USER, + ), +] + + +# TODO: When the api is officially released, update the test to assert the response +@pytest.mark.skip(reason="API is not officially released") +def test_chat_completion(): + num_results = 5 + messages = _MESSAGES + + client = AI21Client() + response = client.chat.completions.create( + model=_MODEL, + messages=messages, + num_results=num_results, + max_tokens=64, + logprobs=True, + top_logprobs=0.6, + temperature=0.7, + stop=["\n"], + top_p=0.3, + frequency_penalty=0.2, + presence_penalty=0.4, + ) + + assert isinstance(response, ChatCompletionResponse) diff --git a/tests/unittests/clients/studio/resources/conftest.py b/tests/unittests/clients/studio/resources/conftest.py index 5a7f3f2c..141a9285 100644 --- a/tests/unittests/clients/studio/resources/conftest.py +++ b/tests/unittests/clients/studio/resources/conftest.py @@ -71,9 +71,9 @@ def get_studio_answer(): def get_studio_chat(): _DUMMY_MODEL = "dummy-chat-model" _DUMMY_MESSAGES = [ - ChatMessage(text="Hello, I need help with a signup process.", role=RoleType.USER), + ChatMessage(content="Hello, I need help with a signup process.", role=RoleType.USER), ChatMessage( - text="Hi Alice, I can help you with that. What seems to be the problem?", + content="Hi Alice, I can help you with that. What seems to be the problem?", role=RoleType.ASSISTANT, ), ] @@ -86,7 +86,7 @@ def get_studio_chat(): { "model": _DUMMY_MODEL, "system": _DUMMY_SYSTEM, - "messages": [message.to_dict() for message in _DUMMY_MESSAGES], + "messages": [{"text": message.content, "role": message.role} for message in _DUMMY_MESSAGES], "temperature": 0.7, "maxTokens": 300, "minTokens": 0, From ca5499f4addd0497f7380fa36085509c257f0f03 Mon Sep 17 00:00:00 2001 From: Asaf Gardin Date: Sun, 24 Mar 2024 17:57:32 +0200 Subject: [PATCH 02/15] refactor: Moved chat message to different packages --- ai21/clients/common/chat_base.py | 2 +- .../studio/resources/studio_chat_completion.py | 4 ++-- ai21/models/chat/__init__.py | 0 ai21/models/chat/chat_message.py | 10 ++++++++++ ai21/models/chat_message.py | 13 +------------ examples/studio/chat.py | 10 ++++++---- .../unittests/clients/studio/resources/conftest.py | 6 +++--- 7 files changed, 23 insertions(+), 22 deletions(-) create mode 100644 ai21/models/chat/__init__.py create mode 100644 ai21/models/chat/chat_message.py diff --git a/ai21/clients/common/chat_base.py b/ai21/clients/common/chat_base.py index 50c35022..e5edb0b9 100644 --- a/ai21/clients/common/chat_base.py +++ b/ai21/clients/common/chat_base.py @@ -75,7 +75,7 @@ def _create_body( return { "model": model, "system": system, - "messages": [{"role": message.role, "text": message.content} for message in messages], + "messages": [message.to_dict() for message in messages], "temperature": temperature, "maxTokens": max_tokens, "minTokens": min_tokens, diff --git a/ai21/clients/studio/resources/studio_chat_completion.py b/ai21/clients/studio/resources/studio_chat_completion.py index e722257d..a628ffc8 100644 --- a/ai21/clients/studio/resources/studio_chat_completion.py +++ b/ai21/clients/studio/resources/studio_chat_completion.py @@ -3,7 +3,7 @@ from typing import List, Optional, Union, Any, Dict from ai21.clients.studio.resources.studio_resource import StudioResource -from ai21.models import ChatMessage +from ai21.models.chat import ChatMessage from ai21.models.responses.chat_completion_response import ChatCompletionResponse from ai21.types import NotGiven, NOT_GIVEN from ai21.utils.typing import remove_not_given @@ -62,7 +62,7 @@ def _create_body( return remove_not_given( { "model": model, - "messages": [{"role": message.role, "content": message.content} for message in messages], + "messages": [message.to_dict() for message in messages], "temperature": temperature, "maxTokens": max_tokens, "n": n, diff --git a/ai21/models/chat/__init__.py b/ai21/models/chat/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/ai21/models/chat/chat_message.py b/ai21/models/chat/chat_message.py new file mode 100644 index 00000000..e335b265 --- /dev/null +++ b/ai21/models/chat/chat_message.py @@ -0,0 +1,10 @@ +from dataclasses import dataclass + +from ai21.models import RoleType +from ai21.models.ai21_base_model_mixin import AI21BaseModelMixin + + +@dataclass +class ChatMessage(AI21BaseModelMixin): + role: RoleType + content: str diff --git a/ai21/models/chat_message.py b/ai21/models/chat_message.py index 9c5cb499..c7536a77 100644 --- a/ai21/models/chat_message.py +++ b/ai21/models/chat_message.py @@ -1,6 +1,4 @@ -import warnings from dataclasses import dataclass -from typing import Optional from ai21.models.ai21_base_model_mixin import AI21BaseModelMixin from ai21.models.role_type import RoleType @@ -9,13 +7,4 @@ @dataclass class ChatMessage(AI21BaseModelMixin): role: RoleType - text: Optional[str] = None - content: Optional[str] = None - - def __post_init__(self): - if self.text is None and self.content is None: - raise ValueError("'content' field or 'text' field must be provided") - - if self.text is not None and self.content is None: - warnings.warn("'text' field is deprecated, please use 'content' field instead", DeprecationWarning) - self.content = self.text + text: str diff --git a/examples/studio/chat.py b/examples/studio/chat.py index ada59562..c9575e40 100644 --- a/examples/studio/chat.py +++ b/examples/studio/chat.py @@ -1,13 +1,15 @@ from ai21 import AI21Client -from ai21.models import ChatMessage, RoleType, Penalty +from ai21.models import RoleType, Penalty +from ai21.models import ChatMessage system = "You're a support engineer in a SaaS company" messages = [ - ChatMessage(content="Hello, I need help with a signup process.", role=RoleType.USER), - ChatMessage(content="Hi Alice, I can help you with that. What seems to be the problem?", role=RoleType.ASSISTANT), - ChatMessage(content="I am having trouble signing up for your product with my Google account.", role=RoleType.USER), + ChatMessage(text="Hello, I need help with a signup process.", role=RoleType.USER), + ChatMessage(text="Hi Alice, I can help you with that. What seems to be the problem?", role=RoleType.ASSISTANT), + ChatMessage(text="I am having trouble signing up for your product with my Google account.", role=RoleType.USER), ] + client = AI21Client() response = client.chat.create( system=system, diff --git a/tests/unittests/clients/studio/resources/conftest.py b/tests/unittests/clients/studio/resources/conftest.py index 141a9285..5a7f3f2c 100644 --- a/tests/unittests/clients/studio/resources/conftest.py +++ b/tests/unittests/clients/studio/resources/conftest.py @@ -71,9 +71,9 @@ def get_studio_answer(): def get_studio_chat(): _DUMMY_MODEL = "dummy-chat-model" _DUMMY_MESSAGES = [ - ChatMessage(content="Hello, I need help with a signup process.", role=RoleType.USER), + ChatMessage(text="Hello, I need help with a signup process.", role=RoleType.USER), ChatMessage( - content="Hi Alice, I can help you with that. What seems to be the problem?", + text="Hi Alice, I can help you with that. What seems to be the problem?", role=RoleType.ASSISTANT, ), ] @@ -86,7 +86,7 @@ def get_studio_chat(): { "model": _DUMMY_MODEL, "system": _DUMMY_SYSTEM, - "messages": [{"text": message.content, "role": message.role} for message in _DUMMY_MESSAGES], + "messages": [message.to_dict() for message in _DUMMY_MESSAGES], "temperature": 0.7, "maxTokens": 300, "minTokens": 0, From 76d2fadf1983766f22505eec514f4fc34fab26b6 Mon Sep 17 00:00:00 2001 From: Asaf Gardin Date: Sun, 24 Mar 2024 18:05:34 +0200 Subject: [PATCH 03/15] fix: Added __all__ --- ai21/models/chat/__init__.py | 3 +++ ai21/models/chat/chat_message.py | 2 ++ .../integration_tests/clients/studio/test_chat_completions.py | 3 ++- 3 files changed, 7 insertions(+), 1 deletion(-) diff --git a/ai21/models/chat/__init__.py b/ai21/models/chat/__init__.py index e69de29b..77182f32 100644 --- a/ai21/models/chat/__init__.py +++ b/ai21/models/chat/__init__.py @@ -0,0 +1,3 @@ +from __future__ import annotations + +from .chat_message import ChatMessage as ChatMessage diff --git a/ai21/models/chat/chat_message.py b/ai21/models/chat/chat_message.py index e335b265..36f0158e 100644 --- a/ai21/models/chat/chat_message.py +++ b/ai21/models/chat/chat_message.py @@ -3,6 +3,8 @@ from ai21.models import RoleType from ai21.models.ai21_base_model_mixin import AI21BaseModelMixin +__all__ = ["ChatMessage"] + @dataclass class ChatMessage(AI21BaseModelMixin): diff --git a/tests/integration_tests/clients/studio/test_chat_completions.py b/tests/integration_tests/clients/studio/test_chat_completions.py index 3fbc000d..5241a977 100644 --- a/tests/integration_tests/clients/studio/test_chat_completions.py +++ b/tests/integration_tests/clients/studio/test_chat_completions.py @@ -1,7 +1,8 @@ import pytest from ai21 import AI21Client -from ai21.models import ChatMessage, RoleType +from ai21.models.chat import ChatMessage +from ai21.models import RoleType from ai21.models.responses.chat_completion_response import ChatCompletionResponse From a59e503e551a3a816db45935f6206ee7aa95946b Mon Sep 17 00:00:00 2001 From: Asaf Gardin Date: Mon, 25 Mar 2024 10:36:29 +0200 Subject: [PATCH 04/15] refactor: imports and file structure --- ai21/clients/studio/resources/studio_chat_completion.py | 3 +-- ai21/models/__init__.py | 2 +- ai21/models/chat/__init__.py | 2 ++ ai21/models/{responses => chat}/chat_completion_response.py | 4 +++- ai21/models/responses/{legacy => }/chat_response.py | 0 ai21/models/responses/legacy/__init__.py | 0 .../integration_tests/clients/studio/test_chat_completions.py | 2 +- 7 files changed, 8 insertions(+), 5 deletions(-) rename ai21/models/{responses => chat}/chat_completion_response.py (83%) rename ai21/models/responses/{legacy => }/chat_response.py (100%) delete mode 100644 ai21/models/responses/legacy/__init__.py diff --git a/ai21/clients/studio/resources/studio_chat_completion.py b/ai21/clients/studio/resources/studio_chat_completion.py index a628ffc8..c2837f42 100644 --- a/ai21/clients/studio/resources/studio_chat_completion.py +++ b/ai21/clients/studio/resources/studio_chat_completion.py @@ -3,8 +3,7 @@ from typing import List, Optional, Union, Any, Dict from ai21.clients.studio.resources.studio_resource import StudioResource -from ai21.models.chat import ChatMessage -from ai21.models.responses.chat_completion_response import ChatCompletionResponse +from ai21.models.chat import ChatMessage, ChatCompletionResponse from ai21.types import NotGiven, NOT_GIVEN from ai21.utils.typing import remove_not_given diff --git a/ai21/models/__init__.py b/ai21/models/__init__.py index 5dc0a782..8f68701e 100644 --- a/ai21/models/__init__.py +++ b/ai21/models/__init__.py @@ -18,7 +18,7 @@ from ai21.models.responses.file_response import FileResponse from ai21.models.responses.gec_response import GECResponse, Correction, CorrectionType from ai21.models.responses.improvement_response import ImprovementsResponse, Improvement -from ai21.models.responses.legacy.chat_response import ChatResponse, ChatOutput, FinishReason +from ai21.models.responses.chat_response import ChatResponse, ChatOutput, FinishReason from ai21.models.responses.library_answer_response import LibraryAnswerResponse, SourceDocument from ai21.models.responses.library_search_response import LibrarySearchResponse, LibrarySearchResult from ai21.models.responses.paraphrase_response import ParaphraseResponse, Suggestion diff --git a/ai21/models/chat/__init__.py b/ai21/models/chat/__init__.py index 77182f32..34d1a10a 100644 --- a/ai21/models/chat/__init__.py +++ b/ai21/models/chat/__init__.py @@ -1,3 +1,5 @@ from __future__ import annotations from .chat_message import ChatMessage as ChatMessage +from .chat_completion_response import ChatCompletionResponse as ChatCompletionResponse +from .chat_completion_response import ChatCompletionResponseChoice as ChatCompletionResponseChoice diff --git a/ai21/models/responses/chat_completion_response.py b/ai21/models/chat/chat_completion_response.py similarity index 83% rename from ai21/models/responses/chat_completion_response.py rename to ai21/models/chat/chat_completion_response.py index 1050f9eb..faf889a9 100644 --- a/ai21/models/responses/chat_completion_response.py +++ b/ai21/models/chat/chat_completion_response.py @@ -1,11 +1,13 @@ from dataclasses import dataclass from typing import Optional, List -from ai21.models import ChatMessage from ai21.models.ai21_base_model_mixin import AI21BaseModelMixin +from ai21.models.chat import ChatMessage from ai21.models.logprobs import Logprobs from ai21.models.usage_info import UsageInfo +__all__ = ["ChatCompletionResponse", "ChatCompletionResponseChoice"] + @dataclass class ChatCompletionResponseChoice(AI21BaseModelMixin): diff --git a/ai21/models/responses/legacy/chat_response.py b/ai21/models/responses/chat_response.py similarity index 100% rename from ai21/models/responses/legacy/chat_response.py rename to ai21/models/responses/chat_response.py diff --git a/ai21/models/responses/legacy/__init__.py b/ai21/models/responses/legacy/__init__.py deleted file mode 100644 index e69de29b..00000000 diff --git a/tests/integration_tests/clients/studio/test_chat_completions.py b/tests/integration_tests/clients/studio/test_chat_completions.py index 5241a977..166c6fda 100644 --- a/tests/integration_tests/clients/studio/test_chat_completions.py +++ b/tests/integration_tests/clients/studio/test_chat_completions.py @@ -3,7 +3,7 @@ from ai21 import AI21Client from ai21.models.chat import ChatMessage from ai21.models import RoleType -from ai21.models.responses.chat_completion_response import ChatCompletionResponse +from ai21.models.chat.chat_completion_response import ChatCompletionResponse _MODEL = "gaia-small" From 05e4f67be3fad8ba27adcb7b2e3ec62006d402c5 Mon Sep 17 00:00:00 2001 From: Asaf Gardin Date: Mon, 25 Mar 2024 10:37:40 +0200 Subject: [PATCH 05/15] docs: Updated todo --- tests/integration_tests/clients/studio/test_chat_completions.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/integration_tests/clients/studio/test_chat_completions.py b/tests/integration_tests/clients/studio/test_chat_completions.py index 166c6fda..63d852e5 100644 --- a/tests/integration_tests/clients/studio/test_chat_completions.py +++ b/tests/integration_tests/clients/studio/test_chat_completions.py @@ -15,7 +15,7 @@ ] -# TODO: When the api is officially released, update the test to assert the response +# TODO: When the api is officially released, update the test to assert the actual response @pytest.mark.skip(reason="API is not officially released") def test_chat_completion(): num_results = 5 From 4f4a291ea6a87781f7b5ace51f6f53968026097a Mon Sep 17 00:00:00 2001 From: Asaf Gardin Date: Mon, 25 Mar 2024 10:48:25 +0200 Subject: [PATCH 06/15] fix: imports --- ai21/clients/common/chat_base.py | 2 +- ai21/clients/studio/resources/chat/__init__.py | 3 +++ .../{studio_chat_completion.py => chat/chat_completion.py} | 2 ++ ai21/clients/studio/resources/studio_chat.py | 2 +- 4 files changed, 7 insertions(+), 2 deletions(-) create mode 100644 ai21/clients/studio/resources/chat/__init__.py rename ai21/clients/studio/resources/{studio_chat_completion.py => chat/chat_completion.py} (98%) diff --git a/ai21/clients/common/chat_base.py b/ai21/clients/common/chat_base.py index e5edb0b9..809fcc3c 100644 --- a/ai21/clients/common/chat_base.py +++ b/ai21/clients/common/chat_base.py @@ -1,7 +1,7 @@ from abc import ABC, abstractmethod from typing import List, Any, Dict, Optional -from ai21.clients.studio.resources.studio_chat_completion import ChatCompletions +from ai21.clients.studio.resources.chat import ChatCompletions from ai21.models import Penalty, ChatResponse, ChatMessage diff --git a/ai21/clients/studio/resources/chat/__init__.py b/ai21/clients/studio/resources/chat/__init__.py new file mode 100644 index 00000000..1a12d6c5 --- /dev/null +++ b/ai21/clients/studio/resources/chat/__init__.py @@ -0,0 +1,3 @@ +from __future__ import annotations + +from .chat_completion import ChatCompletions as ChatCompletions diff --git a/ai21/clients/studio/resources/studio_chat_completion.py b/ai21/clients/studio/resources/chat/chat_completion.py similarity index 98% rename from ai21/clients/studio/resources/studio_chat_completion.py rename to ai21/clients/studio/resources/chat/chat_completion.py index c2837f42..47e82228 100644 --- a/ai21/clients/studio/resources/studio_chat_completion.py +++ b/ai21/clients/studio/resources/chat/chat_completion.py @@ -7,6 +7,8 @@ from ai21.types import NotGiven, NOT_GIVEN from ai21.utils.typing import remove_not_given +__all__ = ["ChatCompletions"] + class ChatCompletions(StudioResource): _module_name = "chat/complete" diff --git a/ai21/clients/studio/resources/studio_chat.py b/ai21/clients/studio/resources/studio_chat.py index dbae2ef3..70f8b01f 100644 --- a/ai21/clients/studio/resources/studio_chat.py +++ b/ai21/clients/studio/resources/studio_chat.py @@ -1,7 +1,7 @@ from typing import List, Optional from ai21.clients.common.chat_base import Chat -from ai21.clients.studio.resources.studio_chat_completion import ChatCompletions +from ai21.clients.studio.resources.chat import ChatCompletions from ai21.clients.studio.resources.studio_resource import StudioResource from ai21.models import ChatMessage, Penalty, ChatResponse From 21d3da8a191916651b9789929b5ab9571b07bde0 Mon Sep 17 00:00:00 2001 From: Asaf Gardin Date: Mon, 25 Mar 2024 11:22:14 +0200 Subject: [PATCH 07/15] fix: Added deprecation warning --- ai21/clients/studio/ai21_client.py | 2 +- ai21/clients/studio/resources/studio_chat.py | 6 ++++++ 2 files changed, 7 insertions(+), 1 deletion(-) diff --git a/ai21/clients/studio/ai21_client.py b/ai21/clients/studio/ai21_client.py index dfb781b5..8e5979d0 100644 --- a/ai21/clients/studio/ai21_client.py +++ b/ai21/clients/studio/ai21_client.py @@ -50,7 +50,7 @@ def __init__( http_client=http_client, ) self.completion = StudioCompletion(self._http_client) - self.chat = StudioChat(self._http_client) + self.chat: StudioChat = StudioChat(self._http_client) self.summarize = StudioSummarize(self._http_client) self.embed = StudioEmbed(self._http_client) self.gec = StudioGEC(self._http_client) diff --git a/ai21/clients/studio/resources/studio_chat.py b/ai21/clients/studio/resources/studio_chat.py index 70f8b01f..e163f441 100644 --- a/ai21/clients/studio/resources/studio_chat.py +++ b/ai21/clients/studio/resources/studio_chat.py @@ -1,3 +1,4 @@ +import warnings from typing import List, Optional from ai21.clients.common.chat_base import Chat @@ -25,6 +26,11 @@ def create( count_penalty: Optional[Penalty] = None, **kwargs, ) -> ChatResponse: + warnings.warn( + "This method is deprecated. Please use the `chat.completions.create` method in the client instead.", + DeprecationWarning, + stacklevel=2, + ) body = self._create_body( model=model, messages=messages, From 4bf975e869d019e55ef573aefd74fbbfdad3aa76 Mon Sep 17 00:00:00 2001 From: Asaf Gardin Date: Mon, 25 Mar 2024 11:39:46 +0200 Subject: [PATCH 08/15] test: Added a unittest --- .../clients/studio/resources/conftest.py | 46 +++++++++++++++++++ .../studio/resources/test_studio_resources.py | 3 ++ 2 files changed, 49 insertions(+) diff --git a/tests/unittests/clients/studio/resources/conftest.py b/tests/unittests/clients/studio/resources/conftest.py index 5a7f3f2c..88cf5ce1 100644 --- a/tests/unittests/clients/studio/resources/conftest.py +++ b/tests/unittests/clients/studio/resources/conftest.py @@ -2,6 +2,7 @@ from pytest_mock import MockerFixture from ai21.ai21_http_client import AI21HTTPClient +from ai21.clients.studio.resources.chat import ChatCompletions from ai21.clients.studio.resources.studio_answer import StudioAnswer from ai21.clients.studio.resources.studio_chat import StudioChat from ai21.clients.studio.resources.studio_completion import StudioCompletion @@ -43,7 +44,13 @@ SummarizeBySegmentResponse, SegmentSummary, ) +from ai21.models.chat import ( + ChatMessage as ChatCompletionChatMessage, + ChatCompletionResponse, + ChatCompletionResponseChoice, +) from ai21.models.responses.segmentation_response import Segment +from ai21.models.usage_info import UsageInfo from ai21.utils.typing import to_lower_camel_case @@ -110,6 +117,45 @@ def get_studio_chat(): ) +def get_chat_completions(): + _DUMMY_MODEL = "dummy-chat-model" + _DUMMY_MESSAGES = [ + ChatCompletionChatMessage(content="Hello, I need help with a signup process.", role=RoleType.USER), + ChatCompletionChatMessage( + content="Hi Alice, I can help you with that. What seems to be the problem?", + role=RoleType.ASSISTANT, + ), + ] + + return ( + ChatCompletions, + {"model": _DUMMY_MODEL, "messages": _DUMMY_MESSAGES}, + "chat/complete", + { + "model": _DUMMY_MODEL, + "messages": [message.to_dict() for message in _DUMMY_MESSAGES], + }, + ChatCompletionResponse( + id="some-id", + choices=[ + ChatCompletionResponseChoice( + index=0, + message=ChatCompletionChatMessage( + content="Hello, I need help with a signup process.", role=RoleType.USER + ), + finish_reason="dummy_reason", + logprobs=None, + ) + ], + usage=UsageInfo( + prompt_tokens=10, + completion_tokens=20, + total_tokens=30, + ), + ), + ) + + def get_studio_completion(**kwargs): _DUMMY_MODEL = "dummy-completion-model" _DUMMY_PROMPT = "dummy-prompt" diff --git a/tests/unittests/clients/studio/resources/test_studio_resources.py b/tests/unittests/clients/studio/resources/test_studio_resources.py index 96fdc154..0c21d88c 100644 --- a/tests/unittests/clients/studio/resources/test_studio_resources.py +++ b/tests/unittests/clients/studio/resources/test_studio_resources.py @@ -17,6 +17,7 @@ get_studio_segmentation, get_studio_summarization, get_studio_summarize_by_segment, + get_chat_completions, ) _BASE_URL = "https://test.api.ai21.com/studio/v1" @@ -31,6 +32,7 @@ class TestStudioResources: ids=[ "studio_answer", "studio_chat", + "chat_completions", "studio_completion", "studio_completion_with_extra_args", "studio_embed", @@ -45,6 +47,7 @@ class TestStudioResources: argvalues=[ (get_studio_answer()), (get_studio_chat()), + (get_chat_completions()), (get_studio_completion()), (get_studio_completion(temperature=0.5, max_tokens=50)), (get_studio_embed()), From 2e7b83633872e396ed53de78b70f2cf8b1cc2e04 Mon Sep 17 00:00:00 2001 From: Asaf Gardin Date: Mon, 25 Mar 2024 13:58:38 +0200 Subject: [PATCH 09/15] fix: CR --- examples/studio/chat.py | 4 ++++ examples/studio/chat_complete.py | 6 +++--- 2 files changed, 7 insertions(+), 3 deletions(-) diff --git a/examples/studio/chat.py b/examples/studio/chat.py index c9575e40..0b044cfa 100644 --- a/examples/studio/chat.py +++ b/examples/studio/chat.py @@ -1,3 +1,7 @@ +""" +This examples uses a deprecated method client.chat.create and instead +should be replaced with the `client.chat.completions.create` +""" from ai21 import AI21Client from ai21.models import RoleType, Penalty from ai21.models import ChatMessage diff --git a/examples/studio/chat_complete.py b/examples/studio/chat_complete.py index 23e88aa7..06a5ec20 100644 --- a/examples/studio/chat_complete.py +++ b/examples/studio/chat_complete.py @@ -1,5 +1,6 @@ -from ai21 import AI21Client, AI21EnvConfig -from ai21.models import ChatMessage, RoleType +from ai21 import AI21Client +from ai21.models import RoleType +from ai21.models.chat import ChatMessage system = "You're a support engineer in a SaaS company" messages = [ @@ -8,7 +9,6 @@ ChatMessage(content="I am having trouble signing up for your product with my Google account.", role=RoleType.USER), ] -AI21EnvConfig.api_host = "https://api-stage.ai21.com" client = AI21Client() response = client.chat.completions.create( From 2a2e1d31f1d04e5e760284a2a7e321bab8413f75 Mon Sep 17 00:00:00 2001 From: Asaf Gardin Date: Mon, 25 Mar 2024 15:56:31 +0200 Subject: [PATCH 10/15] fix: CR --- .../studio/resources/chat/chat_completion.py | 21 +++++++++++-------- examples/studio/chat.py | 1 + .../integration_tests/clients/test_bedrock.py | 1 + 3 files changed, 14 insertions(+), 9 deletions(-) diff --git a/ai21/clients/studio/resources/chat/chat_completion.py b/ai21/clients/studio/resources/chat/chat_completion.py index 47e82228..791232e0 100644 --- a/ai21/clients/studio/resources/chat/chat_completion.py +++ b/ai21/clients/studio/resources/chat/chat_completion.py @@ -17,15 +17,15 @@ def create( self, model: str, messages: List[ChatMessage], - n: Optional[int] | NotGiven = NOT_GIVEN, - logprobs: Optional[bool] | NotGiven = NOT_GIVEN, - top_logprobs: Optional[int] | NotGiven = NOT_GIVEN, - max_tokens: Optional[int] | NotGiven = NOT_GIVEN, - temperature: Optional[float] | NotGiven = NOT_GIVEN, - top_p: Optional[float] | NotGiven = NOT_GIVEN, - stop: Optional[Union[str, List[str]]] | NotGiven = NOT_GIVEN, - frequency_penalty: Optional[float] | NotGiven = NOT_GIVEN, - presence_penalty: Optional[float] | NotGiven = NOT_GIVEN, + n: int | NotGiven = NOT_GIVEN, + logprobs: bool | NotGiven = NOT_GIVEN, + top_logprobs: int | NotGiven = NOT_GIVEN, + max_tokens: int | NotGiven = NOT_GIVEN, + temperature: float | NotGiven = NOT_GIVEN, + top_p: float | NotGiven = NOT_GIVEN, + stop: str | List[str] | NotGiven = NOT_GIVEN, + frequency_penalty: float | NotGiven = NOT_GIVEN, + presence_penalty: float | NotGiven = NOT_GIVEN, **kwargs: Any, ) -> ChatCompletionResponse: body = self._create_body( @@ -40,6 +40,7 @@ def create( top_p=top_p, frequency_penalty=frequency_penalty, presence_penalty=presence_penalty, + **kwargs, ) url = f"{self._client.get_base_url()}/{self._module_name}" @@ -59,6 +60,7 @@ def _create_body( stop: Optional[Union[str, List[str]]] | NotGiven, frequency_penalty: Optional[float] | NotGiven, presence_penalty: Optional[float] | NotGiven, + **kwargs: Any, ) -> Dict[str, Any]: return remove_not_given( { @@ -73,6 +75,7 @@ def _create_body( "stop": stop, "frequencyPenalty": frequency_penalty, "presencePenalty": presence_penalty, + **kwargs, } ) diff --git a/examples/studio/chat.py b/examples/studio/chat.py index 0b044cfa..ff457b24 100644 --- a/examples/studio/chat.py +++ b/examples/studio/chat.py @@ -2,6 +2,7 @@ This examples uses a deprecated method client.chat.create and instead should be replaced with the `client.chat.completions.create` """ + from ai21 import AI21Client from ai21.models import RoleType, Penalty from ai21.models import ChatMessage diff --git a/tests/integration_tests/clients/test_bedrock.py b/tests/integration_tests/clients/test_bedrock.py index 2e64efb0..02d0ee3e 100644 --- a/tests/integration_tests/clients/test_bedrock.py +++ b/tests/integration_tests/clients/test_bedrock.py @@ -1,6 +1,7 @@ """ Run this script after setting the environment variable called AWS_ACCESS_KEY_ID, AWS_SECRET_ACCESS_KEY """ + import subprocess from pathlib import Path From ad2e20509c7d8e1b226d7855fdf43f95c8c7e744 Mon Sep 17 00:00:00 2001 From: Asaf Gardin Date: Mon, 25 Mar 2024 15:57:18 +0200 Subject: [PATCH 11/15] fix: model name --- examples/studio/chat_complete.py | 2 +- tests/integration_tests/clients/studio/test_chat_completions.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/examples/studio/chat_complete.py b/examples/studio/chat_complete.py index 06a5ec20..9248bdf6 100644 --- a/examples/studio/chat_complete.py +++ b/examples/studio/chat_complete.py @@ -13,7 +13,7 @@ response = client.chat.completions.create( messages=messages, - model="gaia-small", + model="new-model-name", n=2, logprobs=True, top_logprobs=2, diff --git a/tests/integration_tests/clients/studio/test_chat_completions.py b/tests/integration_tests/clients/studio/test_chat_completions.py index 63d852e5..a98f79eb 100644 --- a/tests/integration_tests/clients/studio/test_chat_completions.py +++ b/tests/integration_tests/clients/studio/test_chat_completions.py @@ -6,7 +6,7 @@ from ai21.models.chat.chat_completion_response import ChatCompletionResponse -_MODEL = "gaia-small" +_MODEL = "new-model-name" _MESSAGES = [ ChatMessage( content="Hello, I need help studying for the coming test, can you teach me about the US constitution? ", From e7aac373422ebe79608d7f08f1f31372747ef4f8 Mon Sep 17 00:00:00 2001 From: Asaf Gardin Date: Mon, 25 Mar 2024 22:38:01 +0200 Subject: [PATCH 12/15] fix: alias --- .pre-commit-config.yaml | 2 -- ai21/clients/studio/resources/chat/__init__.py | 2 +- .../chat/{chat_completion.py => chat_completions.py} | 0 ai21/clients/studio/resources/studio_chat.py | 6 ------ ai21/models/__init__.py | 6 +++--- ai21/models/chat/__init__.py | 7 ++++--- ai21/models/{ => chat}/role_type.py | 2 ++ ai21/models/chat_message.py | 2 +- ai21/models/responses/chat_response.py | 2 +- examples/__init__.py | 0 examples/studio/__init__.py | 0 examples/studio/chat/__init__.py | 0 .../studio/{chat_complete.py => chat/chat_completions.py} | 0 pyproject.toml | 2 ++ 14 files changed, 14 insertions(+), 17 deletions(-) rename ai21/clients/studio/resources/chat/{chat_completion.py => chat_completions.py} (100%) rename ai21/models/{ => chat}/role_type.py (80%) create mode 100644 examples/__init__.py create mode 100644 examples/studio/__init__.py create mode 100644 examples/studio/chat/__init__.py rename examples/studio/{chat_complete.py => chat/chat_completions.py} (100%) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 9c210e34..41261070 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -108,8 +108,6 @@ repos: rev: v0.0.280 hooks: - id: ruff - args: - - --fix - repo: local hooks: - id: hadolint diff --git a/ai21/clients/studio/resources/chat/__init__.py b/ai21/clients/studio/resources/chat/__init__.py index 1a12d6c5..5c9e2645 100644 --- a/ai21/clients/studio/resources/chat/__init__.py +++ b/ai21/clients/studio/resources/chat/__init__.py @@ -1,3 +1,3 @@ from __future__ import annotations -from .chat_completion import ChatCompletions as ChatCompletions +from .chat_completions import ChatCompletions as ChatCompletions diff --git a/ai21/clients/studio/resources/chat/chat_completion.py b/ai21/clients/studio/resources/chat/chat_completions.py similarity index 100% rename from ai21/clients/studio/resources/chat/chat_completion.py rename to ai21/clients/studio/resources/chat/chat_completions.py diff --git a/ai21/clients/studio/resources/studio_chat.py b/ai21/clients/studio/resources/studio_chat.py index e163f441..70f8b01f 100644 --- a/ai21/clients/studio/resources/studio_chat.py +++ b/ai21/clients/studio/resources/studio_chat.py @@ -1,4 +1,3 @@ -import warnings from typing import List, Optional from ai21.clients.common.chat_base import Chat @@ -26,11 +25,6 @@ def create( count_penalty: Optional[Penalty] = None, **kwargs, ) -> ChatResponse: - warnings.warn( - "This method is deprecated. Please use the `chat.completions.create` method in the client instead.", - DeprecationWarning, - stacklevel=2, - ) body = self._create_body( model=model, messages=messages, diff --git a/ai21/models/__init__.py b/ai21/models/__init__.py index 8f68701e..cc851fc9 100644 --- a/ai21/models/__init__.py +++ b/ai21/models/__init__.py @@ -1,10 +1,12 @@ +from ai21.models.chat.role_type import RoleType +from ai21.models.chat_message import ChatMessage from ai21.models.document_type import DocumentType from ai21.models.embed_type import EmbedType from ai21.models.improvement_type import ImprovementType -from ai21.models.chat_message import ChatMessage from ai21.models.paraphrase_style_type import ParaphraseStyleType from ai21.models.penalty import Penalty from ai21.models.responses.answer_response import AnswerResponse +from ai21.models.responses.chat_response import ChatResponse, ChatOutput, FinishReason from ai21.models.responses.completion_response import ( CompletionsResponse, Completion, @@ -18,14 +20,12 @@ from ai21.models.responses.file_response import FileResponse from ai21.models.responses.gec_response import GECResponse, Correction, CorrectionType from ai21.models.responses.improvement_response import ImprovementsResponse, Improvement -from ai21.models.responses.chat_response import ChatResponse, ChatOutput, FinishReason from ai21.models.responses.library_answer_response import LibraryAnswerResponse, SourceDocument from ai21.models.responses.library_search_response import LibrarySearchResponse, LibrarySearchResult from ai21.models.responses.paraphrase_response import ParaphraseResponse, Suggestion from ai21.models.responses.segmentation_response import SegmentationResponse from ai21.models.responses.summarize_by_segment_response import SummarizeBySegmentResponse, SegmentSummary, Highlight from ai21.models.responses.summarize_response import SummarizeResponse -from ai21.models.role_type import RoleType from ai21.models.summary_method import SummaryMethod __all__ = [ diff --git a/ai21/models/chat/__init__.py b/ai21/models/chat/__init__.py index 34d1a10a..e1194eab 100644 --- a/ai21/models/chat/__init__.py +++ b/ai21/models/chat/__init__.py @@ -1,5 +1,6 @@ from __future__ import annotations -from .chat_message import ChatMessage as ChatMessage -from .chat_completion_response import ChatCompletionResponse as ChatCompletionResponse -from .chat_completion_response import ChatCompletionResponseChoice as ChatCompletionResponseChoice +from .chat_completion_response import ChatCompletionResponse +from .chat_completion_response import ChatCompletionResponseChoice +from .chat_message import ChatMessage +from .role_type import RoleType as RoleType diff --git a/ai21/models/role_type.py b/ai21/models/chat/role_type.py similarity index 80% rename from ai21/models/role_type.py rename to ai21/models/chat/role_type.py index a1630a23..39336e6a 100644 --- a/ai21/models/role_type.py +++ b/ai21/models/chat/role_type.py @@ -1,5 +1,7 @@ from enum import Enum +__all__ = ["RoleType"] + class RoleType(str, Enum): USER = "user" diff --git a/ai21/models/chat_message.py b/ai21/models/chat_message.py index c7536a77..788affee 100644 --- a/ai21/models/chat_message.py +++ b/ai21/models/chat_message.py @@ -1,7 +1,7 @@ from dataclasses import dataclass from ai21.models.ai21_base_model_mixin import AI21BaseModelMixin -from ai21.models.role_type import RoleType +from ai21.models.chat.role_type import RoleType @dataclass diff --git a/ai21/models/responses/chat_response.py b/ai21/models/responses/chat_response.py index e1a15a9f..4ca4ad63 100644 --- a/ai21/models/responses/chat_response.py +++ b/ai21/models/responses/chat_response.py @@ -2,7 +2,7 @@ from typing import Optional, List from ai21.models.ai21_base_model_mixin import AI21BaseModelMixin -from ai21.models.role_type import RoleType +from ai21.models.chat.role_type import RoleType @dataclass diff --git a/examples/__init__.py b/examples/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/examples/studio/__init__.py b/examples/studio/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/examples/studio/chat/__init__.py b/examples/studio/chat/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/examples/studio/chat_complete.py b/examples/studio/chat/chat_completions.py similarity index 100% rename from examples/studio/chat_complete.py rename to examples/studio/chat/chat_completions.py diff --git a/pyproject.toml b/pyproject.toml index 0c7def29..f864f1ad 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -116,3 +116,5 @@ prerelease = true [tool.ruff] line-length = 120 +[tool.ruff.per-file-ignores] +"__init__.py" = ["F401"] From 27911dd505f0f60a4f860e533779de96ebfb7852 Mon Sep 17 00:00:00 2001 From: Asaf Gardin Date: Mon, 25 Mar 2024 22:38:57 +0200 Subject: [PATCH 13/15] revert: ruff --- .pre-commit-config.yaml | 2 ++ 1 file changed, 2 insertions(+) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 41261070..9c210e34 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -108,6 +108,8 @@ repos: rev: v0.0.280 hooks: - id: ruff + args: + - --fix - repo: local hooks: - id: hadolint From debe0de4a088cbb9d8f14c677cad5088440486c4 Mon Sep 17 00:00:00 2001 From: Asaf Gardin Date: Mon, 25 Mar 2024 22:43:07 +0200 Subject: [PATCH 14/15] fix: circualr imports --- ai21/models/chat/chat_completion_response.py | 2 +- ai21/models/chat/chat_message.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/ai21/models/chat/chat_completion_response.py b/ai21/models/chat/chat_completion_response.py index faf889a9..c1ad78f7 100644 --- a/ai21/models/chat/chat_completion_response.py +++ b/ai21/models/chat/chat_completion_response.py @@ -2,9 +2,9 @@ from typing import Optional, List from ai21.models.ai21_base_model_mixin import AI21BaseModelMixin -from ai21.models.chat import ChatMessage from ai21.models.logprobs import Logprobs from ai21.models.usage_info import UsageInfo +from .chat_message import ChatMessage __all__ = ["ChatCompletionResponse", "ChatCompletionResponseChoice"] diff --git a/ai21/models/chat/chat_message.py b/ai21/models/chat/chat_message.py index 36f0158e..933f93f8 100644 --- a/ai21/models/chat/chat_message.py +++ b/ai21/models/chat/chat_message.py @@ -1,7 +1,7 @@ from dataclasses import dataclass -from ai21.models import RoleType from ai21.models.ai21_base_model_mixin import AI21BaseModelMixin +from .role_type import RoleType __all__ = ["ChatMessage"] From d200d0edff8d3d441829c45f534b25927dbd7a4f Mon Sep 17 00:00:00 2001 From: Asaf Gardin Date: Tue, 26 Mar 2024 15:32:51 +0200 Subject: [PATCH 15/15] fix: all import --- ai21/models/chat/__init__.py | 2 ++ ai21/models/chat/chat_completion_response.py | 2 -- ai21/models/chat/chat_message.py | 2 -- ai21/models/chat/role_type.py | 2 -- pyproject.toml | 2 -- 5 files changed, 2 insertions(+), 8 deletions(-) diff --git a/ai21/models/chat/__init__.py b/ai21/models/chat/__init__.py index e1194eab..0fb4df66 100644 --- a/ai21/models/chat/__init__.py +++ b/ai21/models/chat/__init__.py @@ -4,3 +4,5 @@ from .chat_completion_response import ChatCompletionResponseChoice from .chat_message import ChatMessage from .role_type import RoleType as RoleType + +__all__ = ["ChatCompletionResponse", "ChatCompletionResponseChoice", "ChatMessage", "RoleType"] diff --git a/ai21/models/chat/chat_completion_response.py b/ai21/models/chat/chat_completion_response.py index c1ad78f7..b4551e6b 100644 --- a/ai21/models/chat/chat_completion_response.py +++ b/ai21/models/chat/chat_completion_response.py @@ -6,8 +6,6 @@ from ai21.models.usage_info import UsageInfo from .chat_message import ChatMessage -__all__ = ["ChatCompletionResponse", "ChatCompletionResponseChoice"] - @dataclass class ChatCompletionResponseChoice(AI21BaseModelMixin): diff --git a/ai21/models/chat/chat_message.py b/ai21/models/chat/chat_message.py index 933f93f8..f3f4a437 100644 --- a/ai21/models/chat/chat_message.py +++ b/ai21/models/chat/chat_message.py @@ -3,8 +3,6 @@ from ai21.models.ai21_base_model_mixin import AI21BaseModelMixin from .role_type import RoleType -__all__ = ["ChatMessage"] - @dataclass class ChatMessage(AI21BaseModelMixin): diff --git a/ai21/models/chat/role_type.py b/ai21/models/chat/role_type.py index 39336e6a..a1630a23 100644 --- a/ai21/models/chat/role_type.py +++ b/ai21/models/chat/role_type.py @@ -1,7 +1,5 @@ from enum import Enum -__all__ = ["RoleType"] - class RoleType(str, Enum): USER = "user" diff --git a/pyproject.toml b/pyproject.toml index f864f1ad..0c7def29 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -116,5 +116,3 @@ prerelease = true [tool.ruff] line-length = 120 -[tool.ruff.per-file-ignores] -"__init__.py" = ["F401"]