diff --git a/src/aish/llm.py b/src/aish/llm.py
index 8b23284..dfc6e40 100644
--- a/src/aish/llm.py
+++ b/src/aish/llm.py
@@ -1,10 +1,12 @@
import json
import logging
+import os
import threading
import time
import uuid
from dataclasses import dataclass
from enum import Enum
+from pathlib import Path
from typing import Callable, Optional
import anyio
@@ -14,13 +16,18 @@
from aish.config import ConfigModel
from aish.context_manager import ContextManager, MemoryType
from aish.exception import is_litellm_exception, redact_secrets
-from aish.i18n import t
from aish.interruption import ShellState
from aish.litellm_loader import load_litellm
from aish.providers.registry import get_provider_for_model
from aish.prompts import PromptManager
from aish.skills import SkillManager
-from aish.tools.base import ToolBase
+from aish.tools.base import (
+ ToolBase,
+ ToolExecutionContext,
+ ToolPanelSpec,
+ ToolPreflightAction,
+ ToolPreflightResult,
+)
from aish.tools.code_exec import BashTool, PythonTool
from aish.tools.fs_tools import EditFileTool, ReadFileTool, WriteFileTool
from aish.tools.result import ToolResult
@@ -83,6 +90,19 @@ class LLMEvent:
metadata: Optional[dict] = None
+class ToolDispatchStatus(Enum):
+ EXECUTED = "executed"
+ SHORT_CIRCUIT = "short_circuit"
+ REJECTED = "rejected"
+ CANCELLED = "cancelled"
+
+
+@dataclass
+class ToolDispatchOutcome:
+ status: ToolDispatchStatus
+ result: ToolResult
+
+
def normalize_tool_result(value: object) -> ToolResult:
if isinstance(value, ToolResult):
return value
@@ -851,203 +871,128 @@ async def execute_tool(
)
return tool_result
+ def _build_tool_panel_event_data(
+ self,
+ *,
+ tool: ToolBase,
+ tool_name: str,
+ tool_args: dict,
+ panel: ToolPanelSpec,
+ ) -> dict:
+ panel_payload = panel.to_event_payload()
+ data: dict = {
+ "tool_name": tool_name,
+ "tool_args": tool_args,
+ "description": tool.description,
+ "panel": panel_payload,
+ # Temporary top-level mirror for transition/debugging.
+ "panel_mode": panel.mode,
+ }
+ for key in (
+ "target",
+ "preview",
+ "analysis",
+ "allow_remember",
+ "remember_key",
+ "title",
+ ):
+ if key in panel_payload:
+ data[key] = panel_payload[key]
+ return data
+
async def pre_execute_tool(
self, tool_name: str, tool_args: dict
- ) -> tuple[LLMCallbackResult, ToolResult]:
+ ) -> ToolDispatchOutcome:
try:
if tool_name not in self.tools:
- return (
- LLMCallbackResult.CONTINUE,
- ToolResult(
+ return ToolDispatchOutcome(
+ status=ToolDispatchStatus.REJECTED,
+ result=ToolResult(
ok=False,
output=f"Error: Invalid tool name: {tool_name}",
),
)
tool = self.tools[tool_name]
-
- # For BashTool, pass the code to check if confirmation is needed
- if tool_name == "bash_exec":
- arg = tool_args.get("code")
- elif tool_name == "write_file":
- arg = tool_args.get("content")
- elif tool_name == "edit_file":
- arg = tool_args
- else:
- arg = None
-
- # 使用 to_thread.run 来避免阻塞事件循环
- # 这样信号处理器可以在沙箱评估期间运行
- need_confirm = await anyio.to_thread.run_sync(
- tool.need_confirm_before_exec, arg
+ context = ToolExecutionContext(
+ cwd=Path(os.getcwd()).resolve(),
+ cancellation_token=self.cancellation_token,
+ interruption_manager=self.interruption_manager,
+ is_approved=self.is_command_approved,
)
+ preflight = await anyio.to_thread.run_sync(
+ tool.prepare_invocation, tool_args, context
+ )
+ if not isinstance(preflight, ToolPreflightResult):
+ preflight = ToolPreflightResult()
- confirmation_info: dict = {}
- security_analysis: dict = {}
- security_decision: dict = {}
- suppress_security_panels = False
-
- if tool_name == "bash_exec" and arg is not None:
- try:
- confirmation_info = tool.get_confirmation_info(arg) # type: ignore[assignment]
- security_analysis = (
- confirmation_info.get("security_analysis", {})
- if isinstance(confirmation_info, dict)
- else {}
- )
- security_decision = (
- confirmation_info.get("security_decision", {})
- if isinstance(confirmation_info, dict)
- else {}
- )
- except Exception:
- confirmation_info = {}
- security_analysis = {}
- security_decision = {}
-
- # 1) fail-open when sandbox is unavailable (cannot assess risks)
- if (
- isinstance(security_analysis, dict)
- and security_analysis.get("fail_open") is True
- ):
- need_confirm = False
- suppress_security_panels = True
-
- # 2) exact-match allowlist
- if callable(self.is_command_approved) and self.is_command_approved(arg):
- need_confirm = False
- suppress_security_panels = True
-
- if (
- isinstance(security_decision, dict)
- and security_decision.get("allow") is False
- ):
- need_confirm = False
-
- # 在沙箱评估后检查是否被取消
- # 如果被取消,直接返回 CANCEL,不继续执行
if self.cancellation_token and self.cancellation_token.is_cancelled():
if self.interruption_manager:
self.interruption_manager.set_state(ShellState.NORMAL)
- return (
- LLMCallbackResult.CANCEL,
- ToolResult(
+ return ToolDispatchOutcome(
+ status=ToolDispatchStatus.CANCELLED,
+ result=ToolResult(
ok=False,
output="Operation cancelled during security evaluation",
),
)
- # 安全提示面板:对 AI 生成命令的风险评估结果进行展示。
- # - MEDIUM 且需要确认:走确认面板 + y/n
- # - LOW:展示 info 面板,不阻塞
- # - HIGH:展示 blocked 面板,不出现确认;仍由工具自身/安全系统直接阻断
- if (
- tool_name == "bash_exec"
- and arg is not None
- and not suppress_security_panels
- ):
- try:
- info = confirmation_info or tool.get_confirmation_info(arg)
- security_analysis = (
- info.get("security_analysis", {})
- if isinstance(info, dict)
- else {}
- )
- sandbox_info = (
- security_analysis.get("sandbox", {})
- if isinstance(security_analysis, dict)
- else {}
- )
- sandbox_reason = str(
- sandbox_info.get("reason", "")
- if isinstance(sandbox_info, dict)
- else ""
- )
- fallback_rule_matched = bool(
- security_analysis.get("fallback_rule_matched")
- ) if isinstance(security_analysis, dict) else False
- skip_confirmation_panel = sandbox_reason in {
- "sandbox_disabled",
- "sandbox_disabled_by_policy",
- } and not fallback_rule_matched
- risk_level = str(
- security_analysis.get("risk_level", "UNKNOWN")
- ).upper()
- panel_mode = (
- "confirm"
- if need_confirm
- else (
- "blocked" if risk_level in {"HIGH", "CRITICAL"} else "info"
- )
- )
-
- if panel_mode != "confirm" and not skip_confirmation_panel:
- notice_data = {
- "tool_name": tool_name,
- "tool_args": tool_args,
- "description": tool.description,
- "panel_mode": panel_mode,
- **(info if isinstance(info, dict) else {}),
- }
- self.emit_event(
- LLMEventType.TOOL_CONFIRMATION_REQUIRED, notice_data
- )
- except Exception:
- # 不影响工具执行;提示面板失败时静默跳过
- pass
-
- if (
- tool_name == "bash_exec"
- and isinstance(security_decision, dict)
- and security_decision.get("allow") is False
- ):
- reasons = (
- security_analysis.get("reasons", [])
- if isinstance(security_analysis, dict)
- else []
- )
- reason_text = ", ".join(str(reason) for reason in reasons[:5] if reason)
- blocked_msg = (
- t("security.command_blocked_with_reason", reason=reason_text)
- if reason_text
- else t("security.command_blocked")
+ panel = preflight.panel
+ if panel is not None and panel.mode == "info":
+ self.emit_event(
+ LLMEventType.TOOL_CONFIRMATION_REQUIRED,
+ self._build_tool_panel_event_data(
+ tool=tool,
+ tool_name=tool_name,
+ tool_args=tool_args,
+ panel=panel,
+ ),
)
- return (
- LLMCallbackResult.CONTINUE,
- ToolResult(
+
+ if preflight.action == ToolPreflightAction.SHORT_CIRCUIT:
+ if panel is not None and panel.mode == "blocked":
+ self.emit_event(
+ LLMEventType.TOOL_CONFIRMATION_REQUIRED,
+ self._build_tool_panel_event_data(
+ tool=tool,
+ tool_name=tool_name,
+ tool_args=tool_args,
+ panel=panel,
+ ),
+ )
+ return ToolDispatchOutcome(
+ status=ToolDispatchStatus.SHORT_CIRCUIT,
+ result=preflight.result
+ or ToolResult(
ok=False,
- output=blocked_msg,
- code=126,
- meta={"kind": "security_blocked", "reasons": reasons},
+ output=f"Tool {tool_name} short-circuited without a result",
),
)
- if need_confirm:
- # Prepare confirmation data with security information
- confirmation_data = {
- "tool_name": tool_name,
- "tool_args": tool_args,
- "description": tool.description,
- "panel_mode": "confirm",
- **(confirmation_info or tool.get_confirmation_info(arg)),
- }
-
- # Request user confirmation
+ if preflight.action == ToolPreflightAction.CONFIRM:
+ confirm_panel = panel or ToolPanelSpec(mode="confirm")
goon = self.request_confirmation(
LLMEventType.TOOL_CONFIRMATION_REQUIRED,
- confirmation_data,
+ self._build_tool_panel_event_data(
+ tool=tool,
+ tool_name=tool_name,
+ tool_args=tool_args,
+ panel=confirm_panel,
+ ),
timeout_seconds=30.0,
default_on_timeout=LLMCallbackResult.DENY,
)
- # Handle confirmation result
if goon == LLMCallbackResult.APPROVE:
- return goon, await self.execute_tool(tool, tool_name, tool_args)
+ return ToolDispatchOutcome(
+ status=ToolDispatchStatus.EXECUTED,
+ result=await self.execute_tool(tool, tool_name, tool_args),
+ )
- elif goon == LLMCallbackResult.DENY:
- return (
- goon,
- ToolResult(
+ if goon == LLMCallbackResult.DENY:
+ return ToolDispatchOutcome(
+ status=ToolDispatchStatus.REJECTED,
+ result=ToolResult(
ok=False,
output=(
f"Tool {tool_name} execution denied by user, you may "
@@ -1056,45 +1001,41 @@ async def pre_execute_tool(
),
)
- elif goon == LLMCallbackResult.CANCEL:
- return (
- goon,
- ToolResult(
+ if goon == LLMCallbackResult.CANCEL:
+ return ToolDispatchOutcome(
+ status=ToolDispatchStatus.CANCELLED,
+ result=ToolResult(
ok=False,
output=f"Tool {tool_name} execution cancelled by user",
),
)
- else:
- return (
- goon,
- ToolResult(
- ok=False,
- output=f"Invalid confirmation result: {goon}",
- ),
- )
- else:
- # No confirmation needed, execute directly
- return (
- LLMCallbackResult.APPROVE,
- await self.execute_tool(tool, tool_name, tool_args),
+ return ToolDispatchOutcome(
+ status=ToolDispatchStatus.REJECTED,
+ result=ToolResult(
+ ok=False,
+ output=f"Invalid confirmation result: {goon}",
+ ),
)
+
+ return ToolDispatchOutcome(
+ status=ToolDispatchStatus.EXECUTED,
+ result=await self.execute_tool(tool, tool_name, tool_args),
+ )
except KeyboardInterrupt:
- # 用户中断(例如在沙箱评估期间按 Ctrl+C)
- # 恢复状态并返回取消结果
if self.interruption_manager:
self.interruption_manager.set_state(ShellState.NORMAL)
- return (
- LLMCallbackResult.CANCEL,
- ToolResult(
+ return ToolDispatchOutcome(
+ status=ToolDispatchStatus.CANCELLED,
+ result=ToolResult(
ok=False,
output="Operation cancelled by user",
),
)
except Exception as e:
- return (
- LLMCallbackResult.CONTINUE,
- ToolResult(
+ return ToolDispatchOutcome(
+ status=ToolDispatchStatus.REJECTED,
+ result=ToolResult(
ok=False,
output=str(e),
meta={"exception_type": type(e).__name__},
@@ -1133,7 +1074,7 @@ def _build_skills_reminder_message(self) -> Optional[dict]:
return {
"role": "user",
"content": (
- "\n" f"{skills_reminder_text}\n" ""
+ f"\n{skills_reminder_text}\n"
),
}
@@ -1205,8 +1146,12 @@ async def _handle_tool_calls(
# TODO: For malformed/truncated tool arguments, add a model-side retry flow.
tool_args = json.loads(tool_call["function"]["arguments"])
- goon, tool_result = await self.pre_execute_tool(tool_name, tool_args)
- if goon == LLMCallbackResult.APPROVE:
+ dispatch = await self.pre_execute_tool(tool_name, tool_args)
+ tool_result = dispatch.result
+ if dispatch.status in {
+ ToolDispatchStatus.EXECUTED,
+ ToolDispatchStatus.SHORT_CIRCUIT,
+ }:
rendered_result = tool_result.render_for_llm()
tool_msg = {
"role": "tool",
@@ -1239,7 +1184,7 @@ async def _handle_tool_calls(
}
context_manager.add_memory(MemoryType.LLM, error_msg)
- if goon == LLMCallbackResult.CANCEL:
+ if dispatch.status == ToolDispatchStatus.CANCELLED:
# 触发 CANCELLED 事件,让 shell 能够显示取消消息
self.emit_event(
LLMEventType.CANCELLED, {"reason": "tool_cancelled"}
diff --git a/src/aish/shell_enhanced/shell_prompt_io.py b/src/aish/shell_enhanced/shell_prompt_io.py
index 9549c14..c73cb7b 100644
--- a/src/aish/shell_enhanced/shell_prompt_io.py
+++ b/src/aish/shell_enhanced/shell_prompt_io.py
@@ -369,16 +369,16 @@ def handle_tool_confirmation_required(shell: Any, event: LLMEvent) -> LLMCallbac
self._finalize_content_preview()
data = event.data
- panel_mode = str(data.get("panel_mode", "confirm")).lower()
+ panel = data.get("panel") if isinstance(data.get("panel"), dict) else {}
+ panel_mode = str(panel.get("mode") or data.get("panel_mode", "confirm")).lower()
# Display confirmation/security notice using rich formatting
self._display_security_panel(data, panel_mode=panel_mode)
# Only "confirm" requires interactive user input
if panel_mode == "confirm":
- tool_name = str(data.get("tool_name", ""))
- remember_command = data.get("command")
- allow_remember = bool(remember_command) and tool_name == "bash_exec"
+ remember_command = panel.get("remember_key", data.get("remember_key"))
+ allow_remember = bool(panel.get("allow_remember", data.get("allow_remember")))
return self._get_user_confirmation(
remember_command=remember_command,
allow_remember=allow_remember,
@@ -926,12 +926,21 @@ def get_content():
def display_security_panel(shell: Any, data: dict, panel_mode: str = "confirm") -> None:
"""Display rich security panel for AI tool calls."""
self = shell
- panel_mode = str(panel_mode).lower()
+ panel = data.get("panel") if isinstance(data.get("panel"), dict) else {}
+ panel_mode = str(panel.get("mode") or panel_mode).lower()
is_blocked = panel_mode == "blocked"
is_info = panel_mode == "info"
tool_name = str(data.get("tool_name", "unknown"))
- security_analysis = data.get("security_analysis", {})
+ security_analysis = panel.get("analysis")
+ if not isinstance(security_analysis, dict):
+ security_analysis = (
+ data.get("analysis")
+ if isinstance(data.get("analysis"), dict)
+ else data.get("security_analysis", {})
+ )
+ target = panel.get("target", data.get("target"))
+ preview = panel.get("preview", data.get("preview"))
def _sandbox_reason_value(analysis: object) -> str:
if not isinstance(analysis, dict):
@@ -997,6 +1006,8 @@ def _risk_level_value(analysis: object) -> str:
content.append(
f"[bold]{t('shell.security.label.command')}:[/bold] {data['command']}"
)
+ elif target:
+ content.append(f"[bold]{t('shell.security.label.target')}:[/bold] {target}")
# Fallback hint: sandbox failed to assess the command, so we cannot determine risk.
# In this case we ask users to confirm before executing the real command.
@@ -1019,32 +1030,17 @@ def _risk_level_value(analysis: object) -> str:
f"[bold]{t('shell.security.label.fallback_hint')}:[/bold] {hint}"
)
- # For non-bash tools (e.g. write_file), tool-specific confirmation info is carried
- # in generic fields like tool_args/content_preview/content_length.
- tool_args = data.get("tool_args")
- if isinstance(tool_args, dict):
- file_path = tool_args.get("file_path") or tool_args.get("path")
- if file_path:
- content.append(
- f"[bold]{t('shell.security.label.target')}:[/bold] {file_path}"
- )
-
- if "content" in tool_args:
+ if preview is None:
+ tool_args = data.get("tool_args")
+ if isinstance(tool_args, dict) and "content" in tool_args:
raw_content = tool_args.get("content")
+ if isinstance(raw_content, str):
+ preview = raw_content[:100] + "..." if len(raw_content) > 100 else raw_content
- # Prefer the tool-provided preview; otherwise derive a safe preview.
- content_preview = data.get("content_preview")
- if tool_name == "write_file" and isinstance(raw_content, str):
- content_preview = raw_content
- elif content_preview is None and isinstance(raw_content, str):
- content_preview = (
- raw_content[:100] + "..." if len(raw_content) > 100 else raw_content
- )
-
- if content_preview is not None:
- content.append(
- f"[bold]{t('shell.security.label.content_preview')}:[/bold] {content_preview}"
- )
+ if preview is not None:
+ content.append(
+ f"[bold]{t('shell.security.label.content_preview')}:[/bold] {preview}"
+ )
if security_analysis and (sandbox_enabled or fallback_rule_matched):
is_low_risk = risk_level_upper == "LOW"
diff --git a/src/aish/tools/base.py b/src/aish/tools/base.py
index ee1332b..2f00533 100644
--- a/src/aish/tools/base.py
+++ b/src/aish/tools/base.py
@@ -1,13 +1,65 @@
from __future__ import annotations
from abc import abstractmethod
-from typing import Any, Awaitable
+from dataclasses import dataclass, field
+from enum import Enum
+from pathlib import Path
+from typing import Any, Awaitable, Callable
from pydantic import BaseModel, ConfigDict
from aish.tools.result import ToolResult
+class ToolPreflightAction(str, Enum):
+ EXECUTE = "execute"
+ CONFIRM = "confirm"
+ SHORT_CIRCUIT = "short_circuit"
+
+
+@dataclass
+class ToolExecutionContext:
+ cwd: Path
+ cancellation_token: Any | None = None
+ interruption_manager: Any | None = None
+ is_approved: Callable[[str], bool] | None = None
+
+
+@dataclass
+class ToolPanelSpec:
+ mode: str = "confirm"
+ target: str | None = None
+ preview: str | None = None
+ analysis: dict[str, Any] = field(default_factory=dict)
+ allow_remember: bool = False
+ remember_key: str | None = None
+ title: str | None = None
+
+ def to_event_payload(self) -> dict[str, Any]:
+ payload: dict[str, Any] = {
+ "mode": self.mode,
+ "allow_remember": self.allow_remember,
+ }
+ if self.target is not None:
+ payload["target"] = self.target
+ if self.preview is not None:
+ payload["preview"] = self.preview
+ if self.analysis:
+ payload["analysis"] = self.analysis
+ if self.remember_key is not None:
+ payload["remember_key"] = self.remember_key
+ if self.title is not None:
+ payload["title"] = self.title
+ return payload
+
+
+@dataclass
+class ToolPreflightResult:
+ action: ToolPreflightAction = ToolPreflightAction.EXECUTE
+ panel: ToolPanelSpec | None = None
+ result: ToolResult | None = None
+
+
class ToolBase(BaseModel):
model_config = ConfigDict(extra="allow")
@@ -32,10 +84,90 @@ def get_confirmation_info(self, *args, **kwargs) -> dict:
"""Get additional information for confirmation dialog"""
return {}
+ def get_pre_execute_subject(self, tool_args: dict[str, Any]) -> Any:
+ """Legacy adapter hook for tools that only inspect part of tool_args."""
+ return tool_args
+
+ def prepare_invocation(
+ self, tool_args: dict[str, Any], context: ToolExecutionContext
+ ) -> ToolPreflightResult:
+ """Prepare a tool invocation before execution.
+
+ New tools should override this method directly. The default implementation
+ adapts the legacy confirmation hooks for backward compatibility.
+ """
+
+ _ = context
+ subject = self.get_pre_execute_subject(tool_args)
+ need_confirm = self.need_confirm_before_exec(subject)
+ if not need_confirm:
+ return ToolPreflightResult(action=ToolPreflightAction.EXECUTE)
+
+ info = self.get_confirmation_info(subject)
+ return ToolPreflightResult(
+ action=ToolPreflightAction.CONFIRM,
+ panel=self._build_panel_from_legacy(tool_args, info),
+ )
+
def get_session_output(self, result: ToolResult) -> str | None:
"""Optionally expose a tool result as the session's fallback output."""
return None
+ def _build_panel_from_legacy(
+ self, tool_args: dict[str, Any], info: object
+ ) -> ToolPanelSpec:
+ info_dict = info if isinstance(info, dict) else {}
+ target: str | None = None
+ preview: str | None = None
+ analysis: dict[str, Any] = {}
+ remember_key: str | None = None
+ title: str | None = None
+
+ if isinstance(tool_args, dict):
+ raw_target = tool_args.get("file_path") or tool_args.get("path")
+ if raw_target is not None:
+ target = str(raw_target)
+
+ if isinstance(info_dict, dict):
+ raw_target = info_dict.get("target")
+ if raw_target is not None:
+ target = str(raw_target)
+
+ raw_preview = info_dict.get("preview")
+ if raw_preview is None:
+ raw_preview = info_dict.get("content_preview")
+ if raw_preview is not None:
+ preview = str(raw_preview)
+
+ raw_analysis = info_dict.get("analysis")
+ if isinstance(raw_analysis, dict):
+ analysis = raw_analysis
+ elif isinstance(info_dict.get("security_analysis"), dict):
+ analysis = info_dict["security_analysis"]
+
+ raw_remember_key = info_dict.get("remember_key")
+ if raw_remember_key is None:
+ raw_remember_key = info_dict.get("command")
+ if raw_remember_key is not None:
+ remember_key = str(raw_remember_key)
+
+ raw_title = info_dict.get("title")
+ if raw_title is not None:
+ title = str(raw_title)
+
+ mode = str(info_dict.get("panel_mode", "confirm"))
+ allow_remember = bool(info_dict.get("allow_remember", False))
+
+ return ToolPanelSpec(
+ mode=mode,
+ target=target,
+ preview=preview,
+ analysis=analysis,
+ allow_remember=allow_remember,
+ remember_key=remember_key,
+ title=title,
+ )
+
@abstractmethod
def __call__(
self, *args: Any, **kwargs: Any
diff --git a/src/aish/tools/code_exec.py b/src/aish/tools/code_exec.py
index 8f44f65..7658166 100644
--- a/src/aish/tools/code_exec.py
+++ b/src/aish/tools/code_exec.py
@@ -12,7 +12,8 @@
from aish.offload import render_bash_output
from aish.security.security_manager import (SecurityDecision,
SimpleSecurityManager)
-from aish.tools.base import ToolBase
+from aish.tools.base import (ToolBase, ToolExecutionContext, ToolPanelSpec,
+ ToolPreflightAction, ToolPreflightResult)
from aish.tools.bash_executor import UnifiedBashExecutor
from aish.tools.result import ToolResult
@@ -228,6 +229,115 @@ def need_confirm_before_exec(self, code: Optional[str] = None) -> bool:
return False
return bool(self._last_decision.require_confirmation)
+ def prepare_invocation(
+ self, tool_args: dict[str, object], context: ToolExecutionContext
+ ) -> ToolPreflightResult:
+ command = tool_args.get("code")
+ if not isinstance(command, str) or not command:
+ return ToolPreflightResult(action=ToolPreflightAction.EXECUTE)
+
+ interruption_manager = context.interruption_manager or self.interruption_manager
+ if interruption_manager:
+ interruption_manager.set_state(ShellState.SANDBOX_EVAL)
+
+ decision = self.security_manager.decide(
+ command,
+ is_ai_command=True,
+ cwd=context.cwd,
+ )
+ self._last_decision = decision
+
+ if interruption_manager:
+ interruption_manager.set_state(ShellState.NORMAL)
+
+ info = self.get_confirmation_info(command)
+ analysis_data = (
+ info.get("security_analysis", {}) if isinstance(info, dict) else {}
+ )
+ panel = ToolPanelSpec(
+ mode="confirm",
+ target=command,
+ analysis=analysis_data if isinstance(analysis_data, dict) else {},
+ allow_remember=True,
+ remember_key=command,
+ )
+
+ if (
+ isinstance(analysis_data, dict)
+ and analysis_data.get("fail_open") is True
+ ):
+ return ToolPreflightResult(action=ToolPreflightAction.EXECUTE)
+
+ if callable(context.is_approved) and context.is_approved(command):
+ return ToolPreflightResult(action=ToolPreflightAction.EXECUTE)
+
+ if not decision.allow:
+ reasons = (
+ analysis_data.get("reasons", [])
+ if isinstance(analysis_data, dict)
+ else []
+ )
+ reason_text = ", ".join(str(reason) for reason in reasons[:5] if reason)
+ blocked_msg = (
+ t("security.command_blocked_with_reason", reason=reason_text)
+ if reason_text
+ else t("security.command_blocked")
+ )
+ return ToolPreflightResult(
+ action=ToolPreflightAction.SHORT_CIRCUIT,
+ panel=ToolPanelSpec(
+ mode="blocked",
+ target=command,
+ analysis=panel.analysis,
+ allow_remember=True,
+ remember_key=command,
+ ),
+ result=ToolResult(
+ ok=False,
+ output=blocked_msg,
+ code=126,
+ meta={"kind": "security_blocked", "reasons": reasons},
+ stop_tool_chain=True,
+ ),
+ )
+
+ sandbox_info = (
+ analysis_data.get("sandbox", {}) if isinstance(analysis_data, dict) else {}
+ )
+ sandbox_reason = (
+ str(sandbox_info.get("reason", "")) if isinstance(sandbox_info, dict) else ""
+ )
+ fallback_rule_matched = bool(
+ analysis_data.get("fallback_rule_matched")
+ if isinstance(analysis_data, dict)
+ else False
+ )
+ skip_notice_panel = sandbox_reason in {
+ "sandbox_disabled",
+ "sandbox_disabled_by_policy",
+ } and not fallback_rule_matched
+
+ if decision.require_confirmation:
+ panel.mode = "confirm"
+ return ToolPreflightResult(
+ action=ToolPreflightAction.CONFIRM,
+ panel=panel,
+ )
+
+ risk_level = str(
+ analysis_data.get("risk_level", "UNKNOWN")
+ if isinstance(analysis_data, dict)
+ else "UNKNOWN"
+ ).upper()
+ panel.mode = "blocked" if risk_level in {"HIGH", "CRITICAL"} else "info"
+ if skip_notice_panel:
+ return ToolPreflightResult(action=ToolPreflightAction.EXECUTE)
+
+ return ToolPreflightResult(
+ action=ToolPreflightAction.EXECUTE,
+ panel=panel,
+ )
+
def get_confirmation_info(self, code: Optional[str] = None) -> dict:
"""Get security information for confirmation dialog"""
command = code
diff --git a/src/aish/tools/fs_tools.py b/src/aish/tools/fs_tools.py
index 664f107..d946f28 100644
--- a/src/aish/tools/fs_tools.py
+++ b/src/aish/tools/fs_tools.py
@@ -1,10 +1,16 @@
from pathlib import Path
from typing import ClassVar
-from aish.tools.base import ToolBase
+from aish.tools.base import (ToolBase, ToolExecutionContext, ToolPanelSpec,
+ ToolPreflightAction, ToolPreflightResult)
from aish.tools.result import ToolResult
+def _preview_text(value: object, limit: int = 100) -> str:
+ text = str(value) if value is not None else ""
+ return text[:limit] + "..." if len(text) > limit else text
+
+
# TODO: support images
class ReadFileTool(ToolBase):
MAX_READ_BYTES: ClassVar[int] = 32 * 1024
@@ -188,6 +194,21 @@ def __call__(self, file_path: str, content: str) -> ToolResult:
def need_confirm_before_exec(self, content: str) -> bool:
return True
+ def prepare_invocation(
+ self, tool_args: dict[str, object], context: ToolExecutionContext
+ ) -> ToolPreflightResult:
+ _ = context
+ file_path = tool_args.get("file_path")
+ content = tool_args.get("content", "")
+ return ToolPreflightResult(
+ action=ToolPreflightAction.CONFIRM,
+ panel=ToolPanelSpec(
+ mode="confirm",
+ target=str(file_path) if file_path is not None else None,
+ preview=content if isinstance(content, str) else _preview_text(content),
+ ),
+ )
+
def get_confirmation_info(self, content: str) -> dict:
# For write_file tool, we need to get the file_path from the tool_args
# Since we only receive the content string here, we'll return what we can
@@ -346,18 +367,33 @@ def __call__(
def need_confirm_before_exec(self, tool_args: dict | None = None) -> bool:
return True
+ def prepare_invocation(
+ self, tool_args: dict[str, object], context: ToolExecutionContext
+ ) -> ToolPreflightResult:
+ _ = context
+ replace_all = bool(tool_args.get("replace_all", False))
+ old_string = tool_args.get("old_string", "")
+ new_string = tool_args.get("new_string", "")
+ mode = "Replace all" if replace_all else "Replace"
+ return ToolPreflightResult(
+ action=ToolPreflightAction.CONFIRM,
+ panel=ToolPanelSpec(
+ mode="confirm",
+ target=str(tool_args.get("file_path"))
+ if tool_args.get("file_path") is not None
+ else None,
+ preview=f"{mode}: {_preview_text(old_string)} -> {_preview_text(new_string)}",
+ ),
+ )
+
def get_confirmation_info(self, tool_args: dict | None = None) -> dict:
if not isinstance(tool_args, dict):
return {}
- def _preview(value: object, limit: int = 100) -> str:
- text = str(value) if value is not None else ""
- return text[:limit] + "..." if len(text) > limit else text
-
replace_all = bool(tool_args.get("replace_all", False))
old_string = tool_args.get("old_string", "")
new_string = tool_args.get("new_string", "")
mode = "Replace all" if replace_all else "Replace"
return {
- "content_preview": f"{mode}: {_preview(old_string)} -> {_preview(new_string)}"
+ "content_preview": f"{mode}: {_preview_text(old_string)} -> {_preview_text(new_string)}"
}
diff --git a/tests/test_ask_user_tool.py b/tests/test_ask_user_tool.py
index 3eb7e09..41b389d 100644
--- a/tests/test_ask_user_tool.py
+++ b/tests/test_ask_user_tool.py
@@ -5,9 +5,12 @@
from aish.config import ConfigModel
from aish.context_manager import ContextManager
-from aish.llm import LLMCallbackResult, LLMSession
+from aish.llm import (LLMCallbackResult, LLMSession, ToolDispatchOutcome,
+ ToolDispatchStatus)
from aish.skills import SkillManager
from aish.tools.ask_user import AskUserTool
+from aish.tools.base import (ToolBase, ToolExecutionContext, ToolPanelSpec,
+ ToolPreflightAction, ToolPreflightResult)
from aish.tools.result import ToolResult
@@ -126,9 +129,9 @@ async def test_handle_tool_calls_ask_user_user_input_required_breaks(monkeypatch
async def fake_pre_execute_tool(tool_name, _tool_args):
if tool_name == "ask_user":
- return (
- LLMCallbackResult.APPROVE,
- ToolResult(
+ return ToolDispatchOutcome(
+ status=ToolDispatchStatus.EXECUTED,
+ result=ToolResult(
ok=False,
output="paused",
meta={"kind": "user_input_required", "reason": "cancelled"},
@@ -170,9 +173,9 @@ async def test_handle_tool_calls_system_diagnose_agent_sets_session_output():
async def fake_pre_execute_tool(tool_name, _tool_args):
assert tool_name == "system_diagnose_agent"
- return (
- LLMCallbackResult.APPROVE,
- ToolResult(ok=True, output="diagnostic result"),
+ return ToolDispatchOutcome(
+ status=ToolDispatchStatus.EXECUTED,
+ result=ToolResult(ok=True, output="diagnostic result"),
)
with patch.object(
@@ -208,9 +211,9 @@ async def test_handle_tool_calls_bash_security_blocked_clears_session_output():
async def fake_pre_execute_tool(tool_name, _tool_args):
assert tool_name == "bash_exec"
- return (
- LLMCallbackResult.APPROVE,
- ToolResult(
+ return ToolDispatchOutcome(
+ status=ToolDispatchStatus.SHORT_CIRCUIT,
+ result=ToolResult(
ok=False,
output="blocked",
meta={"kind": "security_blocked"},
@@ -237,24 +240,41 @@ async def test_pre_execute_tool_emits_blocked_panel_for_policy_fallback_rule(mon
config = ConfigModel(model="test-model", api_key="test-key")
session = LLMSession(config=config, skill_manager=SkillManager())
- class _DummyBashTool:
- description = "dummy bash"
- called = False
-
- def need_confirm_before_exec(self, _arg):
- return False
-
- def get_confirmation_info(self, arg):
- return {
- "command": arg,
- "security_decision": {"allow": False, "require_confirmation": False},
- "security_analysis": {
- "risk_level": "HIGH",
- "sandbox": {"enabled": False, "reason": "sandbox_disabled_by_policy"},
- "fallback_rule_matched": True,
- "reasons": ["系统配置目录,误修改会导致严重故障"],
- },
- }
+ class _DummyBashTool(ToolBase):
+ def __init__(self):
+ super().__init__(
+ name="bash_exec",
+ description="dummy bash",
+ parameters={"type": "object", "properties": {"code": {"type": "string"}}},
+ )
+ self.called = False
+
+ def prepare_invocation(
+ self, tool_args: dict[str, object], context: ToolExecutionContext
+ ) -> ToolPreflightResult:
+ _ = context
+ return ToolPreflightResult(
+ action=ToolPreflightAction.SHORT_CIRCUIT,
+ panel=ToolPanelSpec(
+ mode="blocked",
+ target=str(tool_args.get("code")),
+ analysis={
+ "risk_level": "HIGH",
+ "sandbox": {
+ "enabled": False,
+ "reason": "sandbox_disabled_by_policy",
+ },
+ "fallback_rule_matched": True,
+ "reasons": ["系统配置目录,误修改会导致严重故障"],
+ },
+ ),
+ result=ToolResult(
+ ok=False,
+ output="blocked",
+ meta={"kind": "security_blocked"},
+ stop_tool_chain=True,
+ ),
+ )
async def __call__(self, code: str):
self.called = True
@@ -264,10 +284,177 @@ async def __call__(self, code: str):
emitted: list[tuple[object, dict]] = []
monkeypatch.setattr(session, "emit_event", lambda event_type, data=None: emitted.append((event_type, data or {})))
- goon, _result = await session.pre_execute_tool("bash_exec", {"code": "sudo rm /etc/aish/123"})
+ outcome = await session.pre_execute_tool("bash_exec", {"code": "sudo rm /etc/aish/123"})
- assert goon == LLMCallbackResult.CONTINUE
+ assert outcome.status == ToolDispatchStatus.SHORT_CIRCUIT
assert emitted
- assert emitted[0][1].get("panel_mode") == "blocked"
- assert _result.meta.get("kind") == "security_blocked"
+ assert emitted[0][1].get("panel", {}).get("mode") == "blocked"
+ assert outcome.result.meta.get("kind") == "security_blocked"
assert session.tools["bash_exec"].called is False
+
+
+@pytest.mark.anyio
+async def test_pre_execute_tool_info_panel_executes_tool(monkeypatch):
+ config = ConfigModel(model="test-model", api_key="test-key")
+ session = LLMSession(config=config, skill_manager=SkillManager())
+
+ class _InfoTool(ToolBase):
+ def __init__(self):
+ super().__init__(
+ name="info_tool",
+ description="info tool",
+ parameters={"type": "object", "properties": {"value": {"type": "string"}}},
+ )
+
+ def prepare_invocation(
+ self, tool_args: dict[str, object], context: ToolExecutionContext
+ ) -> ToolPreflightResult:
+ _ = context
+ return ToolPreflightResult(
+ action=ToolPreflightAction.EXECUTE,
+ panel=ToolPanelSpec(mode="info", target=str(tool_args.get("value"))),
+ )
+
+ def __call__(self, value: str):
+ return ToolResult(ok=True, output=f"echo:{value}")
+
+ session.tools["info_tool"] = _InfoTool()
+ emitted: list[tuple[object, dict]] = []
+ monkeypatch.setattr(
+ session,
+ "emit_event",
+ lambda event_type, data=None: emitted.append((event_type, data or {})),
+ )
+
+ outcome = await session.pre_execute_tool("info_tool", {"value": "hello"})
+
+ assert outcome.status == ToolDispatchStatus.EXECUTED
+ assert outcome.result.output == "echo:hello"
+ assert emitted
+ assert emitted[0][1].get("panel", {}).get("mode") == "info"
+
+
+@pytest.mark.anyio
+async def test_pre_execute_tool_legacy_hooks_use_get_pre_execute_subject(monkeypatch):
+ config = ConfigModel(model="test-model", api_key="test-key")
+ session = LLMSession(config=config, skill_manager=SkillManager())
+
+ class _LegacyTool(ToolBase):
+ def __init__(self):
+ super().__init__(
+ name="legacy_tool",
+ description="legacy tool",
+ parameters={
+ "type": "object",
+ "properties": {
+ "dangerous": {"type": "string"},
+ "file_path": {"type": "string"},
+ },
+ },
+ )
+ self.seen_subject = None
+
+ def get_pre_execute_subject(self, tool_args: dict[str, object]) -> object:
+ return tool_args.get("dangerous")
+
+ def need_confirm_before_exec(self, subject: object) -> bool:
+ self.seen_subject = subject
+ return True
+
+ def get_confirmation_info(self, subject: object) -> dict:
+ return {
+ "target": "/tmp/demo.txt",
+ "preview": f"subject={subject}",
+ }
+
+ def __call__(self, dangerous: str, file_path: str):
+ return ToolResult(ok=True, output=f"{dangerous}:{file_path}")
+
+ session.tools["legacy_tool"] = _LegacyTool()
+ monkeypatch.setattr(session, "request_confirmation", lambda *_args, **_kwargs: LLMCallbackResult.APPROVE)
+
+ outcome = await session.pre_execute_tool(
+ "legacy_tool",
+ {"dangerous": "rm -rf", "file_path": "/tmp/demo.txt"},
+ )
+
+ assert outcome.status == ToolDispatchStatus.EXECUTED
+ assert session.tools["legacy_tool"].seen_subject == "rm -rf"
+ assert outcome.result.output == "rm -rf:/tmp/demo.txt"
+
+
+@pytest.mark.anyio
+async def test_pre_execute_tool_write_file_confirmation_uses_panel_target_preview(
+ monkeypatch,
+):
+ config = ConfigModel(model="test-model", api_key="test-key")
+ session = LLMSession(config=config, skill_manager=SkillManager())
+ captured: dict[str, object] = {}
+
+ def _request_confirmation(_event_type, data, **_kwargs):
+ captured.update(data)
+ return LLMCallbackResult.DENY
+
+ monkeypatch.setattr(session, "request_confirmation", _request_confirmation)
+
+ outcome = await session.pre_execute_tool(
+ "write_file",
+ {"file_path": "/tmp/demo.txt", "content": "hello world"},
+ )
+
+ assert outcome.status == ToolDispatchStatus.REJECTED
+ assert captured.get("panel", {}).get("target") == "/tmp/demo.txt"
+ assert captured.get("panel", {}).get("preview") == "hello world"
+
+
+@pytest.mark.anyio
+async def test_pre_execute_tool_write_file_confirmation_keeps_full_long_preview(
+ monkeypatch,
+):
+ config = ConfigModel(model="test-model", api_key="test-key")
+ session = LLMSession(config=config, skill_manager=SkillManager())
+ captured: dict[str, object] = {}
+ long_content = "A" * 150 + "TAIL"
+
+ def _request_confirmation(_event_type, data, **_kwargs):
+ captured.update(data)
+ return LLMCallbackResult.DENY
+
+ monkeypatch.setattr(session, "request_confirmation", _request_confirmation)
+
+ outcome = await session.pre_execute_tool(
+ "write_file",
+ {"file_path": "/tmp/demo.txt", "content": long_content},
+ )
+
+ assert outcome.status == ToolDispatchStatus.REJECTED
+ assert captured.get("panel", {}).get("preview") == long_content
+
+
+@pytest.mark.anyio
+async def test_pre_execute_tool_edit_file_confirmation_uses_panel_target_preview(
+ monkeypatch,
+):
+ config = ConfigModel(model="test-model", api_key="test-key")
+ session = LLMSession(config=config, skill_manager=SkillManager())
+ captured: dict[str, object] = {}
+
+ def _request_confirmation(_event_type, data, **_kwargs):
+ captured.update(data)
+ return LLMCallbackResult.DENY
+
+ monkeypatch.setattr(session, "request_confirmation", _request_confirmation)
+
+ outcome = await session.pre_execute_tool(
+ "edit_file",
+ {
+ "file_path": "/tmp/demo.txt",
+ "old_string": "old",
+ "new_string": "new",
+ "replace_all": True,
+ },
+ )
+
+ assert outcome.status == ToolDispatchStatus.REJECTED
+ assert captured.get("panel", {}).get("target") == "/tmp/demo.txt"
+ assert captured.get("panel", {}).get("preview") == "Replace all: old -> new"
diff --git a/tests/test_shell_prompt_io.py b/tests/test_shell_prompt_io.py
index 45b7e1c..f125270 100644
--- a/tests/test_shell_prompt_io.py
+++ b/tests/test_shell_prompt_io.py
@@ -7,7 +7,11 @@
from rich.console import Console
from aish.llm import LLMCallbackResult, LLMEvent, LLMEventType
-from aish.shell_enhanced.shell_prompt_io import display_security_panel, handle_ask_user_required
+from aish.shell_enhanced.shell_prompt_io import (
+ display_security_panel,
+ handle_ask_user_required,
+ handle_tool_confirmation_required,
+)
def _reset_i18n_cache() -> None:
@@ -117,6 +121,8 @@ def test_display_security_panel_shows_fallback_rule_details(monkeypatch):
assert "风险等级" in output
assert "原因" in output
assert "系统配置目录,误修改会导致严重故障" in output
+
+
def test_display_security_panel_for_fallback_rule_confirm_hides_generic_fallback_hint(
monkeypatch,
):
@@ -144,3 +150,32 @@ def test_display_security_panel_for_fallback_rule_confirm_hides_generic_fallback
assert "用户业务数据变更需人工确认" in output
assert "未能完成命令风险评估" not in output
+
+def test_handle_tool_confirmation_required_uses_panel_payload():
+ shell = _DummyShell()
+ captured: list[tuple[object, object]] = []
+ shell._get_user_confirmation = lambda remember_command=None, allow_remember=False: (
+ captured.append((remember_command, allow_remember)) or LLMCallbackResult.APPROVE
+ )
+ shell._display_security_panel = lambda data, panel_mode="confirm": captured.append(
+ ("panel", panel_mode, data.get("panel", {}))
+ )
+
+ event = LLMEvent(
+ event_type=LLMEventType.TOOL_CONFIRMATION_REQUIRED,
+ data={
+ "tool_name": "bash_exec",
+ "panel": {
+ "mode": "confirm",
+ "target": "echo hi",
+ "allow_remember": True,
+ "remember_key": "echo hi",
+ },
+ },
+ timestamp=time.time(),
+ )
+
+ result = handle_tool_confirmation_required(shell, event)
+
+ assert result == LLMCallbackResult.APPROVE
+ assert ("echo hi", True) in captured