From 7e46a30cce25e9367101e00617245a603186fce1 Mon Sep 17 00:00:00 2001 From: benshuk Date: Tue, 7 Jan 2025 14:16:40 +0200 Subject: [PATCH 1/5] feat: :sparkles: add support for status polling for `Run` --- ai21/clients/common/beta/assistant/runs.py | 6 ++++ .../resources/beta/assistant/thread_runs.py | 34 +++++++++++++++++++ 2 files changed, 40 insertions(+) diff --git a/ai21/clients/common/beta/assistant/runs.py b/ai21/clients/common/beta/assistant/runs.py index 0ea681a6..652e6488 100644 --- a/ai21/clients/common/beta/assistant/runs.py +++ b/ai21/clients/common/beta/assistant/runs.py @@ -65,3 +65,9 @@ def cancel( @abstractmethod def submit_tool_outputs(self, *, thread_id: str, run_id: str, tool_outputs: List[ToolOutput]) -> RunResponse: pass + + @abstractmethod + def poll_for_status( + self, *, thread_id: str, run_id: str, polling_interval: int, polling_timeout: int + ) -> RunResponse: + pass diff --git a/ai21/clients/studio/resources/beta/assistant/thread_runs.py b/ai21/clients/studio/resources/beta/assistant/thread_runs.py index 04908fc5..e5e9060e 100644 --- a/ai21/clients/studio/resources/beta/assistant/thread_runs.py +++ b/ai21/clients/studio/resources/beta/assistant/thread_runs.py @@ -1,5 +1,7 @@ from __future__ import annotations +import asyncio +import time from typing import List from ai21.clients.common.beta.assistant.runs import BaseRuns @@ -55,6 +57,22 @@ def submit_tool_outputs(self, *, thread_id: str, run_id: str, tool_outputs: List response_cls=RunResponse, ) + def poll_for_status( + self, *, thread_id: str, run_id: str, polling_interval: int = 1, timeout: int = 60 + ) -> RunResponse: + start_time = time.time() + run = self.retrieve(thread_id=thread_id, run_id=run_id) + + while run.status == "in_progress": + run = self.retrieve(thread_id=thread_id, run_id=run_id) + + if time.time() - start_time > timeout: + break + else: + time.sleep(polling_interval) + + return run + class AsyncThreadRuns(AsyncStudioResource, BaseRuns): async def create( @@ -102,3 +120,19 @@ async def submit_tool_outputs(self, *, thread_id: str, run_id: str, tool_outputs body=body, response_cls=RunResponse, ) + + async def poll_for_status( + self, *, thread_id: str, run_id: str, polling_interval: int = 1, timeout: int = 60 + ) -> RunResponse: + start_time = time.time() + run = await self.retrieve(thread_id=thread_id, run_id=run_id) + + while run.status == "in_progress": + run = await self.retrieve(thread_id=thread_id, run_id=run_id) + + if time.time() - start_time > timeout: + break + else: + await asyncio.sleep(polling_interval) + + return run From a8a40af5450fe16fe88ff563114c9c6de3eb98f0 Mon Sep 17 00:00:00 2001 From: benshuk Date: Tue, 7 Jan 2025 16:54:42 +0200 Subject: [PATCH 2/5] chore: :wrench: constants --- ai21/models/assistant/run.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/ai21/models/assistant/run.py b/ai21/models/assistant/run.py index ed118545..850cf765 100644 --- a/ai21/models/assistant/run.py +++ b/ai21/models/assistant/run.py @@ -1,4 +1,4 @@ -from typing import Literal, Any, List +from typing import Literal, Any, List, Set from typing_extensions import TypedDict @@ -15,6 +15,10 @@ "requires_action", ] +TERMINATED_RUN_STATUSES: Set[RunStatus] = {"completed", "failed", "expired", "cancelled", "requires_action"} +DEFAULT_RUN_POLL_INTERVAL: float = 1 # seconds +DEFAULT_RUN_POLL_TIMEOUT: float = 60 # seconds + class ToolOutput(TypedDict): tool_call_id: str From 5ba357065b3d9e3c4d5e0226ced30f2a634979c8 Mon Sep 17 00:00:00 2001 From: benshuk Date: Tue, 7 Jan 2025 16:55:52 +0200 Subject: [PATCH 3/5] refactor: :recycle: polling implementation --- ai21/clients/common/beta/assistant/runs.py | 18 ++- .../resources/beta/assistant/thread_runs.py | 109 +++++++++++++++--- 2 files changed, 106 insertions(+), 21 deletions(-) diff --git a/ai21/clients/common/beta/assistant/runs.py b/ai21/clients/common/beta/assistant/runs.py index 652e6488..a9327129 100644 --- a/ai21/clients/common/beta/assistant/runs.py +++ b/ai21/clients/common/beta/assistant/runs.py @@ -67,7 +67,21 @@ def submit_tool_outputs(self, *, thread_id: str, run_id: str, tool_outputs: List pass @abstractmethod - def poll_for_status( - self, *, thread_id: str, run_id: str, polling_interval: int, polling_timeout: int + def _poll_for_status( + self, *, thread_id: str, run_id: str, poll_interval: float, poll_timeout: float + ) -> RunResponse: + pass + + @abstractmethod + def create_and_poll( + self, + *, + thread_id: str, + assistant_id: str, + description: str | NotGiven, + optimization: Optimization | NotGiven, + poll_interval: float, + poll_timeout: float, + **kwargs, ) -> RunResponse: pass diff --git a/ai21/clients/studio/resources/beta/assistant/thread_runs.py b/ai21/clients/studio/resources/beta/assistant/thread_runs.py index e5e9060e..b9fb722c 100644 --- a/ai21/clients/studio/resources/beta/assistant/thread_runs.py +++ b/ai21/clients/studio/resources/beta/assistant/thread_runs.py @@ -7,7 +7,12 @@ from ai21.clients.common.beta.assistant.runs import BaseRuns from ai21.clients.studio.resources.studio_resource import StudioResource, AsyncStudioResource from ai21.models.assistant.assistant import Optimization -from ai21.models.assistant.run import ToolOutput +from ai21.models.assistant.run import ( + ToolOutput, + TERMINATED_RUN_STATUSES, + DEFAULT_RUN_POLL_INTERVAL, + DEFAULT_RUN_POLL_TIMEOUT, +) from ai21.models.responses.run_response import RunResponse from ai21.types import NotGiven, NOT_GIVEN @@ -57,21 +62,54 @@ def submit_tool_outputs(self, *, thread_id: str, run_id: str, tool_outputs: List response_cls=RunResponse, ) - def poll_for_status( - self, *, thread_id: str, run_id: str, polling_interval: int = 1, timeout: int = 60 + def _poll_for_status( + self, *, thread_id: str, run_id: str, poll_interval: float, poll_timeout: float ) -> RunResponse: start_time = time.time() - run = self.retrieve(thread_id=thread_id, run_id=run_id) - while run.status == "in_progress": + while True: run = self.retrieve(thread_id=thread_id, run_id=run_id) - if time.time() - start_time > timeout: - break - else: - time.sleep(polling_interval) + if run.status in TERMINATED_RUN_STATUSES: + return run - return run + if (time.time() - start_time) > poll_timeout: + return run + + time.sleep(poll_interval) + + def create_and_poll( + self, + *, + thread_id: str, + assistant_id: str, + description: str | NotGiven = NOT_GIVEN, + optimization: Optimization | NotGiven = NOT_GIVEN, + poll_interval: float = DEFAULT_RUN_POLL_INTERVAL, + poll_timeout: float = DEFAULT_RUN_POLL_TIMEOUT, + **kwargs, + ) -> RunResponse: + """ + Create a run and poll for its status until it is no longer in progress or the timeout is reached. + + Args: + thread_id: The ID of the thread. + assistant_id: The ID of the assistant. + description: The description of the run. + optimization: The optimization level to use. + poll_interval: The interval in seconds to poll for the run status. + poll_timeout: The timeout in seconds to wait for the run to complete. + + Returns: + The run response. + """ + run = self.create( + thread_id=thread_id, assistant_id=assistant_id, description=description, optimization=optimization, **kwargs + ) + + return self._poll_for_status( + thread_id=thread_id, run_id=run.id, poll_interval=poll_interval, poll_timeout=poll_timeout + ) class AsyncThreadRuns(AsyncStudioResource, BaseRuns): @@ -121,18 +159,51 @@ async def submit_tool_outputs(self, *, thread_id: str, run_id: str, tool_outputs response_cls=RunResponse, ) - async def poll_for_status( - self, *, thread_id: str, run_id: str, polling_interval: int = 1, timeout: int = 60 + async def _poll_for_status( + self, *, thread_id: str, run_id: str, poll_interval: float, poll_timeout: float ) -> RunResponse: start_time = time.time() - run = await self.retrieve(thread_id=thread_id, run_id=run_id) - while run.status == "in_progress": + while True: run = await self.retrieve(thread_id=thread_id, run_id=run_id) - if time.time() - start_time > timeout: - break - else: - await asyncio.sleep(polling_interval) + if run.status in TERMINATED_RUN_STATUSES: + return run - return run + if (time.time() - start_time) > poll_timeout: + return run + + await asyncio.sleep(poll_interval) + + async def create_and_poll( + self, + *, + thread_id: str, + assistant_id: str, + description: str | NotGiven = NOT_GIVEN, + optimization: Optimization | NotGiven = NOT_GIVEN, + poll_interval: float = DEFAULT_RUN_POLL_INTERVAL, + poll_timeout: float = DEFAULT_RUN_POLL_TIMEOUT, + **kwargs, + ) -> RunResponse: + """ + Create a run and poll for its status until it is no longer in progress or the timeout is reached. + + Args: + thread_id: The ID of the thread. + assistant_id: The ID of the assistant. + description: The description of the run. + optimization: The optimization level to use. + poll_interval: The interval in seconds to poll for the run status. + poll_timeout: The timeout in seconds to wait for the run to complete. + + Returns: + The run response. + """ + run = await self.create( + thread_id=thread_id, assistant_id=assistant_id, description=description, optimization=optimization, **kwargs + ) + + return await self._poll_for_status( + thread_id=thread_id, run_id=run.id, poll_interval=poll_interval, poll_timeout=poll_timeout + ) From ddef14cd1a2992710922de879e46033d5304d46f Mon Sep 17 00:00:00 2001 From: benshuk Date: Tue, 7 Jan 2025 17:12:18 +0200 Subject: [PATCH 4/5] docs: :memo: update assistant examples --- examples/studio/assistant/assistant.py | 16 ++-------------- examples/studio/assistant/async_assistant.py | 18 ++++-------------- 2 files changed, 6 insertions(+), 28 deletions(-) diff --git a/examples/studio/assistant/assistant.py b/examples/studio/assistant/assistant.py index bbd1eb1d..a3ba465c 100644 --- a/examples/studio/assistant/assistant.py +++ b/examples/studio/assistant/assistant.py @@ -1,9 +1,5 @@ -import time - from ai21 import AI21Client -TIMEOUT = 20 - def main(): ai21_client = AI21Client() @@ -19,25 +15,17 @@ def main(): ] ) - run = ai21_client.beta.threads.runs.create( + run = ai21_client.beta.threads.runs.create_and_poll( thread_id=thread.id, assistant_id=assistant.id, ) - start = time.time() - - while run.status == "in_progress": - run = ai21_client.beta.threads.runs.retrieve(thread_id=thread.id, run_id=run.id) - if time.time() - start > TIMEOUT: - break - time.sleep(1) - if run.status == "completed": messages = ai21_client.beta.threads.messages.list(thread_id=thread.id) print("Messages:") print("\n".join(f"{msg.role}: {msg.content['text']}" for msg in messages.results)) else: - raise Exception(f"Run failed. Status: {run.status}") + print(f"Run status: {run.status}") if __name__ == "__main__": diff --git a/examples/studio/assistant/async_assistant.py b/examples/studio/assistant/async_assistant.py index 19cf3d04..f0cf4b90 100644 --- a/examples/studio/assistant/async_assistant.py +++ b/examples/studio/assistant/async_assistant.py @@ -1,12 +1,8 @@ import asyncio -import time from ai21 import AsyncAI21Client -TIMEOUT = 20 - - async def main(): ai21_client = AsyncAI21Client() @@ -21,25 +17,19 @@ async def main(): ] ) - run = await ai21_client.beta.threads.runs.create( + run = await ai21_client.beta.threads.runs.create_and_poll( thread_id=thread.id, assistant_id=assistant.id, ) - start = time.time() - - while run.status == "in_progress": - run = await ai21_client.beta.threads.runs.retrieve(thread_id=thread.id, run_id=run.id) - if time.time() - start > TIMEOUT: - break - time.sleep(1) - if run.status == "completed": messages = await ai21_client.beta.threads.messages.list(thread_id=thread.id) print("Messages:") print("\n".join(f"{msg.role}: {msg.content['text']}" for msg in messages.results)) - else: + elif run.status == "failed": raise Exception(f"Run failed. Status: {run.status}") + else: + print(f"Run status: {run.status}") if __name__ == "__main__": From 88be228bdcf3229a7f2dcbdf01add40f00a0036f Mon Sep 17 00:00:00 2001 From: benshuk Date: Tue, 7 Jan 2025 17:19:20 +0200 Subject: [PATCH 5/5] refactor: :recycle: rename `poll_interval` and `poll_timeout` to indicate seconds --- ai21/clients/common/beta/assistant/runs.py | 4 +- .../resources/beta/assistant/thread_runs.py | 40 +++---------------- 2 files changed, 8 insertions(+), 36 deletions(-) diff --git a/ai21/clients/common/beta/assistant/runs.py b/ai21/clients/common/beta/assistant/runs.py index a9327129..77c9af4e 100644 --- a/ai21/clients/common/beta/assistant/runs.py +++ b/ai21/clients/common/beta/assistant/runs.py @@ -80,8 +80,8 @@ def create_and_poll( assistant_id: str, description: str | NotGiven, optimization: Optimization | NotGiven, - poll_interval: float, - poll_timeout: float, + poll_interval_sec: float, + poll_timeout_sec: float, **kwargs, ) -> RunResponse: pass diff --git a/ai21/clients/studio/resources/beta/assistant/thread_runs.py b/ai21/clients/studio/resources/beta/assistant/thread_runs.py index b9fb722c..19038cca 100644 --- a/ai21/clients/studio/resources/beta/assistant/thread_runs.py +++ b/ai21/clients/studio/resources/beta/assistant/thread_runs.py @@ -85,30 +85,16 @@ def create_and_poll( assistant_id: str, description: str | NotGiven = NOT_GIVEN, optimization: Optimization | NotGiven = NOT_GIVEN, - poll_interval: float = DEFAULT_RUN_POLL_INTERVAL, - poll_timeout: float = DEFAULT_RUN_POLL_TIMEOUT, + poll_interval_sec: float = DEFAULT_RUN_POLL_INTERVAL, + poll_timeout_sec: float = DEFAULT_RUN_POLL_TIMEOUT, **kwargs, ) -> RunResponse: - """ - Create a run and poll for its status until it is no longer in progress or the timeout is reached. - - Args: - thread_id: The ID of the thread. - assistant_id: The ID of the assistant. - description: The description of the run. - optimization: The optimization level to use. - poll_interval: The interval in seconds to poll for the run status. - poll_timeout: The timeout in seconds to wait for the run to complete. - - Returns: - The run response. - """ run = self.create( thread_id=thread_id, assistant_id=assistant_id, description=description, optimization=optimization, **kwargs ) return self._poll_for_status( - thread_id=thread_id, run_id=run.id, poll_interval=poll_interval, poll_timeout=poll_timeout + thread_id=thread_id, run_id=run.id, poll_interval=poll_interval_sec, poll_timeout=poll_timeout_sec ) @@ -182,28 +168,14 @@ async def create_and_poll( assistant_id: str, description: str | NotGiven = NOT_GIVEN, optimization: Optimization | NotGiven = NOT_GIVEN, - poll_interval: float = DEFAULT_RUN_POLL_INTERVAL, - poll_timeout: float = DEFAULT_RUN_POLL_TIMEOUT, + poll_interval_sec: float = DEFAULT_RUN_POLL_INTERVAL, + poll_timeout_sec: float = DEFAULT_RUN_POLL_TIMEOUT, **kwargs, ) -> RunResponse: - """ - Create a run and poll for its status until it is no longer in progress or the timeout is reached. - - Args: - thread_id: The ID of the thread. - assistant_id: The ID of the assistant. - description: The description of the run. - optimization: The optimization level to use. - poll_interval: The interval in seconds to poll for the run status. - poll_timeout: The timeout in seconds to wait for the run to complete. - - Returns: - The run response. - """ run = await self.create( thread_id=thread_id, assistant_id=assistant_id, description=description, optimization=optimization, **kwargs ) return await self._poll_for_status( - thread_id=thread_id, run_id=run.id, poll_interval=poll_interval, poll_timeout=poll_timeout + thread_id=thread_id, run_id=run.id, poll_interval=poll_interval_sec, poll_timeout=poll_timeout_sec )