Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support openai 1.14+ #866

Merged
merged 4 commits into from
Mar 13, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 5 additions & 2 deletions src/marvin/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,9 @@
)
from .ai.images import paint, image
from .ai.audio import speak_async, speak, speech, transcribe, transcribe_async
from . import beta

if settings.auto_import_beta_modules:
from . import beta

try:
from ._version import version as __version__
Expand Down Expand Up @@ -48,9 +50,10 @@
"transcribe",
"transcribe_async",
# --- beta ---
"beta",
]

if settings.auto_import_beta_modules:
__all__.append("beta")

# compatibility with Marvin v1
ai_fn = fn
Expand Down
4 changes: 2 additions & 2 deletions src/marvin/beta/assistants/assistants.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
)
from marvin.utilities.logging import get_logger

from .threads import Thread, ThreadMessage
from .threads import Message, Thread

if TYPE_CHECKING:
from .runs import Run
Expand Down Expand Up @@ -81,7 +81,7 @@ async def say_async(
thread: Optional[Thread] = None,
return_user_message: bool = False,
**run_kwargs,
) -> list[ThreadMessage]:
) -> list[Message]:
"""
A convenience method for adding a user message to the assistant's
default thread, running the assistant, and returning the assistant's
Expand Down
18 changes: 12 additions & 6 deletions src/marvin/beta/assistants/formatting.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,13 @@
from datetime import datetime

import openai
from openai.types.beta.threads import ThreadMessage

# for openai < 1.14.0
try:
from openai.types.beta.threads import ThreadMessage as Message
# for openai >= 1.14.0
except ImportError:
from openai.types.beta.threads import Message
from openai.types.beta.threads.runs.run_step import RunStep
from rich import box
from rich.console import Console
Expand Down Expand Up @@ -38,7 +44,7 @@
# for obj in combined:
# if isinstance(obj, RunStep):
# pprint_run_step(obj)
# elif isinstance(obj, ThreadMessage):
# elif isinstance(obj, Message):
# pprint_message(obj)


Expand Down Expand Up @@ -135,14 +141,14 @@ def download_temp_file(file_id: str, suffix: str = None):
return temp_file_path


def pprint_message(message: ThreadMessage):
def pprint_message(message: Message):
"""
Pretty-prints a single message using the rich library, highlighting the
speaker's role, the message text, any available images, and the message
timestamp in a panel format.

Args:
message (ThreadMessage): A message object
message (Message): A message object
"""
console = Console()
role_colors = {
Expand Down Expand Up @@ -192,7 +198,7 @@ def pprint_message(message: ThreadMessage):
console.print(panel)


def pprint_messages(messages: list[ThreadMessage]):
def pprint_messages(messages: list[Message]):
"""
Iterates over a list of messages and pretty-prints each one.

Expand All @@ -201,7 +207,7 @@ def pprint_messages(messages: list[ThreadMessage]):
timestamp in a panel format.

Args:
messages (list[ThreadMessage]): A list of ThreadMessage objects to be
messages (list[Message]): A list of Message objects to be
printed.
"""
for message in messages:
Expand Down
25 changes: 15 additions & 10 deletions src/marvin/beta/assistants/threads.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,12 @@
import time
from typing import TYPE_CHECKING, Callable, Optional, Union

from openai.types.beta.threads import ThreadMessage
# for openai < 1.14.0
try:
from openai.types.beta.threads import ThreadMessage as Message
# for openai >= 1.14.0
except ImportError:
from openai.types.beta.threads import Message
from pydantic import BaseModel, Field, PrivateAttr

import marvin.utilities.openai
Expand Down Expand Up @@ -34,7 +39,7 @@ class Thread(BaseModel, ExposeSyncMethodsMixin):

id: Optional[str] = None
metadata: dict = {}
messages: list[ThreadMessage] = Field([], repr=False)
messages: list[Message] = Field([], repr=False)

def __enter__(self):
return run_sync(self.__aenter__)
Expand Down Expand Up @@ -67,7 +72,7 @@ async def create_async(self, messages: list[str] = None):
@expose_sync_method("add")
async def add_async(
self, message: str, file_paths: Optional[list[str]] = None, role: str = "user"
) -> ThreadMessage:
) -> Message:
"""
Add a user message to the thread.
"""
Expand All @@ -87,7 +92,7 @@ async def add_async(
response = await client.beta.threads.messages.create(
thread_id=self.id, role=role, content=message, file_ids=file_ids
)
return ThreadMessage.model_validate(response.model_dump())
return response

@expose_sync_method("get_messages")
async def get_messages_async(
Expand All @@ -96,7 +101,7 @@ async def get_messages_async(
before_message: Optional[str] = None,
after_message: Optional[str] = None,
json_compatible: bool = False,
) -> list[Union[ThreadMessage, dict]]:
) -> list[Union[Message, dict]]:
"""
Asynchronously retrieves messages from the thread.

