diff --git a/ai21/clients/common/assistant/assistant.py b/ai21/clients/common/assistant/assistants.py similarity index 85% rename from ai21/clients/common/assistant/assistant.py rename to ai21/clients/common/assistant/assistants.py index c2ad588a..e773f89a 100644 --- a/ai21/clients/common/assistant/assistant.py +++ b/ai21/clients/common/assistant/assistants.py @@ -3,18 +3,13 @@ from abc import ABC, abstractmethod from typing import Any, Dict, List -from ai21.models.responses.assistant_response import ( - AssistantResponse, - Optimization, - ToolResources, - Tool, - ListAssistantResponse, -) +from ai21.models.assistant.assistant import Optimization, Tool, ToolResources +from ai21.models.responses.assistant_response import Assistant, ListAssistant from ai21.types import NotGiven, NOT_GIVEN from ai21.utils.typing import remove_not_given -class Assistant(ABC): +class Assistants(ABC): _module_name = "assistants" @abstractmethod @@ -29,7 +24,7 @@ def create( tools: List[Tool] | NotGiven = NOT_GIVEN, tool_resources: ToolResources | NotGiven = NOT_GIVEN, **kwargs, - ) -> AssistantResponse: + ) -> Assistant: pass def _create_body( @@ -58,11 +53,11 @@ def _create_body( ) @abstractmethod - def list(self) -> ListAssistantResponse: + def list(self) -> ListAssistant: pass @abstractmethod - def get(self, assistant_id: str) -> AssistantResponse: + def get(self, assistant_id: str) -> Assistant: pass @abstractmethod @@ -78,5 +73,5 @@ def modify( models: List[str] | NotGiven = NOT_GIVEN, tools: List[Tool] | NotGiven = NOT_GIVEN, tool_resources: ToolResources | NotGiven = NOT_GIVEN, - ) -> AssistantResponse: + ) -> Assistant: pass diff --git a/ai21/clients/common/assistant/threads.py b/ai21/clients/common/assistant/threads.py new file mode 100644 index 00000000..5025ccc2 --- /dev/null +++ b/ai21/clients/common/assistant/threads.py @@ -0,0 +1,19 @@ +from __future__ import annotations + +from abc import ABC, abstractmethod +from typing import List + +from ai21.models.assistant.message import Message +from ai21.models.responses.thread_response import Thread + + +class Threads(ABC): + _module_name = "threads" + + @abstractmethod + def create(self, messages: List[Message], **kwargs) -> Thread: + pass + + @abstractmethod + def get(self, thread_id: str) -> Thread: + pass diff --git a/ai21/clients/studio/resources/assistant/studio_assistant.py b/ai21/clients/studio/resources/assistant/studio_assistant.py index 82014880..fa008695 100644 --- a/ai21/clients/studio/resources/assistant/studio_assistant.py +++ b/ai21/clients/studio/resources/assistant/studio_assistant.py @@ -2,21 +2,17 @@ from typing import List -from ai21.clients.common.assistant.assistant import Assistant +from ai21.clients.common.assistant.assistants import Assistants from ai21.clients.studio.resources.studio_resource import ( AsyncStudioResource, StudioResource, ) -from ai21.models.responses.assistant_response import ( - AssistantResponse, - Tool, - ToolResources, - ListAssistantResponse, -) +from ai21.models.assistant.assistant import Tool, ToolResources +from ai21.models.responses.assistant_response import Assistant, ListAssistant from ai21.types import NotGiven, NOT_GIVEN -class StudioAssistant(StudioResource, Assistant): +class StudioAssistant(StudioResource, Assistants): def create( self, name: str, @@ -28,7 +24,7 @@ def create( tools: List[Tool] | NotGiven = NOT_GIVEN, tool_resources: ToolResources | NotGiven = NOT_GIVEN, **kwargs, - ) -> AssistantResponse: + ) -> Assistant: body = self._create_body( name=name, description=description, @@ -40,13 +36,13 @@ def create( **kwargs, ) - return self._post(path=f"/{self._module_name}", body=body, response_cls=AssistantResponse) + return self._post(path=f"/{self._module_name}", body=body, response_cls=Assistant) - def get(self, assistant_id: str) -> AssistantResponse: - return self._get(path=f"/{self._module_name}/{assistant_id}", response_cls=AssistantResponse) + def get(self, assistant_id: str) -> Assistant: + return self._get(path=f"/{self._module_name}/{assistant_id}", response_cls=Assistant) - def list(self) -> ListAssistantResponse: - return self._get(path=f"/{self._module_name}", response_cls=ListAssistantResponse) + def list(self) -> ListAssistant: + return self._get(path=f"/{self._module_name}", response_cls=ListAssistant) def modify( self, @@ -60,7 +56,7 @@ def modify( models: List[str] | NotGiven = NOT_GIVEN, tools: List[Tool] | NotGiven = NOT_GIVEN, tool_resources: ToolResources | NotGiven = NOT_GIVEN, - ) -> AssistantResponse: + ) -> Assistant: body = self._create_body( name=name, description=description, @@ -72,10 +68,10 @@ def modify( tool_resources=tool_resources, ) - return self._patch(path=f"/{self._module_name}/{assistant_id}", body=body, response_cls=AssistantResponse) + return self._patch(path=f"/{self._module_name}/{assistant_id}", body=body, response_cls=Assistant) -class AsyncStudioAssistant(AsyncStudioResource, Assistant): +class AsyncStudioAssistant(AsyncStudioResource, Assistants): async def create( self, name: str, @@ -87,7 +83,7 @@ async def create( tools: List[Tool] | NotGiven = NOT_GIVEN, tool_resources: ToolResources | NotGiven = NOT_GIVEN, **kwargs, - ) -> AssistantResponse: + ) -> Assistant: body = self._create_body( name=name, description=description, @@ -99,13 +95,13 @@ async def create( **kwargs, ) - return self._post(path=f"/{self._module_name}", body=body, response_cls=AssistantResponse) + return await self._post(path=f"/{self._module_name}", body=body, response_cls=Assistant) - async def get(self, assistant_id: str) -> AssistantResponse: - return await self._get(path=f"/{self._module_name}/{assistant_id}", response_cls=AssistantResponse) + async def get(self, assistant_id: str) -> Assistant: + return await self._get(path=f"/{self._module_name}/{assistant_id}", response_cls=Assistant) - async def list(self) -> ListAssistantResponse: - return await self._get(path=f"/{self._module_name}", response_cls=ListAssistantResponse) + async def list(self) -> ListAssistant: + return await self._get(path=f"/{self._module_name}", response_cls=ListAssistant) async def modify( self, @@ -119,7 +115,7 @@ async def modify( models: List[str] | NotGiven = NOT_GIVEN, tools: List[Tool] | NotGiven = NOT_GIVEN, tool_resources: ToolResources | NotGiven = NOT_GIVEN, - ) -> AssistantResponse: + ) -> Assistant: body = self._create_body( name=name, description=description, @@ -131,4 +127,4 @@ async def modify( tool_resources=tool_resources, ) - return await self._patch(path=f"/{self._module_name}/{assistant_id}", body=body, response_cls=AssistantResponse) + return await self._patch(path=f"/{self._module_name}/{assistant_id}", body=body, response_cls=Assistant) diff --git a/ai21/clients/studio/resources/assistant/studio_thread.py b/ai21/clients/studio/resources/assistant/studio_thread.py new file mode 100644 index 00000000..fd48563e --- /dev/null +++ b/ai21/clients/studio/resources/assistant/studio_thread.py @@ -0,0 +1,28 @@ +from __future__ import annotations + +from typing import List + +from ai21.clients.common.assistant.threads import Threads +from ai21.clients.studio.resources.studio_resource import StudioResource, AsyncStudioResource +from ai21.models.assistant.message import Message +from ai21.models.responses.thread_response import Thread + + +class StudioThread(StudioResource, Threads): + def create(self, messages: List[Message], **kwargs) -> Thread: + body = dict(messages=messages) + + return self._post(path=f"/{self._module_name}", body=body, response_cls=Thread) + + def get(self, thread_id: str) -> Thread: + return self._get(path=f"/{self._module_name}/{thread_id}", response_cls=Thread) + + +class AsyncStudioThread(AsyncStudioResource, Threads): + async def create(self, messages: List[Message], **kwargs) -> Thread: + body = dict(messages=messages) + + return await self._post(path=f"/{self._module_name}", body=body, response_cls=Thread) + + async def get(self, thread_id: str) -> Thread: + return await self._get(path=f"/{self._module_name}/{thread_id}", response_cls=Thread) diff --git a/ai21/clients/studio/resources/beta/async_beta.py b/ai21/clients/studio/resources/beta/async_beta.py index 521c7a13..1bb13bbe 100644 --- a/ai21/clients/studio/resources/beta/async_beta.py +++ b/ai21/clients/studio/resources/beta/async_beta.py @@ -1,4 +1,5 @@ from ai21.clients.studio.resources.assistant.studio_assistant import AsyncStudioAssistant +from ai21.clients.studio.resources.assistant.studio_thread import AsyncStudioThread from ai21.clients.studio.resources.studio_conversational_rag import AsyncStudioConversationalRag from ai21.clients.studio.resources.studio_resource import AsyncStudioResource from ai21.http_client.async_http_client import AsyncAI21HTTPClient @@ -8,5 +9,6 @@ class AsyncBeta(AsyncStudioResource): def __init__(self, client: AsyncAI21HTTPClient): super().__init__(client) - self.conversational_rag = AsyncStudioConversationalRag(client) self.assistants = AsyncStudioAssistant(client) + self.conversational_rag = AsyncStudioConversationalRag(client) + self.threads = AsyncStudioThread(client) diff --git a/ai21/clients/studio/resources/beta/beta.py b/ai21/clients/studio/resources/beta/beta.py index 8560597a..affede10 100644 --- a/ai21/clients/studio/resources/beta/beta.py +++ b/ai21/clients/studio/resources/beta/beta.py @@ -1,4 +1,5 @@ from ai21.clients.studio.resources.assistant.studio_assistant import StudioAssistant +from ai21.clients.studio.resources.assistant.studio_thread import StudioThread from ai21.clients.studio.resources.studio_conversational_rag import StudioConversationalRag from ai21.clients.studio.resources.studio_resource import StudioResource from ai21.http_client.http_client import AI21HTTPClient @@ -8,5 +9,6 @@ class Beta(StudioResource): def __init__(self, client: AI21HTTPClient): super().__init__(client) - self.conversational_rag = StudioConversationalRag(client) self.assistants = StudioAssistant(client) + self.conversational_rag = StudioConversationalRag(client) + self.threads = StudioThread(client) diff --git a/ai21/models/assistant/__init__.py b/ai21/models/assistant/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/ai21/models/assistant/assistant.py b/ai21/models/assistant/assistant.py new file mode 100644 index 00000000..44c3ed06 --- /dev/null +++ b/ai21/models/assistant/assistant.py @@ -0,0 +1,12 @@ +from typing import Optional, Literal + +from typing_extensions import TypedDict + +Optimization = Literal["cost", "latency"] +Tool = Literal["rag", "internet_research", "plan_approval"] + + +class ToolResources(TypedDict, total=False): + rag: Optional[dict] + internet_research: Optional[dict] + plan_approval: Optional[dict] diff --git a/ai21/models/assistant/message.py b/ai21/models/assistant/message.py new file mode 100644 index 00000000..f3ec7d8d --- /dev/null +++ b/ai21/models/assistant/message.py @@ -0,0 +1,29 @@ +from datetime import datetime +from typing import Literal, Optional + +from typing_extensions import TypedDict + +from ai21.models.ai21_base_model import AI21BaseModel + +ThreadMessageRole = Literal["assistant", "user"] + + +class MessageContentText(TypedDict): + type: Literal["text"] + text: str + + +class Message(TypedDict): + role: ThreadMessageRole + content: MessageContentText + + +class MessageResponse(AI21BaseModel): + id: str + created_at: datetime + updated_at: datetime + object: Literal["message"] = "message" + role: ThreadMessageRole + content: MessageContentText + run_id: Optional[str] = None + assistant_id: Optional[str] = None diff --git a/ai21/models/responses/assistant_response.py b/ai21/models/responses/assistant_response.py index d7b0f186..3263b5a8 100644 --- a/ai21/models/responses/assistant_response.py +++ b/ai21/models/responses/assistant_response.py @@ -1,26 +1,15 @@ from datetime import datetime from typing import Optional, List, Literal -from typing_extensions import TypedDict - from ai21.models.ai21_base_model import AI21BaseModel +from ai21.models.assistant.assistant import ToolResources -Optimization = Literal["cost", "latency"] -Tool = Literal["rag", "internet_research", "plan_approval"] - - -class ToolResources(TypedDict, total=False): - rag: Optional[dict] - internet_research: Optional[dict] - plan_approval: Optional[dict] - - -class AssistantResponse(AI21BaseModel): +class Assistant(AI21BaseModel): id: str created_at: datetime updated_at: datetime - object: str + object: Literal["assistant"] = "assistant" name: str description: Optional[str] = None optimization: str @@ -33,5 +22,5 @@ class AssistantResponse(AI21BaseModel): tool_resources: Optional[ToolResources] = None -class ListAssistantResponse(AI21BaseModel): - results: List[AssistantResponse] +class ListAssistant(AI21BaseModel): + results: List[Assistant] diff --git a/ai21/models/responses/thread_response.py b/ai21/models/responses/thread_response.py new file mode 100644 index 00000000..b2c9bbc7 --- /dev/null +++ b/ai21/models/responses/thread_response.py @@ -0,0 +1,15 @@ +from datetime import datetime +from typing import List, Literal + +from ai21.models.ai21_base_model import AI21BaseModel + + +class Thread(AI21BaseModel): + id: str + created_at: datetime + updated_at: datetime + object: Literal["thread"] = "thread" + + +class ListThread(AI21BaseModel): + results: List[Thread]