Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions ai21/clients/common/beta/assistant/messages.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

from abc import ABC, abstractmethod

from ai21.models.assistant.message import ThreadMessageRole, MessageContentText
from ai21.models.assistant.message import ThreadMessageRole, ThreadMessageContent
from ai21.models.responses.message_response import MessageResponse, ListMessageResponse


Expand All @@ -15,7 +15,7 @@ def create(
thread_id: str,
*,
role: ThreadMessageRole,
content: MessageContentText,
content: ThreadMessageContent,
**kwargs,
) -> MessageResponse:
pass
Expand Down
6 changes: 3 additions & 3 deletions ai21/clients/studio/resources/beta/assistant/thread.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from ai21.clients.studio.resources.studio_resource import StudioResource, AsyncStudioResource
from ai21.http_client.async_http_client import AsyncAI21HTTPClient
from ai21.http_client.http_client import AI21HTTPClient
from ai21.models.assistant.message import Message
from ai21.models.assistant.message import Message, modify_message_content
from ai21.models.responses.thread_response import ThreadResponse


Expand All @@ -24,7 +24,7 @@ def create(
messages: List[Message],
**kwargs,
) -> ThreadResponse:
body = dict(messages=messages)
body = dict(messages=[modify_message_content(message) for message in messages])

return self._post(path=f"/{self._module_name}", body=body, response_cls=ThreadResponse)

Expand All @@ -44,7 +44,7 @@ async def create(
messages: List[Message],
**kwargs,
) -> ThreadResponse:
body = dict(messages=messages)
body = dict(messages=[modify_message_content(message) for message in messages])

return await self._post(path=f"/{self._module_name}", body=body, response_cls=ThreadResponse)

Expand Down
17 changes: 6 additions & 11 deletions ai21/clients/studio/resources/beta/assistant/thread_messages.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
from __future__ import annotations


from ai21.clients.common.beta.assistant.messages import BaseMessages
from ai21.clients.studio.resources.studio_resource import StudioResource, AsyncStudioResource
from ai21.models.assistant.message import ThreadMessageRole, MessageContentText
from ai21.models.assistant.message import ThreadMessageRole, modify_message_content, Message, ThreadMessageContent
from ai21.models.responses.message_response import MessageResponse, ListMessageResponse


Expand All @@ -12,13 +13,10 @@ def create(
thread_id: str,
*,
role: ThreadMessageRole,
content: MessageContentText,
content: ThreadMessageContent,
**kwargs,
) -> MessageResponse:
body = dict(
role=role,
content=content,
)
body = modify_message_content(Message(role=role, content=content))

return self._post(path=f"/threads/{thread_id}/{self._module_name}", body=body, response_cls=MessageResponse)

Expand All @@ -32,13 +30,10 @@ async def create(
thread_id: str,
*,
role: ThreadMessageRole,
content: MessageContentText,
content: ThreadMessageContent,
**kwargs,
) -> MessageResponse:
body = dict(
role=role,
content=content,
)
body = modify_message_content(Message(role=role, content=content))

return await self._post(
path=f"/threads/{thread_id}/{self._module_name}", body=body, response_cls=MessageResponse
Expand Down
21 changes: 18 additions & 3 deletions ai21/models/assistant/message.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,30 @@
from typing import Literal
from __future__ import annotations
from typing import Literal, Union

from typing_extensions import TypedDict


ThreadMessageRole = Literal["assistant", "user"]


class MessageContentText(TypedDict):
class ThreadMessageContentText(TypedDict):
type: Literal["text"]
text: str


ThreadMessageContent = Union[str, ThreadMessageContentText]


class Message(TypedDict):
role: ThreadMessageRole
content: MessageContentText
content: ThreadMessageContent


def modify_message_content(message: Message) -> Message:
role = message["role"]
content = message["content"]

if isinstance(content, str):
content = ThreadMessageContentText(type="text", text=content)

return Message(role=role, content=content)
4 changes: 2 additions & 2 deletions ai21/models/responses/message_response.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from typing import Literal, Optional, List

from ai21.models.ai21_base_model import AI21BaseModel
from ai21.models.assistant.message import ThreadMessageRole, MessageContentText
from ai21.models.assistant.message import ThreadMessageRole, ThreadMessageContentText


class MessageResponse(AI21BaseModel):
Expand All @@ -11,7 +11,7 @@ class MessageResponse(AI21BaseModel):
updated_at: datetime
object: Literal["message"] = "message"
role: ThreadMessageRole
content: MessageContentText
content: ThreadMessageContentText
run_id: Optional[str] = None
assistant_id: Optional[str] = None

Expand Down
5 changes: 1 addition & 4 deletions examples/studio/assistant/assistant.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,10 +14,7 @@ def main():
messages=[
{
"role": "user",
"content": {
"type": "text",
"text": "Hello",
},
"content": "Hello",
},
]
)
Expand Down
5 changes: 1 addition & 4 deletions examples/studio/assistant/async_assistant.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,7 @@ async def main():
messages=[
{
"role": "user",
"content": {
"type": "text",
"text": "Hello",
},
"content": "Hello",
},
]
)
Expand Down
Loading