diff --git a/nemoguardrails/actions/llm/utils.py b/nemoguardrails/actions/llm/utils.py index d93b4dc90..487b66925 100644 --- a/nemoguardrails/actions/llm/utils.py +++ b/nemoguardrails/actions/llm/utils.py @@ -243,15 +243,18 @@ def _convert_messages_to_langchain_format(prompt: List[dict]) -> List: def _store_reasoning_traces(response) -> None: """Store reasoning traces from response in context variable. - Extracts reasoning content from response.additional_kwargs["reasoning_content"] - if available. Otherwise, falls back to extracting from tags in the - response content (and removes the tags from content). + Tries multiple extraction methods in order of preference: + 1. content_blocks with type="reasoning" (LangChain v1 standard) + 2. additional_kwargs["reasoning_content"] (provider-specific) + 3. tags in content (legacy fallback) Args: response: The LLM response object """ + reasoning_content = _extract_reasoning_from_content_blocks(response) - reasoning_content = _extract_reasoning_content(response) + if not reasoning_content: + reasoning_content = _extract_reasoning_from_additional_kwargs(response) if not reasoning_content: # Some LLM providers (e.g., certain NVIDIA models) embed reasoning in tags @@ -263,14 +266,27 @@ def _store_reasoning_traces(response) -> None: reasoning_trace_var.set(reasoning_content) -def _extract_reasoning_content(response): +def _extract_reasoning_from_content_blocks(response) -> Optional[str]: + """Extract reasoning from content_blocks with type='reasoning'. + + This is the LangChain v1 standard for structured content blocks. + """ + if hasattr(response, "content_blocks"): + for block in response.content_blocks: + if block.get("type") == "reasoning": + return block.get("reasoning") + return None + + +def _extract_reasoning_from_additional_kwargs(response) -> Optional[str]: + """Extract reasoning from additional_kwargs['reasoning_content']. + + This is used by some providers for backward compatibility. + """ if hasattr(response, "additional_kwargs"): additional_kwargs = response.additional_kwargs - if ( - isinstance(additional_kwargs, dict) - and "reasoning_content" in additional_kwargs - ): - return additional_kwargs["reasoning_content"] + if isinstance(additional_kwargs, dict): + return additional_kwargs.get("reasoning_content") return None @@ -317,10 +333,26 @@ def _extract_and_remove_think_tags(response) -> Optional[str]: def _store_tool_calls(response) -> None: """Extract and store tool calls from response in context.""" - tool_calls = getattr(response, "tool_calls", None) + tool_calls = _extract_tool_calls_from_content_blocks(response) + if not tool_calls: + tool_calls = _extract_tool_calls_from_attribute(response) tool_calls_var.set(tool_calls) +def _extract_tool_calls_from_content_blocks(response) -> List | None: + if hasattr(response, "content_blocks"): + tool_calls = [] + for block in response.content_blocks: + if block.get("type") == "tool_call": + tool_calls.append(block) + return tool_calls if tool_calls else None + return None + + +def _extract_tool_calls_from_attribute(response) -> List | None: + return getattr(response, "tool_calls", None) + + def _store_response_metadata(response) -> None: """Store response metadata excluding content for metadata preservation. diff --git a/tests/test_actions_llm_utils.py b/tests/test_actions_llm_utils.py index 8f0accbd2..b15e496e8 100644 --- a/tests/test_actions_llm_utils.py +++ b/tests/test_actions_llm_utils.py @@ -13,12 +13,30 @@ # See the License for the specific language governing permissions and # limitations under the License. +import pytest +from langchain_core.messages import AIMessage + from nemoguardrails.actions.llm.utils import ( - _extract_and_remove_think_tags, + _extract_reasoning_from_additional_kwargs, + _extract_reasoning_from_content_blocks, + _extract_tool_calls_from_attribute, + _extract_tool_calls_from_content_blocks, _infer_provider_from_module, _store_reasoning_traces, + _store_tool_calls, ) -from nemoguardrails.context import reasoning_trace_var +from nemoguardrails.context import reasoning_trace_var, tool_calls_var + + +@pytest.fixture(autouse=True) +def reset_context_vars(): + reasoning_token = reasoning_trace_var.set(None) + tool_calls_token = tool_calls_var.set(None) + + yield + + reasoning_trace_var.reset(reasoning_token) + tool_calls_var.reset(tool_calls_token) class MockOpenAILLM: @@ -131,176 +149,394 @@ class Wrapper3(Wrapper2): class MockResponse: - def __init__(self, content="", additional_kwargs=None): - self.content = content - self.additional_kwargs = additional_kwargs or {} + def __init__(self, content_blocks=None, additional_kwargs=None, tool_calls=None): + if content_blocks is not None: + self.content_blocks = content_blocks + if additional_kwargs is not None: + self.additional_kwargs = additional_kwargs + if tool_calls is not None: + self.tool_calls = tool_calls -def test_store_reasoning_traces_from_additional_kwargs(): - reasoning_trace_var.set(None) +def test_extract_reasoning_from_content_blocks_single_reasoning(): + response = MockResponse( + content_blocks=[ + {"type": "reasoning", "reasoning": "foo"}, + ] + ) + reasoning = _extract_reasoning_from_content_blocks(response) + assert reasoning == "foo" + +def test_extract_reasoning_from_content_blocks_with_text_and_reasoning(): response = MockResponse( - content="The answer is 42", - additional_kwargs={"reasoning_content": "Let me think about this..."}, + content_blocks=[ + {"type": "text", "text": "bar"}, + {"type": "reasoning", "reasoning": "Let me think about this problem..."}, + ] ) + reasoning = _extract_reasoning_from_content_blocks(response) + assert reasoning == "Let me think about this problem..." + + +def test_extract_reasoning_from_content_blocks_returns_first_reasoning(): + response = MockResponse( + content_blocks=[ + {"type": "reasoning", "reasoning": "First thought"}, + {"type": "reasoning", "reasoning": "Second thought"}, + ] + ) + reasoning = _extract_reasoning_from_content_blocks(response) + assert reasoning == "First thought" - _store_reasoning_traces(response) - assert reasoning_trace_var.get() == "Let me think about this..." +def test_extract_reasoning_from_content_blocks_no_reasoning(): + response = MockResponse( + content_blocks=[ + {"type": "text", "text": "Hello"}, + {"type": "tool_call", "name": "foo", "args": {"a": "b"}, "id": "abc_123"}, + ] + ) + reasoning = _extract_reasoning_from_content_blocks(response) + assert reasoning is None -def test_store_reasoning_traces_from_think_tags(): - reasoning_trace_var.set(None) +def test_extract_reasoning_from_content_blocks_no_attribute(): + response = MockResponse() + reasoning = _extract_reasoning_from_content_blocks(response) + assert reasoning is None + +def test_extract_reasoning_from_additional_kwargs_with_reasoning_content(): response = MockResponse( - content="Let me think about this...The answer is 42" + additional_kwargs={"reasoning_content": "Let me think about this problem..."} ) + reasoning = _extract_reasoning_from_additional_kwargs(response) + assert reasoning == "Let me think about this problem..." + + +def test_extract_reasoning_from_additional_kwargs_no_reasoning_content(): + response = MockResponse(additional_kwargs={"other_field": "some value"}) + reasoning = _extract_reasoning_from_additional_kwargs(response) + assert reasoning is None - _store_reasoning_traces(response) - assert reasoning_trace_var.get() == "Let me think about this..." - assert response.content == "The answer is 42" +def test_extract_reasoning_from_additional_kwargs_no_attribute(): + response = MockResponse() + reasoning = _extract_reasoning_from_additional_kwargs(response) + assert reasoning is None -def test_store_reasoning_traces_multiline_think_tags(): - reasoning_trace_var.set(None) +def test_extract_reasoning_from_additional_kwargs_not_dict(): + response = MockResponse(additional_kwargs="not a dict") + reasoning = _extract_reasoning_from_additional_kwargs(response) + assert reasoning is None + +def test_extract_tool_calls_from_content_blocks_single_tool_call(): + expected_tool_call = { + "type": "tool_call", + "name": "foo", + "args": {"a": "b"}, + "id": "abc_123", + } + response = MockResponse(content_blocks=[expected_tool_call]) + tool_calls = _extract_tool_calls_from_content_blocks(response) + assert tool_calls is not None + assert len(tool_calls) == 1 + assert tool_calls[0] == expected_tool_call + + +def test_extract_tool_calls_from_content_blocks_multiple_tool_calls(): response = MockResponse( - content="Step 1: Analyze the problem\nStep 2: Consider options\nStep 3: Choose solutionThe answer is 42" + content_blocks=[ + {"type": "tool_call", "name": "foo", "args": {"a": "b"}, "id": "abc_123"}, + {"type": "tool_call", "name": "bar", "args": {"c": "d"}, "id": "abc_234"}, + ] ) + tool_calls = _extract_tool_calls_from_content_blocks(response) + assert tool_calls is not None + assert len(tool_calls) == 2 + assert tool_calls[0]["name"] == "foo" + assert tool_calls[1]["name"] == "bar" - _store_reasoning_traces(response) - assert ( - reasoning_trace_var.get() - == "Step 1: Analyze the problem\nStep 2: Consider options\nStep 3: Choose solution" +def test_extract_tool_calls_from_content_blocks_mixed_content(): + response = MockResponse( + content_blocks=[ + {"type": "text", "text": "Hello"}, + {"type": "tool_call", "name": "foo", "args": {"a": "b"}, "id": "abc_123"}, + {"type": "reasoning", "reasoning": "Thinking..."}, + {"type": "tool_call", "name": "bar", "args": {"c": "d"}, "id": "abc_234"}, + ] ) - assert response.content == "The answer is 42" - + tool_calls = _extract_tool_calls_from_content_blocks(response) + assert tool_calls is not None + assert len(tool_calls) == 2 + assert tool_calls[0]["name"] == "foo" + assert tool_calls[1]["name"] == "bar" -def test_store_reasoning_traces_prefers_additional_kwargs(): - reasoning_trace_var.set(None) +def test_extract_tool_calls_from_content_blocks_no_tool_calls(): response = MockResponse( - content="This should not be usedThe answer is 42", - additional_kwargs={"reasoning_content": "This should be used"}, + content_blocks=[ + {"type": "text", "text": "Hello"}, + {"type": "reasoning", "reasoning": "Thinking..."}, + ] ) + tool_calls = _extract_tool_calls_from_content_blocks(response) + assert tool_calls is None - _store_reasoning_traces(response) - assert reasoning_trace_var.get() == "This should be used" +def test_extract_tool_calls_from_content_blocks_no_attribute(): + response = MockResponse() + tool_calls = _extract_tool_calls_from_content_blocks(response) + assert tool_calls is None -def test_store_reasoning_traces_no_reasoning_content(): - reasoning_trace_var.set(None) +def test_extract_tool_calls_from_attribute_with_tool_calls(): + response = MockResponse( + tool_calls=[ + {"type": "tool_call", "name": "foo", "args": {"a": "b"}, "id": "abc_123"}, + {"type": "tool_call", "name": "bar", "args": {"c": "d"}, "id": "abc_234"}, + ] + ) + tool_calls = _extract_tool_calls_from_attribute(response) + assert tool_calls is not None + assert len(tool_calls) == 2 + assert tool_calls[0]["name"] == "foo" + assert tool_calls[1]["name"] == "bar" - response = MockResponse(content="The answer is 42") - _store_reasoning_traces(response) +def test_extract_tool_calls_from_attribute_no_attribute(): + response = MockResponse() + tool_calls = _extract_tool_calls_from_attribute(response) + assert tool_calls is None + - assert reasoning_trace_var.get() is None +def test_store_reasoning_traces_from_content_blocks(): + response = MockResponse( + content_blocks=[ + {"type": "text", "text": "The answer is 42."}, + {"type": "reasoning", "reasoning": "Let me think about this problem..."}, + ] + ) + _store_reasoning_traces(response) + reasoning = reasoning_trace_var.get() + assert reasoning == "Let me think about this problem..." -def test_store_reasoning_traces_empty_reasoning_content(): - reasoning_trace_var.set(None) +def test_store_reasoning_traces_from_additional_kwargs(): response = MockResponse( - content="The answer is 42", additional_kwargs={"reasoning_content": ""} + additional_kwargs={"reasoning_content": "Provider specific reasoning"} ) - _store_reasoning_traces(response) - assert reasoning_trace_var.get() is None + reasoning = reasoning_trace_var.get() + assert reasoning == "Provider specific reasoning" + +def test_store_reasoning_traces_prefers_content_blocks_over_additional_kwargs(): + response = MockResponse( + content_blocks=[ + {"type": "reasoning", "reasoning": "Content blocks reasoning"}, + ], + additional_kwargs={"reasoning_content": "Additional kwargs reasoning"}, + ) + _store_reasoning_traces(response) -def test_store_reasoning_traces_incomplete_think_tags(): - reasoning_trace_var.set(None) + reasoning = reasoning_trace_var.get() + assert reasoning == "Content blocks reasoning" - response = MockResponse(content="This is incomplete") +def test_store_reasoning_traces_fallback_to_additional_kwargs(): + response = MockResponse( + content_blocks=[ + {"type": "text", "text": "No reasoning here"}, + ], + additional_kwargs={"reasoning_content": "Fallback reasoning"}, + ) _store_reasoning_traces(response) - assert reasoning_trace_var.get() is None + reasoning = reasoning_trace_var.get() + assert reasoning == "Fallback reasoning" -def test_store_reasoning_traces_no_content_attribute(): - reasoning_trace_var.set(None) +def test_store_reasoning_traces_no_reasoning(): + response = MockResponse( + content_blocks=[ + {"type": "text", "text": "Just text"}, + ] + ) + _store_reasoning_traces(response) - class ResponseWithoutContent: - def __init__(self): - self.additional_kwargs = {} + reasoning = reasoning_trace_var.get() + assert reasoning is None - response = ResponseWithoutContent() - _store_reasoning_traces(response) +def test_store_tool_calls_from_content_blocks(): + response = MockResponse( + content_blocks=[ + {"type": "text", "text": "Hello"}, + { + "type": "tool_call", + "name": "search", + "args": {"query": "weather"}, + "id": "call_1", + }, + { + "type": "tool_call", + "name": "calculator", + "args": {"expr": "2+2"}, + "id": "call_2", + }, + ] + ) + _store_tool_calls(response) + + tool_calls = tool_calls_var.get() + assert tool_calls is not None + assert len(tool_calls) == 2 + assert tool_calls[0]["name"] == "search" + assert tool_calls[1]["name"] == "calculator" - assert reasoning_trace_var.get() is None +def test_store_tool_calls_from_attribute(): + response = MockResponse( + tool_calls=[ + {"type": "tool_call", "name": "foo", "args": {"a": "b"}, "id": "abc_123"}, + {"type": "tool_call", "name": "bar", "args": {"c": "d"}, "id": "abc_234"}, + ] + ) + _store_tool_calls(response) + + tool_calls = tool_calls_var.get() + assert tool_calls is not None + assert len(tool_calls) == 2 + assert tool_calls[0]["name"] == "foo" + assert tool_calls[1]["name"] == "bar" -def test_store_reasoning_traces_removes_think_tags_with_whitespace(): - reasoning_trace_var.set(None) +def test_store_tool_calls_prefers_content_blocks_over_attribute(): response = MockResponse( - content=" reasoning here \n\n Final answer " + content_blocks=[ + {"type": "tool_call", "name": "from_blocks", "args": {}, "id": "1"}, + ], + tool_calls=[ + {"type": "tool_call", "name": "from_attribute", "args": {}, "id": "2"}, + ], ) + _store_tool_calls(response) - _store_reasoning_traces(response) + tool_calls = tool_calls_var.get() + assert tool_calls is not None + assert len(tool_calls) == 1 + assert tool_calls[0]["name"] == "from_blocks" - assert reasoning_trace_var.get() == "reasoning here" - assert response.content == "Final answer" +def test_store_tool_calls_fallback_to_attribute(): + response = MockResponse( + content_blocks=[ + {"type": "text", "text": "No tool calls here"}, + ], + tool_calls=[ + {"type": "tool_call", "name": "fallback_tool", "args": {}, "id": "1"}, + ], + ) + _store_tool_calls(response) -def test_extract_and_remove_think_tags_basic(): - response = MockResponse(content="reasoninganswer") + tool_calls = tool_calls_var.get() + assert tool_calls is not None + assert len(tool_calls) == 1 + assert tool_calls[0]["name"] == "fallback_tool" - result = _extract_and_remove_think_tags(response) - assert result == "reasoning" - assert response.content == "answer" +def test_store_tool_calls_no_tool_calls(): + response = MockResponse( + content_blocks=[ + {"type": "text", "text": "Just text"}, + ] + ) + _store_tool_calls(response) + tool_calls = tool_calls_var.get() + assert tool_calls is None -def test_extract_and_remove_think_tags_multiline(): - response = MockResponse(content="line1\nline2\nline3final answer") - result = _extract_and_remove_think_tags(response) +def test_store_reasoning_traces_with_real_aimessage_from_content_blocks(): + message = AIMessage( + content="The answer is 42.", + additional_kwargs={"reasoning_content": "Let me think about this problem..."}, + ) - assert result == "line1\nline2\nline3" - assert response.content == "final answer" + _store_reasoning_traces(message) + reasoning = reasoning_trace_var.get() + assert reasoning == "Let me think about this problem..." -def test_extract_and_remove_think_tags_no_tags(): - response = MockResponse(content="just a normal response") - result = _extract_and_remove_think_tags(response) +def test_store_reasoning_traces_with_real_aimessage_no_reasoning(): + message = AIMessage( + content="The answer is 42.", + additional_kwargs={"other_field": "some value"}, + ) - assert result is None - assert response.content == "just a normal response" + _store_reasoning_traces(message) + reasoning = reasoning_trace_var.get() + assert reasoning is None -def test_extract_and_remove_think_tags_incomplete(): - response = MockResponse(content="incomplete") - result = _extract_and_remove_think_tags(response) +def test_store_tool_calls_with_real_aimessage_from_content_blocks(): + message = AIMessage( + "", + tool_calls=[ + {"type": "tool_call", "name": "foo", "args": {"a": "b"}, "id": "abc_123"} + ], + ) - assert result is None - assert response.content == "incomplete" + _store_tool_calls(message) + tool_calls = tool_calls_var.get() + assert tool_calls is not None + assert len(tool_calls) == 1 + assert tool_calls[0]["type"] == "tool_call" + assert tool_calls[0]["name"] == "foo" + assert tool_calls[0]["args"] == {"a": "b"} + assert tool_calls[0]["id"] == "abc_123" -def test_extract_and_remove_think_tags_no_content_attribute(): - class ResponseWithoutContent: - pass - response = ResponseWithoutContent() +def test_store_tool_calls_with_real_aimessage_mixed_content(): + message = AIMessage( + "foo", + tool_calls=[ + {"type": "tool_call", "name": "foo", "args": {"a": "b"}, "id": "abc_123"} + ], + ) - result = _extract_and_remove_think_tags(response) + _store_tool_calls(message) - assert result is None + tool_calls = tool_calls_var.get() + assert tool_calls is not None + assert len(tool_calls) == 1 + assert tool_calls[0]["type"] == "tool_call" + assert tool_calls[0]["name"] == "foo" -def test_extract_and_remove_think_tags_wrong_order(): - response = MockResponse(content=" text here ") +def test_store_tool_calls_with_real_aimessage_multiple_tool_calls(): + message = AIMessage( + "", + tool_calls=[ + {"type": "tool_call", "name": "foo", "args": {"a": "b"}, "id": "abc_123"}, + {"type": "tool_call", "name": "bar", "args": {"c": "d"}, "id": "abc_234"}, + ], + ) - result = _extract_and_remove_think_tags(response) + _store_tool_calls(message) - assert result is None - assert response.content == " text here " + tool_calls = tool_calls_var.get() + assert tool_calls is not None + assert len(tool_calls) == 2 + assert tool_calls[0]["name"] == "foo" + assert tool_calls[1]["name"] == "bar"