diff --git a/ai21/clients/common/assistant/assistants.py b/ai21/clients/common/assistant/assistants.py index e773f89a..1beb2b7f 100644 --- a/ai21/clients/common/assistant/assistants.py +++ b/ai21/clients/common/assistant/assistants.py @@ -57,7 +57,7 @@ def list(self) -> ListAssistant: pass @abstractmethod - def get(self, assistant_id: str) -> Assistant: + def retrieve(self, assistant_id: str) -> Assistant: pass @abstractmethod diff --git a/ai21/clients/common/assistant/messages.py b/ai21/clients/common/assistant/messages.py new file mode 100644 index 00000000..0ccbde79 --- /dev/null +++ b/ai21/clients/common/assistant/messages.py @@ -0,0 +1,25 @@ +from __future__ import annotations + +from abc import ABC, abstractmethod + +from ai21.models.assistant.message import ThreadMessageRole, MessageContentText +from ai21.models.responses.message_response import MessageResponse, ListMessageResponse + + +class Messages(ABC): + _module_name = "messages" + + @abstractmethod + def create( + self, + thread_id: str, + *, + role: ThreadMessageRole, + content: MessageContentText, + **kwargs, + ) -> MessageResponse: + pass + + @abstractmethod + def list(self, thread_id: str) -> ListMessageResponse: + pass diff --git a/ai21/clients/common/assistant/threads.py b/ai21/clients/common/assistant/threads.py index 5025ccc2..9b64580d 100644 --- a/ai21/clients/common/assistant/threads.py +++ b/ai21/clients/common/assistant/threads.py @@ -3,17 +3,23 @@ from abc import ABC, abstractmethod from typing import List +from ai21.clients.common.assistant.messages import Messages from ai21.models.assistant.message import Message from ai21.models.responses.thread_response import Thread class Threads(ABC): _module_name = "threads" + messages: Messages @abstractmethod - def create(self, messages: List[Message], **kwargs) -> Thread: + def create( + self, + messages: List[Message], + **kwargs, + ) -> Thread: pass @abstractmethod - def get(self, thread_id: str) -> Thread: + def retrieve(self, thread_id: str) -> Thread: pass diff --git a/ai21/clients/studio/resources/assistant/studio_assistant.py b/ai21/clients/studio/resources/assistant/studio_assistant.py index fa008695..f2a3c306 100644 --- a/ai21/clients/studio/resources/assistant/studio_assistant.py +++ b/ai21/clients/studio/resources/assistant/studio_assistant.py @@ -38,7 +38,7 @@ def create( return self._post(path=f"/{self._module_name}", body=body, response_cls=Assistant) - def get(self, assistant_id: str) -> Assistant: + def retrieve(self, assistant_id: str) -> Assistant: return self._get(path=f"/{self._module_name}/{assistant_id}", response_cls=Assistant) def list(self) -> ListAssistant: @@ -97,7 +97,7 @@ async def create( return await self._post(path=f"/{self._module_name}", body=body, response_cls=Assistant) - async def get(self, assistant_id: str) -> Assistant: + async def retrieve(self, assistant_id: str) -> Assistant: return await self._get(path=f"/{self._module_name}/{assistant_id}", response_cls=Assistant) async def list(self) -> ListAssistant: diff --git a/ai21/clients/studio/resources/assistant/studio_thread.py b/ai21/clients/studio/resources/assistant/studio_thread.py index fd48563e..60c88e22 100644 --- a/ai21/clients/studio/resources/assistant/studio_thread.py +++ b/ai21/clients/studio/resources/assistant/studio_thread.py @@ -3,26 +3,47 @@ from typing import List from ai21.clients.common.assistant.threads import Threads +from ai21.clients.studio.resources.assistant.studio_thread_message import StudioThreadMessage, AsyncStudioThreadMessage from ai21.clients.studio.resources.studio_resource import StudioResource, AsyncStudioResource +from ai21.http_client.async_http_client import AsyncAI21HTTPClient +from ai21.http_client.http_client import AI21HTTPClient from ai21.models.assistant.message import Message from ai21.models.responses.thread_response import Thread class StudioThread(StudioResource, Threads): - def create(self, messages: List[Message], **kwargs) -> Thread: + def __init__(self, client: AI21HTTPClient): + super().__init__(client) + + self.messages = StudioThreadMessage(client) + + def create( + self, + messages: List[Message], + **kwargs, + ) -> Thread: body = dict(messages=messages) return self._post(path=f"/{self._module_name}", body=body, response_cls=Thread) - def get(self, thread_id: str) -> Thread: + def retrieve(self, thread_id: str) -> Thread: return self._get(path=f"/{self._module_name}/{thread_id}", response_cls=Thread) class AsyncStudioThread(AsyncStudioResource, Threads): - async def create(self, messages: List[Message], **kwargs) -> Thread: + def __init__(self, client: AsyncAI21HTTPClient): + super().__init__(client) + + self.messages = AsyncStudioThreadMessage(client) + + async def create( + self, + messages: List[Message], + **kwargs, + ) -> Thread: body = dict(messages=messages) return await self._post(path=f"/{self._module_name}", body=body, response_cls=Thread) - async def get(self, thread_id: str) -> Thread: + async def retrieve(self, thread_id: str) -> Thread: return await self._get(path=f"/{self._module_name}/{thread_id}", response_cls=Thread) diff --git a/ai21/clients/studio/resources/assistant/studio_thread_message.py b/ai21/clients/studio/resources/assistant/studio_thread_message.py new file mode 100644 index 00000000..95cb1b47 --- /dev/null +++ b/ai21/clients/studio/resources/assistant/studio_thread_message.py @@ -0,0 +1,48 @@ +from __future__ import annotations + +from ai21.clients.common.assistant.messages import Messages +from ai21.clients.studio.resources.studio_resource import StudioResource, AsyncStudioResource +from ai21.models.assistant.message import ThreadMessageRole, MessageContentText +from ai21.models.responses.message_response import MessageResponse, ListMessageResponse + + +class StudioThreadMessage(StudioResource, Messages): + def create( + self, + thread_id: str, + *, + role: ThreadMessageRole, + content: MessageContentText, + **kwargs, + ) -> MessageResponse: + body = dict( + role=role, + content=content, + ) + + return self._post(path=f"/threads/{thread_id}/{self._module_name}", body=body, response_cls=MessageResponse) + + def list(self, thread_id: str) -> ListMessageResponse: + return self._get(path=f"/threads/{thread_id}/{self._module_name}", response_cls=ListMessageResponse) + + +class AsyncStudioThreadMessage(AsyncStudioResource, Messages): + async def create( + self, + thread_id: str, + *, + role: ThreadMessageRole, + content: MessageContentText, + **kwargs, + ) -> MessageResponse: + body = dict( + role=role, + content=content, + ) + + return await self._post( + path=f"/threads/{thread_id}/{self._module_name}", body=body, response_cls=MessageResponse + ) + + async def list(self, thread_id: str) -> ListMessageResponse: + return await self._get(path=f"/threads/{thread_id}/{self._module_name}", response_cls=ListMessageResponse) diff --git a/ai21/models/assistant/message.py b/ai21/models/assistant/message.py index f3ec7d8d..22ea77a2 100644 --- a/ai21/models/assistant/message.py +++ b/ai21/models/assistant/message.py @@ -1,10 +1,7 @@ -from datetime import datetime -from typing import Literal, Optional +from typing import Literal from typing_extensions import TypedDict -from ai21.models.ai21_base_model import AI21BaseModel - ThreadMessageRole = Literal["assistant", "user"] @@ -16,14 +13,3 @@ class MessageContentText(TypedDict): class Message(TypedDict): role: ThreadMessageRole content: MessageContentText - - -class MessageResponse(AI21BaseModel): - id: str - created_at: datetime - updated_at: datetime - object: Literal["message"] = "message" - role: ThreadMessageRole - content: MessageContentText - run_id: Optional[str] = None - assistant_id: Optional[str] = None diff --git a/ai21/models/responses/message_response.py b/ai21/models/responses/message_response.py new file mode 100644 index 00000000..e19fa665 --- /dev/null +++ b/ai21/models/responses/message_response.py @@ -0,0 +1,20 @@ +from datetime import datetime +from typing import Literal, Optional, List + +from ai21.models.ai21_base_model import AI21BaseModel +from ai21.models.assistant.message import ThreadMessageRole, MessageContentText + + +class MessageResponse(AI21BaseModel): + id: str + created_at: datetime + updated_at: datetime + object: Literal["message"] = "message" + role: ThreadMessageRole + content: MessageContentText + run_id: Optional[str] = None + assistant_id: Optional[str] = None + + +class ListMessageResponse(AI21BaseModel): + results: List[MessageResponse]