Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 20 additions & 0 deletions ai21/clients/common/beta/assistant/runs.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,3 +65,23 @@ def cancel(
@abstractmethod
def submit_tool_outputs(self, *, thread_id: str, run_id: str, tool_outputs: List[ToolOutput]) -> RunResponse:
pass

@abstractmethod
def _poll_for_status(
self, *, thread_id: str, run_id: str, poll_interval: float, poll_timeout: float
) -> RunResponse:
pass

@abstractmethod
def create_and_poll(
self,
*,
thread_id: str,
assistant_id: str,
description: str | NotGiven,
optimization: Optimization | NotGiven,
poll_interval_sec: float,
poll_timeout_sec: float,
**kwargs,
) -> RunResponse:
pass
79 changes: 78 additions & 1 deletion ai21/clients/studio/resources/beta/assistant/thread_runs.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,18 @@
from __future__ import annotations

import asyncio
import time
from typing import List

from ai21.clients.common.beta.assistant.runs import BaseRuns
from ai21.clients.studio.resources.studio_resource import StudioResource, AsyncStudioResource
from ai21.models.assistant.assistant import Optimization
from ai21.models.assistant.run import ToolOutput
from ai21.models.assistant.run import (
ToolOutput,
TERMINATED_RUN_STATUSES,
DEFAULT_RUN_POLL_INTERVAL,
DEFAULT_RUN_POLL_TIMEOUT,
)
from ai21.models.responses.run_response import RunResponse
from ai21.types import NotGiven, NOT_GIVEN

Expand Down Expand Up @@ -55,6 +62,41 @@ def submit_tool_outputs(self, *, thread_id: str, run_id: str, tool_outputs: List
response_cls=RunResponse,
)

def _poll_for_status(
self, *, thread_id: str, run_id: str, poll_interval: float, poll_timeout: float
) -> RunResponse:
start_time = time.time()

while True:
run = self.retrieve(thread_id=thread_id, run_id=run_id)

if run.status in TERMINATED_RUN_STATUSES:
return run

if (time.time() - start_time) > poll_timeout:
return run

time.sleep(poll_interval)

def create_and_poll(
self,
*,
thread_id: str,
assistant_id: str,
description: str | NotGiven = NOT_GIVEN,
optimization: Optimization | NotGiven = NOT_GIVEN,
poll_interval_sec: float = DEFAULT_RUN_POLL_INTERVAL,
poll_timeout_sec: float = DEFAULT_RUN_POLL_TIMEOUT,
**kwargs,
) -> RunResponse:
run = self.create(
thread_id=thread_id, assistant_id=assistant_id, description=description, optimization=optimization, **kwargs
)

return self._poll_for_status(
thread_id=thread_id, run_id=run.id, poll_interval=poll_interval_sec, poll_timeout=poll_timeout_sec
)


class AsyncThreadRuns(AsyncStudioResource, BaseRuns):
async def create(
Expand Down Expand Up @@ -102,3 +144,38 @@ async def submit_tool_outputs(self, *, thread_id: str, run_id: str, tool_outputs
body=body,
response_cls=RunResponse,
)

async def _poll_for_status(
self, *, thread_id: str, run_id: str, poll_interval: float, poll_timeout: float
) -> RunResponse:
start_time = time.time()

while True:
run = await self.retrieve(thread_id=thread_id, run_id=run_id)

if run.status in TERMINATED_RUN_STATUSES:
return run

if (time.time() - start_time) > poll_timeout:
return run

await asyncio.sleep(poll_interval)

async def create_and_poll(
self,
*,
thread_id: str,
assistant_id: str,
description: str | NotGiven = NOT_GIVEN,
optimization: Optimization | NotGiven = NOT_GIVEN,
poll_interval_sec: float = DEFAULT_RUN_POLL_INTERVAL,
poll_timeout_sec: float = DEFAULT_RUN_POLL_TIMEOUT,
**kwargs,
) -> RunResponse:
run = await self.create(
thread_id=thread_id, assistant_id=assistant_id, description=description, optimization=optimization, **kwargs
)

return await self._poll_for_status(
thread_id=thread_id, run_id=run.id, poll_interval=poll_interval_sec, poll_timeout=poll_timeout_sec
)
6 changes: 5 additions & 1 deletion ai21/models/assistant/run.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Literal, Any, List
from typing import Literal, Any, List, Set

from typing_extensions import TypedDict

Expand All @@ -15,6 +15,10 @@
"requires_action",
]

TERMINATED_RUN_STATUSES: Set[RunStatus] = {"completed", "failed", "expired", "cancelled", "requires_action"}
DEFAULT_RUN_POLL_INTERVAL: float = 1 # seconds
DEFAULT_RUN_POLL_TIMEOUT: float = 60 # seconds


class ToolOutput(TypedDict):
tool_call_id: str
Expand Down
16 changes: 2 additions & 14 deletions examples/studio/assistant/assistant.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,5 @@
import time

from ai21 import AI21Client

TIMEOUT = 20


def main():
ai21_client = AI21Client()
Expand All @@ -19,25 +15,17 @@ def main():
]
)

run = ai21_client.beta.threads.runs.create(
run = ai21_client.beta.threads.runs.create_and_poll(
thread_id=thread.id,
assistant_id=assistant.id,
)

start = time.time()

while run.status == "in_progress":
run = ai21_client.beta.threads.runs.retrieve(thread_id=thread.id, run_id=run.id)
if time.time() - start > TIMEOUT:
break
time.sleep(1)

if run.status == "completed":
messages = ai21_client.beta.threads.messages.list(thread_id=thread.id)
print("Messages:")
print("\n".join(f"{msg.role}: {msg.content['text']}" for msg in messages.results))
else:
raise Exception(f"Run failed. Status: {run.status}")
print(f"Run status: {run.status}")


if __name__ == "__main__":
Expand Down
18 changes: 4 additions & 14 deletions examples/studio/assistant/async_assistant.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,8 @@
import asyncio
import time

from ai21 import AsyncAI21Client


TIMEOUT = 20


async def main():
ai21_client = AsyncAI21Client()

Expand All @@ -21,25 +17,19 @@ async def main():
]
)

run = await ai21_client.beta.threads.runs.create(
run = await ai21_client.beta.threads.runs.create_and_poll(
thread_id=thread.id,
assistant_id=assistant.id,
)

start = time.time()

while run.status == "in_progress":
run = await ai21_client.beta.threads.runs.retrieve(thread_id=thread.id, run_id=run.id)
if time.time() - start > TIMEOUT:
break
time.sleep(1)

if run.status == "completed":
messages = await ai21_client.beta.threads.messages.list(thread_id=thread.id)
print("Messages:")
print("\n".join(f"{msg.role}: {msg.content['text']}" for msg in messages.results))
else:
elif run.status == "failed":
raise Exception(f"Run failed. Status: {run.status}")
else:
print(f"Run status: {run.status}")


if __name__ == "__main__":
Expand Down
Loading