Expand All @@ -107,12 +112,12 @@ async def get_messages_async(
after_message (str, optional): The ID of the message to start the list from,
retrieving messages sent after this one.
json_compatible (bool, optional): If True, returns messages as dictionaries.
If False, returns messages as ThreadMessage
If False, returns messages as Message
objects. Default is False.

Returns:
list[Union[ThreadMessage, dict]]: A list of messages from the thread, either
as dictionaries or ThreadMessage objects,
list[Union[Message, dict]]: A list of messages from the thread, either
as dictionaries or Message objects,
depending on the value of json_compatible.
"""

Expand All @@ -130,7 +135,7 @@ async def get_messages_async(
order="desc",
)

T = dict if json_compatible else ThreadMessage
T = dict if json_compatible else Message

return parse_as(list[T], reversed(response.model_dump()["data"]))

Expand Down Expand Up @@ -238,7 +243,7 @@ async def run_async(self, interval_seconds: int = None):
logger.error(f"Error refreshing thread: {exc}")
await asyncio.sleep(interval_seconds)

async def get_latest_messages(self) -> list[ThreadMessage]:
async def get_latest_messages(self) -> list[Message]:
limit = 20

# Loop to get all new messages in batches of 20
Expand Down
4 changes: 2 additions & 2 deletions src/marvin/beta/chat_ui/chat_ui.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from fastapi.responses import HTMLResponse
from fastapi.staticfiles import StaticFiles

from marvin.beta.assistants.threads import Thread, ThreadMessage
from marvin.beta.assistants.threads import Message, Thread


def find_free_port():
Expand Down Expand Up @@ -47,7 +47,7 @@ async def post_message(
message_queue.put(dict(thread_id=thread_id, message=content))

@app.get("/api/messages/")
async def get_messages(thread_id: str) -> list[ThreadMessage]:
async def get_messages(thread_id: str) -> list[Message]:
thread = Thread(id=thread_id)
return await thread.get_messages_async(limit=100)

Expand Down
6 changes: 6 additions & 0 deletions src/marvin/settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -253,6 +253,12 @@ class Settings(MarvinSettings):
# ai settings
ai: AISettings = Field(default_factory=AISettings)

# beta settings
auto_import_beta_modules: bool = Field(
True,
description="If True, the marvin.beta module will be automatically imported when marvin is imported.",
)

# log settings
log_level: str = Field(
default="INFO",
Expand Down
12 changes: 6 additions & 6 deletions src/marvin/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,14 +89,14 @@ class ImageUrl(MarvinType):
detail: str = "auto"


class MessageImageURLContent(MarvinType):
class ImageFileContentBlock(MarvinType):
"""Schema for messages containing images"""

type: Literal["image_url"] = "image_url"
image_url: ImageUrl


class MessageTextContent(MarvinType):
class TextContentBlock(MarvinType):
"""Schema for messages containing text"""

type: Literal["text"] = "text"
Expand All @@ -106,7 +106,7 @@ class MessageTextContent(MarvinType):
class BaseMessage(MarvinType):
"""Base schema for messages"""

content: Union[str, list[Union[MessageImageURLContent, MessageTextContent]]]
content: Union[str, list[Union[ImageFileContentBlock, TextContentBlock]]]
role: str


Expand Down Expand Up @@ -305,15 +305,15 @@ def from_path(cls, path: Union[str, Path]) -> "Image":
def from_url(cls, url: str) -> "Image":
return cls(url=url)

def to_message_content(self) -> MessageImageURLContent:
def to_message_content(self) -> ImageFileContentBlock:
if self.url:
return MessageImageURLContent(
return ImageFileContentBlock(
image_url=dict(url=self.url, detail=self.detail)
)
elif self.data:
b64_image = base64.b64encode(self.data).decode("utf-8")
path = f"data:image/{self.format};base64,{b64_image}"
return MessageImageURLContent(image_url=dict(url=path, detail=self.detail))
return ImageFileContentBlock(image_url=dict(url=path, detail=self.detail))
else:
raise ValueError("Image source is not specified")

Expand Down