feat: Enhance WeCom AI Bot integration with long connection support#5930
feat: Enhance WeCom AI Bot integration with long connection support#5930
Conversation
Summary of ChangesHello, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed! This pull request significantly enhances the WeCom AI Bot integration by introducing support for a long connection mode, leveraging WebSockets for real-time communication. This change moves beyond the traditional webhook-only approach, enabling more immediate message processing and response capabilities. The refactoring ensures that users can choose their preferred connection method, improving the overall reliability and responsiveness of the bot's interaction with the WeCom platform. Highlights
Changelog
Activity
Using Gemini Code AssistThe full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips. Invoking Gemini You can request assistance from Gemini at any point by creating a comment using either
Customization To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a Limitations & Feedback Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here. You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension. Footnotes
|
There was a problem hiding this comment.
Hey - 我这边发现了 3 个问题,并给了一些高层次的反馈:
- 新增的
WecomAIBotLongConnectionClient会在_req_locks里永久保存按req_id维度的锁;建议在send_command完成后清理对应条目,避免在长时间运行的进程里无限增长。 ChatProviderTemplate中的默认wecom_ai_bot_connection_mode被改成了long_connection;这可能会让现有用户感到意外——他们可能只期望 webhook 行为,并没有配置 WS 凭据,因此保持webhook作为默认值,让用户显式选择启用长连接模式会更安全。
给 AI Agents 的提示
请根据这次代码评审中的评论进行修改:
## 总体评论
- 新增的 `WecomAIBotLongConnectionClient` 会在 `_req_locks` 里永久保存按 `req_id` 维度的锁;建议在 `send_command` 完成后清理对应条目,避免在长时间运行的进程里无限增长。
- `ChatProviderTemplate` 中的默认 `wecom_ai_bot_connection_mode` 被改成了 `long_connection`;这可能会让现有用户感到意外——他们可能只期望 webhook 行为,并没有配置 WS 凭据,因此保持 `webhook` 作为默认值,让用户显式选择启用长连接模式会更安全。
## 单条评论
### Comment 1
<location path="astrbot/core/platform/sources/wecom_ai_bot/wecomai_adapter.py" line_range="359" />
<code_context>
logger.error("处理欢迎消息时发生异常: %s", e)
return None
+ async def _process_long_connection_payload(
+ self,
+ payload: dict[str, Any],
</code_context>
<issue_to_address>
**issue (complexity):** 建议通过重构来复用入队/挂起响应逻辑,并把依赖 `connection_mode` 的行为封装到一个小的传输层辅助工具中,从而让这个 adapter 更简单、更容易理解。
在不改变行为的前提下,你可以通过两点降低新增的复杂度:(1) 将 webhook 和 长连接 流程共享的入队/挂起逻辑进行集中处理;(2) 在一个很小的内部“transport”抽象后面隐藏模式分支。
### 1. 去重入队 + 挂起响应逻辑
`_process_message`(通过 `_enqueue_message`)和 `_process_long_connection_payload` 都会:
- 计算 `session_id` 和 `stream_id`
- 将消息入队
- 使用不同的回调元数据把队列标记为挂起
- 可选地发送初始响应
你可以抽取一个 helper,它接收一个最小化的、与模式无关的“回调上下文(callback context)”,并在这两个地方复用。
例如:
```python
async def _handle_incoming_message(
self,
message_data: dict[str, Any],
session_id: str,
stream_id: str,
callback_params: dict[str, Any],
) -> None:
# 统一排队
await self._enqueue_message(message_data, callback_params, stream_id, session_id)
self.queue_mgr.set_pending_response(stream_id, callback_params)
# 统一初始响应(如果需要)
initial_text = self.initial_respond_text
if not initial_text:
return
mode = callback_params.get("connection_mode")
req_id = callback_params.get("req_id")
if mode == "long_connection" and req_id:
await self._send_long_connection_respond_msg(
req_id=req_id,
body={
"msgtype": "stream",
"stream": {
"id": stream_id,
"finish": False,
"content": initial_text,
},
},
)
```
然后 `_process_long_connection_payload` 可以变成:
```python
async def _process_long_connection_payload(
self,
payload: dict[str, Any],
) -> None:
cmd = payload.get("cmd")
headers = payload.get("headers") or {}
body = payload.get("body") or {}
req_id = headers.get("req_id")
if not isinstance(body, dict):
return
if cmd == "aibot_msg_callback":
session_id = self._extract_session_id(body)
stream_id = f"{session_id}_{generate_random_string(10)}"
await self._handle_incoming_message(
message_data=body,
session_id=session_id,
stream_id=stream_id,
callback_params={
"req_id": req_id or "",
"connection_mode": "long_connection",
},
)
return
# event callback 分支保持不变
```
在 webhook 这边,你也可以在计算出 `session_id` 和 `stream_id` 后调用同一个 helper,并传入 HTTP 回调上下文,而不是重新实现一次入队 + 挂起逻辑。
这样,未来如果要修改队列语义或初始流行为,只需要在一个地方调整即可。
---
### 2. 通过一个小的 transport 对象本地化模式分支
`__init__`、`run`、`webhook_callback` 和 `terminate` 中的 `connection_mode` 判断,再加上对 `server` / `long_connection_client` 的 `None` 检查,使得这个类更难理解。
你可以把模式特定的启动/停止/webhook 行为封装到一个非常小的内部 helper 对象里,而不需要引入完整的策略模式(strategy pattern)。
例如,在 `__init__` 里创建完具体资源之后:
```python
class _WebhookTransport:
def __init__(self, server: WecomAIBotServer, queue_listener: WecomAIQueueListener):
self._server = server
self._queue_listener = queue_listener
async def run(self) -> None:
await asyncio.gather(self._server.start_server(), self._queue_listener.run())
async def terminate(self) -> None:
await self._server.shutdown()
async def webhook_callback(self, request: Any) -> Any:
if request.method == "GET":
return await self._server.handle_verify(request)
return await self._server.handle_callback(request)
class _LongConnectionTransport:
def __init__(
self,
client: WecomAIBotLongConnectionClient,
queue_listener: WecomAIQueueListener,
):
self._client = client
self._queue_listener = queue_listener
async def run(self) -> None:
await asyncio.gather(self._client.start(), self._queue_listener.run())
async def terminate(self) -> None:
await self._client.shutdown()
async def webhook_callback(self, request: Any) -> Any:
return "long_connection mode does not accept webhook callbacks", 400
```
然后在 `__init__` 中进行注入:
```python
self.transport = None
if self.connection_mode == "long_connection":
# long_connection_client 初始化逻辑保持不变
self.transport = _LongConnectionTransport(
client=self.long_connection_client,
queue_listener=self.queue_listener,
)
else:
# api_client / server 初始化逻辑保持不变
self.transport = _WebhookTransport(
server=self.server,
queue_listener=self.queue_listener,
)
```
并简化公共方法:
```python
def run(self) -> Awaitable[Any]:
async def run_both() -> None:
# 保留 unified_webhook_mode 特殊逻辑(如有需要,可以放进 _WebhookTransport)
if (
self.connection_mode != "long_connection"
and self.unified_webhook_mode
and self.config.get("webhook_uuid")
):
log_webhook_info(f"{self.meta().id}(企业微信智能机器人)", self.config["webhook_uuid"])
await self.queue_listener.run()
return
await self.transport.run()
return run_both()
async def webhook_callback(self, request: Any) -> Any:
return await self.transport.webhook_callback(request)
async def terminate(self) -> None:
logger.info("企业微信智能机器人适配器正在关闭...")
self.shutdown_event.set()
await self.transport.terminate()
```
这样可以:
- 去掉主类中重复的 `if self.connection_mode == ...` 判断
- 消除多个 `if not self.server` / `if not self.long_connection_client` 的保护分支
- 更清晰地表达每种模式下哪些操作是合法的,并把模式特定的代码局部化。
</issue_to_address>
### Comment 2
<location path="astrbot/core/platform/sources/wecom_ai_bot/wecomai_event.py" line_range="138" />
<code_context>
+
async def send(self, message: MessageChain | None) -> None:
"""发送消息"""
raw = self.message_obj.raw_message
</code_context>
<issue_to_address>
**issue (complexity):** 建议把重复的长连接初始化和发送逻辑抽取成专门的 helper 方法,以简化 `send`/`send_streaming` 并减少重复代码。
你可以保留新增功能,同时通过以下方式降低复杂度和重复:
1. **集中长连接上下文的查找逻辑**
2. **将长连接发送逻辑隔离到 helper 中**
3. **将流式发送中的文本聚合与传输逻辑解耦**
### 1. 集中处理 connection mode / req_id 查找
`send` 和 `send_streaming` 都重复了:
```python
raw = self.message_obj.raw_message
assert isinstance(raw, dict), ...
stream_id = raw.get("stream_id", self.session_id)
pending_response = self.queue_mgr.get_pending_response(stream_id) or {}
connection_mode = pending_response.get("callback_params", {}).get("connection_mode")
req_id = pending_response.get("callback_params", {}).get("req_id")
```
可以把这部分抽成一个小 helper,让条件逻辑保持一致:
```python
def _get_long_connection_context(self) -> tuple[str, str | None, str]:
raw = self.message_obj.raw_message
assert isinstance(raw, dict), (
"wecom_ai_bot platform event raw_message should be a dict"
)
stream_id = raw.get("stream_id", self.session_id)
pending = self.queue_mgr.get_pending_response(stream_id) or {}
callback = pending.get("callback_params", {}) or {}
return (
stream_id,
callback.get("connection_mode"),
callback.get("req_id"),
)
```
然后在 `send`/`send_streaming` 中:
```python
stream_id, connection_mode, req_id = self._get_long_connection_context()
```
并在后续逻辑中一致地使用 `stream_id`、`connection_mode`、`req_id`。
### 2. 将长连接分支抽到 helper 中
`send` 和 `send_streaming` 的顶部都有较大的、与模式相关的分支。你可以把这些分支移到更聚焦的 helper 中,让公共方法更易读。
非流式的示例:
```python
async def _try_send_via_long_connection(
self,
stream_id: str,
connection_mode: str | None,
req_id: str | None,
message: MessageChain | None,
) -> bool:
if (
connection_mode != "long_connection"
or not self.long_connection_sender
or not isinstance(req_id, str)
or not req_id
):
return False
# webhook-only 的快速路径
if self.only_use_webhook_url_to_send and self.webhook_client and message:
await self.webhook_client.send_message_chain(message)
await super().send(MessageChain([]))
return True
# webhook 仅发送不支持的部分 + 长连接
if self.webhook_client and message:
await self.webhook_client.send_message_chain(message, unsupported_only=True)
content = self._extract_plain_text_from_chain(message)
await self.long_connection_sender(
req_id,
{
"msgtype": "stream",
"stream": {
"id": stream_id,
"finish": True,
"content": content,
},
},
)
await super().send(MessageChain([]))
return True
```
然后 `send` 可以简化为:
```python
async def send(self, message: MessageChain | None) -> None:
stream_id, connection_mode, req_id = self._get_long_connection_context()
if await self._try_send_via_long_connection(stream_id, connection_mode, req_id, message):
return
if self.only_use_webhook_url_to_send and self.webhook_client and message:
await self.webhook_client.send_message_chain(message)
await self._mark_stream_complete(stream_id)
await super().send(MessageChain([]))
return
if self.webhook_client and message:
await self.webhook_client.send_message_chain(message, unsupported_only=True)
if self.api_client:
await self.api_client.send_message_chain(stream_id, message)
await self._mark_stream_complete(stream_id)
await super().send(MessageChain([]))
```
`send_streaming` 可以用自己的 `_try_send_streaming_via_long_connection` helper 采用类似模式。
### 3. 拆分流式发送中的文本聚合逻辑
目前 `send_streaming` 内联了 `increment_plain` 逻辑以及重复的长连接发送。你可以把这些逻辑移动到一个小 helper 中,让 `send_streaming` 只负责控制流程,而不是具体的聚合细节:
```python
async def _send_stream_chunks_via_long_connection(
self,
stream_id: str,
req_id: str,
generator,
) -> None:
increment_plain = ""
async for chain in generator:
if self.webhook_client:
await self.webhook_client.send_message_chain(chain, unsupported_only=True)
chain.squash_plain()
chunk_text = self._extract_plain_text_from_chain(chain)
if chunk_text:
increment_plain += chunk_text
await self.long_connection_sender(
req_id,
{
"msgtype": "stream",
"stream": {
"id": stream_id,
"finish": False,
"content": increment_plain,
},
},
)
await self.long_connection_sender(
req_id,
{
"msgtype": "stream",
"stream": {
"id": stream_id,
"finish": True,
"content": increment_plain,
},
},
)
```
这样,`send_streaming` 中的长连接分支就可以收缩为模式/路由决策:
```python
async def _try_send_streaming_via_long_connection(
self, stream_id: str, connection_mode: str | None, req_id: str | None, generator, use_fallback: bool
) -> bool:
if (
connection_mode != "long_connection"
or not self.long_connection_sender
or not isinstance(req_id, str)
or not req_id
):
return False
if self.only_use_webhook_url_to_send and self.webhook_client:
merged_chain = MessageChain([])
async for chain in generator:
merged_chain.chain.extend(chain.chain)
merged_chain.squash_plain()
await self.webhook_client.send_message_chain(merged_chain)
await self.long_connection_sender(
req_id,
{
"msgtype": "stream",
"stream": {"id": stream_id, "finish": True, "content": ""},
},
)
await super().send_streaming(generator, use_fallback)
return True
await self._send_stream_chunks_via_long_connection(stream_id, req_id, generator)
await super().send_streaming(generator, use_fallback)
return True
```
而 `send_streaming` 本身则变为:
```python
async def send_streaming(self, generator, use_fallback=False) -> None:
stream_id, connection_mode, req_id = self._get_long_connection_context()
back_queue = self.queue_mgr.get_or_create_back_queue(stream_id)
if await self._try_send_streaming_via_long_connection(
stream_id, connection_mode, req_id, generator, use_fallback
):
return
# existing webhook / queue logic here...
```
这些抽取可以保留所有现有行为(仅 webhook、webhook + 长连接、纯文本提取及流式语义),同时减少嵌套分支和重复代码,让代码更容易理解,也更安全地进行扩展。
</issue_to_address>
### Comment 3
<location path="astrbot/core/platform/sources/wecom_ai_bot/wecomai_long_connection.py" line_range="35" />
<code_context>
+ self._ws: aiohttp.ClientWebSocketResponse | None = None
+ self._shutdown_event = asyncio.Event()
+ self._send_lock = asyncio.Lock()
+ self._response_waiters: dict[str, asyncio.Future[dict[str, Any]]] = {}
+ self._req_locks: dict[str, asyncio.Lock] = {}
+
</code_context>
<issue_to_address>
**issue (complexity):** 建议移除按请求维度的 `_req_locks`,只依赖 `_send_lock` 和 `_response_waiters` 来简化 `send_command` 中的并发模型。
你可以通过删除 `_req_locks`,并仅依赖现有的全局 `_send_lock` 加上 `_response_waiters` 中按请求的 future 来简化并发控制。
目前你的实现是:
- 使用 `_send_lock` 来串行化 `ws.send_json`。
- 在 `_response_waiters` 中为每个 `req_id` 维护一个 future。
- 另外还为每个 `req_id` 维护 `_req_locks`,并在 `send_command` 内部做重试循环。
考虑到正常用法应保证每个命令的 `req_id` 唯一,按 `req_id` 的锁带来的收益不大,却增加了复杂度。可以将多个命令复用同一个 `req_id` 视为误用。
你可以完全移除 `_req_locks`,并保持原有的重试行为不变:
```python
# __init__
self._response_waiters: dict[str, asyncio.Future[dict[str, Any]]] = {}
# self._req_locks: dict[str, asyncio.Lock] = {} # remove this
```
```python
async def send_command(
self,
cmd: str,
req_id: str,
body: dict[str, Any] | None,
) -> bool:
"""发送长连接命令。"""
headers = {"req_id": req_id}
payload: dict[str, Any] = {"cmd": cmd, "headers": headers}
if body is not None:
payload["body"] = body
max_retries = 3
for attempt in range(max_retries + 1):
response = await self._send_and_wait_response(req_id, payload)
if not response:
if attempt < max_retries:
await asyncio.sleep(min(0.2 * (2**attempt), 2.0))
continue
return False
errcode = response.get("errcode")
if errcode in (0, None):
return True
if errcode == 6000 and attempt < max_retries:
backoff = min(0.2 * (2**attempt), 2.0)
logger.warning(
"[WecomAI][LongConn] 命令冲突(errcode=6000),将重试。cmd=%s req_id=%s attempt=%d",
cmd,
req_id,
attempt + 1,
)
await asyncio.sleep(backoff)
continue
logger.warning(
"[WecomAI][LongConn] 命令失败: cmd=%s req_id=%s errcode=%s errmsg=%s",
cmd,
req_id,
errcode,
response.get("errmsg"),
)
return False
return False
```
`_send_json` 依然使用 `_send_lock` 确保同时只有一个 `send_json` 在执行,而 `_response_waiters` 依旧通过 `req_id` 来保护响应。这在保留重试与错误处理等现有行为的同时,去掉了一层锁,让并发模型更容易理解。
</issue_to_address>帮我变得更有用!请对每条评论点 👍 或 👎,我会根据你的反馈改进评审质量。
Original comment in English
Hey - I've found 3 issues, and left some high level feedback:
- The new
WecomAIBotLongConnectionClientkeeps per-req_idlocks forever in_req_locks; consider cleaning up entries aftersend_commandcompletes to avoid unbounded growth in long-running processes. - The default
wecom_ai_bot_connection_modeinChatProviderTemplatehas been changed tolong_connection; this may surprise existing users who expect webhook behavior and have no WS credentials configured, so it might be safer to keepwebhookas the default and let users explicitly opt in to long connection.
Prompt for AI Agents
Please address the comments from this code review:
## Overall Comments
- The new `WecomAIBotLongConnectionClient` keeps per-`req_id` locks forever in `_req_locks`; consider cleaning up entries after `send_command` completes to avoid unbounded growth in long-running processes.
- The default `wecom_ai_bot_connection_mode` in `ChatProviderTemplate` has been changed to `long_connection`; this may surprise existing users who expect webhook behavior and have no WS credentials configured, so it might be safer to keep `webhook` as the default and let users explicitly opt in to long connection.
## Individual Comments
### Comment 1
<location path="astrbot/core/platform/sources/wecom_ai_bot/wecomai_adapter.py" line_range="359" />
<code_context>
logger.error("处理欢迎消息时发生异常: %s", e)
return None
+ async def _process_long_connection_payload(
+ self,
+ payload: dict[str, Any],
</code_context>
<issue_to_address>
**issue (complexity):** Consider refactoring to share the enqueue/pending-response logic and encapsulate connection_mode-specific behavior behind a small transport helper to keep this adapter simpler and easier to reason about.
You can reduce the new complexity without changing behavior by (1) centralizing the enqueue/pending logic shared by webhook and long-connection paths and (2) hiding mode branching behind a tiny internal “transport” abstraction.
### 1. Deduplicate enqueue + pending response logic
`_process_message` (via `_enqueue_message`) and `_process_long_connection_payload` both:
- derive `session_id` and `stream_id`
- enqueue the message
- mark the queue as pending with different callback metadata
- optionally send an initial response
You can extract a helper that takes a minimal, mode-agnostic “callback context” and reuse it in both places.
For example:
```python
async def _handle_incoming_message(
self,
message_data: dict[str, Any],
session_id: str,
stream_id: str,
callback_params: dict[str, Any],
) -> None:
# 统一排队
await self._enqueue_message(message_data, callback_params, stream_id, session_id)
self.queue_mgr.set_pending_response(stream_id, callback_params)
# 统一初始响应(如果需要)
initial_text = self.initial_respond_text
if not initial_text:
return
mode = callback_params.get("connection_mode")
req_id = callback_params.get("req_id")
if mode == "long_connection" and req_id:
await self._send_long_connection_respond_msg(
req_id=req_id,
body={
"msgtype": "stream",
"stream": {
"id": stream_id,
"finish": False,
"content": initial_text,
},
},
)
```
Then `_process_long_connection_payload` becomes:
```python
async def _process_long_connection_payload(
self,
payload: dict[str, Any],
) -> None:
cmd = payload.get("cmd")
headers = payload.get("headers") or {}
body = payload.get("body") or {}
req_id = headers.get("req_id")
if not isinstance(body, dict):
return
if cmd == "aibot_msg_callback":
session_id = self._extract_session_id(body)
stream_id = f"{session_id}_{generate_random_string(10)}"
await self._handle_incoming_message(
message_data=body,
session_id=session_id,
stream_id=stream_id,
callback_params={
"req_id": req_id or "",
"connection_mode": "long_connection",
},
)
return
# event callback 分支保持不变
```
On the webhook side, you can call the same helper after you’ve computed `session_id` and `stream_id`, passing the HTTP callback context instead of reimplementing enqueue + pending logic.
This way, any future change to queue semantics or initial-stream behavior is done in one place.
---
### 2. Localize mode branching via a tiny transport object
The `connection_mode` checks in `__init__`, `run`, `webhook_callback`, and `terminate`, plus the `None`-checks for `server` / `long_connection_client`, make the class harder to reason about.
You can encapsulate the mode-specific start/stop/webhook behavior behind a very small internal helper object, without introducing a full strategy pattern.
For example, in `__init__` after you’ve created the concrete resources:
```python
class _WebhookTransport:
def __init__(self, server: WecomAIBotServer, queue_listener: WecomAIQueueListener):
self._server = server
self._queue_listener = queue_listener
async def run(self) -> None:
await asyncio.gather(self._server.start_server(), self._queue_listener.run())
async def terminate(self) -> None:
await self._server.shutdown()
async def webhook_callback(self, request: Any) -> Any:
if request.method == "GET":
return await self._server.handle_verify(request)
return await self._server.handle_callback(request)
class _LongConnectionTransport:
def __init__(
self,
client: WecomAIBotLongConnectionClient,
queue_listener: WecomAIQueueListener,
):
self._client = client
self._queue_listener = queue_listener
async def run(self) -> None:
await asyncio.gather(self._client.start(), self._queue_listener.run())
async def terminate(self) -> None:
await self._client.shutdown()
async def webhook_callback(self, request: Any) -> Any:
return "long_connection mode does not accept webhook callbacks", 400
```
Then wire it in `__init__`:
```python
self.transport = None
if self.connection_mode == "long_connection":
# long_connection_client 初始化逻辑保持不变
self.transport = _LongConnectionTransport(
client=self.long_connection_client,
queue_listener=self.queue_listener,
)
else:
# api_client / server 初始化逻辑保持不变
self.transport = _WebhookTransport(
server=self.server,
queue_listener=self.queue_listener,
)
```
And simplify the public methods:
```python
def run(self) -> Awaitable[Any]:
async def run_both() -> None:
# 保留 unified_webhook_mode 特殊逻辑(如有需要,可以放进 _WebhookTransport)
if (
self.connection_mode != "long_connection"
and self.unified_webhook_mode
and self.config.get("webhook_uuid")
):
log_webhook_info(f"{self.meta().id}(企业微信智能机器人)", self.config["webhook_uuid"])
await self.queue_listener.run()
return
await self.transport.run()
return run_both()
async def webhook_callback(self, request: Any) -> Any:
return await self.transport.webhook_callback(request)
async def terminate(self) -> None:
logger.info("企业微信智能机器人适配器正在关闭...")
self.shutdown_event.set()
await self.transport.terminate()
```
This:
- removes repeated `if self.connection_mode == ...` checks from the main class
- eliminates several `if not self.server` / `if not self.long_connection_client` guard paths
- makes it clear which operations are valid in each mode, with mode-specific code kept local.
</issue_to_address>
### Comment 2
<location path="astrbot/core/platform/sources/wecom_ai_bot/wecomai_event.py" line_range="138" />
<code_context>
+
async def send(self, message: MessageChain | None) -> None:
"""发送消息"""
raw = self.message_obj.raw_message
</code_context>
<issue_to_address>
**issue (complexity):** Consider extracting the repeated long-connection setup and sending logic into dedicated helper methods to simplify `send`/`send_streaming` and reduce duplication.
You can keep the new functionality but reduce complexity and duplication by:
1. **Centralizing the long‑connection context lookup**
2. **Isolating the long‑connection send logic into helpers**
3. **Separating text aggregation from transport logic for streaming**
### 1. Centralize connection mode / req_id lookup
Both `send` and `send_streaming` repeat:
```python
raw = self.message_obj.raw_message
assert isinstance(raw, dict), ...
stream_id = raw.get("stream_id", self.session_id)
pending_response = self.queue_mgr.get_pending_response(stream_id) or {}
connection_mode = pending_response.get("callback_params", {}).get("connection_mode")
req_id = pending_response.get("callback_params", {}).get("req_id")
```
Extract this into a small helper, so conditions stay consistent:
```python
def _get_long_connection_context(self) -> tuple[str, str | None, str]:
raw = self.message_obj.raw_message
assert isinstance(raw, dict), (
"wecom_ai_bot platform event raw_message should be a dict"
)
stream_id = raw.get("stream_id", self.session_id)
pending = self.queue_mgr.get_pending_response(stream_id) or {}
callback = pending.get("callback_params", {}) or {}
return (
stream_id,
callback.get("connection_mode"),
callback.get("req_id"),
)
```
Then in `send`/`send_streaming`:
```python
stream_id, connection_mode, req_id = self._get_long_connection_context()
```
and use `stream_id`, `connection_mode`, `req_id` consistently.
### 2. Factor long‑connection branches into helpers
The top of `send` and the top of `send_streaming` both have large, mode‑dependent branches. You can move those into focused helpers so the public methods become much easier to read.
Example for non‑streaming:
```python
async def _try_send_via_long_connection(
self,
stream_id: str,
connection_mode: str | None,
req_id: str | None,
message: MessageChain | None,
) -> bool:
if (
connection_mode != "long_connection"
or not self.long_connection_sender
or not isinstance(req_id, str)
or not req_id
):
return False
# webhook-only short-circuit
if self.only_use_webhook_url_to_send and self.webhook_client and message:
await self.webhook_client.send_message_chain(message)
await super().send(MessageChain([]))
return True
# webhook unsupported-only + long connection
if self.webhook_client and message:
await self.webhook_client.send_message_chain(message, unsupported_only=True)
content = self._extract_plain_text_from_chain(message)
await self.long_connection_sender(
req_id,
{
"msgtype": "stream",
"stream": {
"id": stream_id,
"finish": True,
"content": content,
},
},
)
await super().send(MessageChain([]))
return True
```
Then `send` becomes:
```python
async def send(self, message: MessageChain | None) -> None:
stream_id, connection_mode, req_id = self._get_long_connection_context()
if await self._try_send_via_long_connection(stream_id, connection_mode, req_id, message):
return
if self.only_use_webhook_url_to_send and self.webhook_client and message:
await self.webhook_client.send_message_chain(message)
await self._mark_stream_complete(stream_id)
await super().send(MessageChain([]))
return
if self.webhook_client and message:
await self.webhook_client.send_message_chain(message, unsupported_only=True)
if self.api_client:
await self.api_client.send_message_chain(stream_id, message)
await self._mark_stream_complete(stream_id)
await super().send(MessageChain([]))
```
`send_streaming` can mirror this pattern with its own `_try_send_streaming_via_long_connection` helper.
### 3. Separate text aggregation for streaming
The `increment_plain` logic and repeated long‑connection sends are currently embedded inside `send_streaming`. You can move that into a small helper so `send_streaming` only controls flow, not aggregation details:
```python
async def _send_stream_chunks_via_long_connection(
self,
stream_id: str,
req_id: str,
generator,
) -> None:
increment_plain = ""
async for chain in generator:
if self.webhook_client:
await self.webhook_client.send_message_chain(chain, unsupported_only=True)
chain.squash_plain()
chunk_text = self._extract_plain_text_from_chain(chain)
if chunk_text:
increment_plain += chunk_text
await self.long_connection_sender(
req_id,
{
"msgtype": "stream",
"stream": {
"id": stream_id,
"finish": False,
"content": increment_plain,
},
},
)
await self.long_connection_sender(
req_id,
{
"msgtype": "stream",
"stream": {
"id": stream_id,
"finish": True,
"content": increment_plain,
},
},
)
```
Then the long‑connection branch in `send_streaming` shrinks to mode/routing decisions:
```python
async def _try_send_streaming_via_long_connection(
self, stream_id: str, connection_mode: str | None, req_id: str | None, generator, use_fallback: bool
) -> bool:
if (
connection_mode != "long_connection"
or not self.long_connection_sender
or not isinstance(req_id, str)
or not req_id
):
return False
if self.only_use_webhook_url_to_send and self.webhook_client:
merged_chain = MessageChain([])
async for chain in generator:
merged_chain.chain.extend(chain.chain)
merged_chain.squash_plain()
await self.webhook_client.send_message_chain(merged_chain)
await self.long_connection_sender(
req_id,
{
"msgtype": "stream",
"stream": {"id": stream_id, "finish": True, "content": ""},
},
)
await super().send_streaming(generator, use_fallback)
return True
await self._send_stream_chunks_via_long_connection(stream_id, req_id, generator)
await super().send_streaming(generator, use_fallback)
return True
```
And `send_streaming`:
```python
async def send_streaming(self, generator, use_fallback=False) -> None:
stream_id, connection_mode, req_id = self._get_long_connection_context()
back_queue = self.queue_mgr.get_or_create_back_queue(stream_id)
if await self._try_send_streaming_via_long_connection(
stream_id, connection_mode, req_id, generator, use_fallback
):
return
# existing webhook / queue logic here...
```
These extractions keep all behaviors (webhook‑only, webhook+long‑connection, plain‑text extraction and streaming semantics) but reduce nested branching and duplication, making the code easier to reason about and safer to extend.
</issue_to_address>
### Comment 3
<location path="astrbot/core/platform/sources/wecom_ai_bot/wecomai_long_connection.py" line_range="35" />
<code_context>
+ self._ws: aiohttp.ClientWebSocketResponse | None = None
+ self._shutdown_event = asyncio.Event()
+ self._send_lock = asyncio.Lock()
+ self._response_waiters: dict[str, asyncio.Future[dict[str, Any]]] = {}
+ self._req_locks: dict[str, asyncio.Lock] = {}
+
</code_context>
<issue_to_address>
**issue (complexity):** Consider removing the per-request `_req_locks` and relying only on `_send_lock` plus `_response_waiters` to simplify the concurrency model in `send_command`.
You can simplify concurrency control by dropping `_req_locks` and relying on the existing global `_send_lock` plus the per-request future in `_response_waiters`.
Right now you have:
- `_send_lock` to serialize `ws.send_json`.
- `_response_waiters` future per `req_id`.
- Additional `_req_locks` per `req_id` and a retry loop inside `send_command`.
Given that normal usage should give each command a unique `req_id`, the per-`req_id` lock adds complexity without much benefit. Commands sharing a `req_id` can be treated as misuse.
You can remove `_req_locks` entirely and keep the retry behavior unchanged:
```python
# __init__
self._response_waiters: dict[str, asyncio.Future[dict[str, Any]]] = {}
# self._req_locks: dict[str, asyncio.Lock] = {} # remove this
```
```python
async def send_command(
self,
cmd: str,
req_id: str,
body: dict[str, Any] | None,
) -> bool:
"""发送长连接命令。"""
headers = {"req_id": req_id}
payload: dict[str, Any] = {"cmd": cmd, "headers": headers}
if body is not None:
payload["body"] = body
max_retries = 3
for attempt in range(max_retries + 1):
response = await self._send_and_wait_response(req_id, payload)
if not response:
if attempt < max_retries:
await asyncio.sleep(min(0.2 * (2**attempt), 2.0))
continue
return False
errcode = response.get("errcode")
if errcode in (0, None):
return True
if errcode == 6000 and attempt < max_retries:
backoff = min(0.2 * (2**attempt), 2.0)
logger.warning(
"[WecomAI][LongConn] 命令冲突(errcode=6000),将重试。cmd=%s req_id=%s attempt=%d",
cmd,
req_id,
attempt + 1,
)
await asyncio.sleep(backoff)
continue
logger.warning(
"[WecomAI][LongConn] 命令失败: cmd=%s req_id=%s errcode=%s errmsg=%s",
cmd,
req_id,
errcode,
response.get("errmsg"),
)
return False
return False
```
`_send_json` still uses `_send_lock` to ensure only one `send_json` runs at a time, and `_response_waiters` still guard responses by `req_id`. This keeps all existing behavior (including retries and error handling) while removing one layer of locking and making the concurrency model easier to reason about.
</issue_to_address>Help me be more useful! Please click 👍 or 👎 on each comment and I'll use the feedback to improve your reviews.
| logger.error("处理欢迎消息时发生异常: %s", e) | ||
| return None | ||
|
|
||
| async def _process_long_connection_payload( |
There was a problem hiding this comment.
issue (complexity): 建议通过重构来复用入队/挂起响应逻辑,并把依赖 connection_mode 的行为封装到一个小的传输层辅助工具中,从而让这个 adapter 更简单、更容易理解。
在不改变行为的前提下,你可以通过两点降低新增的复杂度:(1) 将 webhook 和 长连接 流程共享的入队/挂起逻辑进行集中处理;(2) 在一个很小的内部“transport”抽象后面隐藏模式分支。
1. 去重入队 + 挂起响应逻辑
_process_message(通过 _enqueue_message)和 _process_long_connection_payload 都会:
- 计算
session_id和stream_id - 将消息入队
- 使用不同的回调元数据把队列标记为挂起
- 可选地发送初始响应
你可以抽取一个 helper,它接收一个最小化的、与模式无关的“回调上下文(callback context)”,并在这两个地方复用。
例如:
async def _handle_incoming_message(
self,
message_data: dict[str, Any],
session_id: str,
stream_id: str,
callback_params: dict[str, Any],
) -> None:
# 统一排队
await self._enqueue_message(message_data, callback_params, stream_id, session_id)
self.queue_mgr.set_pending_response(stream_id, callback_params)
# 统一初始响应(如果需要)
initial_text = self.initial_respond_text
if not initial_text:
return
mode = callback_params.get("connection_mode")
req_id = callback_params.get("req_id")
if mode == "long_connection" and req_id:
await self._send_long_connection_respond_msg(
req_id=req_id,
body={
"msgtype": "stream",
"stream": {
"id": stream_id,
"finish": False,
"content": initial_text,
},
},
)然后 _process_long_connection_payload 可以变成:
async def _process_long_connection_payload(
self,
payload: dict[str, Any],
) -> None:
cmd = payload.get("cmd")
headers = payload.get("headers") or {}
body = payload.get("body") or {}
req_id = headers.get("req_id")
if not isinstance(body, dict):
return
if cmd == "aibot_msg_callback":
session_id = self._extract_session_id(body)
stream_id = f"{session_id}_{generate_random_string(10)}"
await self._handle_incoming_message(
message_data=body,
session_id=session_id,
stream_id=stream_id,
callback_params={
"req_id": req_id or "",
"connection_mode": "long_connection",
},
)
return
# event callback 分支保持不变在 webhook 这边,你也可以在计算出 session_id 和 stream_id 后调用同一个 helper,并传入 HTTP 回调上下文,而不是重新实现一次入队 + 挂起逻辑。
这样,未来如果要修改队列语义或初始流行为,只需要在一个地方调整即可。
2. 通过一个小的 transport 对象本地化模式分支
__init__、run、webhook_callback 和 terminate 中的 connection_mode 判断,再加上对 server / long_connection_client 的 None 检查,使得这个类更难理解。
你可以把模式特定的启动/停止/webhook 行为封装到一个非常小的内部 helper 对象里,而不需要引入完整的策略模式(strategy pattern)。
例如,在 __init__ 里创建完具体资源之后:
class _WebhookTransport:
def __init__(self, server: WecomAIBotServer, queue_listener: WecomAIQueueListener):
self._server = server
self._queue_listener = queue_listener
async def run(self) -> None:
await asyncio.gather(self._server.start_server(), self._queue_listener.run())
async def terminate(self) -> None:
await self._server.shutdown()
async def webhook_callback(self, request: Any) -> Any:
if request.method == "GET":
return await self._server.handle_verify(request)
return await self._server.handle_callback(request)
class _LongConnectionTransport:
def __init__(
self,
client: WecomAIBotLongConnectionClient,
queue_listener: WecomAIQueueListener,
):
self._client = client
self._queue_listener = queue_listener
async def run(self) -> None:
await asyncio.gather(self._client.start(), self._queue_listener.run())
async def terminate(self) -> None:
await self._client.shutdown()
async def webhook_callback(self, request: Any) -> Any:
return "long_connection mode does not accept webhook callbacks", 400然后在 __init__ 中进行注入:
self.transport = None
if self.connection_mode == "long_connection":
# long_connection_client 初始化逻辑保持不变
self.transport = _LongConnectionTransport(
client=self.long_connection_client,
queue_listener=self.queue_listener,
)
else:
# api_client / server 初始化逻辑保持不变
self.transport = _WebhookTransport(
server=self.server,
queue_listener=self.queue_listener,
)并简化公共方法:
def run(self) -> Awaitable[Any]:
async def run_both() -> None:
# 保留 unified_webhook_mode 特殊逻辑(如有需要,可以放进 _WebhookTransport)
if (
self.connection_mode != "long_connection"
and self.unified_webhook_mode
and self.config.get("webhook_uuid")
):
log_webhook_info(f"{self.meta().id}(企业微信智能机器人)", self.config["webhook_uuid"])
await self.queue_listener.run()
return
await self.transport.run()
return run_both()
async def webhook_callback(self, request: Any) -> Any:
return await self.transport.webhook_callback(request)
async def terminate(self) -> None:
logger.info("企业微信智能机器人适配器正在关闭...")
self.shutdown_event.set()
await self.transport.terminate()这样可以:
- 去掉主类中重复的
if self.connection_mode == ...判断 - 消除多个
if not self.server/if not self.long_connection_client的保护分支 - 更清晰地表达每种模式下哪些操作是合法的,并把模式特定的代码局部化。
Original comment in English
issue (complexity): Consider refactoring to share the enqueue/pending-response logic and encapsulate connection_mode-specific behavior behind a small transport helper to keep this adapter simpler and easier to reason about.
You can reduce the new complexity without changing behavior by (1) centralizing the enqueue/pending logic shared by webhook and long-connection paths and (2) hiding mode branching behind a tiny internal “transport” abstraction.
1. Deduplicate enqueue + pending response logic
_process_message (via _enqueue_message) and _process_long_connection_payload both:
- derive
session_idandstream_id - enqueue the message
- mark the queue as pending with different callback metadata
- optionally send an initial response
You can extract a helper that takes a minimal, mode-agnostic “callback context” and reuse it in both places.
For example:
async def _handle_incoming_message(
self,
message_data: dict[str, Any],
session_id: str,
stream_id: str,
callback_params: dict[str, Any],
) -> None:
# 统一排队
await self._enqueue_message(message_data, callback_params, stream_id, session_id)
self.queue_mgr.set_pending_response(stream_id, callback_params)
# 统一初始响应(如果需要)
initial_text = self.initial_respond_text
if not initial_text:
return
mode = callback_params.get("connection_mode")
req_id = callback_params.get("req_id")
if mode == "long_connection" and req_id:
await self._send_long_connection_respond_msg(
req_id=req_id,
body={
"msgtype": "stream",
"stream": {
"id": stream_id,
"finish": False,
"content": initial_text,
},
},
)Then _process_long_connection_payload becomes:
async def _process_long_connection_payload(
self,
payload: dict[str, Any],
) -> None:
cmd = payload.get("cmd")
headers = payload.get("headers") or {}
body = payload.get("body") or {}
req_id = headers.get("req_id")
if not isinstance(body, dict):
return
if cmd == "aibot_msg_callback":
session_id = self._extract_session_id(body)
stream_id = f"{session_id}_{generate_random_string(10)}"
await self._handle_incoming_message(
message_data=body,
session_id=session_id,
stream_id=stream_id,
callback_params={
"req_id": req_id or "",
"connection_mode": "long_connection",
},
)
return
# event callback 分支保持不变On the webhook side, you can call the same helper after you’ve computed session_id and stream_id, passing the HTTP callback context instead of reimplementing enqueue + pending logic.
This way, any future change to queue semantics or initial-stream behavior is done in one place.
2. Localize mode branching via a tiny transport object
The connection_mode checks in __init__, run, webhook_callback, and terminate, plus the None-checks for server / long_connection_client, make the class harder to reason about.
You can encapsulate the mode-specific start/stop/webhook behavior behind a very small internal helper object, without introducing a full strategy pattern.
For example, in __init__ after you’ve created the concrete resources:
class _WebhookTransport:
def __init__(self, server: WecomAIBotServer, queue_listener: WecomAIQueueListener):
self._server = server
self._queue_listener = queue_listener
async def run(self) -> None:
await asyncio.gather(self._server.start_server(), self._queue_listener.run())
async def terminate(self) -> None:
await self._server.shutdown()
async def webhook_callback(self, request: Any) -> Any:
if request.method == "GET":
return await self._server.handle_verify(request)
return await self._server.handle_callback(request)
class _LongConnectionTransport:
def __init__(
self,
client: WecomAIBotLongConnectionClient,
queue_listener: WecomAIQueueListener,
):
self._client = client
self._queue_listener = queue_listener
async def run(self) -> None:
await asyncio.gather(self._client.start(), self._queue_listener.run())
async def terminate(self) -> None:
await self._client.shutdown()
async def webhook_callback(self, request: Any) -> Any:
return "long_connection mode does not accept webhook callbacks", 400Then wire it in __init__:
self.transport = None
if self.connection_mode == "long_connection":
# long_connection_client 初始化逻辑保持不变
self.transport = _LongConnectionTransport(
client=self.long_connection_client,
queue_listener=self.queue_listener,
)
else:
# api_client / server 初始化逻辑保持不变
self.transport = _WebhookTransport(
server=self.server,
queue_listener=self.queue_listener,
)And simplify the public methods:
def run(self) -> Awaitable[Any]:
async def run_both() -> None:
# 保留 unified_webhook_mode 特殊逻辑(如有需要,可以放进 _WebhookTransport)
if (
self.connection_mode != "long_connection"
and self.unified_webhook_mode
and self.config.get("webhook_uuid")
):
log_webhook_info(f"{self.meta().id}(企业微信智能机器人)", self.config["webhook_uuid"])
await self.queue_listener.run()
return
await self.transport.run()
return run_both()
async def webhook_callback(self, request: Any) -> Any:
return await self.transport.webhook_callback(request)
async def terminate(self) -> None:
logger.info("企业微信智能机器人适配器正在关闭...")
self.shutdown_event.set()
await self.transport.terminate()This:
- removes repeated
if self.connection_mode == ...checks from the main class - eliminates several
if not self.server/if not self.long_connection_clientguard paths - makes it clear which operations are valid in each mode, with mode-specific code kept local.
|
|
||
| async def send(self, message: MessageChain | None) -> None: | ||
| """发送消息""" | ||
| raw = self.message_obj.raw_message |
There was a problem hiding this comment.
issue (complexity): 建议把重复的长连接初始化和发送逻辑抽取成专门的 helper 方法,以简化 send/send_streaming 并减少重复代码。
你可以保留新增功能,同时通过以下方式降低复杂度和重复:
- 集中长连接上下文的查找逻辑
- 将长连接发送逻辑隔离到 helper 中
- 将流式发送中的文本聚合与传输逻辑解耦
1. 集中处理 connection mode / req_id 查找
send 和 send_streaming 都重复了:
raw = self.message_obj.raw_message
assert isinstance(raw, dict), ...
stream_id = raw.get("stream_id", self.session_id)
pending_response = self.queue_mgr.get_pending_response(stream_id) or {}
connection_mode = pending_response.get("callback_params", {}).get("connection_mode")
req_id = pending_response.get("callback_params", {}).get("req_id")可以把这部分抽成一个小 helper,让条件逻辑保持一致:
def _get_long_connection_context(self) -> tuple[str, str | None, str]:
raw = self.message_obj.raw_message
assert isinstance(raw, dict), (
"wecom_ai_bot platform event raw_message should be a dict"
)
stream_id = raw.get("stream_id", self.session_id)
pending = self.queue_mgr.get_pending_response(stream_id) or {}
callback = pending.get("callback_params", {}) or {}
return (
stream_id,
callback.get("connection_mode"),
callback.get("req_id"),
)然后在 send/send_streaming 中:
stream_id, connection_mode, req_id = self._get_long_connection_context()并在后续逻辑中一致地使用 stream_id、connection_mode、req_id。
2. 将长连接分支抽到 helper 中
send 和 send_streaming 的顶部都有较大的、与模式相关的分支。你可以把这些分支移到更聚焦的 helper 中,让公共方法更易读。
非流式的示例:
async def _try_send_via_long_connection(
self,
stream_id: str,
connection_mode: str | None,
req_id: str | None,
message: MessageChain | None,
) -> bool:
if (
connection_mode != "long_connection"
or not self.long_connection_sender
or not isinstance(req_id, str)
or not req_id
):
return False
# webhook-only 的快速路径
if self.only_use_webhook_url_to_send and self.webhook_client and message:
await self.webhook_client.send_message_chain(message)
await super().send(MessageChain([]))
return True
# webhook 仅发送不支持的部分 + 长连接
if self.webhook_client and message:
await self.webhook_client.send_message_chain(message, unsupported_only=True)
content = self._extract_plain_text_from_chain(message)
await self.long_connection_sender(
req_id,
{
"msgtype": "stream",
"stream": {
"id": stream_id,
"finish": True,
"content": content,
},
},
)
await super().send(MessageChain([]))
return True然后 send 可以简化为:
async def send(self, message: MessageChain | None) -> None:
stream_id, connection_mode, req_id = self._get_long_connection_context()
if await self._try_send_via_long_connection(stream_id, connection_mode, req_id, message):
return
if self.only_use_webhook_url_to_send and self.webhook_client and message:
await self.webhook_client.send_message_chain(message)
await self._mark_stream_complete(stream_id)
await super().send(MessageChain([]))
return
if self.webhook_client and message:
await self.webhook_client.send_message_chain(message, unsupported_only=True)
if self.api_client:
await self.api_client.send_message_chain(stream_id, message)
await self._mark_stream_complete(stream_id)
await super().send(MessageChain([]))send_streaming 可以用自己的 _try_send_streaming_via_long_connection helper 采用类似模式。
3. 拆分流式发送中的文本聚合逻辑
目前 send_streaming 内联了 increment_plain 逻辑以及重复的长连接发送。你可以把这些逻辑移动到一个小 helper 中,让 send_streaming 只负责控制流程,而不是具体的聚合细节:
async def _send_stream_chunks_via_long_connection(
self,
stream_id: str,
req_id: str,
generator,
) -> None:
increment_plain = ""
async for chain in generator:
if self.webhook_client:
await self.webhook_client.send_message_chain(chain, unsupported_only=True)
chain.squash_plain()
chunk_text = self._extract_plain_text_from_chain(chain)
if chunk_text:
increment_plain += chunk_text
await self.long_connection_sender(
req_id,
{
"msgtype": "stream",
"stream": {
"id": stream_id,
"finish": False,
"content": increment_plain,
},
},
)
await self.long_connection_sender(
req_id,
{
"msgtype": "stream",
"stream": {
"id": stream_id,
"finish": True,
"content": increment_plain,
},
},
)这样,send_streaming 中的长连接分支就可以收缩为模式/路由决策:
async def _try_send_streaming_via_long_connection(
self, stream_id: str, connection_mode: str | None, req_id: str | None, generator, use_fallback: bool
) -> bool:
if (
connection_mode != "long_connection"
or not self.long_connection_sender
or not isinstance(req_id, str)
or not req_id
):
return False
if self.only_use_webhook_url_to_send and self.webhook_client:
merged_chain = MessageChain([])
async for chain in generator:
merged_chain.chain.extend(chain.chain)
merged_chain.squash_plain()
await self.webhook_client.send_message_chain(merged_chain)
await self.long_connection_sender(
req_id,
{
"msgtype": "stream",
"stream": {"id": stream_id, "finish": True, "content": ""},
},
)
await super().send_streaming(generator, use_fallback)
return True
await self._send_stream_chunks_via_long_connection(stream_id, req_id, generator)
await super().send_streaming(generator, use_fallback)
return True而 send_streaming 本身则变为:
async def send_streaming(self, generator, use_fallback=False) -> None:
stream_id, connection_mode, req_id = self._get_long_connection_context()
back_queue = self.queue_mgr.get_or_create_back_queue(stream_id)
if await self._try_send_streaming_via_long_connection(
stream_id, connection_mode, req_id, generator, use_fallback
):
return
# existing webhook / queue logic here...这些抽取可以保留所有现有行为(仅 webhook、webhook + 长连接、纯文本提取及流式语义),同时减少嵌套分支和重复代码,让代码更容易理解,也更安全地进行扩展。
Original comment in English
issue (complexity): Consider extracting the repeated long-connection setup and sending logic into dedicated helper methods to simplify send/send_streaming and reduce duplication.
You can keep the new functionality but reduce complexity and duplication by:
- Centralizing the long‑connection context lookup
- Isolating the long‑connection send logic into helpers
- Separating text aggregation from transport logic for streaming
1. Centralize connection mode / req_id lookup
Both send and send_streaming repeat:
raw = self.message_obj.raw_message
assert isinstance(raw, dict), ...
stream_id = raw.get("stream_id", self.session_id)
pending_response = self.queue_mgr.get_pending_response(stream_id) or {}
connection_mode = pending_response.get("callback_params", {}).get("connection_mode")
req_id = pending_response.get("callback_params", {}).get("req_id")Extract this into a small helper, so conditions stay consistent:
def _get_long_connection_context(self) -> tuple[str, str | None, str]:
raw = self.message_obj.raw_message
assert isinstance(raw, dict), (
"wecom_ai_bot platform event raw_message should be a dict"
)
stream_id = raw.get("stream_id", self.session_id)
pending = self.queue_mgr.get_pending_response(stream_id) or {}
callback = pending.get("callback_params", {}) or {}
return (
stream_id,
callback.get("connection_mode"),
callback.get("req_id"),
)Then in send/send_streaming:
stream_id, connection_mode, req_id = self._get_long_connection_context()and use stream_id, connection_mode, req_id consistently.
2. Factor long‑connection branches into helpers
The top of send and the top of send_streaming both have large, mode‑dependent branches. You can move those into focused helpers so the public methods become much easier to read.
Example for non‑streaming:
async def _try_send_via_long_connection(
self,
stream_id: str,
connection_mode: str | None,
req_id: str | None,
message: MessageChain | None,
) -> bool:
if (
connection_mode != "long_connection"
or not self.long_connection_sender
or not isinstance(req_id, str)
or not req_id
):
return False
# webhook-only short-circuit
if self.only_use_webhook_url_to_send and self.webhook_client and message:
await self.webhook_client.send_message_chain(message)
await super().send(MessageChain([]))
return True
# webhook unsupported-only + long connection
if self.webhook_client and message:
await self.webhook_client.send_message_chain(message, unsupported_only=True)
content = self._extract_plain_text_from_chain(message)
await self.long_connection_sender(
req_id,
{
"msgtype": "stream",
"stream": {
"id": stream_id,
"finish": True,
"content": content,
},
},
)
await super().send(MessageChain([]))
return TrueThen send becomes:
async def send(self, message: MessageChain | None) -> None:
stream_id, connection_mode, req_id = self._get_long_connection_context()
if await self._try_send_via_long_connection(stream_id, connection_mode, req_id, message):
return
if self.only_use_webhook_url_to_send and self.webhook_client and message:
await self.webhook_client.send_message_chain(message)
await self._mark_stream_complete(stream_id)
await super().send(MessageChain([]))
return
if self.webhook_client and message:
await self.webhook_client.send_message_chain(message, unsupported_only=True)
if self.api_client:
await self.api_client.send_message_chain(stream_id, message)
await self._mark_stream_complete(stream_id)
await super().send(MessageChain([]))send_streaming can mirror this pattern with its own _try_send_streaming_via_long_connection helper.
3. Separate text aggregation for streaming
The increment_plain logic and repeated long‑connection sends are currently embedded inside send_streaming. You can move that into a small helper so send_streaming only controls flow, not aggregation details:
async def _send_stream_chunks_via_long_connection(
self,
stream_id: str,
req_id: str,
generator,
) -> None:
increment_plain = ""
async for chain in generator:
if self.webhook_client:
await self.webhook_client.send_message_chain(chain, unsupported_only=True)
chain.squash_plain()
chunk_text = self._extract_plain_text_from_chain(chain)
if chunk_text:
increment_plain += chunk_text
await self.long_connection_sender(
req_id,
{
"msgtype": "stream",
"stream": {
"id": stream_id,
"finish": False,
"content": increment_plain,
},
},
)
await self.long_connection_sender(
req_id,
{
"msgtype": "stream",
"stream": {
"id": stream_id,
"finish": True,
"content": increment_plain,
},
},
)Then the long‑connection branch in send_streaming shrinks to mode/routing decisions:
async def _try_send_streaming_via_long_connection(
self, stream_id: str, connection_mode: str | None, req_id: str | None, generator, use_fallback: bool
) -> bool:
if (
connection_mode != "long_connection"
or not self.long_connection_sender
or not isinstance(req_id, str)
or not req_id
):
return False
if self.only_use_webhook_url_to_send and self.webhook_client:
merged_chain = MessageChain([])
async for chain in generator:
merged_chain.chain.extend(chain.chain)
merged_chain.squash_plain()
await self.webhook_client.send_message_chain(merged_chain)
await self.long_connection_sender(
req_id,
{
"msgtype": "stream",
"stream": {"id": stream_id, "finish": True, "content": ""},
},
)
await super().send_streaming(generator, use_fallback)
return True
await self._send_stream_chunks_via_long_connection(stream_id, req_id, generator)
await super().send_streaming(generator, use_fallback)
return TrueAnd send_streaming:
async def send_streaming(self, generator, use_fallback=False) -> None:
stream_id, connection_mode, req_id = self._get_long_connection_context()
back_queue = self.queue_mgr.get_or_create_back_queue(stream_id)
if await self._try_send_streaming_via_long_connection(
stream_id, connection_mode, req_id, generator, use_fallback
):
return
# existing webhook / queue logic here...These extractions keep all behaviors (webhook‑only, webhook+long‑connection, plain‑text extraction and streaming semantics) but reduce nested branching and duplication, making the code easier to reason about and safer to extend.
There was a problem hiding this comment.
Code Review
This pull request introduces significant enhancements to the WeCom AI Bot adapter by adding support for long connections via WebSockets, while maintaining backward compatibility with the existing webhook mode. However, a critical security vulnerability was identified: a memory leak in the long connection client due to unmanaged request locks, and the image processing logic is susceptible to SSRF and memory exhaustion (DoS) due to a lack of URL validation and response size limits. Additionally, there are suggestions to further improve robustness and enhance code maintainability.
| self._shutdown_event = asyncio.Event() | ||
| self._send_lock = asyncio.Lock() | ||
| self._response_waiters: dict[str, asyncio.Future[dict[str, Any]]] = {} | ||
| self._req_locks: dict[str, asyncio.Lock] = {} |
There was a problem hiding this comment.
The _req_locks dictionary and the associated lock in send_command appear to be unnecessary and also cause a memory leak. The calls to send_command with the same req_id are already serialized by await in the calling code (e.g., in wecomai_event.py). Since there's no concurrency for the same req_id, this locking mechanism isn't needed.
Furthermore, this dictionary is never cleaned up, leading to a memory leak as it grows with each new request.
I recommend removing _req_locks (this line), and also removing the async with req_lock: block in the send_command method (lines 163-164) and un-indenting its content. This will simplify the code and resolve the memory leak.
| if body is not None: | ||
| payload["body"] = body | ||
|
|
||
| req_lock = self._req_locks.setdefault(req_id, asyncio.Lock()) |
There was a problem hiding this comment.
The _req_locks dictionary stores an asyncio.Lock for every req_id. Since a new req_id is generated for every command (including periodic heartbeats), this dictionary will grow indefinitely over time, leading to a memory leak and eventual performance degradation or crash. It is recommended to remove the lock from the dictionary once the request is completed.
| tasks = [ | ||
| process_encrypted_image(url, self.encoding_aes_key) | ||
| for url in _img_url_to_process | ||
| process_encrypted_image(url, aes_key or self.encoding_aes_key) |
There was a problem hiding this comment.
The process_encrypted_image utility function (called here) downloads data from an unvalidated URL provided in the message payload and reads the entire response into memory. This implementation is vulnerable to: 1) Server-Side Request Forgery (SSRF), as the URL is not restricted to trusted WeCom domains, and 2) Denial of Service (OOM), as an attacker could provide a URL to an extremely large file that exhausts the bot's memory. It is recommended to validate the URL domain and enforce a maximum download size limit.
| pending_response = self.queue_mgr.get_pending_response(stream_id) or {} | ||
| connection_mode = pending_response.get("callback_params", {}).get( | ||
| "connection_mode" | ||
| ) | ||
| req_id = pending_response.get("callback_params", {}).get("req_id") |
There was a problem hiding this comment.
This block of code to retrieve connection_mode and req_id is duplicated in the send_streaming method (lines 209-213). To improve code maintainability and reduce redundancy, consider extracting this logic into a helper method.
For example, you could create a method like this:
def _get_response_params(self, stream_id: str) -> tuple[str | None, str | None]:
pending_response = self.queue_mgr.get_pending_response(stream_id) or {}
callback_params = pending_response.get("callback_params", {})
connection_mode = callback_params.get("connection_mode")
req_id = callback_params.get("req_id")
return connection_mode, req_idThen you can call this helper in both send and send_streaming to avoid repetition.
|
No docs changes were generated in this run (docs repo had no updates). Docs repo: AstrBotDevs/AstrBot-docs AI change summary (not committed):
Experimental bot notice:
|
closes: #5929
Modifications / 改动点
Screenshots or Test Results / 运行截图或测试结果
Checklist / 检查清单
requirements.txt和pyproject.toml文件相应位置。/ I have ensured that no new dependencies are introduced, OR if new dependencies are introduced, they have been added to the appropriate locations inrequirements.txtandpyproject.toml.Summary by Sourcery
为 WeCom AI Bot 适配器增加长连接 WebSocket 支持,同时保持对 webhook 模式的兼容性,并将响应接入现有的异步消息管道。
New Features:
Enhancements:
Documentation:
Original summary in English
Summary by Sourcery
Add long-connection WebSocket support to the WeCom AI Bot adapter while keeping webhook mode compatible, and wire responses into the existing async messaging pipeline.
New Features:
Enhancements:
Documentation: