Skip to content
36 changes: 36 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
- [Installation](#Installation) 💿
- [Usage - Chat Completions](#Usage)
- [Conversational RAG (Beta)](#Conversational-RAG-Beta)
- [Assistants (Beta)](#Assistants-Beta)
- [Older Models Support Usage](#Older-Models-Support-Usage)
- [More Models](#More-Models)
- [Streaming](#Streaming)
Expand Down Expand Up @@ -388,6 +389,41 @@ For a more detailed example, see the chat [sync](examples/studio/conversational_

---

### Assistants (Beta)

Create assistants to help you with your tasks.

```python
from time import sleep
from ai21 import AI21Client
from ai21.models.assistant.message import Message

messages = [
Message(content={"type": "text", "text": "Youre message here"}, role="user"),
]

client = AI21Client()

assistant = client.beta.assistants.create(name="My assistant")
thread = client.beta.threads.create(messages=messages)
run = client.beta.threads.runs.create(thread_id=thread.id, assistant_id=assistant.id)

while run.status == "in_progress":
run = client.beta.threads.runs.get(thread_id=thread.id, run_id=run.id)
sleep(1)

if run.status == "completed":
messages = client.beta.threads.messages.list(thread_id=thread.id)
print(messages)
else:
# handle error or required action
pass
```

For a more detailed example, see [sync](examples/studio/assistant/assistant.py) and [async](examples/studio/assistant/async_assistant.py) examples.

---

### File Upload

```python
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from typing import Any, Dict, List

from ai21.models.assistant.assistant import Optimization, Tool, ToolResources
from ai21.models.responses.assistant_response import Assistant, ListAssistant
from ai21.models.responses.assistant_response import AssistantResponse, ListAssistant
from ai21.types import NotGiven, NOT_GIVEN
from ai21.utils.typing import remove_not_given

Expand All @@ -19,12 +19,11 @@ def create(
*,
description: str | NotGiven = NOT_GIVEN,
optimization: Optimization | NotGiven = NOT_GIVEN,
avatar: str | NotGiven = NOT_GIVEN,
models: List[str] | NotGiven = NOT_GIVEN,
tools: List[Tool] | NotGiven = NOT_GIVEN,
tool_resources: ToolResources | NotGiven = NOT_GIVEN,
**kwargs,
) -> Assistant:
) -> AssistantResponse:
pass

def _create_body(
Expand All @@ -33,7 +32,6 @@ def _create_body(
name: str,
description: str | NotGiven,
optimization: str | NotGiven,
avatar: str | NotGiven,
models: List[str] | NotGiven,
tools: List[str] | NotGiven,
tool_resources: dict | NotGiven,
Expand All @@ -44,7 +42,6 @@ def _create_body(
"name": name,
"description": description,
"optimization": optimization,
"avatar": avatar,
"models": models,
"tools": tools,
"tool_resources": tool_resources,
Expand All @@ -57,7 +54,7 @@ def list(self) -> ListAssistant:
pass

@abstractmethod
def retrieve(self, assistant_id: str) -> Assistant:
def retrieve(self, assistant_id: str) -> AssistantResponse:
pass

@abstractmethod
Expand All @@ -68,10 +65,9 @@ def modify(
name: str | NotGiven = NOT_GIVEN,
description: str | NotGiven = NOT_GIVEN,
optimization: Optimization | NotGiven = NOT_GIVEN,
avatar: str | NotGiven = NOT_GIVEN,
is_archived: bool | NotGiven = NOT_GIVEN,
models: List[str] | NotGiven = NOT_GIVEN,
tools: List[Tool] | NotGiven = NOT_GIVEN,
tool_resources: ToolResources | NotGiven = NOT_GIVEN,
) -> Assistant:
) -> AssistantResponse:
pass
67 changes: 67 additions & 0 deletions ai21/clients/common/beta/assistant/runs.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
from __future__ import annotations

from abc import ABC, abstractmethod
from typing import List

from ai21.models.assistant.assistant import Optimization
from ai21.models.assistant.run import ToolOutput
from ai21.models.responses.run_response import RunResponse
from ai21.types import NOT_GIVEN, NotGiven
from ai21.utils.typing import remove_not_given


class Runs(ABC):
_module_name = "runs"

@abstractmethod
def create(
self,
*,
thread_id: str,
assistant_id: str,
description: str | NotGiven = NOT_GIVEN,
optimization: Optimization | NotGiven = NOT_GIVEN,
**kwargs,
) -> RunResponse:
pass

def _create_body(
self,
*,
thread_id: str,
assistant_id: str,
description: str | NotGiven,
optimization: str | NotGiven,
**kwargs,
) -> dict:
return remove_not_given(
{
"thread_id": thread_id,
"assistant_id": assistant_id,
"description": description,
"optimization": optimization,
**kwargs,
}
)

@abstractmethod
def retrieve(
self,
*,
thread_id: str,
run_id: str,
) -> RunResponse:
pass

@abstractmethod
def cancel(
self,
*,
thread_id: str,
run_id: str,
) -> RunResponse:
pass

@abstractmethod
def submit_tool_outputs(self, *, thread_id: str, run_id: str, tool_outputs: List[ToolOutput]) -> RunResponse:
pass
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,9 @@
from abc import ABC, abstractmethod
from typing import List

from ai21.clients.common.assistant.messages import Messages
from ai21.clients.common.beta.assistant.messages import Messages
from ai21.models.assistant.message import Message
from ai21.models.responses.thread_response import Thread
from ai21.models.responses.thread_response import ThreadResponse


class Threads(ABC):
Expand All @@ -17,9 +17,9 @@ def create(
self,
messages: List[Message],
**kwargs,
) -> Thread:
) -> ThreadResponse:
pass

@abstractmethod
def retrieve(self, thread_id: str) -> Thread:
def retrieve(self, thread_id: str) -> ThreadResponse:
pass
Empty file.
Original file line number Diff line number Diff line change
Expand Up @@ -2,44 +2,42 @@

from typing import List

from ai21.clients.common.assistant.assistants import Assistants
from ai21.clients.common.beta.assistant.assistants import Assistants
from ai21.clients.studio.resources.studio_resource import (
AsyncStudioResource,
StudioResource,
)
from ai21.models.assistant.assistant import Tool, ToolResources
from ai21.models.responses.assistant_response import Assistant, ListAssistant
from ai21.models.responses.assistant_response import AssistantResponse, ListAssistant
from ai21.types import NotGiven, NOT_GIVEN


class StudioAssistant(StudioResource, Assistants):
class Assistant(StudioResource, Assistants):
def create(
self,
name: str,
*,
description: str | NotGiven = NOT_GIVEN,
optimization: str | NotGiven = NOT_GIVEN,
avatar: str | NotGiven = NOT_GIVEN,
models: List[str] | NotGiven = NOT_GIVEN,
tools: List[Tool] | NotGiven = NOT_GIVEN,
tool_resources: ToolResources | NotGiven = NOT_GIVEN,
**kwargs,
) -> Assistant:
) -> AssistantResponse:
body = self._create_body(
name=name,
description=description,
optimization=optimization,
avatar=avatar,
models=models,
tools=tools,
tool_resources=tool_resources,
**kwargs,
)

return self._post(path=f"/{self._module_name}", body=body, response_cls=Assistant)
return self._post(path=f"/{self._module_name}", body=body, response_cls=AssistantResponse)

def retrieve(self, assistant_id: str) -> Assistant:
return self._get(path=f"/{self._module_name}/{assistant_id}", response_cls=Assistant)
def retrieve(self, assistant_id: str) -> AssistantResponse:
return self._get(path=f"/{self._module_name}/{assistant_id}", response_cls=AssistantResponse)

def list(self) -> ListAssistant:
return self._get(path=f"/{self._module_name}", response_cls=ListAssistant)
Expand All @@ -51,54 +49,50 @@ def modify(
name: str | NotGiven = NOT_GIVEN,
description: str | NotGiven = NOT_GIVEN,
optimization: str | NotGiven = NOT_GIVEN,
avatar: str | NotGiven = NOT_GIVEN,
is_archived: bool | NotGiven = NOT_GIVEN,
models: List[str] | NotGiven = NOT_GIVEN,
tools: List[Tool] | NotGiven = NOT_GIVEN,
tool_resources: ToolResources | NotGiven = NOT_GIVEN,
) -> Assistant:
) -> AssistantResponse:
body = self._create_body(
name=name,
description=description,
optimization=optimization,
avatar=avatar,
is_archived=is_archived,
models=models,
tools=tools,
tool_resources=tool_resources,
)

return self._patch(path=f"/{self._module_name}/{assistant_id}", body=body, response_cls=Assistant)
return self._patch(path=f"/{self._module_name}/{assistant_id}", body=body, response_cls=AssistantResponse)


class AsyncStudioAssistant(AsyncStudioResource, Assistants):
class AsyncAssistant(AsyncStudioResource, Assistants):
async def create(
self,
name: str,
*,
description: str | NotGiven = NOT_GIVEN,
optimization: str | NotGiven = NOT_GIVEN,
avatar: str | NotGiven = NOT_GIVEN,
models: List[str] | NotGiven = NOT_GIVEN,
tools: List[Tool] | NotGiven = NOT_GIVEN,
tool_resources: ToolResources | NotGiven = NOT_GIVEN,
**kwargs,
) -> Assistant:
) -> AssistantResponse:
body = self._create_body(
name=name,
description=description,
optimization=optimization,
avatar=avatar,
models=models,
tools=tools,
tool_resources=tool_resources,
**kwargs,
)

return await self._post(path=f"/{self._module_name}", body=body, response_cls=Assistant)
return await self._post(path=f"/{self._module_name}", body=body, response_cls=AssistantResponse)

async def retrieve(self, assistant_id: str) -> Assistant:
return await self._get(path=f"/{self._module_name}/{assistant_id}", response_cls=Assistant)
async def retrieve(self, assistant_id: str) -> AssistantResponse:
return await self._get(path=f"/{self._module_name}/{assistant_id}", response_cls=AssistantResponse)

async def list(self) -> ListAssistant:
return await self._get(path=f"/{self._module_name}", response_cls=ListAssistant)
Expand All @@ -110,21 +104,19 @@ async def modify(
name: str | NotGiven = NOT_GIVEN,
description: str | NotGiven = NOT_GIVEN,
optimization: str | NotGiven = NOT_GIVEN,
avatar: str | NotGiven = NOT_GIVEN,
is_archived: bool | NotGiven = NOT_GIVEN,
models: List[str] | NotGiven = NOT_GIVEN,
tools: List[Tool] | NotGiven = NOT_GIVEN,
tool_resources: ToolResources | NotGiven = NOT_GIVEN,
) -> Assistant:
) -> AssistantResponse:
body = self._create_body(
name=name,
description=description,
optimization=optimization,
avatar=avatar,
is_archived=is_archived,
models=models,
tools=tools,
tool_resources=tool_resources,
)

return await self._patch(path=f"/{self._module_name}/{assistant_id}", body=body, response_cls=Assistant)
return await self._patch(path=f"/{self._module_name}/{assistant_id}", body=body, response_cls=AssistantResponse)
Loading
Loading