Skip to content

Commit

Permalink
Merge pull request #866 from PrefectHQ/message-bump
Browse files Browse the repository at this point in the history
Support openai 1.14+
  • Loading branch information
jlowin committed Mar 13, 2024
2 parents 9601033 + e5dcd41 commit ed4ae7a
Show file tree
Hide file tree
Showing 7 changed files with 48 additions and 28 deletions.
7 changes: 5 additions & 2 deletions src/marvin/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,9 @@
)
from .ai.images import paint, image
from .ai.audio import speak_async, speak, speech, transcribe, transcribe_async
from . import beta

if settings.auto_import_beta_modules:
from . import beta

try:
from ._version import version as __version__
Expand Down Expand Up @@ -48,9 +50,10 @@
"transcribe",
"transcribe_async",
# --- beta ---
"beta",
]

if settings.auto_import_beta_modules:
__all__.append("beta")

# compatibility with Marvin v1
ai_fn = fn
Expand Down
4 changes: 2 additions & 2 deletions src/marvin/beta/assistants/assistants.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
)
from marvin.utilities.logging import get_logger

from .threads import Thread, ThreadMessage
from .threads import Message, Thread

if TYPE_CHECKING:
from .runs import Run
Expand Down Expand Up @@ -81,7 +81,7 @@ async def say_async(
thread: Optional[Thread] = None,
return_user_message: bool = False,
**run_kwargs,
) -> list[ThreadMessage]:
) -> list[Message]:
"""
A convenience method for adding a user message to the assistant's
default thread, running the assistant, and returning the assistant's
Expand Down
18 changes: 12 additions & 6 deletions src/marvin/beta/assistants/formatting.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,13 @@
from datetime import datetime

import openai
from openai.types.beta.threads import ThreadMessage

# for openai < 1.14.0
try:
from openai.types.beta.threads import ThreadMessage as Message
# for openai >= 1.14.0
except ImportError:
from openai.types.beta.threads import Message
from openai.types.beta.threads.runs.run_step import RunStep
from rich import box
from rich.console import Console
Expand Down Expand Up @@ -38,7 +44,7 @@
# for obj in combined:
# if isinstance(obj, RunStep):
# pprint_run_step(obj)
# elif isinstance(obj, ThreadMessage):
# elif isinstance(obj, Message):
# pprint_message(obj)


Expand Down Expand Up @@ -135,14 +141,14 @@ def download_temp_file(file_id: str, suffix: str = None):
return temp_file_path


def pprint_message(message: ThreadMessage):
def pprint_message(message: Message):
"""
Pretty-prints a single message using the rich library, highlighting the
speaker's role, the message text, any available images, and the message
timestamp in a panel format.
Args:
message (ThreadMessage): A message object
message (Message): A message object
"""
console = Console()
role_colors = {
Expand Down Expand Up @@ -192,7 +198,7 @@ def pprint_message(message: ThreadMessage):
console.print(panel)


def pprint_messages(messages: list[ThreadMessage]):
def pprint_messages(messages: list[Message]):
"""
Iterates over a list of messages and pretty-prints each one.
Expand All @@ -201,7 +207,7 @@ def pprint_messages(messages: list[ThreadMessage]):
timestamp in a panel format.
Args:
messages (list[ThreadMessage]): A list of ThreadMessage objects to be
messages (list[Message]): A list of Message objects to be
printed.
"""
for message in messages:
Expand Down
25 changes: 15 additions & 10 deletions src/marvin/beta/assistants/threads.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,12 @@
import time
from typing import TYPE_CHECKING, Callable, Optional, Union

from openai.types.beta.threads import ThreadMessage
# for openai < 1.14.0
try:
from openai.types.beta.threads import ThreadMessage as Message
# for openai >= 1.14.0
except ImportError:
from openai.types.beta.threads import Message
from pydantic import BaseModel, Field, PrivateAttr

import marvin.utilities.openai
Expand Down Expand Up @@ -34,7 +39,7 @@ class Thread(BaseModel, ExposeSyncMethodsMixin):

id: Optional[str] = None
metadata: dict = {}
messages: list[ThreadMessage] = Field([], repr=False)
messages: list[Message] = Field([], repr=False)

def __enter__(self):
return run_sync(self.__aenter__)
Expand Down Expand Up @@ -67,7 +72,7 @@ async def create_async(self, messages: list[str] = None):
@expose_sync_method("add")
async def add_async(
self, message: str, file_paths: Optional[list[str]] = None, role: str = "user"
) -> ThreadMessage:
) -> Message:
"""
Add a user message to the thread.
"""
Expand All @@ -87,7 +92,7 @@ async def add_async(
response = await client.beta.threads.messages.create(
thread_id=self.id, role=role, content=message, file_ids=file_ids
)
return ThreadMessage.model_validate(response.model_dump())
return response

@expose_sync_method("get_messages")
async def get_messages_async(
Expand All @@ -96,7 +101,7 @@ async def get_messages_async(
before_message: Optional[str] = None,
after_message: Optional[str] = None,
json_compatible: bool = False,
) -> list[Union[ThreadMessage, dict]]:
) -> list[Union[Message, dict]]:
"""
Asynchronously retrieves messages from the thread.
Expand All @@ -107,12 +112,12 @@ async def get_messages_async(
after_message (str, optional): The ID of the message to start the list from,
retrieving messages sent after this one.
json_compatible (bool, optional): If True, returns messages as dictionaries.
If False, returns messages as ThreadMessage
If False, returns messages as Message
objects. Default is False.
Returns:
list[Union[ThreadMessage, dict]]: A list of messages from the thread, either
as dictionaries or ThreadMessage objects,
list[Union[Message, dict]]: A list of messages from the thread, either
as dictionaries or Message objects,
depending on the value of json_compatible.
"""

Expand All @@ -130,7 +135,7 @@ async def get_messages_async(
order="desc",
)

T = dict if json_compatible else ThreadMessage
T = dict if json_compatible else Message

return parse_as(list[T], reversed(response.model_dump()["data"]))

Expand Down Expand Up @@ -238,7 +243,7 @@ async def run_async(self, interval_seconds: int = None):
logger.error(f"Error refreshing thread: {exc}")
await asyncio.sleep(interval_seconds)

async def get_latest_messages(self) -> list[ThreadMessage]:
async def get_latest_messages(self) -> list[Message]:
limit = 20

# Loop to get all new messages in batches of 20
Expand Down
4 changes: 2 additions & 2 deletions src/marvin/beta/chat_ui/chat_ui.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from fastapi.responses import HTMLResponse
from fastapi.staticfiles import StaticFiles

from marvin.beta.assistants.threads import Thread, ThreadMessage
from marvin.beta.assistants.threads import Message, Thread


def find_free_port():
Expand Down Expand Up @@ -47,7 +47,7 @@ async def post_message(
message_queue.put(dict(thread_id=thread_id, message=content))

@app.get("/api/messages/")
async def get_messages(thread_id: str) -> list[ThreadMessage]:
async def get_messages(thread_id: str) -> list[Message]:
thread = Thread(id=thread_id)
return await thread.get_messages_async(limit=100)

Expand Down
6 changes: 6 additions & 0 deletions src/marvin/settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -253,6 +253,12 @@ class Settings(MarvinSettings):
# ai settings
ai: AISettings = Field(default_factory=AISettings)

# beta settings
auto_import_beta_modules: bool = Field(
True,
description="If True, the marvin.beta module will be automatically imported when marvin is imported.",
)

# log settings
log_level: str = Field(
default="INFO",
Expand Down
12 changes: 6 additions & 6 deletions src/marvin/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,14 +89,14 @@ class ImageUrl(MarvinType):
detail: str = "auto"


class MessageImageURLContent(MarvinType):
class ImageFileContentBlock(MarvinType):
"""Schema for messages containing images"""

type: Literal["image_url"] = "image_url"
image_url: ImageUrl


class MessageTextContent(MarvinType):
class TextContentBlock(MarvinType):
"""Schema for messages containing text"""

type: Literal["text"] = "text"
Expand All @@ -106,7 +106,7 @@ class MessageTextContent(MarvinType):
class BaseMessage(MarvinType):
"""Base schema for messages"""

content: Union[str, list[Union[MessageImageURLContent, MessageTextContent]]]
content: Union[str, list[Union[ImageFileContentBlock, TextContentBlock]]]
role: str


Expand Down Expand Up @@ -305,15 +305,15 @@ def from_path(cls, path: Union[str, Path]) -> "Image":
def from_url(cls, url: str) -> "Image":
return cls(url=url)

def to_message_content(self) -> MessageImageURLContent:
def to_message_content(self) -> ImageFileContentBlock:
if self.url:
return MessageImageURLContent(
return ImageFileContentBlock(
image_url=dict(url=self.url, detail=self.detail)
)
elif self.data:
b64_image = base64.b64encode(self.data).decode("utf-8")
path = f"data:image/{self.format};base64,{b64_image}"
return MessageImageURLContent(image_url=dict(url=path, detail=self.detail))
return ImageFileContentBlock(image_url=dict(url=path, detail=self.detail))
else:
raise ValueError("Image source is not specified")

Expand Down

0 comments on commit ed4ae7a

Please sign in to comment.