Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
126 changes: 72 additions & 54 deletions agentrun/memory_collection/memory_conversation.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

"""

import asyncio
import json
import os
from typing import (
Expand Down Expand Up @@ -76,6 +77,7 @@ def __init__(
# 延迟初始化
self._memory_store = None
self._ots_client = None
self._init_lock = asyncio.Lock()

@staticmethod
def _default_user_id_extractor(req: Any) -> str:
Expand Down Expand Up @@ -153,10 +155,18 @@ def _default_agent_id_extractor(req: Any) -> str:
return "default_agent"

async def _get_memory_store(self):
"""获取或创建 AsyncMemoryStore 实例"""
"""获取或创建 AsyncMemoryStore 实例(双检锁,并发安全)"""
if self._memory_store is not None:
return self._memory_store

async with self._init_lock:
# 拿到锁后再检查一次,防止并发请求重复初始化
if self._memory_store is not None:
return self._memory_store
return await self._init_memory_store()

async def _init_memory_store(self):
"""内部初始化方法,由 _get_memory_store 在持锁状态下调用"""
try:
# 导入依赖
from tablestore_for_agent_memory.base.base_memory_store import (
Expand Down Expand Up @@ -228,7 +238,7 @@ async def _get_memory_store(self):
)
await self._memory_store.init_table()
await self._memory_store.init_search_index()
logger.info(f"Tables and indexes initialized successfully")
logger.info("Tables and indexes initialized successfully")
except Exception as e:
# 如果表已存在,会抛出异常,这是正常的
logger.info(
Expand Down Expand Up @@ -384,10 +394,13 @@ async def wrap_invoke_agent(
metadata={"agent_id": agent_id},
)

try:
await memory_store.put_session(session)
except Exception as e:
logger.error(f"Failed to save session: {e}", exc_info=True)
async def _put_session_bg():
try:
await memory_store.put_session(session)
except Exception as e:
logger.error(f"Failed to save session: {e}", exc_info=True)

Comment on lines +397 to +402
asyncio.create_task(_put_session_bg())

# 构建输入消息列表(包含所有历史消息)
input_messages = []
Expand Down Expand Up @@ -465,57 +478,62 @@ async def wrap_invoke_agent(
yield event

# 保存完整的对话轮次(输入 + 输出)
# 只有当有文本内容或工具调用时才保存
# 使用 fire-and-forget 避免阻塞流式响应关闭
if agent_response_content or tool_calls or tool_results:
try:
# 构建助手响应消息
assistant_message: Dict[str, Any] = {
"role": "assistant",
}

# 添加文本内容(如果有)
if agent_response_content:
assistant_message["content"] = agent_response_content
else:
# OpenAI 格式要求:如果有 tool_calls,content 可以为 null
assistant_message["content"] = None

# 添加工具调用(如果有)
if tool_calls:
assistant_message["tool_calls"] = list(
tool_calls.values()
# 构建助手响应消息
assistant_message: Dict[str, Any] = {
"role": "assistant",
}

if agent_response_content:
assistant_message["content"] = agent_response_content
else:
assistant_message["content"] = None

if tool_calls:
assistant_message["tool_calls"] = list(tool_calls.values())

output_messages = input_messages + [assistant_message]

if tool_results:
output_messages.extend(tool_results)

conversation_message = Message(
session_id=session_id,
message_id=f"msg_{uuid.uuid4().hex[:16]}",
content=json.dumps(output_messages, ensure_ascii=False),
)

async def _save_conversation_bg(
ms=memory_store,
msg=conversation_message,
sess=session,
n_msgs=len(output_messages),
text_len=len(agent_response_content),
n_tc=len(tool_calls),
n_tr=len(tool_results),
):
try:
await ms.put_message(msg)
sess.update_time = microseconds_timestamp()
await ms.update_session(sess)
logger.debug(
"Saved conversation: %d messages,"
" text length: %d chars,"
" tool_calls: %d, tool_results: %d",
n_msgs,
text_len,
n_tc,
n_tr,
)
except Exception as e:
logger.error(
"Failed to save conversation: %s",
e,
exc_info=True,
)
Comment on lines +516 to 534

# 构建完整的消息列表
output_messages = input_messages + [assistant_message]

# 添加工具执行结果(如果有)
if tool_results:
output_messages.extend(tool_results)

# 将完整的对话历史存储为一条消息
# content 字段存储 JSON 格式的消息列表
conversation_message = Message(
session_id=session_id,
message_id=f"msg_{uuid.uuid4().hex[:16]}",
content=json.dumps(output_messages, ensure_ascii=False),
)
await memory_store.put_message(conversation_message)

# 更新 Session 时间
session.update_time = microseconds_timestamp()
await memory_store.update_session(session)

logger.debug(
f"Saved conversation: {len(output_messages)} messages,"
f" text length: {len(agent_response_content)} chars,"
f" tool_calls: {len(tool_calls)}, tool_results:"
f" {len(tool_results)}"
)
except Exception as e:
logger.error(
f"Failed to save conversation: {e}", exc_info=True
)
asyncio.create_task(_save_conversation_bg())

except Exception as e:
logger.error(f"Error in agent handler: {e}", exc_info=True)
Expand Down
18 changes: 18 additions & 0 deletions tests/unittests/memory_collection/test_memory_conversation.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
"""Tests for AgentRun Memory Conversation / AgentRun 记忆对话测试"""

import asyncio
from unittest.mock import AsyncMock, MagicMock, Mock, patch

import pytest
Expand All @@ -8,6 +9,11 @@
from agentrun.server.model import AgentRequest, Message, MessageRole


async def _flush_bg_tasks():
"""Let fire-and-forget background tasks complete before assertions."""
await asyncio.sleep(0.05)

Comment on lines +12 to +15

@pytest.fixture
def mock_memory_collection():
"""Mock MemoryCollection"""
Expand Down Expand Up @@ -185,6 +191,9 @@ async def mock_agent(request: AgentRequest):
# Verify results
assert results == ["Hello", ", ", "world!"]

# Wait for fire-and-forget background tasks to complete
await _flush_bg_tasks()

# Verify memory store calls
assert mock_memory_store.put_session.called
assert mock_memory_store.put_message.called
Expand Down Expand Up @@ -252,6 +261,9 @@ async def mock_agent(request: AgentRequest):
async for event in memory.wrap_invoke_agent(request, mock_agent):
results.append(event)

# Wait for fire-and-forget background tasks to complete
await _flush_bg_tasks()

# Verify agent still responds
assert results == ["Still works!"]

Expand Down Expand Up @@ -339,6 +351,9 @@ async def mock_agent(request: AgentRequest):
assert results[0] == "Let me search for that..."
assert results[3] == "Based on the search, it's sunny today."

# Wait for fire-and-forget background tasks to complete
await _flush_bg_tasks()

# Verify message was saved with tool calls
assert mock_memory_store.put_message.called
saved_message = mock_memory_store.put_message.call_args[0][0]
Expand Down Expand Up @@ -437,6 +452,9 @@ async def mock_agent(request: AgentRequest):
# Verify all events were passed through
assert len(results) == 4

# Wait for fire-and-forget background tasks to complete
await _flush_bg_tasks()

# Verify message was saved with accumulated tool call
assert mock_memory_store.put_message.called
saved_message = mock_memory_store.put_message.call_args[0][0]
Expand Down
Loading