Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: add to call response types #310

Merged
merged 4 commits into from
Jun 7, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 16 additions & 1 deletion mirascope/anthropic/calls.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,12 @@


class AnthropicCall(
BaseCall[AnthropicCallResponse, AnthropicCallResponseChunk, AnthropicTool]
BaseCall[
AnthropicCallResponse,
AnthropicCallResponseChunk,
AnthropicTool,
MessageParam,
]
):
"""A base class for calling Anthropic's Claude models.

Expand Down Expand Up @@ -70,6 +75,7 @@ def call(
A `AnthropicCallResponse` instance.
"""
messages, kwargs, tool_types = self._setup_anthropic_kwargs(kwargs)
user_message_param = self._get_possible_user_message(messages)
client = get_wrapped_client(
Anthropic(api_key=self.api_key, base_url=self.base_url), self
)
Expand All @@ -87,6 +93,7 @@ def call(
)
return AnthropicCallResponse(
response=message,
user_message_param=user_message_param,
tool_types=tool_types,
start_time=start_time,
end_time=datetime.datetime.now().timestamp() * 1000,
Expand All @@ -108,6 +115,7 @@ async def call_async(
A `AnthropicCallResponse` instance.
"""
messages, kwargs, tool_types = self._setup_anthropic_kwargs(kwargs)
user_message_param = self._get_possible_user_message(messages)
client = get_wrapped_async_client(
AsyncAnthropic(api_key=self.api_key, base_url=self.base_url), self
)
Expand All @@ -126,6 +134,7 @@ async def call_async(
)
return AnthropicCallResponse(
response=message,
user_message_param=user_message_param,
tool_types=tool_types,
start_time=start_time,
end_time=datetime.datetime.now().timestamp() * 1000,
Expand All @@ -147,6 +156,7 @@ def stream(
An `AnthropicCallResponseChunk` for each chunk of the response.
"""
messages, kwargs, tool_types = self._setup_anthropic_kwargs(kwargs)
user_message_param = self._get_possible_user_message(messages)
client = get_wrapped_client(
Anthropic(api_key=self.api_key, base_url=self.base_url), self
)
Expand All @@ -162,13 +172,15 @@ def stream(
for chunk in message_stream:
yield AnthropicCallResponseChunk(
chunk=chunk, # type: ignore
user_message_param=user_message_param,
tool_types=tool_types,
response_format=self.call_params.response_format,
)
else:
for chunk in stream: # type: ignore
yield AnthropicCallResponseChunk(
chunk=chunk, # type: ignore
user_message_param=user_message_param,
tool_types=tool_types,
response_format=self.call_params.response_format,
)
Expand All @@ -187,6 +199,7 @@ async def stream_async(
An `AnthropicCallResponseChunk` for each chunk of the response.
"""
messages, kwargs, tool_types = self._setup_anthropic_kwargs(kwargs)
user_message_param = self._get_possible_user_message(messages)
client = get_wrapped_async_client(
AsyncAnthropic(api_key=self.api_key, base_url=self.base_url), self
)
Expand All @@ -203,13 +216,15 @@ async def stream_async(
async for chunk in message_stream: # type: ignore
yield AnthropicCallResponseChunk(
chunk=chunk, # type: ignore
user_message_param=user_message_param,
tool_types=tool_types,
response_format=self.call_params.response_format,
)
else:
async for chunk in stream: # type: ignore
yield AnthropicCallResponseChunk(
chunk=chunk, # type: ignore
user_message_param=user_message_param,
tool_types=tool_types,
response_format=self.call_params.response_format,
)
Expand Down
2 changes: 2 additions & 0 deletions mirascope/anthropic/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,7 @@ class BookRecommender(AnthropicCall):
"""

response_format: Optional[Literal["json"]] = None
user_message_param: Optional[MessageParam] = None

@property
def message_param(self) -> MessageParam:
Expand Down Expand Up @@ -212,6 +213,7 @@ class Math(AnthropicCall):
"""

response_format: Optional[Literal["json"]] = None
user_message_param: Optional[MessageParam] = None

@property
def type(
Expand Down
16 changes: 14 additions & 2 deletions mirascope/base/calls.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,16 +21,22 @@

from .prompts import BasePrompt, BasePromptT
from .tools import BaseTool
from .types import BaseCallParams, BaseCallResponse, BaseCallResponseChunk, BaseConfig
from .types import (
BaseCallParams,
BaseCallResponse,
BaseCallResponseChunk,
BaseConfig,
)

BaseCallResponseT = TypeVar("BaseCallResponseT", bound=BaseCallResponse)
BaseCallResponseChunkT = TypeVar("BaseCallResponseChunkT", bound=BaseCallResponseChunk)
BaseToolT = TypeVar("BaseToolT", bound=BaseTool)
MessageParamT = TypeVar("MessageParamT", bound=Any)


class BaseCall(
BasePrompt,
Generic[BaseCallResponseT, BaseCallResponseChunkT, BaseToolT],
Generic[BaseCallResponseT, BaseCallResponseChunkT, BaseToolT, MessageParamT],
ABC,
):
"""The base class abstract interface for calling LLMs."""
Expand Down Expand Up @@ -146,5 +152,11 @@ def _setup(
kwargs["tools"] = [tool_type.tool_schema() for tool_type in tool_types]
return kwargs, tool_types

def _get_possible_user_message(
self, messages: list[Any]
) -> Optional[MessageParamT]:
"""Returns the most recent message if it's a user message, otherwise `None`."""
return messages[-1] if messages[-1]["role"] == "user" else None


BaseCallT = TypeVar("BaseCallT", bound=BaseCall)
3 changes: 3 additions & 0 deletions mirascope/base/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,6 +131,7 @@ class BaseCallResponse(BaseModel, Generic[ResponseT, BaseToolT], ABC):
"""

response: ResponseT
user_message_param: Optional[Any] = None
tool_types: Optional[list[Type[BaseToolT]]] = None
start_time: float # The start time of the completion in ms
end_time: float # The end time of the completion in ms
Expand All @@ -139,6 +140,7 @@ class BaseCallResponse(BaseModel, Generic[ResponseT, BaseToolT], ABC):
model_config = ConfigDict(extra="allow", arbitrary_types_allowed=True)

@property
@abstractmethod
def message_param(self) -> Any:
"""Returns the assistant's response as a message parameter."""
... # pragma: no cover
Expand Down Expand Up @@ -207,6 +209,7 @@ class BaseCallResponseChunk(BaseModel, Generic[ChunkT, BaseToolT], ABC):
"""

chunk: ChunkT
user_message_param: Optional[Any] = None
tool_types: Optional[list[Type[BaseToolT]]] = None
cost: Optional[float] = None # The cost of the completion in dollars
model_config = ConfigDict(extra="allow", arbitrary_types_allowed=True)
Expand Down
20 changes: 17 additions & 3 deletions mirascope/cohere/calls.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,9 @@
from .utils import cohere_api_calculate_cost


class CohereCall(BaseCall[CohereCallResponse, CohereCallResponseChunk, CohereTool]):
class CohereCall(
BaseCall[CohereCallResponse, CohereCallResponseChunk, CohereTool, ChatMessage]
):
"""A base class for calling Cohere's chat models.

Example:
Expand Down Expand Up @@ -81,6 +83,7 @@ def call(
)
return CohereCallResponse(
response=response,
user_message_param=ChatMessage(message=message, role="user"), # type: ignore
tool_types=tool_types,
start_time=start_time,
end_time=datetime.datetime.now().timestamp() * 1000,
Expand Down Expand Up @@ -120,6 +123,7 @@ async def call_async(
)
return CohereCallResponse(
response=response,
user_message_param=ChatMessage(message=message, role="user"), # type: ignore
tool_types=tool_types,
start_time=start_time,
end_time=datetime.datetime.now().timestamp() * 1000,
Expand Down Expand Up @@ -149,8 +153,13 @@ def stream(
response_chunk_type=CohereCallResponseChunk,
tool_types=tool_types,
)
user_message_param = ChatMessage(message=message, role="user") # type: ignore
for event in chat_stream(message=message, **kwargs):
yield CohereCallResponseChunk(chunk=event, tool_types=tool_types)
yield CohereCallResponseChunk(
chunk=event,
user_message_param=user_message_param,
tool_types=tool_types,
)

@retry
async def stream_async(
Expand All @@ -176,8 +185,13 @@ async def stream_async(
response_chunk_type=CohereCallResponseChunk,
tool_types=tool_types,
)
user_message_param = ChatMessage(message=message, role="user") # type: ignore
async for event in chat_stream(message=message, **kwargs):
yield CohereCallResponseChunk(chunk=event, tool_types=tool_types)
yield CohereCallResponseChunk(
chunk=event,
user_message_param=user_message_param,
tool_types=tool_types,
)

############################## PRIVATE METHODS ###################################

Expand Down
2 changes: 2 additions & 0 deletions mirascope/cohere/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,6 +115,7 @@ class BookRecommender(CohereCall):

# We need to skip validation since it's a pydantic_v1 model and breaks validation.
response: SkipValidation[NonStreamedChatResponse]
user_message_param: SkipValidation[Optional[ChatMessage]] = None

@property
def message_param(self) -> ChatMessage:
Expand Down Expand Up @@ -256,6 +257,7 @@ class Math(CohereCall):
"""

chunk: SkipValidation[StreamedChatResponse]
user_message_param: SkipValidation[Optional[ChatMessage]] = None

@property
def event_type(
Expand Down
36 changes: 28 additions & 8 deletions mirascope/gemini/calls.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
)

from google.generativeai import GenerativeModel # type: ignore
from google.generativeai.types import ContentsType # type: ignore
from google.generativeai.types import ContentDict, ContentsType # type: ignore
from tenacity import AsyncRetrying, Retrying

from ..base import BaseCall, retry
Expand All @@ -32,7 +32,9 @@
logger = logging.getLogger("mirascope")


class GeminiCall(BaseCall[GeminiCallResponse, GeminiCallResponseChunk, GeminiTool]):
class GeminiCall(
BaseCall[GeminiCallResponse, GeminiCallResponseChunk, GeminiTool, ContentDict]
):
'''A class for prompting Google's Gemini Chat API.

This prompt supports the message types: USER, MODEL, TOOL
Expand Down Expand Up @@ -103,15 +105,18 @@ def call(
tool_types=tool_types,
model_name=model_name,
)
messages = self.messages()
user_message_param = self._get_possible_user_message(messages)
start_time = datetime.datetime.now().timestamp() * 1000
response = generate_content(
self.messages(),
messages,
stream=False,
tools=kwargs.pop("tools") if "tools" in kwargs else None,
**kwargs,
)
return GeminiCallResponse(
response=response,
user_message_param=user_message_param,
tool_types=tool_types,
start_time=start_time,
end_time=datetime.datetime.now().timestamp() * 1000,
Expand Down Expand Up @@ -145,15 +150,18 @@ async def call_async(
tool_types=tool_types,
model_name=model_name,
)
messages = self.messages()
user_message_param = self._get_possible_user_message(messages)
start_time = datetime.datetime.now().timestamp() * 1000
response = await generate_content_async(
self.messages(),
messages,
stream=False,
tools=kwargs.pop("tools") if "tools" in kwargs else None,
**kwargs,
)
return GeminiCallResponse(
response=response,
user_message_param=user_message_param,
tool_types=tool_types,
start_time=start_time,
end_time=datetime.datetime.now().timestamp() * 1000,
Expand Down Expand Up @@ -185,14 +193,20 @@ def stream(
tool_types=tool_types,
model_name=model_name,
)
messages = self.messages()
user_message_param = self._get_possible_user_message(messages)
stream = generate_content(
self.messages(),
messages,
stream=True,
tools=kwargs.pop("tools") if "tools" in kwargs else None,
**kwargs,
)
for chunk in stream:
yield GeminiCallResponseChunk(chunk=chunk, tool_types=tool_types)
yield GeminiCallResponseChunk(
chunk=chunk,
user_message_param=user_message_param,
tool_types=tool_types,
)

@retry
async def stream_async(
Expand Down Expand Up @@ -220,13 +234,19 @@ async def stream_async(
tool_types=tool_types,
model_name=model_name,
)
messages = self.messages()
user_message_param = self._get_possible_user_message(messages)
stream = generate_content_async(
self.messages(),
messages,
stream=True,
tools=kwargs.pop("tools") if "tools" in kwargs else None,
**kwargs,
)
if inspect.iscoroutine(stream):
stream = await stream
async for chunk in stream:
yield GeminiCallResponseChunk(chunk=chunk, tool_types=tool_types)
yield GeminiCallResponseChunk(
chunk=chunk,
user_message_param=user_message_param,
tool_types=tool_types,
)
4 changes: 4 additions & 0 deletions mirascope/gemini/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,8 @@ class BookRecommender(GeminiPrompt):
```
"""

user_message_param: Optional[ContentDict] = None

@property
def message_param(self) -> ContentDict:
"""Returns the models's response as a message parameter."""
Expand Down Expand Up @@ -188,6 +190,8 @@ class Math(GeminiCall):
```
"""

user_message_param: Optional[ContentDict] = None

@property
def content(self) -> str:
"""Returns the chunk content for the 0th choice."""
Expand Down
Loading
Loading