-
-
Notifications
You must be signed in to change notification settings - Fork 1.2k
feat(platform): add CLI Tester for plugin testing and debugging #4787
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: master
Are you sure you want to change the base?
Conversation
- Add CLI platform adapter with Unix socket mode - Support isolated sessions with configurable TTL - Add whitelist exemption for CLI platform - Include astrbot-cli command-line tool - Support independent configuration file system
…efault - Rename "CLI Platform Adapter" to "CLI Tester" - Set default enable to false (disabled by default) - Update descriptions to emphasize testing and debugging purpose - Clarify design goal: build fast feedback loop for vibe coding
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hey - 我发现了两个问题,并留下了一些高层面的反馈:
- 当前 CLI 适配器将诸如
/AstrBot/data/{config_file}和/tmp/astrbot.sock这样的路径写死;建议从现有的配置/根路径工具中派生这些路径,或者允许通过配置覆盖,这样在非 Docker 或非默认部署环境中该特性也能正常工作。 - 白名单检查对
cli平台无条件绕过校验;如果你预期有些环境会以更受控的方式使用 CLI,那么将该行为放在一个配置开关后面,可能会比硬编码豁免更安全。 - 对于 Unix 套接字服务器,你可能需要显式控制文件权限/所有权(例如通过
os.chmod或 umask),并在 JSON 载荷上增加最小的分帧/大小检查,以避免在多个客户端连接时因部分读取或超大读取导致的问题。
供 AI Agents 使用的提示词
请根据这次代码审查中的评论进行修改:
## 整体评论
- 当前 CLI 适配器将诸如 `/AstrBot/data/{config_file}` 和 `/tmp/astrbot.sock` 这样的路径写死;建议从现有的配置/根路径工具中派生这些路径,或者允许通过配置覆盖,这样在非 Docker 或非默认部署环境中该特性也能正常工作。
- 白名单检查对 `cli` 平台无条件绕过校验;如果你预期有些环境会以更受控的方式使用 CLI,那么将该行为放在一个配置开关后面,可能会比硬编码豁免更安全。
- 对于 Unix 套接字服务器,你可能需要显式控制文件权限/所有权(例如通过 `os.chmod` 或 umask),并在 JSON 载荷上增加最小的分帧/大小检查,以避免在多个客户端连接时因部分读取或超大读取导致的问题。
## 单独评论
### 评论 1
<location> `astrbot/core/platform/sources/cli/cli_adapter.py:302` </location>
<code_context>
+ while self._running:
+ try:
+ # 接受连接(非阻塞)
+ loop = asyncio.get_event_loop()
+ client_socket, _ = await loop.sock_accept(server_socket)
+
</code_context>
<issue_to_address>
**suggestion (bug_risk):** 在异步代码中使用 asyncio.get_event_loop() 并不推荐;asyncio.get_running_loop() 更安全且更具前向兼容性。
在 Python 3.10+ 中,`asyncio.get_event_loop()` 在异步代码中已被弃用,并且在某些策略下可能返回错误的事件循环。在 `_run_socket_mode` 中(以及类似的 `_handle_socket_client` / `_read_input` 中),请使用 `asyncio.get_running_loop()`,以确保你获取的是当前任务所在的事件循环。
建议实现:
```python
# 接受连接(非阻塞)
loop = asyncio.get_running_loop()
client_socket, _ = await loop.sock_accept(server_socket)
```
在 `astrbot/core/platform/sources/cli/cli_adapter.py` 中搜索其他在 `async def` 函数内部使用 `asyncio.get_event_loop()` 的地方,尤其是 `_handle_socket_client` 和 `_read_input`,并以类似方式替换:
- 当上下文为运行在事件循环上的异步代码时,将 `loop = asyncio.get_event_loop()` 替换为 `loop = asyncio.get_running_loop()`。
如果在纯同步初始化代码中(不在活动事件循环内)有 `asyncio.get_event_loop()` 的使用,应单独进行审查,因为它们可能需要不同的模式(例如使用 `asyncio.new_event_loop()` 显式创建事件循环,而不是使用 `get_running_loop()`)。
</issue_to_address>
### 评论 2
<location> `astrbot/core/platform/sources/cli/cli_adapter.py:134` </location>
<code_context>
+ logger.info("[ENTRY] CLIPlatformAdapter.run inputs={}")
+ return self._run_loop()
+
+ async def _run_loop(self) -> None:
+ """主运行循环
+
</code_context>
<issue_to_address>
**issue (complexity):** 建议通过抽取小的辅助函数、集中复用逻辑以及将模式映射到处理函数的方式来重构 CLI 适配器,使代码在不改变行为的前提下更易理解。
在不改变行为的前提下,你可以通过抽取一些聚焦的小型辅助函数并整合重复逻辑,显著降低复杂度。下面是一些具体、局部化的重构建议,可以保留现有设计,但让代码更易于推理。
---
### 1. 使用策略映射简化模式选择
`_run_loop` 目前在函数体中同时包含 TTY 检测和分支逻辑。你可以将模式解析抽取到一个小的辅助函数中,并将模式映射到可调用对象:
```python
# in __init__
self._mode_handlers: dict[str, callable[[], Awaitable[None]]] = {
"tty": self._run_tty_mode,
"file": self._run_file_mode,
"socket": self._run_socket_mode,
}
def _resolve_mode(self) -> str:
has_tty = sys.stdin.isatty()
if self.mode == "auto":
return "file" if not has_tty else "tty"
if self.mode in ("tty", "file", "socket"):
if self.mode == "tty" and not has_tty:
logger.warning(
"[PROCESS] TTY mode requested but no TTY detected. "
"CLI platform will not start."
)
return "" # or None
return self.mode
logger.error("[ERROR] Unknown mode: %s", self.mode)
return ""
```
```python
async def _run_loop(self) -> None:
logger.info("[PROCESS] Starting CLI loop")
if self.use_isolated_sessions:
self._cleanup_task = asyncio.create_task(self._cleanup_expired_sessions())
mode = self._resolve_mode()
if not mode:
return
handler = self._mode_handlers.get(mode)
if handler:
logger.info("[PROCESS] Starting %s mode", mode)
await handler()
```
这样可以在不改变行为的前提下,压平分支结构,并把 TTY 逻辑局部化。
---
### 2. 抽取一个小型 `SessionTracker` 辅助类
`_convert_input` 和 `_cleanup_expired_sessions` 都会操作 `_session_timestamps`。将它们移入一个小型辅助类中,可以将消息构建和会话生命周期解耦:
```python
class _SessionTracker:
def __init__(self, ttl: int) -> None:
import time
self._ttl = ttl
self._timestamps: dict[str, float] = {}
self._time = time
def ensure_session(self, base_session_id: str, request_id: str | None) -> str:
if request_id is None:
return base_session_id
session_id = f"{base_session_id}_{request_id}"
if session_id not in self._timestamps:
self._timestamps[session_id] = self._time.time()
return session_id
def collect_expired(self) -> list[str]:
now = self._time.time()
expired = [
s for s, ts in list(self._timestamps.items())
if now - ts > self._ttl
]
for s in expired:
self._timestamps.pop(s, None)
return expired
```
将其接入适配器:
```python
# __init__
self._session_tracker = _SessionTracker(self.session_ttl)
```
```python
def _convert_input(self, text: str, request_id: str | None = None) -> AstrBotMessage:
...
if self.use_isolated_sessions and request_id:
message.session_id = self._session_tracker.ensure_session(
base_session_id="cli_session",
request_id=request_id,
)
else:
message.session_id = self.session_id
...
```
```python
async def _cleanup_expired_sessions(self) -> None:
logger.info("[ENTRY] _cleanup_expired_sessions started, TTL=%s seconds", self.session_ttl)
while self._running:
try:
await asyncio.sleep(10)
if not self.use_isolated_sessions:
continue
expired_sessions = self._session_tracker.collect_expired()
for session_id in expired_sessions:
logger.info("[PROCESS] Cleaning expired session: %s", session_id)
# TODO: DB cleanup if needed
if expired_sessions:
logger.info("[PROCESS] Cleaned %d expired sessions", len(expired_sessions))
except Exception as e:
logger.error("[ERROR] Session cleanup error: %s", e)
```
这样可以避免在多个位置直接修改 `_session_timestamps`,并将会话策略集中到一个地方。
---
### 3. 去重图片提取/归一化逻辑
你已经在 `_handle_socket_client` 中实现了图片提取和 base64 转换。如果 `CLIMessageEvent.send`(或类似逻辑)中有重叠代码,可以将其集中到一个可复用的辅助函数中:
```python
from astrbot.core.message.components import Image
def normalize_images_from_chain(message_chain: MessageChain) -> list[dict[str, Any]]:
import base64
images: list[dict[str, Any]] = []
for comp in message_chain.chain:
if not isinstance(comp, Image) or not comp.file:
continue
info: dict[str, Any] = {}
if comp.file.startswith("http"):
info["type"] = "url"
info["url"] = comp.file
elif comp.file.startswith("file:///"):
info["type"] = "file"
file_path = comp.file[8:]
info["path"] = file_path
try:
with open(file_path, "rb") as f:
raw = f.read()
info["base64_data"] = base64.b64encode(raw).decode("utf-8")
info["size"] = len(raw)
except Exception as e:
logger.error("[ERROR] Failed to read image file %s: %s", file_path, e)
info["error"] = str(e)
elif comp.file.startswith("base64://"):
info["type"] = "base64"
base64_data = comp.file[9:]
info["base64_data"] = base64_data
info["base64_length"] = len(base64_data)
images.append(info)
return images
```
在 `_handle_socket_client` 中使用:
```python
from .image_utils import normalize_images_from_chain # or local helper
...
message_chain = await asyncio.wait_for(response_future, timeout=30.0)
response_text = message_chain.get_plain_text()
images = normalize_images_from_chain(message_chain)
response = json.dumps(
{
"status": "success",
"response": response_text,
"images": images,
"request_id": request_id,
},
ensure_ascii=False,
)
await loop.sock_sendall(client_socket, response.encode("utf-8"))
```
并在 `CLIMessageEvent` 或其他复用该逻辑的代码路径中同样复用该辅助函数。
---
### 4. 将 `_handle_socket_client` 拆分为更小的步骤
在不改变行为的前提下,你可以把该方法拆分成解析 / 处理 / 序列化等辅助函数,使其更易测试且更短:
```python
def _parse_socket_request(self, raw: bytes) -> tuple[dict[str, Any], str, str]:
import json
request = json.loads(raw.decode("utf-8"))
message_text = request.get("message", "")
request_id = request.get("request_id", str(uuid.uuid4()))
return request, message_text, request_id
def _serialize_error(self, request_id: str, msg: str) -> bytes:
import json
return json.dumps(
{"status": "error", "error": msg, "request_id": request_id},
ensure_ascii=False,
).encode("utf-8")
def _serialize_success(self, request_id: str, chain: MessageChain) -> bytes:
import json
response_text = chain.get_plain_text()
images = normalize_images_from_chain(chain)
return json.dumps(
{
"status": "success",
"response": response_text,
"images": images,
"request_id": request_id,
},
ensure_ascii=False,
).encode("utf-8")
```
在 `_handle_socket_client` 中使用它们:
```python
async def _handle_socket_client(self, client_socket) -> None:
import json
logger.debug("[ENTRY] _handle_socket_client")
loop = asyncio.get_event_loop()
try:
data = await loop.sock_recv(client_socket, 4096)
if not data:
logger.debug("[PROCESS] Empty request, closing connection")
return
try:
_, message_text, request_id = self._parse_socket_request(data)
except json.JSONDecodeError:
await loop.sock_sendall(
client_socket, self._serialize_error("", "Invalid JSON format")
)
return
response_future: asyncio.Future[MessageChain] = asyncio.Future()
message = self._convert_input(message_text, request_id=request_id)
message_event = CLIMessageEvent(
message_str=message.message_str,
message_obj=message,
platform_meta=self.meta(),
session_id=message.session_id,
output_queue=self._output_queue,
response_future=response_future,
)
self.commit_event(message_event)
try:
chain = await asyncio.wait_for(response_future, timeout=30.0)
await loop.sock_sendall(
client_socket, self._serialize_success(request_id, chain)
)
except asyncio.TimeoutError:
await loop.sock_sendall(
client_socket, self._serialize_error(request_id, "Request timeout")
)
except Exception as e:
logger.error("[ERROR] Socket client handler error: %s", e)
import traceback
logger.error(traceback.format_exc())
finally:
client_socket.close()
logger.debug("[EXIT] _handle_socket_client return=None")
```
这样可以在保持现有行为的同时,将这个较长的方法拆分为更小、可单独测试的步骤。
---
这些改动在保持现有特性和流程的前提下,可以:
- 将会话策略从消息转换逻辑中分离出来;
- 降低模式选择时的分支复杂度;
- 去除图片处理中的重复代码;
- 将套接字客户端处理逻辑拆分得更短、更清晰。
</issue_to_address>帮我变得更有用!请在每条评论上点击 👍 或 👎,我会根据这些反馈来改进后续的评审。
Original comment in English
Hey - I've found 2 issues, and left some high level feedback:
- The CLI adapter currently hardcodes paths like
/AstrBot/data/{config_file}and/tmp/astrbot.sock; consider deriving these from existing config/root-path utilities or allowing them to be overridden so the feature works in non-Docker or non-default deployments. - The whitelist check unconditionally bypasses validation for the
cliplatform; if you expect some environments to use CLI in a more controlled way, it might be safer to gate this behavior behind a config flag rather than hardcoding the exemption. - For the Unix socket server, you may want to explicitly control file permissions/ownership (e.g., via
os.chmodor umask) and add minimal framing/size checks on the JSON payload to avoid issues with partial or oversized reads when multiple clients connect.
Prompt for AI Agents
Please address the comments from this code review:
## Overall Comments
- The CLI adapter currently hardcodes paths like `/AstrBot/data/{config_file}` and `/tmp/astrbot.sock`; consider deriving these from existing config/root-path utilities or allowing them to be overridden so the feature works in non-Docker or non-default deployments.
- The whitelist check unconditionally bypasses validation for the `cli` platform; if you expect some environments to use CLI in a more controlled way, it might be safer to gate this behavior behind a config flag rather than hardcoding the exemption.
- For the Unix socket server, you may want to explicitly control file permissions/ownership (e.g., via `os.chmod` or umask) and add minimal framing/size checks on the JSON payload to avoid issues with partial or oversized reads when multiple clients connect.
## Individual Comments
### Comment 1
<location> `astrbot/core/platform/sources/cli/cli_adapter.py:302` </location>
<code_context>
+ while self._running:
+ try:
+ # 接受连接(非阻塞)
+ loop = asyncio.get_event_loop()
+ client_socket, _ = await loop.sock_accept(server_socket)
+
</code_context>
<issue_to_address>
**suggestion (bug_risk):** Using asyncio.get_event_loop() in async code is discouraged; asyncio.get_running_loop() is safer and future‑proof.
In Python 3.10+, `asyncio.get_event_loop()` is deprecated in async code and can return the wrong loop under some policies. In `_run_socket_mode` (and similarly in `_handle_socket_client` / `_read_input`), please use `asyncio.get_running_loop()` so you reliably get the loop for the current task.
Suggested implementation:
```python
# 接受连接(非阻塞)
loop = asyncio.get_running_loop()
client_socket, _ = await loop.sock_accept(server_socket)
```
Search in `astrbot/core/platform/sources/cli/cli_adapter.py` for other occurrences of `asyncio.get_event_loop()` used inside `async def` functions, especially in `_handle_socket_client` and `_read_input`, and replace them similarly:
- Change `loop = asyncio.get_event_loop()` to `loop = asyncio.get_running_loop()` where the surrounding context is async code running on the event loop.
If there are any uses of `asyncio.get_event_loop()` in purely synchronous initialization code (outside of an active event loop), those should be reviewed separately, as they may need a different pattern (e.g., explicitly creating a loop with `asyncio.new_event_loop()` rather than `get_running_loop()`).
</issue_to_address>
### Comment 2
<location> `astrbot/core/platform/sources/cli/cli_adapter.py:134` </location>
<code_context>
+ logger.info("[ENTRY] CLIPlatformAdapter.run inputs={}")
+ return self._run_loop()
+
+ async def _run_loop(self) -> None:
+ """主运行循环
+
</code_context>
<issue_to_address>
**issue (complexity):** Consider refactoring the CLI adapter by extracting small helpers, centralizing shared logic, and mapping modes to handlers to make the code easier to follow without changing behavior.
You can reduce complexity meaningfully without changing behavior by extracting a few focused helpers and consolidating duplicated logic. Here are concrete, localized refactors that keep the existing design but make it easier to reason about.
---
### 1. Simplify mode selection with a strategy map
`_run_loop` currently mixes TTY detection and branching logic inline. You can pull the mode resolution into a small helper and map modes to callables:
```python
# in __init__
self._mode_handlers: dict[str, callable[[], Awaitable[None]]] = {
"tty": self._run_tty_mode,
"file": self._run_file_mode,
"socket": self._run_socket_mode,
}
def _resolve_mode(self) -> str:
has_tty = sys.stdin.isatty()
if self.mode == "auto":
return "file" if not has_tty else "tty"
if self.mode in ("tty", "file", "socket"):
if self.mode == "tty" and not has_tty:
logger.warning(
"[PROCESS] TTY mode requested but no TTY detected. "
"CLI platform will not start."
)
return "" # or None
return self.mode
logger.error("[ERROR] Unknown mode: %s", self.mode)
return ""
```
```python
async def _run_loop(self) -> None:
logger.info("[PROCESS] Starting CLI loop")
if self.use_isolated_sessions:
self._cleanup_task = asyncio.create_task(self._cleanup_expired_sessions())
mode = self._resolve_mode()
if not mode:
return
handler = self._mode_handlers.get(mode)
if handler:
logger.info("[PROCESS] Starting %s mode", mode)
await handler()
```
This flattens the branching and localizes TTY logic without changing behavior.
---
### 2. Extract a small `SessionTracker` helper
`_convert_input` and `_cleanup_expired_sessions` both manipulate `_session_timestamps`. Move this into a tiny helper to decouple message construction from session lifecycle:
```python
class _SessionTracker:
def __init__(self, ttl: int) -> None:
import time
self._ttl = ttl
self._timestamps: dict[str, float] = {}
self._time = time
def ensure_session(self, base_session_id: str, request_id: str | None) -> str:
if request_id is None:
return base_session_id
session_id = f"{base_session_id}_{request_id}"
if session_id not in self._timestamps:
self._timestamps[session_id] = self._time.time()
return session_id
def collect_expired(self) -> list[str]:
now = self._time.time()
expired = [
s for s, ts in list(self._timestamps.items())
if now - ts > self._ttl
]
for s in expired:
self._timestamps.pop(s, None)
return expired
```
Wire it into the adapter:
```python
# __init__
self._session_tracker = _SessionTracker(self.session_ttl)
```
```python
def _convert_input(self, text: str, request_id: str | None = None) -> AstrBotMessage:
...
if self.use_isolated_sessions and request_id:
message.session_id = self._session_tracker.ensure_session(
base_session_id="cli_session",
request_id=request_id,
)
else:
message.session_id = self.session_id
...
```
```python
async def _cleanup_expired_sessions(self) -> None:
logger.info("[ENTRY] _cleanup_expired_sessions started, TTL=%s seconds", self.session_ttl)
while self._running:
try:
await asyncio.sleep(10)
if not self.use_isolated_sessions:
continue
expired_sessions = self._session_tracker.collect_expired()
for session_id in expired_sessions:
logger.info("[PROCESS] Cleaning expired session: %s", session_id)
# TODO: DB cleanup if needed
if expired_sessions:
logger.info("[PROCESS] Cleaned %d expired sessions", len(expired_sessions))
except Exception as e:
logger.error("[ERROR] Session cleanup error: %s", e)
```
This removes direct mutation of `_session_timestamps` from multiple places and keeps session policy in one spot.
---
### 3. Deduplicate image extraction/normalization
You already do image extraction and base64 conversion in `_handle_socket_client`. If `CLIMessageEvent.send` (or similar) has overlapping logic, centralize it in a reusable helper:
```python
from astrbot.core.message.components import Image
def normalize_images_from_chain(message_chain: MessageChain) -> list[dict[str, Any]]:
import base64
images: list[dict[str, Any]] = []
for comp in message_chain.chain:
if not isinstance(comp, Image) or not comp.file:
continue
info: dict[str, Any] = {}
if comp.file.startswith("http"):
info["type"] = "url"
info["url"] = comp.file
elif comp.file.startswith("file:///"):
info["type"] = "file"
file_path = comp.file[8:]
info["path"] = file_path
try:
with open(file_path, "rb") as f:
raw = f.read()
info["base64_data"] = base64.b64encode(raw).decode("utf-8")
info["size"] = len(raw)
except Exception as e:
logger.error("[ERROR] Failed to read image file %s: %s", file_path, e)
info["error"] = str(e)
elif comp.file.startswith("base64://"):
info["type"] = "base64"
base64_data = comp.file[9:]
info["base64_data"] = base64_data
info["base64_length"] = len(base64_data)
images.append(info)
return images
```
Then in `_handle_socket_client`:
```python
from .image_utils import normalize_images_from_chain # or local helper
...
message_chain = await asyncio.wait_for(response_future, timeout=30.0)
response_text = message_chain.get_plain_text()
images = normalize_images_from_chain(message_chain)
response = json.dumps(
{
"status": "success",
"response": response_text,
"images": images,
"request_id": request_id,
},
ensure_ascii=False,
)
await loop.sock_sendall(client_socket, response.encode("utf-8"))
```
And reuse the same helper wherever you currently replicate that logic in `CLIMessageEvent` or other code paths.
---
### 4. Factor `_handle_socket_client` into smaller steps
Without changing behavior, you can split the method into parsing / processing / serializing helpers to make it testable and shorter:
```python
def _parse_socket_request(self, raw: bytes) -> tuple[dict[str, Any], str, str]:
import json
request = json.loads(raw.decode("utf-8"))
message_text = request.get("message", "")
request_id = request.get("request_id", str(uuid.uuid4()))
return request, message_text, request_id
def _serialize_error(self, request_id: str, msg: str) -> bytes:
import json
return json.dumps(
{"status": "error", "error": msg, "request_id": request_id},
ensure_ascii=False,
).encode("utf-8")
def _serialize_success(self, request_id: str, chain: MessageChain) -> bytes:
import json
response_text = chain.get_plain_text()
images = normalize_images_from_chain(chain)
return json.dumps(
{
"status": "success",
"response": response_text,
"images": images,
"request_id": request_id,
},
ensure_ascii=False,
).encode("utf-8")
```
Use them in `_handle_socket_client`:
```python
async def _handle_socket_client(self, client_socket) -> None:
import json
logger.debug("[ENTRY] _handle_socket_client")
loop = asyncio.get_event_loop()
try:
data = await loop.sock_recv(client_socket, 4096)
if not data:
logger.debug("[PROCESS] Empty request, closing connection")
return
try:
_, message_text, request_id = self._parse_socket_request(data)
except json.JSONDecodeError:
await loop.sock_sendall(
client_socket, self._serialize_error("", "Invalid JSON format")
)
return
response_future: asyncio.Future[MessageChain] = asyncio.Future()
message = self._convert_input(message_text, request_id=request_id)
message_event = CLIMessageEvent(
message_str=message.message_str,
message_obj=message,
platform_meta=self.meta(),
session_id=message.session_id,
output_queue=self._output_queue,
response_future=response_future,
)
self.commit_event(message_event)
try:
chain = await asyncio.wait_for(response_future, timeout=30.0)
await loop.sock_sendall(
client_socket, self._serialize_success(request_id, chain)
)
except asyncio.TimeoutError:
await loop.sock_sendall(
client_socket, self._serialize_error(request_id, "Request timeout")
)
except Exception as e:
logger.error("[ERROR] Socket client handler error: %s", e)
import traceback
logger.error(traceback.format_exc())
finally:
client_socket.close()
logger.debug("[EXIT] _handle_socket_client return=None")
```
This keeps the same behavior but breaks the long method into smaller, independently testable pieces.
---
These changes preserve the existing feature set and flow, but:
- Isolate session policy from message conversion.
- Reduce branching complexity in mode selection.
- Remove duplication in image handling.
- Shorten and clarify the socket client handler into clear sub-steps.
</issue_to_address>Help me be more useful! Please click 👍 or 👎 on each comment and I'll use the feedback to improve your reviews.
| logger.info("[ENTRY] CLIPlatformAdapter.run inputs={}") | ||
| return self._run_loop() | ||
|
|
||
| async def _run_loop(self) -> None: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
issue (complexity): 建议通过抽取小的辅助函数、集中复用逻辑以及将模式映射到处理函数的方式来重构 CLI 适配器,使代码在不改变行为的前提下更易理解。
在不改变行为的前提下,你可以通过抽取一些聚焦的小型辅助函数并整合重复逻辑,显著降低复杂度。下面是一些具体、局部化的重构建议,可以保留现有设计,但让代码更易于推理。
1. 使用策略映射简化模式选择
_run_loop 目前在函数体中同时包含 TTY 检测和分支逻辑。你可以将模式解析抽取到一个小的辅助函数中,并将模式映射到可调用对象:
# in __init__
self._mode_handlers: dict[str, callable[[], Awaitable[None]]] = {
"tty": self._run_tty_mode,
"file": self._run_file_mode,
"socket": self._run_socket_mode,
}
def _resolve_mode(self) -> str:
has_tty = sys.stdin.isatty()
if self.mode == "auto":
return "file" if not has_tty else "tty"
if self.mode in ("tty", "file", "socket"):
if self.mode == "tty" and not has_tty:
logger.warning(
"[PROCESS] TTY mode requested but no TTY detected. "
"CLI platform will not start."
)
return "" # or None
return self.mode
logger.error("[ERROR] Unknown mode: %s", self.mode)
return ""async def _run_loop(self) -> None:
logger.info("[PROCESS] Starting CLI loop")
if self.use_isolated_sessions:
self._cleanup_task = asyncio.create_task(self._cleanup_expired_sessions())
mode = self._resolve_mode()
if not mode:
return
handler = self._mode_handlers.get(mode)
if handler:
logger.info("[PROCESS] Starting %s mode", mode)
await handler()这样可以在不改变行为的前提下,压平分支结构,并把 TTY 逻辑局部化。
2. 抽取一个小型 SessionTracker 辅助类
_convert_input 和 _cleanup_expired_sessions 都会操作 _session_timestamps。将它们移入一个小型辅助类中,可以将消息构建和会话生命周期解耦:
class _SessionTracker:
def __init__(self, ttl: int) -> None:
import time
self._ttl = ttl
self._timestamps: dict[str, float] = {}
self._time = time
def ensure_session(self, base_session_id: str, request_id: str | None) -> str:
if request_id is None:
return base_session_id
session_id = f"{base_session_id}_{request_id}"
if session_id not in self._timestamps:
self._timestamps[session_id] = self._time.time()
return session_id
def collect_expired(self) -> list[str]:
now = self._time.time()
expired = [
s for s, ts in list(self._timestamps.items())
if now - ts > self._ttl
]
for s in expired:
self._timestamps.pop(s, None)
return expired将其接入适配器:
# __init__
self._session_tracker = _SessionTracker(self.session_ttl)def _convert_input(self, text: str, request_id: str | None = None) -> AstrBotMessage:
...
if self.use_isolated_sessions and request_id:
message.session_id = self._session_tracker.ensure_session(
base_session_id="cli_session",
request_id=request_id,
)
else:
message.session_id = self.session_id
...async def _cleanup_expired_sessions(self) -> None:
logger.info("[ENTRY] _cleanup_expired_sessions started, TTL=%s seconds", self.session_ttl)
while self._running:
try:
await asyncio.sleep(10)
if not self.use_isolated_sessions:
continue
expired_sessions = self._session_tracker.collect_expired()
for session_id in expired_sessions:
logger.info("[PROCESS] Cleaning expired session: %s", session_id)
# TODO: DB cleanup if needed
if expired_sessions:
logger.info("[PROCESS] Cleaned %d expired sessions", len(expired_sessions))
except Exception as e:
logger.error("[ERROR] Session cleanup error: %s", e)这样可以避免在多个位置直接修改 _session_timestamps,并将会话策略集中到一个地方。
3. 去重图片提取/归一化逻辑
你已经在 _handle_socket_client 中实现了图片提取和 base64 转换。如果 CLIMessageEvent.send(或类似逻辑)中有重叠代码,可以将其集中到一个可复用的辅助函数中:
from astrbot.core.message.components import Image
def normalize_images_from_chain(message_chain: MessageChain) -> list[dict[str, Any]]:
import base64
images: list[dict[str, Any]] = []
for comp in message_chain.chain:
if not isinstance(comp, Image) or not comp.file:
continue
info: dict[str, Any] = {}
if comp.file.startswith("http"):
info["type"] = "url"
info["url"] = comp.file
elif comp.file.startswith("file:///"):
info["type"] = "file"
file_path = comp.file[8:]
info["path"] = file_path
try:
with open(file_path, "rb") as f:
raw = f.read()
info["base64_data"] = base64.b64encode(raw).decode("utf-8")
info["size"] = len(raw)
except Exception as e:
logger.error("[ERROR] Failed to read image file %s: %s", file_path, e)
info["error"] = str(e)
elif comp.file.startswith("base64://"):
info["type"] = "base64"
base64_data = comp.file[9:]
info["base64_data"] = base64_data
info["base64_length"] = len(base64_data)
images.append(info)
return images在 _handle_socket_client 中使用:
from .image_utils import normalize_images_from_chain # or local helper
...
message_chain = await asyncio.wait_for(response_future, timeout=30.0)
response_text = message_chain.get_plain_text()
images = normalize_images_from_chain(message_chain)
response = json.dumps(
{
"status": "success",
"response": response_text,
"images": images,
"request_id": request_id,
},
ensure_ascii=False,
)
await loop.sock_sendall(client_socket, response.encode("utf-8"))并在 CLIMessageEvent 或其他复用该逻辑的代码路径中同样复用该辅助函数。
4. 将 _handle_socket_client 拆分为更小的步骤
在不改变行为的前提下,你可以把该方法拆分成解析 / 处理 / 序列化等辅助函数,使其更易测试且更短:
def _parse_socket_request(self, raw: bytes) -> tuple[dict[str, Any], str, str]:
import json
request = json.loads(raw.decode("utf-8"))
message_text = request.get("message", "")
request_id = request.get("request_id", str(uuid.uuid4()))
return request, message_text, request_id
def _serialize_error(self, request_id: str, msg: str) -> bytes:
import json
return json.dumps(
{"status": "error", "error": msg, "request_id": request_id},
ensure_ascii=False,
).encode("utf-8")
def _serialize_success(self, request_id: str, chain: MessageChain) -> bytes:
import json
response_text = chain.get_plain_text()
images = normalize_images_from_chain(chain)
return json.dumps(
{
"status": "success",
"response": response_text,
"images": images,
"request_id": request_id,
},
ensure_ascii=False,
).encode("utf-8")在 _handle_socket_client 中使用它们:
async def _handle_socket_client(self, client_socket) -> None:
import json
logger.debug("[ENTRY] _handle_socket_client")
loop = asyncio.get_event_loop()
try:
data = await loop.sock_recv(client_socket, 4096)
if not data:
logger.debug("[PROCESS] Empty request, closing connection")
return
try:
_, message_text, request_id = self._parse_socket_request(data)
except json.JSONDecodeError:
await loop.sock_sendall(
client_socket, self._serialize_error("", "Invalid JSON format")
)
return
response_future: asyncio.Future[MessageChain] = asyncio.Future()
message = self._convert_input(message_text, request_id=request_id)
message_event = CLIMessageEvent(
message_str=message.message_str,
message_obj=message,
platform_meta=self.meta(),
session_id=message.session_id,
output_queue=self._output_queue,
response_future=response_future,
)
self.commit_event(message_event)
try:
chain = await asyncio.wait_for(response_future, timeout=30.0)
await loop.sock_sendall(
client_socket, self._serialize_success(request_id, chain)
)
except asyncio.TimeoutError:
await loop.sock_sendall(
client_socket, self._serialize_error(request_id, "Request timeout")
)
except Exception as e:
logger.error("[ERROR] Socket client handler error: %s", e)
import traceback
logger.error(traceback.format_exc())
finally:
client_socket.close()
logger.debug("[EXIT] _handle_socket_client return=None")这样可以在保持现有行为的同时,将这个较长的方法拆分为更小、可单独测试的步骤。
这些改动在保持现有特性和流程的前提下,可以:
- 将会话策略从消息转换逻辑中分离出来;
- 降低模式选择时的分支复杂度;
- 去除图片处理中的重复代码;
- 将套接字客户端处理逻辑拆分得更短、更清晰。
Original comment in English
issue (complexity): Consider refactoring the CLI adapter by extracting small helpers, centralizing shared logic, and mapping modes to handlers to make the code easier to follow without changing behavior.
You can reduce complexity meaningfully without changing behavior by extracting a few focused helpers and consolidating duplicated logic. Here are concrete, localized refactors that keep the existing design but make it easier to reason about.
1. Simplify mode selection with a strategy map
_run_loop currently mixes TTY detection and branching logic inline. You can pull the mode resolution into a small helper and map modes to callables:
# in __init__
self._mode_handlers: dict[str, callable[[], Awaitable[None]]] = {
"tty": self._run_tty_mode,
"file": self._run_file_mode,
"socket": self._run_socket_mode,
}
def _resolve_mode(self) -> str:
has_tty = sys.stdin.isatty()
if self.mode == "auto":
return "file" if not has_tty else "tty"
if self.mode in ("tty", "file", "socket"):
if self.mode == "tty" and not has_tty:
logger.warning(
"[PROCESS] TTY mode requested but no TTY detected. "
"CLI platform will not start."
)
return "" # or None
return self.mode
logger.error("[ERROR] Unknown mode: %s", self.mode)
return ""async def _run_loop(self) -> None:
logger.info("[PROCESS] Starting CLI loop")
if self.use_isolated_sessions:
self._cleanup_task = asyncio.create_task(self._cleanup_expired_sessions())
mode = self._resolve_mode()
if not mode:
return
handler = self._mode_handlers.get(mode)
if handler:
logger.info("[PROCESS] Starting %s mode", mode)
await handler()This flattens the branching and localizes TTY logic without changing behavior.
2. Extract a small SessionTracker helper
_convert_input and _cleanup_expired_sessions both manipulate _session_timestamps. Move this into a tiny helper to decouple message construction from session lifecycle:
class _SessionTracker:
def __init__(self, ttl: int) -> None:
import time
self._ttl = ttl
self._timestamps: dict[str, float] = {}
self._time = time
def ensure_session(self, base_session_id: str, request_id: str | None) -> str:
if request_id is None:
return base_session_id
session_id = f"{base_session_id}_{request_id}"
if session_id not in self._timestamps:
self._timestamps[session_id] = self._time.time()
return session_id
def collect_expired(self) -> list[str]:
now = self._time.time()
expired = [
s for s, ts in list(self._timestamps.items())
if now - ts > self._ttl
]
for s in expired:
self._timestamps.pop(s, None)
return expiredWire it into the adapter:
# __init__
self._session_tracker = _SessionTracker(self.session_ttl)def _convert_input(self, text: str, request_id: str | None = None) -> AstrBotMessage:
...
if self.use_isolated_sessions and request_id:
message.session_id = self._session_tracker.ensure_session(
base_session_id="cli_session",
request_id=request_id,
)
else:
message.session_id = self.session_id
...async def _cleanup_expired_sessions(self) -> None:
logger.info("[ENTRY] _cleanup_expired_sessions started, TTL=%s seconds", self.session_ttl)
while self._running:
try:
await asyncio.sleep(10)
if not self.use_isolated_sessions:
continue
expired_sessions = self._session_tracker.collect_expired()
for session_id in expired_sessions:
logger.info("[PROCESS] Cleaning expired session: %s", session_id)
# TODO: DB cleanup if needed
if expired_sessions:
logger.info("[PROCESS] Cleaned %d expired sessions", len(expired_sessions))
except Exception as e:
logger.error("[ERROR] Session cleanup error: %s", e)This removes direct mutation of _session_timestamps from multiple places and keeps session policy in one spot.
3. Deduplicate image extraction/normalization
You already do image extraction and base64 conversion in _handle_socket_client. If CLIMessageEvent.send (or similar) has overlapping logic, centralize it in a reusable helper:
from astrbot.core.message.components import Image
def normalize_images_from_chain(message_chain: MessageChain) -> list[dict[str, Any]]:
import base64
images: list[dict[str, Any]] = []
for comp in message_chain.chain:
if not isinstance(comp, Image) or not comp.file:
continue
info: dict[str, Any] = {}
if comp.file.startswith("http"):
info["type"] = "url"
info["url"] = comp.file
elif comp.file.startswith("file:///"):
info["type"] = "file"
file_path = comp.file[8:]
info["path"] = file_path
try:
with open(file_path, "rb") as f:
raw = f.read()
info["base64_data"] = base64.b64encode(raw).decode("utf-8")
info["size"] = len(raw)
except Exception as e:
logger.error("[ERROR] Failed to read image file %s: %s", file_path, e)
info["error"] = str(e)
elif comp.file.startswith("base64://"):
info["type"] = "base64"
base64_data = comp.file[9:]
info["base64_data"] = base64_data
info["base64_length"] = len(base64_data)
images.append(info)
return imagesThen in _handle_socket_client:
from .image_utils import normalize_images_from_chain # or local helper
...
message_chain = await asyncio.wait_for(response_future, timeout=30.0)
response_text = message_chain.get_plain_text()
images = normalize_images_from_chain(message_chain)
response = json.dumps(
{
"status": "success",
"response": response_text,
"images": images,
"request_id": request_id,
},
ensure_ascii=False,
)
await loop.sock_sendall(client_socket, response.encode("utf-8"))And reuse the same helper wherever you currently replicate that logic in CLIMessageEvent or other code paths.
4. Factor _handle_socket_client into smaller steps
Without changing behavior, you can split the method into parsing / processing / serializing helpers to make it testable and shorter:
def _parse_socket_request(self, raw: bytes) -> tuple[dict[str, Any], str, str]:
import json
request = json.loads(raw.decode("utf-8"))
message_text = request.get("message", "")
request_id = request.get("request_id", str(uuid.uuid4()))
return request, message_text, request_id
def _serialize_error(self, request_id: str, msg: str) -> bytes:
import json
return json.dumps(
{"status": "error", "error": msg, "request_id": request_id},
ensure_ascii=False,
).encode("utf-8")
def _serialize_success(self, request_id: str, chain: MessageChain) -> bytes:
import json
response_text = chain.get_plain_text()
images = normalize_images_from_chain(chain)
return json.dumps(
{
"status": "success",
"response": response_text,
"images": images,
"request_id": request_id,
},
ensure_ascii=False,
).encode("utf-8")Use them in _handle_socket_client:
async def _handle_socket_client(self, client_socket) -> None:
import json
logger.debug("[ENTRY] _handle_socket_client")
loop = asyncio.get_event_loop()
try:
data = await loop.sock_recv(client_socket, 4096)
if not data:
logger.debug("[PROCESS] Empty request, closing connection")
return
try:
_, message_text, request_id = self._parse_socket_request(data)
except json.JSONDecodeError:
await loop.sock_sendall(
client_socket, self._serialize_error("", "Invalid JSON format")
)
return
response_future: asyncio.Future[MessageChain] = asyncio.Future()
message = self._convert_input(message_text, request_id=request_id)
message_event = CLIMessageEvent(
message_str=message.message_str,
message_obj=message,
platform_meta=self.meta(),
session_id=message.session_id,
output_queue=self._output_queue,
response_future=response_future,
)
self.commit_event(message_event)
try:
chain = await asyncio.wait_for(response_future, timeout=30.0)
await loop.sock_sendall(
client_socket, self._serialize_success(request_id, chain)
)
except asyncio.TimeoutError:
await loop.sock_sendall(
client_socket, self._serialize_error(request_id, "Request timeout")
)
except Exception as e:
logger.error("[ERROR] Socket client handler error: %s", e)
import traceback
logger.error(traceback.format_exc())
finally:
client_socket.close()
logger.debug("[EXIT] _handle_socket_client return=None")This keeps the same behavior but breaks the long method into smaller, independently testable pieces.
These changes preserve the existing feature set and flow, but:
- Isolate session policy from message conversion.
- Reduce branching complexity in mode selection.
- Remove duplication in image handling.
- Shorten and clarify the socket client handler into clear sub-steps.
- Implement automatic token generation on first startup - Add token validation in socket handler - Add token loading and sending in CLI client - Set token file permissions to 600 - Add security logging for rejected requests Token file: /AstrBot/data/.cli_token Algorithm: secrets.token_urlsafe(32)
- Replace hardcoded paths with dynamic path resolution - Use get_astrbot_data_path() for data directory - Use get_astrbot_temp_path() for temp directory - Support ASTRBOT_ROOT environment variable - Fix asyncio deprecation warnings - Replace asyncio.get_event_loop() with get_running_loop() - Improve Python 3.10+ compatibility - Add socket file permission control - Set socket permissions to 600 (owner-only) - Add security logging - Update astrbot-cli client - Add dynamic path resolution functions - Match server-side path logic - Improve cross-environment compatibility Addresses code review feedback from PR AstrBotDevs#4787
- Fix client to receive large responses by using loop recv instead of single recv(4096) - Save base64 images to temp files instead of exposing in JSON response - Implement adaptive delay mechanism for multi-round reply collection: * First send: 5s delay (fast response for simple text) * Subsequent sends: 10s delay (auto-switch for tool invocation) - Support tool invocation scenarios with multiple replies (text + images) - Add detailed logging for debugging multi-round reply collection
- Add repository check to dashboard_ci.yml - Only create release in main repository (AstrBotDevs/AstrBot) - Prevents token error in forked repositories
- Fix invalid parameter 'fetch-tag' in docker-image.yml - Should be 'fetch-tags' (plural) according to actions/checkout@v6 docs - Resolves 'Unexpected input' warning in GitHub Actions
Motivation / 动机
构建快速反馈循环,支持 Vibe Coding 开发模式
传统插件开发流程需要:编写代码 → 重启机器人 → 登录QQ → 发送测试消息 → 查看结果 → 修改代码...
这个流程存在的问题:
CLI Tester 解决方案:编写代码 → 命令行测试 → 立即查看结果 → 修改代码...
核心价值:
Modifications / 改动点
新增文件:
astrbot/core/platform/sources/cli/cli_adapter.py(27KB) - CLI测试器核心适配器astrbot/core/platform/sources/cli/cli_event.py(4.2KB) - CLI事件处理astrbot/core/platform/sources/cli/__init__.py(255B) - 模块导出astrbot-cli(3.9KB) - 命令行客户端工具修改文件:
astrbot/core/platform/manager.py- 添加CLI平台导入分支astrbot/core/pipeline/whitelist_check/stage.py- 添加CLI平台白名单豁免核心功能:
/tmp/astrbot.sock提供同步请求-响应技术特点:
📦 零外部依赖:仅使用Python标准库
💾 轻量级:总计约35KB
🔒 默认关闭:开发时手动启用,不影响生产环境
🏗️ 遵循Unix哲学:原子化模块、显式I/O、管道编排
This is NOT a breaking change. / 这不是一个破坏性变更。
Usage / 使用方法
1. 启用CLI Tester
在管理面板中启用CLI平台,或修改配置文件:
{ "type": "cli", "enable": true, "mode": "socket", "socket_path": "/tmp/astrbot.sock" }2. 基础测试
3. 创建全局命令(可选)
Screenshots or Test Results / 运行截图或测试结果
测试1:基础消息测试
测试2:插件命令测试
测试3:富媒体支持(图片)
测试4:配置文件加载
日志显示配置文件成功加载:
测试5:会话隔离与并发
Checklist / 检查清单
requirements.txt和pyproject.toml文件相应位置。/ I have ensured that no new dependencies are introduced, OR if new dependencies are introduced, they have been added to the appropriate locations inrequirements.txtandpyproject.toml.Additional Notes / 补充说明
设计理念:
CLI Tester 的设计灵感来自于现代开发工具的"快速反馈循环"理念。通过消除登录IM平台这一步骤,开发者可以:
这种工作流程特别适合:
未来扩展:
相关文档:
CLI_README.mdCLAUDE.mdSummary by Sourcery
引入一个基于 CLI 的平台适配器,用于在不连接到任何即时通讯(IM)平台的前提下,对 AstrBot 插件进行快速本地测试和调试。
新功能:
astrbot-cli),用于向 AstrBot 发送消息并接收结构化响应,包括在套接字接口上对富媒体负载的支持。增强改进:
Original summary in English
Summary by Sourcery
Introduce a CLI-based platform adapter to enable fast local testing and debugging of AstrBot plugins without connecting to an IM platform.
New Features:
Enhancements: