Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions nemoguardrails/logging/callbacks.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,6 +109,8 @@ async def on_chat_model_start(
if msg.type == "human"
else "Bot"
if msg.type == "ai"
else "Tool"
if msg.type == "tool"
else "System"
)
+ "[/]"
Expand Down
46 changes: 45 additions & 1 deletion tests/test_callbacks.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,11 +13,12 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from unittest.mock import patch
from uuid import uuid4

import pytest
from langchain.schema import Generation, LLMResult
from langchain_core.messages import AIMessage
from langchain_core.messages import AIMessage, HumanMessage, SystemMessage, ToolMessage
from langchain_core.outputs import ChatGeneration

from nemoguardrails.context import explain_info_var, llm_call_info_var, llm_stats_var
Expand Down Expand Up @@ -168,3 +169,46 @@ async def test_multiple_generations_token_accumulation():
assert llm_stats.get_stat("total_tokens") == 19
assert llm_stats.get_stat("total_prompt_tokens") == 12
assert llm_stats.get_stat("total_completion_tokens") == 7


@pytest.mark.asyncio
async def test_tool_message_labeling_in_logging():
"""Test that tool messages are labeled as 'Tool' in logging output."""
llm_call_info = LLMCallInfo()
llm_call_info_var.set(llm_call_info)

llm_stats = LLMStats()
llm_stats_var.set(llm_stats)

explain_info = ExplainInfo()
explain_info_var.set(explain_info)

handler = LoggingCallbackHandler()

messages = [
HumanMessage(content="Hello"),
AIMessage(content="Hi there"),
SystemMessage(content="System message"),
ToolMessage(content="Tool result", tool_call_id="test_tool_call"),
]

with patch("nemoguardrails.logging.callbacks.log") as mock_log:
await handler.on_chat_model_start(
serialized={},
messages=[messages],
run_id=uuid4(),
)

mock_log.info.assert_called()

logged_prompt = None
for call in mock_log.info.call_args_list:
if "Prompt Messages" in str(call):
logged_prompt = call[0][1]
break

assert logged_prompt is not None
assert "[cyan]User[/]" in logged_prompt
assert "[cyan]Bot[/]" in logged_prompt
assert "[cyan]System[/]" in logged_prompt
assert "[cyan]Tool[/]" in logged_prompt