Skip to content

Commit

Permalink
mistral[minor]: Function calling and with_structured_output (langchai…
Browse files Browse the repository at this point in the history
  • Loading branch information
baskaryan authored and gkorland committed Mar 30, 2024
1 parent 4e00ffe commit cfc0e81
Show file tree
Hide file tree
Showing 3 changed files with 604 additions and 181 deletions.
237 changes: 214 additions & 23 deletions libs/partners/mistralai/langchain_mistralai/chat_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import importlib.util
import logging
from operator import itemgetter
from typing import (
Any,
AsyncIterator,
Expand All @@ -10,15 +11,19 @@
Iterator,
List,
Optional,
Sequence,
Tuple,
Type,
Union,
cast,
)

from langchain_core._api import beta
from langchain_core.callbacks import (
AsyncCallbackManagerForLLMRun,
CallbackManagerForLLMRun,
)
from langchain_core.language_models import LanguageModelInput
from langchain_core.language_models.chat_models import (
BaseChatModel,
agenerate_from_stream,
Expand All @@ -36,19 +41,22 @@
HumanMessageChunk,
SystemMessage,
SystemMessageChunk,
ToolMessage,
)
from langchain_core.outputs import (
ChatGeneration,
ChatGenerationChunk,
ChatResult,
from langchain_core.output_parsers.base import OutputParserLike
from langchain_core.output_parsers.openai_tools import (
JsonOutputKeyToolsParser,
PydanticToolsParser,
)
from langchain_core.pydantic_v1 import Field, SecretStr, root_validator
from langchain_core.outputs import ChatGeneration, ChatGenerationChunk, ChatResult
from langchain_core.pydantic_v1 import BaseModel, Field, SecretStr, root_validator
from langchain_core.runnables import Runnable, RunnableMap, RunnablePassthrough
from langchain_core.tools import BaseTool
from langchain_core.utils import convert_to_secret_str, get_from_dict_or_env
from langchain_core.utils.function_calling import convert_to_openai_tool
from mistralai.async_client import MistralAsyncClient
from mistralai.client import MistralClient
from mistralai.constants import (
ENDPOINT as DEFAULT_MISTRAL_ENDPOINT,
)
from mistralai.constants import ENDPOINT as DEFAULT_MISTRAL_ENDPOINT
from mistralai.exceptions import (
MistralAPIException,
MistralConnectionException,
Expand All @@ -57,12 +65,8 @@
from mistralai.models.chat_completion import (
ChatCompletionResponse as MistralChatCompletionResponse,
)
from mistralai.models.chat_completion import (
ChatMessage as MistralChatMessage,
)
from mistralai.models.chat_completion import (
DeltaMessage as MistralDeltaMessage,
)
from mistralai.models.chat_completion import ChatMessage as MistralChatMessage
from mistralai.models.chat_completion import DeltaMessage as MistralDeltaMessage

logger = logging.getLogger(__name__)

Expand All @@ -89,14 +93,22 @@ def _convert_mistral_chat_message_to_message(
_message: MistralChatMessage,
) -> BaseMessage:
role = _message.role
content = cast(Union[str, List], _message.content)
if role == "user":
return HumanMessage(content=_message.content)
return HumanMessage(content=content)
elif role == "assistant":
return AIMessage(content=_message.content)
additional_kwargs: Dict = {}
if hasattr(_message, "tool_calls") and getattr(_message, "tool_calls"):
additional_kwargs["tool_calls"] = [
tc.model_dump() for tc in getattr(_message, "tool_calls")
]
return AIMessage(content=content, additional_kwargs=additional_kwargs)
elif role == "system":
return SystemMessage(content=_message.content)
return SystemMessage(content=content)
elif role == "tool":
return ToolMessage(content=content, name=_message.name) # type: ignore[attr-defined]
else:
return ChatMessage(content=_message.content, role=role)
return ChatMessage(content=content, role=role)


async def acompletion_with_retry(
Expand All @@ -119,14 +131,19 @@ async def _completion_with_retry(**kwargs: Any) -> Any:


def _convert_delta_to_message_chunk(
_obj: MistralDeltaMessage, default_class: Type[BaseMessageChunk]
_delta: MistralDeltaMessage, default_class: Type[BaseMessageChunk]
) -> BaseMessageChunk:
role = getattr(_obj, "role")
content = getattr(_obj, "content", "")
role = getattr(_delta, "role")
content = getattr(_delta, "content", "")
if role == "user" or default_class == HumanMessageChunk:
return HumanMessageChunk(content=content)
elif role == "assistant" or default_class == AIMessageChunk:
return AIMessageChunk(content=content)
additional_kwargs: Dict = {}
if hasattr(_delta, "tool_calls") and getattr(_delta, "tool_calls"):
additional_kwargs["tool_calls"] = [
tc.model_dump() for tc in getattr(_delta, "tool_calls")
]
return AIMessageChunk(content=content, additional_kwargs=additional_kwargs)
elif role == "system" or default_class == SystemMessageChunk:
return SystemMessageChunk(content=content)
elif role or default_class == ChatMessageChunk:
Expand All @@ -143,9 +160,26 @@ def _convert_message_to_mistral_chat_message(
elif isinstance(message, HumanMessage):
mistral_message = MistralChatMessage(role="user", content=message.content)
elif isinstance(message, AIMessage):
mistral_message = MistralChatMessage(role="assistant", content=message.content)
if "tool_calls" in message.additional_kwargs:
from mistralai.models.chat_completion import ( # type: ignore[attr-defined]
ToolCall as MistralToolCall,
)

tool_calls = [
MistralToolCall.model_validate(tc)
for tc in message.additional_kwargs["tool_calls"]
]
else:
tool_calls = None
mistral_message = MistralChatMessage(
role="assistant", content=message.content, tool_calls=tool_calls
)
elif isinstance(message, SystemMessage):
mistral_message = MistralChatMessage(role="system", content=message.content)
elif isinstance(message, ToolMessage):
mistral_message = MistralChatMessage(
role="tool", content=message.content, name=message.name
)
else:
raise ValueError(f"Got unknown type {message}")
return mistral_message
Expand Down Expand Up @@ -368,6 +402,163 @@ async def _agenerate(
)
return self._create_chat_result(response)

def bind_tools(
self,
tools: Sequence[Union[Dict[str, Any], Type[BaseModel], Callable, BaseTool]],
**kwargs: Any,
) -> Runnable[LanguageModelInput, BaseMessage]:
"""Bind tool-like objects to this chat model.
Assumes model is compatible with OpenAI tool-calling API.
Args:
tools: A list of tool definitions to bind to this chat model.
Can be a dictionary, pydantic model, callable, or BaseTool. Pydantic
models, callables, and BaseTools will be automatically converted to
their schema dictionary representation.
tool_choice: Which tool to require the model to call.
Must be the name of the single provided function or
"auto" to automatically determine which function to call
(if any), or a dict of the form:
{"type": "function", "function": {"name": <<tool_name>>}}.
**kwargs: Any additional parameters to pass to the
:class:`~langchain.runnable.Runnable` constructor.
"""

formatted_tools = [convert_to_openai_tool(tool) for tool in tools]
return super().bind(tools=formatted_tools, **kwargs)

@beta()
def with_structured_output(
self,
schema: Union[Dict, Type[BaseModel]],
*,
include_raw: bool = False,
**kwargs: Any,
) -> Runnable[LanguageModelInput, Union[Dict, BaseModel]]:
"""Model wrapper that returns outputs formatted to match the given schema.
Args:
schema: The output schema as a dict or a Pydantic class. If a Pydantic class
then the model output will be an object of that class. If a dict then
the model output will be a dict. With a Pydantic class the returned
attributes will be validated, whereas with a dict they will not be. If
`method` is "function_calling" and `schema` is a dict, then the dict
must match the OpenAI function-calling spec.
include_raw: If False then only the parsed structured output is returned. If
an error occurs during model output parsing it will be raised. If True
then both the raw model response (a BaseMessage) and the parsed model
response will be returned. If an error occurs during output parsing it
will be caught and returned as well. The final output is always a dict
with keys "raw", "parsed", and "parsing_error".
Returns:
A Runnable that takes any ChatModel input and returns as output:
If include_raw is True then a dict with keys:
raw: BaseMessage
parsed: Optional[_DictOrPydantic]
parsing_error: Optional[BaseException]
If include_raw is False then just _DictOrPydantic is returned,
where _DictOrPydantic depends on the schema:
If schema is a Pydantic class then _DictOrPydantic is the Pydantic
class.
If schema is a dict then _DictOrPydantic is a dict.
Example: Function-calling, Pydantic schema (method="function_calling", include_raw=False):
.. code-block:: python
from langchain_mistralai import ChatMistralAI
from langchain_core.pydantic_v1 import BaseModel
class AnswerWithJustification(BaseModel):
'''An answer to the user question along with justification for the answer.'''
answer: str
justification: str
llm = ChatMistralAI(model="mistral-large-latest", temperature=0)
structured_llm = llm.with_structured_output(AnswerWithJustification)
structured_llm.invoke("What weighs more a pound of bricks or a pound of feathers")
# -> AnswerWithJustification(
# answer='They weigh the same',
# justification='Both a pound of bricks and a pound of feathers weigh one pound. The weight is the same, but the volume or density of the objects may differ.'
# )
Example: Function-calling, Pydantic schema (method="function_calling", include_raw=True):
.. code-block:: python
from langchain_mistralai import ChatMistralAI
from langchain_core.pydantic_v1 import BaseModel
class AnswerWithJustification(BaseModel):
'''An answer to the user question along with justification for the answer.'''
answer: str
justification: str
llm = ChatMistralAI(model="mistral-large-latest", temperature=0)
structured_llm = llm.with_structured_output(AnswerWithJustification, include_raw=True)
structured_llm.invoke("What weighs more a pound of bricks or a pound of feathers")
# -> {
# 'raw': AIMessage(content='', additional_kwargs={'tool_calls': [{'id': 'call_Ao02pnFYXD6GN1yzc0uXPsvF', 'function': {'arguments': '{"answer":"They weigh the same.","justification":"Both a pound of bricks and a pound of feathers weigh one pound. The weight is the same, but the volume or density of the objects may differ."}', 'name': 'AnswerWithJustification'}, 'type': 'function'}]}),
# 'parsed': AnswerWithJustification(answer='They weigh the same.', justification='Both a pound of bricks and a pound of feathers weigh one pound. The weight is the same, but the volume or density of the objects may differ.'),
# 'parsing_error': None
# }
Example: Function-calling, dict schema (method="function_calling", include_raw=False):
.. code-block:: python
from langchain_mistralai import ChatMistralAI
from langchain_core.pydantic_v1 import BaseModel
from langchain_core.utils.function_calling import convert_to_openai_tool
class AnswerWithJustification(BaseModel):
'''An answer to the user question along with justification for the answer.'''
answer: str
justification: str
dict_schema = convert_to_openai_tool(AnswerWithJustification)
llm = ChatMistralAI(model="mistral-large-latest", temperature=0)
structured_llm = llm.with_structured_output(dict_schema)
structured_llm.invoke("What weighs more a pound of bricks or a pound of feathers")
# -> {
# 'answer': 'They weigh the same',
# 'justification': 'Both a pound of bricks and a pound of feathers weigh one pound. The weight is the same, but the volume and density of the two substances differ.'
# }
""" # noqa: E501
if kwargs:
raise ValueError(f"Received unsupported arguments {kwargs}")
is_pydantic_schema = isinstance(schema, type) and issubclass(schema, BaseModel)
llm = self.bind_tools([schema], tool_choice="any")
if is_pydantic_schema:
output_parser: OutputParserLike = PydanticToolsParser(
tools=[schema], first_tool_only=True
)
else:
key_name = convert_to_openai_tool(schema)["function"]["name"]
output_parser = JsonOutputKeyToolsParser(
key_name=key_name, first_tool_only=True
)

if include_raw:
parser_assign = RunnablePassthrough.assign(
parsed=itemgetter("raw") | output_parser, parsing_error=lambda _: None
)
parser_none = RunnablePassthrough.assign(parsed=lambda _: None)
parser_with_fallback = parser_assign.with_fallbacks(
[parser_none], exception_key="parsing_error"
)
return RunnableMap(raw=llm) | parser_with_fallback
else:
return llm | output_parser

@property
def _identifying_params(self) -> Dict[str, Any]:
"""Get the identifying parameters."""
Expand Down
Loading

0 comments on commit cfc0e81

Please sign in to comment.