Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 28 additions & 0 deletions nemoguardrails/actions/llm/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
from nemoguardrails.colang.v2_x.runtime.flows import InternalEvent, InternalEvents
from nemoguardrails.context import (
llm_call_info_var,
llm_response_metadata_var,
reasoning_trace_var,
tool_calls_var,
)
Expand Down Expand Up @@ -85,6 +86,7 @@ async def llm_call(
response = await _invoke_with_message_list(llm, prompt, all_callbacks, stop)

_store_tool_calls(response)
_store_response_metadata(response)
return _extract_content(response)


Expand Down Expand Up @@ -173,6 +175,20 @@ def _store_tool_calls(response) -> None:
tool_calls_var.set(tool_calls)


def _store_response_metadata(response) -> None:
"""Store response metadata excluding content for metadata preservation."""
if hasattr(response, "model_fields"):
metadata = {}
for field_name in response.model_fields:
if (
field_name != "content"
): # Exclude content since it may be modified by rails
metadata[field_name] = getattr(response, field_name)
llm_response_metadata_var.set(metadata)
else:
llm_response_metadata_var.set(None)


def _extract_content(response) -> str:
"""Extract text content from response."""
if hasattr(response, "content"):
Expand Down Expand Up @@ -655,3 +671,15 @@ def get_and_clear_tool_calls_contextvar() -> Optional[list]:
tool_calls_var.set(None)
return tool_calls
return None


def get_and_clear_response_metadata_contextvar() -> Optional[dict]:
"""Get the current response metadata and clear it from the context.

Returns:
Optional[dict]: The response metadata if it exists, None otherwise.
"""
if metadata := llm_response_metadata_var.get():
llm_response_metadata_var.set(None)
return metadata
return None
5 changes: 5 additions & 0 deletions nemoguardrails/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,3 +42,8 @@
tool_calls_var: contextvars.ContextVar[Optional[list]] = contextvars.ContextVar(
"tool_calls", default=None
)

# The response metadata from the current LLM response.
llm_response_metadata_var: contextvars.ContextVar[
Optional[dict]
] = contextvars.ContextVar("llm_response_metadata", default=None)
84 changes: 67 additions & 17 deletions nemoguardrails/integrations/langchain/runnable_rails.py
Original file line number Diff line number Diff line change
Expand Up @@ -393,11 +393,21 @@ def _format_passthrough_output(self, result: Any, context: Dict[str, Any]) -> An
return passthrough_output

def _format_chat_prompt_output(
self, result: Any, tool_calls: Optional[list] = None
self,
result: Any,
tool_calls: Optional[list] = None,
metadata: Optional[dict] = None,
) -> AIMessage:
"""Format output for ChatPromptValue input."""
content = self._extract_content_from_result(result)
if tool_calls:

if metadata and isinstance(metadata, dict):
metadata_copy = metadata.copy()
metadata_copy.pop("content", None)
if tool_calls:
metadata_copy["tool_calls"] = tool_calls
return AIMessage(content=content, **metadata_copy)
elif tool_calls:
return AIMessage(content=content, tool_calls=tool_calls)
return AIMessage(content=content)

Expand All @@ -406,11 +416,21 @@ def _format_string_prompt_output(self, result: Any) -> str:
return self._extract_content_from_result(result)

def _format_message_output(
self, result: Any, tool_calls: Optional[list] = None
self,
result: Any,
tool_calls: Optional[list] = None,
metadata: Optional[dict] = None,
) -> AIMessage:
"""Format output for BaseMessage input types."""
content = self._extract_content_from_result(result)
if tool_calls:

if metadata and isinstance(metadata, dict):
metadata_copy = metadata.copy()
metadata_copy.pop("content", None)
if tool_calls:
metadata_copy["tool_calls"] = tool_calls
return AIMessage(content=content, **metadata_copy)
elif tool_calls:
return AIMessage(content=content, tool_calls=tool_calls)
return AIMessage(content=content)

Expand All @@ -434,25 +454,50 @@ def _format_dict_output_for_dict_message_list(
}

def _format_dict_output_for_base_message_list(
self, result: Any, output_key: str, tool_calls: Optional[list] = None
self,
result: Any,
output_key: str,
tool_calls: Optional[list] = None,
metadata: Optional[dict] = None,
) -> Dict[str, Any]:
"""Format dict output when user input was a list of BaseMessage objects."""
content = self._extract_content_from_result(result)
if tool_calls:

if metadata and isinstance(metadata, dict):
metadata_copy = metadata.copy()
metadata_copy.pop("content", None)
if tool_calls:
metadata_copy["tool_calls"] = tool_calls
return {output_key: AIMessage(content=content, **metadata_copy)}
elif tool_calls:
return {output_key: AIMessage(content=content, tool_calls=tool_calls)}
return {output_key: AIMessage(content=content)}

def _format_dict_output_for_base_message(
self, result: Any, output_key: str, tool_calls: Optional[list] = None
self,
result: Any,
output_key: str,
tool_calls: Optional[list] = None,
metadata: Optional[dict] = None,
) -> Dict[str, Any]:
"""Format dict output when user input was a BaseMessage."""
content = self._extract_content_from_result(result)
if tool_calls:

if metadata:
metadata_copy = metadata.copy()
if tool_calls:
metadata_copy["tool_calls"] = tool_calls
return {output_key: AIMessage(content=content, **metadata_copy)}
elif tool_calls:
return {output_key: AIMessage(content=content, tool_calls=tool_calls)}
return {output_key: AIMessage(content=content)}

def _format_dict_output(
self, input_dict: dict, result: Any, tool_calls: Optional[list] = None
self,
input_dict: dict,
result: Any,
tool_calls: Optional[list] = None,
metadata: Optional[dict] = None,
) -> Dict[str, Any]:
"""Format output for dictionary input."""
output_key = self.passthrough_bot_output_key
Expand All @@ -471,13 +516,13 @@ def _format_dict_output(
)
elif all(isinstance(msg, BaseMessage) for msg in user_input):
return self._format_dict_output_for_base_message_list(
result, output_key, tool_calls
result, output_key, tool_calls, metadata
)
else:
return {output_key: result}
elif isinstance(user_input, BaseMessage):
return self._format_dict_output_for_base_message(
result, output_key, tool_calls
result, output_key, tool_calls, metadata
)

# Generic fallback for dictionaries
Expand All @@ -490,6 +535,7 @@ def _format_output(
result: Any,
context: Dict[str, Any],
tool_calls: Optional[list] = None,
metadata: Optional[dict] = None,
) -> Any:
"""Format the output based on the input type and rails result.

Expand All @@ -512,17 +558,17 @@ def _format_output(
return self._format_passthrough_output(result, context)

if isinstance(input, ChatPromptValue):
return self._format_chat_prompt_output(result, tool_calls)
return self._format_chat_prompt_output(result, tool_calls, metadata)
elif isinstance(input, StringPromptValue):
return self._format_string_prompt_output(result)
elif isinstance(input, (HumanMessage, AIMessage, BaseMessage)):
return self._format_message_output(result, tool_calls)
return self._format_message_output(result, tool_calls, metadata)
elif isinstance(input, list) and all(
isinstance(msg, BaseMessage) for msg in input
):
return self._format_message_output(result, tool_calls)
return self._format_message_output(result, tool_calls, metadata)
elif isinstance(input, dict):
return self._format_dict_output(input, result, tool_calls)
return self._format_dict_output(input, result, tool_calls, metadata)
elif isinstance(input, str):
return self._format_string_prompt_output(result)
else:
Expand Down Expand Up @@ -669,7 +715,9 @@ def _full_rails_invoke(
result = result[0]

# Format and return the output based in input type
return self._format_output(input, result, context, res.tool_calls)
return self._format_output(
input, result, context, res.tool_calls, res.llm_metadata
)

async def ainvoke(
self,
Expand Down Expand Up @@ -731,7 +779,9 @@ async def _full_rails_ainvoke(
result = res.response

# Format and return the output based on input type
return self._format_output(input, result, context, res.tool_calls)
return self._format_output(
input, result, context, res.tool_calls, res.llm_metadata
)

def stream(
self,
Expand Down
5 changes: 5 additions & 0 deletions nemoguardrails/rails/llm/llmrails.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
from nemoguardrails.actions.llm.generation import LLMGenerationActions
from nemoguardrails.actions.llm.utils import (
get_and_clear_reasoning_trace_contextvar,
get_and_clear_response_metadata_contextvar,
get_and_clear_tool_calls_contextvar,
get_colang_history,
)
Expand Down Expand Up @@ -1086,6 +1087,7 @@ async def generate_async(
options.log.internal_events = True

tool_calls = get_and_clear_tool_calls_contextvar()
llm_metadata = get_and_clear_response_metadata_contextvar()

# If we have generation options, we prepare a GenerationResponse instance.
if options:
Expand All @@ -1106,6 +1108,9 @@ async def generate_async(
if tool_calls:
res.tool_calls = tool_calls

if llm_metadata:
res.llm_metadata = llm_metadata

if self.config.colang_version == "1.0":
# If output variables are specified, we extract their values
if options.output_vars:
Expand Down
4 changes: 4 additions & 0 deletions nemoguardrails/rails/llm/options.py
Original file line number Diff line number Diff line change
Expand Up @@ -412,6 +412,10 @@ class GenerationResponse(BaseModel):
default=None,
description="Tool calls extracted from the LLM response, if any.",
)
llm_metadata: Optional[dict] = Field(
default=None,
description="Metadata from the LLM response (additional_kwargs, response_metadata, usage_metadata, etc.)",
)


if __name__ == "__main__":
Expand Down
Loading