diff --git a/src/agent/__init__.py b/src/agent/__init__.py index 2fb64d7..0054527 100644 --- a/src/agent/__init__.py +++ b/src/agent/__init__.py @@ -34,6 +34,17 @@ LEGACY_AGENT_TOOL_NAME, ONE_SHOT_BUILTIN_AGENT_TYPES, ) +from .filter_agents_by_mcp import ( + filter_agents_by_mcp_requirements, + has_required_mcp_servers, +) +from .load_agents_dir import ( + clear_agent_definitions_cache, + get_active_agents_from_list, + get_agent_definitions_with_overrides, +) +from .load_plugin_agents import load_plugin_agents +from .parse_agent_markdown import parse_agent_from_markdown from .prompt import ( format_agent_line, get_agent_prompt, @@ -97,4 +108,12 @@ # Subagent context "SubagentContextOverrides", "create_subagent_context", + # Custom-agent discovery + "clear_agent_definitions_cache", + "filter_agents_by_mcp_requirements", + "get_active_agents_from_list", + "get_agent_definitions_with_overrides", + "has_required_mcp_servers", + "load_plugin_agents", + "parse_agent_from_markdown", ] diff --git a/src/agent/agent_definitions.py b/src/agent/agent_definitions.py index f8cfe4a..e3b1398 100644 --- a/src/agent/agent_definitions.py +++ b/src/agent/agent_definitions.py @@ -10,7 +10,7 @@ from ..permissions.types import PermissionMode -AgentSource = Literal["built-in", "user", "plugin", "dynamic"] +AgentSource = Literal["built-in", "user", "project", "managed", "plugin", "dynamic"] @dataclass diff --git a/src/agent/filter_agents_by_mcp.py b/src/agent/filter_agents_by_mcp.py new file mode 100644 index 0000000..0a2e0ac --- /dev/null +++ b/src/agent/filter_agents_by_mcp.py @@ -0,0 +1,49 @@ +"""Filter agents by their declared ``required_mcp_servers``. + +Port of ``hasRequiredMcpServers`` / ``filterAgentsByMcpRequirements`` in +typescript/src/tools/AgentTool/loadAgentsDir.ts:228-254. + +Built-in agents are never dropped — they're trusted regardless of MCP +availability. +""" +from __future__ import annotations + +from collections.abc import Iterable + +from src.agent.agent_definitions import AgentDefinition, is_built_in_agent + + +def has_required_mcp_servers( + agent: AgentDefinition, + available_servers: Iterable[str], +) -> bool: + """Return True iff every required pattern matches an available server. + + Matching is case-insensitive substring (same as the TS reference): the + pattern ``slack`` matches the server name ``MySlackServer``. Empty + requirements pass through. + """ + if not agent.required_mcp_servers: + return True + available_lower = [s.lower() for s in available_servers] + return all( + any(pattern.lower() in server for server in available_lower) + for pattern in agent.required_mcp_servers + ) + + +def filter_agents_by_mcp_requirements( + agents: Iterable[AgentDefinition], + available_servers: Iterable[str], +) -> list[AgentDefinition]: + """Drop agents whose required MCP servers aren't available. + + Built-ins are exempt: they're not allowed to declare requirements and + must always be reachable. + """ + available_list = list(available_servers) + return [ + agent + for agent in agents + if is_built_in_agent(agent) or has_required_mcp_servers(agent, available_list) + ] diff --git a/src/agent/load_agents_dir.py b/src/agent/load_agents_dir.py new file mode 100644 index 0000000..a122397 --- /dev/null +++ b/src/agent/load_agents_dir.py @@ -0,0 +1,157 @@ +"""Discover and merge custom agent definitions from disk + plugins. + +Port of ``getAgentDefinitionsWithOverrides`` in +typescript/src/tools/AgentTool/loadAgentsDir.ts. Combines built-in agents +(``src/agent/agent_definitions.py:get_built_in_agents``), plugin agents +(via ``load_plugin_agents``), and on-disk custom agents from managed / +user / project directories (via ``load_markdown_files_for_subdir``). + +Last-wins merge order on duplicate ``agent_type``: + [built-in, plugin, user, project, managed] + +A module-level cache keyed on cwd avoids re-walking the filesystem on +every prompt build. Call ``clear_agent_definitions_cache()`` after a +known on-disk change (e.g., the user edits ``~/.claude/agents/foo.md``) +to force a refresh. +""" +from __future__ import annotations + +import logging +import os +from typing import Iterable + +from src.agent.agent_definitions import AgentDefinition, get_built_in_agents +from src.agent.parse_agent_markdown import parse_agent_from_markdown +from src.utils.markdown_config_loader import ( + SOURCE_MANAGED, + SOURCE_PROJECT, + SOURCE_USER, + load_markdown_files_for_subdir, +) + +logger = logging.getLogger(__name__) + + +# Each disk source maps to the matching ``AgentSource`` literal so +# downstream consumers can distinguish a managed-policy agent from a +# user one (e.g., to enforce "managed cannot be overridden by user"). +_SOURCE_TO_AGENT_SOURCE: dict[str, str] = { + SOURCE_MANAGED: "managed", + SOURCE_USER: "user", + SOURCE_PROJECT: "project", +} + +# Priority order for last-wins merge — earlier entries are overridden by +# later ones if they share an agent_type. +_MERGE_ORDER: tuple[str, ...] = ( + "built-in", + "plugin", + SOURCE_USER, + SOURCE_PROJECT, + SOURCE_MANAGED, +) + + +# Cache is keyed on ``os.path.realpath(cwd)`` so symlinked / trailing-slash +# variants of the same project collapse into a single entry. The cache is +# session-bound; SDK callers that hop between unrelated projects can grow +# it unboundedly — acceptable for now since per-cwd discovery is cheap. +_agent_dir_cache: dict[str, list[AgentDefinition]] = {} + + +def _cache_key(cwd: str) -> str: + try: + return os.path.realpath(cwd) + except (OSError, ValueError): + return cwd + + +def clear_agent_definitions_cache() -> None: + """Drop the discovery cache. Call after on-disk agent changes.""" + _agent_dir_cache.clear() + + +def get_active_agents_from_list( + agents: Iterable[AgentDefinition], +) -> list[AgentDefinition]: + """Last-wins dedup by ``agent_type`` while preserving input order. + + Mirrors ``getActiveAgentsFromList`` from loadAgentsDir.ts:192-220. + Callers are responsible for arranging input order so the desired + override priority is honoured (lowest priority first, highest last). + """ + by_type: dict[str, AgentDefinition] = {} + order: list[str] = [] + for agent in agents: + if agent.agent_type not in by_type: + order.append(agent.agent_type) + by_type[agent.agent_type] = agent + return [by_type[t] for t in order] + + +def _load_custom_agents(cwd: str) -> dict[str, list[AgentDefinition]]: + """Group disk-discovered agents by their disk source label.""" + grouped: dict[str, list[AgentDefinition]] = { + SOURCE_USER: [], + SOURCE_PROJECT: [], + SOURCE_MANAGED: [], + } + files = load_markdown_files_for_subdir("agents", cwd) + for md in files: + agent_source = _SOURCE_TO_AGENT_SOURCE.get(md.source, "user") + agent = parse_agent_from_markdown( + file_path=md.file_path, + frontmatter=md.frontmatter, + body=md.body, + source=agent_source, # type: ignore[arg-type] + base_dir=md.base_dir, + ) + if agent is None: + continue + grouped[md.source].append(agent) + return grouped + + +def get_agent_definitions_with_overrides(cwd: str) -> list[AgentDefinition]: + """Return the merged list of agents visible from ``cwd``. + + Cache-keyed on ``cwd``. Built-ins are always included; the user can + override a built-in by defining an agent with the same ``agent_type``. + On any unexpected loader error the built-ins are returned alone — a + broken custom agent file should never disable the model's ability to + spawn the built-in agents. + """ + key = _cache_key(cwd) + cached = _agent_dir_cache.get(key) + if cached is not None: + return list(cached) + + try: + builtins = list(get_built_in_agents()) + try: + from src.agent.load_plugin_agents import load_plugin_agents + from src.plugins import get_loaded_plugins + plugin_agents = load_plugin_agents(get_loaded_plugins()) + except Exception: + logger.exception("plugin agent loading failed; continuing without plugin agents") + plugin_agents = [] + + custom = _load_custom_agents(cwd) + + sources_in_order: dict[str, list[AgentDefinition]] = { + "built-in": builtins, + "plugin": plugin_agents, + SOURCE_USER: custom[SOURCE_USER], + SOURCE_PROJECT: custom[SOURCE_PROJECT], + SOURCE_MANAGED: custom[SOURCE_MANAGED], + } + flat: list[AgentDefinition] = [] + for source_key in _MERGE_ORDER: + flat.extend(sources_in_order.get(source_key, [])) + + active = get_active_agents_from_list(flat) + _agent_dir_cache[key] = active + return list(active) + except Exception: + logger.exception("agent discovery failed; falling back to built-ins") + return list(get_built_in_agents()) diff --git a/src/agent/load_plugin_agents.py b/src/agent/load_plugin_agents.py new file mode 100644 index 0000000..f7cab03 --- /dev/null +++ b/src/agent/load_plugin_agents.py @@ -0,0 +1,109 @@ +"""Load agent definitions exposed by enabled plugins. + +Mirrors ``loadPluginAgents`` in +typescript/src/utils/plugins/loadPluginAgents.ts. For each enabled +plugin with a non-empty ``agents_paths``, walks the directory +recursively for ``*.md`` files, parses each via ``parse_agent_from_markdown``, +and namespaces the resulting ``agent_type`` as +``"::"`` so nested folders cannot collide. + +Plugin agents intentionally drop ``permission_mode``, ``hooks``, and +``mcp_servers`` from the parsed definition — those grant capabilities +beyond install-time trust and must come from user-controlled settings, +not third-party plugin manifests. +""" +from __future__ import annotations + +import logging +from dataclasses import replace +from pathlib import Path + +from src.agent.agent_definitions import AgentDefinition +from src.agent.parse_agent_markdown import parse_agent_from_markdown +from src.plugins.types import LoadedPlugin +from src.skills.frontmatter import parse_frontmatter + +logger = logging.getLogger(__name__) + + +def _scan_md_files(directory: str) -> list[tuple[str, str]]: + """Recursively list ``*.md`` files under ``directory``. + + Returns ``(absolute_file_path, relative_namespace)`` pairs where + ``relative_namespace`` is the parent-dir path relative to + ``directory``, with separators turned into ``:`` (so a file at + ``/foo/bar.md`` yields namespace ``"foo"``). Files directly + under ``directory`` yield ``""``. + """ + base = Path(directory) + if not base.is_dir(): + return [] + out: list[tuple[str, str]] = [] + try: + for path in base.rglob("*.md"): + if not path.is_file(): + continue + rel = path.parent.relative_to(base) + namespace = ":".join(rel.parts) if rel.parts else "" + out.append((str(path), namespace)) + except (OSError, PermissionError): + return [] + return sorted(out) + + +def _build_namespaced_agent_type( + plugin_name: str, namespace: str, base_name: str, +) -> str: + parts = [plugin_name] + if namespace: + parts.append(namespace) + parts.append(base_name) + return ":".join(parts) + + +def load_plugin_agents(plugins: list[LoadedPlugin]) -> list[AgentDefinition]: + """Return all agent definitions discovered across the given plugins. + + Agent types are namespaced as ``::`` to + mirror the TS ``walkPluginMarkdown`` convention — without the + ```` segment, plugins shipping multiple agents named + ``review.md`` in different folders would silently collide. + """ + agents: list[AgentDefinition] = [] + for plugin in plugins: + if not plugin.enabled or not plugin.agents_paths: + continue + for agents_dir in plugin.agents_paths: + for file_path, namespace in _scan_md_files(agents_dir): + try: + content = Path(file_path).read_text(encoding="utf-8") + except (OSError, PermissionError, UnicodeDecodeError) as exc: + logger.debug( + "plugin %s: failed to read %s: %s", + plugin.name, file_path, exc, + ) + continue + parsed = parse_frontmatter(content) + agent = parse_agent_from_markdown( + file_path=file_path, + frontmatter=parsed.frontmatter, + body=parsed.body, + source="plugin", + base_dir=plugin.path, + ) + if agent is None: + continue + namespaced = replace( + agent, + agent_type=_build_namespaced_agent_type( + plugin.name, namespace, agent.agent_type, + ), + source="plugin", + # Strip elevated capabilities: plugins cannot grant + # permission overrides, hooks, or MCP servers. + permission_mode=None, + hooks=None, + mcp_servers=None, + ) + agents.append(namespaced) + return agents diff --git a/src/agent/parse_agent_markdown.py b/src/agent/parse_agent_markdown.py new file mode 100644 index 0000000..1a9320a --- /dev/null +++ b/src/agent/parse_agent_markdown.py @@ -0,0 +1,226 @@ +"""Parse a markdown agent definition into an ``AgentDefinition``. + +Mirrors ``parseAgentFromMarkdown`` in typescript/src/tools/AgentTool/loadAgentsDir.ts. + +Field mapping (frontmatter → AgentDefinition): + name → agent_type (defaults to filename stem) + description → when_to_use (required) + tools → tools (None / ['*'] both mean "all") + disallowed-tools → disallowed_tools + disallowedTools → disallowed_tools (camelCase alias) + model → model ('inherit' kept as a literal) + permission-mode → permission_mode + permissionMode → permission_mode (camelCase alias) + max-turns → max_turns + maxTurns → max_turns (camelCase alias) + background → background + color → color + memory → memory + omit-claude-md → omit_claude_md + omitClaudeMd → omit_claude_md (camelCase alias) + hooks → hooks + skills → skills + isolation → isolation + required-mcp-servers → required_mcp_servers + requiredMcpServers → required_mcp_servers (camelCase alias) + mcp-servers → mcp_servers + mcpServers → mcp_servers (camelCase alias) + effort → effort + +The markdown body becomes the agent's system prompt, returned by +``agent.get_system_prompt()``. Missing required fields produce ``None`` +with a debug log; the loader silently drops the file rather than crash. +""" +from __future__ import annotations + +import logging +from pathlib import Path +from typing import Any + +from src.agent.agent_definitions import AgentDefinition, AgentSource +from src.utils.frontmatter_validators import ( + parse_effort_value, + parse_hooks, + parse_permission_mode, + parse_positive_int, + parse_string_list, +) + +logger = logging.getLogger(__name__) + + +AGENT_COLORS: frozenset[str] = frozenset( + {"red", "blue", "green", "yellow", "purple", "orange", "pink", "cyan"} +) +VALID_MEMORY_SCOPES: frozenset[str] = frozenset({"user", "project", "local"}) +VALID_ISOLATION_MODES: frozenset[str] = frozenset({"worktree", "remote"}) + + +def _first(d: dict[str, Any], *keys: str) -> Any: + """Return the first non-``None`` value among ``keys`` in ``d``.""" + for key in keys: + if key in d and d[key] is not None: + return d[key] + return None + + +def _parse_tools(value: Any) -> list[str] | None: + """Parse a tools list. ``None`` / ``['*']`` both mean "all tools". + + Returns ``None`` to signal all-tools (matches AgentDefinition.tools + semantics: ``None`` or ``['*']`` both mean unrestricted). + """ + if value is None: + return None + parsed = parse_string_list(value) + if not parsed: + return [] + if "*" in parsed: + return None + return parsed + + +def _parse_color(value: Any) -> str | None: + if value is None or not isinstance(value, str): + return None + color = value.strip().lower() + return color if color in AGENT_COLORS else None + + +def _parse_memory(value: Any, *, file_path: str) -> str | None: + if value is None or value == "": + return None + s = str(value).strip() + if s in VALID_MEMORY_SCOPES: + return s + logger.debug( + "agent %s: invalid memory=%r (valid: %s)", + file_path, value, ", ".join(sorted(VALID_MEMORY_SCOPES)), + ) + return None + + +def _parse_isolation(value: Any, *, file_path: str) -> str | None: + if value is None or value == "": + return None + s = str(value).strip() + if s in VALID_ISOLATION_MODES: + return s + logger.debug( + "agent %s: invalid isolation=%r (valid: %s)", + file_path, value, ", ".join(sorted(VALID_ISOLATION_MODES)), + ) + return None + + +def _parse_bool(value: Any) -> bool: + if isinstance(value, bool): + return value + if isinstance(value, str): + return value.strip().lower() in ("true", "yes", "1") + return False + + +def _parse_model(value: Any) -> str | None: + """Return ``'inherit'``, a concrete model string, or ``None``.""" + if value is None: + return None + s = str(value).strip() + if not s: + return None + return "inherit" if s.lower() == "inherit" else s + + +def parse_agent_from_markdown( + file_path: str, + frontmatter: dict[str, Any], + body: str, + source: AgentSource, + base_dir: str, +) -> AgentDefinition | None: + """Map a parsed markdown agent definition to an ``AgentDefinition``. + + Returns ``None`` (with a debug log) when the required ``description`` + field is missing. Never raises — every other invalid field is dropped + silently so a single bad value doesn't prevent the agent from loading. + """ + raw_name = _first(frontmatter, "name") + if raw_name is not None and not isinstance(raw_name, str): + # Reject non-string names (TS does the same). YAML can coerce + # ``name: true`` to a bool or ``name: 12345`` to an int; treating + # those as the agent_type would silently register agents that + # can't be invoked via ``@agent-True`` mention syntax. + logger.debug( + "agent file %s: 'name' must be a string (got %s); using filename", + file_path, type(raw_name).__name__, + ) + raw_name = None + agent_type = (raw_name or Path(file_path).stem).strip() + if not agent_type: + logger.debug("agent file %s has empty name; skipping", file_path) + return None + + description = _first(frontmatter, "description") + if not description or not isinstance(description, str): + logger.debug( + "agent file %s is missing required 'description'; skipping", + file_path, + ) + return None + when_to_use = description.replace("\\n", "\n") + + tools = _parse_tools(_first(frontmatter, "tools")) + + disallowed_raw = _first(frontmatter, "disallowed-tools", "disallowedTools") + disallowed_tools = parse_string_list(disallowed_raw) if disallowed_raw is not None else None + + model = _parse_model(_first(frontmatter, "model")) + permission_mode = parse_permission_mode( + _first(frontmatter, "permission-mode", "permissionMode") + ) + max_turns = parse_positive_int(_first(frontmatter, "max-turns", "maxTurns")) + background = _parse_bool(_first(frontmatter, "background")) + color = _parse_color(_first(frontmatter, "color")) + memory = _parse_memory(_first(frontmatter, "memory"), file_path=file_path) + omit_claude_md = _parse_bool(_first(frontmatter, "omit-claude-md", "omitClaudeMd")) + hooks = parse_hooks(_first(frontmatter, "hooks"), owner_name=f"agent {agent_type}") + skills = parse_string_list(_first(frontmatter, "skills")) + isolation = _parse_isolation( + _first(frontmatter, "isolation"), file_path=file_path + ) + required_mcp_servers = parse_string_list( + _first(frontmatter, "required-mcp-servers", "requiredMcpServers") + ) + mcp_servers_raw = _first(frontmatter, "mcp-servers", "mcpServers") + mcp_servers: list[Any] | None = ( + list(mcp_servers_raw) if isinstance(mcp_servers_raw, list) else None + ) + effort = parse_effort_value(_first(frontmatter, "effort")) + + body_text = body.strip() + + def _get_system_prompt(**_kwargs: Any) -> str: + return body_text + + return AgentDefinition( + agent_type=agent_type, + when_to_use=when_to_use, + tools=tools, + source=source, + base_dir=base_dir, + model=model, + permission_mode=permission_mode, + max_turns=max_turns, + background=background, + color=color, + memory=memory, + omit_claude_md=omit_claude_md, + disallowed_tools=disallowed_tools, + hooks=hooks, + skills=skills or None, + isolation=isolation, # type: ignore[arg-type] + required_mcp_servers=required_mcp_servers or None, + mcp_servers=mcp_servers, + effort=effort, + get_system_prompt=_get_system_prompt, + ) diff --git a/src/plugins/loader.py b/src/plugins/loader.py index fed89e1..4653e10 100644 --- a/src/plugins/loader.py +++ b/src/plugins/loader.py @@ -77,6 +77,23 @@ def load_plugin_from_directory( version=raw.get("version", "1.0.0"), ) + agents_paths: list[str] = [] + single = raw.get("agentsPath") + if isinstance(single, str) and single.strip(): + agents_paths.append(single.strip()) + multi = raw.get("agentsPaths") + if isinstance(multi, list): + for item in multi: + if isinstance(item, str) and item.strip(): + agents_paths.append(item.strip()) + + resolved_agents_paths: list[str] = [] + for entry in agents_paths: + p = Path(entry) + resolved = str(p) if p.is_absolute() else str(plugin_dir / entry) + if resolved not in resolved_agents_paths: + resolved_agents_paths.append(resolved) + plugin = LoadedPlugin( name=manifest.name, manifest=manifest, @@ -86,6 +103,7 @@ def load_plugin_from_directory( enabled=raw.get("enabled", True), hooks_config=raw.get("hooks"), mcp_servers=raw.get("mcp_servers"), + agents_paths=resolved_agents_paths, ) return plugin diff --git a/src/plugins/types.py b/src/plugins/types.py index e21fca7..3d7a276 100644 --- a/src/plugins/types.py +++ b/src/plugins/types.py @@ -22,6 +22,10 @@ class LoadedPlugin: is_builtin: bool = False hooks_config: dict[str, Any] | None = None mcp_servers: dict[str, Any] | None = None + # Directories (relative to ``path`` or absolute) holding ``*.md`` agent + # definitions exposed by this plugin. Populated from the manifest's + # ``agentsPath`` (single) and ``agentsPaths`` (list) keys. + agents_paths: list[str] = field(default_factory=list) @dataclass diff --git a/src/repl/core.py b/src/repl/core.py index 1ef0405..e132645 100644 --- a/src/repl/core.py +++ b/src/repl/core.py @@ -315,7 +315,20 @@ def __init__( self.provider.model ) - self.tool_registry = build_default_registry(provider=self.provider) + # Late-binding closure: ``tool_context`` is built below, but the + # Agent tool's prompt builder won't read this until much later, + # so reading ``self.tool_context.mcp_clients`` lazily is safe. + def _get_mcp_servers_for_prompt() -> list[str]: + ctx = getattr(self, "tool_context", None) + if ctx is None: + return [] + clients = getattr(ctx, "mcp_clients", None) or {} + return list(clients.keys()) + + self.tool_registry = build_default_registry( + provider=self.provider, + get_available_mcp_servers=_get_mcp_servers_for_prompt, + ) self._engine_messages: list[Any] = [] from src.permissions.types import ToolPermissionContext @@ -1388,23 +1401,38 @@ def _format_tool_result_preview( def _available_agents(self) -> list[Any]: """Return the list of agent definitions that can be invoked via ``@agent-...``. - Pulls built-in agents and merges any extras registered on - ``tool_context.options.agent_definitions`` so that user/plugin agents - participate in the same ``@agent-`` lookup that the TypeScript - ``processAgentMentions`` uses. + Calls the on-disk loader so user / project / managed / plugin + agents participate in the same ``@agent-`` lookup the + TypeScript ``processAgentMentions`` performs. ``options.agent_definitions`` + is still honored as an SDK-side override and supports both the + canonical ``{"active_agents": [...]}`` shape and a legacy flat + list/dict form so existing harnesses keep working. """ try: from src.agent.agent_definitions import get_built_in_agents + from src.agent.load_agents_dir import ( + get_agent_definitions_with_overrides, + ) except Exception: return [] - agents = list(get_built_in_agents()) - extra = getattr(getattr(self.tool_context, "options", None), "agent_definitions", None) + extra = getattr( + getattr(self.tool_context, "options", None), + "agent_definitions", + None, + ) if isinstance(extra, dict): - agents.extend(extra.values()) - elif isinstance(extra, list): - agents.extend(extra) - return agents + active = extra.get("active_agents") + if isinstance(active, list) and active: + return list(active) + + try: + cwd = str( + self.tool_context.cwd or self.tool_context.workspace_root + ) + return list(get_agent_definitions_with_overrides(cwd)) + except Exception: + return list(get_built_in_agents()) def _enqueue_prompt(self, text: str) -> None: """Append a user-typed prompt to the queue from any thread.""" @@ -2712,7 +2740,17 @@ def _handle_relogin(self): self.provider_name = provider # Rebuild tool registry with new provider so Agent tool works - self.tool_registry = build_default_registry(provider=self.provider) + def _get_mcp_servers_for_prompt() -> list[str]: + ctx = getattr(self, "tool_context", None) + if ctx is None: + return [] + clients = getattr(ctx, "mcp_clients", None) or {} + return list(clients.keys()) + + self.tool_registry = build_default_registry( + provider=self.provider, + get_available_mcp_servers=_get_mcp_servers_for_prompt, + ) self.console.print("[green]✓ Provider reinitialized. You can continue chatting![/green]\n") diff --git a/src/tool_system/defaults.py b/src/tool_system/defaults.py index 1b95594..53425bb 100644 --- a/src/tool_system/defaults.py +++ b/src/tool_system/defaults.py @@ -1,5 +1,7 @@ from __future__ import annotations +from typing import Any, Callable + from .registry import ToolRegistry from .tools import ALL_STATIC_TOOLS, make_agent_tool, make_tool_search_tool @@ -8,10 +10,17 @@ def build_default_registry( *, include_user_tools: bool = True, provider: "Any | None" = None, + get_available_mcp_servers: Callable[[], list[str]] | None = None, ) -> ToolRegistry: registry = ToolRegistry() for tool in ALL_STATIC_TOOLS: registry.register(tool) - registry.register(make_agent_tool(registry, provider=provider)) + registry.register( + make_agent_tool( + registry, + provider=provider, + get_available_mcp_servers=get_available_mcp_servers, + ) + ) registry.register(make_tool_search_tool(registry)) return registry diff --git a/src/tool_system/tools/agent.py b/src/tool_system/tools/agent.py index 23d164b..65befae 100644 --- a/src/tool_system/tools/agent.py +++ b/src/tool_system/tools/agent.py @@ -12,6 +12,7 @@ import asyncio import logging +import os import sys import time from typing import Any @@ -29,6 +30,8 @@ find_agent_by_type, get_built_in_agents, ) +from src.agent.filter_agents_by_mcp import filter_agents_by_mcp_requirements +from src.agent.load_agents_dir import get_agent_definitions_with_overrides from src.agent.agent_tool_utils import ( extract_partial_result, finalize_agent_tool, @@ -112,6 +115,7 @@ def make_agent_tool( registry: ToolRegistry, provider: Any | None = None, + get_available_mcp_servers: Any | None = None, ) -> Tool: """Build the Agent tool. @@ -121,15 +125,31 @@ def make_agent_tool( registry: Tool registry providing the available tool pool. provider: BaseProvider for API calls. If None, agent execution is a no-op (useful for testing tool registration only). + get_available_mcp_servers: Optional zero-arg callable returning the + currently-available MCP server names. Used by the prompt builder so + agents declaring ``required_mcp_servers`` not present in the live + inventory are hidden from the tool description (matching the + per-call resolver). When ``None`` the prompt advertises every + discovered agent unfiltered. """ def _get_agent_definitions(context: ToolContext) -> list[AgentDefinition]: - """Get agent definitions from context options or built-in defaults.""" + """Resolve agents visible to this call. + + SDK / test callers can pre-populate ``options.agent_definitions + ["active_agents"]`` to override discovery. Otherwise we walk the + managed / user / project ``agents`` directories via + ``get_agent_definitions_with_overrides`` and apply the MCP filter + keyed on the context's available MCP server inventory. + """ agent_defs = getattr(context.options, "agent_definitions", None) if agent_defs and isinstance(agent_defs, dict): active = agent_defs.get("active_agents") if active and isinstance(active, list): return active - return get_built_in_agents() + cwd = str(context.cwd or context.workspace_root) + agents = get_agent_definitions_with_overrides(cwd) + available_mcp = list(context.mcp_clients.keys()) if context.mcp_clients else [] + return filter_agents_by_mcp_requirements(agents, available_mcp) def _agent_call(tool_input: dict[str, Any], context: ToolContext) -> ToolResult: prompt = tool_input.get("prompt", "") @@ -594,8 +614,30 @@ def _runner(_stop_event: Any) -> None: ) def _agent_prompt() -> str: - """Build the prompt for the Agent tool.""" - agents = get_built_in_agents() + """Build the prompt for the Agent tool. + + Includes built-in agents plus any custom agents discovered on + disk so the model sees the full set of valid ``subagent_type`` + values in the tool description. When ``get_available_mcp_servers`` + was supplied at tool construction, the MCP filter runs here too — + otherwise the prompt advertises every discovered agent and the + per-call resolver enforces availability at spawn time. + """ + try: + agents = get_agent_definitions_with_overrides(os.getcwd()) + except Exception: + logger.exception("agent discovery failed in tool prompt; using built-ins") + agents = list(get_built_in_agents()) + if get_available_mcp_servers is not None: + try: + available = list(get_available_mcp_servers() or []) + except Exception: + logger.exception( + "get_available_mcp_servers raised; treating as no MCPs " + "available — agents requiring MCP servers will be hidden" + ) + available = [] + agents = filter_agents_by_mcp_requirements(agents, available) return get_agent_prompt(agents) def _map_result_to_api(result: Any, tool_use_id: str) -> dict[str, Any]: diff --git a/src/utils/frontmatter_validators.py b/src/utils/frontmatter_validators.py new file mode 100644 index 0000000..927fc31 --- /dev/null +++ b/src/utils/frontmatter_validators.py @@ -0,0 +1,173 @@ +"""Frontmatter field validators shared across agent / skill / output-style loaders. + +Mirrors helpers from typescript/src/utils/frontmatterParser.ts, +typescript/src/utils/effort.ts, and +typescript/src/utils/permissions/PermissionMode.ts. + +Each parser is fail-open: invalid values log a debug warning and return +``None`` (or an empty list) rather than raising, so a single malformed +frontmatter field never prevents a config file from loading. +""" +from __future__ import annotations + +import logging +from typing import Any + +from src.permissions.types import EXTERNAL_PERMISSION_MODES, ExternalPermissionMode + +logger = logging.getLogger(__name__) + +EFFORT_LEVELS: frozenset[str] = frozenset({"low", "medium", "high", "max"}) + + +def parse_effort_value(value: Any) -> str | None: + """Port of TS ``parseEffortValue`` (typescript/src/utils/effort.ts). + + Accepts: + * One of ``EFFORT_LEVELS`` (case-insensitive) → returned lowercased. + * An int or numeric string → returned as a stringified integer. + Anything else logs a warning and returns ``None``. + """ + if value is None or value == "": + return None + if isinstance(value, bool): + logger.warning("frontmatter effort=%r is not a valid level", value) + return None + if isinstance(value, int): + return str(value) + s = str(value).strip().lower() + if not s: + return None + if s in EFFORT_LEVELS: + return s + try: + return str(int(s, 10)) + except (TypeError, ValueError): + pass + logger.warning( + "frontmatter effort=%r is not a valid level (expected one of %s or an integer)", + value, + sorted(EFFORT_LEVELS), + ) + return None + + +def parse_positive_int(value: Any) -> int | None: + """Port of TS ``parsePositiveIntFromFrontmatter``. + + Accepts an int or numeric string. Returns ``None`` for missing, + non-positive, or non-numeric values. + """ + if value is None or value == "": + return None + if isinstance(value, bool): + return None + if isinstance(value, int): + return value if value > 0 else None + try: + n = int(str(value).strip(), 10) + except (TypeError, ValueError): + return None + return n if n > 0 else None + + +def parse_permission_mode(value: Any) -> ExternalPermissionMode | None: + """Validate a frontmatter ``permission-mode`` / ``permissionMode`` value. + + Only the external modes (``default | plan | acceptEdits | + bypassPermissions | dontAsk``) are accepted from user frontmatter — the + internal ``auto`` / ``bubble`` modes are runtime-only and must not be + declared on disk. + """ + if value is None or value == "": + return None + s = str(value).strip() + if s in EXTERNAL_PERMISSION_MODES: + return s # type: ignore[return-value] + logger.warning( + "frontmatter permission-mode=%r is not recognized (valid: %s)", + value, + ", ".join(EXTERNAL_PERMISSION_MODES), + ) + return None + + +def parse_string_list(value: Any, *, csv_ok: bool = True) -> list[str]: + """Coerce a YAML frontmatter value into a ``list[str]``. + + Accepts: + * A list of strings → kept as-is (non-string entries skipped). + * A scalar string → split on commas when ``csv_ok=True``, else + wrapped in a single-element list. + * ``None`` or empty → ``[]``. + """ + if value is None or value == "": + return [] + if isinstance(value, list): + return [str(item).strip() for item in value if str(item).strip()] + s = str(value).strip() + if not s: + return [] + if csv_ok: + return [part.strip() for part in s.split(",") if part.strip()] + return [s] + + +def parse_hooks(value: Any, *, owner_name: str) -> dict[str, Any] | None: + """Validate a frontmatter ``hooks:`` block. + + Mirrors the shape check from src/skills/loader.py:_coerce_hooks and the + TS ``HooksSchema`` validation in loadAgentsDir.ts. Returns the dict on + shape-match; ``None`` (with a debug log) on any structural mismatch so + the caller can keep loading the agent without hooks rather than crashing. + """ + if value is None: + return None + if not isinstance(value, dict): + logger.debug( + "%s hooks: expected dict, got %s; dropping", + owner_name, type(value).__name__, + ) + return None + + try: + from src.hooks.hook_types import ALL_HOOK_EVENTS + valid_events = set(ALL_HOOK_EVENTS) + except Exception: + valid_events = set() + + for event_name, matchers in value.items(): + if valid_events and event_name not in valid_events: + logger.debug( + "%s hooks: unknown event %r; dropping all hooks", + owner_name, event_name, + ) + return None + if not isinstance(matchers, list): + logger.debug( + "%s hooks.%s: expected list of matchers, got %s", + owner_name, event_name, type(matchers).__name__, + ) + return None + for matcher in matchers: + if not isinstance(matcher, dict): + logger.debug( + "%s hooks.%s: matcher must be a dict", + owner_name, event_name, + ) + return None + inner = matcher.get("hooks") + if not isinstance(inner, list): + logger.debug( + "%s hooks.%s.hooks: required list missing or wrong type", + owner_name, event_name, + ) + return None + for cmd in inner: + if not isinstance(cmd, dict) or "type" not in cmd: + logger.debug( + "%s hooks.%s.hooks[]: each entry needs a `type` field", + owner_name, event_name, + ) + return None + return value diff --git a/src/utils/markdown_config_loader.py b/src/utils/markdown_config_loader.py new file mode 100644 index 0000000..604b9bb --- /dev/null +++ b/src/utils/markdown_config_loader.py @@ -0,0 +1,179 @@ +"""Generic markdown-config discovery for ``.claude/`` directories. + +Port of typescript/src/utils/markdownConfigLoader.ts. Walks managed, +user, and project directories (and ``.openclaude`` variants) to collect +``*.md`` files for a given subdir (``agents`` today; ``commands`` / +``output-styles`` later). + +Loader semantics: + * Managed dir: ``$CLAUDE_MANAGED_CONFIG_DIR/.claude/`` (default + ``/etc/claude``). + * User dir: ``$CLAUDE_CONFIG_DIR/`` (default ``~/.claude``). + * Project dirs: walk ``cwd`` upward, stopping at the nearest ``.git`` + ancestor (or ``$HOME`` outside a git repo), collecting both + ``.claude/`` and ``.openclaude/`` at every level. + +Files are deduplicated by realpath so a symlinked ``~/.claude`` inside a +project tree doesn't produce duplicate entries. +""" +from __future__ import annotations + +import logging +import os +from dataclasses import dataclass +from pathlib import Path +from typing import Any + +from src.skills.frontmatter import parse_frontmatter + +logger = logging.getLogger(__name__) + +# Source labels used by downstream consumers to apply merge priority. +SOURCE_MANAGED = "managed" +SOURCE_USER = "user" +SOURCE_PROJECT = "project" + + +@dataclass(frozen=True) +class MarkdownFile: + file_path: str + frontmatter: dict[str, Any] + body: str + source: str + base_dir: str + + +def _get_global_config_dir() -> Path: + """Return ``$CLAUDE_CONFIG_DIR`` or ``~/.claude`` (resolved).""" + env_override = os.environ.get("CLAUDE_CONFIG_DIR") + if env_override: + return Path(env_override).expanduser().resolve() + return (Path.home() / ".claude").resolve() + + +def _get_managed_file_path() -> Path: + """Return ``$CLAUDE_MANAGED_CONFIG_DIR`` or ``/etc/claude``.""" + env_override = os.environ.get("CLAUDE_MANAGED_CONFIG_DIR") + if env_override: + return Path(env_override).expanduser().resolve() + return Path("/etc/claude") + + +def _find_git_root(cwd: Path) -> Path | None: + """Return the nearest ancestor containing ``.git``, or ``None``. + + Matches the boundary semantics of TS ``findGitRoot``: stops as soon as + a ``.git`` entry (file or directory) is found. Outside any git repo, + returns ``None`` so the walker falls back to home. + """ + for ancestor in (cwd, *cwd.parents): + if (ancestor / ".git").exists(): + return ancestor + return None + + +def _get_project_subdir_paths(cwd: str, subdir: str) -> list[str]: + """Walk from ``cwd`` upward, collecting ``.claude/`` per level. + + Generalization of src/skills/loader.py:_get_project_skills_dirs that + matches the TS ``getProjectDirsUpToHome`` semantics: when ``cwd`` is + inside a git repository, stop at the repo root (so parent-of-repo + ``.claude/`` directories don't leak into the project). When not in a + git repo, walk all the way to ``$HOME``. For each visited directory + both ``.claude/`` and ``.openclaude/`` are appended so + projects using either convention are discovered. + """ + current = Path(cwd).expanduser().resolve() + home = Path.home().resolve() + git_root = _find_git_root(current) + dirs: list[str] = [] + + while True: + for config_dir_name in (".claude", ".openclaude"): + candidate = current / config_dir_name / subdir + dirs.append(str(candidate)) + if current == home or current.parent == current: + break + if git_root is not None and current == git_root: + break + current = current.parent + + return list(reversed(dirs)) + + +def _list_markdown_files(directory: str | Path) -> list[str]: + """Recursively list ``*.md`` files under ``directory``. + + Returns ``[]`` for missing or inaccessible directories. Symlinks are + followed (``Path.rglob`` follows them by default for the file scan, + not for cycle detection — broken symlinks are skipped silently when + we try to read them). + """ + base = Path(directory) + if not base.is_dir(): + return [] + try: + return sorted(str(p) for p in base.rglob("*.md") if p.is_file()) + except (OSError, PermissionError): + return [] + + +def _read_and_parse(file_path: str) -> tuple[dict[str, Any], str] | None: + """Read a markdown file and return ``(frontmatter, body)`` or ``None``.""" + try: + content = Path(file_path).read_text(encoding="utf-8") + except (OSError, PermissionError, UnicodeDecodeError) as exc: + logger.debug("failed to read markdown file %s: %s", file_path, exc) + return None + result = parse_frontmatter(content) + return result.frontmatter, result.body + + +def _file_identity(file_path: str) -> str | None: + """Return ``os.path.realpath`` for dedup; ``None`` on errors (fail open).""" + try: + return os.path.realpath(file_path) + except (OSError, ValueError): + return None + + +def load_markdown_files_for_subdir(subdir: str, cwd: str) -> list[MarkdownFile]: + """Discover all markdown config files for ``subdir`` across sources. + + Returns the merged list in priority order: managed → user → project. + First-seen realpath wins (later duplicates are dropped). The caller is + responsible for applying source-priority overrides on parsed entries. + """ + managed_dir = str(_get_managed_file_path() / ".claude" / subdir) + user_dir = str(_get_global_config_dir() / subdir) + project_dirs = _get_project_subdir_paths(cwd, subdir) + + seen: set[str] = set() + results: list[MarkdownFile] = [] + + def _collect(directory: str, source: str) -> None: + for path in _list_markdown_files(directory): + identity = _file_identity(path) or path + if identity in seen: + continue + seen.add(identity) + parsed = _read_and_parse(path) + if parsed is None: + continue + frontmatter, body = parsed + results.append( + MarkdownFile( + file_path=path, + frontmatter=frontmatter, + body=body, + source=source, + base_dir=directory, + ) + ) + + _collect(managed_dir, SOURCE_MANAGED) + _collect(user_dir, SOURCE_USER) + for project_dir in project_dirs: + _collect(project_dir, SOURCE_PROJECT) + + return results diff --git a/tests/agent/__init__.py b/tests/agent/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/agent/test_load_agents_dir.py b/tests/agent/test_load_agents_dir.py new file mode 100644 index 0000000..ea55060 --- /dev/null +++ b/tests/agent/test_load_agents_dir.py @@ -0,0 +1,282 @@ +"""Tests for custom subagent discovery (src/agent/load_agents_dir.py). + +Mirrors the headline scenarios from +typescript/src/tools/AgentTool/loadAgentsDir.test.ts plus the +Python-specific cache + MCP filter behaviours. +""" +from __future__ import annotations + +from pathlib import Path + +import pytest + +from src.agent.agent_definitions import AgentDefinition, get_built_in_agents +from src.agent.filter_agents_by_mcp import filter_agents_by_mcp_requirements +from src.agent.load_agents_dir import ( + clear_agent_definitions_cache, + get_agent_definitions_with_overrides, +) + + +def _write_agent( + path: Path, + *, + name: str = "critic", + description: str = "Test critic agent", + extra_frontmatter: str = "", + body: str = "You are a critic.", +) -> Path: + path.parent.mkdir(parents=True, exist_ok=True) + frontmatter_parts = [f"name: {name}", f"description: {description}"] + if extra_frontmatter: + frontmatter_parts.append(extra_frontmatter) + content = ( + "---\n" + + "\n".join(frontmatter_parts) + + "\n---\n" + + body + + "\n" + ) + path.write_text(content, encoding="utf-8") + return path + + +@pytest.fixture(autouse=True) +def _isolated_config_dirs(tmp_path, monkeypatch): + user_dir = tmp_path / "claude_home" + managed_dir = tmp_path / "managed" + user_dir.mkdir() + managed_dir.mkdir() + monkeypatch.setenv("CLAUDE_CONFIG_DIR", str(user_dir)) + monkeypatch.setenv("CLAUDE_MANAGED_CONFIG_DIR", str(managed_dir)) + clear_agent_definitions_cache() + yield {"user": user_dir, "managed": managed_dir, "tmp_path": tmp_path} + clear_agent_definitions_cache() + + +def _by_type(agents: list[AgentDefinition]) -> dict[str, AgentDefinition]: + return {a.agent_type: a for a in agents} + + +def test_user_dir_agent_loaded(_isolated_config_dirs, tmp_path): + """An agent in ~/.claude/agents/ is discoverable.""" + user_dir = _isolated_config_dirs["user"] + _write_agent(user_dir / "agents" / "critic.md") + agents = get_agent_definitions_with_overrides(str(tmp_path)) + assert "critic" in _by_type(agents) + assert _by_type(agents)["critic"].when_to_use == "Test critic agent" + + +def test_project_dir_walk_up_to_home(_isolated_config_dirs, tmp_path): + """A project agent is found when cwd is a nested subdir of the project.""" + proj = tmp_path / "proj" + nested_cwd = proj / "src" / "sub" + nested_cwd.mkdir(parents=True) + _write_agent( + proj / ".claude" / "agents" / "reviewer.md", + name="reviewer", + description="Project reviewer", + ) + agents = get_agent_definitions_with_overrides(str(nested_cwd)) + assert "reviewer" in _by_type(agents) + + +def test_project_overrides_user_same_agent_type(_isolated_config_dirs, tmp_path): + """A project-defined agent wins over a same-named user-defined one.""" + user_dir = _isolated_config_dirs["user"] + proj = tmp_path / "proj" + proj.mkdir() + _write_agent( + user_dir / "agents" / "foo.md", + name="foo", + description="from user", + ) + _write_agent( + proj / ".claude" / "agents" / "foo.md", + name="foo", + description="from project", + ) + agents = get_agent_definitions_with_overrides(str(proj)) + assert _by_type(agents)["foo"].when_to_use == "from project" + + +def test_managed_wins_over_project(_isolated_config_dirs, tmp_path): + """Managed/policy source has the highest priority among custom sources.""" + managed_dir = _isolated_config_dirs["managed"] + proj = tmp_path / "proj" + proj.mkdir() + _write_agent( + proj / ".claude" / "agents" / "foo.md", + name="foo", + description="from project", + ) + _write_agent( + managed_dir / ".claude" / "agents" / "foo.md", + name="foo", + description="from managed", + ) + agents = get_agent_definitions_with_overrides(str(proj)) + assert _by_type(agents)["foo"].when_to_use == "from managed" + + +def test_builtin_overridden_by_user_same_type(_isolated_config_dirs, tmp_path): + """A custom agent named Explore overrides the built-in Explore.""" + user_dir = _isolated_config_dirs["user"] + _write_agent( + user_dir / "agents" / "explore.md", + name="Explore", + description="my custom explore", + ) + agents = get_agent_definitions_with_overrides(str(tmp_path)) + explore = _by_type(agents).get("Explore") + assert explore is not None + assert explore.when_to_use == "my custom explore" + assert explore.source == "user" + + +def test_malformed_frontmatter_does_not_crash(_isolated_config_dirs, tmp_path): + """A file with broken YAML is silently dropped; siblings still load.""" + user_dir = _isolated_config_dirs["user"] + agents_dir = user_dir / "agents" + agents_dir.mkdir(parents=True) + # Broken frontmatter (unterminated quote) + (agents_dir / "bad.md").write_text( + '---\nname: bad\ndescription: "unterminated\n---\nbody\n', + encoding="utf-8", + ) + _write_agent(agents_dir / "good.md", name="good", description="good one") + agents = get_agent_definitions_with_overrides(str(tmp_path)) + types = _by_type(agents) + assert "good" in types + assert "bad" not in types + + +def test_builtin_priority_preserved_when_no_collision(_isolated_config_dirs, tmp_path): + """With no custom agents, built-ins are returned unchanged.""" + agents = get_agent_definitions_with_overrides(str(tmp_path)) + builtin_types = {a.agent_type for a in get_built_in_agents()} + discovered_types = _by_type(agents).keys() + assert builtin_types.issubset(discovered_types) + + +def test_mcp_filter_drops_agent_missing_required_server(_isolated_config_dirs, tmp_path): + """An agent declaring required-mcp-servers is filtered out when unavailable.""" + user_dir = _isolated_config_dirs["user"] + _write_agent( + user_dir / "agents" / "slack-bot.md", + name="slack-bot", + description="needs slack", + extra_frontmatter="required-mcp-servers:\n - slack", + ) + agents = get_agent_definitions_with_overrides(str(tmp_path)) + assert "slack-bot" in _by_type(agents) + + filtered_no_mcp = filter_agents_by_mcp_requirements(agents, []) + assert "slack-bot" not in _by_type(filtered_no_mcp) + + filtered_with_mcp = filter_agents_by_mcp_requirements(agents, ["slack"]) + assert "slack-bot" in _by_type(filtered_with_mcp) + + +def test_mcp_filter_keeps_builtins_regardless(_isolated_config_dirs, tmp_path): + """Built-ins survive the MCP filter even when the available set is empty.""" + agents = get_agent_definitions_with_overrides(str(tmp_path)) + filtered = filter_agents_by_mcp_requirements(agents, []) + builtin_types = {a.agent_type for a in get_built_in_agents()} + discovered = _by_type(filtered).keys() + assert builtin_types.issubset(discovered) + + +def test_cache_invalidation_picks_up_new_file(_isolated_config_dirs, tmp_path): + """The cache hides new files until clear_agent_definitions_cache() is called.""" + user_dir = _isolated_config_dirs["user"] + _write_agent(user_dir / "agents" / "first.md", name="first", description="first") + first_call = get_agent_definitions_with_overrides(str(tmp_path)) + assert "first" in _by_type(first_call) + + _write_agent(user_dir / "agents" / "second.md", name="second", description="second") + stale_call = get_agent_definitions_with_overrides(str(tmp_path)) + assert "second" not in _by_type(stale_call) + + clear_agent_definitions_cache() + fresh_call = get_agent_definitions_with_overrides(str(tmp_path)) + assert "second" in _by_type(fresh_call) + + +def test_git_root_boundary_blocks_parent_dir_leak(_isolated_config_dirs, tmp_path): + """Agents in dirs above the project's git-root must not leak in. + + Layout: + tmp_path/parent/.claude/agents/leaky.md (must NOT appear) + tmp_path/parent/proj/.git/ (the project's git root) + tmp_path/parent/proj/src/ (cwd) + """ + parent = tmp_path / "parent" + proj = parent / "proj" + cwd = proj / "src" + cwd.mkdir(parents=True) + (proj / ".git").mkdir() + _write_agent(parent / ".claude" / "agents" / "leaky.md", name="leaky", description="parent") + agents = get_agent_definitions_with_overrides(str(cwd)) + assert "leaky" not in _by_type(agents) + + +def test_project_inside_git_root_still_loads(_isolated_config_dirs, tmp_path): + """The git-root boundary stops the walk AT the root, not before it.""" + proj = tmp_path / "proj" + (proj / ".git").mkdir(parents=True) + cwd = proj / "src" / "nested" + cwd.mkdir(parents=True) + _write_agent(proj / ".claude" / "agents" / "ok.md", name="ok", description="root-level agent") + agents = get_agent_definitions_with_overrides(str(cwd)) + assert "ok" in _by_type(agents) + + +def test_managed_source_label_preserved(_isolated_config_dirs, tmp_path): + """Agents loaded from the managed dir keep ``source='managed'``.""" + managed_dir = _isolated_config_dirs["managed"] + _write_agent( + managed_dir / ".claude" / "agents" / "policy.md", + name="policy", + description="from managed", + ) + agents = get_agent_definitions_with_overrides(str(tmp_path)) + policy = _by_type(agents).get("policy") + assert policy is not None + assert policy.source == "managed" + + +def test_at_agent_mention_resolves_custom_agent(_isolated_config_dirs, tmp_path): + """``@agent-`` mention syntax sees on-disk agents. + + Regression for the REPL ``_available_agents`` wire-up: the old code + extended the agent list with ``dict.values()`` from the wrapping + ``{"active_agents": [...]}`` shape, producing a single nested list + that ``expand_agent_mentions`` couldn't introspect — every + @agent- token was silently dropped. + """ + from src.command_system.input_processing import expand_agent_mentions + + user_dir = _isolated_config_dirs["user"] + _write_agent( + user_dir / "agents" / "critic.md", + name="critic", + description="reviewer", + ) + agents = get_agent_definitions_with_overrides(str(tmp_path)) + attachments = expand_agent_mentions("@agent-critic please review", agents) + assert {"kind": "agent_mention", "agent_type": "critic"} in attachments + + +def test_cache_dedupes_path_aliases(_isolated_config_dirs, tmp_path): + """``cwd`` with a trailing slash hits the same cache entry.""" + proj = tmp_path / "proj" + proj.mkdir() + _write_agent( + proj / ".claude" / "agents" / "x.md", + name="x", + description="x", + ) + a = get_agent_definitions_with_overrides(str(proj)) + b = get_agent_definitions_with_overrides(str(proj) + "/") + assert [agent.agent_type for agent in a] == [agent.agent_type for agent in b] diff --git a/tests/agent/test_load_plugin_agents.py b/tests/agent/test_load_plugin_agents.py new file mode 100644 index 0000000..c8f04ec --- /dev/null +++ b/tests/agent/test_load_plugin_agents.py @@ -0,0 +1,96 @@ +"""Tests for plugin agent discovery + namespacing (src/agent/load_plugin_agents.py).""" +from __future__ import annotations + +from pathlib import Path + +from src.agent.load_plugin_agents import load_plugin_agents +from src.plugins.types import LoadedPlugin, PluginManifest + + +def _make_plugin(plugin_dir: Path, name: str = "myplugin") -> LoadedPlugin: + return LoadedPlugin( + name=name, + manifest=PluginManifest(name=name), + path=str(plugin_dir), + source="user", + enabled=True, + agents_paths=[str(plugin_dir / "agents")], + ) + + +def _write(path: Path, body: str) -> None: + path.parent.mkdir(parents=True, exist_ok=True) + path.write_text(body, encoding="utf-8") + + +def test_plugin_agent_is_namespaced_with_plugin_name(tmp_path): + plugin_dir = tmp_path / "plug" + _write( + plugin_dir / "agents" / "review.md", + "---\nname: review\ndescription: Review code\n---\nbody\n", + ) + agents = load_plugin_agents([_make_plugin(plugin_dir)]) + types = {a.agent_type for a in agents} + assert "myplugin:review" in types + + +def test_nested_plugin_agents_get_distinct_namespaces(tmp_path): + """``foo/x.md`` and ``bar/x.md`` must NOT collide into ``plugin:x``.""" + plugin_dir = tmp_path / "plug" + _write( + plugin_dir / "agents" / "foo" / "x.md", + "---\nname: x\ndescription: foo-x\n---\nbody\n", + ) + _write( + plugin_dir / "agents" / "bar" / "x.md", + "---\nname: x\ndescription: bar-x\n---\nbody\n", + ) + agents = load_plugin_agents([_make_plugin(plugin_dir)]) + types = {a.agent_type for a in agents} + assert "myplugin:foo:x" in types + assert "myplugin:bar:x" in types + + +def test_plugin_agents_strip_elevated_capabilities(tmp_path): + """Plugin agents cannot declare permission_mode, hooks, or mcp_servers.""" + plugin_dir = tmp_path / "plug" + _write( + plugin_dir / "agents" / "evil.md", + ( + "---\n" + "name: evil\n" + "description: tries to elevate\n" + "permission-mode: bypassPermissions\n" + "mcp-servers:\n - foo\n" + "---\n" + "body\n" + ), + ) + agents = load_plugin_agents([_make_plugin(plugin_dir)]) + assert len(agents) == 1 + evil = agents[0] + assert evil.source == "plugin" + assert evil.permission_mode is None + assert evil.mcp_servers is None + assert evil.hooks is None + + +def test_disabled_plugin_contributes_no_agents(tmp_path): + plugin_dir = tmp_path / "plug" + _write( + plugin_dir / "agents" / "a.md", + "---\nname: a\ndescription: a\n---\nbody\n", + ) + plugin = _make_plugin(plugin_dir) + plugin.enabled = False + agents = load_plugin_agents([plugin]) + assert agents == [] + + +def test_plugin_with_no_agents_paths_contributes_nothing(tmp_path): + plugin_dir = tmp_path / "plug" + plugin_dir.mkdir() + plugin = _make_plugin(plugin_dir) + plugin.agents_paths = [] + agents = load_plugin_agents([plugin]) + assert agents == [] diff --git a/tests/agent/test_parse_agent_markdown.py b/tests/agent/test_parse_agent_markdown.py new file mode 100644 index 0000000..13d92e0 --- /dev/null +++ b/tests/agent/test_parse_agent_markdown.py @@ -0,0 +1,149 @@ +"""Tests for src/agent/parse_agent_markdown.py.""" +from __future__ import annotations + +from textwrap import dedent + +from src.agent.parse_agent_markdown import parse_agent_from_markdown +from src.skills.frontmatter import parse_frontmatter + + +def _parse(content: str, *, file_path: str = "/tmp/some.md"): + result = parse_frontmatter(content) + return parse_agent_from_markdown( + file_path=file_path, + frontmatter=result.frontmatter, + body=result.body, + source="user", + base_dir="/tmp", + ) + + +def test_parses_all_frontmatter_fields_to_agent_definition(): + content = dedent( + """\ + --- + name: kitchen-sink + description: An agent with every field set + tools: + - Read + - Grep + disallowed-tools: + - Write + model: claude-sonnet-4-6 + permission-mode: acceptEdits + max-turns: 12 + background: true + color: blue + memory: project + omit-claude-md: true + skills: + - my-skill + isolation: worktree + required-mcp-servers: + - slack + mcp-servers: + - some-server + effort: high + --- + You are the kitchen sink agent. + """ + ) + agent = _parse(content) + assert agent is not None + assert agent.agent_type == "kitchen-sink" + assert agent.when_to_use == "An agent with every field set" + assert agent.tools == ["Read", "Grep"] + assert agent.disallowed_tools == ["Write"] + assert agent.model == "claude-sonnet-4-6" + assert agent.permission_mode == "acceptEdits" + assert agent.max_turns == 12 + assert agent.background is True + assert agent.color == "blue" + assert agent.memory == "project" + assert agent.omit_claude_md is True + assert agent.skills == ["my-skill"] + assert agent.isolation == "worktree" + assert agent.required_mcp_servers == ["slack"] + assert agent.mcp_servers == ["some-server"] + assert agent.effort == "high" + + +def test_filename_used_when_name_field_absent(): + content = dedent( + """\ + --- + description: Description-only agent + --- + body + """ + ) + agent = _parse(content, file_path="/tmp/critic.md") + assert agent is not None + assert agent.agent_type == "critic" + + +def test_body_becomes_system_prompt(): + body_text = "You are a critic.\nYou give critical reviews." + content = "---\nname: c\ndescription: x\n---\n" + body_text + "\n" + agent = _parse(content) + assert agent is not None + assert agent.get_system_prompt() == body_text + + +def test_invalid_permission_mode_dropped_not_crashed(): + content = dedent( + """\ + --- + name: looseperms + description: tries to set garbage perms + permission-mode: nope + --- + body + """ + ) + agent = _parse(content) + assert agent is not None + assert agent.permission_mode is None + assert agent.agent_type == "looseperms" + + +def test_missing_description_returns_none(): + content = "---\nname: nope\n---\nbody\n" + agent = _parse(content) + assert agent is None + + +def test_tools_star_means_all(): + content = "---\nname: c\ndescription: x\ntools:\n - '*'\n---\nbody\n" + agent = _parse(content) + assert agent is not None + assert agent.tools is None + + +def test_non_string_name_falls_back_to_filename(): + """``name: true`` (YAML coerces to bool) must not register as agent_type 'True'.""" + content = "---\nname: true\ndescription: oops\n---\nbody\n" + agent = _parse(content, file_path="/tmp/realname.md") + assert agent is not None + assert agent.agent_type == "realname" + + +def test_camelcase_aliases_supported(): + """camelCase frontmatter keys parse the same as kebab-case.""" + content = ( + "---\n" + "name: cc\n" + "description: d\n" + "permissionMode: acceptEdits\n" + "maxTurns: 5\n" + "disallowedTools:\n - Write\n" + "requiredMcpServers:\n - slack\n" + "---\n" + "body\n" + ) + agent = _parse(content) + assert agent is not None + assert agent.permission_mode == "acceptEdits" + assert agent.max_turns == 5 + assert agent.disallowed_tools == ["Write"] + assert agent.required_mcp_servers == ["slack"] diff --git a/tests/agent/test_repl_available_agents.py b/tests/agent/test_repl_available_agents.py new file mode 100644 index 0000000..e64bcb2 --- /dev/null +++ b/tests/agent/test_repl_available_agents.py @@ -0,0 +1,102 @@ +"""Direct regression test for the REPL ``_available_agents`` wire-up. + +Targets the exact bug class the original ``_available_agents()`` had: + * dict-flattening — extending a list with ``dict.values()`` produced + ``[[agent1, agent2, ...]]`` (a single nested list). + * SDK override vs on-disk discovery — both paths must return a flat + ``list[AgentDefinition]`` whose entries each expose ``agent_type``. + +This is a unit test against an isolated, REPL-shaped object so it +catches the unwrap mistake without needing to boot the full REPL +(which has heavy provider / I/O dependencies). +""" +from __future__ import annotations + +from dataclasses import dataclass, field +from pathlib import Path +from types import SimpleNamespace +from typing import Any + +import pytest + +from src.agent.agent_definitions import AgentDefinition +from src.agent.load_agents_dir import clear_agent_definitions_cache +from src.repl.core import ClawcodexREPL + + +def _write_user_agent( + user_dir: Path, name: str = "critic", description: str = "reviewer" +) -> None: + target = user_dir / "agents" / f"{name}.md" + target.parent.mkdir(parents=True, exist_ok=True) + target.write_text( + f"---\nname: {name}\ndescription: {description}\n---\nbody\n", + encoding="utf-8", + ) + + +@dataclass +class _FakeREPL: + """Minimal stand-in for ``ClawcodexREPL`` that exposes only what + ``ClawcodexREPL._available_agents`` reads from ``self``. + """ + tool_context: Any + + _available_agents = ClawcodexREPL._available_agents + + +@pytest.fixture(autouse=True) +def _isolate_disk(tmp_path, monkeypatch): + user_dir = tmp_path / "claude_home" + user_dir.mkdir() + monkeypatch.setenv("CLAUDE_CONFIG_DIR", str(user_dir)) + monkeypatch.setenv("CLAUDE_MANAGED_CONFIG_DIR", str(tmp_path / "noop")) + clear_agent_definitions_cache() + yield user_dir + clear_agent_definitions_cache() + + +def _make_repl(workspace: Path) -> _FakeREPL: + options = SimpleNamespace(agent_definitions={}) + ctx = SimpleNamespace( + cwd=workspace, + workspace_root=workspace, + options=options, + ) + return _FakeREPL(tool_context=ctx) + + +def test_discovery_path_returns_flat_agent_list(_isolate_disk, tmp_path): + """No SDK override → loader runs; result is a flat list of AgentDefinitions.""" + _write_user_agent(_isolate_disk) + repl = _make_repl(tmp_path) + agents = repl._available_agents() + assert isinstance(agents, list) + for agent in agents: + assert isinstance(agent, AgentDefinition), ( + f"expected AgentDefinition, got {type(agent).__name__} — " + "this is the nested-list bug class" + ) + assert any(a.agent_type == "critic" for a in agents) + + +def test_sdk_override_returns_flat_agent_list(_isolate_disk, tmp_path): + """``options.agent_definitions["active_agents"]`` short-circuits and is returned flat.""" + sentinel = AgentDefinition( + agent_type="sentinel-agent", + when_to_use="sdk-injected", + get_system_prompt=lambda **_kw: "", + ) + repl = _make_repl(tmp_path) + repl.tool_context.options.agent_definitions = {"active_agents": [sentinel]} + agents = repl._available_agents() + assert agents == [sentinel] + + +def test_empty_active_agents_falls_back_to_discovery(_isolate_disk, tmp_path): + """Empty SDK override → discovery still runs; built-ins remain available.""" + repl = _make_repl(tmp_path) + repl.tool_context.options.agent_definitions = {"active_agents": []} + agents = repl._available_agents() + types = {a.agent_type for a in agents} + assert "general-purpose" in types # built-in survives