From d2747a56169c8c68cd9701291fab72509835f830 Mon Sep 17 00:00:00 2001 From: benshuk Date: Mon, 2 Dec 2024 11:49:36 +0200 Subject: [PATCH 1/4] fix: :truck: rename `get` method to `retrieve` --- ai21/clients/common/assistant/assistants.py | 2 +- ai21/clients/common/assistant/threads.py | 2 +- ai21/clients/studio/resources/assistant/studio_assistant.py | 4 ++-- ai21/clients/studio/resources/assistant/studio_thread.py | 4 ++-- 4 files changed, 6 insertions(+), 6 deletions(-) diff --git a/ai21/clients/common/assistant/assistants.py b/ai21/clients/common/assistant/assistants.py index e773f89a..1beb2b7f 100644 --- a/ai21/clients/common/assistant/assistants.py +++ b/ai21/clients/common/assistant/assistants.py @@ -57,7 +57,7 @@ def list(self) -> ListAssistant: pass @abstractmethod - def get(self, assistant_id: str) -> Assistant: + def retrieve(self, assistant_id: str) -> Assistant: pass @abstractmethod diff --git a/ai21/clients/common/assistant/threads.py b/ai21/clients/common/assistant/threads.py index 5025ccc2..d24ead91 100644 --- a/ai21/clients/common/assistant/threads.py +++ b/ai21/clients/common/assistant/threads.py @@ -15,5 +15,5 @@ def create(self, messages: List[Message], **kwargs) -> Thread: pass @abstractmethod - def get(self, thread_id: str) -> Thread: + def retrieve(self, thread_id: str) -> Thread: pass diff --git a/ai21/clients/studio/resources/assistant/studio_assistant.py b/ai21/clients/studio/resources/assistant/studio_assistant.py index fa008695..f2a3c306 100644 --- a/ai21/clients/studio/resources/assistant/studio_assistant.py +++ b/ai21/clients/studio/resources/assistant/studio_assistant.py @@ -38,7 +38,7 @@ def create( return self._post(path=f"/{self._module_name}", body=body, response_cls=Assistant) - def get(self, assistant_id: str) -> Assistant: + def retrieve(self, assistant_id: str) -> Assistant: return self._get(path=f"/{self._module_name}/{assistant_id}", response_cls=Assistant) def list(self) -> ListAssistant: @@ -97,7 +97,7 @@ async def create( return await self._post(path=f"/{self._module_name}", body=body, response_cls=Assistant) - async def get(self, assistant_id: str) -> Assistant: + async def retrieve(self, assistant_id: str) -> Assistant: return await self._get(path=f"/{self._module_name}/{assistant_id}", response_cls=Assistant) async def list(self) -> ListAssistant: diff --git a/ai21/clients/studio/resources/assistant/studio_thread.py b/ai21/clients/studio/resources/assistant/studio_thread.py index fd48563e..2ecd7199 100644 --- a/ai21/clients/studio/resources/assistant/studio_thread.py +++ b/ai21/clients/studio/resources/assistant/studio_thread.py @@ -14,7 +14,7 @@ def create(self, messages: List[Message], **kwargs) -> Thread: return self._post(path=f"/{self._module_name}", body=body, response_cls=Thread) - def get(self, thread_id: str) -> Thread: + def retrieve(self, thread_id: str) -> Thread: return self._get(path=f"/{self._module_name}/{thread_id}", response_cls=Thread) @@ -24,5 +24,5 @@ async def create(self, messages: List[Message], **kwargs) -> Thread: return await self._post(path=f"/{self._module_name}", body=body, response_cls=Thread) - async def get(self, thread_id: str) -> Thread: + async def retrieve(self, thread_id: str) -> Thread: return await self._get(path=f"/{self._module_name}/{thread_id}", response_cls=Thread) From 151f5acfa2d83fccb8b9bc459fce761324542ce5 Mon Sep 17 00:00:00 2001 From: benshuk Date: Mon, 2 Dec 2024 11:50:53 +0200 Subject: [PATCH 2/4] refactor: :truck: move `MessageResponse` type to a separate file --- ai21/models/assistant/message.py | 16 +--------------- ai21/models/responses/message_response.py | 20 ++++++++++++++++++++ 2 files changed, 21 insertions(+), 15 deletions(-) create mode 100644 ai21/models/responses/message_response.py diff --git a/ai21/models/assistant/message.py b/ai21/models/assistant/message.py index f3ec7d8d..22ea77a2 100644 --- a/ai21/models/assistant/message.py +++ b/ai21/models/assistant/message.py @@ -1,10 +1,7 @@ -from datetime import datetime -from typing import Literal, Optional +from typing import Literal from typing_extensions import TypedDict -from ai21.models.ai21_base_model import AI21BaseModel - ThreadMessageRole = Literal["assistant", "user"] @@ -16,14 +13,3 @@ class MessageContentText(TypedDict): class Message(TypedDict): role: ThreadMessageRole content: MessageContentText - - -class MessageResponse(AI21BaseModel): - id: str - created_at: datetime - updated_at: datetime - object: Literal["message"] = "message" - role: ThreadMessageRole - content: MessageContentText - run_id: Optional[str] = None - assistant_id: Optional[str] = None diff --git a/ai21/models/responses/message_response.py b/ai21/models/responses/message_response.py new file mode 100644 index 00000000..e19fa665 --- /dev/null +++ b/ai21/models/responses/message_response.py @@ -0,0 +1,20 @@ +from datetime import datetime +from typing import Literal, Optional, List + +from ai21.models.ai21_base_model import AI21BaseModel +from ai21.models.assistant.message import ThreadMessageRole, MessageContentText + + +class MessageResponse(AI21BaseModel): + id: str + created_at: datetime + updated_at: datetime + object: Literal["message"] = "message" + role: ThreadMessageRole + content: MessageContentText + run_id: Optional[str] = None + assistant_id: Optional[str] = None + + +class ListMessageResponse(AI21BaseModel): + results: List[MessageResponse] From b483e704743a2d9e8fbd5da99f97bf348263a4d5 Mon Sep 17 00:00:00 2001 From: benshuk Date: Mon, 2 Dec 2024 13:45:14 +0200 Subject: [PATCH 3/4] feat: :sparkles: add support for Message resource --- ai21/clients/common/assistant/messages.py | 20 ++++++++++ ai21/clients/common/assistant/threads.py | 2 + .../resources/assistant/studio_thread.py | 13 +++++++ .../assistant/studio_thread_message.py | 38 +++++++++++++++++++ 4 files changed, 73 insertions(+) create mode 100644 ai21/clients/common/assistant/messages.py create mode 100644 ai21/clients/studio/resources/assistant/studio_thread_message.py diff --git a/ai21/clients/common/assistant/messages.py b/ai21/clients/common/assistant/messages.py new file mode 100644 index 00000000..6200ca76 --- /dev/null +++ b/ai21/clients/common/assistant/messages.py @@ -0,0 +1,20 @@ +from __future__ import annotations + +from abc import ABC, abstractmethod + +from ai21.models.assistant.message import ThreadMessageRole, MessageContentText +from ai21.models.responses.message_response import MessageResponse, ListMessageResponse + + +class Messages(ABC): + _module_name = "messages" + + @abstractmethod + def create( + self, thread_id: str, *, role: ThreadMessageRole, content: MessageContentText, **kwargs + ) -> MessageResponse: + pass + + @abstractmethod + def list(self, thread_id: str) -> ListMessageResponse: + pass diff --git a/ai21/clients/common/assistant/threads.py b/ai21/clients/common/assistant/threads.py index d24ead91..89b44965 100644 --- a/ai21/clients/common/assistant/threads.py +++ b/ai21/clients/common/assistant/threads.py @@ -3,12 +3,14 @@ from abc import ABC, abstractmethod from typing import List +from ai21.clients.common.assistant.messages import Messages from ai21.models.assistant.message import Message from ai21.models.responses.thread_response import Thread class Threads(ABC): _module_name = "threads" + messages: Messages @abstractmethod def create(self, messages: List[Message], **kwargs) -> Thread: diff --git a/ai21/clients/studio/resources/assistant/studio_thread.py b/ai21/clients/studio/resources/assistant/studio_thread.py index 2ecd7199..4c41249c 100644 --- a/ai21/clients/studio/resources/assistant/studio_thread.py +++ b/ai21/clients/studio/resources/assistant/studio_thread.py @@ -3,12 +3,20 @@ from typing import List from ai21.clients.common.assistant.threads import Threads +from ai21.clients.studio.resources.assistant.studio_thread_message import StudioThreadMessage, AsyncStudioThreadMessage from ai21.clients.studio.resources.studio_resource import StudioResource, AsyncStudioResource +from ai21.http_client.async_http_client import AsyncAI21HTTPClient +from ai21.http_client.http_client import AI21HTTPClient from ai21.models.assistant.message import Message from ai21.models.responses.thread_response import Thread class StudioThread(StudioResource, Threads): + def __init__(self, client: AI21HTTPClient): + super().__init__(client) + + self.messages = StudioThreadMessage(client) + def create(self, messages: List[Message], **kwargs) -> Thread: body = dict(messages=messages) @@ -19,6 +27,11 @@ def retrieve(self, thread_id: str) -> Thread: class AsyncStudioThread(AsyncStudioResource, Threads): + def __init__(self, client: AsyncAI21HTTPClient): + super().__init__(client) + + self.messages = AsyncStudioThreadMessage(client) + async def create(self, messages: List[Message], **kwargs) -> Thread: body = dict(messages=messages) diff --git a/ai21/clients/studio/resources/assistant/studio_thread_message.py b/ai21/clients/studio/resources/assistant/studio_thread_message.py new file mode 100644 index 00000000..ce94c966 --- /dev/null +++ b/ai21/clients/studio/resources/assistant/studio_thread_message.py @@ -0,0 +1,38 @@ +from __future__ import annotations + +from ai21.clients.common.assistant.messages import Messages +from ai21.clients.studio.resources.studio_resource import StudioResource, AsyncStudioResource +from ai21.models.assistant.message import ThreadMessageRole, MessageContentText +from ai21.models.responses.message_response import MessageResponse, ListMessageResponse + + +class StudioThreadMessage(StudioResource, Messages): + def create( + self, thread_id: str, *, role: ThreadMessageRole, content: MessageContentText, **kwargs + ) -> MessageResponse: + body = dict( + role=role, + content=content, + ) + + return self._post(path=f"/threads/{thread_id}/{self._module_name}", body=body, response_cls=MessageResponse) + + def list(self, thread_id: str) -> ListMessageResponse: + return self._get(path=f"/threads/{thread_id}/{self._module_name}", response_cls=ListMessageResponse) + + +class AsyncStudioThreadMessage(AsyncStudioResource, Messages): + async def create( + self, thread_id: str, *, role: ThreadMessageRole, content: MessageContentText, **kwargs + ) -> MessageResponse: + body = dict( + role=role, + content=content, + ) + + return await self._post( + path=f"/threads/{thread_id}/{self._module_name}", body=body, response_cls=MessageResponse + ) + + async def list(self, thread_id: str) -> ListMessageResponse: + return await self._get(path=f"/threads/{thread_id}/{self._module_name}", response_cls=ListMessageResponse) From a8c345e81f2b29dd14e5ec118d5f4068792198df Mon Sep 17 00:00:00 2001 From: benshuk Date: Mon, 2 Dec 2024 15:00:10 +0200 Subject: [PATCH 4/4] refactor: :art: reformat functions with 2+ arguments --- ai21/clients/common/assistant/messages.py | 7 ++++++- ai21/clients/common/assistant/threads.py | 6 +++++- .../studio/resources/assistant/studio_thread.py | 12 ++++++++++-- .../resources/assistant/studio_thread_message.py | 14 ++++++++++++-- 4 files changed, 33 insertions(+), 6 deletions(-) diff --git a/ai21/clients/common/assistant/messages.py b/ai21/clients/common/assistant/messages.py index 6200ca76..0ccbde79 100644 --- a/ai21/clients/common/assistant/messages.py +++ b/ai21/clients/common/assistant/messages.py @@ -11,7 +11,12 @@ class Messages(ABC): @abstractmethod def create( - self, thread_id: str, *, role: ThreadMessageRole, content: MessageContentText, **kwargs + self, + thread_id: str, + *, + role: ThreadMessageRole, + content: MessageContentText, + **kwargs, ) -> MessageResponse: pass diff --git a/ai21/clients/common/assistant/threads.py b/ai21/clients/common/assistant/threads.py index 89b44965..9b64580d 100644 --- a/ai21/clients/common/assistant/threads.py +++ b/ai21/clients/common/assistant/threads.py @@ -13,7 +13,11 @@ class Threads(ABC): messages: Messages @abstractmethod - def create(self, messages: List[Message], **kwargs) -> Thread: + def create( + self, + messages: List[Message], + **kwargs, + ) -> Thread: pass @abstractmethod diff --git a/ai21/clients/studio/resources/assistant/studio_thread.py b/ai21/clients/studio/resources/assistant/studio_thread.py index 4c41249c..60c88e22 100644 --- a/ai21/clients/studio/resources/assistant/studio_thread.py +++ b/ai21/clients/studio/resources/assistant/studio_thread.py @@ -17,7 +17,11 @@ def __init__(self, client: AI21HTTPClient): self.messages = StudioThreadMessage(client) - def create(self, messages: List[Message], **kwargs) -> Thread: + def create( + self, + messages: List[Message], + **kwargs, + ) -> Thread: body = dict(messages=messages) return self._post(path=f"/{self._module_name}", body=body, response_cls=Thread) @@ -32,7 +36,11 @@ def __init__(self, client: AsyncAI21HTTPClient): self.messages = AsyncStudioThreadMessage(client) - async def create(self, messages: List[Message], **kwargs) -> Thread: + async def create( + self, + messages: List[Message], + **kwargs, + ) -> Thread: body = dict(messages=messages) return await self._post(path=f"/{self._module_name}", body=body, response_cls=Thread) diff --git a/ai21/clients/studio/resources/assistant/studio_thread_message.py b/ai21/clients/studio/resources/assistant/studio_thread_message.py index ce94c966..95cb1b47 100644 --- a/ai21/clients/studio/resources/assistant/studio_thread_message.py +++ b/ai21/clients/studio/resources/assistant/studio_thread_message.py @@ -8,7 +8,12 @@ class StudioThreadMessage(StudioResource, Messages): def create( - self, thread_id: str, *, role: ThreadMessageRole, content: MessageContentText, **kwargs + self, + thread_id: str, + *, + role: ThreadMessageRole, + content: MessageContentText, + **kwargs, ) -> MessageResponse: body = dict( role=role, @@ -23,7 +28,12 @@ def list(self, thread_id: str) -> ListMessageResponse: class AsyncStudioThreadMessage(AsyncStudioResource, Messages): async def create( - self, thread_id: str, *, role: ThreadMessageRole, content: MessageContentText, **kwargs + self, + thread_id: str, + *, + role: ThreadMessageRole, + content: MessageContentText, + **kwargs, ) -> MessageResponse: body = dict( role=role,