From 9b9c21867371620a55f403c926595134c08d1c7f Mon Sep 17 00:00:00 2001 From: Sian Cao Date: Sun, 22 Mar 2026 17:16:48 +0800 Subject: [PATCH] Refactor tool preflight protocol --- src/aish/llm.py | 335 +++++++++------------ src/aish/shell_enhanced/shell_prompt_io.py | 56 ++-- src/aish/tools/base.py | 134 ++++++++- src/aish/tools/code_exec.py | 112 ++++++- src/aish/tools/fs_tools.py | 48 ++- tests/test_ask_user_tool.py | 251 +++++++++++++-- tests/test_shell_prompt_io.py | 37 ++- 7 files changed, 707 insertions(+), 266 deletions(-) diff --git a/src/aish/llm.py b/src/aish/llm.py index 8b23284..dfc6e40 100644 --- a/src/aish/llm.py +++ b/src/aish/llm.py @@ -1,10 +1,12 @@ import json import logging +import os import threading import time import uuid from dataclasses import dataclass from enum import Enum +from pathlib import Path from typing import Callable, Optional import anyio @@ -14,13 +16,18 @@ from aish.config import ConfigModel from aish.context_manager import ContextManager, MemoryType from aish.exception import is_litellm_exception, redact_secrets -from aish.i18n import t from aish.interruption import ShellState from aish.litellm_loader import load_litellm from aish.providers.registry import get_provider_for_model from aish.prompts import PromptManager from aish.skills import SkillManager -from aish.tools.base import ToolBase +from aish.tools.base import ( + ToolBase, + ToolExecutionContext, + ToolPanelSpec, + ToolPreflightAction, + ToolPreflightResult, +) from aish.tools.code_exec import BashTool, PythonTool from aish.tools.fs_tools import EditFileTool, ReadFileTool, WriteFileTool from aish.tools.result import ToolResult @@ -83,6 +90,19 @@ class LLMEvent: metadata: Optional[dict] = None +class ToolDispatchStatus(Enum): + EXECUTED = "executed" + SHORT_CIRCUIT = "short_circuit" + REJECTED = "rejected" + CANCELLED = "cancelled" + + +@dataclass +class ToolDispatchOutcome: + status: ToolDispatchStatus + result: ToolResult + + def normalize_tool_result(value: object) -> ToolResult: if isinstance(value, ToolResult): return value @@ -851,203 +871,128 @@ async def execute_tool( ) return tool_result + def _build_tool_panel_event_data( + self, + *, + tool: ToolBase, + tool_name: str, + tool_args: dict, + panel: ToolPanelSpec, + ) -> dict: + panel_payload = panel.to_event_payload() + data: dict = { + "tool_name": tool_name, + "tool_args": tool_args, + "description": tool.description, + "panel": panel_payload, + # Temporary top-level mirror for transition/debugging. + "panel_mode": panel.mode, + } + for key in ( + "target", + "preview", + "analysis", + "allow_remember", + "remember_key", + "title", + ): + if key in panel_payload: + data[key] = panel_payload[key] + return data + async def pre_execute_tool( self, tool_name: str, tool_args: dict - ) -> tuple[LLMCallbackResult, ToolResult]: + ) -> ToolDispatchOutcome: try: if tool_name not in self.tools: - return ( - LLMCallbackResult.CONTINUE, - ToolResult( + return ToolDispatchOutcome( + status=ToolDispatchStatus.REJECTED, + result=ToolResult( ok=False, output=f"Error: Invalid tool name: {tool_name}", ), ) tool = self.tools[tool_name] - - # For BashTool, pass the code to check if confirmation is needed - if tool_name == "bash_exec": - arg = tool_args.get("code") - elif tool_name == "write_file": - arg = tool_args.get("content") - elif tool_name == "edit_file": - arg = tool_args - else: - arg = None - - # 使用 to_thread.run 来避免阻塞事件循环 - # 这样信号处理器可以在沙箱评估期间运行 - need_confirm = await anyio.to_thread.run_sync( - tool.need_confirm_before_exec, arg + context = ToolExecutionContext( + cwd=Path(os.getcwd()).resolve(), + cancellation_token=self.cancellation_token, + interruption_manager=self.interruption_manager, + is_approved=self.is_command_approved, ) + preflight = await anyio.to_thread.run_sync( + tool.prepare_invocation, tool_args, context + ) + if not isinstance(preflight, ToolPreflightResult): + preflight = ToolPreflightResult() - confirmation_info: dict = {} - security_analysis: dict = {} - security_decision: dict = {} - suppress_security_panels = False - - if tool_name == "bash_exec" and arg is not None: - try: - confirmation_info = tool.get_confirmation_info(arg) # type: ignore[assignment] - security_analysis = ( - confirmation_info.get("security_analysis", {}) - if isinstance(confirmation_info, dict) - else {} - ) - security_decision = ( - confirmation_info.get("security_decision", {}) - if isinstance(confirmation_info, dict) - else {} - ) - except Exception: - confirmation_info = {} - security_analysis = {} - security_decision = {} - - # 1) fail-open when sandbox is unavailable (cannot assess risks) - if ( - isinstance(security_analysis, dict) - and security_analysis.get("fail_open") is True - ): - need_confirm = False - suppress_security_panels = True - - # 2) exact-match allowlist - if callable(self.is_command_approved) and self.is_command_approved(arg): - need_confirm = False - suppress_security_panels = True - - if ( - isinstance(security_decision, dict) - and security_decision.get("allow") is False - ): - need_confirm = False - - # 在沙箱评估后检查是否被取消 - # 如果被取消,直接返回 CANCEL,不继续执行 if self.cancellation_token and self.cancellation_token.is_cancelled(): if self.interruption_manager: self.interruption_manager.set_state(ShellState.NORMAL) - return ( - LLMCallbackResult.CANCEL, - ToolResult( + return ToolDispatchOutcome( + status=ToolDispatchStatus.CANCELLED, + result=ToolResult( ok=False, output="Operation cancelled during security evaluation", ), ) - # 安全提示面板:对 AI 生成命令的风险评估结果进行展示。 - # - MEDIUM 且需要确认:走确认面板 + y/n - # - LOW:展示 info 面板,不阻塞 - # - HIGH:展示 blocked 面板,不出现确认;仍由工具自身/安全系统直接阻断 - if ( - tool_name == "bash_exec" - and arg is not None - and not suppress_security_panels - ): - try: - info = confirmation_info or tool.get_confirmation_info(arg) - security_analysis = ( - info.get("security_analysis", {}) - if isinstance(info, dict) - else {} - ) - sandbox_info = ( - security_analysis.get("sandbox", {}) - if isinstance(security_analysis, dict) - else {} - ) - sandbox_reason = str( - sandbox_info.get("reason", "") - if isinstance(sandbox_info, dict) - else "" - ) - fallback_rule_matched = bool( - security_analysis.get("fallback_rule_matched") - ) if isinstance(security_analysis, dict) else False - skip_confirmation_panel = sandbox_reason in { - "sandbox_disabled", - "sandbox_disabled_by_policy", - } and not fallback_rule_matched - risk_level = str( - security_analysis.get("risk_level", "UNKNOWN") - ).upper() - panel_mode = ( - "confirm" - if need_confirm - else ( - "blocked" if risk_level in {"HIGH", "CRITICAL"} else "info" - ) - ) - - if panel_mode != "confirm" and not skip_confirmation_panel: - notice_data = { - "tool_name": tool_name, - "tool_args": tool_args, - "description": tool.description, - "panel_mode": panel_mode, - **(info if isinstance(info, dict) else {}), - } - self.emit_event( - LLMEventType.TOOL_CONFIRMATION_REQUIRED, notice_data - ) - except Exception: - # 不影响工具执行;提示面板失败时静默跳过 - pass - - if ( - tool_name == "bash_exec" - and isinstance(security_decision, dict) - and security_decision.get("allow") is False - ): - reasons = ( - security_analysis.get("reasons", []) - if isinstance(security_analysis, dict) - else [] - ) - reason_text = ", ".join(str(reason) for reason in reasons[:5] if reason) - blocked_msg = ( - t("security.command_blocked_with_reason", reason=reason_text) - if reason_text - else t("security.command_blocked") + panel = preflight.panel + if panel is not None and panel.mode == "info": + self.emit_event( + LLMEventType.TOOL_CONFIRMATION_REQUIRED, + self._build_tool_panel_event_data( + tool=tool, + tool_name=tool_name, + tool_args=tool_args, + panel=panel, + ), ) - return ( - LLMCallbackResult.CONTINUE, - ToolResult( + + if preflight.action == ToolPreflightAction.SHORT_CIRCUIT: + if panel is not None and panel.mode == "blocked": + self.emit_event( + LLMEventType.TOOL_CONFIRMATION_REQUIRED, + self._build_tool_panel_event_data( + tool=tool, + tool_name=tool_name, + tool_args=tool_args, + panel=panel, + ), + ) + return ToolDispatchOutcome( + status=ToolDispatchStatus.SHORT_CIRCUIT, + result=preflight.result + or ToolResult( ok=False, - output=blocked_msg, - code=126, - meta={"kind": "security_blocked", "reasons": reasons}, + output=f"Tool {tool_name} short-circuited without a result", ), ) - if need_confirm: - # Prepare confirmation data with security information - confirmation_data = { - "tool_name": tool_name, - "tool_args": tool_args, - "description": tool.description, - "panel_mode": "confirm", - **(confirmation_info or tool.get_confirmation_info(arg)), - } - - # Request user confirmation + if preflight.action == ToolPreflightAction.CONFIRM: + confirm_panel = panel or ToolPanelSpec(mode="confirm") goon = self.request_confirmation( LLMEventType.TOOL_CONFIRMATION_REQUIRED, - confirmation_data, + self._build_tool_panel_event_data( + tool=tool, + tool_name=tool_name, + tool_args=tool_args, + panel=confirm_panel, + ), timeout_seconds=30.0, default_on_timeout=LLMCallbackResult.DENY, ) - # Handle confirmation result if goon == LLMCallbackResult.APPROVE: - return goon, await self.execute_tool(tool, tool_name, tool_args) + return ToolDispatchOutcome( + status=ToolDispatchStatus.EXECUTED, + result=await self.execute_tool(tool, tool_name, tool_args), + ) - elif goon == LLMCallbackResult.DENY: - return ( - goon, - ToolResult( + if goon == LLMCallbackResult.DENY: + return ToolDispatchOutcome( + status=ToolDispatchStatus.REJECTED, + result=ToolResult( ok=False, output=( f"Tool {tool_name} execution denied by user, you may " @@ -1056,45 +1001,41 @@ async def pre_execute_tool( ), ) - elif goon == LLMCallbackResult.CANCEL: - return ( - goon, - ToolResult( + if goon == LLMCallbackResult.CANCEL: + return ToolDispatchOutcome( + status=ToolDispatchStatus.CANCELLED, + result=ToolResult( ok=False, output=f"Tool {tool_name} execution cancelled by user", ), ) - else: - return ( - goon, - ToolResult( - ok=False, - output=f"Invalid confirmation result: {goon}", - ), - ) - else: - # No confirmation needed, execute directly - return ( - LLMCallbackResult.APPROVE, - await self.execute_tool(tool, tool_name, tool_args), + return ToolDispatchOutcome( + status=ToolDispatchStatus.REJECTED, + result=ToolResult( + ok=False, + output=f"Invalid confirmation result: {goon}", + ), ) + + return ToolDispatchOutcome( + status=ToolDispatchStatus.EXECUTED, + result=await self.execute_tool(tool, tool_name, tool_args), + ) except KeyboardInterrupt: - # 用户中断(例如在沙箱评估期间按 Ctrl+C) - # 恢复状态并返回取消结果 if self.interruption_manager: self.interruption_manager.set_state(ShellState.NORMAL) - return ( - LLMCallbackResult.CANCEL, - ToolResult( + return ToolDispatchOutcome( + status=ToolDispatchStatus.CANCELLED, + result=ToolResult( ok=False, output="Operation cancelled by user", ), ) except Exception as e: - return ( - LLMCallbackResult.CONTINUE, - ToolResult( + return ToolDispatchOutcome( + status=ToolDispatchStatus.REJECTED, + result=ToolResult( ok=False, output=str(e), meta={"exception_type": type(e).__name__}, @@ -1133,7 +1074,7 @@ def _build_skills_reminder_message(self) -> Optional[dict]: return { "role": "user", "content": ( - "\n" f"{skills_reminder_text}\n" "" + f"\n{skills_reminder_text}\n" ), } @@ -1205,8 +1146,12 @@ async def _handle_tool_calls( # TODO: For malformed/truncated tool arguments, add a model-side retry flow. tool_args = json.loads(tool_call["function"]["arguments"]) - goon, tool_result = await self.pre_execute_tool(tool_name, tool_args) - if goon == LLMCallbackResult.APPROVE: + dispatch = await self.pre_execute_tool(tool_name, tool_args) + tool_result = dispatch.result + if dispatch.status in { + ToolDispatchStatus.EXECUTED, + ToolDispatchStatus.SHORT_CIRCUIT, + }: rendered_result = tool_result.render_for_llm() tool_msg = { "role": "tool", @@ -1239,7 +1184,7 @@ async def _handle_tool_calls( } context_manager.add_memory(MemoryType.LLM, error_msg) - if goon == LLMCallbackResult.CANCEL: + if dispatch.status == ToolDispatchStatus.CANCELLED: # 触发 CANCELLED 事件,让 shell 能够显示取消消息 self.emit_event( LLMEventType.CANCELLED, {"reason": "tool_cancelled"} diff --git a/src/aish/shell_enhanced/shell_prompt_io.py b/src/aish/shell_enhanced/shell_prompt_io.py index 9549c14..c73cb7b 100644 --- a/src/aish/shell_enhanced/shell_prompt_io.py +++ b/src/aish/shell_enhanced/shell_prompt_io.py @@ -369,16 +369,16 @@ def handle_tool_confirmation_required(shell: Any, event: LLMEvent) -> LLMCallbac self._finalize_content_preview() data = event.data - panel_mode = str(data.get("panel_mode", "confirm")).lower() + panel = data.get("panel") if isinstance(data.get("panel"), dict) else {} + panel_mode = str(panel.get("mode") or data.get("panel_mode", "confirm")).lower() # Display confirmation/security notice using rich formatting self._display_security_panel(data, panel_mode=panel_mode) # Only "confirm" requires interactive user input if panel_mode == "confirm": - tool_name = str(data.get("tool_name", "")) - remember_command = data.get("command") - allow_remember = bool(remember_command) and tool_name == "bash_exec" + remember_command = panel.get("remember_key", data.get("remember_key")) + allow_remember = bool(panel.get("allow_remember", data.get("allow_remember"))) return self._get_user_confirmation( remember_command=remember_command, allow_remember=allow_remember, @@ -926,12 +926,21 @@ def get_content(): def display_security_panel(shell: Any, data: dict, panel_mode: str = "confirm") -> None: """Display rich security panel for AI tool calls.""" self = shell - panel_mode = str(panel_mode).lower() + panel = data.get("panel") if isinstance(data.get("panel"), dict) else {} + panel_mode = str(panel.get("mode") or panel_mode).lower() is_blocked = panel_mode == "blocked" is_info = panel_mode == "info" tool_name = str(data.get("tool_name", "unknown")) - security_analysis = data.get("security_analysis", {}) + security_analysis = panel.get("analysis") + if not isinstance(security_analysis, dict): + security_analysis = ( + data.get("analysis") + if isinstance(data.get("analysis"), dict) + else data.get("security_analysis", {}) + ) + target = panel.get("target", data.get("target")) + preview = panel.get("preview", data.get("preview")) def _sandbox_reason_value(analysis: object) -> str: if not isinstance(analysis, dict): @@ -997,6 +1006,8 @@ def _risk_level_value(analysis: object) -> str: content.append( f"[bold]{t('shell.security.label.command')}:[/bold] {data['command']}" ) + elif target: + content.append(f"[bold]{t('shell.security.label.target')}:[/bold] {target}") # Fallback hint: sandbox failed to assess the command, so we cannot determine risk. # In this case we ask users to confirm before executing the real command. @@ -1019,32 +1030,17 @@ def _risk_level_value(analysis: object) -> str: f"[bold]{t('shell.security.label.fallback_hint')}:[/bold] {hint}" ) - # For non-bash tools (e.g. write_file), tool-specific confirmation info is carried - # in generic fields like tool_args/content_preview/content_length. - tool_args = data.get("tool_args") - if isinstance(tool_args, dict): - file_path = tool_args.get("file_path") or tool_args.get("path") - if file_path: - content.append( - f"[bold]{t('shell.security.label.target')}:[/bold] {file_path}" - ) - - if "content" in tool_args: + if preview is None: + tool_args = data.get("tool_args") + if isinstance(tool_args, dict) and "content" in tool_args: raw_content = tool_args.get("content") + if isinstance(raw_content, str): + preview = raw_content[:100] + "..." if len(raw_content) > 100 else raw_content - # Prefer the tool-provided preview; otherwise derive a safe preview. - content_preview = data.get("content_preview") - if tool_name == "write_file" and isinstance(raw_content, str): - content_preview = raw_content - elif content_preview is None and isinstance(raw_content, str): - content_preview = ( - raw_content[:100] + "..." if len(raw_content) > 100 else raw_content - ) - - if content_preview is not None: - content.append( - f"[bold]{t('shell.security.label.content_preview')}:[/bold] {content_preview}" - ) + if preview is not None: + content.append( + f"[bold]{t('shell.security.label.content_preview')}:[/bold] {preview}" + ) if security_analysis and (sandbox_enabled or fallback_rule_matched): is_low_risk = risk_level_upper == "LOW" diff --git a/src/aish/tools/base.py b/src/aish/tools/base.py index ee1332b..2f00533 100644 --- a/src/aish/tools/base.py +++ b/src/aish/tools/base.py @@ -1,13 +1,65 @@ from __future__ import annotations from abc import abstractmethod -from typing import Any, Awaitable +from dataclasses import dataclass, field +from enum import Enum +from pathlib import Path +from typing import Any, Awaitable, Callable from pydantic import BaseModel, ConfigDict from aish.tools.result import ToolResult +class ToolPreflightAction(str, Enum): + EXECUTE = "execute" + CONFIRM = "confirm" + SHORT_CIRCUIT = "short_circuit" + + +@dataclass +class ToolExecutionContext: + cwd: Path + cancellation_token: Any | None = None + interruption_manager: Any | None = None + is_approved: Callable[[str], bool] | None = None + + +@dataclass +class ToolPanelSpec: + mode: str = "confirm" + target: str | None = None + preview: str | None = None + analysis: dict[str, Any] = field(default_factory=dict) + allow_remember: bool = False + remember_key: str | None = None + title: str | None = None + + def to_event_payload(self) -> dict[str, Any]: + payload: dict[str, Any] = { + "mode": self.mode, + "allow_remember": self.allow_remember, + } + if self.target is not None: + payload["target"] = self.target + if self.preview is not None: + payload["preview"] = self.preview + if self.analysis: + payload["analysis"] = self.analysis + if self.remember_key is not None: + payload["remember_key"] = self.remember_key + if self.title is not None: + payload["title"] = self.title + return payload + + +@dataclass +class ToolPreflightResult: + action: ToolPreflightAction = ToolPreflightAction.EXECUTE + panel: ToolPanelSpec | None = None + result: ToolResult | None = None + + class ToolBase(BaseModel): model_config = ConfigDict(extra="allow") @@ -32,10 +84,90 @@ def get_confirmation_info(self, *args, **kwargs) -> dict: """Get additional information for confirmation dialog""" return {} + def get_pre_execute_subject(self, tool_args: dict[str, Any]) -> Any: + """Legacy adapter hook for tools that only inspect part of tool_args.""" + return tool_args + + def prepare_invocation( + self, tool_args: dict[str, Any], context: ToolExecutionContext + ) -> ToolPreflightResult: + """Prepare a tool invocation before execution. + + New tools should override this method directly. The default implementation + adapts the legacy confirmation hooks for backward compatibility. + """ + + _ = context + subject = self.get_pre_execute_subject(tool_args) + need_confirm = self.need_confirm_before_exec(subject) + if not need_confirm: + return ToolPreflightResult(action=ToolPreflightAction.EXECUTE) + + info = self.get_confirmation_info(subject) + return ToolPreflightResult( + action=ToolPreflightAction.CONFIRM, + panel=self._build_panel_from_legacy(tool_args, info), + ) + def get_session_output(self, result: ToolResult) -> str | None: """Optionally expose a tool result as the session's fallback output.""" return None + def _build_panel_from_legacy( + self, tool_args: dict[str, Any], info: object + ) -> ToolPanelSpec: + info_dict = info if isinstance(info, dict) else {} + target: str | None = None + preview: str | None = None + analysis: dict[str, Any] = {} + remember_key: str | None = None + title: str | None = None + + if isinstance(tool_args, dict): + raw_target = tool_args.get("file_path") or tool_args.get("path") + if raw_target is not None: + target = str(raw_target) + + if isinstance(info_dict, dict): + raw_target = info_dict.get("target") + if raw_target is not None: + target = str(raw_target) + + raw_preview = info_dict.get("preview") + if raw_preview is None: + raw_preview = info_dict.get("content_preview") + if raw_preview is not None: + preview = str(raw_preview) + + raw_analysis = info_dict.get("analysis") + if isinstance(raw_analysis, dict): + analysis = raw_analysis + elif isinstance(info_dict.get("security_analysis"), dict): + analysis = info_dict["security_analysis"] + + raw_remember_key = info_dict.get("remember_key") + if raw_remember_key is None: + raw_remember_key = info_dict.get("command") + if raw_remember_key is not None: + remember_key = str(raw_remember_key) + + raw_title = info_dict.get("title") + if raw_title is not None: + title = str(raw_title) + + mode = str(info_dict.get("panel_mode", "confirm")) + allow_remember = bool(info_dict.get("allow_remember", False)) + + return ToolPanelSpec( + mode=mode, + target=target, + preview=preview, + analysis=analysis, + allow_remember=allow_remember, + remember_key=remember_key, + title=title, + ) + @abstractmethod def __call__( self, *args: Any, **kwargs: Any diff --git a/src/aish/tools/code_exec.py b/src/aish/tools/code_exec.py index 8f44f65..7658166 100644 --- a/src/aish/tools/code_exec.py +++ b/src/aish/tools/code_exec.py @@ -12,7 +12,8 @@ from aish.offload import render_bash_output from aish.security.security_manager import (SecurityDecision, SimpleSecurityManager) -from aish.tools.base import ToolBase +from aish.tools.base import (ToolBase, ToolExecutionContext, ToolPanelSpec, + ToolPreflightAction, ToolPreflightResult) from aish.tools.bash_executor import UnifiedBashExecutor from aish.tools.result import ToolResult @@ -228,6 +229,115 @@ def need_confirm_before_exec(self, code: Optional[str] = None) -> bool: return False return bool(self._last_decision.require_confirmation) + def prepare_invocation( + self, tool_args: dict[str, object], context: ToolExecutionContext + ) -> ToolPreflightResult: + command = tool_args.get("code") + if not isinstance(command, str) or not command: + return ToolPreflightResult(action=ToolPreflightAction.EXECUTE) + + interruption_manager = context.interruption_manager or self.interruption_manager + if interruption_manager: + interruption_manager.set_state(ShellState.SANDBOX_EVAL) + + decision = self.security_manager.decide( + command, + is_ai_command=True, + cwd=context.cwd, + ) + self._last_decision = decision + + if interruption_manager: + interruption_manager.set_state(ShellState.NORMAL) + + info = self.get_confirmation_info(command) + analysis_data = ( + info.get("security_analysis", {}) if isinstance(info, dict) else {} + ) + panel = ToolPanelSpec( + mode="confirm", + target=command, + analysis=analysis_data if isinstance(analysis_data, dict) else {}, + allow_remember=True, + remember_key=command, + ) + + if ( + isinstance(analysis_data, dict) + and analysis_data.get("fail_open") is True + ): + return ToolPreflightResult(action=ToolPreflightAction.EXECUTE) + + if callable(context.is_approved) and context.is_approved(command): + return ToolPreflightResult(action=ToolPreflightAction.EXECUTE) + + if not decision.allow: + reasons = ( + analysis_data.get("reasons", []) + if isinstance(analysis_data, dict) + else [] + ) + reason_text = ", ".join(str(reason) for reason in reasons[:5] if reason) + blocked_msg = ( + t("security.command_blocked_with_reason", reason=reason_text) + if reason_text + else t("security.command_blocked") + ) + return ToolPreflightResult( + action=ToolPreflightAction.SHORT_CIRCUIT, + panel=ToolPanelSpec( + mode="blocked", + target=command, + analysis=panel.analysis, + allow_remember=True, + remember_key=command, + ), + result=ToolResult( + ok=False, + output=blocked_msg, + code=126, + meta={"kind": "security_blocked", "reasons": reasons}, + stop_tool_chain=True, + ), + ) + + sandbox_info = ( + analysis_data.get("sandbox", {}) if isinstance(analysis_data, dict) else {} + ) + sandbox_reason = ( + str(sandbox_info.get("reason", "")) if isinstance(sandbox_info, dict) else "" + ) + fallback_rule_matched = bool( + analysis_data.get("fallback_rule_matched") + if isinstance(analysis_data, dict) + else False + ) + skip_notice_panel = sandbox_reason in { + "sandbox_disabled", + "sandbox_disabled_by_policy", + } and not fallback_rule_matched + + if decision.require_confirmation: + panel.mode = "confirm" + return ToolPreflightResult( + action=ToolPreflightAction.CONFIRM, + panel=panel, + ) + + risk_level = str( + analysis_data.get("risk_level", "UNKNOWN") + if isinstance(analysis_data, dict) + else "UNKNOWN" + ).upper() + panel.mode = "blocked" if risk_level in {"HIGH", "CRITICAL"} else "info" + if skip_notice_panel: + return ToolPreflightResult(action=ToolPreflightAction.EXECUTE) + + return ToolPreflightResult( + action=ToolPreflightAction.EXECUTE, + panel=panel, + ) + def get_confirmation_info(self, code: Optional[str] = None) -> dict: """Get security information for confirmation dialog""" command = code diff --git a/src/aish/tools/fs_tools.py b/src/aish/tools/fs_tools.py index 664f107..d946f28 100644 --- a/src/aish/tools/fs_tools.py +++ b/src/aish/tools/fs_tools.py @@ -1,10 +1,16 @@ from pathlib import Path from typing import ClassVar -from aish.tools.base import ToolBase +from aish.tools.base import (ToolBase, ToolExecutionContext, ToolPanelSpec, + ToolPreflightAction, ToolPreflightResult) from aish.tools.result import ToolResult +def _preview_text(value: object, limit: int = 100) -> str: + text = str(value) if value is not None else "" + return text[:limit] + "..." if len(text) > limit else text + + # TODO: support images class ReadFileTool(ToolBase): MAX_READ_BYTES: ClassVar[int] = 32 * 1024 @@ -188,6 +194,21 @@ def __call__(self, file_path: str, content: str) -> ToolResult: def need_confirm_before_exec(self, content: str) -> bool: return True + def prepare_invocation( + self, tool_args: dict[str, object], context: ToolExecutionContext + ) -> ToolPreflightResult: + _ = context + file_path = tool_args.get("file_path") + content = tool_args.get("content", "") + return ToolPreflightResult( + action=ToolPreflightAction.CONFIRM, + panel=ToolPanelSpec( + mode="confirm", + target=str(file_path) if file_path is not None else None, + preview=content if isinstance(content, str) else _preview_text(content), + ), + ) + def get_confirmation_info(self, content: str) -> dict: # For write_file tool, we need to get the file_path from the tool_args # Since we only receive the content string here, we'll return what we can @@ -346,18 +367,33 @@ def __call__( def need_confirm_before_exec(self, tool_args: dict | None = None) -> bool: return True + def prepare_invocation( + self, tool_args: dict[str, object], context: ToolExecutionContext + ) -> ToolPreflightResult: + _ = context + replace_all = bool(tool_args.get("replace_all", False)) + old_string = tool_args.get("old_string", "") + new_string = tool_args.get("new_string", "") + mode = "Replace all" if replace_all else "Replace" + return ToolPreflightResult( + action=ToolPreflightAction.CONFIRM, + panel=ToolPanelSpec( + mode="confirm", + target=str(tool_args.get("file_path")) + if tool_args.get("file_path") is not None + else None, + preview=f"{mode}: {_preview_text(old_string)} -> {_preview_text(new_string)}", + ), + ) + def get_confirmation_info(self, tool_args: dict | None = None) -> dict: if not isinstance(tool_args, dict): return {} - def _preview(value: object, limit: int = 100) -> str: - text = str(value) if value is not None else "" - return text[:limit] + "..." if len(text) > limit else text - replace_all = bool(tool_args.get("replace_all", False)) old_string = tool_args.get("old_string", "") new_string = tool_args.get("new_string", "") mode = "Replace all" if replace_all else "Replace" return { - "content_preview": f"{mode}: {_preview(old_string)} -> {_preview(new_string)}" + "content_preview": f"{mode}: {_preview_text(old_string)} -> {_preview_text(new_string)}" } diff --git a/tests/test_ask_user_tool.py b/tests/test_ask_user_tool.py index 3eb7e09..41b389d 100644 --- a/tests/test_ask_user_tool.py +++ b/tests/test_ask_user_tool.py @@ -5,9 +5,12 @@ from aish.config import ConfigModel from aish.context_manager import ContextManager -from aish.llm import LLMCallbackResult, LLMSession +from aish.llm import (LLMCallbackResult, LLMSession, ToolDispatchOutcome, + ToolDispatchStatus) from aish.skills import SkillManager from aish.tools.ask_user import AskUserTool +from aish.tools.base import (ToolBase, ToolExecutionContext, ToolPanelSpec, + ToolPreflightAction, ToolPreflightResult) from aish.tools.result import ToolResult @@ -126,9 +129,9 @@ async def test_handle_tool_calls_ask_user_user_input_required_breaks(monkeypatch async def fake_pre_execute_tool(tool_name, _tool_args): if tool_name == "ask_user": - return ( - LLMCallbackResult.APPROVE, - ToolResult( + return ToolDispatchOutcome( + status=ToolDispatchStatus.EXECUTED, + result=ToolResult( ok=False, output="paused", meta={"kind": "user_input_required", "reason": "cancelled"}, @@ -170,9 +173,9 @@ async def test_handle_tool_calls_system_diagnose_agent_sets_session_output(): async def fake_pre_execute_tool(tool_name, _tool_args): assert tool_name == "system_diagnose_agent" - return ( - LLMCallbackResult.APPROVE, - ToolResult(ok=True, output="diagnostic result"), + return ToolDispatchOutcome( + status=ToolDispatchStatus.EXECUTED, + result=ToolResult(ok=True, output="diagnostic result"), ) with patch.object( @@ -208,9 +211,9 @@ async def test_handle_tool_calls_bash_security_blocked_clears_session_output(): async def fake_pre_execute_tool(tool_name, _tool_args): assert tool_name == "bash_exec" - return ( - LLMCallbackResult.APPROVE, - ToolResult( + return ToolDispatchOutcome( + status=ToolDispatchStatus.SHORT_CIRCUIT, + result=ToolResult( ok=False, output="blocked", meta={"kind": "security_blocked"}, @@ -237,24 +240,41 @@ async def test_pre_execute_tool_emits_blocked_panel_for_policy_fallback_rule(mon config = ConfigModel(model="test-model", api_key="test-key") session = LLMSession(config=config, skill_manager=SkillManager()) - class _DummyBashTool: - description = "dummy bash" - called = False - - def need_confirm_before_exec(self, _arg): - return False - - def get_confirmation_info(self, arg): - return { - "command": arg, - "security_decision": {"allow": False, "require_confirmation": False}, - "security_analysis": { - "risk_level": "HIGH", - "sandbox": {"enabled": False, "reason": "sandbox_disabled_by_policy"}, - "fallback_rule_matched": True, - "reasons": ["系统配置目录,误修改会导致严重故障"], - }, - } + class _DummyBashTool(ToolBase): + def __init__(self): + super().__init__( + name="bash_exec", + description="dummy bash", + parameters={"type": "object", "properties": {"code": {"type": "string"}}}, + ) + self.called = False + + def prepare_invocation( + self, tool_args: dict[str, object], context: ToolExecutionContext + ) -> ToolPreflightResult: + _ = context + return ToolPreflightResult( + action=ToolPreflightAction.SHORT_CIRCUIT, + panel=ToolPanelSpec( + mode="blocked", + target=str(tool_args.get("code")), + analysis={ + "risk_level": "HIGH", + "sandbox": { + "enabled": False, + "reason": "sandbox_disabled_by_policy", + }, + "fallback_rule_matched": True, + "reasons": ["系统配置目录,误修改会导致严重故障"], + }, + ), + result=ToolResult( + ok=False, + output="blocked", + meta={"kind": "security_blocked"}, + stop_tool_chain=True, + ), + ) async def __call__(self, code: str): self.called = True @@ -264,10 +284,177 @@ async def __call__(self, code: str): emitted: list[tuple[object, dict]] = [] monkeypatch.setattr(session, "emit_event", lambda event_type, data=None: emitted.append((event_type, data or {}))) - goon, _result = await session.pre_execute_tool("bash_exec", {"code": "sudo rm /etc/aish/123"}) + outcome = await session.pre_execute_tool("bash_exec", {"code": "sudo rm /etc/aish/123"}) - assert goon == LLMCallbackResult.CONTINUE + assert outcome.status == ToolDispatchStatus.SHORT_CIRCUIT assert emitted - assert emitted[0][1].get("panel_mode") == "blocked" - assert _result.meta.get("kind") == "security_blocked" + assert emitted[0][1].get("panel", {}).get("mode") == "blocked" + assert outcome.result.meta.get("kind") == "security_blocked" assert session.tools["bash_exec"].called is False + + +@pytest.mark.anyio +async def test_pre_execute_tool_info_panel_executes_tool(monkeypatch): + config = ConfigModel(model="test-model", api_key="test-key") + session = LLMSession(config=config, skill_manager=SkillManager()) + + class _InfoTool(ToolBase): + def __init__(self): + super().__init__( + name="info_tool", + description="info tool", + parameters={"type": "object", "properties": {"value": {"type": "string"}}}, + ) + + def prepare_invocation( + self, tool_args: dict[str, object], context: ToolExecutionContext + ) -> ToolPreflightResult: + _ = context + return ToolPreflightResult( + action=ToolPreflightAction.EXECUTE, + panel=ToolPanelSpec(mode="info", target=str(tool_args.get("value"))), + ) + + def __call__(self, value: str): + return ToolResult(ok=True, output=f"echo:{value}") + + session.tools["info_tool"] = _InfoTool() + emitted: list[tuple[object, dict]] = [] + monkeypatch.setattr( + session, + "emit_event", + lambda event_type, data=None: emitted.append((event_type, data or {})), + ) + + outcome = await session.pre_execute_tool("info_tool", {"value": "hello"}) + + assert outcome.status == ToolDispatchStatus.EXECUTED + assert outcome.result.output == "echo:hello" + assert emitted + assert emitted[0][1].get("panel", {}).get("mode") == "info" + + +@pytest.mark.anyio +async def test_pre_execute_tool_legacy_hooks_use_get_pre_execute_subject(monkeypatch): + config = ConfigModel(model="test-model", api_key="test-key") + session = LLMSession(config=config, skill_manager=SkillManager()) + + class _LegacyTool(ToolBase): + def __init__(self): + super().__init__( + name="legacy_tool", + description="legacy tool", + parameters={ + "type": "object", + "properties": { + "dangerous": {"type": "string"}, + "file_path": {"type": "string"}, + }, + }, + ) + self.seen_subject = None + + def get_pre_execute_subject(self, tool_args: dict[str, object]) -> object: + return tool_args.get("dangerous") + + def need_confirm_before_exec(self, subject: object) -> bool: + self.seen_subject = subject + return True + + def get_confirmation_info(self, subject: object) -> dict: + return { + "target": "/tmp/demo.txt", + "preview": f"subject={subject}", + } + + def __call__(self, dangerous: str, file_path: str): + return ToolResult(ok=True, output=f"{dangerous}:{file_path}") + + session.tools["legacy_tool"] = _LegacyTool() + monkeypatch.setattr(session, "request_confirmation", lambda *_args, **_kwargs: LLMCallbackResult.APPROVE) + + outcome = await session.pre_execute_tool( + "legacy_tool", + {"dangerous": "rm -rf", "file_path": "/tmp/demo.txt"}, + ) + + assert outcome.status == ToolDispatchStatus.EXECUTED + assert session.tools["legacy_tool"].seen_subject == "rm -rf" + assert outcome.result.output == "rm -rf:/tmp/demo.txt" + + +@pytest.mark.anyio +async def test_pre_execute_tool_write_file_confirmation_uses_panel_target_preview( + monkeypatch, +): + config = ConfigModel(model="test-model", api_key="test-key") + session = LLMSession(config=config, skill_manager=SkillManager()) + captured: dict[str, object] = {} + + def _request_confirmation(_event_type, data, **_kwargs): + captured.update(data) + return LLMCallbackResult.DENY + + monkeypatch.setattr(session, "request_confirmation", _request_confirmation) + + outcome = await session.pre_execute_tool( + "write_file", + {"file_path": "/tmp/demo.txt", "content": "hello world"}, + ) + + assert outcome.status == ToolDispatchStatus.REJECTED + assert captured.get("panel", {}).get("target") == "/tmp/demo.txt" + assert captured.get("panel", {}).get("preview") == "hello world" + + +@pytest.mark.anyio +async def test_pre_execute_tool_write_file_confirmation_keeps_full_long_preview( + monkeypatch, +): + config = ConfigModel(model="test-model", api_key="test-key") + session = LLMSession(config=config, skill_manager=SkillManager()) + captured: dict[str, object] = {} + long_content = "A" * 150 + "TAIL" + + def _request_confirmation(_event_type, data, **_kwargs): + captured.update(data) + return LLMCallbackResult.DENY + + monkeypatch.setattr(session, "request_confirmation", _request_confirmation) + + outcome = await session.pre_execute_tool( + "write_file", + {"file_path": "/tmp/demo.txt", "content": long_content}, + ) + + assert outcome.status == ToolDispatchStatus.REJECTED + assert captured.get("panel", {}).get("preview") == long_content + + +@pytest.mark.anyio +async def test_pre_execute_tool_edit_file_confirmation_uses_panel_target_preview( + monkeypatch, +): + config = ConfigModel(model="test-model", api_key="test-key") + session = LLMSession(config=config, skill_manager=SkillManager()) + captured: dict[str, object] = {} + + def _request_confirmation(_event_type, data, **_kwargs): + captured.update(data) + return LLMCallbackResult.DENY + + monkeypatch.setattr(session, "request_confirmation", _request_confirmation) + + outcome = await session.pre_execute_tool( + "edit_file", + { + "file_path": "/tmp/demo.txt", + "old_string": "old", + "new_string": "new", + "replace_all": True, + }, + ) + + assert outcome.status == ToolDispatchStatus.REJECTED + assert captured.get("panel", {}).get("target") == "/tmp/demo.txt" + assert captured.get("panel", {}).get("preview") == "Replace all: old -> new" diff --git a/tests/test_shell_prompt_io.py b/tests/test_shell_prompt_io.py index 45b7e1c..f125270 100644 --- a/tests/test_shell_prompt_io.py +++ b/tests/test_shell_prompt_io.py @@ -7,7 +7,11 @@ from rich.console import Console from aish.llm import LLMCallbackResult, LLMEvent, LLMEventType -from aish.shell_enhanced.shell_prompt_io import display_security_panel, handle_ask_user_required +from aish.shell_enhanced.shell_prompt_io import ( + display_security_panel, + handle_ask_user_required, + handle_tool_confirmation_required, +) def _reset_i18n_cache() -> None: @@ -117,6 +121,8 @@ def test_display_security_panel_shows_fallback_rule_details(monkeypatch): assert "风险等级" in output assert "原因" in output assert "系统配置目录,误修改会导致严重故障" in output + + def test_display_security_panel_for_fallback_rule_confirm_hides_generic_fallback_hint( monkeypatch, ): @@ -144,3 +150,32 @@ def test_display_security_panel_for_fallback_rule_confirm_hides_generic_fallback assert "用户业务数据变更需人工确认" in output assert "未能完成命令风险评估" not in output + +def test_handle_tool_confirmation_required_uses_panel_payload(): + shell = _DummyShell() + captured: list[tuple[object, object]] = [] + shell._get_user_confirmation = lambda remember_command=None, allow_remember=False: ( + captured.append((remember_command, allow_remember)) or LLMCallbackResult.APPROVE + ) + shell._display_security_panel = lambda data, panel_mode="confirm": captured.append( + ("panel", panel_mode, data.get("panel", {})) + ) + + event = LLMEvent( + event_type=LLMEventType.TOOL_CONFIRMATION_REQUIRED, + data={ + "tool_name": "bash_exec", + "panel": { + "mode": "confirm", + "target": "echo hi", + "allow_remember": True, + "remember_key": "echo hi", + }, + }, + timestamp=time.time(), + ) + + result = handle_tool_confirmation_required(shell, event) + + assert result == LLMCallbackResult.APPROVE + assert ("echo hi", True) in captured