Skip to content

Commit

Permalink
openai[patch]: Assign message id in ChatOpenAI (langchain-ai#17837)
Browse files Browse the repository at this point in the history
  • Loading branch information
nfcampos authored and gkorland committed Mar 30, 2024
1 parent dc9da5a commit 42d9c4e
Showing 1 changed file with 21 additions and 12 deletions.
33 changes: 21 additions & 12 deletions libs/partners/openai/langchain_openai/chat_models/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,8 +92,9 @@ def _convert_dict_to_message(_dict: Mapping[str, Any]) -> BaseMessage:
The LangChain message.
"""
role = _dict.get("role")
id_ = _dict.get("id")
if role == "user":
return HumanMessage(content=_dict.get("content", ""))
return HumanMessage(content=_dict.get("content", ""), id=id_)
elif role == "assistant":
# Fix for azure
# Also OpenAI returns None for tool invocations
Expand All @@ -103,11 +104,13 @@ def _convert_dict_to_message(_dict: Mapping[str, Any]) -> BaseMessage:
additional_kwargs["function_call"] = dict(function_call)
if tool_calls := _dict.get("tool_calls"):
additional_kwargs["tool_calls"] = tool_calls
return AIMessage(content=content, additional_kwargs=additional_kwargs)
return AIMessage(content=content, additional_kwargs=additional_kwargs, id=id_)
elif role == "system":
return SystemMessage(content=_dict.get("content", ""))
return SystemMessage(content=_dict.get("content", ""), id=id_)
elif role == "function":
return FunctionMessage(content=_dict.get("content", ""), name=_dict.get("name"))
return FunctionMessage(
content=_dict.get("content", ""), name=_dict.get("name"), id=id_
)
elif role == "tool":
additional_kwargs = {}
if "name" in _dict:
Expand All @@ -116,9 +119,10 @@ def _convert_dict_to_message(_dict: Mapping[str, Any]) -> BaseMessage:
content=_dict.get("content", ""),
tool_call_id=_dict.get("tool_call_id"),
additional_kwargs=additional_kwargs,
id=id_,
)
else:
return ChatMessage(content=_dict.get("content", ""), role=role)
return ChatMessage(content=_dict.get("content", ""), role=role, id=id_)


def _convert_message_to_dict(message: BaseMessage) -> dict:
Expand Down Expand Up @@ -171,6 +175,7 @@ def _convert_message_to_dict(message: BaseMessage) -> dict:
def _convert_delta_to_message_chunk(
_dict: Mapping[str, Any], default_class: Type[BaseMessageChunk]
) -> BaseMessageChunk:
id_ = _dict.get("id")
role = cast(str, _dict.get("role"))
content = cast(str, _dict.get("content") or "")
additional_kwargs: Dict = {}
Expand All @@ -183,19 +188,23 @@ def _convert_delta_to_message_chunk(
additional_kwargs["tool_calls"] = _dict["tool_calls"]

if role == "user" or default_class == HumanMessageChunk:
return HumanMessageChunk(content=content)
return HumanMessageChunk(content=content, id=id_)
elif role == "assistant" or default_class == AIMessageChunk:
return AIMessageChunk(content=content, additional_kwargs=additional_kwargs)
return AIMessageChunk(
content=content, additional_kwargs=additional_kwargs, id=id_
)
elif role == "system" or default_class == SystemMessageChunk:
return SystemMessageChunk(content=content)
return SystemMessageChunk(content=content, id=id_)
elif role == "function" or default_class == FunctionMessageChunk:
return FunctionMessageChunk(content=content, name=_dict["name"])
return FunctionMessageChunk(content=content, name=_dict["name"], id=id_)
elif role == "tool" or default_class == ToolMessageChunk:
return ToolMessageChunk(content=content, tool_call_id=_dict["tool_call_id"])
return ToolMessageChunk(
content=content, tool_call_id=_dict["tool_call_id"], id=id_
)
elif role or default_class == ChatMessageChunk:
return ChatMessageChunk(content=content, role=role)
return ChatMessageChunk(content=content, role=role, id=id_)
else:
return default_class(content=content) # type: ignore
return default_class(content=content, id=id_) # type: ignore


class _FunctionCall(TypedDict):
Expand Down

0 comments on commit 42d9c4e

Please sign in to comment.