diff --git a/README.md b/README.md index b3bbeb80..71856588 100644 --- a/README.md +++ b/README.md @@ -26,6 +26,7 @@ - [Installation](#Installation) 💿 - [Usage - Chat Completions](#Usage) - [Conversational RAG (Beta)](#Conversational-RAG-Beta) +- [Assistants (Beta)](#Assistants-Beta) - [Older Models Support Usage](#Older-Models-Support-Usage) - [More Models](#More-Models) - [Streaming](#Streaming) @@ -388,6 +389,41 @@ For a more detailed example, see the chat [sync](examples/studio/conversational_ --- +### Assistants (Beta) + +Create assistants to help you with your tasks. + +```python +from time import sleep +from ai21 import AI21Client +from ai21.models.assistant.message import Message + +messages = [ + Message(content={"type": "text", "text": "Youre message here"}, role="user"), +] + +client = AI21Client() + +assistant = client.beta.assistants.create(name="My assistant") +thread = client.beta.threads.create(messages=messages) +run = client.beta.threads.runs.create(thread_id=thread.id, assistant_id=assistant.id) + +while run.status == "in_progress": + run = client.beta.threads.runs.get(thread_id=thread.id, run_id=run.id) + sleep(1) + +if run.status == "completed": + messages = client.beta.threads.messages.list(thread_id=thread.id) + print(messages) +else: + # handle error or required action + pass +``` + +For a more detailed example, see [sync](examples/studio/assistant/assistant.py) and [async](examples/studio/assistant/async_assistant.py) examples. + +--- + ### File Upload ```python diff --git a/ai21/clients/common/assistant/__init__.py b/ai21/clients/common/beta/__init__.py similarity index 100% rename from ai21/clients/common/assistant/__init__.py rename to ai21/clients/common/beta/__init__.py diff --git a/ai21/clients/studio/resources/assistant/__init__.py b/ai21/clients/common/beta/assistant/__init__.py similarity index 100% rename from ai21/clients/studio/resources/assistant/__init__.py rename to ai21/clients/common/beta/assistant/__init__.py diff --git a/ai21/clients/common/assistant/assistants.py b/ai21/clients/common/beta/assistant/assistants.py similarity index 85% rename from ai21/clients/common/assistant/assistants.py rename to ai21/clients/common/beta/assistant/assistants.py index 1beb2b7f..65bbabf2 100644 --- a/ai21/clients/common/assistant/assistants.py +++ b/ai21/clients/common/beta/assistant/assistants.py @@ -4,7 +4,7 @@ from typing import Any, Dict, List from ai21.models.assistant.assistant import Optimization, Tool, ToolResources -from ai21.models.responses.assistant_response import Assistant, ListAssistant +from ai21.models.responses.assistant_response import AssistantResponse, ListAssistant from ai21.types import NotGiven, NOT_GIVEN from ai21.utils.typing import remove_not_given @@ -19,12 +19,11 @@ def create( *, description: str | NotGiven = NOT_GIVEN, optimization: Optimization | NotGiven = NOT_GIVEN, - avatar: str | NotGiven = NOT_GIVEN, models: List[str] | NotGiven = NOT_GIVEN, tools: List[Tool] | NotGiven = NOT_GIVEN, tool_resources: ToolResources | NotGiven = NOT_GIVEN, **kwargs, - ) -> Assistant: + ) -> AssistantResponse: pass def _create_body( @@ -33,7 +32,6 @@ def _create_body( name: str, description: str | NotGiven, optimization: str | NotGiven, - avatar: str | NotGiven, models: List[str] | NotGiven, tools: List[str] | NotGiven, tool_resources: dict | NotGiven, @@ -44,7 +42,6 @@ def _create_body( "name": name, "description": description, "optimization": optimization, - "avatar": avatar, "models": models, "tools": tools, "tool_resources": tool_resources, @@ -57,7 +54,7 @@ def list(self) -> ListAssistant: pass @abstractmethod - def retrieve(self, assistant_id: str) -> Assistant: + def retrieve(self, assistant_id: str) -> AssistantResponse: pass @abstractmethod @@ -68,10 +65,9 @@ def modify( name: str | NotGiven = NOT_GIVEN, description: str | NotGiven = NOT_GIVEN, optimization: Optimization | NotGiven = NOT_GIVEN, - avatar: str | NotGiven = NOT_GIVEN, is_archived: bool | NotGiven = NOT_GIVEN, models: List[str] | NotGiven = NOT_GIVEN, tools: List[Tool] | NotGiven = NOT_GIVEN, tool_resources: ToolResources | NotGiven = NOT_GIVEN, - ) -> Assistant: + ) -> AssistantResponse: pass diff --git a/ai21/clients/common/assistant/messages.py b/ai21/clients/common/beta/assistant/messages.py similarity index 100% rename from ai21/clients/common/assistant/messages.py rename to ai21/clients/common/beta/assistant/messages.py diff --git a/ai21/clients/common/beta/assistant/runs.py b/ai21/clients/common/beta/assistant/runs.py new file mode 100644 index 00000000..656bf26e --- /dev/null +++ b/ai21/clients/common/beta/assistant/runs.py @@ -0,0 +1,67 @@ +from __future__ import annotations + +from abc import ABC, abstractmethod +from typing import List + +from ai21.models.assistant.assistant import Optimization +from ai21.models.assistant.run import ToolOutput +from ai21.models.responses.run_response import RunResponse +from ai21.types import NOT_GIVEN, NotGiven +from ai21.utils.typing import remove_not_given + + +class Runs(ABC): + _module_name = "runs" + + @abstractmethod + def create( + self, + *, + thread_id: str, + assistant_id: str, + description: str | NotGiven = NOT_GIVEN, + optimization: Optimization | NotGiven = NOT_GIVEN, + **kwargs, + ) -> RunResponse: + pass + + def _create_body( + self, + *, + thread_id: str, + assistant_id: str, + description: str | NotGiven, + optimization: str | NotGiven, + **kwargs, + ) -> dict: + return remove_not_given( + { + "thread_id": thread_id, + "assistant_id": assistant_id, + "description": description, + "optimization": optimization, + **kwargs, + } + ) + + @abstractmethod + def retrieve( + self, + *, + thread_id: str, + run_id: str, + ) -> RunResponse: + pass + + @abstractmethod + def cancel( + self, + *, + thread_id: str, + run_id: str, + ) -> RunResponse: + pass + + @abstractmethod + def submit_tool_outputs(self, *, thread_id: str, run_id: str, tool_outputs: List[ToolOutput]) -> RunResponse: + pass diff --git a/ai21/clients/common/assistant/threads.py b/ai21/clients/common/beta/assistant/threads.py similarity index 63% rename from ai21/clients/common/assistant/threads.py rename to ai21/clients/common/beta/assistant/threads.py index 9b64580d..75674fc9 100644 --- a/ai21/clients/common/assistant/threads.py +++ b/ai21/clients/common/beta/assistant/threads.py @@ -3,9 +3,9 @@ from abc import ABC, abstractmethod from typing import List -from ai21.clients.common.assistant.messages import Messages +from ai21.clients.common.beta.assistant.messages import Messages from ai21.models.assistant.message import Message -from ai21.models.responses.thread_response import Thread +from ai21.models.responses.thread_response import ThreadResponse class Threads(ABC): @@ -17,9 +17,9 @@ def create( self, messages: List[Message], **kwargs, - ) -> Thread: + ) -> ThreadResponse: pass @abstractmethod - def retrieve(self, thread_id: str) -> Thread: + def retrieve(self, thread_id: str) -> ThreadResponse: pass diff --git a/ai21/clients/studio/resources/beta/assistant/__init__.py b/ai21/clients/studio/resources/beta/assistant/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/ai21/clients/studio/resources/assistant/studio_assistant.py b/ai21/clients/studio/resources/beta/assistant/assistant.py similarity index 78% rename from ai21/clients/studio/resources/assistant/studio_assistant.py rename to ai21/clients/studio/resources/beta/assistant/assistant.py index f2a3c306..7b49a4c6 100644 --- a/ai21/clients/studio/resources/assistant/studio_assistant.py +++ b/ai21/clients/studio/resources/beta/assistant/assistant.py @@ -2,44 +2,42 @@ from typing import List -from ai21.clients.common.assistant.assistants import Assistants +from ai21.clients.common.beta.assistant.assistants import Assistants from ai21.clients.studio.resources.studio_resource import ( AsyncStudioResource, StudioResource, ) from ai21.models.assistant.assistant import Tool, ToolResources -from ai21.models.responses.assistant_response import Assistant, ListAssistant +from ai21.models.responses.assistant_response import AssistantResponse, ListAssistant from ai21.types import NotGiven, NOT_GIVEN -class StudioAssistant(StudioResource, Assistants): +class Assistant(StudioResource, Assistants): def create( self, name: str, *, description: str | NotGiven = NOT_GIVEN, optimization: str | NotGiven = NOT_GIVEN, - avatar: str | NotGiven = NOT_GIVEN, models: List[str] | NotGiven = NOT_GIVEN, tools: List[Tool] | NotGiven = NOT_GIVEN, tool_resources: ToolResources | NotGiven = NOT_GIVEN, **kwargs, - ) -> Assistant: + ) -> AssistantResponse: body = self._create_body( name=name, description=description, optimization=optimization, - avatar=avatar, models=models, tools=tools, tool_resources=tool_resources, **kwargs, ) - return self._post(path=f"/{self._module_name}", body=body, response_cls=Assistant) + return self._post(path=f"/{self._module_name}", body=body, response_cls=AssistantResponse) - def retrieve(self, assistant_id: str) -> Assistant: - return self._get(path=f"/{self._module_name}/{assistant_id}", response_cls=Assistant) + def retrieve(self, assistant_id: str) -> AssistantResponse: + return self._get(path=f"/{self._module_name}/{assistant_id}", response_cls=AssistantResponse) def list(self) -> ListAssistant: return self._get(path=f"/{self._module_name}", response_cls=ListAssistant) @@ -51,54 +49,50 @@ def modify( name: str | NotGiven = NOT_GIVEN, description: str | NotGiven = NOT_GIVEN, optimization: str | NotGiven = NOT_GIVEN, - avatar: str | NotGiven = NOT_GIVEN, is_archived: bool | NotGiven = NOT_GIVEN, models: List[str] | NotGiven = NOT_GIVEN, tools: List[Tool] | NotGiven = NOT_GIVEN, tool_resources: ToolResources | NotGiven = NOT_GIVEN, - ) -> Assistant: + ) -> AssistantResponse: body = self._create_body( name=name, description=description, optimization=optimization, - avatar=avatar, is_archived=is_archived, models=models, tools=tools, tool_resources=tool_resources, ) - return self._patch(path=f"/{self._module_name}/{assistant_id}", body=body, response_cls=Assistant) + return self._patch(path=f"/{self._module_name}/{assistant_id}", body=body, response_cls=AssistantResponse) -class AsyncStudioAssistant(AsyncStudioResource, Assistants): +class AsyncAssistant(AsyncStudioResource, Assistants): async def create( self, name: str, *, description: str | NotGiven = NOT_GIVEN, optimization: str | NotGiven = NOT_GIVEN, - avatar: str | NotGiven = NOT_GIVEN, models: List[str] | NotGiven = NOT_GIVEN, tools: List[Tool] | NotGiven = NOT_GIVEN, tool_resources: ToolResources | NotGiven = NOT_GIVEN, **kwargs, - ) -> Assistant: + ) -> AssistantResponse: body = self._create_body( name=name, description=description, optimization=optimization, - avatar=avatar, models=models, tools=tools, tool_resources=tool_resources, **kwargs, ) - return await self._post(path=f"/{self._module_name}", body=body, response_cls=Assistant) + return await self._post(path=f"/{self._module_name}", body=body, response_cls=AssistantResponse) - async def retrieve(self, assistant_id: str) -> Assistant: - return await self._get(path=f"/{self._module_name}/{assistant_id}", response_cls=Assistant) + async def retrieve(self, assistant_id: str) -> AssistantResponse: + return await self._get(path=f"/{self._module_name}/{assistant_id}", response_cls=AssistantResponse) async def list(self) -> ListAssistant: return await self._get(path=f"/{self._module_name}", response_cls=ListAssistant) @@ -110,21 +104,19 @@ async def modify( name: str | NotGiven = NOT_GIVEN, description: str | NotGiven = NOT_GIVEN, optimization: str | NotGiven = NOT_GIVEN, - avatar: str | NotGiven = NOT_GIVEN, is_archived: bool | NotGiven = NOT_GIVEN, models: List[str] | NotGiven = NOT_GIVEN, tools: List[Tool] | NotGiven = NOT_GIVEN, tool_resources: ToolResources | NotGiven = NOT_GIVEN, - ) -> Assistant: + ) -> AssistantResponse: body = self._create_body( name=name, description=description, optimization=optimization, - avatar=avatar, is_archived=is_archived, models=models, tools=tools, tool_resources=tool_resources, ) - return await self._patch(path=f"/{self._module_name}/{assistant_id}", body=body, response_cls=Assistant) + return await self._patch(path=f"/{self._module_name}/{assistant_id}", body=body, response_cls=AssistantResponse) diff --git a/ai21/clients/studio/resources/assistant/studio_thread.py b/ai21/clients/studio/resources/beta/assistant/thread.py similarity index 52% rename from ai21/clients/studio/resources/assistant/studio_thread.py rename to ai21/clients/studio/resources/beta/assistant/thread.py index 60c88e22..b2fdf57b 100644 --- a/ai21/clients/studio/resources/assistant/studio_thread.py +++ b/ai21/clients/studio/resources/beta/assistant/thread.py @@ -2,48 +2,51 @@ from typing import List -from ai21.clients.common.assistant.threads import Threads -from ai21.clients.studio.resources.assistant.studio_thread_message import StudioThreadMessage, AsyncStudioThreadMessage +from ai21.clients.common.beta.assistant.threads import Threads +from ai21.clients.studio.resources.beta.assistant.thread_message import ThreadMessage, AsyncThreadMessage +from ai21.clients.studio.resources.beta.assistant.thread_run import AsyncThreadRun, ThreadRun from ai21.clients.studio.resources.studio_resource import StudioResource, AsyncStudioResource from ai21.http_client.async_http_client import AsyncAI21HTTPClient from ai21.http_client.http_client import AI21HTTPClient from ai21.models.assistant.message import Message -from ai21.models.responses.thread_response import Thread +from ai21.models.responses.thread_response import ThreadResponse -class StudioThread(StudioResource, Threads): +class Thread(StudioResource, Threads): def __init__(self, client: AI21HTTPClient): super().__init__(client) - self.messages = StudioThreadMessage(client) + self.messages = ThreadMessage(client) + self.runs = ThreadRun(client) def create( self, messages: List[Message], **kwargs, - ) -> Thread: + ) -> ThreadResponse: body = dict(messages=messages) - return self._post(path=f"/{self._module_name}", body=body, response_cls=Thread) + return self._post(path=f"/{self._module_name}", body=body, response_cls=ThreadResponse) - def retrieve(self, thread_id: str) -> Thread: - return self._get(path=f"/{self._module_name}/{thread_id}", response_cls=Thread) + def retrieve(self, thread_id: str) -> ThreadResponse: + return self._get(path=f"/{self._module_name}/{thread_id}", response_cls=ThreadResponse) -class AsyncStudioThread(AsyncStudioResource, Threads): +class AsyncThread(AsyncStudioResource, Threads): def __init__(self, client: AsyncAI21HTTPClient): super().__init__(client) - self.messages = AsyncStudioThreadMessage(client) + self.messages = AsyncThreadMessage(client) + self.runs = AsyncThreadRun(client) async def create( self, messages: List[Message], **kwargs, - ) -> Thread: + ) -> ThreadResponse: body = dict(messages=messages) - return await self._post(path=f"/{self._module_name}", body=body, response_cls=Thread) + return await self._post(path=f"/{self._module_name}", body=body, response_cls=ThreadResponse) - async def retrieve(self, thread_id: str) -> Thread: - return await self._get(path=f"/{self._module_name}/{thread_id}", response_cls=Thread) + async def retrieve(self, thread_id: str) -> ThreadResponse: + return await self._get(path=f"/{self._module_name}/{thread_id}", response_cls=ThreadResponse) diff --git a/ai21/clients/studio/resources/assistant/studio_thread_message.py b/ai21/clients/studio/resources/beta/assistant/thread_message.py similarity index 89% rename from ai21/clients/studio/resources/assistant/studio_thread_message.py rename to ai21/clients/studio/resources/beta/assistant/thread_message.py index 95cb1b47..f8a8ee10 100644 --- a/ai21/clients/studio/resources/assistant/studio_thread_message.py +++ b/ai21/clients/studio/resources/beta/assistant/thread_message.py @@ -1,12 +1,12 @@ from __future__ import annotations -from ai21.clients.common.assistant.messages import Messages +from ai21.clients.common.beta.assistant.messages import Messages from ai21.clients.studio.resources.studio_resource import StudioResource, AsyncStudioResource from ai21.models.assistant.message import ThreadMessageRole, MessageContentText from ai21.models.responses.message_response import MessageResponse, ListMessageResponse -class StudioThreadMessage(StudioResource, Messages): +class ThreadMessage(StudioResource, Messages): def create( self, thread_id: str, @@ -26,7 +26,7 @@ def list(self, thread_id: str) -> ListMessageResponse: return self._get(path=f"/threads/{thread_id}/{self._module_name}", response_cls=ListMessageResponse) -class AsyncStudioThreadMessage(AsyncStudioResource, Messages): +class AsyncThreadMessage(AsyncStudioResource, Messages): async def create( self, thread_id: str, diff --git a/ai21/clients/studio/resources/beta/assistant/thread_run.py b/ai21/clients/studio/resources/beta/assistant/thread_run.py new file mode 100644 index 00000000..71955363 --- /dev/null +++ b/ai21/clients/studio/resources/beta/assistant/thread_run.py @@ -0,0 +1,104 @@ +from __future__ import annotations + +from typing import List + +from ai21.clients.common.beta.assistant.runs import Runs +from ai21.clients.studio.resources.studio_resource import StudioResource, AsyncStudioResource +from ai21.models.assistant.assistant import Optimization +from ai21.models.assistant.run import ToolOutput +from ai21.models.responses.run_response import RunResponse +from ai21.types import NotGiven, NOT_GIVEN + + +class ThreadRun(StudioResource, Runs): + def create( + self, + *, + thread_id: str, + assistant_id: str, + description: str | NotGiven = NOT_GIVEN, + optimization: Optimization | NotGiven = NOT_GIVEN, + **kwargs, + ) -> RunResponse: + body = self._create_body( + thread_id=thread_id, + assistant_id=assistant_id, + description=description, + optimization=optimization, + **kwargs, + ) + + return self._post(path=f"/threads/{thread_id}/{self._module_name}", body=body, response_cls=RunResponse) + + def retrieve( + self, + *, + thread_id: str, + run_id: str, + ) -> RunResponse: + return self._get(path=f"/threads/{thread_id}/{self._module_name}/{run_id}", response_cls=RunResponse) + + def cancel( + self, + *, + thread_id: str, + run_id: str, + ) -> RunResponse: + return self._post(path=f"/threads/{thread_id}/{self._module_name}/{run_id}/cancel", response_cls=RunResponse) + + def submit_tool_outputs(self, *, thread_id: str, run_id: str, tool_outputs: List[ToolOutput]) -> RunResponse: + body = dict(tool_outputs=tool_outputs) + + return self._post( + path=f"/threads/{thread_id}/{self._module_name}/{run_id}/submit_tool_outputs", + body=body, + response_cls=RunResponse, + ) + + +class AsyncThreadRun(AsyncStudioResource, Runs): + async def create( + self, + *, + thread_id: str, + assistant_id: str, + description: str | NotGiven = NOT_GIVEN, + optimization: Optimization | NotGiven = NOT_GIVEN, + **kwargs, + ) -> RunResponse: + body = self._create_body( + thread_id=thread_id, + assistant_id=assistant_id, + description=description, + optimization=optimization, + **kwargs, + ) + + return await self._post(path=f"/threads/{thread_id}/{self._module_name}", body=body, response_cls=RunResponse) + + async def retrieve( + self, + *, + thread_id: str, + run_id: str, + ) -> RunResponse: + return await self._get(path=f"/threads/{thread_id}/{self._module_name}/{run_id}", response_cls=RunResponse) + + async def cancel( + self, + *, + thread_id: str, + run_id: str, + ) -> RunResponse: + return await self._post( + path=f"/threads/{thread_id}/{self._module_name}/{run_id}/cancel", response_cls=RunResponse + ) + + async def submit_tool_outputs(self, *, thread_id: str, run_id: str, tool_outputs: List[ToolOutput]) -> RunResponse: + body = dict(tool_outputs=tool_outputs) + + return await self._post( + path=f"/threads/{thread_id}/{self._module_name}/{run_id}/submit_tool_outputs", + body=body, + response_cls=RunResponse, + ) diff --git a/ai21/clients/studio/resources/beta/async_beta.py b/ai21/clients/studio/resources/beta/async_beta.py index 1bb13bbe..5c93b9f6 100644 --- a/ai21/clients/studio/resources/beta/async_beta.py +++ b/ai21/clients/studio/resources/beta/async_beta.py @@ -1,5 +1,5 @@ -from ai21.clients.studio.resources.assistant.studio_assistant import AsyncStudioAssistant -from ai21.clients.studio.resources.assistant.studio_thread import AsyncStudioThread +from ai21.clients.studio.resources.beta.assistant.assistant import AsyncAssistant +from ai21.clients.studio.resources.beta.assistant.thread import AsyncThread from ai21.clients.studio.resources.studio_conversational_rag import AsyncStudioConversationalRag from ai21.clients.studio.resources.studio_resource import AsyncStudioResource from ai21.http_client.async_http_client import AsyncAI21HTTPClient @@ -9,6 +9,6 @@ class AsyncBeta(AsyncStudioResource): def __init__(self, client: AsyncAI21HTTPClient): super().__init__(client) - self.assistants = AsyncStudioAssistant(client) + self.assistants = AsyncAssistant(client) self.conversational_rag = AsyncStudioConversationalRag(client) - self.threads = AsyncStudioThread(client) + self.threads = AsyncThread(client) diff --git a/ai21/clients/studio/resources/beta/beta.py b/ai21/clients/studio/resources/beta/beta.py index affede10..62b55e80 100644 --- a/ai21/clients/studio/resources/beta/beta.py +++ b/ai21/clients/studio/resources/beta/beta.py @@ -1,5 +1,5 @@ -from ai21.clients.studio.resources.assistant.studio_assistant import StudioAssistant -from ai21.clients.studio.resources.assistant.studio_thread import StudioThread +from ai21.clients.studio.resources.beta.assistant.assistant import Assistant +from ai21.clients.studio.resources.beta.assistant.thread import Thread from ai21.clients.studio.resources.studio_conversational_rag import StudioConversationalRag from ai21.clients.studio.resources.studio_resource import StudioResource from ai21.http_client.http_client import AI21HTTPClient @@ -9,6 +9,6 @@ class Beta(StudioResource): def __init__(self, client: AI21HTTPClient): super().__init__(client) - self.assistants = StudioAssistant(client) + self.assistants = Assistant(client) self.conversational_rag = StudioConversationalRag(client) - self.threads = StudioThread(client) + self.threads = Thread(client) diff --git a/ai21/models/assistant/run.py b/ai21/models/assistant/run.py new file mode 100644 index 00000000..ed118545 --- /dev/null +++ b/ai21/models/assistant/run.py @@ -0,0 +1,41 @@ +from typing import Literal, Any, List + +from typing_extensions import TypedDict + + +RunStatus = Literal[ + "cancelled", + "cancelling", + "completed", + "expired", + "failed", + "incomplete", + "in_progress", + "queued", + "requires_action", +] + + +class ToolOutput(TypedDict): + tool_call_id: str + output: Any + + +class Function(TypedDict): + name: str + arguments: Any + + +class ToolCall(TypedDict): + type: Literal["function"] + id: str + function: Function + + +class SubmitToolCallOutputs(TypedDict): + tool_calls: List[ToolOutput] + + +class RequiredAction(TypedDict): + type: Literal["submit_tool_outputs"] + submit_tool_outputs: SubmitToolCallOutputs diff --git a/ai21/models/responses/assistant_response.py b/ai21/models/responses/assistant_response.py index 3263b5a8..0bf40c5f 100644 --- a/ai21/models/responses/assistant_response.py +++ b/ai21/models/responses/assistant_response.py @@ -5,7 +5,7 @@ from ai21.models.assistant.assistant import ToolResources -class Assistant(AI21BaseModel): +class AssistantResponse(AI21BaseModel): id: str created_at: datetime updated_at: datetime @@ -23,4 +23,4 @@ class Assistant(AI21BaseModel): class ListAssistant(AI21BaseModel): - results: List[Assistant] + results: List[AssistantResponse] diff --git a/ai21/models/responses/run_response.py b/ai21/models/responses/run_response.py new file mode 100644 index 00000000..0a260235 --- /dev/null +++ b/ai21/models/responses/run_response.py @@ -0,0 +1,19 @@ +from datetime import datetime +from typing import Optional + +from ai21.models.ai21_base_model import AI21BaseModel +from ai21.models.assistant.assistant import Optimization +from ai21.models.assistant.run import RunStatus, RequiredAction + + +class RunResponse(AI21BaseModel): + id: str + created_at: datetime + updated_at: datetime + thread_id: str + assistant_id: str + description: Optional[str] = None + status: RunStatus + optimization: Optimization + execution_id: Optional[str] = None + required_action: Optional[RequiredAction] = None diff --git a/ai21/models/responses/thread_response.py b/ai21/models/responses/thread_response.py index b2c9bbc7..f7206a40 100644 --- a/ai21/models/responses/thread_response.py +++ b/ai21/models/responses/thread_response.py @@ -4,7 +4,7 @@ from ai21.models.ai21_base_model import AI21BaseModel -class Thread(AI21BaseModel): +class ThreadResponse(AI21BaseModel): id: str created_at: datetime updated_at: datetime @@ -12,4 +12,4 @@ class Thread(AI21BaseModel): class ListThread(AI21BaseModel): - results: List[Thread] + results: List[ThreadResponse] diff --git a/examples/studio/assistant/assistant.py b/examples/studio/assistant/assistant.py new file mode 100644 index 00000000..ebe3100f --- /dev/null +++ b/examples/studio/assistant/assistant.py @@ -0,0 +1,47 @@ +import time + +from ai21 import AI21Client + +TIMEOUT = 20 + + +def main(): + ai21_client = AI21Client() + + assistant = ai21_client.beta.assistants.create(name="My Assistant") + + thread = ai21_client.beta.threads.create( + messages=[ + { + "role": "user", + "content": { + "type": "text", + "text": "Hello", + }, + }, + ] + ) + + run = ai21_client.beta.threads.runs.create( + thread_id=thread.id, + assistant_id=assistant.id, + ) + + start = time.time() + + while run.status == "in_progress": + run = ai21_client.beta.threads.runs.retrieve(thread_id=thread.id, run_id=run.id) + if time.time() - start > TIMEOUT: + break + time.sleep(1) + + if run.status == "completed": + messages = ai21_client.beta.threads.messages.list(thread_id=thread.id) + print("Messages:") + print("\n".join(f"{msg.role}: {msg.content['text']}" for msg in messages.results)) + else: + raise Exception(f"Run failed. Status: {run.status}") + + +if __name__ == "__main__": + main() diff --git a/examples/studio/assistant/async_assistant.py b/examples/studio/assistant/async_assistant.py new file mode 100644 index 00000000..3f68805f --- /dev/null +++ b/examples/studio/assistant/async_assistant.py @@ -0,0 +1,49 @@ +import asyncio +import time + +from ai21 import AsyncAI21Client + + +TIMEOUT = 20 + + +async def main(): + ai21_client = AsyncAI21Client() + + assistant = await ai21_client.beta.assistants.create(name="My Assistant") + + thread = await ai21_client.beta.threads.create( + messages=[ + { + "role": "user", + "content": { + "type": "text", + "text": "Hello", + }, + }, + ] + ) + + run = await ai21_client.beta.threads.runs.create( + thread_id=thread.id, + assistant_id=assistant.id, + ) + + start = time.time() + + while run.status == "in_progress": + run = await ai21_client.beta.threads.runs.retrieve(thread_id=thread.id, run_id=run.id) + if time.time() - start > TIMEOUT: + break + time.sleep(1) + + if run.status == "completed": + messages = await ai21_client.beta.threads.messages.list(thread_id=thread.id) + print("Messages:") + print("\n".join(f"{msg.role}: {msg.content['text']}" for msg in messages.results)) + else: + raise Exception(f"Run failed. Status: {run.status}") + + +if __name__ == "__main__": + asyncio.run(main()) diff --git a/tests/integration_tests/clients/test_studio.py b/tests/integration_tests/clients/test_studio.py index 5cd0d4dd..4a581b0e 100644 --- a/tests/integration_tests/clients/test_studio.py +++ b/tests/integration_tests/clients/test_studio.py @@ -55,6 +55,8 @@ def test_studio(test_file_name: str): ("chat/async_stream_chat_completions.py",), ("conversational_rag/conversational_rag.py",), ("conversational_rag/async_conversational_rag.py",), + ("assistant/assistant.py",), + ("assistant/async_assistant.py",), ], ids=[ "when_chat__should_return_ok", @@ -62,6 +64,8 @@ def test_studio(test_file_name: str): "when_stream_chat_completions__should_return_ok", "when_conversational_rag__should_return_ok", "when_async_conversational_rag__should_return_ok", + "when_assistant__should_return_ok", + "when_async_assistant__should_return_ok", ], ) async def test_async_studio(test_file_name: str):