Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion ai21/clients/common/assistant/assistants.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ def list(self) -> ListAssistant:
pass

@abstractmethod
def get(self, assistant_id: str) -> Assistant:
def retrieve(self, assistant_id: str) -> Assistant:
pass

@abstractmethod
Expand Down
25 changes: 25 additions & 0 deletions ai21/clients/common/assistant/messages.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
from __future__ import annotations

from abc import ABC, abstractmethod

from ai21.models.assistant.message import ThreadMessageRole, MessageContentText
from ai21.models.responses.message_response import MessageResponse, ListMessageResponse


class Messages(ABC):
_module_name = "messages"

@abstractmethod
def create(
self,
thread_id: str,
*,
role: ThreadMessageRole,
content: MessageContentText,
**kwargs,
) -> MessageResponse:
pass

@abstractmethod
def list(self, thread_id: str) -> ListMessageResponse:
pass
10 changes: 8 additions & 2 deletions ai21/clients/common/assistant/threads.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,17 +3,23 @@
from abc import ABC, abstractmethod
from typing import List

from ai21.clients.common.assistant.messages import Messages
from ai21.models.assistant.message import Message
from ai21.models.responses.thread_response import Thread


class Threads(ABC):
_module_name = "threads"
messages: Messages

@abstractmethod
def create(self, messages: List[Message], **kwargs) -> Thread:
def create(
self,
messages: List[Message],
**kwargs,
) -> Thread:
pass

@abstractmethod
def get(self, thread_id: str) -> Thread:
def retrieve(self, thread_id: str) -> Thread:
pass
4 changes: 2 additions & 2 deletions ai21/clients/studio/resources/assistant/studio_assistant.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ def create(

return self._post(path=f"/{self._module_name}", body=body, response_cls=Assistant)

def get(self, assistant_id: str) -> Assistant:
def retrieve(self, assistant_id: str) -> Assistant:
return self._get(path=f"/{self._module_name}/{assistant_id}", response_cls=Assistant)

def list(self) -> ListAssistant:
Expand Down Expand Up @@ -97,7 +97,7 @@ async def create(

return await self._post(path=f"/{self._module_name}", body=body, response_cls=Assistant)

async def get(self, assistant_id: str) -> Assistant:
async def retrieve(self, assistant_id: str) -> Assistant:
return await self._get(path=f"/{self._module_name}/{assistant_id}", response_cls=Assistant)

async def list(self) -> ListAssistant:
Expand Down
29 changes: 25 additions & 4 deletions ai21/clients/studio/resources/assistant/studio_thread.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,26 +3,47 @@
from typing import List

from ai21.clients.common.assistant.threads import Threads
from ai21.clients.studio.resources.assistant.studio_thread_message import StudioThreadMessage, AsyncStudioThreadMessage
from ai21.clients.studio.resources.studio_resource import StudioResource, AsyncStudioResource
from ai21.http_client.async_http_client import AsyncAI21HTTPClient
from ai21.http_client.http_client import AI21HTTPClient
from ai21.models.assistant.message import Message
from ai21.models.responses.thread_response import Thread


class StudioThread(StudioResource, Threads):
def create(self, messages: List[Message], **kwargs) -> Thread:
def __init__(self, client: AI21HTTPClient):
super().__init__(client)

self.messages = StudioThreadMessage(client)

def create(
self,
messages: List[Message],
**kwargs,
) -> Thread:
body = dict(messages=messages)

return self._post(path=f"/{self._module_name}", body=body, response_cls=Thread)

def get(self, thread_id: str) -> Thread:
def retrieve(self, thread_id: str) -> Thread:
return self._get(path=f"/{self._module_name}/{thread_id}", response_cls=Thread)


class AsyncStudioThread(AsyncStudioResource, Threads):
async def create(self, messages: List[Message], **kwargs) -> Thread:
def __init__(self, client: AsyncAI21HTTPClient):
super().__init__(client)

self.messages = AsyncStudioThreadMessage(client)

async def create(
self,
messages: List[Message],
**kwargs,
) -> Thread:
body = dict(messages=messages)

return await self._post(path=f"/{self._module_name}", body=body, response_cls=Thread)

async def get(self, thread_id: str) -> Thread:
async def retrieve(self, thread_id: str) -> Thread:
return await self._get(path=f"/{self._module_name}/{thread_id}", response_cls=Thread)
48 changes: 48 additions & 0 deletions ai21/clients/studio/resources/assistant/studio_thread_message.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
from __future__ import annotations

from ai21.clients.common.assistant.messages import Messages
from ai21.clients.studio.resources.studio_resource import StudioResource, AsyncStudioResource
from ai21.models.assistant.message import ThreadMessageRole, MessageContentText
from ai21.models.responses.message_response import MessageResponse, ListMessageResponse


class StudioThreadMessage(StudioResource, Messages):
def create(
self,
thread_id: str,
*,
role: ThreadMessageRole,
content: MessageContentText,
**kwargs,
) -> MessageResponse:
body = dict(
role=role,
content=content,
)

return self._post(path=f"/threads/{thread_id}/{self._module_name}", body=body, response_cls=MessageResponse)

def list(self, thread_id: str) -> ListMessageResponse:
return self._get(path=f"/threads/{thread_id}/{self._module_name}", response_cls=ListMessageResponse)


class AsyncStudioThreadMessage(AsyncStudioResource, Messages):
async def create(
self,
thread_id: str,
*,
role: ThreadMessageRole,
content: MessageContentText,
**kwargs,
) -> MessageResponse:
body = dict(
role=role,
content=content,
)

return await self._post(
path=f"/threads/{thread_id}/{self._module_name}", body=body, response_cls=MessageResponse
)

async def list(self, thread_id: str) -> ListMessageResponse:
return await self._get(path=f"/threads/{thread_id}/{self._module_name}", response_cls=ListMessageResponse)
16 changes: 1 addition & 15 deletions ai21/models/assistant/message.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,7 @@
from datetime import datetime
from typing import Literal, Optional
from typing import Literal

from typing_extensions import TypedDict

from ai21.models.ai21_base_model import AI21BaseModel

ThreadMessageRole = Literal["assistant", "user"]


Expand All @@ -16,14 +13,3 @@ class MessageContentText(TypedDict):
class Message(TypedDict):
role: ThreadMessageRole
content: MessageContentText


class MessageResponse(AI21BaseModel):
id: str
created_at: datetime
updated_at: datetime
object: Literal["message"] = "message"
role: ThreadMessageRole
content: MessageContentText
run_id: Optional[str] = None
assistant_id: Optional[str] = None
20 changes: 20 additions & 0 deletions ai21/models/responses/message_response.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
from datetime import datetime
from typing import Literal, Optional, List

from ai21.models.ai21_base_model import AI21BaseModel
from ai21.models.assistant.message import ThreadMessageRole, MessageContentText


class MessageResponse(AI21BaseModel):
id: str
created_at: datetime
updated_at: datetime
object: Literal["message"] = "message"
role: ThreadMessageRole
content: MessageContentText
run_id: Optional[str] = None
assistant_id: Optional[str] = None


class ListMessageResponse(AI21BaseModel):
results: List[MessageResponse]
Loading