From d7963600105ea5ccd93c275219632a579227666a Mon Sep 17 00:00:00 2001 From: Sodawyx Date: Thu, 16 Apr 2026 19:46:16 +0800 Subject: [PATCH 1/7] feat(super_agent): Add new super agent module and related APIs This commit introduces the `super_agent` module with its core components including models, API handlers, clients, and streaming functionality. It also adds corresponding unit tests to ensure proper functionality of the new features. Key changes include: - Added `model.py`, `api/` directory with control and data modules - Introduced `agent.py`, `client.py`, `stream.py`, and template files - Updated `tool.py` to improve parameter description handling - Modified existing client methods in `agent_runtime/client.py` and `__client_async_template.py` - Extended `agent_runtime/model.py` to support SUPER_AGENT protocol - Improved test cases for better coverage - Updated package exports in `__init__.py` to expose new super agent classes and functions Co-developed-by: Aone Copilot Signed-off-by: Sodawyx --- agentrun/__init__.py | 29 + .../agent_runtime/__client_async_template.py | 7 +- agentrun/agent_runtime/client.py | 14 +- agentrun/agent_runtime/model.py | 1 + agentrun/integration/utils/tool.py | 4 +- .../super_agent/__agent_async_template.py | 195 +++++ .../super_agent/__client_async_template.py | 514 ++++++++++++ agentrun/super_agent/__init__.py | 52 ++ agentrun/super_agent/agent.py | 203 +++++ agentrun/super_agent/agui.py | 136 ++++ .../super_agent/api/__data_async_template.py | 269 ++++++ agentrun/super_agent/api/__init__.py | 5 + agentrun/super_agent/api/control.py | 550 +++++++++++++ agentrun/super_agent/api/data.py | 275 +++++++ agentrun/super_agent/client.py | 524 ++++++++++++ agentrun/super_agent/model.py | 91 +++ agentrun/super_agent/stream.py | 180 ++++ .../test_langchain_agui_integration.py | 8 +- tests/unittests/super_agent/__init__.py | 0 tests/unittests/super_agent/test_agent.py | 249 ++++++ tests/unittests/super_agent/test_agui.py | 498 ++++++++++++ tests/unittests/super_agent/test_client.py | 768 ++++++++++++++++++ tests/unittests/super_agent/test_control.py | 385 +++++++++ tests/unittests/super_agent/test_data_api.py | 488 +++++++++++ .../unittests/super_agent/test_no_coupling.py | 59 ++ tests/unittests/super_agent/test_stream.py | 203 +++++ tests/unittests/toolset/api/test_openapi.py | 24 +- 27 files changed, 5702 insertions(+), 29 deletions(-) create mode 100644 agentrun/super_agent/__agent_async_template.py create mode 100644 agentrun/super_agent/__client_async_template.py create mode 100644 agentrun/super_agent/__init__.py create mode 100644 agentrun/super_agent/agent.py create mode 100644 agentrun/super_agent/agui.py create mode 100644 agentrun/super_agent/api/__data_async_template.py create mode 100644 agentrun/super_agent/api/__init__.py create mode 100644 agentrun/super_agent/api/control.py create mode 100644 agentrun/super_agent/api/data.py create mode 100644 agentrun/super_agent/client.py create mode 100644 agentrun/super_agent/model.py create mode 100644 agentrun/super_agent/stream.py create mode 100644 tests/unittests/super_agent/__init__.py create mode 100644 tests/unittests/super_agent/test_agent.py create mode 100644 tests/unittests/super_agent/test_agui.py create mode 100644 tests/unittests/super_agent/test_client.py create mode 100644 tests/unittests/super_agent/test_control.py create mode 100644 tests/unittests/super_agent/test_data_api.py create mode 100644 tests/unittests/super_agent/test_no_coupling.py create mode 100644 tests/unittests/super_agent/test_stream.py diff --git a/agentrun/__init__.py b/agentrun/__init__.py index bf3e750..9693b1b 100644 --- a/agentrun/__init__.py +++ b/agentrun/__init__.py @@ -113,6 +113,21 @@ SandboxClient, Template, ) +# Super Agent +from agentrun.super_agent import ( + ConversationInfo, + InvokeResponseData, + InvokeStream, +) +from agentrun.super_agent import Message as SuperAgentMessage +from agentrun.super_agent import ( + SSEEvent, + SuperAgent, + SuperAgentClient, + SuperAgentCreateInput, + SuperAgentListInput, + SuperAgentUpdateInput, +) # Tool from agentrun.tool import Tool as ToolResource from agentrun.tool import ToolClient as ToolResourceClient @@ -248,6 +263,20 @@ "ModelProxyCreateInput", "ModelProxyUpdateInput", "ModelProxyListInput", + ######## Super Agent ######## + # base + "SuperAgent", + "SuperAgentClient", + # inner model + "InvokeStream", + "SSEEvent", + "ConversationInfo", + "SuperAgentMessage", + # api model + "SuperAgentCreateInput", + "SuperAgentUpdateInput", + "SuperAgentListInput", + "InvokeResponseData", ######## Sandbox ######## "SandboxClient", "BrowserSandbox", diff --git a/agentrun/agent_runtime/__client_async_template.py b/agentrun/agent_runtime/__client_async_template.py index ac8759a..d72bac6 100644 --- a/agentrun/agent_runtime/__client_async_template.py +++ b/agentrun/agent_runtime/__client_async_template.py @@ -118,7 +118,12 @@ async def delete_async( result = await self.__control_api.delete_agent_runtime_async( id, config=config ) - return AgentRuntime.from_inner_object(result) + # Delete 响应只有 agentRuntimeId 有效,其他字段为空字符串/零值, + # 走 from_inner_object 会在 status/artifactType 等 Enum 字段上触发 + # 伪校验错误。这里直接构造一个只带 id 的极简对象。 + return AgentRuntime.model_construct( + agent_runtime_id=result.agent_runtime_id + ) except HTTPError as e: raise e.to_resource_error("AgentRuntime", id) from e diff --git a/agentrun/agent_runtime/client.py b/agentrun/agent_runtime/client.py index 826e047..6010463 100644 --- a/agentrun/agent_runtime/client.py +++ b/agentrun/agent_runtime/client.py @@ -171,7 +171,12 @@ async def delete_async( result = await self.__control_api.delete_agent_runtime_async( id, config=config ) - return AgentRuntime.from_inner_object(result) + # Delete 响应只有 agentRuntimeId 有效,其他字段为空字符串/零值, + # 走 from_inner_object 会在 status/artifactType 等 Enum 字段上触发 + # 伪校验错误。这里直接构造一个只带 id 的极简对象。 + return AgentRuntime.model_construct( + agent_runtime_id=result.agent_runtime_id + ) except HTTPError as e: raise e.to_resource_error("AgentRuntime", id) from e @@ -191,7 +196,12 @@ def delete(self, id: str, config: Optional[Config] = None) -> AgentRuntime: """ try: result = self.__control_api.delete_agent_runtime(id, config=config) - return AgentRuntime.from_inner_object(result) + # Delete 响应只有 agentRuntimeId 有效,其他字段为空字符串/零值, + # 走 from_inner_object 会在 status/artifactType 等 Enum 字段上触发 + # 伪校验错误。这里直接构造一个只带 id 的极简对象。 + return AgentRuntime.model_construct( + agent_runtime_id=result.agent_runtime_id + ) except HTTPError as e: raise e.to_resource_error("AgentRuntime", id) from e diff --git a/agentrun/agent_runtime/model.py b/agentrun/agent_runtime/model.py index f3d649e..965c1bf 100644 --- a/agentrun/agent_runtime/model.py +++ b/agentrun/agent_runtime/model.py @@ -185,6 +185,7 @@ class AgentRuntimeProtocolType(str, Enum): HTTP = "HTTP" MCP = "MCP" + SUPER_AGENT = "SUPER_AGENT" class AgentRuntimeProtocolConfig(BaseModel): diff --git a/agentrun/integration/utils/tool.py b/agentrun/integration/utils/tool.py index 1fd3666..c479846 100644 --- a/agentrun/integration/utils/tool.py +++ b/agentrun/integration/utils/tool.py @@ -1614,8 +1614,8 @@ def _build_openapi_schema( if isinstance(schema, dict): properties[name] = { **schema, - "description": ( - param.get("description") or schema.get("description", "") + "description": param.get("description") or schema.get( + "description", "" ), } if param.get("required"): diff --git a/agentrun/super_agent/__agent_async_template.py b/agentrun/super_agent/__agent_async_template.py new file mode 100644 index 0000000..a1f6d1a --- /dev/null +++ b/agentrun/super_agent/__agent_async_template.py @@ -0,0 +1,195 @@ +"""SuperAgent 实例 / Super Agent Instance + +``SuperAgent`` 是暴露给应用开发者的强类型实例对象, 承载 ``invoke`` / 会话管理 +两类方法 (仅异步; 见决策 14)。CRUDL 由 ``SuperAgentClient`` 管理。 + +本文件为模板 (``__agent_async_template.py``), codegen 会把 ``async def ...`` +转换成同步骨架; 实际第一版异步主路径 + 同步 NotImplementedError 占位。 +""" + +from dataclasses import dataclass, field +from typing import Any, Dict, List, Optional + +from agentrun.super_agent.api.data import SuperAgentDataAPI +from agentrun.super_agent.model import ConversationInfo, Message +from agentrun.super_agent.stream import InvokeStream +from agentrun.utils.config import Config + +_SYNC_UNSUPPORTED_MSG = ( + "sync version not supported, use *_async (see decision 14 in" + " openspec/changes/add-super-agent-sdk/design.md)" +) + + +@dataclass +class SuperAgent: + """超级 Agent 实例. + + 业务字段 (``prompt / agents / tools / ...``) 从 ``protocolSettings.config`` + 反解。系统字段 (``agent_runtime_id / arn / status / ...``) 来自 AgentRuntime。 + """ + + name: str + description: Optional[str] = None + prompt: Optional[str] = None + agents: List[str] = field(default_factory=list) + tools: List[str] = field(default_factory=list) + skills: List[str] = field(default_factory=list) + sandboxes: List[str] = field(default_factory=list) + workspaces: List[str] = field(default_factory=list) + model_service_name: Optional[str] = None + model_name: Optional[str] = None + + agent_runtime_id: str = "" + arn: str = "" + status: str = "" + created_at: str = "" + last_updated_at: str = "" + external_endpoint: str = "" + + _client: Any = field(default=None, repr=False, compare=False) + + def _resolve_config(self, config: Optional[Config]) -> Config: + client_cfg = ( + getattr(self._client, "config", None) if self._client else None + ) + return Config.with_configs(client_cfg, config) + + def _forwarded_business_fields(self) -> Dict[str, Any]: + """把 SuperAgent 实例字段打包成 ``forwardedProps`` 顶层业务字段 dict. + + 与 ``protocolSettings[0].config`` 写入时的结构保持对称: list 型用 ``[]`` + 代替 None, scalar 型保留 None (由 JSON 序列化为 ``null``)。服务端读取同 + 一份语义, 避免客户端/服务端对"未设置"产生歧义。 + """ + return { + "prompt": self.prompt, + "agents": list(self.agents), + "tools": list(self.tools), + "skills": list(self.skills), + "sandboxes": list(self.sandboxes), + "workspaces": list(self.workspaces), + "modelServiceName": self.model_service_name, + "modelName": self.model_name, + } + + async def invoke_async( + self, + messages: List[Dict[str, Any]], + *, + conversation_id: Optional[str] = None, + config: Optional[Config] = None, + ) -> InvokeStream: + """Phase 1: POST /invoke; 返回包含 ``conversation_id`` 的 :class:`InvokeStream`. + + 首次 ``async for ev in stream`` 才触发 Phase 2 拉流 (lazy)。 + """ + cfg = self._resolve_config(config) + api = SuperAgentDataAPI(self.name, config=cfg) + resp = await api.invoke_async( + messages, + conversation_id=conversation_id, + config=cfg, + forwarded_extras=self._forwarded_business_fields(), + ) + stream_url = resp.stream_url + stream_headers = dict(resp.stream_headers) + session_id = stream_headers.get("X-Super-Agent-Session-Id", "") + + async def _factory(): + return api.stream_async( + stream_url, stream_headers=stream_headers, config=cfg + ) + + return InvokeStream( + conversation_id=resp.conversation_id, + session_id=session_id, + stream_url=stream_url, + stream_headers=stream_headers, + _stream_factory=_factory, + ) + + def invoke( + self, + messages: List[Dict[str, Any]], + *, + conversation_id: Optional[str] = None, + config: Optional[Config] = None, + ) -> InvokeStream: + raise NotImplementedError(_SYNC_UNSUPPORTED_MSG) + + async def get_conversation_async( + self, + conversation_id: str, + *, + config: Optional[Config] = None, + ) -> ConversationInfo: + """GET /conversations/{id} → :class:`ConversationInfo` (缺字段用默认值).""" + cfg = self._resolve_config(config) + api = SuperAgentDataAPI(self.name, config=cfg) + data = await api.get_conversation_async(conversation_id, config=cfg) + return _conversation_info_from_dict( + data, fallback_conversation_id=conversation_id + ) + + def get_conversation( + self, + conversation_id: str, + *, + config: Optional[Config] = None, + ) -> ConversationInfo: + raise NotImplementedError(_SYNC_UNSUPPORTED_MSG) + + async def delete_conversation_async( + self, + conversation_id: str, + *, + config: Optional[Config] = None, + ) -> None: + """DELETE /conversations/{id}.""" + cfg = self._resolve_config(config) + api = SuperAgentDataAPI(self.name, config=cfg) + await api.delete_conversation_async(conversation_id, config=cfg) + + def delete_conversation( + self, + conversation_id: str, + *, + config: Optional[Config] = None, + ) -> None: + raise NotImplementedError(_SYNC_UNSUPPORTED_MSG) + + +def _to_message(raw: Dict[str, Any]) -> Message: + return Message( + role=str(raw.get("role") or ""), + content=str(raw.get("content") or ""), + message_id=raw.get("messageId") or raw.get("message_id"), + created_at=raw.get("createdAt") or raw.get("created_at"), + ) + + +def _conversation_info_from_dict( + data: Dict[str, Any], *, fallback_conversation_id: str +) -> ConversationInfo: + data = data or {} + messages_raw = data.get("messages") or [] + messages = [_to_message(m) for m in messages_raw if isinstance(m, dict)] + return ConversationInfo( + conversation_id=str( + data.get("conversationId") or fallback_conversation_id + ), + agent_id=str(data.get("agentId") or data.get("agent_id") or ""), + title=data.get("title"), + main_user_id=data.get("mainUserId") or data.get("main_user_id"), + sub_user_id=data.get("subUserId") or data.get("sub_user_id"), + created_at=int(data.get("createdAt") or data.get("created_at") or 0), + updated_at=int(data.get("updatedAt") or data.get("updated_at") or 0), + error_message=data.get("errorMessage") or data.get("error_message"), + invoke_info=data.get("invokeInfo") or data.get("invoke_info"), + messages=messages, + params=data.get("params"), + ) + + +__all__ = ["SuperAgent"] diff --git a/agentrun/super_agent/__client_async_template.py b/agentrun/super_agent/__client_async_template.py new file mode 100644 index 0000000..0a9f1ba --- /dev/null +++ b/agentrun/super_agent/__client_async_template.py @@ -0,0 +1,514 @@ +"""SuperAgentClient / 超级 Agent 客户端 + +对外入口: CRUDL (create / get / update / delete / list / list_all) 同步 + 异步双写。 +内部持有一个 :class:`AgentRuntimeClient` 实例, 通过 ``api/control.py`` 的 +转换函数把 ``SuperAgent`` 与 ``AgentRuntime`` 互相映射。 + +list 固定按 tag ``x-agentrun-super-agent`` 过滤, 不接受用户自定义 tag。 +""" + +import asyncio +import time +from typing import Any, List, Optional + +from alibabacloud_agentrun20250910.models import ( + CreateAgentRuntimeInput, + UpdateAgentRuntimeInput, +) + +from agentrun.agent_runtime.api import AgentRuntimeControlAPI +from agentrun.agent_runtime.client import AgentRuntimeClient +from agentrun.agent_runtime.model import AgentRuntimeListInput +from agentrun.agent_runtime.runtime import AgentRuntime +from agentrun.super_agent.agent import SuperAgent +from agentrun.super_agent.api.control import ( + from_agent_runtime, + is_super_agent, + SUPER_AGENT_TAG, + to_create_input, + to_update_input, +) +from agentrun.utils.config import Config +from agentrun.utils.log import logger +from agentrun.utils.model import Status + +# 公开 API 签名故意保持 ``Optional[X] = None`` 对外简洁; +# ``_UNSET`` 仅用于内部区分 "未传" 与 "显式 None (= 清空)". +_UNSET: Any = object() + +# create/update 轮询默认参数 +_WAIT_INTERVAL_SECONDS = 3 +_WAIT_TIMEOUT_SECONDS = 300 + + +def _raise_if_failed(rt: AgentRuntime, action: str) -> None: + """若 rt 处于失败态, 抛出带 status_reason 的 RuntimeError.""" + status = getattr(rt, "status", None) + status_str = str(status) if status is not None else "" + if status_str in { + Status.CREATE_FAILED.value, + Status.UPDATE_FAILED.value, + Status.DELETE_FAILED.value, + }: + reason = getattr(rt, "status_reason", None) or "(no reason)" + name = getattr(rt, "agent_runtime_name", None) or "(unknown)" + raise RuntimeError( + f"Super agent {action} failed: name={name!r} status={status_str} " + f"reason={reason}" + ) + + +def _merge(current: dict, updates: dict) -> dict: + """把 ``updates`` 中非 ``_UNSET`` 的字段合并到 ``current`` (None 表示清空).""" + merged = dict(current) + for key, value in updates.items(): + if value is _UNSET: + continue + merged[key] = value + return merged + + +def _super_agent_to_business_dict(agent: SuperAgent) -> dict: + return { + "description": agent.description, + "prompt": agent.prompt, + "agents": list(agent.agents), + "tools": list(agent.tools), + "skills": list(agent.skills), + "sandboxes": list(agent.sandboxes), + "workspaces": list(agent.workspaces), + "model_service_name": agent.model_service_name, + "model_name": agent.model_name, + } + + +class SuperAgentClient: + """Super Agent CRUDL 客户端.""" + + def __init__(self, config: Optional[Config] = None) -> None: + self.config = config + self._rt = AgentRuntimeClient(config=config) + # create/update 绕过 AgentRuntimeClient 的 artifact_type 校验 (SUPER_AGENT 不需要 code/container), + # 并通过 ``ProtocolConfiguration`` 的 monkey-patch 保留 ``externalEndpoint`` 字段。 + self._rt_control = AgentRuntimeControlAPI(config=config) + + async def _wait_final_async( + self, + agent_runtime_id: str, + *, + config: Optional[Config] = None, + interval_seconds: int = _WAIT_INTERVAL_SECONDS, + timeout_seconds: int = _WAIT_TIMEOUT_SECONDS, + ) -> AgentRuntime: + """轮询 get 直到 status 进入最终态 (READY / *_FAILED).""" + cfg = Config.with_configs(self.config, config) + start = time.monotonic() + while True: + rt = await self._rt.get_async(agent_runtime_id, config=cfg) + status = getattr(rt, "status", None) + logger.debug( + "super agent %s poll status=%s", agent_runtime_id, status + ) + if Status.is_final_status(status): + return rt + if time.monotonic() - start > timeout_seconds: + raise TimeoutError( + f"Timed out waiting for super agent {agent_runtime_id!r}" + f" to reach final status (last status={status})" + ) + await asyncio.sleep(interval_seconds) + + def _wait_final( + self, + agent_runtime_id: str, + *, + config: Optional[Config] = None, + interval_seconds: int = _WAIT_INTERVAL_SECONDS, + timeout_seconds: int = _WAIT_TIMEOUT_SECONDS, + ) -> AgentRuntime: + """同步版 _wait_final_async.""" + cfg = Config.with_configs(self.config, config) + start = time.monotonic() + while True: + rt = self._rt.get(agent_runtime_id, config=cfg) + status = getattr(rt, "status", None) + logger.debug( + "super agent %s poll status=%s", agent_runtime_id, status + ) + if Status.is_final_status(status): + return rt + if time.monotonic() - start > timeout_seconds: + raise TimeoutError( + f"Timed out waiting for super agent {agent_runtime_id!r}" + f" to reach final status (last status={status})" + ) + time.sleep(interval_seconds) + + # ─── Create ────────────────────────────────────── + async def create_async( + self, + *, + name: str, + description: Optional[str] = None, + prompt: Optional[str] = None, + agents: Optional[List[str]] = None, + tools: Optional[List[str]] = None, + skills: Optional[List[str]] = None, + sandboxes: Optional[List[str]] = None, + workspaces: Optional[List[str]] = None, + model_service_name: Optional[str] = None, + model_name: Optional[str] = None, + config: Optional[Config] = None, + ) -> SuperAgent: + """异步创建超级 Agent.""" + cfg = Config.with_configs(self.config, config) + rt_input = to_create_input( + name, + description=description, + prompt=prompt, + agents=agents, + tools=tools, + skills=skills, + sandboxes=sandboxes, + workspaces=workspaces, + model_service_name=model_service_name, + model_name=model_name, + cfg=cfg, + ) + dara_input = CreateAgentRuntimeInput().from_map(rt_input.model_dump()) + result = await self._rt_control.create_agent_runtime_async( + dara_input, config=cfg + ) + rt = AgentRuntime.from_inner_object(result) + # 轮询直到进入最终态; 失败则抛出带 status_reason 的错误。 + agent_id = getattr(rt, "agent_runtime_id", None) + if agent_id: + rt = await self._wait_final_async(agent_id, config=cfg) + _raise_if_failed(rt, action="create") + agent = from_agent_runtime(rt) + agent._client = self + return agent + + def create( + self, + *, + name: str, + description: Optional[str] = None, + prompt: Optional[str] = None, + agents: Optional[List[str]] = None, + tools: Optional[List[str]] = None, + skills: Optional[List[str]] = None, + sandboxes: Optional[List[str]] = None, + workspaces: Optional[List[str]] = None, + model_service_name: Optional[str] = None, + model_name: Optional[str] = None, + config: Optional[Config] = None, + ) -> SuperAgent: + """同步创建超级 Agent.""" + cfg = Config.with_configs(self.config, config) + rt_input = to_create_input( + name, + description=description, + prompt=prompt, + agents=agents, + tools=tools, + skills=skills, + sandboxes=sandboxes, + workspaces=workspaces, + model_service_name=model_service_name, + model_name=model_name, + cfg=cfg, + ) + dara_input = CreateAgentRuntimeInput().from_map(rt_input.model_dump()) + result = self._rt_control.create_agent_runtime(dara_input, config=cfg) + rt = AgentRuntime.from_inner_object(result) + agent_id = getattr(rt, "agent_runtime_id", None) + if agent_id: + rt = self._wait_final(agent_id, config=cfg) + _raise_if_failed(rt, action="create") + agent = from_agent_runtime(rt) + agent._client = self + return agent + + # ─── Get ────────────────────────────────────────── + # Aliyun 控制面 get/delete/update 接口只认 ``agent_runtime_id`` (URN), + # 不认 resource_name; ``_find_rt_by_name*`` 用 list + 名称匹配来解析 id. + def _find_rt_by_name(self, name: str, config: Optional[Config]) -> Any: + cfg = Config.with_configs(self.config, config) + page_number = 1 + page_size = 50 + while True: + runtimes = self._rt.list( + AgentRuntimeListInput( + page_number=page_number, + page_size=page_size, + tags=SUPER_AGENT_TAG, + ), + config=cfg, + ) + for rt in runtimes: + if getattr(rt, "agent_runtime_name", None) == name: + return rt + if len(runtimes) < page_size: + return None + page_number += 1 + + async def _find_rt_by_name_async( + self, name: str, config: Optional[Config] + ) -> Any: + cfg = Config.with_configs(self.config, config) + page_number = 1 + page_size = 50 + while True: + runtimes = await self._rt.list_async( + AgentRuntimeListInput( + page_number=page_number, + page_size=page_size, + tags=SUPER_AGENT_TAG, + ), + config=cfg, + ) + for rt in runtimes: + if getattr(rt, "agent_runtime_name", None) == name: + return rt + if len(runtimes) < page_size: + return None + page_number += 1 + + async def get_async( + self, name: str, *, config: Optional[Config] = None + ) -> SuperAgent: + """异步获取超级 Agent (名称解析 → ID).""" + cfg = Config.with_configs(self.config, config) + rt = await self._find_rt_by_name_async(name, config=cfg) + if rt is None: + raise ValueError(f"Super agent {name!r} not found") + if not is_super_agent(rt): + raise ValueError(f"Resource {name!r} is not a super agent") + agent = from_agent_runtime(rt) + agent._client = self + return agent + + def get(self, name: str, *, config: Optional[Config] = None) -> SuperAgent: + """同步获取超级 Agent (名称解析 → ID).""" + cfg = Config.with_configs(self.config, config) + rt = self._find_rt_by_name(name, config=cfg) + if rt is None: + raise ValueError(f"Super agent {name!r} not found") + if not is_super_agent(rt): + raise ValueError(f"Resource {name!r} is not a super agent") + agent = from_agent_runtime(rt) + agent._client = self + return agent + + # ─── Update (read-merge-write) ───────────────────── + async def update_async( + self, + name: str, + *, + description: Any = _UNSET, + prompt: Any = _UNSET, + agents: Any = _UNSET, + tools: Any = _UNSET, + skills: Any = _UNSET, + sandboxes: Any = _UNSET, + workspaces: Any = _UNSET, + model_service_name: Any = _UNSET, + model_name: Any = _UNSET, + config: Optional[Config] = None, + ) -> SuperAgent: + """异步更新超级 Agent (read-merge-write).""" + cfg = Config.with_configs(self.config, config) + rt = await self._find_rt_by_name_async(name, config=cfg) + if rt is None: + raise ValueError(f"Super agent {name!r} not found") + if not is_super_agent(rt): + raise ValueError(f"Resource {name!r} is not a super agent") + current = _super_agent_to_business_dict(from_agent_runtime(rt)) + updates = { + "description": description, + "prompt": prompt, + "agents": agents, + "tools": tools, + "skills": skills, + "sandboxes": sandboxes, + "workspaces": workspaces, + "model_service_name": model_service_name, + "model_name": model_name, + } + merged = _merge(current, updates) + rt_input = to_update_input(name, merged, cfg=cfg) + dara_input = UpdateAgentRuntimeInput().from_map(rt_input.model_dump()) + agent_id = getattr(rt, "agent_runtime_id", None) or name + result = await self._rt_control.update_agent_runtime_async( + agent_id, dara_input, config=cfg + ) + rt = AgentRuntime.from_inner_object(result) + rt_id = getattr(rt, "agent_runtime_id", None) or agent_id + if rt_id: + rt = await self._wait_final_async(rt_id, config=cfg) + _raise_if_failed(rt, action="update") + agent = from_agent_runtime(rt) + agent._client = self + return agent + + def update( + self, + name: str, + *, + description: Any = _UNSET, + prompt: Any = _UNSET, + agents: Any = _UNSET, + tools: Any = _UNSET, + skills: Any = _UNSET, + sandboxes: Any = _UNSET, + workspaces: Any = _UNSET, + model_service_name: Any = _UNSET, + model_name: Any = _UNSET, + config: Optional[Config] = None, + ) -> SuperAgent: + """同步更新超级 Agent (read-merge-write).""" + cfg = Config.with_configs(self.config, config) + rt = self._find_rt_by_name(name, config=cfg) + if rt is None: + raise ValueError(f"Super agent {name!r} not found") + if not is_super_agent(rt): + raise ValueError(f"Resource {name!r} is not a super agent") + current = _super_agent_to_business_dict(from_agent_runtime(rt)) + updates = { + "description": description, + "prompt": prompt, + "agents": agents, + "tools": tools, + "skills": skills, + "sandboxes": sandboxes, + "workspaces": workspaces, + "model_service_name": model_service_name, + "model_name": model_name, + } + merged = _merge(current, updates) + rt_input = to_update_input(name, merged, cfg=cfg) + dara_input = UpdateAgentRuntimeInput().from_map(rt_input.model_dump()) + agent_id = getattr(rt, "agent_runtime_id", None) or name + result = self._rt_control.update_agent_runtime( + agent_id, dara_input, config=cfg + ) + rt = AgentRuntime.from_inner_object(result) + rt_id = getattr(rt, "agent_runtime_id", None) or agent_id + if rt_id: + rt = self._wait_final(rt_id, config=cfg) + _raise_if_failed(rt, action="update") + agent = from_agent_runtime(rt) + agent._client = self + return agent + + # ─── Delete ─────────────────────────────────────── + async def delete_async( + self, name: str, *, config: Optional[Config] = None + ) -> None: + """异步删除超级 Agent (名称解析 → ID).""" + cfg = Config.with_configs(self.config, config) + rt = await self._find_rt_by_name_async(name, config=cfg) + if rt is None: + raise ValueError(f"Super agent {name!r} not found") + agent_id = getattr(rt, "agent_runtime_id", None) or name + await self._rt.delete_async(agent_id, config=cfg) + + def delete(self, name: str, *, config: Optional[Config] = None) -> None: + """同步删除超级 Agent (名称解析 → ID).""" + cfg = Config.with_configs(self.config, config) + rt = self._find_rt_by_name(name, config=cfg) + if rt is None: + raise ValueError(f"Super agent {name!r} not found") + agent_id = getattr(rt, "agent_runtime_id", None) or name + self._rt.delete(agent_id, config=cfg) + + # ─── List ───────────────────────────────────────── + async def list_async( + self, + *, + page_number: int = 1, + page_size: int = 20, + config: Optional[Config] = None, + ) -> List[SuperAgent]: + """异步列出超级 Agent (固定 tag 过滤, 过滤非 SUPER_AGENT).""" + cfg = Config.with_configs(self.config, config) + rt_input = AgentRuntimeListInput( + page_number=page_number, + page_size=page_size, + tags=SUPER_AGENT_TAG, + ) + runtimes = await self._rt.list_async(rt_input, config=cfg) + result: List[SuperAgent] = [] + for rt in runtimes: + if not is_super_agent(rt): + continue + agent = from_agent_runtime(rt) + agent._client = self + result.append(agent) + return result + + def list( + self, + *, + page_number: int = 1, + page_size: int = 20, + config: Optional[Config] = None, + ) -> List[SuperAgent]: + """同步列出超级 Agent (固定 tag 过滤, 过滤非 SUPER_AGENT).""" + cfg = Config.with_configs(self.config, config) + rt_input = AgentRuntimeListInput( + page_number=page_number, + page_size=page_size, + tags=SUPER_AGENT_TAG, + ) + runtimes = self._rt.list(rt_input, config=cfg) + result: List[SuperAgent] = [] + for rt in runtimes: + if not is_super_agent(rt): + continue + agent = from_agent_runtime(rt) + agent._client = self + result.append(agent) + return result + + async def list_all_async( + self, *, config: Optional[Config] = None, page_size: int = 50 + ) -> List[SuperAgent]: + """异步一次性拉取所有超级 Agent (自动分页).""" + cfg = Config.with_configs(self.config, config) + result: List[SuperAgent] = [] + page_number = 1 + while True: + page = await self.list_async( + page_number=page_number, page_size=page_size, config=cfg + ) + if not page: + break + result.extend(page) + if len(page) < page_size: + break + page_number += 1 + return result + + def list_all( + self, *, config: Optional[Config] = None, page_size: int = 50 + ) -> List[SuperAgent]: + """同步一次性拉取所有超级 Agent (自动分页).""" + cfg = Config.with_configs(self.config, config) + result: List[SuperAgent] = [] + page_number = 1 + while True: + page = self.list( + page_number=page_number, page_size=page_size, config=cfg + ) + if not page: + break + result.extend(page) + if len(page) < page_size: + break + page_number += 1 + return result + + +__all__ = ["SuperAgentClient"] diff --git a/agentrun/super_agent/__init__.py b/agentrun/super_agent/__init__.py new file mode 100644 index 0000000..5afb68e --- /dev/null +++ b/agentrun/super_agent/__init__.py @@ -0,0 +1,52 @@ +"""Super Agent 模块 / Super Agent Module + +独立的超级 Agent SDK, 面向写应用的开发者。用户无需感知底层 AgentRuntime 概念, +只需: + +.. code-block:: python + + from agentrun.super_agent import SuperAgentClient + + client = SuperAgentClient() + agent = await client.create_async(name="my-agent", prompt="你好") + stream = await agent.invoke_async(messages=[{"role": "user", "content": "hi"}]) + print(stream.conversation_id) + async for ev in stream: + print(ev.event, ev.data) + +可选的 AG-UI 强类型适配, 显式导入: + +.. code-block:: python + + from agentrun.super_agent.agui import as_agui_events + + async for event in as_agui_events(stream): + ... # event 是 ``ag_ui.core.BaseEvent`` 子类实例 + +详见 ``openspec/changes/add-super-agent-sdk/`` 中的 proposal、design、spec。 +""" + +from .agent import SuperAgent +from .client import SuperAgentClient +from .model import ( + ConversationInfo, + InvokeResponseData, + Message, + SuperAgentCreateInput, + SuperAgentListInput, + SuperAgentUpdateInput, +) +from .stream import InvokeStream, SSEEvent + +__all__ = [ + "SuperAgentClient", + "SuperAgent", + "InvokeStream", + "SSEEvent", + "SuperAgentCreateInput", + "SuperAgentUpdateInput", + "SuperAgentListInput", + "ConversationInfo", + "Message", + "InvokeResponseData", +] diff --git a/agentrun/super_agent/agent.py b/agentrun/super_agent/agent.py new file mode 100644 index 0000000..fac7ff6 --- /dev/null +++ b/agentrun/super_agent/agent.py @@ -0,0 +1,203 @@ +""" +This file is auto generated by the code generation script. +Do not modify this file manually. +Use the `make codegen` command to regenerate. + +当前文件为自动生成的控制 API 客户端代码。请勿手动修改此文件。 +使用 `make codegen` 命令重新生成。 + +source: agentrun/super_agent/__agent_async_template.py + +SuperAgent 实例 / Super Agent Instance + +``SuperAgent`` 是暴露给应用开发者的强类型实例对象, 承载 ``invoke`` / 会话管理 +两类方法 (仅异步; 见决策 14)。CRUDL 由 ``SuperAgentClient`` 管理。 +同步方法保留 ``NotImplementedError`` 占位, 以备未来扩展。 +""" + +from dataclasses import dataclass, field +from typing import Any, Dict, List, Optional + +from agentrun.super_agent.api.data import SuperAgentDataAPI +from agentrun.super_agent.model import ConversationInfo, Message +from agentrun.super_agent.stream import InvokeStream +from agentrun.utils.config import Config + +_SYNC_UNSUPPORTED_MSG = ( + "sync version not supported, use *_async (see decision 14 in" + " openspec/changes/add-super-agent-sdk/design.md)" +) + + +@dataclass +class SuperAgent: + """超级 Agent 实例. + + 业务字段 (``prompt / agents / tools / ...``) 从 ``protocolSettings.config`` + 反解。系统字段 (``agent_runtime_id / arn / status / ...``) 来自 AgentRuntime。 + """ + + name: str + description: Optional[str] = None + prompt: Optional[str] = None + agents: List[str] = field(default_factory=list) + tools: List[str] = field(default_factory=list) + skills: List[str] = field(default_factory=list) + sandboxes: List[str] = field(default_factory=list) + workspaces: List[str] = field(default_factory=list) + model_service_name: Optional[str] = None + model_name: Optional[str] = None + + agent_runtime_id: str = "" + arn: str = "" + status: str = "" + created_at: str = "" + last_updated_at: str = "" + external_endpoint: str = "" + + _client: Any = field(default=None, repr=False, compare=False) + + def _resolve_config(self, config: Optional[Config]) -> Config: + client_cfg = ( + getattr(self._client, "config", None) if self._client else None + ) + return Config.with_configs(client_cfg, config) + + def _forwarded_business_fields(self) -> Dict[str, Any]: + """把 SuperAgent 实例字段打包成 ``forwardedProps`` 顶层业务字段 dict. + + 与 ``protocolSettings[0].config`` 写入时的结构保持对称: list 型用 ``[]`` + 代替 None, scalar 型保留 None (由 JSON 序列化为 ``null``)。服务端读取同 + 一份语义, 避免客户端/服务端对"未设置"产生歧义。 + """ + return { + "prompt": self.prompt, + "agents": list(self.agents), + "tools": list(self.tools), + "skills": list(self.skills), + "sandboxes": list(self.sandboxes), + "workspaces": list(self.workspaces), + "modelServiceName": self.model_service_name, + "modelName": self.model_name, + } + + async def invoke_async( + self, + messages: List[Dict[str, Any]], + *, + conversation_id: Optional[str] = None, + config: Optional[Config] = None, + ) -> InvokeStream: + """Phase 1: POST /invoke; 返回包含 ``conversation_id`` 的 :class:`InvokeStream`. + + 首次 ``async for ev in stream`` 才触发 Phase 2 拉流 (lazy)。 + """ + cfg = self._resolve_config(config) + api = SuperAgentDataAPI(self.name, config=cfg) + resp = await api.invoke_async( + messages, + conversation_id=conversation_id, + config=cfg, + forwarded_extras=self._forwarded_business_fields(), + ) + stream_url = resp.stream_url + stream_headers = dict(resp.stream_headers) + session_id = stream_headers.get("X-Super-Agent-Session-Id", "") + + async def _factory(): + return api.stream_async( + stream_url, stream_headers=stream_headers, config=cfg + ) + + return InvokeStream( + conversation_id=resp.conversation_id, + session_id=session_id, + stream_url=stream_url, + stream_headers=stream_headers, + _stream_factory=_factory, + ) + + def invoke( + self, + messages: List[Dict[str, Any]], + *, + conversation_id: Optional[str] = None, + config: Optional[Config] = None, + ) -> InvokeStream: + raise NotImplementedError(_SYNC_UNSUPPORTED_MSG) + + async def get_conversation_async( + self, + conversation_id: str, + *, + config: Optional[Config] = None, + ) -> ConversationInfo: + """GET /conversations/{id} → :class:`ConversationInfo` (缺字段用默认值).""" + cfg = self._resolve_config(config) + api = SuperAgentDataAPI(self.name, config=cfg) + data = await api.get_conversation_async(conversation_id, config=cfg) + return _conversation_info_from_dict( + data, fallback_conversation_id=conversation_id + ) + + def get_conversation( + self, + conversation_id: str, + *, + config: Optional[Config] = None, + ) -> ConversationInfo: + raise NotImplementedError(_SYNC_UNSUPPORTED_MSG) + + async def delete_conversation_async( + self, + conversation_id: str, + *, + config: Optional[Config] = None, + ) -> None: + """DELETE /conversations/{id}.""" + cfg = self._resolve_config(config) + api = SuperAgentDataAPI(self.name, config=cfg) + await api.delete_conversation_async(conversation_id, config=cfg) + + def delete_conversation( + self, + conversation_id: str, + *, + config: Optional[Config] = None, + ) -> None: + raise NotImplementedError(_SYNC_UNSUPPORTED_MSG) + + +def _to_message(raw: Dict[str, Any]) -> Message: + return Message( + role=str(raw.get("role") or ""), + content=str(raw.get("content") or ""), + message_id=raw.get("messageId") or raw.get("message_id"), + created_at=raw.get("createdAt") or raw.get("created_at"), + ) + + +def _conversation_info_from_dict( + data: Dict[str, Any], *, fallback_conversation_id: str +) -> ConversationInfo: + data = data or {} + messages_raw = data.get("messages") or [] + messages = [_to_message(m) for m in messages_raw if isinstance(m, dict)] + return ConversationInfo( + conversation_id=str( + data.get("conversationId") or fallback_conversation_id + ), + agent_id=str(data.get("agentId") or data.get("agent_id") or ""), + title=data.get("title"), + main_user_id=data.get("mainUserId") or data.get("main_user_id"), + sub_user_id=data.get("subUserId") or data.get("sub_user_id"), + created_at=int(data.get("createdAt") or data.get("created_at") or 0), + updated_at=int(data.get("updatedAt") or data.get("updated_at") or 0), + error_message=data.get("errorMessage") or data.get("error_message"), + invoke_info=data.get("invokeInfo") or data.get("invoke_info"), + messages=messages, + params=data.get("params"), + ) + + +__all__ = ["SuperAgent"] diff --git a/agentrun/super_agent/agui.py b/agentrun/super_agent/agui.py new file mode 100644 index 0000000..ac615a8 --- /dev/null +++ b/agentrun/super_agent/agui.py @@ -0,0 +1,136 @@ +"""Super Agent AG-UI 适配器 / Super Agent AG-UI Adapter + +把 :class:`InvokeStream` 的原始 :class:`SSEEvent` 解码为 ``ag_ui.core.BaseEvent`` +强类型事件 (**client 侧解码**)。 + +与 ``agentrun/server/agui_protocol.py`` (server 侧编码方向) 是反向关系, 互不依赖: +本文件只消费 ``ag_ui.core`` 的事件类, server 侧则负责产生它们。 + +使用: + +.. code-block:: python + + from agentrun.super_agent.agui import as_agui_events + + async for event in as_agui_events(stream): + # event 是 ag_ui.core.BaseEvent 子类 (如 TextMessageContentEvent) + ... + + # 想跳过未知事件 (如过渡期兼容), 传 on_unknown="skip" + async for event in as_agui_events(stream, on_unknown="skip"): + ... +""" + +from __future__ import annotations + +from typing import AsyncIterator, Dict, Literal, Optional, Type + +from ag_ui.core import ( + BaseEvent, + CustomEvent, + MessagesSnapshotEvent, + RawEvent, + RunErrorEvent, + RunFinishedEvent, + RunStartedEvent, + StateDeltaEvent, + StateSnapshotEvent, + TextMessageContentEvent, + TextMessageEndEvent, + TextMessageStartEvent, + ToolCallArgsEvent, + ToolCallEndEvent, + ToolCallResultEvent, + ToolCallStartEvent, +) + +from .stream import InvokeStream, SSEEvent + +_EVENT_TYPE_TO_CLASS: Dict[str, Type[BaseEvent]] = { + "RUN_STARTED": RunStartedEvent, + "RUN_FINISHED": RunFinishedEvent, + "RUN_ERROR": RunErrorEvent, + "TEXT_MESSAGE_START": TextMessageStartEvent, + "TEXT_MESSAGE_CONTENT": TextMessageContentEvent, + "TEXT_MESSAGE_END": TextMessageEndEvent, + "TOOL_CALL_START": ToolCallStartEvent, + "TOOL_CALL_ARGS": ToolCallArgsEvent, + "TOOL_CALL_END": ToolCallEndEvent, + "TOOL_CALL_RESULT": ToolCallResultEvent, + "STATE_SNAPSHOT": StateSnapshotEvent, + "STATE_DELTA": StateDeltaEvent, + "MESSAGES_SNAPSHOT": MessagesSnapshotEvent, + "RAW": RawEvent, + "CUSTOM": CustomEvent, +} + + +UnknownMode = Literal["raise", "skip"] + + +def _data_preview(data: str, limit: int = 200) -> str: + if len(data) <= limit: + return data + return data[:limit] + "..." + + +def decode_sse_to_agui( + sse_event: SSEEvent, + *, + on_unknown: UnknownMode = "raise", +) -> Optional[BaseEvent]: + """把单个 :class:`SSEEvent` 解码为 AG-UI 事件. + + - 空 ``data`` (keepalive) → 返回 ``None`` + - 未知 ``event`` + ``on_unknown='raise'`` → 抛 ``ValueError`` + - 未知 ``event`` + ``on_unknown='skip'`` → 返回 ``None`` + - ``data`` 不合法 JSON 或 Pydantic 校验失败 → 抛 ``ValueError`` + (不论 ``on_unknown``, 因为这通常是真实错误) + """ + if sse_event.data == "": + return None + + event_name = sse_event.event + if not event_name or event_name not in _EVENT_TYPE_TO_CLASS: + if on_unknown == "skip": + return None + raise ValueError( + f"Unknown AG-UI event type {event_name!r}; data prefix:" + f" {_data_preview(sse_event.data)}" + ) + + cls = _EVENT_TYPE_TO_CLASS[event_name] + try: + return cls.model_validate_json(sse_event.data) + except Exception as exc: # JSONDecodeError, ValidationError, etc. + raise ValueError( + f"Failed to decode AG-UI {event_name!r} event: {exc};" + f" data prefix: {_data_preview(sse_event.data)}" + ) from exc + + +async def as_agui_events( + stream: InvokeStream, + *, + on_unknown: UnknownMode = "raise", +) -> AsyncIterator[BaseEvent]: + """把 :class:`InvokeStream` 中的原始 :class:`SSEEvent` 解码为强类型流. + + 无论正常消费结束、中途异常、解码异常, 都保证 ``await stream.aclose()`` 被调用 + 以释放 httpx 连接。 + """ + try: + async for sse_event in stream: + agui_event = decode_sse_to_agui(sse_event, on_unknown=on_unknown) + if agui_event is None: + continue + yield agui_event + finally: + await stream.aclose() + + +__all__ = [ + "as_agui_events", + "decode_sse_to_agui", + "_EVENT_TYPE_TO_CLASS", +] diff --git a/agentrun/super_agent/api/__data_async_template.py b/agentrun/super_agent/api/__data_async_template.py new file mode 100644 index 0000000..3549842 --- /dev/null +++ b/agentrun/super_agent/api/__data_async_template.py @@ -0,0 +1,269 @@ +"""Super Agent 数据面 API / Super Agent Data Plane API + +继承 :class:`agentrun.utils.data_api.DataAPI`, 复用: + +- ``_get_ram_data_endpoint``: RAM ``-ram`` 前缀改写 +- ``get_base_url`` / ``with_path``: URL 拼接 (含版本号) +- ``auth``: RAM 签名头生成 (``Agentrun-Authorization`` / ``x-acs-*``) + +本文件为 **模板** (``__data_async_template.py``); +当前 ``codegen`` 的 async→sync 转换不支持 async generator (``async for + yield``) +以及 ``async with ... as x`` 里的作用域保留, 所以 ``api/data.py`` 的同步骨架 +(同步方法第一版仅保留 ``NotImplementedError`` 占位) 目前 **手工维护**。 +运行 ``make codegen`` 会生成不可运行的版本, 请不要直接覆盖。 +""" + +import json +from typing import Any, AsyncIterator, Dict, Iterator, List, Optional + +import httpx + +from agentrun.super_agent.model import InvokeResponseData +from agentrun.super_agent.stream import parse_sse_async, SSEEvent +from agentrun.utils.config import Config +from agentrun.utils.data_api import DataAPI, ResourceType +from agentrun.utils.log import logger + +API_VERSION = "2025-09-10" +SUPER_AGENT_RESOURCE_PATH = "__SUPER_AGENT__" +SUPER_AGENT_NAMESPACE = ( + f"{API_VERSION}/super-agents/{SUPER_AGENT_RESOURCE_PATH}" +) + +_SYNC_UNSUPPORTED_MSG = ( + "sync version not supported, use *_async (see decision 14 in" + " openspec/changes/add-super-agent-sdk/design.md)" +) + + +class SuperAgentDataAPI(DataAPI): + """Super Agent 数据面 API (异步主路径 + 同步占位).""" + + def __init__( + self, + agent_runtime_name: str, + config: Optional[Config] = None, + ): + super().__init__( + resource_name=agent_runtime_name, + resource_type=ResourceType.Runtime, + namespace=SUPER_AGENT_NAMESPACE, + config=config, + ) + self.agent_runtime_name = agent_runtime_name + + def _build_invoke_body( + self, + messages: List[Dict[str, Any]], + conversation_id: Optional[str], + forwarded_extras: Optional[Dict[str, Any]] = None, + ) -> Dict[str, Any]: + # ``forwarded_extras`` 承载从 AgentRuntime 元数据读出的业务字段 + # (prompt/agents/tools/skills/sandboxes/workspaces/modelServiceName/modelName), + # 由上层 ``SuperAgent.invoke_async`` 注入。``metadata`` 和 ``conversationId`` + # 由 SDK 管理, 不允许 extras 覆盖。 + forwarded: Dict[str, Any] = dict(forwarded_extras or {}) + forwarded["metadata"] = {"agentRuntimeName": self.agent_runtime_name} + if conversation_id is not None: + forwarded["conversationId"] = conversation_id + return {"messages": list(messages), "forwardedProps": forwarded} + + def _parse_invoke_response( + self, payload: Dict[str, Any] + ) -> InvokeResponseData: + if not isinstance(payload, dict): + raise ValueError( + "Invalid invoke response: expected object, got" + f" {type(payload).__name__}" + ) + data = payload.get("data") + if not isinstance(data, dict): + raise ValueError("Invalid invoke response: missing data field") + conversation_id = data.get("conversationId") + if not conversation_id: + raise ValueError( + "Invalid invoke response: missing conversationId field" + ) + stream_url = data.get("url") + if not stream_url: + raise ValueError("Invalid invoke response: missing url field") + raw_headers = data.get("headers") or {} + stream_headers = {str(k): str(v) for k, v in dict(raw_headers).items()} + return InvokeResponseData( + conversation_id=str(conversation_id), + stream_url=str(stream_url), + stream_headers=stream_headers, + ) + + async def invoke_async( + self, + messages: List[Dict[str, Any]], + conversation_id: Optional[str] = None, + config: Optional[Config] = None, + forwarded_extras: Optional[Dict[str, Any]] = None, + ) -> InvokeResponseData: + """Phase 1: POST /invoke, 返回 Phase 2 的 URL 与 headers.""" + cfg = Config.with_configs(self.config, config) + url = self.with_path("invoke", config=cfg) + body = self._build_invoke_body( + messages, conversation_id, forwarded_extras + ) + body_bytes = json.dumps(body).encode("utf-8") + _, signed_headers, _ = self.auth( + url=url, + method="POST", + headers=cfg.get_headers(), + body=body_bytes, + config=cfg, + ) + signed_headers.setdefault("Content-Type", "application/json") + logger.debug("super_agent invoke request: POST %s body=%s", url, body) + + async with httpx.AsyncClient(timeout=cfg.get_timeout()) as client: + resp = await client.post( + url, headers=signed_headers, content=body_bytes + ) + resp.raise_for_status() + payload = resp.json() + logger.debug( + "super_agent invoke response: status=%d payload=%s", + resp.status_code, + payload, + ) + return self._parse_invoke_response(payload) + + def invoke( + self, + messages: List[Dict[str, Any]], + conversation_id: Optional[str] = None, + config: Optional[Config] = None, + forwarded_extras: Optional[Dict[str, Any]] = None, + ) -> InvokeResponseData: + raise NotImplementedError(_SYNC_UNSUPPORTED_MSG) + + async def stream_async( + self, + stream_url: str, + stream_headers: Optional[Dict[str, str]] = None, + config: Optional[Config] = None, + ) -> AsyncIterator[SSEEvent]: + """Phase 2: GET stream_url → 流式 yield SSEEvent. + + 合并优先级 (低 → 高): ``cfg.get_headers()`` → ``stream_headers`` → RAM 签名头。 + """ + cfg = Config.with_configs(self.config, config) + _, signed_headers, _ = self.auth( + url=stream_url, + method="GET", + headers=cfg.get_headers(), + config=cfg, + ) + merged_headers: Dict[str, str] = {} + merged_headers.update(cfg.get_headers()) + if stream_headers: + merged_headers.update(stream_headers) + merged_headers.update(signed_headers) + + logger.debug("super_agent stream request: GET %s", stream_url) + timeout = httpx.Timeout(cfg.get_timeout(), read=None) + client = httpx.AsyncClient(timeout=timeout) + try: + ctx = client.stream("GET", stream_url, headers=merged_headers) + response = await ctx.__aenter__() + try: + response.raise_for_status() + logger.debug( + "super_agent stream response: status=%d headers=%s", + response.status_code, + dict(response.headers), + ) + async for event in parse_sse_async(response): + logger.debug("super_agent stream event: %s", event) + yield event + finally: + await ctx.__aexit__(None, None, None) + finally: + await client.aclose() + + def stream( + self, + stream_url: str, + stream_headers: Optional[Dict[str, str]] = None, + config: Optional[Config] = None, + ) -> Iterator[SSEEvent]: + raise NotImplementedError(_SYNC_UNSUPPORTED_MSG) + + async def get_conversation_async( + self, + conversation_id: str, + config: Optional[Config] = None, + ) -> Dict[str, Any]: + """GET /conversations/{id} → 返回服务端 ``data`` 字段 dict.""" + cfg = Config.with_configs(self.config, config) + url = self.with_path(f"conversations/{conversation_id}", config=cfg) + _, signed_headers, _ = self.auth( + url=url, + method="GET", + headers=cfg.get_headers(), + config=cfg, + ) + logger.debug("super_agent get_conversation request: GET %s", url) + async with httpx.AsyncClient(timeout=cfg.get_timeout()) as client: + resp = await client.get(url, headers=signed_headers) + resp.raise_for_status() + payload = resp.json() if resp.text else {} + logger.debug( + "super_agent get_conversation response: status=%d payload=%s", + resp.status_code, + payload, + ) + if not isinstance(payload, dict): + return {} + data = payload.get("data") + return data if isinstance(data, dict) else {} + + def get_conversation( + self, + conversation_id: str, + config: Optional[Config] = None, + ) -> Dict[str, Any]: + raise NotImplementedError(_SYNC_UNSUPPORTED_MSG) + + async def delete_conversation_async( + self, + conversation_id: str, + config: Optional[Config] = None, + ) -> None: + """DELETE /conversations/{id}.""" + cfg = Config.with_configs(self.config, config) + url = self.with_path(f"conversations/{conversation_id}", config=cfg) + _, signed_headers, _ = self.auth( + url=url, + method="DELETE", + headers=cfg.get_headers(), + config=cfg, + ) + logger.debug("super_agent delete_conversation request: DELETE %s", url) + async with httpx.AsyncClient(timeout=cfg.get_timeout()) as client: + resp = await client.delete(url, headers=signed_headers) + resp.raise_for_status() + logger.debug( + "super_agent delete_conversation response: status=%d body=%s", + resp.status_code, + resp.text, + ) + + def delete_conversation( + self, + conversation_id: str, + config: Optional[Config] = None, + ) -> None: + raise NotImplementedError(_SYNC_UNSUPPORTED_MSG) + + +__all__ = [ + "API_VERSION", + "SUPER_AGENT_RESOURCE_PATH", + "SUPER_AGENT_NAMESPACE", + "SuperAgentDataAPI", +] diff --git a/agentrun/super_agent/api/__init__.py b/agentrun/super_agent/api/__init__.py new file mode 100644 index 0000000..28825ad --- /dev/null +++ b/agentrun/super_agent/api/__init__.py @@ -0,0 +1,5 @@ +"""Super Agent 内部 API 模块 / Super Agent internal API module""" + +from .data import SuperAgentDataAPI + +__all__ = ["SuperAgentDataAPI"] diff --git a/agentrun/super_agent/api/control.py b/agentrun/super_agent/api/control.py new file mode 100644 index 0000000..4c7eafa --- /dev/null +++ b/agentrun/super_agent/api/control.py @@ -0,0 +1,550 @@ +"""Super Agent 控制面辅助函数 / Super Agent Control Plane Helpers + +本模块包含: +- 常量: API 版本号 / 协议类型 / 标签 / 资源路径 / RAM 数据域名列表 +- URL 工具: ``_add_ram_prefix_to_host`` / ``build_super_agent_endpoint`` +- AgentRuntime ↔ SuperAgent 的双向转换: + ``to_create_input`` / ``to_update_input`` / ``from_agent_runtime`` + / ``is_super_agent`` / ``parse_super_agent_config`` +- 为承载 ``externalEndpoint`` 的 Pydantic 与 Dara 层扩展类 + +不使用模板生成,保持单一来源,避免同步/异步重复维护。 +""" + +from __future__ import annotations + +import json +from typing import Any, Dict, List, Optional +from urllib.parse import urlparse, urlunparse + +from alibabacloud_agentrun20250910.client import Client as _DaraClient +from alibabacloud_agentrun20250910.models import ( + CreateAgentRuntimeInput as _DaraCreateAgentRuntimeInput, +) +from alibabacloud_agentrun20250910.models import ( + ListAgentRuntimesRequest as _DaraListAgentRuntimesRequest, +) +from alibabacloud_agentrun20250910.models import ( + ProtocolConfiguration as _DaraProtocolConfiguration, +) +from alibabacloud_agentrun20250910.models import ( + UpdateAgentRuntimeInput as _DaraUpdateAgentRuntimeInput, +) +from pydantic import Field + +from agentrun.agent_runtime.model import ( + AgentRuntimeArtifact, + AgentRuntimeContainer, + AgentRuntimeCreateInput, + AgentRuntimeProtocolConfig, + AgentRuntimeUpdateInput, +) +from agentrun.agent_runtime.runtime import AgentRuntime +from agentrun.utils.config import Config +from agentrun.utils.model import NetworkConfig, NetworkMode + +# ─── 常量 ───────────────────────────────────────────── +API_VERSION = "2025-09-10" +SUPER_AGENT_PROTOCOL_TYPE = "SUPER_AGENT" +# ``SUPER_AGENT_TAG`` 标识下游 AgentRuntime 是超级 Agent, 用于 list 过滤。 +SUPER_AGENT_TAG = "x-agentrun-super-agent" +# ``EXTERNAL_TAG`` 标识下游 AgentRuntime 由外部 (SuperAgent) 托管调用, 不由 AgentRun 直接托管。 +EXTERNAL_TAG = "x-agentrun-external" +# 创建下游 AgentRuntime 时固定写入的 tag 列表: ``[EXTERNAL_TAG, SUPER_AGENT_TAG]``。 +SUPER_AGENT_CREATE_TAGS = [EXTERNAL_TAG, SUPER_AGENT_TAG] +SUPER_AGENT_RESOURCE_PATH = "__SUPER_AGENT__" +SUPER_AGENT_INVOKE_PATH = "/invoke" +SUPER_AGENT_NAMESPACE = ( + f"{API_VERSION}/super-agents/{SUPER_AGENT_RESOURCE_PATH}" +) + +_RAM_DATA_DOMAINS = ("agentrun-data", "funagent-data-pre") + + +# ─── URL 工具 ────────────────────────────────────────── + + +def _add_ram_prefix_to_host(base_url: str) -> str: + """给已知 data host 加 ``-ram`` 前缀. + + 仅当 host 第二段命中 :data:`_RAM_DATA_DOMAINS` 时改写为 + ``-ram.<其余>``, 其他情况原样返回。 + 与 :meth:`agentrun.utils.data_api.DataAPI._get_ram_data_endpoint` 同源。 + """ + parsed = urlparse(base_url) + if not parsed.netloc: + return base_url + if not any(f".{d}." in parsed.netloc for d in _RAM_DATA_DOMAINS): + return base_url + parts = parsed.netloc.split(".", 1) + if len(parts) != 2: + return base_url + ram_netloc = parts[0] + "-ram." + parts[1] + return urlunparse(( + parsed.scheme, + ram_netloc, + parsed.path or "", + parsed.params, + parsed.query, + parsed.fragment, + )) + + +def build_super_agent_endpoint(cfg: Optional[Config] = None) -> str: + """构造 ``protocolConfiguration.externalEndpoint`` 的存储值 (不带版本号). + + 基于 :meth:`Config.get_data_endpoint` + :func:`_add_ram_prefix_to_host` + + 追加 ``/super-agents/__SUPER_AGENT__``, 自动适配生产 / 预发 / 自定义网关。 + """ + cfg = Config.with_configs(cfg) + base = cfg.get_data_endpoint() + ram_base = _add_ram_prefix_to_host(base) + return f"{ram_base.rstrip('/')}/super-agents/{SUPER_AGENT_RESOURCE_PATH}" + + +# ─── Pydantic 扩展类 ──────────────────────────────────── +class SuperAgentProtocolConfig(AgentRuntimeProtocolConfig): + """承载 ``protocol_settings`` + ``external_endpoint`` 的 Pydantic 扩展. + + 基类 ``AgentRuntimeProtocolConfig`` 的 ``type`` 字段是 ``HTTP / MCP`` 枚举, + 本子类通过 ``model_construct`` 绕过校验存入字符串 ``"SUPER_AGENT"``。 + """ + + protocol_settings: Optional[List[Dict[str, Any]]] = Field( + alias="protocolSettings", default=None + ) + external_endpoint: Optional[str] = Field( + alias="externalEndpoint", default=None + ) + + +class _SuperAgentCreateInput(AgentRuntimeCreateInput): + """默认使用 ``serialize_as_any=True`` 的 create input, 保留子类 extras. + + ``external_agent_endpoint_url`` 是基类 ``AgentRuntimeMutableProps`` 没有覆盖 + 的顶层字段, 但在 ``x-agentrun-external`` tag 下服务端强制要求填入, 这里显式 + 补齐 (alias 由 BaseModel 的 ``to_camel_case`` 生成 → ``externalAgentEndpointUrl``)。 + """ + + external_agent_endpoint_url: Optional[str] = None + + def model_dump(self, **kwargs: Any) -> Dict[str, Any]: + kwargs.setdefault("serialize_as_any", True) + return super().model_dump(**kwargs) + + +class _SuperAgentUpdateInput(AgentRuntimeUpdateInput): + """默认使用 ``serialize_as_any=True`` 的 update input, 保留子类 extras.""" + + external_agent_endpoint_url: Optional[str] = None + + def model_dump(self, **kwargs: Any) -> Dict[str, Any]: + kwargs.setdefault("serialize_as_any", True) + return super().model_dump(**kwargs) + + +# ─── Dara 模型猴补丁 ────────────────────────────────────── +# Dara 的 ``ProtocolConfiguration`` 当前版本没有 ``externalEndpoint`` 字段; +# ``AgentRuntimeClient.create_async/update_async`` 内部做 +# ``CreateAgentRuntimeInput().from_map(pydantic.model_dump())`` 的 roundtrip, +# 会在 Dara 层丢失此字段。这里做一次加性 patch: 仅追加读写 ``externalEndpoint``, +# 不改变任何现有字段行为, 用模块级哨兵属性保证幂等。 + +if not getattr(_DaraProtocolConfiguration, "_super_agent_patched", False): + _orig_to_map = _DaraProtocolConfiguration.to_map + _orig_from_map = _DaraProtocolConfiguration.from_map + + def _patched_to_map(self: _DaraProtocolConfiguration) -> Dict[str, Any]: + result = _orig_to_map(self) + ee = getattr(self, "external_endpoint", None) + if ee is not None: + result["externalEndpoint"] = ee + return result + + def _patched_from_map( + self: _DaraProtocolConfiguration, m: Optional[Dict[str, Any]] = None + ) -> _DaraProtocolConfiguration: + _orig_from_map(self, m) + if m and m.get("externalEndpoint") is not None: + self.external_endpoint = m.get("externalEndpoint") + return self + + _DaraProtocolConfiguration.to_map = _patched_to_map # type: ignore[assignment] + _DaraProtocolConfiguration.from_map = _patched_from_map # type: ignore[assignment] + _DaraProtocolConfiguration._super_agent_patched = True # type: ignore[attr-defined] + + +# Dara 的 ``CreateAgentRuntimeInput`` / ``UpdateAgentRuntimeInput`` 当前版本没有 +# ``tags`` 字段, 与 ``ProtocolConfiguration`` 同理会在 Pydantic → Dara 的 roundtrip +# 中被静默丢弃. 这里沿用同款加性 patch, 只补齐 ``tags`` 字段的读写. +def _patch_dara_tags(cls: Any) -> None: + if getattr(cls, "_super_agent_tags_patched", False): + return + _orig_to_map = cls.to_map + _orig_from_map = cls.from_map + + def _patched_to_map(self: Any) -> Dict[str, Any]: + result = _orig_to_map(self) + tags = getattr(self, "tags", None) + if tags is not None: + result["tags"] = tags + return result + + def _patched_from_map(self: Any, m: Optional[Dict[str, Any]] = None) -> Any: + _orig_from_map(self, m) + if m and m.get("tags") is not None: + self.tags = m.get("tags") + return self + + cls.to_map = _patched_to_map # type: ignore[assignment] + cls.from_map = _patched_from_map # type: ignore[assignment] + cls._super_agent_tags_patched = True # type: ignore[attr-defined] + + +_patch_dara_tags(_DaraCreateAgentRuntimeInput) +_patch_dara_tags(_DaraUpdateAgentRuntimeInput) +# ``ListAgentRuntimesRequest`` 同样没有 ``tags`` 字段: 补上 from_map/to_map 以保留 +# 属性; 真正让服务端生效的查询参数注入由下面的 client 级补丁完成。 +_patch_dara_tags(_DaraListAgentRuntimesRequest) + + +# ─── Dara 客户端猴补丁: list 请求 query 注入 tags ─────────────── +# 现版 Dara ``Client.list_agent_runtimes_with_options`` 不读 ``request.tags`` +# 构造 query, 导致即便 Pydantic 侧把 tags 传下来, 服务端也收不到。这里一次性 +# 包裹同步 / 异步两个方法: 若 request 带有 ``tags`` 就在底层 ``call_api`` 调用 +# 前把 ``tags`` (列表 → 逗号分隔) 追加到 ``req.query``。 +# 每个 API 调用都会 ``_get_client()`` 新建 ``Client`` 实例, 实例属性级别的替换 +# 在并发下是安全的。 + + +def _tags_query_value(tags: Any) -> Optional[str]: + if tags is None: + return None + if isinstance(tags, str): + return tags + if isinstance(tags, (list, tuple)): + return ",".join(str(t) for t in tags) + return str(tags) + + +def _patch_dara_client_list_tags() -> None: + if getattr(_DaraClient, "_super_agent_list_tags_patched", False): + return + + _orig_sync = _DaraClient.list_agent_runtimes_with_options + _orig_async = _DaraClient.list_agent_runtimes_with_options_async + + def _patched_sync( + self: Any, request: Any, headers: Any, runtime: Any + ) -> Any: + tags_value = _tags_query_value(getattr(request, "tags", None)) + if tags_value is None: + return _orig_sync(self, request, headers, runtime) + orig_call_api = self.call_api + + def _injecting(params: Any, req: Any, rt: Any) -> Any: + if req.query is None: + req.query = {} + req.query["tags"] = tags_value + return orig_call_api(params, req, rt) + + self.call_api = _injecting + try: + return _orig_sync(self, request, headers, runtime) + finally: + try: + del self.call_api + except AttributeError: + pass + + async def _patched_async( + self: Any, request: Any, headers: Any, runtime: Any + ) -> Any: + tags_value = _tags_query_value(getattr(request, "tags", None)) + if tags_value is None: + return await _orig_async(self, request, headers, runtime) + orig_call_api_async = self.call_api_async + + async def _injecting(params: Any, req: Any, rt: Any) -> Any: + if req.query is None: + req.query = {} + req.query["tags"] = tags_value + return await orig_call_api_async(params, req, rt) + + self.call_api_async = _injecting + try: + return await _orig_async(self, request, headers, runtime) + finally: + try: + del self.call_api_async + except AttributeError: + pass + + _DaraClient.list_agent_runtimes_with_options = _patched_sync # type: ignore[assignment] + _DaraClient.list_agent_runtimes_with_options_async = _patched_async # type: ignore[assignment] + _DaraClient._super_agent_list_tags_patched = True # type: ignore[attr-defined] + + +_patch_dara_client_list_tags() + + +# ─── AgentRuntime ↔ SuperAgent 转换 ──────────────────────── +def _business_fields_from_args( + *, + prompt: Optional[str] = None, + agents: Optional[List[str]] = None, + tools: Optional[List[str]] = None, + skills: Optional[List[str]] = None, + sandboxes: Optional[List[str]] = None, + workspaces: Optional[List[str]] = None, + model_service_name: Optional[str] = None, + model_name: Optional[str] = None, +) -> Dict[str, Any]: + """把业务字段 (None 保留为 None) 收拢成 dict, 供 ``protocolSettings.config`` 使用.""" + return { + "prompt": prompt, + "agents": agents if agents is not None else [], + "tools": tools if tools is not None else [], + "skills": skills if skills is not None else [], + "sandboxes": sandboxes if sandboxes is not None else [], + "workspaces": workspaces if workspaces is not None else [], + "modelServiceName": model_service_name, + "modelName": model_name, + } + + +def _build_protocol_settings_config( + *, name: str, business: Dict[str, Any] +) -> str: + """构造 ``protocolSettings[0].config`` 的 JSON 字符串.""" + cfg_dict: Dict[str, Any] = { + "path": SUPER_AGENT_INVOKE_PATH, + "prompt": business.get("prompt"), + "agents": business.get("agents") or [], + "tools": business.get("tools") or [], + "skills": business.get("skills") or [], + "sandboxes": business.get("sandboxes") or [], + "workspaces": business.get("workspaces") or [], + "modelServiceName": business.get("modelServiceName"), + "modelName": business.get("modelName"), + "metadata": {"agentRuntimeName": name}, + } + return json.dumps(cfg_dict, ensure_ascii=False) + + +def _build_protocol_configuration( + *, + name: str, + business: Dict[str, Any], + cfg: Optional[Config], +) -> SuperAgentProtocolConfig: + """构造超级 Agent 的 ``protocolConfiguration`` Pydantic 模型.""" + config_json = _build_protocol_settings_config(name=name, business=business) + settings: List[Dict[str, Any]] = [{ + "type": SUPER_AGENT_PROTOCOL_TYPE, + "name": name, + "path": SUPER_AGENT_INVOKE_PATH, + "config": config_json, + }] + pc = SuperAgentProtocolConfig.model_construct( + type=SUPER_AGENT_PROTOCOL_TYPE, + protocol_settings=settings, + external_endpoint=build_super_agent_endpoint(cfg), + ) + return pc + + +def to_create_input( + name: str, + *, + description: Optional[str] = None, + prompt: Optional[str] = None, + agents: Optional[List[str]] = None, + tools: Optional[List[str]] = None, + skills: Optional[List[str]] = None, + sandboxes: Optional[List[str]] = None, + workspaces: Optional[List[str]] = None, + model_service_name: Optional[str] = None, + model_name: Optional[str] = None, + cfg: Optional[Config] = None, +) -> AgentRuntimeCreateInput: + """把超级 Agent 业务字段转为 :class:`AgentRuntimeCreateInput`.""" + business = _business_fields_from_args( + prompt=prompt, + agents=agents, + tools=tools, + skills=skills, + sandboxes=sandboxes, + workspaces=workspaces, + model_service_name=model_service_name, + model_name=model_name, + ) + pc = _build_protocol_configuration(name=name, business=business, cfg=cfg) + # SUPER_AGENT 是平台托管运行时, 不跑用户代码/容器, 但服务端仍要求 + # artifact_type / network_configuration 非空. 这里给占位默认值即可. + return _SuperAgentCreateInput.model_construct( + agent_runtime_name=name, + description=description, + protocol_configuration=pc, + tags=list(SUPER_AGENT_CREATE_TAGS), + # 带 ``x-agentrun-external`` tag 时服务端强制要求 externalAgentEndpointUrl 非空, + # 对超级 Agent 而言即数据面入口 (与 protocolConfiguration.externalEndpoint 同值)。 + external_agent_endpoint_url=build_super_agent_endpoint(cfg), + # 占位 artifact: SUPER_AGENT 不跑用户 container/code, 但服务端要求非空。 + artifact_type=AgentRuntimeArtifact.CONTAINER, + container_configuration=AgentRuntimeContainer( + image="registry.cn-hangzhou.aliyuncs.com/agentrun/super-agent-placeholder:v1" + ), + network_configuration=NetworkConfig(network_mode=NetworkMode.PUBLIC), + ) + + +def to_update_input( + name: str, + merged: Dict[str, Any], + cfg: Optional[Config] = None, +) -> AgentRuntimeUpdateInput: + """把合并后的业务字段转为 :class:`AgentRuntimeUpdateInput` (全量替换).""" + business = _business_fields_from_args( + prompt=merged.get("prompt"), + agents=merged.get("agents"), + tools=merged.get("tools"), + skills=merged.get("skills"), + sandboxes=merged.get("sandboxes"), + workspaces=merged.get("workspaces"), + model_service_name=merged.get("model_service_name"), + model_name=merged.get("model_name"), + ) + pc = _build_protocol_configuration(name=name, business=business, cfg=cfg) + return _SuperAgentUpdateInput.model_construct( + agent_runtime_name=name, + description=merged.get("description"), + protocol_configuration=pc, + tags=list(SUPER_AGENT_CREATE_TAGS), + # 带 ``x-agentrun-external`` tag 时服务端强制要求 externalAgentEndpointUrl 非空。 + external_agent_endpoint_url=build_super_agent_endpoint(cfg), + # 占位 artifact: SUPER_AGENT 不跑用户 container/code, 但服务端要求非空。 + artifact_type=AgentRuntimeArtifact.CONTAINER, + container_configuration=AgentRuntimeContainer( + image="registry.cn-hangzhou.aliyuncs.com/agentrun/super-agent-placeholder:v1" + ), + network_configuration=NetworkConfig(network_mode=NetworkMode.PUBLIC), + ) + + +def _extract_protocol_configuration(rt: AgentRuntime) -> Dict[str, Any]: + """统一把 rt.protocol_configuration 转为 dict (兼容 dict / pydantic 两种形态).""" + pc = getattr(rt, "protocol_configuration", None) + if pc is None: + return {} + if isinstance(pc, dict): + return pc + # Pydantic 模型: 用 model_dump(serialize_as_any=True) 保留 extras + try: + return pc.model_dump(by_alias=True, serialize_as_any=True) + except TypeError: + return pc.model_dump(by_alias=True) + + +def _extract_protocol_settings(pc_dict: Dict[str, Any]) -> List[Dict[str, Any]]: + """从 protocolConfiguration dict 中取出 protocolSettings 列表.""" + for key in ("protocolSettings", "protocol_settings"): + v = pc_dict.get(key) + if isinstance(v, list): + return v + return [] + + +def is_super_agent(rt: AgentRuntime) -> bool: + """判断一个 AgentRuntime 是否为超级 Agent.""" + pc_dict = _extract_protocol_configuration(rt) + if not pc_dict: + return False + settings = _extract_protocol_settings(pc_dict) + if not settings: + return False + first = settings[0] if isinstance(settings[0], dict) else {} + return first.get("type") == SUPER_AGENT_PROTOCOL_TYPE + + +def parse_super_agent_config(rt: AgentRuntime) -> Dict[str, Any]: + """反解 ``protocolSettings[0].config`` 为业务字段 dict. + + 如果 config 缺失或非法 JSON, 返回空 dict (不抛异常)。 + """ + pc_dict = _extract_protocol_configuration(rt) + if not pc_dict: + return {} + settings = _extract_protocol_settings(pc_dict) + if not settings: + return {} + first = settings[0] if isinstance(settings[0], dict) else {} + raw_config = first.get("config") + if not raw_config: + return {} + if isinstance(raw_config, dict): + return raw_config + if isinstance(raw_config, str): + try: + parsed = json.loads(raw_config) + return parsed if isinstance(parsed, dict) else {} + except (TypeError, ValueError): + return {} + return {} + + +def _get_external_endpoint(rt: AgentRuntime) -> str: + pc_dict = _extract_protocol_configuration(rt) + return ( + pc_dict.get("externalEndpoint") + or pc_dict.get("external_endpoint", "") + or "" + ) + + +def from_agent_runtime(rt: AgentRuntime) -> "SuperAgent": # noqa: F821 + """反解 AgentRuntime → SuperAgent 实例 (不注入 ``_client``).""" + # 延迟导入避免循环 + from agentrun.super_agent.agent import SuperAgent + + business = parse_super_agent_config(rt) + return SuperAgent( + name=getattr(rt, "agent_runtime_name", None) or "", + description=getattr(rt, "description", None), + prompt=business.get("prompt"), + agents=list(business.get("agents") or []), + tools=list(business.get("tools") or []), + skills=list(business.get("skills") or []), + sandboxes=list(business.get("sandboxes") or []), + workspaces=list(business.get("workspaces") or []), + model_service_name=business.get("modelServiceName"), + model_name=business.get("modelName"), + agent_runtime_id=getattr(rt, "agent_runtime_id", None) or "", + arn=getattr(rt, "agent_runtime_arn", None) or "", + status=str(getattr(rt, "status", "") or ""), + created_at=getattr(rt, "created_at", None) or "", + last_updated_at=getattr(rt, "last_updated_at", None) or "", + external_endpoint=_get_external_endpoint(rt), + ) + + +__all__ = [ + "API_VERSION", + "SUPER_AGENT_PROTOCOL_TYPE", + "SUPER_AGENT_TAG", + "EXTERNAL_TAG", + "SUPER_AGENT_CREATE_TAGS", + "SUPER_AGENT_RESOURCE_PATH", + "SUPER_AGENT_INVOKE_PATH", + "SUPER_AGENT_NAMESPACE", + "SuperAgentProtocolConfig", + "_SuperAgentCreateInput", + "_SuperAgentUpdateInput", + "build_super_agent_endpoint", + "_add_ram_prefix_to_host", + "to_create_input", + "to_update_input", + "from_agent_runtime", + "is_super_agent", + "parse_super_agent_config", +] diff --git a/agentrun/super_agent/api/data.py b/agentrun/super_agent/api/data.py new file mode 100644 index 0000000..ec6b2bb --- /dev/null +++ b/agentrun/super_agent/api/data.py @@ -0,0 +1,275 @@ +""" +This file is auto generated by the code generation script. +Do not modify this file manually. +Use the `make codegen` command to regenerate. + +当前文件为自动生成的控制 API 客户端代码。请勿手动修改此文件。 +使用 `make codegen` 命令重新生成。 + +source: agentrun/super_agent/api/__data_async_template.py + +Super Agent 数据面 API / Super Agent Data Plane API + +继承 :class:`agentrun.utils.data_api.DataAPI`, 复用: + +- ``_get_ram_data_endpoint``: RAM ``-ram`` 前缀改写 +- ``get_base_url`` / ``with_path``: URL 拼接 (含版本号) +- ``auth``: RAM 签名头生成 (``Agentrun-Authorization`` / ``x-acs-*``) + +同步方法第一版抛 ``NotImplementedError``, 仅保留骨架以备未来扩展。 +""" + +import json +from typing import Any, AsyncIterator, Dict, Iterator, List, Optional + +import httpx + +from agentrun.super_agent.model import InvokeResponseData +from agentrun.super_agent.stream import parse_sse_async, SSEEvent +from agentrun.utils.config import Config +from agentrun.utils.data_api import DataAPI, ResourceType +from agentrun.utils.log import logger + +API_VERSION = "2025-09-10" +SUPER_AGENT_RESOURCE_PATH = "__SUPER_AGENT__" +SUPER_AGENT_NAMESPACE = ( + f"{API_VERSION}/super-agents/{SUPER_AGENT_RESOURCE_PATH}" +) + +_SYNC_UNSUPPORTED_MSG = ( + "sync version not supported, use *_async (see decision 14 in" + " openspec/changes/add-super-agent-sdk/design.md)" +) + + +class SuperAgentDataAPI(DataAPI): + """Super Agent 数据面 API (异步主路径 + 同步占位).""" + + def __init__( + self, + agent_runtime_name: str, + config: Optional[Config] = None, + ): + super().__init__( + resource_name=agent_runtime_name, + resource_type=ResourceType.Runtime, + namespace=SUPER_AGENT_NAMESPACE, + config=config, + ) + self.agent_runtime_name = agent_runtime_name + + def _build_invoke_body( + self, + messages: List[Dict[str, Any]], + conversation_id: Optional[str], + forwarded_extras: Optional[Dict[str, Any]] = None, + ) -> Dict[str, Any]: + # ``forwarded_extras`` 承载从 AgentRuntime 元数据读出的业务字段 + # (prompt/agents/tools/skills/sandboxes/workspaces/modelServiceName/modelName), + # 由上层 ``SuperAgent.invoke_async`` 注入。``metadata`` 和 ``conversationId`` + # 由 SDK 管理, 不允许 extras 覆盖。 + forwarded: Dict[str, Any] = dict(forwarded_extras or {}) + forwarded["metadata"] = {"agentRuntimeName": self.agent_runtime_name} + if conversation_id is not None: + forwarded["conversationId"] = conversation_id + return {"messages": list(messages), "forwardedProps": forwarded} + + def _parse_invoke_response( + self, payload: Dict[str, Any] + ) -> InvokeResponseData: + if not isinstance(payload, dict): + raise ValueError( + "Invalid invoke response: expected object, got" + f" {type(payload).__name__}" + ) + data = payload.get("data") + if not isinstance(data, dict): + raise ValueError("Invalid invoke response: missing data field") + conversation_id = data.get("conversationId") + if not conversation_id: + raise ValueError( + "Invalid invoke response: missing conversationId field" + ) + stream_url = data.get("url") + if not stream_url: + raise ValueError("Invalid invoke response: missing url field") + raw_headers = data.get("headers") or {} + stream_headers = {str(k): str(v) for k, v in dict(raw_headers).items()} + return InvokeResponseData( + conversation_id=str(conversation_id), + stream_url=str(stream_url), + stream_headers=stream_headers, + ) + + async def invoke_async( + self, + messages: List[Dict[str, Any]], + conversation_id: Optional[str] = None, + config: Optional[Config] = None, + forwarded_extras: Optional[Dict[str, Any]] = None, + ) -> InvokeResponseData: + """Phase 1: POST /invoke, 返回 Phase 2 的 URL 与 headers.""" + cfg = Config.with_configs(self.config, config) + url = self.with_path("invoke", config=cfg) + body = self._build_invoke_body( + messages, conversation_id, forwarded_extras + ) + body_bytes = json.dumps(body).encode("utf-8") + _, signed_headers, _ = self.auth( + url=url, + method="POST", + headers=cfg.get_headers(), + body=body_bytes, + config=cfg, + ) + signed_headers.setdefault("Content-Type", "application/json") + logger.debug("super_agent invoke request: POST %s body=%s", url, body) + + async with httpx.AsyncClient(timeout=cfg.get_timeout()) as client: + resp = await client.post( + url, headers=signed_headers, content=body_bytes + ) + resp.raise_for_status() + payload = resp.json() + logger.debug( + "super_agent invoke response: status=%d payload=%s", + resp.status_code, + payload, + ) + return self._parse_invoke_response(payload) + + def invoke( + self, + messages: List[Dict[str, Any]], + conversation_id: Optional[str] = None, + config: Optional[Config] = None, + forwarded_extras: Optional[Dict[str, Any]] = None, + ) -> InvokeResponseData: + raise NotImplementedError(_SYNC_UNSUPPORTED_MSG) + + async def stream_async( + self, + stream_url: str, + stream_headers: Optional[Dict[str, str]] = None, + config: Optional[Config] = None, + ) -> AsyncIterator[SSEEvent]: + """Phase 2: GET stream_url → 流式 yield SSEEvent. + + 合并优先级 (低 → 高): ``cfg.get_headers()`` → ``stream_headers`` → RAM 签名头。 + """ + cfg = Config.with_configs(self.config, config) + _, signed_headers, _ = self.auth( + url=stream_url, + method="GET", + headers=cfg.get_headers(), + config=cfg, + ) + merged_headers: Dict[str, str] = {} + merged_headers.update(cfg.get_headers()) + if stream_headers: + merged_headers.update(stream_headers) + merged_headers.update(signed_headers) + + logger.debug("super_agent stream request: GET %s", stream_url) + timeout = httpx.Timeout(cfg.get_timeout(), read=None) + client = httpx.AsyncClient(timeout=timeout) + try: + ctx = client.stream("GET", stream_url, headers=merged_headers) + response = await ctx.__aenter__() + try: + response.raise_for_status() + logger.debug( + "super_agent stream response: status=%d headers=%s", + response.status_code, + dict(response.headers), + ) + async for event in parse_sse_async(response): + logger.debug("super_agent stream event: %s", event) + yield event + finally: + await ctx.__aexit__(None, None, None) + finally: + await client.aclose() + + def stream( + self, + stream_url: str, + stream_headers: Optional[Dict[str, str]] = None, + config: Optional[Config] = None, + ) -> Iterator[SSEEvent]: + raise NotImplementedError(_SYNC_UNSUPPORTED_MSG) + + async def get_conversation_async( + self, + conversation_id: str, + config: Optional[Config] = None, + ) -> Dict[str, Any]: + """GET /conversations/{id} → 返回服务端 ``data`` 字段 dict.""" + cfg = Config.with_configs(self.config, config) + url = self.with_path(f"conversations/{conversation_id}", config=cfg) + _, signed_headers, _ = self.auth( + url=url, + method="GET", + headers=cfg.get_headers(), + config=cfg, + ) + logger.debug("super_agent get_conversation request: GET %s", url) + async with httpx.AsyncClient(timeout=cfg.get_timeout()) as client: + resp = await client.get(url, headers=signed_headers) + resp.raise_for_status() + payload = resp.json() if resp.text else {} + logger.debug( + "super_agent get_conversation response: status=%d payload=%s", + resp.status_code, + payload, + ) + if not isinstance(payload, dict): + return {} + data = payload.get("data") + return data if isinstance(data, dict) else {} + + def get_conversation( + self, + conversation_id: str, + config: Optional[Config] = None, + ) -> Dict[str, Any]: + raise NotImplementedError(_SYNC_UNSUPPORTED_MSG) + + async def delete_conversation_async( + self, + conversation_id: str, + config: Optional[Config] = None, + ) -> None: + """DELETE /conversations/{id}.""" + cfg = Config.with_configs(self.config, config) + url = self.with_path(f"conversations/{conversation_id}", config=cfg) + _, signed_headers, _ = self.auth( + url=url, + method="DELETE", + headers=cfg.get_headers(), + config=cfg, + ) + logger.debug("super_agent delete_conversation request: DELETE %s", url) + async with httpx.AsyncClient(timeout=cfg.get_timeout()) as client: + resp = await client.delete(url, headers=signed_headers) + resp.raise_for_status() + logger.debug( + "super_agent delete_conversation response: status=%d body=%s", + resp.status_code, + resp.text, + ) + + def delete_conversation( + self, + conversation_id: str, + config: Optional[Config] = None, + ) -> None: + raise NotImplementedError(_SYNC_UNSUPPORTED_MSG) + + +__all__ = [ + "API_VERSION", + "SUPER_AGENT_RESOURCE_PATH", + "SUPER_AGENT_NAMESPACE", + "SuperAgentDataAPI", +] diff --git a/agentrun/super_agent/client.py b/agentrun/super_agent/client.py new file mode 100644 index 0000000..8ac6c84 --- /dev/null +++ b/agentrun/super_agent/client.py @@ -0,0 +1,524 @@ +""" +This file is auto generated by the code generation script. +Do not modify this file manually. +Use the `make codegen` command to regenerate. + +当前文件为自动生成的控制 API 客户端代码。请勿手动修改此文件。 +使用 `make codegen` 命令重新生成。 + +source: agentrun/super_agent/__client_async_template.py + +SuperAgentClient / 超级 Agent 客户端 + +对外入口: CRUDL (create / get / update / delete / list / list_all) 同步 + 异步双写。 +内部持有一个 :class:`AgentRuntimeClient` 实例, 通过 ``api/control.py`` 的 +转换函数把 ``SuperAgent`` 与 ``AgentRuntime`` 互相映射。 + +list 固定按 tag ``x-agentrun-super-agent`` 过滤, 不接受用户自定义 tag。 +""" + +import asyncio +import time +from typing import Any, List, Optional + +from alibabacloud_agentrun20250910.models import ( + CreateAgentRuntimeInput, + UpdateAgentRuntimeInput, +) + +from agentrun.agent_runtime.api import AgentRuntimeControlAPI +from agentrun.agent_runtime.client import AgentRuntimeClient +from agentrun.agent_runtime.model import AgentRuntimeListInput +from agentrun.agent_runtime.runtime import AgentRuntime +from agentrun.super_agent.agent import SuperAgent +from agentrun.super_agent.api.control import ( + from_agent_runtime, + is_super_agent, + SUPER_AGENT_TAG, + to_create_input, + to_update_input, +) +from agentrun.utils.config import Config +from agentrun.utils.log import logger +from agentrun.utils.model import Status + +# 公开 API 签名故意保持 ``Optional[X] = None`` 对外简洁; +# ``_UNSET`` 仅用于内部区分 "未传" 与 "显式 None (= 清空)". +_UNSET: Any = object() + +# create/update 轮询默认参数 +_WAIT_INTERVAL_SECONDS = 3 +_WAIT_TIMEOUT_SECONDS = 300 + + +def _raise_if_failed(rt: AgentRuntime, action: str) -> None: + """若 rt 处于失败态, 抛出带 status_reason 的 RuntimeError.""" + status = getattr(rt, "status", None) + status_str = str(status) if status is not None else "" + if status_str in { + Status.CREATE_FAILED.value, + Status.UPDATE_FAILED.value, + Status.DELETE_FAILED.value, + }: + reason = getattr(rt, "status_reason", None) or "(no reason)" + name = getattr(rt, "agent_runtime_name", None) or "(unknown)" + raise RuntimeError( + f"Super agent {action} failed: name={name!r} status={status_str} " + f"reason={reason}" + ) + + +def _merge(current: dict, updates: dict) -> dict: + """把 ``updates`` 中非 ``_UNSET`` 的字段合并到 ``current`` (None 表示清空).""" + merged = dict(current) + for key, value in updates.items(): + if value is _UNSET: + continue + merged[key] = value + return merged + + +def _super_agent_to_business_dict(agent: SuperAgent) -> dict: + return { + "description": agent.description, + "prompt": agent.prompt, + "agents": list(agent.agents), + "tools": list(agent.tools), + "skills": list(agent.skills), + "sandboxes": list(agent.sandboxes), + "workspaces": list(agent.workspaces), + "model_service_name": agent.model_service_name, + "model_name": agent.model_name, + } + + +class SuperAgentClient: + """Super Agent CRUDL 客户端.""" + + def __init__(self, config: Optional[Config] = None) -> None: + self.config = config + self._rt = AgentRuntimeClient(config=config) + # create/update 绕过 AgentRuntimeClient 的 artifact_type 校验 (SUPER_AGENT 不需要 code/container), + # 并通过 ``ProtocolConfiguration`` 的 monkey-patch 保留 ``externalEndpoint`` 字段。 + self._rt_control = AgentRuntimeControlAPI(config=config) + + async def _wait_final_async( + self, + agent_runtime_id: str, + *, + config: Optional[Config] = None, + interval_seconds: int = _WAIT_INTERVAL_SECONDS, + timeout_seconds: int = _WAIT_TIMEOUT_SECONDS, + ) -> AgentRuntime: + """轮询 get 直到 status 进入最终态 (READY / *_FAILED).""" + cfg = Config.with_configs(self.config, config) + start = time.monotonic() + while True: + rt = await self._rt.get_async(agent_runtime_id, config=cfg) + status = getattr(rt, "status", None) + logger.debug( + "super agent %s poll status=%s", agent_runtime_id, status + ) + if Status.is_final_status(status): + return rt + if time.monotonic() - start > timeout_seconds: + raise TimeoutError( + f"Timed out waiting for super agent {agent_runtime_id!r}" + f" to reach final status (last status={status})" + ) + await asyncio.sleep(interval_seconds) + + def _wait_final( + self, + agent_runtime_id: str, + *, + config: Optional[Config] = None, + interval_seconds: int = _WAIT_INTERVAL_SECONDS, + timeout_seconds: int = _WAIT_TIMEOUT_SECONDS, + ) -> AgentRuntime: + """同步版 _wait_final_async.""" + cfg = Config.with_configs(self.config, config) + start = time.monotonic() + while True: + rt = self._rt.get(agent_runtime_id, config=cfg) + status = getattr(rt, "status", None) + logger.debug( + "super agent %s poll status=%s", agent_runtime_id, status + ) + if Status.is_final_status(status): + return rt + if time.monotonic() - start > timeout_seconds: + raise TimeoutError( + f"Timed out waiting for super agent {agent_runtime_id!r}" + f" to reach final status (last status={status})" + ) + time.sleep(interval_seconds) + + # ─── Create ────────────────────────────────────── + async def create_async( + self, + *, + name: str, + description: Optional[str] = None, + prompt: Optional[str] = None, + agents: Optional[List[str]] = None, + tools: Optional[List[str]] = None, + skills: Optional[List[str]] = None, + sandboxes: Optional[List[str]] = None, + workspaces: Optional[List[str]] = None, + model_service_name: Optional[str] = None, + model_name: Optional[str] = None, + config: Optional[Config] = None, + ) -> SuperAgent: + """异步创建超级 Agent.""" + cfg = Config.with_configs(self.config, config) + rt_input = to_create_input( + name, + description=description, + prompt=prompt, + agents=agents, + tools=tools, + skills=skills, + sandboxes=sandboxes, + workspaces=workspaces, + model_service_name=model_service_name, + model_name=model_name, + cfg=cfg, + ) + dara_input = CreateAgentRuntimeInput().from_map(rt_input.model_dump()) + result = await self._rt_control.create_agent_runtime_async( + dara_input, config=cfg + ) + rt = AgentRuntime.from_inner_object(result) + # 轮询直到进入最终态; 失败则抛出带 status_reason 的错误。 + agent_id = getattr(rt, "agent_runtime_id", None) + if agent_id and not Status.is_final_status(getattr(rt, "status", None)): + rt = await self._wait_final_async(agent_id, config=cfg) + _raise_if_failed(rt, action="create") + agent = from_agent_runtime(rt) + agent._client = self + return agent + + def create( + self, + *, + name: str, + description: Optional[str] = None, + prompt: Optional[str] = None, + agents: Optional[List[str]] = None, + tools: Optional[List[str]] = None, + skills: Optional[List[str]] = None, + sandboxes: Optional[List[str]] = None, + workspaces: Optional[List[str]] = None, + model_service_name: Optional[str] = None, + model_name: Optional[str] = None, + config: Optional[Config] = None, + ) -> SuperAgent: + """同步创建超级 Agent.""" + cfg = Config.with_configs(self.config, config) + rt_input = to_create_input( + name, + description=description, + prompt=prompt, + agents=agents, + tools=tools, + skills=skills, + sandboxes=sandboxes, + workspaces=workspaces, + model_service_name=model_service_name, + model_name=model_name, + cfg=cfg, + ) + dara_input = CreateAgentRuntimeInput().from_map(rt_input.model_dump()) + result = self._rt_control.create_agent_runtime(dara_input, config=cfg) + rt = AgentRuntime.from_inner_object(result) + agent_id = getattr(rt, "agent_runtime_id", None) + if agent_id: + rt = self._wait_final(agent_id, config=cfg) + _raise_if_failed(rt, action="create") + agent = from_agent_runtime(rt) + agent._client = self + return agent + + # ─── Get ────────────────────────────────────────── + # Aliyun 控制面 get/delete/update 接口只认 ``agent_runtime_id`` (URN), + # 不认 resource_name; ``_find_rt_by_name*`` 用 list + 名称匹配来解析 id. + def _find_rt_by_name(self, name: str, config: Optional[Config]) -> Any: + cfg = Config.with_configs(self.config, config) + page_number = 1 + page_size = 50 + while True: + runtimes = self._rt.list( + AgentRuntimeListInput( + page_number=page_number, + page_size=page_size, + tags=SUPER_AGENT_TAG, + ), + config=cfg, + ) + for rt in runtimes: + if getattr(rt, "agent_runtime_name", None) == name: + return rt + if len(runtimes) < page_size: + return None + page_number += 1 + + async def _find_rt_by_name_async( + self, name: str, config: Optional[Config] + ) -> Any: + cfg = Config.with_configs(self.config, config) + page_number = 1 + page_size = 50 + while True: + runtimes = await self._rt.list_async( + AgentRuntimeListInput( + page_number=page_number, + page_size=page_size, + tags=SUPER_AGENT_TAG, + ), + config=cfg, + ) + for rt in runtimes: + if getattr(rt, "agent_runtime_name", None) == name: + return rt + if len(runtimes) < page_size: + return None + page_number += 1 + + async def get_async( + self, name: str, *, config: Optional[Config] = None + ) -> SuperAgent: + """异步获取超级 Agent (名称解析 → ID).""" + cfg = Config.with_configs(self.config, config) + rt = await self._find_rt_by_name_async(name, config=cfg) + if rt is None: + raise ValueError(f"Super agent {name!r} not found") + if not is_super_agent(rt): + raise ValueError(f"Resource {name!r} is not a super agent") + agent = from_agent_runtime(rt) + agent._client = self + return agent + + def get(self, name: str, *, config: Optional[Config] = None) -> SuperAgent: + """同步获取超级 Agent (名称解析 → ID).""" + cfg = Config.with_configs(self.config, config) + rt = self._find_rt_by_name(name, config=cfg) + if rt is None: + raise ValueError(f"Super agent {name!r} not found") + if not is_super_agent(rt): + raise ValueError(f"Resource {name!r} is not a super agent") + agent = from_agent_runtime(rt) + agent._client = self + return agent + + # ─── Update (read-merge-write) ───────────────────── + async def update_async( + self, + name: str, + *, + description: Any = _UNSET, + prompt: Any = _UNSET, + agents: Any = _UNSET, + tools: Any = _UNSET, + skills: Any = _UNSET, + sandboxes: Any = _UNSET, + workspaces: Any = _UNSET, + model_service_name: Any = _UNSET, + model_name: Any = _UNSET, + config: Optional[Config] = None, + ) -> SuperAgent: + """异步更新超级 Agent (read-merge-write).""" + cfg = Config.with_configs(self.config, config) + rt = await self._find_rt_by_name_async(name, config=cfg) + if rt is None: + raise ValueError(f"Super agent {name!r} not found") + if not is_super_agent(rt): + raise ValueError(f"Resource {name!r} is not a super agent") + current = _super_agent_to_business_dict(from_agent_runtime(rt)) + updates = { + "description": description, + "prompt": prompt, + "agents": agents, + "tools": tools, + "skills": skills, + "sandboxes": sandboxes, + "workspaces": workspaces, + "model_service_name": model_service_name, + "model_name": model_name, + } + merged = _merge(current, updates) + rt_input = to_update_input(name, merged, cfg=cfg) + dara_input = UpdateAgentRuntimeInput().from_map(rt_input.model_dump()) + agent_id = getattr(rt, "agent_runtime_id", None) or name + result = await self._rt_control.update_agent_runtime_async( + agent_id, dara_input, config=cfg + ) + rt = AgentRuntime.from_inner_object(result) + rt_id = getattr(rt, "agent_runtime_id", None) or agent_id + if rt_id: + rt = await self._wait_final_async(rt_id, config=cfg) + _raise_if_failed(rt, action="update") + agent = from_agent_runtime(rt) + agent._client = self + return agent + + def update( + self, + name: str, + *, + description: Any = _UNSET, + prompt: Any = _UNSET, + agents: Any = _UNSET, + tools: Any = _UNSET, + skills: Any = _UNSET, + sandboxes: Any = _UNSET, + workspaces: Any = _UNSET, + model_service_name: Any = _UNSET, + model_name: Any = _UNSET, + config: Optional[Config] = None, + ) -> SuperAgent: + """同步更新超级 Agent (read-merge-write).""" + cfg = Config.with_configs(self.config, config) + rt = self._find_rt_by_name(name, config=cfg) + if rt is None: + raise ValueError(f"Super agent {name!r} not found") + if not is_super_agent(rt): + raise ValueError(f"Resource {name!r} is not a super agent") + current = _super_agent_to_business_dict(from_agent_runtime(rt)) + updates = { + "description": description, + "prompt": prompt, + "agents": agents, + "tools": tools, + "skills": skills, + "sandboxes": sandboxes, + "workspaces": workspaces, + "model_service_name": model_service_name, + "model_name": model_name, + } + merged = _merge(current, updates) + rt_input = to_update_input(name, merged, cfg=cfg) + dara_input = UpdateAgentRuntimeInput().from_map(rt_input.model_dump()) + agent_id = getattr(rt, "agent_runtime_id", None) or name + result = self._rt_control.update_agent_runtime( + agent_id, dara_input, config=cfg + ) + rt = AgentRuntime.from_inner_object(result) + rt_id = getattr(rt, "agent_runtime_id", None) or agent_id + if rt_id: + rt = self._wait_final(rt_id, config=cfg) + _raise_if_failed(rt, action="update") + agent = from_agent_runtime(rt) + agent._client = self + return agent + + # ─── Delete ─────────────────────────────────────── + async def delete_async( + self, name: str, *, config: Optional[Config] = None + ) -> None: + """异步删除超级 Agent (名称解析 → ID).""" + cfg = Config.with_configs(self.config, config) + rt = await self._find_rt_by_name_async(name, config=cfg) + if rt is None: + raise ValueError(f"Super agent {name!r} not found") + agent_id = getattr(rt, "agent_runtime_id", None) or name + await self._rt.delete_async(agent_id, config=cfg) + + def delete(self, name: str, *, config: Optional[Config] = None) -> None: + """同步删除超级 Agent (名称解析 → ID).""" + cfg = Config.with_configs(self.config, config) + rt = self._find_rt_by_name(name, config=cfg) + if rt is None: + raise ValueError(f"Super agent {name!r} not found") + agent_id = getattr(rt, "agent_runtime_id", None) or name + self._rt.delete(agent_id, config=cfg) + + # ─── List ───────────────────────────────────────── + async def list_async( + self, + *, + page_number: int = 1, + page_size: int = 20, + config: Optional[Config] = None, + ) -> List[SuperAgent]: + """异步列出超级 Agent (固定 tag 过滤, 过滤非 SUPER_AGENT).""" + cfg = Config.with_configs(self.config, config) + rt_input = AgentRuntimeListInput( + page_number=page_number, + page_size=page_size, + tags=SUPER_AGENT_TAG, + ) + runtimes = await self._rt.list_async(rt_input, config=cfg) + result: List[SuperAgent] = [] + for rt in runtimes: + if not is_super_agent(rt): + continue + agent = from_agent_runtime(rt) + agent._client = self + result.append(agent) + return result + + def list( + self, + *, + page_number: int = 1, + page_size: int = 20, + config: Optional[Config] = None, + ) -> List[SuperAgent]: + """同步列出超级 Agent (固定 tag 过滤, 过滤非 SUPER_AGENT).""" + cfg = Config.with_configs(self.config, config) + rt_input = AgentRuntimeListInput( + page_number=page_number, + page_size=page_size, + tags=SUPER_AGENT_TAG, + ) + runtimes = self._rt.list(rt_input, config=cfg) + result: List[SuperAgent] = [] + for rt in runtimes: + if not is_super_agent(rt): + continue + agent = from_agent_runtime(rt) + agent._client = self + result.append(agent) + return result + + async def list_all_async( + self, *, config: Optional[Config] = None, page_size: int = 50 + ) -> List[SuperAgent]: + """异步一次性拉取所有超级 Agent (自动分页).""" + cfg = Config.with_configs(self.config, config) + result: List[SuperAgent] = [] + page_number = 1 + while True: + page = await self.list_async( + page_number=page_number, page_size=page_size, config=cfg + ) + if not page: + break + result.extend(page) + if len(page) < page_size: + break + page_number += 1 + return result + + def list_all( + self, *, config: Optional[Config] = None, page_size: int = 50 + ) -> List[SuperAgent]: + """同步一次性拉取所有超级 Agent (自动分页).""" + cfg = Config.with_configs(self.config, config) + result: List[SuperAgent] = [] + page_number = 1 + while True: + page = self.list( + page_number=page_number, page_size=page_size, config=cfg + ) + if not page: + break + result.extend(page) + if len(page) < page_size: + break + page_number += 1 + return result + + +__all__ = ["SuperAgentClient"] diff --git a/agentrun/super_agent/model.py b/agentrun/super_agent/model.py new file mode 100644 index 0000000..c710263 --- /dev/null +++ b/agentrun/super_agent/model.py @@ -0,0 +1,91 @@ +"""Super Agent 数据模型 / Super Agent Data Models + +此模块定义超级 Agent SDK 的输入输出数据模型。 +This module defines the input/output data models for the Super Agent SDK. +""" + +from dataclasses import dataclass, field +from typing import Any, Dict, List, Optional + + +@dataclass +class SuperAgentCreateInput: + """超级 Agent 创建输入 / Super Agent creation input""" + + name: str + description: Optional[str] = None + prompt: Optional[str] = None + agents: List[str] = field(default_factory=list) + tools: List[str] = field(default_factory=list) + skills: List[str] = field(default_factory=list) + sandboxes: List[str] = field(default_factory=list) + workspaces: List[str] = field(default_factory=list) + model_service_name: Optional[str] = None + model_name: Optional[str] = None + + +@dataclass +class SuperAgentUpdateInput: + """超级 Agent 更新输入 / Super Agent update input + + 仅传想修改的字段, 其他保持不变。 + Only pass the fields to modify; others remain unchanged. + """ + + name: str + description: Optional[str] = None + prompt: Optional[str] = None + agents: Optional[List[str]] = None + tools: Optional[List[str]] = None + skills: Optional[List[str]] = None + sandboxes: Optional[List[str]] = None + workspaces: Optional[List[str]] = None + model_service_name: Optional[str] = None + model_name: Optional[str] = None + + +@dataclass +class SuperAgentListInput: + """超级 Agent 列表查询输入 / Super Agent list query input""" + + page_number: int = 1 + page_size: int = 20 + + +@dataclass +class InvokeResponseData: + """Phase 1 响应 data 字段的强类型表示。 + + Strongly-typed representation of the data field in the phase 1 response. + """ + + conversation_id: str + stream_url: str + stream_headers: Dict[str, str] + + +@dataclass +class Message: + """会话消息 / Conversation message""" + + role: str + content: str + message_id: Optional[str] = None + created_at: Optional[int] = None + + +@dataclass +class ConversationInfo: + """服务端会话信息 / Server-side conversation info""" + + conversation_id: str + agent_id: str + title: Optional[str] = None + main_user_id: Optional[str] = None + sub_user_id: Optional[str] = None + created_at: int = 0 + updated_at: int = 0 + error_message: Optional[str] = None + invoke_info: Optional[Dict[str, Any]] = None + messages: List[Message] = field(default_factory=list) + params: Optional[Dict[str, Any]] = None diff --git a/agentrun/super_agent/stream.py b/agentrun/super_agent/stream.py new file mode 100644 index 0000000..d1c76f5 --- /dev/null +++ b/agentrun/super_agent/stream.py @@ -0,0 +1,180 @@ +"""Super Agent SSE 流 / Super Agent SSE Stream + +- :class:`SSEEvent`: SSE 协议单个事件的原始表示, 不做业务 normalize。 +- :func:`parse_sse_async`: 从 ``httpx.Response.aiter_lines`` 提取 ``SSEEvent``。 +- :class:`InvokeStream`: Phase 1 状态载体 + Phase 2 懒触发的异步流。 + +故意不引入 ``httpx-sse`` 等额外依赖;约 30 行的解析器足以覆盖 +``event / data / id / retry`` 四字段、注释行、多行 ``data:``、流末尾 flush。 +""" + +from __future__ import annotations + +from dataclasses import dataclass, field +import json +from typing import Any, AsyncIterator, Awaitable, Callable, Dict, List, Optional + +import httpx + + +@dataclass +class SSEEvent: + """SSE 协议单个事件的原始表示. + + SDK 不做高阶 normalize, 调用方按需用 :meth:`data_json` 解析。 + """ + + event: Optional[str] = None + data: str = "" + id: Optional[str] = None + retry: Optional[int] = None + + def data_json(self) -> Optional[Any]: + """尝试把 ``data`` 解析为 JSON, 失败或为空返回 ``None``.""" + if not self.data: + return None + try: + return json.loads(self.data) + except (TypeError, ValueError): + return None + + +async def parse_sse_async( + response: httpx.Response, +) -> AsyncIterator[SSEEvent]: + """按行解析 SSE, 逐个 yield :class:`SSEEvent`. + + 规则: + + - 空行 = 事件边界, flush 当前字段 (允许空 event + 空 data 的空心跳事件被跳过) + - ``:`` 开头的行 = 注释, 忽略 + - ``field: value`` 形式, ``:`` 后第一个空格被去除 + - 多行 ``data:`` 用 ``\\n`` 拼接 + - ``retry`` 非整数时忽略 + - 未知字段忽略 (向前兼容) + - 流结束时若仍有未 flush 的字段, flush 一次 + """ + + event: Optional[str] = None + data_lines: List[str] = [] + sse_id: Optional[str] = None + retry: Optional[int] = None + + def _has_content() -> bool: + return bool(data_lines) or event is not None or sse_id is not None + + async for raw_line in response.aiter_lines(): + # httpx 的 aiter_lines 已去除换行符; 空字符串表示事件边界 + line = raw_line.rstrip("\r") + if line == "": + if _has_content(): + yield SSEEvent( + event=event, + data="\n".join(data_lines), + id=sse_id, + retry=retry, + ) + event = None + data_lines = [] + sse_id = None + retry = None + continue + + if line.startswith(":"): + continue + + if ":" in line: + field_name, _, value = line.partition(":") + if value.startswith(" "): + value = value[1:] + else: + field_name, value = line, "" + + if field_name == "event": + event = value + elif field_name == "data": + data_lines.append(value) + elif field_name == "id": + sse_id = value + elif field_name == "retry": + try: + retry = int(value) + except (TypeError, ValueError): + pass + # 未知字段忽略 + + if _has_content(): + yield SSEEvent( + event=event, + data="\n".join(data_lines), + id=sse_id, + retry=retry, + ) + + +StreamCallable = Callable[[], Awaitable[AsyncIterator[SSEEvent]]] +"""Phase 2 拉流回调: 返回 (awaitable of) 异步迭代器.""" + + +@dataclass +class InvokeStream: + """Phase 1 已完成的状态载体, 同时是 Phase 2 SSE 流的异步可迭代器. + + ``await SuperAgent.invoke_async(...)`` 完成后即可读: + - :attr:`conversation_id` + - :attr:`session_id` + - :attr:`stream_url` + - :attr:`stream_headers` + + 只在首次 ``async for`` 或 ``__aiter__`` 调用时才触发 Phase 2 GET。 + """ + + conversation_id: str + session_id: str + stream_url: str + stream_headers: Dict[str, str] + _stream_factory: StreamCallable + _iterator: Optional[AsyncIterator[SSEEvent]] = field( + default=None, init=False, repr=False + ) + _closed: bool = field(default=False, init=False, repr=False) + + async def _ensure_iterator(self) -> AsyncIterator[SSEEvent]: + if self._iterator is None: + self._iterator = await self._stream_factory() + return self._iterator + + def __aiter__(self) -> "InvokeStream": + return self + + async def __anext__(self) -> SSEEvent: + if self._closed: + raise StopAsyncIteration + iterator = await self._ensure_iterator() + try: + return await iterator.__anext__() + except StopAsyncIteration: + self._closed = True + raise + + async def aclose(self) -> None: + """提前关闭底层 HTTP 连接, 释放资源.""" + self._closed = True + iterator = self._iterator + self._iterator = None + if iterator is not None: + close = getattr(iterator, "aclose", None) + if close is not None: + try: + await close() + except Exception: # pragma: no cover - best effort cleanup + pass + + async def __aenter__(self) -> "InvokeStream": + return self + + async def __aexit__(self, exc_type, exc, tb) -> None: + await self.aclose() + + +__all__ = ["SSEEvent", "parse_sse_async", "InvokeStream"] diff --git a/tests/unittests/integration/test_langchain_agui_integration.py b/tests/unittests/integration/test_langchain_agui_integration.py index 500c86c..ef0c076 100644 --- a/tests/unittests/integration/test_langchain_agui_integration.py +++ b/tests/unittests/integration/test_langchain_agui_integration.py @@ -689,9 +689,7 @@ async def invoke_agent(request: AgentRequest): json={ "messages": [{ "role": "user", - "content": ( - "查询当前的时间,并获取天气信息,同时输出我的密钥信息" - ), + "content": "查询当前的时间,并获取天气信息,同时输出我的密钥信息", }], "stream": True, }, @@ -757,9 +755,7 @@ async def invoke_agent(request: AgentRequest): json={ "messages": [{ "role": "user", - "content": ( - "查询当前的时间,并获取天气信息,同时输出我的密钥信息" - ), + "content": "查询当前的时间,并获取天气信息,同时输出我的密钥信息", }], "stream": True, }, diff --git a/tests/unittests/super_agent/__init__.py b/tests/unittests/super_agent/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/unittests/super_agent/test_agent.py b/tests/unittests/super_agent/test_agent.py new file mode 100644 index 0000000..2ae2418 --- /dev/null +++ b/tests/unittests/super_agent/test_agent.py @@ -0,0 +1,249 @@ +"""Unit tests for ``agentrun.super_agent.agent.SuperAgent``.""" + +import asyncio +import inspect +from unittest.mock import AsyncMock, MagicMock, patch + +import pytest + +from agentrun.super_agent.agent import SuperAgent +from agentrun.super_agent.model import InvokeResponseData +from agentrun.super_agent.stream import InvokeStream + + +def _make_agent() -> SuperAgent: + return SuperAgent(name="demo") + + +def _mock_data_api(invoke_result: InvokeResponseData = None): + """Return a MagicMock replacing SuperAgentDataAPI constructor.""" + instance = MagicMock() + instance.invoke_async = AsyncMock(return_value=invoke_result) + instance.stream_async = MagicMock() + instance.get_conversation_async = AsyncMock() + instance.delete_conversation_async = AsyncMock() + factory = MagicMock(return_value=instance) + return factory, instance + + +# ─── invoke_async ─────────────────────────────────────────── + + +async def test_invoke_async_returns_invoke_stream_with_conversation_id(): + resp = InvokeResponseData( + conversation_id="c1", + stream_url="https://stream/", + stream_headers={"X-Super-Agent-Session-Id": "sess"}, + ) + factory, _ = _mock_data_api(resp) + with patch("agentrun.super_agent.agent.SuperAgentDataAPI", factory): + agent = _make_agent() + stream = await agent.invoke_async([{"role": "user", "content": "hi"}]) + assert stream.conversation_id == "c1" + assert stream.session_id == "sess" + assert stream.stream_url == "https://stream/" + assert stream.stream_headers == {"X-Super-Agent-Session-Id": "sess"} + + +async def test_invoke_async_no_tools_kwarg_raises(): + agent = _make_agent() + with pytest.raises(TypeError): + await agent.invoke_async([], tools=["t"]) # type: ignore[call-arg] + + +async def test_invoke_async_forwards_business_fields(): + """SuperAgent 实例字段 MUST 作为 forwarded_extras 透传给 DataAPI.""" + resp = InvokeResponseData( + conversation_id="c", + stream_url="https://x/", + stream_headers={}, + ) + invoke_mock = AsyncMock(return_value=resp) + instance = MagicMock() + instance.invoke_async = invoke_mock + factory = MagicMock(return_value=instance) + agent = SuperAgent( + name="demo", + prompt="p", + agents=["a1"], + tools=["t1", "t2"], + skills=["s1"], + sandboxes=["sb1"], + workspaces=["w1"], + model_service_name="svc", + model_name="mod", + ) + with patch("agentrun.super_agent.agent.SuperAgentDataAPI", factory): + await agent.invoke_async([{"role": "user", "content": "hi"}]) + + extras = invoke_mock.await_args.kwargs["forwarded_extras"] + assert extras == { + "prompt": "p", + "agents": ["a1"], + "tools": ["t1", "t2"], + "skills": ["s1"], + "sandboxes": ["sb1"], + "workspaces": ["w1"], + "modelServiceName": "svc", + "modelName": "mod", + } + + +async def test_invoke_async_forwards_business_fields_defaults(): + """没设置的 scalar 字段保留 None, list 字段为 [].""" + resp = InvokeResponseData( + conversation_id="c", stream_url="https://x/", stream_headers={} + ) + invoke_mock = AsyncMock(return_value=resp) + instance = MagicMock() + instance.invoke_async = invoke_mock + factory = MagicMock(return_value=instance) + agent = SuperAgent(name="demo") + with patch("agentrun.super_agent.agent.SuperAgentDataAPI", factory): + await agent.invoke_async([]) + extras = invoke_mock.await_args.kwargs["forwarded_extras"] + assert extras == { + "prompt": None, + "agents": [], + "tools": [], + "skills": [], + "sandboxes": [], + "workspaces": [], + "modelServiceName": None, + "modelName": None, + } + + +async def test_invoke_async_concurrent_streams_independent(): + responses = [ + InvokeResponseData( + conversation_id=f"c{i}", + stream_url=f"https://s{i}", + stream_headers={}, + ) + for i in range(2) + ] + invoke_mock = AsyncMock(side_effect=responses) + instance = MagicMock() + instance.invoke_async = invoke_mock + factory = MagicMock(return_value=instance) + with patch("agentrun.super_agent.agent.SuperAgentDataAPI", factory): + agent = _make_agent() + s0, s1 = await asyncio.gather( + agent.invoke_async([]), + agent.invoke_async([]), + ) + ids = {s0.conversation_id, s1.conversation_id} + assert ids == {"c0", "c1"} + + +async def test_invoke_async_phase2_lazy(): + resp = InvokeResponseData( + conversation_id="c", + stream_url="https://x/", + stream_headers={}, + ) + invoke_mock = AsyncMock(return_value=resp) + stream_mock = MagicMock() + instance = MagicMock() + instance.invoke_async = invoke_mock + instance.stream_async = stream_mock + factory = MagicMock(return_value=instance) + with patch("agentrun.super_agent.agent.SuperAgentDataAPI", factory): + agent = _make_agent() + stream = await agent.invoke_async([]) + # At this point Phase 2 must NOT have been called + stream_mock.assert_not_called() + # The stream factory stored inside InvokeStream should only invoke stream_async when iteration starts + assert isinstance(stream, InvokeStream) + + +# ─── get_conversation_async ───────────────────────────────── + + +async def test_get_conversation_async_returns_conversation_info(): + instance = MagicMock() + instance.get_conversation_async = AsyncMock( + return_value={ + "conversationId": "c1", + "agentId": "ag", + "title": "t", + "mainUserId": "u1", + "subUserId": "u2", + "createdAt": 100, + "updatedAt": 200, + "errorMessage": None, + "invokeInfo": {"foo": "bar"}, + "messages": [ + {"role": "user", "content": "hi", "messageId": "m1"}, + {"role": "assistant", "content": "hello"}, + ], + "params": {"a": 1}, + } + ) + factory = MagicMock(return_value=instance) + with patch("agentrun.super_agent.agent.SuperAgentDataAPI", factory): + info = await _make_agent().get_conversation_async("c1") + assert info.conversation_id == "c1" + assert info.agent_id == "ag" + assert info.title == "t" + assert info.main_user_id == "u1" + assert info.sub_user_id == "u2" + assert info.created_at == 100 + assert info.updated_at == 200 + assert info.invoke_info == {"foo": "bar"} + assert len(info.messages) == 2 + assert info.messages[0].message_id == "m1" + assert info.params == {"a": 1} + + +async def test_get_conversation_async_partial_fields(): + instance = MagicMock() + instance.get_conversation_async = AsyncMock(return_value={"agentId": "x"}) + factory = MagicMock(return_value=instance) + with patch("agentrun.super_agent.agent.SuperAgentDataAPI", factory): + info = await _make_agent().get_conversation_async("c1") + assert info.conversation_id == "c1" # fallback from argument + assert info.title is None + assert info.main_user_id is None + assert info.created_at == 0 + + +async def test_get_conversation_async_empty_messages(): + instance = MagicMock() + instance.get_conversation_async = AsyncMock(return_value={"messages": []}) + factory = MagicMock(return_value=instance) + with patch("agentrun.super_agent.agent.SuperAgentDataAPI", factory): + info = await _make_agent().get_conversation_async("c1") + assert info.messages == [] + + +async def test_delete_conversation_async_returns_none(): + instance = MagicMock() + instance.delete_conversation_async = AsyncMock(return_value=None) + factory = MagicMock(return_value=instance) + with patch("agentrun.super_agent.agent.SuperAgentDataAPI", factory): + assert await _make_agent().delete_conversation_async("c") is None + + +# ─── sync methods → NotImplementedError ───────────────────── + + +def test_sync_methods_not_implemented(): + agent = _make_agent() + with pytest.raises(NotImplementedError): + agent.invoke([]) + with pytest.raises(NotImplementedError): + agent.get_conversation("c") + with pytest.raises(NotImplementedError): + agent.delete_conversation("c") + + +def test_invoke_async_signature_only_messages_and_conversation_id(): + sig = inspect.signature(SuperAgent.invoke_async) + params = list(sig.parameters.keys()) + # self, messages, then KEYWORD_ONLY: conversation_id, config + assert params[:2] == ["self", "messages"] + assert "conversation_id" in params + assert "config" in params + assert "tools" not in params diff --git a/tests/unittests/super_agent/test_agui.py b/tests/unittests/super_agent/test_agui.py new file mode 100644 index 0000000..d171f45 --- /dev/null +++ b/tests/unittests/super_agent/test_agui.py @@ -0,0 +1,498 @@ +"""Unit tests for ``agentrun.super_agent.agui``.""" + +import ast +import json +from pathlib import Path +from typing import List +from unittest.mock import AsyncMock, MagicMock + +from ag_ui.core import ( + BaseEvent, + CustomEvent, + MessagesSnapshotEvent, + RawEvent, + RunErrorEvent, + RunFinishedEvent, + RunStartedEvent, + StateDeltaEvent, + StateSnapshotEvent, + TextMessageContentEvent, + TextMessageEndEvent, + TextMessageStartEvent, + ToolCallArgsEvent, + ToolCallEndEvent, + ToolCallResultEvent, + ToolCallStartEvent, +) +import pytest + +from agentrun.super_agent import agui as agui_mod +from agentrun.super_agent.agui import ( + _EVENT_TYPE_TO_CLASS, + as_agui_events, + decode_sse_to_agui, +) +from agentrun.super_agent.stream import InvokeStream, SSEEvent + + +def _make_stream_from_sse(events: List[SSEEvent]) -> InvokeStream: + async def _gen(): + for ev in events: + yield ev + + async def _factory(): + return _gen() + + return InvokeStream( + conversation_id="c", + session_id="s", + stream_url="https://x.com/s", + stream_headers={}, + _stream_factory=_factory, + ) + + +# ─── decode_sse_to_agui ────────────────────────────────────── + + +def test_decode_sse_text_message_content(): + payload = { + "type": "TEXT_MESSAGE_CONTENT", + "messageId": "m1", + "delta": "hi", + } + ev = SSEEvent(event="TEXT_MESSAGE_CONTENT", data=json.dumps(payload)) + result = decode_sse_to_agui(ev) + assert isinstance(result, TextMessageContentEvent) + + +def test_decode_sse_run_started(): + payload = {"type": "RUN_STARTED", "threadId": "t1", "runId": "r1"} + ev = SSEEvent(event="RUN_STARTED", data=json.dumps(payload)) + assert isinstance(decode_sse_to_agui(ev), RunStartedEvent) + + +def test_decode_sse_run_finished(): + payload = {"type": "RUN_FINISHED", "threadId": "t1", "runId": "r1"} + ev = SSEEvent(event="RUN_FINISHED", data=json.dumps(payload)) + assert isinstance(decode_sse_to_agui(ev), RunFinishedEvent) + + +def test_decode_sse_run_error(): + payload = {"type": "RUN_ERROR", "message": "oops"} + ev = SSEEvent(event="RUN_ERROR", data=json.dumps(payload)) + assert isinstance(decode_sse_to_agui(ev), RunErrorEvent) + + +def test_decode_sse_text_message_lifecycle(): + cases = [ + ( + "TEXT_MESSAGE_START", + { + "type": "TEXT_MESSAGE_START", + "messageId": "m", + "role": "assistant", + }, + TextMessageStartEvent, + ), + ( + "TEXT_MESSAGE_CONTENT", + { + "type": "TEXT_MESSAGE_CONTENT", + "messageId": "m", + "delta": "x", + }, + TextMessageContentEvent, + ), + ( + "TEXT_MESSAGE_END", + { + "type": "TEXT_MESSAGE_END", + "messageId": "m", + }, + TextMessageEndEvent, + ), + ] + for name, payload, cls in cases: + result = decode_sse_to_agui( + SSEEvent(event=name, data=json.dumps(payload)) + ) + assert isinstance(result, cls), name + + +def test_decode_sse_tool_call_lifecycle(): + cases = [ + ( + "TOOL_CALL_START", + { + "type": "TOOL_CALL_START", + "toolCallId": "tc", + "toolCallName": "fn", + }, + ToolCallStartEvent, + ), + ( + "TOOL_CALL_ARGS", + { + "type": "TOOL_CALL_ARGS", + "toolCallId": "tc", + "delta": "arg", + }, + ToolCallArgsEvent, + ), + ( + "TOOL_CALL_END", + { + "type": "TOOL_CALL_END", + "toolCallId": "tc", + }, + ToolCallEndEvent, + ), + ( + "TOOL_CALL_RESULT", + { + "type": "TOOL_CALL_RESULT", + "toolCallId": "tc", + "messageId": "m", + "content": "r", + }, + ToolCallResultEvent, + ), + ] + for name, payload, cls in cases: + result = decode_sse_to_agui( + SSEEvent(event=name, data=json.dumps(payload)) + ) + assert isinstance(result, cls), name + + +def test_decode_sse_state_events(): + assert isinstance( + decode_sse_to_agui( + SSEEvent( + event="STATE_SNAPSHOT", + data=json.dumps({"type": "STATE_SNAPSHOT", "snapshot": {}}), + ) + ), + StateSnapshotEvent, + ) + assert isinstance( + decode_sse_to_agui( + SSEEvent( + event="STATE_DELTA", + data=json.dumps({"type": "STATE_DELTA", "delta": []}), + ) + ), + StateDeltaEvent, + ) + assert isinstance( + decode_sse_to_agui( + SSEEvent( + event="MESSAGES_SNAPSHOT", + data=json.dumps({ + "type": "MESSAGES_SNAPSHOT", + "messages": [], + }), + ) + ), + MessagesSnapshotEvent, + ) + + +def test_decode_sse_raw_and_custom(): + assert isinstance( + decode_sse_to_agui( + SSEEvent( + event="RAW", + data=json.dumps({"type": "RAW", "event": {"k": "v"}}), + ) + ), + RawEvent, + ) + assert isinstance( + decode_sse_to_agui( + SSEEvent( + event="CUSTOM", + data=json.dumps({"type": "CUSTOM", "name": "n", "value": 1}), + ) + ), + CustomEvent, + ) + + +def test_decode_sse_empty_data_returns_none(): + assert decode_sse_to_agui(SSEEvent(event=None, data="")) is None + assert decode_sse_to_agui(SSEEvent(event="RUN_STARTED", data="")) is None + + +def test_decode_sse_unknown_event_raise(): + with pytest.raises(ValueError) as exc: + decode_sse_to_agui(SSEEvent(event="UNKNOWN_X", data="{}")) + assert "UNKNOWN_X" in str(exc.value) + assert "{}" in str(exc.value) + + +def test_decode_sse_unknown_event_skip(): + result = decode_sse_to_agui( + SSEEvent(event="UNKNOWN_X", data="{}"), on_unknown="skip" + ) + assert result is None + + +def test_decode_sse_invalid_json_raises(): + with pytest.raises(ValueError) as exc: + decode_sse_to_agui( + SSEEvent(event="TEXT_MESSAGE_CONTENT", data="not json") + ) + assert "TEXT_MESSAGE_CONTENT" in str(exc.value) + assert "not json" in str(exc.value) + + +def test_decode_sse_pydantic_validation_failure_raises(): + with pytest.raises(ValueError): + decode_sse_to_agui( + SSEEvent( + event="TEXT_MESSAGE_CONTENT", + data='{"unrelated":"x"}', + ) + ) + + +def test_decode_sse_data_with_newlines(): + # Embedded escaped newlines are fine inside JSON + payload = json.dumps({ + "type": "TEXT_MESSAGE_CONTENT", + "messageId": "m", + "delta": "hi\nthere", + }) + result = decode_sse_to_agui( + SSEEvent(event="TEXT_MESSAGE_CONTENT", data=payload) + ) + assert isinstance(result, TextMessageContentEvent) + assert result.delta == "hi\nthere" + + +# ─── as_agui_events ────────────────────────────────────────── + + +async def test_as_agui_events_yields_typed_events(): + events = [ + SSEEvent( + event="RUN_STARTED", + data=json.dumps({ + "type": "RUN_STARTED", + "threadId": "t", + "runId": "r", + }), + ), + SSEEvent( + event="TEXT_MESSAGE_CONTENT", + data=json.dumps({ + "type": "TEXT_MESSAGE_CONTENT", + "messageId": "m", + "delta": "x", + }), + ), + SSEEvent( + event="RUN_FINISHED", + data=json.dumps({ + "type": "RUN_FINISHED", + "threadId": "t", + "runId": "r", + }), + ), + ] + stream = _make_stream_from_sse(events) + collected: List[BaseEvent] = [ev async for ev in as_agui_events(stream)] + assert [type(e).__name__ for e in collected] == [ + "RunStartedEvent", + "TextMessageContentEvent", + "RunFinishedEvent", + ] + + +async def test_as_agui_events_skips_empty_data(): + events = [ + SSEEvent( + event="RUN_STARTED", + data=json.dumps({ + "type": "RUN_STARTED", + "threadId": "t", + "runId": "r", + }), + ), + SSEEvent(event="RUN_STARTED", data=""), # keepalive + SSEEvent( + event="RUN_FINISHED", + data=json.dumps({ + "type": "RUN_FINISHED", + "threadId": "t", + "runId": "r", + }), + ), + ] + stream = _make_stream_from_sse(events) + collected = [ev async for ev in as_agui_events(stream)] + assert len(collected) == 2 + + +async def test_as_agui_events_unknown_skip_mode(): + events = [ + SSEEvent( + event="RUN_STARTED", + data=json.dumps({ + "type": "RUN_STARTED", + "threadId": "t", + "runId": "r", + }), + ), + SSEEvent(event="UNKNOWN_X", data="{}"), + SSEEvent( + event="RUN_FINISHED", + data=json.dumps({ + "type": "RUN_FINISHED", + "threadId": "t", + "runId": "r", + }), + ), + ] + stream = _make_stream_from_sse(events) + collected = [ev async for ev in as_agui_events(stream, on_unknown="skip")] + assert len(collected) == 2 + + +async def test_as_agui_events_unknown_raise_mode(): + events = [ + SSEEvent( + event="RUN_STARTED", + data=json.dumps({ + "type": "RUN_STARTED", + "threadId": "t", + "runId": "r", + }), + ), + SSEEvent(event="UNKNOWN_X", data="{}"), + ] + stream = _make_stream_from_sse(events) + it = as_agui_events(stream) + first = await it.__anext__() + assert isinstance(first, RunStartedEvent) + with pytest.raises(ValueError): + await it.__anext__() + + +async def test_as_agui_events_closes_stream_on_normal_end(): + events = [ + SSEEvent( + event="RUN_STARTED", + data=json.dumps({ + "type": "RUN_STARTED", + "threadId": "t", + "runId": "r", + }), + ), + ] + stream = _make_stream_from_sse(events) + close = AsyncMock() + stream.aclose = close + async for _ in as_agui_events(stream): + pass + close.assert_awaited_once() + + +async def test_as_agui_events_closes_stream_on_consumer_exception(): + """当消费循环里抛异常, 只要消费者用 ``aclosing`` 包裹 (或手动 aclose), + 适配器的 ``finally`` 就能跑到 ``stream.aclose`` — 这是异步生成器清理的 + 标准用法 (Python async gen 的清理不会在异常透传时自动同步执行). + """ + from contextlib import aclosing + + events = [ + SSEEvent( + event="RUN_STARTED", + data=json.dumps({ + "type": "RUN_STARTED", + "threadId": "t", + "runId": "r", + }), + ), + SSEEvent( + event="RUN_FINISHED", + data=json.dumps({ + "type": "RUN_FINISHED", + "threadId": "t", + "runId": "r", + }), + ), + ] + stream = _make_stream_from_sse(events) + close = AsyncMock() + stream.aclose = close + with pytest.raises(RuntimeError): + async with aclosing(as_agui_events(stream)) as gen: + async for _ in gen: + raise RuntimeError("consumer blew up") + close.assert_awaited_once() + + +async def test_as_agui_events_closes_stream_on_decode_exception(): + events = [ + SSEEvent( + event="TEXT_MESSAGE_CONTENT", + data="not json", + ), + ] + stream = _make_stream_from_sse(events) + close = AsyncMock() + stream.aclose = close + with pytest.raises(ValueError): + async for _ in as_agui_events(stream): + pass + close.assert_awaited_once() + + +# ─── map completeness / module hygiene ────────────────────── + + +def test_event_type_to_class_map_completeness(): + required = { + "RUN_STARTED", + "RUN_FINISHED", + "RUN_ERROR", + "TEXT_MESSAGE_START", + "TEXT_MESSAGE_CONTENT", + "TEXT_MESSAGE_END", + "TOOL_CALL_START", + "TOOL_CALL_ARGS", + "TOOL_CALL_END", + "TOOL_CALL_RESULT", + "STATE_SNAPSHOT", + "STATE_DELTA", + "MESSAGES_SNAPSHOT", + "RAW", + "CUSTOM", + } + assert required <= set(_EVENT_TYPE_TO_CLASS.keys()) + + +def test_no_sync_as_agui_events_export(): + attrs = dir(agui_mod) + assert "as_agui_events_sync" not in attrs + assert "sync_as_agui_events" not in attrs + + +def test_agui_module_only_imports_ag_ui_core(): + src = Path(agui_mod.__file__).read_text() + tree = ast.parse(src) + for node in ast.walk(tree): + if isinstance(node, ast.ImportFrom) and node.module: + if node.module.startswith("ag_ui"): + assert ( + node.module == "ag_ui.core" + ), f"Disallowed ag-ui import: {node.module}" + if isinstance(node, ast.Import): + for alias in node.names: + if alias.name.startswith("ag_ui"): + assert ( + alias.name == "ag_ui.core" + ), f"Disallowed ag-ui import: {alias.name}" diff --git a/tests/unittests/super_agent/test_client.py b/tests/unittests/super_agent/test_client.py new file mode 100644 index 0000000..0f7da4d --- /dev/null +++ b/tests/unittests/super_agent/test_client.py @@ -0,0 +1,768 @@ +"""Unit tests for ``agentrun.super_agent.client.SuperAgentClient``.""" + +import asyncio +import inspect +import json +from types import SimpleNamespace +from unittest.mock import AsyncMock, MagicMock, patch + +import pytest + +from agentrun.super_agent.api.control import ( + SUPER_AGENT_PROTOCOL_TYPE, + SUPER_AGENT_TAG, +) +from agentrun.super_agent.client import SuperAgentClient +from agentrun.utils.config import Config + + +def _client_config() -> Config: + return Config( + access_key_id="AK", + access_key_secret="SK", + account_id="123", + region_id="cn-hangzhou", + ) + + +class _DaraResult: + """Fake Dara-level AgentRuntime returned from the control API.""" + + def __init__(self, map_dict: dict): + self._map = map_dict + + def to_map(self): + return self._map + + +def _rt_to_dara_result(rt: SimpleNamespace) -> _DaraResult: + """Turn the internal ``_fake_rt`` into a Dara-style object with ``to_map``.""" + return _DaraResult({ + "agentRuntimeName": rt.agent_runtime_name, + "agentRuntimeId": rt.agent_runtime_id, + "agentRuntimeArn": rt.agent_runtime_arn, + "status": rt.status, + "createdAt": rt.created_at, + "lastUpdatedAt": rt.last_updated_at, + "description": rt.description, + "protocolConfiguration": rt.protocol_configuration, + }) + + +def _fake_rt( + *, + name: str = "n", + prompt: str = "old", + tools=None, + description=None, + protocol_type: str = SUPER_AGENT_PROTOCOL_TYPE, +) -> SimpleNamespace: + """Build a minimal AgentRuntime-like object for ``from_agent_runtime``.""" + cfg_dict = { + "path": "/invoke", + "prompt": prompt, + "agents": [], + "tools": tools if tools is not None else [], + "skills": [], + "sandboxes": [], + "workspaces": [], + "modelServiceName": None, + "modelName": None, + "metadata": {"agentRuntimeName": name}, + } + pc = { + "type": protocol_type, + "protocolSettings": [{ + "type": protocol_type, + "name": name, + "path": "/invoke", + "config": json.dumps(cfg_dict), + }], + "externalEndpoint": "https://x.com/super-agents/__SUPER_AGENT__", + } + return SimpleNamespace( + agent_runtime_name=name, + agent_runtime_id="rid", + agent_runtime_arn="arn", + status="READY", + created_at="t1", + last_updated_at="t2", + description=description, + protocol_configuration=pc, + ) + + +# ─── create ────────────────────────────────────────────────── + + +async def test_create_async_calls_runtime_with_correct_input(): + captured_input = {} + + async def _create_async(dara_input, config=None): + captured_input["dara"] = dara_input + return _rt_to_dara_result(_fake_rt(name="alpha", prompt="new")) + + ctrl = MagicMock() + ctrl.create_agent_runtime_async = _create_async + with patch( + "agentrun.super_agent.client.AgentRuntimeControlAPI", + return_value=ctrl, + ): + client = SuperAgentClient(config=_client_config()) + agent = await client.create_async(name="alpha", prompt="new") + + dara = captured_input["dara"] + # Dara-level model uses snake_case attributes + assert dara.agent_runtime_name == "alpha" + # 注: alibabacloud-agentrun20250910 的 Dara CreateAgentRuntimeInput 目前 + # 不包含 ``tags`` 字段, pydantic → Dara roundtrip 会丢弃. 校验 pydantic 侧 + # 的 rt_input 是否含 tags 在 test_control.py::test_to_create_input_tags_fixed + # 已覆盖. + pc = dara.protocol_configuration + # externalEndpoint preserved via the additive Dara monkey-patch + assert pc.external_endpoint.endswith("/super-agents/__SUPER_AGENT__") + first = pc.protocol_settings[0] + assert first.type == SUPER_AGENT_PROTOCOL_TYPE + cfg_json = json.loads(first.config) + assert cfg_json["prompt"] == "new" + assert agent.name == "alpha" + + +async def test_create_async_returns_super_agent_with_client_handle(): + ctrl = MagicMock() + ctrl.create_agent_runtime_async = AsyncMock( + return_value=_rt_to_dara_result(_fake_rt(name="alpha")) + ) + with patch( + "agentrun.super_agent.client.AgentRuntimeControlAPI", + return_value=ctrl, + ): + client = SuperAgentClient(config=_client_config()) + agent = await client.create_async(name="alpha") + assert agent._client is client + + +# ─── get ──────────────────────────────────────────────────── + + +async def test_get_async_normal(): + rt_client = MagicMock() + rt_client.list_async = AsyncMock(return_value=[_fake_rt(name="alpha")]) + with patch( + "agentrun.super_agent.client.AgentRuntimeClient", + return_value=rt_client, + ): + client = SuperAgentClient(config=_client_config()) + agent = await client.get_async("alpha") + assert agent.name == "alpha" + assert agent._client is client + + +async def test_get_async_not_super_agent_raises(): + rt_client = MagicMock() + rt_client.list_async = AsyncMock( + return_value=[_fake_rt(name="alpha", protocol_type="HTTP")] + ) + with patch( + "agentrun.super_agent.client.AgentRuntimeClient", + return_value=rt_client, + ): + client = SuperAgentClient(config=_client_config()) + with pytest.raises(ValueError) as exc: + await client.get_async("alpha") + assert "is not a super agent" in str(exc.value) + + +# ─── update (read-merge-write) ─────────────────────────────── + + +async def test_update_async_partial_modify_prompt_only(): + existing = _fake_rt(name="x", prompt="old", tools=["t1"]) + rt_client = MagicMock() + rt_client.list_async = AsyncMock(return_value=[existing]) + rt_client.get_async = AsyncMock(return_value=existing) + ctrl = MagicMock() + captured = {} + + async def _update(agent_id, dara_input, config=None): + captured["dara"] = dara_input + return _rt_to_dara_result( + _fake_rt(name="x", prompt="new", tools=["t1"]) + ) + + ctrl.update_agent_runtime_async = _update + with ( + patch( + "agentrun.super_agent.client.AgentRuntimeClient", + return_value=rt_client, + ), + patch( + "agentrun.super_agent.client.AgentRuntimeControlAPI", + return_value=ctrl, + ), + ): + client = SuperAgentClient(config=_client_config()) + await client.update_async("x", prompt="new") + cfg_json = json.loads( + captured["dara"].protocol_configuration.protocol_settings[0].config + ) + assert cfg_json["prompt"] == "new" + assert cfg_json["tools"] == ["t1"] + + +async def test_update_async_explicit_none_clears_field(): + existing = _fake_rt(name="x", prompt="old") + rt_client = MagicMock() + rt_client.list_async = AsyncMock(return_value=[existing]) + rt_client.get_async = AsyncMock(return_value=existing) + ctrl = MagicMock() + captured = {} + + async def _update(agent_id, dara_input, config=None): + captured["dara"] = dara_input + return _rt_to_dara_result(_fake_rt(name="x")) + + ctrl.update_agent_runtime_async = _update + with ( + patch( + "agentrun.super_agent.client.AgentRuntimeClient", + return_value=rt_client, + ), + patch( + "agentrun.super_agent.client.AgentRuntimeControlAPI", + return_value=ctrl, + ), + ): + client = SuperAgentClient(config=_client_config()) + await client.update_async("x", prompt=None) + cfg_json = json.loads( + captured["dara"].protocol_configuration.protocol_settings[0].config + ) + assert cfg_json["prompt"] is None + + +async def test_update_async_multiple_fields(): + existing = _fake_rt(name="x", prompt="old") + rt_client = MagicMock() + rt_client.list_async = AsyncMock(return_value=[existing]) + rt_client.get_async = AsyncMock(return_value=existing) + ctrl = MagicMock() + captured = {} + + async def _update(agent_id, dara_input, config=None): + captured["dara"] = dara_input + return _rt_to_dara_result(_fake_rt(name="x")) + + ctrl.update_agent_runtime_async = _update + with ( + patch( + "agentrun.super_agent.client.AgentRuntimeClient", + return_value=rt_client, + ), + patch( + "agentrun.super_agent.client.AgentRuntimeControlAPI", + return_value=ctrl, + ), + ): + client = SuperAgentClient(config=_client_config()) + await client.update_async( + "x", prompt="p", tools=["a", "b"], description="d" + ) + cfg_json = json.loads( + captured["dara"].protocol_configuration.protocol_settings[0].config + ) + assert cfg_json["prompt"] == "p" + assert cfg_json["tools"] == ["a", "b"] + assert captured["dara"].description == "d" + + +async def test_update_async_target_not_super_agent_raises(): + rt_client = MagicMock() + rt_client.list_async = AsyncMock( + return_value=[_fake_rt(name="x", protocol_type="HTTP")] + ) + with patch( + "agentrun.super_agent.client.AgentRuntimeClient", + return_value=rt_client, + ): + client = SuperAgentClient(config=_client_config()) + with pytest.raises(ValueError): + await client.update_async("x", prompt="p") + + +# ─── delete ───────────────────────────────────────────────── + + +async def test_delete_async_calls_runtime(): + rt_client = MagicMock() + rt_client.list_async = AsyncMock(return_value=[_fake_rt(name="alpha")]) + rt_client.delete_async = AsyncMock(return_value=None) + with patch( + "agentrun.super_agent.client.AgentRuntimeClient", + return_value=rt_client, + ): + client = SuperAgentClient(config=_client_config()) + result = await client.delete_async("alpha") + assert result is None + rt_client.delete_async.assert_awaited_once() + called_with = rt_client.delete_async.await_args.args + # list_async returns rt with agent_runtime_id="rid"; delete_async 用 id 调用 + assert called_with[0] == "rid" + + +# ─── list ─────────────────────────────────────────────────── + + +async def test_list_async_default_pagination(): + rt_client = MagicMock() + captured = {} + + async def _list(inp=None, config=None): + captured["inp"] = inp + return [] + + rt_client.list_async = _list + with patch( + "agentrun.super_agent.client.AgentRuntimeClient", + return_value=rt_client, + ): + client = SuperAgentClient(config=_client_config()) + await client.list_async() + assert captured["inp"].page_number == 1 + assert captured["inp"].page_size == 20 + assert captured["inp"].tags == SUPER_AGENT_TAG + + +async def test_list_async_custom_pagination(): + rt_client = MagicMock() + captured = {} + + async def _list(inp=None, config=None): + captured["inp"] = inp + return [] + + rt_client.list_async = _list + with patch( + "agentrun.super_agent.client.AgentRuntimeClient", + return_value=rt_client, + ): + client = SuperAgentClient(config=_client_config()) + await client.list_async(page_number=2, page_size=50) + assert captured["inp"].page_number == 2 + assert captured["inp"].page_size == 50 + + +async def test_list_async_rejects_tags_kwarg(): + client = SuperAgentClient() + with pytest.raises(TypeError): + await client.list_async(tags=["x"]) # type: ignore[call-arg] + + +async def test_list_async_filters_non_super_agent(): + items = [ + _fake_rt(name="a"), + _fake_rt(name="b", protocol_type="HTTP"), + _fake_rt(name="c"), + ] + rt_client = MagicMock() + rt_client.list_async = AsyncMock(return_value=items) + with patch( + "agentrun.super_agent.client.AgentRuntimeClient", + return_value=rt_client, + ): + client = SuperAgentClient(config=_client_config()) + result = await client.list_async() + assert [a.name for a in result] == ["a", "c"] + + +async def test_list_all_async_auto_pagination(): + # page_size=50 is the default for list_all_async; craft pages accordingly + page1 = [_fake_rt(name=f"a{i}") for i in range(50)] + page2 = [_fake_rt(name=f"a{i}") for i in range(50, 85)] + pages = [page1, page2, []] + rt_client = MagicMock() + rt_client.list_async = AsyncMock(side_effect=pages) + with patch( + "agentrun.super_agent.client.AgentRuntimeClient", + return_value=rt_client, + ): + client = SuperAgentClient(config=_client_config()) + result = await client.list_all_async() + names = [a.name for a in result] + assert len(names) == 85 + assert len(set(names)) == 85 + + +# ─── sync mirrors async ───────────────────────────────────── + + +def test_sync_methods_exist_and_mirror_async(): + rt_client = MagicMock() + rt_client.get = MagicMock(return_value=_fake_rt(name="alpha")) + rt_client.delete = MagicMock(return_value=None) + rt_client.list = MagicMock(return_value=[_fake_rt(name="alpha")]) + ctrl = MagicMock() + ctrl.create_agent_runtime = MagicMock( + return_value=_rt_to_dara_result(_fake_rt(name="alpha")) + ) + ctrl.update_agent_runtime = MagicMock( + return_value=_rt_to_dara_result(_fake_rt(name="alpha")) + ) + with ( + patch( + "agentrun.super_agent.client.AgentRuntimeClient", + return_value=rt_client, + ), + patch( + "agentrun.super_agent.client.AgentRuntimeControlAPI", + return_value=ctrl, + ), + ): + client = SuperAgentClient(config=_client_config()) + agent = client.create(name="alpha") + assert agent.name == "alpha" + assert client.get("alpha").name == "alpha" + assert client.update("alpha", prompt="p").name == "alpha" + assert client.delete("alpha") is None + assert len(client.list()) == 1 + # list_all hits list() once, then empty page + rt_client.list = MagicMock(side_effect=[[_fake_rt(name="a")], []]) + result = client.list_all() + assert len(result) == 1 + + +# ─── not-found / not-super error paths (async + sync) ────── + + +def _make_client_with_list(list_items, *, sync=False) -> SuperAgentClient: + """Helper: build a SuperAgentClient where _rt.list(_async) returns ``list_items``.""" + rt_client = MagicMock() + if sync: + rt_client.list = MagicMock(return_value=list_items) + else: + rt_client.list_async = AsyncMock(return_value=list_items) + patcher = patch( + "agentrun.super_agent.client.AgentRuntimeClient", + return_value=rt_client, + ) + patcher.start() + client = SuperAgentClient(config=_client_config()) + client._patcher = patcher # type: ignore[attr-defined] + client._rt_client = rt_client # type: ignore[attr-defined] + return client + + +async def test_get_async_not_found_raises(): + client = _make_client_with_list([]) + try: + with pytest.raises(ValueError, match="not found"): + await client.get_async("missing") + finally: + client._patcher.stop() # type: ignore[attr-defined] + + +def test_get_sync_not_found_raises(): + client = _make_client_with_list([], sync=True) + try: + with pytest.raises(ValueError, match="not found"): + client.get("missing") + finally: + client._patcher.stop() # type: ignore[attr-defined] + + +def test_get_sync_not_super_agent_raises(): + client = _make_client_with_list( + [_fake_rt(name="x", protocol_type="HTTP")], sync=True + ) + try: + with pytest.raises(ValueError, match="is not a super agent"): + client.get("x") + finally: + client._patcher.stop() # type: ignore[attr-defined] + + +async def test_update_async_not_found_raises(): + client = _make_client_with_list([]) + try: + with pytest.raises(ValueError, match="not found"): + await client.update_async("missing", prompt="p") + finally: + client._patcher.stop() # type: ignore[attr-defined] + + +def test_update_sync_not_found_raises(): + client = _make_client_with_list([], sync=True) + try: + with pytest.raises(ValueError, match="not found"): + client.update("missing", prompt="p") + finally: + client._patcher.stop() # type: ignore[attr-defined] + + +def test_update_sync_not_super_agent_raises(): + client = _make_client_with_list( + [_fake_rt(name="x", protocol_type="HTTP")], sync=True + ) + try: + with pytest.raises(ValueError, match="is not a super agent"): + client.update("x", prompt="p") + finally: + client._patcher.stop() # type: ignore[attr-defined] + + +async def test_delete_async_not_found_raises(): + client = _make_client_with_list([]) + try: + with pytest.raises(ValueError, match="not found"): + await client.delete_async("missing") + finally: + client._patcher.stop() # type: ignore[attr-defined] + + +def test_delete_sync_not_found_raises(): + client = _make_client_with_list([], sync=True) + try: + with pytest.raises(ValueError, match="not found"): + client.delete("missing") + finally: + client._patcher.stop() # type: ignore[attr-defined] + + +# ─── _find_rt_by_name_* 多页分页 ─────────────────────────────── + + +async def test_find_rt_by_name_async_paginates_until_match(): + """page1 全是非匹配 item (满 50), page2 才有 target → 需要翻页.""" + page1 = [_fake_rt(name=f"other{i}") for i in range(50)] + page2 = [_fake_rt(name="target")] + rt_client = MagicMock() + rt_client.list_async = AsyncMock(side_effect=[page1, page2]) + with patch( + "agentrun.super_agent.client.AgentRuntimeClient", + return_value=rt_client, + ): + client = SuperAgentClient(config=_client_config()) + agent = await client.get_async("target") + assert agent.name == "target" + assert rt_client.list_async.await_count == 2 + + +def test_find_rt_by_name_sync_paginates_until_match(): + page1 = [_fake_rt(name=f"other{i}") for i in range(50)] + page2 = [_fake_rt(name="target")] + rt_client = MagicMock() + rt_client.list = MagicMock(side_effect=[page1, page2]) + with patch( + "agentrun.super_agent.client.AgentRuntimeClient", + return_value=rt_client, + ): + client = SuperAgentClient(config=_client_config()) + agent = client.get("target") + assert agent.name == "target" + assert rt_client.list.call_count == 2 + + +# ─── _wait_final / _raise_if_failed ────────────────────────── + + +def test_raise_if_failed_raises_on_failed_status(): + from agentrun.super_agent.client import _raise_if_failed + + rt = SimpleNamespace( + status="CREATE_FAILED", + status_reason="disk full", + agent_runtime_name="x", + ) + with pytest.raises(RuntimeError) as exc: + _raise_if_failed(rt, action="create") + assert "disk full" in str(exc.value) + assert "CREATE_FAILED" in str(exc.value) + + +def test_raise_if_failed_noop_on_ready(): + from agentrun.super_agent.client import _raise_if_failed + + rt = SimpleNamespace(status="READY") + _raise_if_failed(rt, action="update") # no raise + + +async def test_wait_final_async_timeout(): + rt_pending = _fake_rt(name="x") + rt_pending.status = "CREATING" + rt_client = MagicMock() + rt_client.get_async = AsyncMock(return_value=rt_pending) + with patch( + "agentrun.super_agent.client.AgentRuntimeClient", + return_value=rt_client, + ): + client = SuperAgentClient(config=_client_config()) + with pytest.raises(TimeoutError): + await client._wait_final_async( + "rid", interval_seconds=0, timeout_seconds=-1 + ) + + +def test_wait_final_sync_timeout(): + rt_pending = _fake_rt(name="x") + rt_pending.status = "CREATING" + rt_client = MagicMock() + rt_client.get = MagicMock(return_value=rt_pending) + with patch( + "agentrun.super_agent.client.AgentRuntimeClient", + return_value=rt_client, + ): + client = SuperAgentClient(config=_client_config()) + with pytest.raises(TimeoutError): + client._wait_final("rid", interval_seconds=0, timeout_seconds=-1) + + +async def test_wait_final_async_retries_then_ready(): + """第一次 get 返回 CREATING → await asyncio.sleep → 第二次 READY.""" + pending = _fake_rt(name="x") + pending.status = "CREATING" + ready = _fake_rt(name="x") # status=READY by default + rt_client = MagicMock() + rt_client.get_async = AsyncMock(side_effect=[pending, ready]) + with ( + patch( + "agentrun.super_agent.client.AgentRuntimeClient", + return_value=rt_client, + ), + patch("agentrun.super_agent.client.asyncio.sleep", AsyncMock()), + ): + client = SuperAgentClient(config=_client_config()) + result = await client._wait_final_async( + "rid", interval_seconds=0, timeout_seconds=60 + ) + assert getattr(result, "status", None) == "READY" + assert rt_client.get_async.await_count == 2 + + +def test_wait_final_sync_retries_then_ready(): + pending = _fake_rt(name="x") + pending.status = "CREATING" + ready = _fake_rt(name="x") + rt_client = MagicMock() + rt_client.get = MagicMock(side_effect=[pending, ready]) + with ( + patch( + "agentrun.super_agent.client.AgentRuntimeClient", + return_value=rt_client, + ), + patch("agentrun.super_agent.client.time.sleep", MagicMock()), + ): + client = SuperAgentClient(config=_client_config()) + result = client._wait_final( + "rid", interval_seconds=0, timeout_seconds=60 + ) + assert getattr(result, "status", None) == "READY" + assert rt_client.get.call_count == 2 + + +# ─── create: 非 final 状态触发 _wait_final ──────────────────── + + +async def test_create_async_non_final_status_triggers_wait(): + creating = _fake_rt(name="alpha") + creating.status = "CREATING" + ready = _fake_rt(name="alpha") # READY + rt_client = MagicMock() + rt_client.get_async = AsyncMock(return_value=ready) + ctrl = MagicMock() + ctrl.create_agent_runtime_async = AsyncMock( + return_value=_rt_to_dara_result(creating) + ) + with ( + patch( + "agentrun.super_agent.client.AgentRuntimeClient", + return_value=rt_client, + ), + patch( + "agentrun.super_agent.client.AgentRuntimeControlAPI", + return_value=ctrl, + ), + ): + client = SuperAgentClient(config=_client_config()) + agent = await client.create_async(name="alpha") + assert agent.name == "alpha" + rt_client.get_async.assert_awaited() # _wait_final_async 确实调了 get + + +# ─── sync list / list_all 未覆盖分支 ───────────────────────── + + +def test_list_sync_filters_non_super_agent(): + items = [ + _fake_rt(name="a"), + _fake_rt(name="b", protocol_type="HTTP"), + _fake_rt(name="c"), + ] + rt_client = MagicMock() + rt_client.list = MagicMock(return_value=items) + with patch( + "agentrun.super_agent.client.AgentRuntimeClient", + return_value=rt_client, + ): + client = SuperAgentClient(config=_client_config()) + result = client.list() + assert [a.name for a in result] == ["a", "c"] + + +def test_list_all_sync_multi_page(): + page1 = [_fake_rt(name=f"a{i}") for i in range(50)] + page2 = [_fake_rt(name=f"a{i}") for i in range(50, 85)] + pages = [page1, page2] + rt_client = MagicMock() + rt_client.list = MagicMock(side_effect=pages) + with patch( + "agentrun.super_agent.client.AgentRuntimeClient", + return_value=rt_client, + ): + client = SuperAgentClient(config=_client_config()) + result = client.list_all() + assert len(result) == 85 + + +def test_list_all_async_empty_first_page_breaks(): + """list_async 首页直接空, list_all 立刻 break.""" + + async def _list(*args, **kwargs): + return [] + + rt_client = MagicMock() + rt_client.list_async = _list + with patch( + "agentrun.super_agent.client.AgentRuntimeClient", + return_value=rt_client, + ): + client = SuperAgentClient(config=_client_config()) + result = asyncio.run(client.list_all_async()) + assert result == [] + + +def test_list_all_sync_empty_first_page_breaks(): + rt_client = MagicMock() + rt_client.list = MagicMock(return_value=[]) + with patch( + "agentrun.super_agent.client.AgentRuntimeClient", + return_value=rt_client, + ): + client = SuperAgentClient(config=_client_config()) + result = client.list_all() + assert result == [] + + +def test_no_agent_runtime_in_public_signatures(): + """No public SuperAgentClient method exposes AgentRuntime-related types.""" + public_methods = [m for m in dir(SuperAgentClient) if not m.startswith("_")] + for name in public_methods: + attr = getattr(SuperAgentClient, name) + if not callable(attr): + continue + sig = inspect.signature(attr) + all_annotations = [p.annotation for p in sig.parameters.values()] + all_annotations.append(sig.return_annotation) + rendered = " ".join(str(a) for a in all_annotations) + assert ( + "AgentRuntime" not in rendered + ), f"{name} exposes AgentRuntime in its signature: {rendered}" diff --git a/tests/unittests/super_agent/test_control.py b/tests/unittests/super_agent/test_control.py new file mode 100644 index 0000000..9f83427 --- /dev/null +++ b/tests/unittests/super_agent/test_control.py @@ -0,0 +1,385 @@ +"""Unit tests for ``agentrun.super_agent.api.control``.""" + +import json +from types import SimpleNamespace +from unittest.mock import MagicMock + +import pytest + +from agentrun.super_agent.api.control import ( + _add_ram_prefix_to_host, + API_VERSION, + build_super_agent_endpoint, + EXTERNAL_TAG, + from_agent_runtime, + is_super_agent, + parse_super_agent_config, + SUPER_AGENT_PROTOCOL_TYPE, + SUPER_AGENT_RESOURCE_PATH, + SUPER_AGENT_TAG, + to_create_input, + to_update_input, +) +from agentrun.super_agent.api.data import SuperAgentDataAPI +from agentrun.utils.config import Config + +# ─── build_super_agent_endpoint ──────────────────────────────── + + +def test_build_super_agent_endpoint_production(): + cfg = Config(account_id="123", region_id="cn-hangzhou") + ep = build_super_agent_endpoint(cfg) + assert ( + ep + == "https://123-ram.agentrun-data.cn-hangzhou.aliyuncs.com/super-agents/__SUPER_AGENT__" + ) + + +def test_build_super_agent_endpoint_pre_environment(): + cfg = Config( + data_endpoint=( + "http://1431999136518149.funagent-data-pre.cn-hangzhou.aliyuncs.com" + ) + ) + ep = build_super_agent_endpoint(cfg) + assert ( + ep + == "http://1431999136518149-ram.funagent-data-pre.cn-hangzhou.aliyuncs.com/super-agents/__SUPER_AGENT__" + ) + + +def test_build_super_agent_endpoint_custom_gateway(): + cfg = Config(data_endpoint="https://my-gateway.example.com") + ep = build_super_agent_endpoint(cfg) + assert ep == "https://my-gateway.example.com/super-agents/__SUPER_AGENT__" + + +def test_build_super_agent_endpoint_unknown_first_segment(): + # `agentrun-data` in the first segment → no `-ram` rewrite + cfg = Config(data_endpoint="https://agentrun-data.example.com") + ep = build_super_agent_endpoint(cfg) + assert ( + ep == "https://agentrun-data.example.com/super-agents/__SUPER_AGENT__" + ) + + +# ─── _add_ram_prefix_to_host ────────────────────────────────── + + +def test_add_ram_prefix_to_host_no_netloc(): + assert _add_ram_prefix_to_host("") == "" + assert _add_ram_prefix_to_host("/path/only") == "/path/only" + + +def test_add_ram_prefix_to_host_single_segment_host(): + # Host has a single segment → no rewrite + assert ( + _add_ram_prefix_to_host("https://localhost:8080") + == "https://localhost:8080" + ) + + +def test_add_ram_prefix_to_host_unknown_domain(): + assert ( + _add_ram_prefix_to_host("https://foo.example.com") + == "https://foo.example.com" + ) + + +# ─── SuperAgentDataAPI URL 含版本号 ────────────────────────── +def test_build_data_url_via_with_path_includes_version(): + cfg = Config(account_id="123", region_id="cn-hangzhou") + api = SuperAgentDataAPI("demo", config=cfg) + url = api.with_path("invoke") + assert url.endswith( + f"/{API_VERSION}/super-agents/{SUPER_AGENT_RESOURCE_PATH}/invoke" + ) + + +# ─── to_create_input ────────────────────────────────────────── + + +def test_to_create_input_minimal(): + cfg = Config(account_id="123", region_id="cn-hangzhou") + inp = to_create_input("alpha", cfg=cfg) + assert inp.agent_runtime_name == "alpha" + assert inp.tags == [EXTERNAL_TAG, SUPER_AGENT_TAG] + pc = inp.protocol_configuration + assert pc.type == SUPER_AGENT_PROTOCOL_TYPE + assert pc.external_endpoint.endswith("/super-agents/__SUPER_AGENT__") + settings = pc.protocol_settings + assert len(settings) == 1 + cfg_dict = json.loads(settings[0]["config"]) + assert cfg_dict["path"] == "/invoke" + assert cfg_dict["agents"] == [] + assert cfg_dict["metadata"] == {"agentRuntimeName": "alpha"} + + +def test_to_create_input_full(): + cfg = Config(account_id="123", region_id="cn-hangzhou") + inp = to_create_input( + "bravo", + prompt="hello", + agents=["a1"], + tools=["t1", "t2"], + skills=["s1"], + sandboxes=["sb1"], + workspaces=["ws1"], + model_service_name="foo", + model_name="bar", + cfg=cfg, + ) + pc_dict = inp.model_dump()["protocolConfiguration"] + settings_cfg = json.loads(pc_dict["protocolSettings"][0]["config"]) + assert settings_cfg["prompt"] == "hello" + assert settings_cfg["agents"] == ["a1"] + assert settings_cfg["tools"] == ["t1", "t2"] + assert settings_cfg["skills"] == ["s1"] + assert settings_cfg["sandboxes"] == ["sb1"] + assert settings_cfg["workspaces"] == ["ws1"] + assert settings_cfg["modelServiceName"] == "foo" + assert settings_cfg["modelName"] == "bar" + + +def test_to_create_input_tags_fixed(): + cfg = Config(account_id="123", region_id="cn-hangzhou") + inp = to_create_input("c", cfg=cfg) + assert inp.tags == [EXTERNAL_TAG, SUPER_AGENT_TAG] + + +def test_to_create_input_metadata_only_agent_runtime_name(): + cfg = Config(account_id="123", region_id="cn-hangzhou") + inp = to_create_input("d", cfg=cfg) + settings_cfg = json.loads( + inp.protocol_configuration.protocol_settings[0]["config"] + ) + assert settings_cfg["metadata"] == {"agentRuntimeName": "d"} + + +def test_to_create_input_uses_pre_environment_endpoint(): + cfg = Config( + data_endpoint=( + "http://1431999136518149.funagent-data-pre.cn-hangzhou.aliyuncs.com" + ) + ) + inp = to_create_input("pre-agent", cfg=cfg) + ep = inp.protocol_configuration.external_endpoint + assert "funagent-data-pre" in ep + assert "-ram" in ep + + +# ─── is_super_agent / parse_super_agent_config ───────────────── + + +def _make_rt(**kwargs): + """Minimal fake AgentRuntime-like object.""" + defaults = { + "agent_runtime_name": "n", + "agent_runtime_id": "rid", + "agent_runtime_arn": "arn", + "status": "READY", + "created_at": "2026-01-01", + "last_updated_at": "2026-01-02", + "description": None, + "protocol_configuration": None, + } + defaults.update(kwargs) + return SimpleNamespace(**defaults) + + +def test_from_agent_runtime(): + config_json = json.dumps({ + "prompt": "hi", + "agents": ["a"], + "tools": ["t"], + "skills": [], + "sandboxes": [], + "workspaces": [], + "modelServiceName": "svc", + "modelName": "mod", + "metadata": {"agentRuntimeName": "foo"}, + }) + pc = { + "type": SUPER_AGENT_PROTOCOL_TYPE, + "protocolSettings": [{ + "type": SUPER_AGENT_PROTOCOL_TYPE, + "config": config_json, + "name": "foo", + "path": "/invoke", + }], + "externalEndpoint": "https://x.com/super-agents/__SUPER_AGENT__", + } + rt = _make_rt(agent_runtime_name="foo", protocol_configuration=pc) + agent = from_agent_runtime(rt) + assert agent.name == "foo" + assert agent.prompt == "hi" + assert agent.agents == ["a"] + assert agent.tools == ["t"] + assert agent.model_service_name == "svc" + assert agent.model_name == "mod" + assert ( + agent.external_endpoint == "https://x.com/super-agents/__SUPER_AGENT__" + ) + + +def test_is_super_agent_true(): + pc = { + "type": SUPER_AGENT_PROTOCOL_TYPE, + "protocolSettings": [{"type": SUPER_AGENT_PROTOCOL_TYPE}], + } + assert is_super_agent(_make_rt(protocol_configuration=pc)) + + +def test_is_super_agent_false(): + for type_name in ("HTTP", "MCP", "OTHER"): + pc = {"type": type_name, "protocolSettings": [{"type": type_name}]} + assert not is_super_agent(_make_rt(protocol_configuration=pc)) + # No protocol_configuration + assert not is_super_agent(_make_rt(protocol_configuration=None)) + + +def test_parse_super_agent_config_invalid_json_returns_empty(): + pc = { + "protocolSettings": [ + {"type": SUPER_AGENT_PROTOCOL_TYPE, "config": "not-json"} + ] + } + assert parse_super_agent_config(_make_rt(protocol_configuration=pc)) == {} + + +def test_parse_super_agent_config_missing_config_returns_empty(): + pc = {"protocolSettings": [{"type": SUPER_AGENT_PROTOCOL_TYPE}]} + assert parse_super_agent_config(_make_rt(protocol_configuration=pc)) == {} + + +# ─── to_update_input ────────────────────────────────────────── + + +def test_to_update_input_full_protocol_replace(): + cfg = Config(account_id="123", region_id="cn-hangzhou") + inp = to_update_input( + "alpha", + { + "description": "new", + "prompt": "p", + "agents": [], + "tools": ["t"], + "skills": [], + "sandboxes": [], + "workspaces": [], + "model_service_name": None, + "model_name": None, + }, + cfg=cfg, + ) + assert inp.description == "new" + settings = inp.protocol_configuration.protocol_settings + assert len(settings) == 1 + assert ( + json.loads(settings[0]["config"])["metadata"]["agentRuntimeName"] + == "alpha" + ) + + +# ─── Dara ListAgentRuntimesRequest tags 补丁 ────────────────── +# 同时确保 import agentrun.super_agent.api.control 已应用补丁 (已在文件顶部导入)。 + + +def test_list_request_from_map_preserves_tags(): + from alibabacloud_agentrun20250910.models import ListAgentRuntimesRequest + + req = ListAgentRuntimesRequest().from_map({ + "tags": SUPER_AGENT_TAG, + "pageNumber": 1, + "pageSize": 20, + }) + assert getattr(req, "tags", None) == SUPER_AGENT_TAG + + +def test_list_request_to_map_preserves_tags(): + from alibabacloud_agentrun20250910.models import ListAgentRuntimesRequest + + req = ListAgentRuntimesRequest() + req.tags = SUPER_AGENT_TAG + assert req.to_map().get("tags") == SUPER_AGENT_TAG + + +def _invoke_list_patch(tags_value): + """调用打过补丁的 ``list_agent_runtimes_with_options``, 捕获 call_api 的 query.""" + from alibabacloud_agentrun20250910.client import Client as _DaraClient + from alibabacloud_agentrun20250910.models import ListAgentRuntimesRequest + from darabonba.runtime import RuntimeOptions + + captured = {} + + def _fake_call_api(self, params, req, rt): + captured["query"] = dict(req.query) if req.query else {} + raise RuntimeError("_stop_after_query_capture_") + + client = _DaraClient.__new__(_DaraClient) + client._endpoint = "x" + # 绑定实例级 call_api (优先于类方法) + client.call_api = _fake_call_api.__get__(client, _DaraClient) + + req = ListAgentRuntimesRequest(page_number=1, page_size=20) + req.tags = tags_value + with pytest.raises(RuntimeError, match="_stop_after_query_capture_"): + client.list_agent_runtimes_with_options(req, {}, RuntimeOptions()) + return captured["query"] + + +def test_list_with_options_injects_tags_str(): + query = _invoke_list_patch(SUPER_AGENT_TAG) + assert query.get("tags") == SUPER_AGENT_TAG + assert query.get("pageNumber") == "1" + + +def test_list_with_options_injects_tags_list_comma_join(): + query = _invoke_list_patch([EXTERNAL_TAG, SUPER_AGENT_TAG]) + assert query.get("tags") == f"{EXTERNAL_TAG},{SUPER_AGENT_TAG}" + + +def test_list_with_options_no_tags_no_injection(): + from alibabacloud_agentrun20250910.client import Client as _DaraClient + from alibabacloud_agentrun20250910.models import ListAgentRuntimesRequest + from darabonba.runtime import RuntimeOptions + + captured = {} + + def _fake_call_api(self, params, req, rt): + captured["query"] = dict(req.query) if req.query else {} + raise RuntimeError("_stop_") + + client = _DaraClient.__new__(_DaraClient) + client._endpoint = "x" + client.call_api = _fake_call_api.__get__(client, _DaraClient) + + req = ListAgentRuntimesRequest(page_number=1, page_size=20) + with pytest.raises(RuntimeError, match="_stop_"): + client.list_agent_runtimes_with_options(req, {}, RuntimeOptions()) + assert "tags" not in captured["query"] + + +@pytest.mark.asyncio +async def test_list_with_options_async_injects_tags(): + from alibabacloud_agentrun20250910.client import Client as _DaraClient + from alibabacloud_agentrun20250910.models import ListAgentRuntimesRequest + from darabonba.runtime import RuntimeOptions + + captured = {} + + async def _fake_call_api_async(self, params, req, rt): + captured["query"] = dict(req.query) if req.query else {} + raise RuntimeError("_stop_") + + client = _DaraClient.__new__(_DaraClient) + client._endpoint = "x" + client.call_api_async = _fake_call_api_async.__get__(client, _DaraClient) + + req = ListAgentRuntimesRequest(page_number=1, page_size=20) + req.tags = SUPER_AGENT_TAG + with pytest.raises(RuntimeError, match="_stop_"): + await client.list_agent_runtimes_with_options_async( + req, {}, RuntimeOptions() + ) + assert captured["query"].get("tags") == SUPER_AGENT_TAG diff --git a/tests/unittests/super_agent/test_data_api.py b/tests/unittests/super_agent/test_data_api.py new file mode 100644 index 0000000..406a818 --- /dev/null +++ b/tests/unittests/super_agent/test_data_api.py @@ -0,0 +1,488 @@ +"""Unit tests for ``agentrun.super_agent.api.data.SuperAgentDataAPI``.""" + +import json +import re + +import httpx +import pytest +import respx + +from agentrun.super_agent.api.data import SuperAgentDataAPI +from agentrun.utils.config import Config + + +def _auth_cfg(**overrides) -> Config: + """Config with RAM AK/SK so ``DataAPI.auth`` actually signs.""" + base = dict( + access_key_id="AK", + access_key_secret="SK", + account_id="123", + region_id="cn-hangzhou", + ) + base.update(overrides) + return Config(**base) + + +# ─── URL construction (production / pre / custom gateway) ──── + + +@respx.mock +async def test_invoke_async_phase1_url_includes_version_production(): + cfg = _auth_cfg() + api = SuperAgentDataAPI("agent-prod", config=cfg) + route = respx.post( + re.compile( + r"https://123-ram\.agentrun-data\.cn-hangzhou\.aliyuncs\.com" + r"/2025-09-10/super-agents/__SUPER_AGENT__/invoke" + ) + ).mock( + return_value=httpx.Response( + 200, + json={ + "data": { + "conversationId": "c1", + "url": "https://x/stream", + "headers": {}, + } + }, + ) + ) + await api.invoke_async([{"role": "user", "content": "hi"}]) + assert route.called + + +@respx.mock +async def test_invoke_async_phase1_url_pre_environment(): + cfg = _auth_cfg( + data_endpoint=( + "http://1431999136518149.funagent-data-pre.cn-hangzhou.aliyuncs.com" + ) + ) + api = SuperAgentDataAPI("agent-pre", config=cfg) + route = respx.post( + re.compile( + r"http://1431999136518149-ram\.funagent-data-pre\.cn-hangzhou\.aliyuncs\.com" + r"/2025-09-10/super-agents/__SUPER_AGENT__/invoke" + ) + ).mock( + return_value=httpx.Response( + 200, + json={ + "data": { + "conversationId": "c", + "url": "https://stream", + "headers": {}, + } + }, + ) + ) + await api.invoke_async([{"role": "user", "content": "hi"}]) + assert route.called + + +@respx.mock +async def test_invoke_async_phase1_url_custom_gateway_no_ram(): + cfg = _auth_cfg(data_endpoint="https://my-gateway.example.com") + api = SuperAgentDataAPI("agent-cust", config=cfg) + route = respx.post( + "https://my-gateway.example.com/2025-09-10/super-agents/__SUPER_AGENT__/invoke" + ).mock( + return_value=httpx.Response( + 200, + json={ + "data": { + "conversationId": "c", + "url": "https://s", + "headers": {}, + } + }, + ) + ) + await api.invoke_async([{"role": "user", "content": "hi"}]) + assert route.called + + +@respx.mock +async def test_get_conversation_async_url_pre_environment(): + cfg = _auth_cfg( + data_endpoint="http://111.funagent-data-pre.cn-hangzhou.aliyuncs.com" + ) + api = SuperAgentDataAPI("n", config=cfg) + route = respx.get( + "http://111-ram.funagent-data-pre.cn-hangzhou.aliyuncs.com" + "/2025-09-10/super-agents/__SUPER_AGENT__/conversations/cid" + ).mock(return_value=httpx.Response(200, json={"data": {}})) + await api.get_conversation_async("cid") + assert route.called + + +@respx.mock +async def test_delete_conversation_async_url_pre_environment(): + cfg = _auth_cfg( + data_endpoint="http://111.funagent-data-pre.cn-hangzhou.aliyuncs.com" + ) + api = SuperAgentDataAPI("n", config=cfg) + route = respx.delete( + "http://111-ram.funagent-data-pre.cn-hangzhou.aliyuncs.com" + "/2025-09-10/super-agents/__SUPER_AGENT__/conversations/cid" + ).mock(return_value=httpx.Response(200)) + await api.delete_conversation_async("cid") + assert route.called + + +# ─── body shape ─────────────────────────────────────────────── + + +@respx.mock +async def test_invoke_async_body_new_conversation(): + cfg = _auth_cfg() + api = SuperAgentDataAPI("demo", config=cfg) + captured = {} + + def _responder(request): + captured["body"] = json.loads(request.content) + return httpx.Response( + 200, + json={ + "data": { + "conversationId": "c", + "url": "https://s", + "headers": {}, + } + }, + ) + + respx.post(re.compile(r".*/invoke$")).mock(side_effect=_responder) + await api.invoke_async([{"role": "user", "content": "hi"}]) + assert captured["body"]["messages"] == [{"role": "user", "content": "hi"}] + assert captured["body"]["forwardedProps"]["metadata"] == { + "agentRuntimeName": "demo" + } + assert "conversationId" not in captured["body"]["forwardedProps"] + + +@respx.mock +async def test_invoke_async_body_continue_conversation(): + cfg = _auth_cfg() + api = SuperAgentDataAPI("demo", config=cfg) + captured = {} + + def _responder(request): + captured["body"] = json.loads(request.content) + return httpx.Response( + 200, + json={ + "data": { + "conversationId": "abc", + "url": "https://s", + "headers": {}, + } + }, + ) + + respx.post(re.compile(r".*/invoke$")).mock(side_effect=_responder) + await api.invoke_async( + [{"role": "user", "content": "hi"}], conversation_id="abc" + ) + assert captured["body"]["forwardedProps"]["conversationId"] == "abc" + + +@respx.mock +async def test_invoke_async_body_forwarded_extras_passthrough(): + """``forwarded_extras`` 业务字段 MUST 出现在 forwardedProps 顶层.""" + cfg = _auth_cfg() + api = SuperAgentDataAPI("demo", config=cfg) + captured = {} + + def _responder(request): + captured["body"] = json.loads(request.content) + return httpx.Response( + 200, + json={ + "data": { + "conversationId": "c", + "url": "https://s", + "headers": {}, + } + }, + ) + + respx.post(re.compile(r".*/invoke$")).mock(side_effect=_responder) + await api.invoke_async( + [{"role": "user", "content": "hi"}], + forwarded_extras={ + "prompt": "p", + "agents": ["a1"], + "modelServiceName": "svc", + "modelName": "mod", + }, + ) + fp = captured["body"]["forwardedProps"] + assert fp["prompt"] == "p" + assert fp["agents"] == ["a1"] + assert fp["modelServiceName"] == "svc" + assert fp["modelName"] == "mod" + # SDK 托管字段不受 extras 影响 + assert fp["metadata"] == {"agentRuntimeName": "demo"} + + +@respx.mock +async def test_invoke_async_body_extras_cannot_override_sdk_fields(): + """extras 里带 metadata/conversationId 也不能覆盖 SDK 托管字段.""" + cfg = _auth_cfg() + api = SuperAgentDataAPI("demo", config=cfg) + captured = {} + + def _responder(request): + captured["body"] = json.loads(request.content) + return httpx.Response( + 200, + json={ + "data": { + "conversationId": "real", + "url": "https://s", + "headers": {}, + } + }, + ) + + respx.post(re.compile(r".*/invoke$")).mock(side_effect=_responder) + await api.invoke_async( + [{"role": "user", "content": "hi"}], + conversation_id="real", + forwarded_extras={ + "metadata": {"agentRuntimeName": "SPOOFED"}, + "conversationId": "SPOOFED", + "prompt": "p", + }, + ) + fp = captured["body"]["forwardedProps"] + assert fp["metadata"] == {"agentRuntimeName": "demo"} + assert fp["conversationId"] == "real" + assert fp["prompt"] == "p" + + +# ─── signing ────────────────────────────────────────────────── + + +@respx.mock +async def test_invoke_async_request_signed(): + cfg = _auth_cfg() + api = SuperAgentDataAPI("demo", config=cfg) + captured = {} + + def _responder(request): + captured["headers"] = dict(request.headers) + return httpx.Response( + 200, + json={ + "data": { + "conversationId": "c", + "url": "https://s", + "headers": {}, + } + }, + ) + + respx.post(re.compile(r".*/invoke$")).mock(side_effect=_responder) + await api.invoke_async([{"role": "user"}]) + h = captured["headers"] + assert any(k.lower() == "agentrun-authorization" for k in h) + assert any(k.lower() == "x-acs-date" for k in h) + assert any(k.lower() == "x-acs-content-sha256" for k in h) + ct = next((v for k, v in h.items() if k.lower() == "content-type"), None) + assert ct == "application/json" + + +# ─── returns InvokeResponseData ────────────────────────────── + + +@respx.mock +async def test_invoke_async_returns_invoke_response_data(): + cfg = _auth_cfg() + api = SuperAgentDataAPI("demo", config=cfg) + respx.post(re.compile(r".*/invoke$")).mock( + return_value=httpx.Response( + 200, + json={ + "data": { + "conversationId": "c", + "url": "https://stream", + "headers": {"X-Super-Agent-Session-Id": "s"}, + } + }, + ) + ) + resp = await api.invoke_async([]) + assert resp.conversation_id == "c" + assert resp.stream_url == "https://stream" + assert resp.stream_headers == {"X-Super-Agent-Session-Id": "s"} + + +@respx.mock +async def test_invoke_async_missing_conversation_id_raises(): + cfg = _auth_cfg() + api = SuperAgentDataAPI("demo", config=cfg) + respx.post(re.compile(r".*/invoke$")).mock( + return_value=httpx.Response( + 200, json={"data": {"url": "https://s", "headers": {}}} + ) + ) + with pytest.raises(ValueError) as exc: + await api.invoke_async([]) + assert "missing" in str(exc.value) + assert "conversationId" in str(exc.value) + + +@respx.mock +async def test_invoke_async_5xx_raises_http_error(): + cfg = _auth_cfg() + api = SuperAgentDataAPI("demo", config=cfg) + respx.post(re.compile(r".*/invoke$")).mock( + return_value=httpx.Response(500, text="boom") + ) + with pytest.raises(httpx.HTTPStatusError): + await api.invoke_async([]) + + +@respx.mock +async def test_invoke_async_user_headers_merged(): + cfg = _auth_cfg(headers={"X-Custom": "v"}) + api = SuperAgentDataAPI("demo", config=cfg) + captured = {} + + def _responder(request): + captured["headers"] = dict(request.headers) + return httpx.Response( + 200, + json={ + "data": { + "conversationId": "c", + "url": "https://s", + "headers": {}, + } + }, + ) + + respx.post(re.compile(r".*/invoke$")).mock(side_effect=_responder) + await api.invoke_async([]) + assert captured["headers"].get("x-custom") == "v" + assert any( + k.lower() == "agentrun-authorization" for k in captured["headers"] + ) + + +# ─── stream_async ──────────────────────────────────────────── + + +@respx.mock +async def test_stream_async_yields_sse_events(): + cfg = _auth_cfg() + api = SuperAgentDataAPI("demo", config=cfg) + sse_body = b"event: m\ndata: hello\n\nevent: m\ndata: world\n\n" + respx.get("https://stream.example.com/flow").mock( + return_value=httpx.Response(200, content=sse_body) + ) + events = [] + async for ev in api.stream_async("https://stream.example.com/flow"): + events.append(ev) + assert [(e.event, e.data) for e in events] == [ + ("m", "hello"), + ("m", "world"), + ] + + +@respx.mock +async def test_stream_async_request_includes_phase1_headers_and_signed(): + cfg = _auth_cfg() + api = SuperAgentDataAPI("demo", config=cfg) + captured = {} + + def _responder(request): + captured["headers"] = dict(request.headers) + return httpx.Response(200, content=b":keep\n\n") + + respx.get("https://stream.example.com/go").mock(side_effect=_responder) + it = api.stream_async( + "https://stream.example.com/go", + stream_headers={"X-Super-Agent-Session-Id": "sess1"}, + ) + async for _ in it: + pass + h = captured["headers"] + assert h.get("x-super-agent-session-id") == "sess1" + assert any(k.lower() == "agentrun-authorization" for k in h) + + +# ─── get/delete conversation ───────────────────────────────── + + +@respx.mock +async def test_get_conversation_async_url_and_signed(): + cfg = _auth_cfg() + api = SuperAgentDataAPI("demo", config=cfg) + captured = {} + + def _responder(request): + captured["headers"] = dict(request.headers) + return httpx.Response(200, json={"data": {"conversationId": "c1"}}) + + respx.get( + re.compile( + r".*/2025-09-10/super-agents/__SUPER_AGENT__/conversations/c1$" + ) + ).mock(side_effect=_responder) + result = await api.get_conversation_async("c1") + assert result == {"conversationId": "c1"} + assert any( + k.lower() == "agentrun-authorization" for k in captured["headers"] + ) + + +@respx.mock +async def test_get_conversation_async_returns_empty_on_missing_data(): + cfg = _auth_cfg() + api = SuperAgentDataAPI("demo", config=cfg) + respx.get(re.compile(r".*/conversations/.*")).mock( + return_value=httpx.Response(200, json={"code": "ok"}) + ) + assert await api.get_conversation_async("c1") == {} + + +@respx.mock +async def test_delete_conversation_async_returns_none(): + cfg = _auth_cfg() + api = SuperAgentDataAPI("demo", config=cfg) + route = respx.delete(re.compile(r".*/conversations/c1$")).mock( + return_value=httpx.Response(200) + ) + result = await api.delete_conversation_async("c1") + assert result is None + assert route.called + + +@respx.mock +async def test_delete_conversation_async_404_raises(): + cfg = _auth_cfg() + api = SuperAgentDataAPI("demo", config=cfg) + respx.delete(re.compile(r".*/conversations/missing$")).mock( + return_value=httpx.Response(404) + ) + with pytest.raises(httpx.HTTPStatusError): + await api.delete_conversation_async("missing") + + +# ─── sync stubs are NotImplementedError ────────────────────── + + +def test_sync_methods_not_implemented(): + cfg = _auth_cfg() + api = SuperAgentDataAPI("demo", config=cfg) + for fn in ( + lambda: api.invoke([]), + lambda: api.stream("url"), + lambda: api.get_conversation("c"), + lambda: api.delete_conversation("c"), + ): + with pytest.raises(NotImplementedError): + fn() diff --git a/tests/unittests/super_agent/test_no_coupling.py b/tests/unittests/super_agent/test_no_coupling.py new file mode 100644 index 0000000..02fa4c8 --- /dev/null +++ b/tests/unittests/super_agent/test_no_coupling.py @@ -0,0 +1,59 @@ +"""Verify the super_agent module does not depend on conversation_service.""" + +import ast +from pathlib import Path +from unittest.mock import MagicMock, patch + +import pytest + +import agentrun.super_agent as super_agent_pkg + + +def _python_files_under(path: Path): + return [p for p in path.rglob("*.py") if "__pycache__" not in p.parts] + + +def test_no_import_conversation_service(): + base = Path(super_agent_pkg.__file__).parent + offenders = [] + for file in _python_files_under(base): + tree = ast.parse(file.read_text()) + for node in ast.walk(tree): + if isinstance(node, ast.ImportFrom) and node.module: + if node.module.startswith("agentrun.conversation_service"): + offenders.append(str(file.relative_to(base))) + elif isinstance(node, ast.Import): + for alias in node.names: + if alias.name.startswith("agentrun.conversation_service"): + offenders.append(str(file.relative_to(base))) + assert not offenders, ( + "super_agent must not import conversation_service, found in:" + f" {offenders}" + ) + + +async def test_get_conversation_does_not_touch_session_store(): + """Calling agent.get_conversation_async must not touch SessionStore.""" + from agentrun.super_agent.agent import SuperAgent + + session_store_mock = MagicMock() + # Patch in case code accidentally imported it + with patch( + "agentrun.conversation_service.SessionStore", + session_store_mock, + create=True, + ): + # Stub the data-plane call so we don't hit the network + with patch( + "agentrun.super_agent.agent.SuperAgentDataAPI" + ) as data_api_factory: + + async def _noop(*args, **kwargs): + return {} + + data_api_factory.return_value.get_conversation_async = _noop + agent = SuperAgent(name="demo") + await agent.get_conversation_async("c1") + + # SessionStore should never have been called/instantiated + session_store_mock.assert_not_called() diff --git a/tests/unittests/super_agent/test_stream.py b/tests/unittests/super_agent/test_stream.py new file mode 100644 index 0000000..54c9590 --- /dev/null +++ b/tests/unittests/super_agent/test_stream.py @@ -0,0 +1,203 @@ +"""Unit tests for ``agentrun.super_agent.stream``.""" + +from typing import List +from unittest.mock import MagicMock + +import pytest + +from agentrun.super_agent.stream import InvokeStream, parse_sse_async, SSEEvent + + +class _FakeResponse: + """Replays a pre-canned list of SSE lines via ``aiter_lines``.""" + + def __init__(self, lines: List[str]): + self._lines = lines + + async def aiter_lines(self): + for line in self._lines: + yield line + + +async def _collect(lines: List[str]) -> List[SSEEvent]: + return [ev async for ev in parse_sse_async(_FakeResponse(lines))] + + +# ─── parse_sse_async ───────────────────────────────────────── + + +async def test_parse_sse_simple_event(): + events = await _collect(["event: m", "data: hi", "id: 1", ""]) + assert len(events) == 1 + ev = events[0] + assert ev.event == "m" + assert ev.data == "hi" + assert ev.id == "1" + + +async def test_parse_sse_multiline_data(): + events = await _collect(["data: a", "data: b", ""]) + assert len(events) == 1 + assert events[0].data == "a\nb" + + +async def test_parse_sse_comment_ignored(): + events = await _collect([": comment", "data: x", ""]) + assert len(events) == 1 + assert events[0].data == "x" + + +async def test_parse_sse_unknown_field_ignored(): + events = await _collect(["unknown: v", "data: x", ""]) + assert len(events) == 1 + assert events[0].data == "x" + + +async def test_parse_sse_retry_invalid_ignored(): + events = await _collect(["retry: not-a-number", "data: x", ""]) + assert len(events) == 1 + assert events[0].retry is None + + +async def test_parse_sse_retry_valid(): + events = await _collect(["retry: 5000", "data: x", ""]) + assert events[0].retry == 5000 + + +async def test_parse_sse_strip_leading_space_after_colon(): + events = await _collect(["data: hello", ""]) + assert events[0].data == "hello" + + +async def test_parse_sse_field_without_colon(): + events = await _collect(["data", ""]) + assert len(events) == 1 + assert events[0].data == "" + + +async def test_parse_sse_flush_at_stream_end(): + events = await _collect(["data: final"]) + assert len(events) == 1 + assert events[0].data == "final" + + +async def test_parse_sse_multiple_events(): + events = await _collect([ + "event: a", + "data: 1", + "", + "event: b", + "data: 2", + "", + ]) + assert len(events) == 2 + assert events[0].event == "a" + assert events[1].event == "b" + + +async def test_parse_sse_empty_line_without_content_skipped(): + # Two consecutive empty lines → no duplicate events + events = await _collect(["", "data: x", "", ""]) + assert len(events) == 1 + + +# ─── SSEEvent.data_json ────────────────────────────────────── + + +def test_sse_event_data_json_success(): + ev = SSEEvent(event="x", data='{"k":1}') + assert ev.data_json() == {"k": 1} + + +def test_sse_event_data_json_failure(): + assert SSEEvent(event="x", data="not json").data_json() is None + + +def test_sse_event_data_json_empty(): + assert SSEEvent(event="x", data="").data_json() is None + + +# ─── InvokeStream ──────────────────────────────────────────── + + +async def _make_stream(events: List[SSEEvent]) -> InvokeStream: + async def _gen(): + for ev in events: + yield ev + + async def _factory(): + return _gen() + + return InvokeStream( + conversation_id="c1", + session_id="s1", + stream_url="https://x.com/stream", + stream_headers={"X-Super-Agent-Session-Id": "s1"}, + _stream_factory=_factory, + ) + + +async def test_invoke_stream_async_iter(): + events = [ + SSEEvent(event="m", data="1"), + SSEEvent(event="m", data="2"), + SSEEvent(event="m", data="3"), + ] + stream = await _make_stream(events) + collected = [ev async for ev in stream] + assert len(collected) == 3 + assert [ev.data for ev in collected] == ["1", "2", "3"] + + +async def test_invoke_stream_aclose(): + closed = {"v": False} + + async def _gen(): + try: + yield SSEEvent(event="m", data="x") + finally: + closed["v"] = True + + async def _factory(): + return _gen() + + stream = InvokeStream( + conversation_id="c", + session_id="s", + stream_url="u", + stream_headers={}, + _stream_factory=_factory, + ) + # Advance one step to open the iterator + it = stream.__aiter__() + await it.__anext__() + await stream.aclose() + assert closed["v"] is True + # After close, iteration is terminated + with pytest.raises(StopAsyncIteration): + await it.__anext__() + + +async def test_invoke_stream_async_with(): + closed = {"v": False} + + async def _gen(): + try: + yield SSEEvent(event="m", data="x") + finally: + closed["v"] = True + + async def _factory(): + return _gen() + + stream = InvokeStream( + conversation_id="c", + session_id="s", + stream_url="u", + stream_headers={}, + _stream_factory=_factory, + ) + async with stream as s: + async for _ev in s: + break + assert closed["v"] is True diff --git a/tests/unittests/toolset/api/test_openapi.py b/tests/unittests/toolset/api/test_openapi.py index 0f2e82d..bb32eac 100644 --- a/tests/unittests/toolset/api/test_openapi.py +++ b/tests/unittests/toolset/api/test_openapi.py @@ -548,9 +548,7 @@ def test_post_with_ref_schema(self): "content": { "application/json": { "schema": { - "$ref": ( - "#/components/schemas/CreateOrderRequest" - ) + "$ref": "#/components/schemas/CreateOrderRequest" } } }, @@ -761,9 +759,7 @@ def test_invalid_ref_gracefully_handled(self): "content": { "application/json": { "schema": { - "$ref": ( - "#/components/schemas/NonExistent" - ) + "$ref": "#/components/schemas/NonExistent" } } } @@ -796,9 +792,7 @@ def test_external_ref_not_resolved(self): "content": { "application/json": { "schema": { - "$ref": ( - "https://example.com/schemas/external.json" - ) + "$ref": "https://example.com/schemas/external.json" } } } @@ -918,9 +912,7 @@ def _get_coffee_shop_schema(): "content": { "application/json": { "schema": { - "$ref": ( - "#/components/schemas/CreateOrderRequest" - ) + "$ref": "#/components/schemas/CreateOrderRequest" } } }, @@ -956,9 +948,7 @@ def _get_coffee_shop_schema(): "content": { "application/json": { "schema": { - "$ref": ( - "#/components/schemas/UpdateOrderStatusRequest" - ) + "$ref": "#/components/schemas/UpdateOrderStatusRequest" } } }, @@ -1229,9 +1219,7 @@ def test_tool_schema(self): "openapi": "3.0.1", "info": {"title": "Test", "version": "1.0"}, "servers": [{ - "url": ( - "https://1431999136518149.agentrun-data.cn-hangzhou.aliyuncs.com/tools/test/" - ) + "url": "https://1431999136518149.agentrun-data.cn-hangzhou.aliyuncs.com/tools/test/" }], "paths": { "/invoke": { From cbc8fe96cc30d7ecfb65e66a6247669765fe0fb6 Mon Sep 17 00:00:00 2001 From: Sodawyx Date: Thu, 16 Apr 2026 20:34:16 +0800 Subject: [PATCH 2/7] refactor(super_agent): Apply Dara SDK patches lazily and improve type hints This change refactors the application of monkey patches to the Dara SDK models and clients, applying them only when needed rather than at module load time. It also improves type hint accuracy for several parameters in both synchronous and asynchronous client methods. The key changes include: - Introducing `ensure_super_agent_patches_applied()` function that applies all necessary patches lazily - Moving placeholder image definition into its own constant - Updating docstrings for clarity - Improving parameter types across multiple client method signatures Co-developed-by: Aone Copilot Signed-off-by: Sodawyx --- .../super_agent/__client_async_template.py | 43 +++++---- agentrun/super_agent/agui.py | 4 +- agentrun/super_agent/api/control.py | 87 +++++++++++-------- agentrun/super_agent/client.py | 43 +++++---- tests/unittests/super_agent/test_control.py | 7 +- 5 files changed, 110 insertions(+), 74 deletions(-) diff --git a/agentrun/super_agent/__client_async_template.py b/agentrun/super_agent/__client_async_template.py index 0a9f1ba..533a7c0 100644 --- a/agentrun/super_agent/__client_async_template.py +++ b/agentrun/super_agent/__client_async_template.py @@ -22,6 +22,7 @@ from agentrun.agent_runtime.runtime import AgentRuntime from agentrun.super_agent.agent import SuperAgent from agentrun.super_agent.api.control import ( + ensure_super_agent_patches_applied, from_agent_runtime, is_super_agent, SUPER_AGENT_TAG, @@ -86,6 +87,9 @@ class SuperAgentClient: """Super Agent CRUDL 客户端.""" def __init__(self, config: Optional[Config] = None) -> None: + # 按需打 Dara SDK 兼容补丁 (幂等)。放在本构造函数里, 让 "仅 import + # agentrun.super_agent" 的调用方不被动承担全局 SDK 副作用。 + ensure_super_agent_patches_applied() self.config = config self._rt = AgentRuntimeClient(config=config) # create/update 绕过 AgentRuntimeClient 的 artifact_type 校验 (SUPER_AGENT 不需要 code/container), @@ -302,19 +306,22 @@ def get(self, name: str, *, config: Optional[Config] = None) -> SuperAgent: return agent # ─── Update (read-merge-write) ───────────────────── + # 参数默认值 ``_UNSET`` 是内部哨兵 (object())。为保留 IDE 自动补全与 mypy + # 类型检查, 签名保持精确类型标注, 对 ``= _UNSET`` 的赋值加 ``type: ignore``。 + # 未传 = 保持不变, 显式传 None = 清空字段。 async def update_async( self, name: str, *, - description: Any = _UNSET, - prompt: Any = _UNSET, - agents: Any = _UNSET, - tools: Any = _UNSET, - skills: Any = _UNSET, - sandboxes: Any = _UNSET, - workspaces: Any = _UNSET, - model_service_name: Any = _UNSET, - model_name: Any = _UNSET, + description: Optional[str] = _UNSET, # type: ignore[assignment] + prompt: Optional[str] = _UNSET, # type: ignore[assignment] + agents: Optional[List[str]] = _UNSET, # type: ignore[assignment] + tools: Optional[List[str]] = _UNSET, # type: ignore[assignment] + skills: Optional[List[str]] = _UNSET, # type: ignore[assignment] + sandboxes: Optional[List[str]] = _UNSET, # type: ignore[assignment] + workspaces: Optional[List[str]] = _UNSET, # type: ignore[assignment] + model_service_name: Optional[str] = _UNSET, # type: ignore[assignment] + model_name: Optional[str] = _UNSET, # type: ignore[assignment] config: Optional[Config] = None, ) -> SuperAgent: """异步更新超级 Agent (read-merge-write).""" @@ -356,15 +363,15 @@ def update( self, name: str, *, - description: Any = _UNSET, - prompt: Any = _UNSET, - agents: Any = _UNSET, - tools: Any = _UNSET, - skills: Any = _UNSET, - sandboxes: Any = _UNSET, - workspaces: Any = _UNSET, - model_service_name: Any = _UNSET, - model_name: Any = _UNSET, + description: Optional[str] = _UNSET, # type: ignore[assignment] + prompt: Optional[str] = _UNSET, # type: ignore[assignment] + agents: Optional[List[str]] = _UNSET, # type: ignore[assignment] + tools: Optional[List[str]] = _UNSET, # type: ignore[assignment] + skills: Optional[List[str]] = _UNSET, # type: ignore[assignment] + sandboxes: Optional[List[str]] = _UNSET, # type: ignore[assignment] + workspaces: Optional[List[str]] = _UNSET, # type: ignore[assignment] + model_service_name: Optional[str] = _UNSET, # type: ignore[assignment] + model_name: Optional[str] = _UNSET, # type: ignore[assignment] config: Optional[Config] = None, ) -> SuperAgent: """同步更新超级 Agent (read-merge-write).""" diff --git a/agentrun/super_agent/agui.py b/agentrun/super_agent/agui.py index ac615a8..b6904d1 100644 --- a/agentrun/super_agent/agui.py +++ b/agentrun/super_agent/agui.py @@ -23,7 +23,7 @@ from __future__ import annotations -from typing import AsyncIterator, Dict, Literal, Optional, Type +from typing import AsyncGenerator, Dict, Literal, Optional, Type from ag_ui.core import ( BaseEvent, @@ -113,7 +113,7 @@ async def as_agui_events( stream: InvokeStream, *, on_unknown: UnknownMode = "raise", -) -> AsyncIterator[BaseEvent]: +) -> AsyncGenerator[BaseEvent, None]: """把 :class:`InvokeStream` 中的原始 :class:`SSEEvent` 解码为强类型流. 无论正常消费结束、中途异常、解码异常, 都保证 ``await stream.aclose()`` 被调用 diff --git a/agentrun/super_agent/api/control.py b/agentrun/super_agent/api/control.py index 4c7eafa..469bbe9 100644 --- a/agentrun/super_agent/api/control.py +++ b/agentrun/super_agent/api/control.py @@ -14,9 +14,12 @@ from __future__ import annotations import json -from typing import Any, Dict, List, Optional +from typing import Any, Dict, List, Optional, TYPE_CHECKING from urllib.parse import urlparse, urlunparse +if TYPE_CHECKING: + from agentrun.super_agent.agent import SuperAgent + from alibabacloud_agentrun20250910.client import Client as _DaraClient from alibabacloud_agentrun20250910.models import ( CreateAgentRuntimeInput as _DaraCreateAgentRuntimeInput, @@ -60,6 +63,12 @@ _RAM_DATA_DOMAINS = ("agentrun-data", "funagent-data-pre") +# SUPER_AGENT 不跑用户 container/code, 但服务端强制要求 artifact/container_configuration 非空, +# 这里给一个占位镜像地址即可。region 取杭州仅为了格式合法, 服务端不会实际 pull。 +_PLACEHOLDER_IMAGE = ( + "registry.cn-hangzhou.aliyuncs.com/agentrun/super-agent-placeholder:v1" +) + # ─── URL 工具 ────────────────────────────────────────── @@ -143,14 +152,23 @@ def model_dump(self, **kwargs: Any) -> Dict[str, Any]: return super().model_dump(**kwargs) -# ─── Dara 模型猴补丁 ────────────────────────────────────── -# Dara 的 ``ProtocolConfiguration`` 当前版本没有 ``externalEndpoint`` 字段; -# ``AgentRuntimeClient.create_async/update_async`` 内部做 -# ``CreateAgentRuntimeInput().from_map(pydantic.model_dump())`` 的 roundtrip, -# 会在 Dara 层丢失此字段。这里做一次加性 patch: 仅追加读写 ``externalEndpoint``, -# 不改变任何现有字段行为, 用模块级哨兵属性保证幂等。 +# ─── Dara 模型/客户端猴补丁 ────────────────────────────────────── +# 当前版 Dara SDK 缺 ``ProtocolConfiguration.externalEndpoint`` 和 +# ``CreateAgentRuntimeInput/UpdateAgentRuntimeInput/ListAgentRuntimesRequest.tags`` +# 字段, 会在 Pydantic ↔ Dara ``from_map / to_map`` roundtrip 中静默丢失; 且 +# ``Client.list_agent_runtimes_with_options{,_async}`` 不会把 ``tags`` 写到 query。 +# +# 所有补丁都延迟到 ``SuperAgentClient`` 实例化时 (见 +# ``ensure_super_agent_patches_applied``) 才触发, 避免仅 import 本模块的调用方 +# 被动承担全局副作用。补丁本身用哨兵属性保证幂等, 重复调用安全。 +# TODO: 等 Dara SDK 原生支持后删除。 + + +def _patch_dara_protocol_configuration() -> None: + """补齐 ``ProtocolConfiguration.externalEndpoint`` 的 from_map/to_map 读写.""" + if getattr(_DaraProtocolConfiguration, "_super_agent_patched", False): + return -if not getattr(_DaraProtocolConfiguration, "_super_agent_patched", False): _orig_to_map = _DaraProtocolConfiguration.to_map _orig_from_map = _DaraProtocolConfiguration.from_map @@ -174,10 +192,8 @@ def _patched_from_map( _DaraProtocolConfiguration._super_agent_patched = True # type: ignore[attr-defined] -# Dara 的 ``CreateAgentRuntimeInput`` / ``UpdateAgentRuntimeInput`` 当前版本没有 -# ``tags`` 字段, 与 ``ProtocolConfiguration`` 同理会在 Pydantic → Dara 的 roundtrip -# 中被静默丢弃. 这里沿用同款加性 patch, 只补齐 ``tags`` 字段的读写. def _patch_dara_tags(cls: Any) -> None: + """给 Dara model 补齐 ``tags`` 字段的 from_map/to_map 读写.""" if getattr(cls, "_super_agent_tags_patched", False): return _orig_to_map = cls.to_map @@ -201,22 +217,6 @@ def _patched_from_map(self: Any, m: Optional[Dict[str, Any]] = None) -> Any: cls._super_agent_tags_patched = True # type: ignore[attr-defined] -_patch_dara_tags(_DaraCreateAgentRuntimeInput) -_patch_dara_tags(_DaraUpdateAgentRuntimeInput) -# ``ListAgentRuntimesRequest`` 同样没有 ``tags`` 字段: 补上 from_map/to_map 以保留 -# 属性; 真正让服务端生效的查询参数注入由下面的 client 级补丁完成。 -_patch_dara_tags(_DaraListAgentRuntimesRequest) - - -# ─── Dara 客户端猴补丁: list 请求 query 注入 tags ─────────────── -# 现版 Dara ``Client.list_agent_runtimes_with_options`` 不读 ``request.tags`` -# 构造 query, 导致即便 Pydantic 侧把 tags 传下来, 服务端也收不到。这里一次性 -# 包裹同步 / 异步两个方法: 若 request 带有 ``tags`` 就在底层 ``call_api`` 调用 -# 前把 ``tags`` (列表 → 逗号分隔) 追加到 ``req.query``。 -# 每个 API 调用都会 ``_get_client()`` 新建 ``Client`` 实例, 实例属性级别的替换 -# 在并发下是安全的。 - - def _tags_query_value(tags: Any) -> Optional[str]: if tags is None: return None @@ -228,6 +228,13 @@ def _tags_query_value(tags: Any) -> Optional[str]: def _patch_dara_client_list_tags() -> None: + """包裹 ``Client.list_agent_runtimes_with_options{,_async}``: 若 request 带 + ``tags`` 就在底层 ``call_api`` 调用前把 ``tags`` (列表 → 逗号分隔) 追加到 + ``req.query``。 + + 每次 API 调用由 ``_get_client()`` 新建 ``Client`` 实例, 实例属性级别的 + ``self.call_api = _injecting`` 替换在并发下是安全的。 + """ if getattr(_DaraClient, "_super_agent_list_tags_patched", False): return @@ -285,7 +292,20 @@ async def _injecting(params: Any, req: Any, rt: Any) -> Any: _DaraClient._super_agent_list_tags_patched = True # type: ignore[attr-defined] -_patch_dara_client_list_tags() +def ensure_super_agent_patches_applied() -> None: + """按需应用全部 Dara SDK 兼容补丁 (幂等)。 + + 由 ``SuperAgentClient.__init__`` 调用。如果调用方直接使用 + ``to_create_input`` / ``to_update_input`` 并自己构造 ``CreateAgentRuntimeInput`` + / ``ListAgentRuntimesRequest``, 也应在 Pydantic → Dara 转换前调用一次本函数。 + """ + _patch_dara_protocol_configuration() + _patch_dara_tags(_DaraCreateAgentRuntimeInput) + _patch_dara_tags(_DaraUpdateAgentRuntimeInput) + # ``ListAgentRuntimesRequest`` 补齐 from_map/to_map 保留属性; 真正让服务端 + # 生效的 query 注入由 ``_patch_dara_client_list_tags`` 完成。 + _patch_dara_tags(_DaraListAgentRuntimesRequest) + _patch_dara_client_list_tags() # ─── AgentRuntime ↔ SuperAgent 转换 ──────────────────────── @@ -392,9 +412,7 @@ def to_create_input( external_agent_endpoint_url=build_super_agent_endpoint(cfg), # 占位 artifact: SUPER_AGENT 不跑用户 container/code, 但服务端要求非空。 artifact_type=AgentRuntimeArtifact.CONTAINER, - container_configuration=AgentRuntimeContainer( - image="registry.cn-hangzhou.aliyuncs.com/agentrun/super-agent-placeholder:v1" - ), + container_configuration=AgentRuntimeContainer(image=_PLACEHOLDER_IMAGE), network_configuration=NetworkConfig(network_mode=NetworkMode.PUBLIC), ) @@ -425,9 +443,7 @@ def to_update_input( external_agent_endpoint_url=build_super_agent_endpoint(cfg), # 占位 artifact: SUPER_AGENT 不跑用户 container/code, 但服务端要求非空。 artifact_type=AgentRuntimeArtifact.CONTAINER, - container_configuration=AgentRuntimeContainer( - image="registry.cn-hangzhou.aliyuncs.com/agentrun/super-agent-placeholder:v1" - ), + container_configuration=AgentRuntimeContainer(image=_PLACEHOLDER_IMAGE), network_configuration=NetworkConfig(network_mode=NetworkMode.PUBLIC), ) @@ -502,7 +518,7 @@ def _get_external_endpoint(rt: AgentRuntime) -> str: ) -def from_agent_runtime(rt: AgentRuntime) -> "SuperAgent": # noqa: F821 +def from_agent_runtime(rt: AgentRuntime) -> "SuperAgent": """反解 AgentRuntime → SuperAgent 实例 (不注入 ``_client``).""" # 延迟导入避免循环 from agentrun.super_agent.agent import SuperAgent @@ -547,4 +563,5 @@ def from_agent_runtime(rt: AgentRuntime) -> "SuperAgent": # noqa: F821 "from_agent_runtime", "is_super_agent", "parse_super_agent_config", + "ensure_super_agent_patches_applied", ] diff --git a/agentrun/super_agent/client.py b/agentrun/super_agent/client.py index 8ac6c84..0a2daae 100644 --- a/agentrun/super_agent/client.py +++ b/agentrun/super_agent/client.py @@ -32,6 +32,7 @@ from agentrun.agent_runtime.runtime import AgentRuntime from agentrun.super_agent.agent import SuperAgent from agentrun.super_agent.api.control import ( + ensure_super_agent_patches_applied, from_agent_runtime, is_super_agent, SUPER_AGENT_TAG, @@ -96,6 +97,9 @@ class SuperAgentClient: """Super Agent CRUDL 客户端.""" def __init__(self, config: Optional[Config] = None) -> None: + # 按需打 Dara SDK 兼容补丁 (幂等)。放在本构造函数里, 让 "仅 import + # agentrun.super_agent" 的调用方不被动承担全局 SDK 副作用。 + ensure_super_agent_patches_applied() self.config = config self._rt = AgentRuntimeClient(config=config) # create/update 绕过 AgentRuntimeClient 的 artifact_type 校验 (SUPER_AGENT 不需要 code/container), @@ -312,19 +316,22 @@ def get(self, name: str, *, config: Optional[Config] = None) -> SuperAgent: return agent # ─── Update (read-merge-write) ───────────────────── + # 参数默认值 ``_UNSET`` 是内部哨兵 (object())。为保留 IDE 自动补全与 mypy + # 类型检查, 签名保持精确类型标注, 对 ``= _UNSET`` 的赋值加 ``type: ignore``。 + # 未传 = 保持不变, 显式传 None = 清空字段。 async def update_async( self, name: str, *, - description: Any = _UNSET, - prompt: Any = _UNSET, - agents: Any = _UNSET, - tools: Any = _UNSET, - skills: Any = _UNSET, - sandboxes: Any = _UNSET, - workspaces: Any = _UNSET, - model_service_name: Any = _UNSET, - model_name: Any = _UNSET, + description: Optional[str] = _UNSET, # type: ignore[assignment] + prompt: Optional[str] = _UNSET, # type: ignore[assignment] + agents: Optional[List[str]] = _UNSET, # type: ignore[assignment] + tools: Optional[List[str]] = _UNSET, # type: ignore[assignment] + skills: Optional[List[str]] = _UNSET, # type: ignore[assignment] + sandboxes: Optional[List[str]] = _UNSET, # type: ignore[assignment] + workspaces: Optional[List[str]] = _UNSET, # type: ignore[assignment] + model_service_name: Optional[str] = _UNSET, # type: ignore[assignment] + model_name: Optional[str] = _UNSET, # type: ignore[assignment] config: Optional[Config] = None, ) -> SuperAgent: """异步更新超级 Agent (read-merge-write).""" @@ -366,15 +373,15 @@ def update( self, name: str, *, - description: Any = _UNSET, - prompt: Any = _UNSET, - agents: Any = _UNSET, - tools: Any = _UNSET, - skills: Any = _UNSET, - sandboxes: Any = _UNSET, - workspaces: Any = _UNSET, - model_service_name: Any = _UNSET, - model_name: Any = _UNSET, + description: Optional[str] = _UNSET, # type: ignore[assignment] + prompt: Optional[str] = _UNSET, # type: ignore[assignment] + agents: Optional[List[str]] = _UNSET, # type: ignore[assignment] + tools: Optional[List[str]] = _UNSET, # type: ignore[assignment] + skills: Optional[List[str]] = _UNSET, # type: ignore[assignment] + sandboxes: Optional[List[str]] = _UNSET, # type: ignore[assignment] + workspaces: Optional[List[str]] = _UNSET, # type: ignore[assignment] + model_service_name: Optional[str] = _UNSET, # type: ignore[assignment] + model_name: Optional[str] = _UNSET, # type: ignore[assignment] config: Optional[Config] = None, ) -> SuperAgent: """同步更新超级 Agent (read-merge-write).""" diff --git a/tests/unittests/super_agent/test_control.py b/tests/unittests/super_agent/test_control.py index 9f83427..1f3adc3 100644 --- a/tests/unittests/super_agent/test_control.py +++ b/tests/unittests/super_agent/test_control.py @@ -10,6 +10,7 @@ _add_ram_prefix_to_host, API_VERSION, build_super_agent_endpoint, + ensure_super_agent_patches_applied, EXTERNAL_TAG, from_agent_runtime, is_super_agent, @@ -23,6 +24,10 @@ from agentrun.super_agent.api.data import SuperAgentDataAPI from agentrun.utils.config import Config +# 本文件部分测试 (list request tags 补丁) 依赖 Dara SDK 已被打过补丁, +# 显式在模块加载时触发补丁 (幂等, 与 SuperAgentClient.__init__ 内触发点一致)。 +ensure_super_agent_patches_applied() + # ─── build_super_agent_endpoint ──────────────────────────────── @@ -282,7 +287,7 @@ def test_to_update_input_full_protocol_replace(): # ─── Dara ListAgentRuntimesRequest tags 补丁 ────────────────── -# 同时确保 import agentrun.super_agent.api.control 已应用补丁 (已在文件顶部导入)。 +# 补丁已在模块顶部通过 ensure_super_agent_patches_applied() 显式触发。 def test_list_request_from_map_preserves_tags(): From 46e9249fdc079af2ec037174564d07162de72f32 Mon Sep 17 00:00:00 2001 From: Sodawyx Date: Thu, 16 Apr 2026 22:47:44 +0800 Subject: [PATCH 3/7] feat(super_agent): add list_conversations_async method This commit introduces the `list_conversations_async` method to both the `SuperAgent` and template classes. The method allows listing conversations asynchronously with optional metadata filtering. By default, it filters by agent runtime name unless explicitly overridden. Key changes include: - Added async implementation in `agent.py`, `__agent_async_template.py`, and API modules - Includes proper error handling and type validation - Maintains backward compatibility through synchronous placeholder methods that raise NotImplementedError Co-developed-by: Aone Copilot Signed-off-by: Sodawyx --- .../super_agent/__agent_async_template.py | 35 ++++ agentrun/super_agent/agent.py | 35 ++++ .../super_agent/api/__data_async_template.py | 49 +++++ agentrun/super_agent/api/data.py | 49 +++++ tests/unittests/super_agent/test_agent.py | 93 +++++++++ tests/unittests/super_agent/test_data_api.py | 179 ++++++++++++++++++ 6 files changed, 440 insertions(+) diff --git a/agentrun/super_agent/__agent_async_template.py b/agentrun/super_agent/__agent_async_template.py index a1f6d1a..9f5ba8b 100644 --- a/agentrun/super_agent/__agent_async_template.py +++ b/agentrun/super_agent/__agent_async_template.py @@ -118,6 +118,41 @@ def invoke( ) -> InvokeStream: raise NotImplementedError(_SYNC_UNSUPPORTED_MSG) + async def list_conversations_async( + self, + *, + metadata: Optional[Dict[str, Any]] = None, + config: Optional[Config] = None, + ) -> List[ConversationInfo]: + """GET /conversations → ``List[ConversationInfo]``. + + 默认按 ``{"agentRuntimeName": self.name}`` 过滤 (仅返回当前 agent 的会话); + 传入 ``metadata={}`` 或自定义 dict 可覆盖默认过滤条件。 + """ + cfg = self._resolve_config(config) + api = SuperAgentDataAPI(self.name, config=cfg) + effective_metadata = ( + {"agentRuntimeName": self.name} if metadata is None else metadata + ) + raw_list = await api.list_conversations_async( + metadata=effective_metadata, config=cfg + ) + return [ + _conversation_info_from_dict( + raw, + fallback_conversation_id=str(raw.get("conversationId") or ""), + ) + for raw in raw_list + ] + + def list_conversations( + self, + *, + metadata: Optional[Dict[str, Any]] = None, + config: Optional[Config] = None, + ) -> List[ConversationInfo]: + raise NotImplementedError(_SYNC_UNSUPPORTED_MSG) + async def get_conversation_async( self, conversation_id: str, diff --git a/agentrun/super_agent/agent.py b/agentrun/super_agent/agent.py index fac7ff6..2460041 100644 --- a/agentrun/super_agent/agent.py +++ b/agentrun/super_agent/agent.py @@ -126,6 +126,41 @@ def invoke( ) -> InvokeStream: raise NotImplementedError(_SYNC_UNSUPPORTED_MSG) + async def list_conversations_async( + self, + *, + metadata: Optional[Dict[str, Any]] = None, + config: Optional[Config] = None, + ) -> List[ConversationInfo]: + """GET /conversations → ``List[ConversationInfo]``. + + 默认按 ``{"agentRuntimeName": self.name}`` 过滤 (仅返回当前 agent 的会话); + 传入 ``metadata={}`` 或自定义 dict 可覆盖默认过滤条件。 + """ + cfg = self._resolve_config(config) + api = SuperAgentDataAPI(self.name, config=cfg) + effective_metadata = ( + {"agentRuntimeName": self.name} if metadata is None else metadata + ) + raw_list = await api.list_conversations_async( + metadata=effective_metadata, config=cfg + ) + return [ + _conversation_info_from_dict( + raw, + fallback_conversation_id=str(raw.get("conversationId") or ""), + ) + for raw in raw_list + ] + + def list_conversations( + self, + *, + metadata: Optional[Dict[str, Any]] = None, + config: Optional[Config] = None, + ) -> List[ConversationInfo]: + raise NotImplementedError(_SYNC_UNSUPPORTED_MSG) + async def get_conversation_async( self, conversation_id: str, diff --git a/agentrun/super_agent/api/__data_async_template.py b/agentrun/super_agent/api/__data_async_template.py index 3549842..9a7549e 100644 --- a/agentrun/super_agent/api/__data_async_template.py +++ b/agentrun/super_agent/api/__data_async_template.py @@ -193,6 +193,55 @@ def stream( ) -> Iterator[SSEEvent]: raise NotImplementedError(_SYNC_UNSUPPORTED_MSG) + async def list_conversations_async( + self, + metadata: Optional[Dict[str, Any]] = None, + config: Optional[Config] = None, + ) -> List[Dict[str, Any]]: + """GET /conversations → 返回服务端 ``data.conversations`` 数组 (缺失时返回 [])。 + + ``metadata`` 若非 None, 会以 JSON 编码后通过 ``metadata`` query 参数下发; + 服务端按该 metadata 过滤 (例如 ``{"agentRuntimeName": "..."}``)。 + 不传则由服务端按当前 sub uid 过滤。 + """ + cfg = Config.with_configs(self.config, config) + query: Optional[Dict[str, Any]] = None + if metadata is not None: + query = {"metadata": json.dumps(metadata, ensure_ascii=False)} + url = self.with_path("conversations", query=query, config=cfg) + _, signed_headers, _ = self.auth( + url=url, + method="GET", + headers=cfg.get_headers(), + config=cfg, + ) + logger.debug("super_agent list_conversations request: GET %s", url) + async with httpx.AsyncClient(timeout=cfg.get_timeout()) as client: + resp = await client.get(url, headers=signed_headers) + resp.raise_for_status() + payload = resp.json() if resp.text else {} + logger.debug( + "super_agent list_conversations response: status=%d payload=%s", + resp.status_code, + payload, + ) + if not isinstance(payload, dict): + return [] + data = payload.get("data") + if not isinstance(data, dict): + return [] + raw_list = data.get("conversations") + if not isinstance(raw_list, list): + return [] + return [item for item in raw_list if isinstance(item, dict)] + + def list_conversations( + self, + metadata: Optional[Dict[str, Any]] = None, + config: Optional[Config] = None, + ) -> List[Dict[str, Any]]: + raise NotImplementedError(_SYNC_UNSUPPORTED_MSG) + async def get_conversation_async( self, conversation_id: str, diff --git a/agentrun/super_agent/api/data.py b/agentrun/super_agent/api/data.py index ec6b2bb..947e0b4 100644 --- a/agentrun/super_agent/api/data.py +++ b/agentrun/super_agent/api/data.py @@ -199,6 +199,55 @@ def stream( ) -> Iterator[SSEEvent]: raise NotImplementedError(_SYNC_UNSUPPORTED_MSG) + async def list_conversations_async( + self, + metadata: Optional[Dict[str, Any]] = None, + config: Optional[Config] = None, + ) -> List[Dict[str, Any]]: + """GET /conversations → 返回服务端 ``data.conversations`` 数组 (缺失时返回 [])。 + + ``metadata`` 若非 None, 会以 JSON 编码后通过 ``metadata`` query 参数下发; + 服务端按该 metadata 过滤 (例如 ``{"agentRuntimeName": "..."}``)。 + 不传则由服务端按当前 sub uid 过滤。 + """ + cfg = Config.with_configs(self.config, config) + query: Optional[Dict[str, Any]] = None + if metadata is not None: + query = {"metadata": json.dumps(metadata, ensure_ascii=False)} + url = self.with_path("conversations", query=query, config=cfg) + _, signed_headers, _ = self.auth( + url=url, + method="GET", + headers=cfg.get_headers(), + config=cfg, + ) + logger.debug("super_agent list_conversations request: GET %s", url) + async with httpx.AsyncClient(timeout=cfg.get_timeout()) as client: + resp = await client.get(url, headers=signed_headers) + resp.raise_for_status() + payload = resp.json() if resp.text else {} + logger.debug( + "super_agent list_conversations response: status=%d payload=%s", + resp.status_code, + payload, + ) + if not isinstance(payload, dict): + return [] + data = payload.get("data") + if not isinstance(data, dict): + return [] + raw_list = data.get("conversations") + if not isinstance(raw_list, list): + return [] + return [item for item in raw_list if isinstance(item, dict)] + + def list_conversations( + self, + metadata: Optional[Dict[str, Any]] = None, + config: Optional[Config] = None, + ) -> List[Dict[str, Any]]: + raise NotImplementedError(_SYNC_UNSUPPORTED_MSG) + async def get_conversation_async( self, conversation_id: str, diff --git a/tests/unittests/super_agent/test_agent.py b/tests/unittests/super_agent/test_agent.py index 2ae2418..b975fee 100644 --- a/tests/unittests/super_agent/test_agent.py +++ b/tests/unittests/super_agent/test_agent.py @@ -226,6 +226,97 @@ async def test_delete_conversation_async_returns_none(): assert await _make_agent().delete_conversation_async("c") is None +# ─── list_conversations_async ────────────────────────────── + + +async def test_list_conversations_async_default_filters_by_agent_name(): + """不传 metadata 时, 默认按 ``{"agentRuntimeName": self.name}`` 过滤.""" + instance = MagicMock() + instance.list_conversations_async = AsyncMock(return_value=[]) + factory = MagicMock(return_value=instance) + with patch("agentrun.super_agent.agent.SuperAgentDataAPI", factory): + agent = SuperAgent(name="my-agent") + result = await agent.list_conversations_async() + assert result == [] + assert instance.list_conversations_async.await_args.kwargs["metadata"] == { + "agentRuntimeName": "my-agent" + } + + +async def test_list_conversations_async_explicit_metadata_passthrough(): + instance = MagicMock() + instance.list_conversations_async = AsyncMock(return_value=[]) + factory = MagicMock(return_value=instance) + with patch("agentrun.super_agent.agent.SuperAgentDataAPI", factory): + agent = _make_agent() + await agent.list_conversations_async(metadata={"foo": "bar"}) + assert instance.list_conversations_async.await_args.kwargs["metadata"] == { + "foo": "bar" + } + + +async def test_list_conversations_async_empty_metadata_overrides_default(): + """传入空 dict 明确表示「不按 agent 过滤」, SDK MUST 不再注入默认值.""" + instance = MagicMock() + instance.list_conversations_async = AsyncMock(return_value=[]) + factory = MagicMock(return_value=instance) + with patch("agentrun.super_agent.agent.SuperAgentDataAPI", factory): + agent = _make_agent() + await agent.list_conversations_async(metadata={}) + assert instance.list_conversations_async.await_args.kwargs["metadata"] == {} + + +async def test_list_conversations_async_returns_conversation_info_list(): + instance = MagicMock() + instance.list_conversations_async = AsyncMock( + return_value=[ + { + "conversationId": "c1", + "agentId": "ag", + "title": "first", + "createdAt": 1, + "updatedAt": 2, + "messages": [{"role": "user", "content": "hi"}], + }, + { + "conversationId": "c2", + "agentId": "ag", + "title": "second", + "messages": [], + }, + ] + ) + factory = MagicMock(return_value=instance) + with patch("agentrun.super_agent.agent.SuperAgentDataAPI", factory): + result = await _make_agent().list_conversations_async() + assert [c.conversation_id for c in result] == ["c1", "c2"] + assert result[0].title == "first" + assert len(result[0].messages) == 1 + assert result[0].messages[0].content == "hi" + assert result[1].title == "second" + + +async def test_list_conversations_async_empty_list(): + instance = MagicMock() + instance.list_conversations_async = AsyncMock(return_value=[]) + factory = MagicMock(return_value=instance) + with patch("agentrun.super_agent.agent.SuperAgentDataAPI", factory): + assert await _make_agent().list_conversations_async() == [] + + +async def test_list_conversations_async_item_missing_conversation_id_uses_empty_fallback(): + """对单条会话里 ``conversationId`` 缺失的情况, fallback 保持空串 (不会用 agent name).""" + instance = MagicMock() + instance.list_conversations_async = AsyncMock( + return_value=[{"agentId": "ag"}] + ) + factory = MagicMock(return_value=instance) + with patch("agentrun.super_agent.agent.SuperAgentDataAPI", factory): + result = await _make_agent().list_conversations_async() + assert len(result) == 1 + assert result[0].conversation_id == "" + + # ─── sync methods → NotImplementedError ───────────────────── @@ -237,6 +328,8 @@ def test_sync_methods_not_implemented(): agent.get_conversation("c") with pytest.raises(NotImplementedError): agent.delete_conversation("c") + with pytest.raises(NotImplementedError): + agent.list_conversations() def test_invoke_async_signature_only_messages_and_conversation_id(): diff --git a/tests/unittests/super_agent/test_data_api.py b/tests/unittests/super_agent/test_data_api.py index 406a818..ec2d7e2 100644 --- a/tests/unittests/super_agent/test_data_api.py +++ b/tests/unittests/super_agent/test_data_api.py @@ -102,6 +102,24 @@ async def test_invoke_async_phase1_url_custom_gateway_no_ram(): assert route.called +@respx.mock +async def test_list_conversations_async_url_pre_environment(): + cfg = _auth_cfg( + data_endpoint="http://111.funagent-data-pre.cn-hangzhou.aliyuncs.com" + ) + api = SuperAgentDataAPI("n", config=cfg) + route = respx.get( + re.compile( + r"http://111-ram\.funagent-data-pre\.cn-hangzhou\.aliyuncs\.com" + r"/2025-09-10/super-agents/__SUPER_AGENT__/conversations(\?.*)?$" + ) + ).mock( + return_value=httpx.Response(200, json={"data": {"conversations": []}}) + ) + await api.list_conversations_async() + assert route.called + + @respx.mock async def test_get_conversation_async_url_pre_environment(): cfg = _auth_cfg( @@ -472,6 +490,166 @@ async def test_delete_conversation_async_404_raises(): await api.delete_conversation_async("missing") +# ─── list_conversations ────────────────────────────────────── + + +@respx.mock +async def test_list_conversations_async_parses_data_array(): + cfg = _auth_cfg() + api = SuperAgentDataAPI("demo", config=cfg) + respx.get(re.compile(r".*/conversations(\?.*)?$")).mock( + return_value=httpx.Response( + 200, + json={ + "data": { + "conversations": [ + {"conversationId": "c1", "title": "t1"}, + {"conversationId": "c2", "title": "t2"}, + ] + }, + "success": True, + }, + ) + ) + result = await api.list_conversations_async() + assert [c["conversationId"] for c in result] == ["c1", "c2"] + + +@respx.mock +async def test_list_conversations_async_metadata_query_encoded(): + cfg = _auth_cfg() + api = SuperAgentDataAPI("demo", config=cfg) + captured = {} + + def _responder(request): + captured["url"] = str(request.url) + return httpx.Response(200, json={"data": {"conversations": []}}) + + respx.get(re.compile(r".*/conversations.*")).mock(side_effect=_responder) + await api.list_conversations_async(metadata={"agentRuntimeName": "demo"}) + assert "metadata=" in captured["url"] + # metadata is json-encoded then URL-encoded + assert "agentRuntimeName" in captured["url"] + + +@respx.mock +async def test_list_conversations_async_without_metadata_no_query(): + cfg = _auth_cfg() + api = SuperAgentDataAPI("demo", config=cfg) + captured = {} + + def _responder(request): + captured["url"] = str(request.url) + return httpx.Response(200, json={"data": {"conversations": []}}) + + respx.get(re.compile(r".*/conversations.*")).mock(side_effect=_responder) + await api.list_conversations_async() + assert "metadata" not in captured["url"] + + +@respx.mock +async def test_list_conversations_async_request_signed(): + cfg = _auth_cfg() + api = SuperAgentDataAPI("demo", config=cfg) + captured = {} + + def _responder(request): + captured["headers"] = dict(request.headers) + return httpx.Response(200, json={"data": {"conversations": []}}) + + respx.get(re.compile(r".*/conversations.*")).mock(side_effect=_responder) + await api.list_conversations_async() + assert any( + k.lower() == "agentrun-authorization" for k in captured["headers"] + ) + + +@respx.mock +async def test_list_conversations_async_missing_data_returns_empty(): + cfg = _auth_cfg() + api = SuperAgentDataAPI("demo", config=cfg) + respx.get(re.compile(r".*/conversations.*")).mock( + return_value=httpx.Response(200, json={"success": True}) + ) + assert await api.list_conversations_async() == [] + + +@respx.mock +async def test_list_conversations_async_data_not_dict_returns_empty(): + cfg = _auth_cfg() + api = SuperAgentDataAPI("demo", config=cfg) + respx.get(re.compile(r".*/conversations.*")).mock( + return_value=httpx.Response(200, json={"data": "unexpected"}) + ) + assert await api.list_conversations_async() == [] + + +@respx.mock +async def test_list_conversations_async_conversations_not_list_returns_empty(): + cfg = _auth_cfg() + api = SuperAgentDataAPI("demo", config=cfg) + respx.get(re.compile(r".*/conversations.*")).mock( + return_value=httpx.Response( + 200, json={"data": {"conversations": "bad"}} + ) + ) + assert await api.list_conversations_async() == [] + + +@respx.mock +async def test_list_conversations_async_filters_non_dict_items(): + cfg = _auth_cfg() + api = SuperAgentDataAPI("demo", config=cfg) + respx.get(re.compile(r".*/conversations.*")).mock( + return_value=httpx.Response( + 200, + json={ + "data": { + "conversations": [ + {"conversationId": "c1"}, + "invalid", + None, + {"conversationId": "c2"}, + ] + } + }, + ) + ) + result = await api.list_conversations_async() + assert [c["conversationId"] for c in result] == ["c1", "c2"] + + +@respx.mock +async def test_list_conversations_async_payload_not_dict_returns_empty(): + cfg = _auth_cfg() + api = SuperAgentDataAPI("demo", config=cfg) + respx.get(re.compile(r".*/conversations.*")).mock( + return_value=httpx.Response(200, json=[1, 2, 3]) + ) + assert await api.list_conversations_async() == [] + + +@respx.mock +async def test_list_conversations_async_empty_body_returns_empty(): + cfg = _auth_cfg() + api = SuperAgentDataAPI("demo", config=cfg) + respx.get(re.compile(r".*/conversations.*")).mock( + return_value=httpx.Response(200, content=b"") + ) + assert await api.list_conversations_async() == [] + + +@respx.mock +async def test_list_conversations_async_5xx_raises(): + cfg = _auth_cfg() + api = SuperAgentDataAPI("demo", config=cfg) + respx.get(re.compile(r".*/conversations.*")).mock( + return_value=httpx.Response(500, text="boom") + ) + with pytest.raises(httpx.HTTPStatusError): + await api.list_conversations_async() + + # ─── sync stubs are NotImplementedError ────────────────────── @@ -481,6 +659,7 @@ def test_sync_methods_not_implemented(): for fn in ( lambda: api.invoke([]), lambda: api.stream("url"), + lambda: api.list_conversations(), lambda: api.get_conversation("c"), lambda: api.delete_conversation("c"), ): From 36099649d7b7908deaddc36967c70093593e232d Mon Sep 17 00:00:00 2001 From: OhYee Date: Thu, 16 Apr 2026 23:01:33 +0800 Subject: [PATCH 4/7] fix(ram_signature): use unquote_plus for query parameter decoding MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Updated query parameter decoding to use unquote_plus instead of unquote for better handling of '+' characters in URLs. This change ensures that query parameters are decoded correctly, aligning with standard URL encoding practices. 修复(ram_signature):使用 unquote_plus 解码查询参数 更新查询参数解码,改用 unquote_plus 代替 unquote,以更好地处理 URL 中的 '+' 字符。此更改确保查询参数正确解码,符合标准 URL 编码实践。 Change-Id: I6d8af91c3ac3512cf8b9a96dda4501a643e8fa66 Signed-off-by: OhYee --- agentrun/utils/ram_signature/signer.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/agentrun/utils/ram_signature/signer.py b/agentrun/utils/ram_signature/signer.py index c91ddf8..0f0bb70 100644 --- a/agentrun/utils/ram_signature/signer.py +++ b/agentrun/utils/ram_signature/signer.py @@ -12,7 +12,7 @@ import hashlib import hmac from typing import Optional -from urllib.parse import quote, unquote, urlparse +from urllib.parse import quote, unquote_plus, urlparse ALGORITHM = "AGENTRUN4-HMAC-SHA256" UNSIGNED_PAYLOAD = "UNSIGNED-PAYLOAD" @@ -156,7 +156,7 @@ def get_agentrun_signed_headers( for pair in parsed.query.split("&"): if "=" in pair: k, v = pair.split("=", 1) - query_params[unquote(k)] = unquote(v) + query_params[unquote_plus(k)] = unquote_plus(v) now = sign_time if sign_time is not None else datetime.now(timezone.utc) if now.tzinfo is None: From c8847e08d016cfadab46858b6510abfe2f16b4de Mon Sep 17 00:00:00 2001 From: Sodawyx Date: Mon, 20 Apr 2026 13:59:11 +0800 Subject: [PATCH 5/7] refactor(super_agent): Update tag constants and simplify creation process This change updates the `SUPER_AGENT_TAG` constant from `"x-agentrun-super-agent"` to `"x-agentrun-super"`, removes the `EXTERNAL_TAG` from the default tags when creating a super agent runtime, and adds clarifying comments about the purpose of `external_agent_endpoint_url`. The corresponding unit tests have also been updated accordingly. Co-developed-by: Aone Copilot Signed-off-by: Sodawyx --- agentrun/super_agent/__client_async_template.py | 2 +- agentrun/super_agent/api/control.py | 16 ++++++++-------- agentrun/super_agent/client.py | 2 +- tests/unittests/super_agent/test_control.py | 4 ++-- 4 files changed, 12 insertions(+), 12 deletions(-) diff --git a/agentrun/super_agent/__client_async_template.py b/agentrun/super_agent/__client_async_template.py index 533a7c0..4beb65a 100644 --- a/agentrun/super_agent/__client_async_template.py +++ b/agentrun/super_agent/__client_async_template.py @@ -4,7 +4,7 @@ 内部持有一个 :class:`AgentRuntimeClient` 实例, 通过 ``api/control.py`` 的 转换函数把 ``SuperAgent`` 与 ``AgentRuntime`` 互相映射。 -list 固定按 tag ``x-agentrun-super-agent`` 过滤, 不接受用户自定义 tag。 +list 固定按 tag ``x-agentrun-super`` 过滤, 不接受用户自定义 tag。 """ import asyncio diff --git a/agentrun/super_agent/api/control.py b/agentrun/super_agent/api/control.py index 469bbe9..cf444e4 100644 --- a/agentrun/super_agent/api/control.py +++ b/agentrun/super_agent/api/control.py @@ -50,11 +50,12 @@ API_VERSION = "2025-09-10" SUPER_AGENT_PROTOCOL_TYPE = "SUPER_AGENT" # ``SUPER_AGENT_TAG`` 标识下游 AgentRuntime 是超级 Agent, 用于 list 过滤。 -SUPER_AGENT_TAG = "x-agentrun-super-agent" +SUPER_AGENT_TAG = "x-agentrun-super" # ``EXTERNAL_TAG`` 标识下游 AgentRuntime 由外部 (SuperAgent) 托管调用, 不由 AgentRun 直接托管。 +# 保留常量以便外部消费者引用; 创建超级 Agent 时不再写入此 tag。 EXTERNAL_TAG = "x-agentrun-external" -# 创建下游 AgentRuntime 时固定写入的 tag 列表: ``[EXTERNAL_TAG, SUPER_AGENT_TAG]``。 -SUPER_AGENT_CREATE_TAGS = [EXTERNAL_TAG, SUPER_AGENT_TAG] +# 创建下游 AgentRuntime 时固定写入的 tag 列表: ``[SUPER_AGENT_TAG]`` (仅一个)。 +SUPER_AGENT_CREATE_TAGS = [SUPER_AGENT_TAG] SUPER_AGENT_RESOURCE_PATH = "__SUPER_AGENT__" SUPER_AGENT_INVOKE_PATH = "/invoke" SUPER_AGENT_NAMESPACE = ( @@ -131,8 +132,8 @@ class _SuperAgentCreateInput(AgentRuntimeCreateInput): """默认使用 ``serialize_as_any=True`` 的 create input, 保留子类 extras. ``external_agent_endpoint_url`` 是基类 ``AgentRuntimeMutableProps`` 没有覆盖 - 的顶层字段, 但在 ``x-agentrun-external`` tag 下服务端强制要求填入, 这里显式 - 补齐 (alias 由 BaseModel 的 ``to_camel_case`` 生成 → ``externalAgentEndpointUrl``)。 + 的顶层字段, 这里显式补齐 (alias 由 BaseModel 的 ``to_camel_case`` 生成 → + ``externalAgentEndpointUrl``), 用于承载超级 Agent 的数据面入口地址。 """ external_agent_endpoint_url: Optional[str] = None @@ -407,8 +408,7 @@ def to_create_input( description=description, protocol_configuration=pc, tags=list(SUPER_AGENT_CREATE_TAGS), - # 带 ``x-agentrun-external`` tag 时服务端强制要求 externalAgentEndpointUrl 非空, - # 对超级 Agent 而言即数据面入口 (与 protocolConfiguration.externalEndpoint 同值)。 + # 超级 Agent 的数据面入口 (与 protocolConfiguration.externalEndpoint 同值)。 external_agent_endpoint_url=build_super_agent_endpoint(cfg), # 占位 artifact: SUPER_AGENT 不跑用户 container/code, 但服务端要求非空。 artifact_type=AgentRuntimeArtifact.CONTAINER, @@ -439,7 +439,7 @@ def to_update_input( description=merged.get("description"), protocol_configuration=pc, tags=list(SUPER_AGENT_CREATE_TAGS), - # 带 ``x-agentrun-external`` tag 时服务端强制要求 externalAgentEndpointUrl 非空。 + # 超级 Agent 的数据面入口 (与 protocolConfiguration.externalEndpoint 同值)。 external_agent_endpoint_url=build_super_agent_endpoint(cfg), # 占位 artifact: SUPER_AGENT 不跑用户 container/code, 但服务端要求非空。 artifact_type=AgentRuntimeArtifact.CONTAINER, diff --git a/agentrun/super_agent/client.py b/agentrun/super_agent/client.py index 0a2daae..421f54d 100644 --- a/agentrun/super_agent/client.py +++ b/agentrun/super_agent/client.py @@ -14,7 +14,7 @@ 内部持有一个 :class:`AgentRuntimeClient` 实例, 通过 ``api/control.py`` 的 转换函数把 ``SuperAgent`` 与 ``AgentRuntime`` 互相映射。 -list 固定按 tag ``x-agentrun-super-agent`` 过滤, 不接受用户自定义 tag。 +list 固定按 tag ``x-agentrun-super`` 过滤, 不接受用户自定义 tag。 """ import asyncio diff --git a/tests/unittests/super_agent/test_control.py b/tests/unittests/super_agent/test_control.py index 1f3adc3..5c7f114 100644 --- a/tests/unittests/super_agent/test_control.py +++ b/tests/unittests/super_agent/test_control.py @@ -108,7 +108,7 @@ def test_to_create_input_minimal(): cfg = Config(account_id="123", region_id="cn-hangzhou") inp = to_create_input("alpha", cfg=cfg) assert inp.agent_runtime_name == "alpha" - assert inp.tags == [EXTERNAL_TAG, SUPER_AGENT_TAG] + assert inp.tags == [SUPER_AGENT_TAG] pc = inp.protocol_configuration assert pc.type == SUPER_AGENT_PROTOCOL_TYPE assert pc.external_endpoint.endswith("/super-agents/__SUPER_AGENT__") @@ -149,7 +149,7 @@ def test_to_create_input_full(): def test_to_create_input_tags_fixed(): cfg = Config(account_id="123", region_id="cn-hangzhou") inp = to_create_input("c", cfg=cfg) - assert inp.tags == [EXTERNAL_TAG, SUPER_AGENT_TAG] + assert inp.tags == [SUPER_AGENT_TAG] def test_to_create_input_metadata_only_agent_runtime_name(): From c61ead9776a349f1c752cc3a076fb6a2cde6ccf8 Mon Sep 17 00:00:00 2001 From: Sodawyx Date: Mon, 20 Apr 2026 20:03:40 +0800 Subject: [PATCH 6/7] refactor(super_agent): Update config structure and parsing logic This change updates the internal configuration structure to include nested `headers`, `body`, and `forwardedProps`. It also refactors the parsing logic in `_flatten_protocol_config()` to handle both old and new formats seamlessly. The unit tests have been updated accordingly. Co-developed-by: Aone Copilot Signed-off-by: Sodawyx --- agentrun/super_agent/api/control.py | 41 +++++- tests/unittests/super_agent/test_client.py | 37 +++-- tests/unittests/super_agent/test_control.py | 145 ++++++++++++++++---- 3 files changed, 178 insertions(+), 45 deletions(-) diff --git a/agentrun/super_agent/api/control.py b/agentrun/super_agent/api/control.py index cf444e4..fe3aecf 100644 --- a/agentrun/super_agent/api/control.py +++ b/agentrun/super_agent/api/control.py @@ -337,9 +337,12 @@ def _business_fields_from_args( def _build_protocol_settings_config( *, name: str, business: Dict[str, Any] ) -> str: - """构造 ``protocolSettings[0].config`` 的 JSON 字符串.""" - cfg_dict: Dict[str, Any] = { - "path": SUPER_AGENT_INVOKE_PATH, + """构造 ``protocolSettings[0].config`` 的 JSON 字符串. + + 新结构: 顶层 ``path`` / ``headers`` / ``body``, 业务字段收拢到 + ``body.forwardedProps`` (开放字典, 语义 "any, merge")。 + """ + forwarded_props: Dict[str, Any] = { "prompt": business.get("prompt"), "agents": business.get("agents") or [], "tools": business.get("tools") or [], @@ -350,6 +353,11 @@ def _build_protocol_settings_config( "modelName": business.get("modelName"), "metadata": {"agentRuntimeName": name}, } + cfg_dict: Dict[str, Any] = { + "path": SUPER_AGENT_INVOKE_PATH, + "headers": {}, + "body": {"forwardedProps": forwarded_props}, + } return json.dumps(cfg_dict, ensure_ascii=False) @@ -483,10 +491,31 @@ def is_super_agent(rt: AgentRuntime) -> bool: return first.get("type") == SUPER_AGENT_PROTOCOL_TYPE +def _flatten_protocol_config(cfg: Any) -> Dict[str, Any]: + """把 ``protocolSettings[0].config`` 解析结果压平为扁平业务字段 dict. + + 兼容两种物理布局: + - 新: ``{"path": ..., "headers": ..., "body": {"forwardedProps": {...}}}`` + - 旧: 业务字段直接在根 (历史 AgentRuntime, 迁移前写入) + + 两种结构都返回扁平的业务字段 dict, 上游 + :func:`from_agent_runtime` 无需感知物理布局差异。 + """ + if not isinstance(cfg, dict): + return {} + body = cfg.get("body") + if isinstance(body, dict): + forwarded = body.get("forwardedProps") + if isinstance(forwarded, dict): + return forwarded + return cfg + + def parse_super_agent_config(rt: AgentRuntime) -> Dict[str, Any]: - """反解 ``protocolSettings[0].config`` 为业务字段 dict. + """反解 ``protocolSettings[0].config`` 为扁平业务字段 dict. 如果 config 缺失或非法 JSON, 返回空 dict (不抛异常)。 + 新旧嵌套布局由 :func:`_flatten_protocol_config` 统一拍平。 """ pc_dict = _extract_protocol_configuration(rt) if not pc_dict: @@ -499,13 +528,13 @@ def parse_super_agent_config(rt: AgentRuntime) -> Dict[str, Any]: if not raw_config: return {} if isinstance(raw_config, dict): - return raw_config + return _flatten_protocol_config(raw_config) if isinstance(raw_config, str): try: parsed = json.loads(raw_config) - return parsed if isinstance(parsed, dict) else {} except (TypeError, ValueError): return {} + return _flatten_protocol_config(parsed) return {} diff --git a/tests/unittests/super_agent/test_client.py b/tests/unittests/super_agent/test_client.py index 0f7da4d..604b4d6 100644 --- a/tests/unittests/super_agent/test_client.py +++ b/tests/unittests/super_agent/test_client.py @@ -60,15 +60,20 @@ def _fake_rt( """Build a minimal AgentRuntime-like object for ``from_agent_runtime``.""" cfg_dict = { "path": "/invoke", - "prompt": prompt, - "agents": [], - "tools": tools if tools is not None else [], - "skills": [], - "sandboxes": [], - "workspaces": [], - "modelServiceName": None, - "modelName": None, - "metadata": {"agentRuntimeName": name}, + "headers": {}, + "body": { + "forwardedProps": { + "prompt": prompt, + "agents": [], + "tools": tools if tools is not None else [], + "skills": [], + "sandboxes": [], + "workspaces": [], + "modelServiceName": None, + "modelName": None, + "metadata": {"agentRuntimeName": name}, + } + }, } pc = { "type": protocol_type, @@ -124,7 +129,7 @@ async def _create_async(dara_input, config=None): first = pc.protocol_settings[0] assert first.type == SUPER_AGENT_PROTOCOL_TYPE cfg_json = json.loads(first.config) - assert cfg_json["prompt"] == "new" + assert cfg_json["body"]["forwardedProps"]["prompt"] == "new" assert agent.name == "alpha" @@ -206,8 +211,9 @@ async def _update(agent_id, dara_input, config=None): cfg_json = json.loads( captured["dara"].protocol_configuration.protocol_settings[0].config ) - assert cfg_json["prompt"] == "new" - assert cfg_json["tools"] == ["t1"] + forwarded = cfg_json["body"]["forwardedProps"] + assert forwarded["prompt"] == "new" + assert forwarded["tools"] == ["t1"] async def test_update_async_explicit_none_clears_field(): @@ -238,7 +244,7 @@ async def _update(agent_id, dara_input, config=None): cfg_json = json.loads( captured["dara"].protocol_configuration.protocol_settings[0].config ) - assert cfg_json["prompt"] is None + assert cfg_json["body"]["forwardedProps"]["prompt"] is None async def test_update_async_multiple_fields(): @@ -271,8 +277,9 @@ async def _update(agent_id, dara_input, config=None): cfg_json = json.loads( captured["dara"].protocol_configuration.protocol_settings[0].config ) - assert cfg_json["prompt"] == "p" - assert cfg_json["tools"] == ["a", "b"] + forwarded = cfg_json["body"]["forwardedProps"] + assert forwarded["prompt"] == "p" + assert forwarded["tools"] == ["a", "b"] assert captured["dara"].description == "d" diff --git a/tests/unittests/super_agent/test_control.py b/tests/unittests/super_agent/test_control.py index 5c7f114..3831a2a 100644 --- a/tests/unittests/super_agent/test_control.py +++ b/tests/unittests/super_agent/test_control.py @@ -116,8 +116,10 @@ def test_to_create_input_minimal(): assert len(settings) == 1 cfg_dict = json.loads(settings[0]["config"]) assert cfg_dict["path"] == "/invoke" - assert cfg_dict["agents"] == [] - assert cfg_dict["metadata"] == {"agentRuntimeName": "alpha"} + assert cfg_dict["headers"] == {} + forwarded = cfg_dict["body"]["forwardedProps"] + assert forwarded["agents"] == [] + assert forwarded["metadata"] == {"agentRuntimeName": "alpha"} def test_to_create_input_full(): @@ -136,14 +138,17 @@ def test_to_create_input_full(): ) pc_dict = inp.model_dump()["protocolConfiguration"] settings_cfg = json.loads(pc_dict["protocolSettings"][0]["config"]) - assert settings_cfg["prompt"] == "hello" - assert settings_cfg["agents"] == ["a1"] - assert settings_cfg["tools"] == ["t1", "t2"] - assert settings_cfg["skills"] == ["s1"] - assert settings_cfg["sandboxes"] == ["sb1"] - assert settings_cfg["workspaces"] == ["ws1"] - assert settings_cfg["modelServiceName"] == "foo" - assert settings_cfg["modelName"] == "bar" + assert settings_cfg["path"] == "/invoke" + assert settings_cfg["headers"] == {} + forwarded = settings_cfg["body"]["forwardedProps"] + assert forwarded["prompt"] == "hello" + assert forwarded["agents"] == ["a1"] + assert forwarded["tools"] == ["t1", "t2"] + assert forwarded["skills"] == ["s1"] + assert forwarded["sandboxes"] == ["sb1"] + assert forwarded["workspaces"] == ["ws1"] + assert forwarded["modelServiceName"] == "foo" + assert forwarded["modelName"] == "bar" def test_to_create_input_tags_fixed(): @@ -158,7 +163,9 @@ def test_to_create_input_metadata_only_agent_runtime_name(): settings_cfg = json.loads( inp.protocol_configuration.protocol_settings[0]["config"] ) - assert settings_cfg["metadata"] == {"agentRuntimeName": "d"} + assert settings_cfg["body"]["forwardedProps"]["metadata"] == { + "agentRuntimeName": "d" + } def test_to_create_input_uses_pre_environment_endpoint(): @@ -194,15 +201,21 @@ def _make_rt(**kwargs): def test_from_agent_runtime(): config_json = json.dumps({ - "prompt": "hi", - "agents": ["a"], - "tools": ["t"], - "skills": [], - "sandboxes": [], - "workspaces": [], - "modelServiceName": "svc", - "modelName": "mod", - "metadata": {"agentRuntimeName": "foo"}, + "path": "/invoke", + "headers": {}, + "body": { + "forwardedProps": { + "prompt": "hi", + "agents": ["a"], + "tools": ["t"], + "skills": [], + "sandboxes": [], + "workspaces": [], + "modelServiceName": "svc", + "modelName": "mod", + "metadata": {"agentRuntimeName": "foo"}, + } + }, }) pc = { "type": SUPER_AGENT_PROTOCOL_TYPE, @@ -227,6 +240,89 @@ def test_from_agent_runtime(): ) +def test_from_agent_runtime_legacy_flat_config(): + """旧结构兼容: config 是扁平 dict, 业务字段直接在根 (历史 AgentRuntime).""" + config_json = json.dumps({ + "prompt": "legacy", + "agents": ["la"], + "tools": ["lt"], + "skills": [], + "sandboxes": [], + "workspaces": [], + "modelServiceName": "legacy-svc", + "modelName": "legacy-mod", + "metadata": {"agentRuntimeName": "legacy"}, + }) + pc = { + "type": SUPER_AGENT_PROTOCOL_TYPE, + "protocolSettings": [{ + "type": SUPER_AGENT_PROTOCOL_TYPE, + "config": config_json, + "name": "legacy", + "path": "/invoke", + }], + "externalEndpoint": "https://x.com/super-agents/__SUPER_AGENT__", + } + rt = _make_rt(agent_runtime_name="legacy", protocol_configuration=pc) + agent = from_agent_runtime(rt) + assert agent.prompt == "legacy" + assert agent.agents == ["la"] + assert agent.model_service_name == "legacy-svc" + + +def test_parse_super_agent_config_dict_config_new_structure(): + """config 已经是 dict (非字符串) 时也能拍平.""" + pc = { + "type": SUPER_AGENT_PROTOCOL_TYPE, + "protocolSettings": [{ + "type": SUPER_AGENT_PROTOCOL_TYPE, + "config": { + "path": "/invoke", + "headers": {}, + "body": { + "forwardedProps": { + "prompt": "p", + "agents": [], + } + }, + }, + }], + } + business = parse_super_agent_config(_make_rt(protocol_configuration=pc)) + assert business["prompt"] == "p" + assert business["agents"] == [] + + +def test_parse_super_agent_config_dict_config_legacy_flat(): + """config 是 dict + 旧扁平结构时走 fallback, 原样返回.""" + pc = { + "type": SUPER_AGENT_PROTOCOL_TYPE, + "protocolSettings": [{ + "type": SUPER_AGENT_PROTOCOL_TYPE, + "config": {"prompt": "legacy-dict", "agents": ["la"]}, + }], + } + business = parse_super_agent_config(_make_rt(protocol_configuration=pc)) + assert business == {"prompt": "legacy-dict", "agents": ["la"]} + + +def test_flatten_protocol_config_non_dict_returns_empty(): + """非 dict 输入 (防御分支) 返回空 dict.""" + from agentrun.super_agent.api.control import _flatten_protocol_config + + assert _flatten_protocol_config(None) == {} + assert _flatten_protocol_config("not-a-dict") == {} + assert _flatten_protocol_config([1, 2, 3]) == {} + + +def test_flatten_protocol_config_body_without_forwarded_props(): + """body 存在但缺 forwardedProps → fallback 到整个 cfg (旧结构).""" + from agentrun.super_agent.api.control import _flatten_protocol_config + + cfg = {"body": {"other": "x"}, "prompt": "flat"} + assert _flatten_protocol_config(cfg) == cfg + + def test_is_super_agent_true(): pc = { "type": SUPER_AGENT_PROTOCOL_TYPE, @@ -280,10 +376,11 @@ def test_to_update_input_full_protocol_replace(): assert inp.description == "new" settings = inp.protocol_configuration.protocol_settings assert len(settings) == 1 - assert ( - json.loads(settings[0]["config"])["metadata"]["agentRuntimeName"] - == "alpha" - ) + cfg_json = json.loads(settings[0]["config"]) + forwarded = cfg_json["body"]["forwardedProps"] + assert forwarded["metadata"]["agentRuntimeName"] == "alpha" + assert forwarded["prompt"] == "p" + assert forwarded["tools"] == ["t"] # ─── Dara ListAgentRuntimesRequest tags 补丁 ────────────────── From 52cf6d1bda0ba140d8d9b8369d4611991099f55e Mon Sep 17 00:00:00 2001 From: Sodawyx Date: Tue, 21 Apr 2026 14:04:56 +0800 Subject: [PATCH 7/7] refactor(super_agent): Replace `tags` with `system_tags` across multiple modules This change updates the codebase to use `system_tags` instead of `tags`, aligning with changes in the underlying SDK where `systemTags` is now natively supported. The update affects both the control logic and related unit tests. Key changes include: - Replacing references to `tags` with `system_tags` throughout the super_agent module - Updating corresponding unit tests to reflect this change - Modifying docstrings that previously mentioned filtering by tag to mention systemTag instead - Bumping the required version of alibabacloud-agentrun20250910 to >=5.6.1 which includes native support for systemTags Co-developed-by: Aone Copilot Signed-off-by: Sodawyx --- agentrun/agent_runtime/api/control.py | 9 ++ agentrun/agent_runtime/model.py | 4 + .../super_agent/__client_async_template.py | 14 +- agentrun/super_agent/api/control.py | 146 ++---------------- agentrun/super_agent/client.py | 14 +- pyproject.toml | 2 +- tests/unittests/super_agent/test_client.py | 7 +- tests/unittests/super_agent/test_control.py | 94 +++-------- 8 files changed, 63 insertions(+), 227 deletions(-) diff --git a/agentrun/agent_runtime/api/control.py b/agentrun/agent_runtime/api/control.py index 8b76d45..c3a3847 100644 --- a/agentrun/agent_runtime/api/control.py +++ b/agentrun/agent_runtime/api/control.py @@ -21,6 +21,9 @@ CreateAgentRuntimeEndpointRequest, CreateAgentRuntimeInput, CreateAgentRuntimeRequest, + DeleteAgentRuntimeEndpointRequest, + DeleteAgentRuntimeRequest, + GetAgentRuntimeEndpointRequest, GetAgentRuntimeRequest, ListAgentRuntimeEndpointsOutput, ListAgentRuntimeEndpointsRequest, @@ -193,6 +196,7 @@ def delete_agent_runtime( client = self._get_client(config) response = client.delete_agent_runtime_with_options( agent_id, + DeleteAgentRuntimeRequest(), headers=headers or {}, runtime=RuntimeOptions(), ) @@ -248,6 +252,7 @@ async def delete_agent_runtime_async( client = self._get_client(config) response = await client.delete_agent_runtime_with_options_async( agent_id, + DeleteAgentRuntimeRequest(), headers=headers or {}, runtime=RuntimeOptions(), ) @@ -778,6 +783,7 @@ def delete_agent_runtime_endpoint( response = client.delete_agent_runtime_endpoint_with_options( agent_id, endpoint_id, + DeleteAgentRuntimeEndpointRequest(), headers=headers or {}, runtime=RuntimeOptions(), ) @@ -838,6 +844,7 @@ async def delete_agent_runtime_endpoint_async( await client.delete_agent_runtime_endpoint_with_options_async( agent_id, endpoint_id, + DeleteAgentRuntimeEndpointRequest(), headers=headers or {}, runtime=RuntimeOptions(), ) @@ -1028,6 +1035,7 @@ def get_agent_runtime_endpoint( response = client.get_agent_runtime_endpoint_with_options( agent_id, endpoint_id, + GetAgentRuntimeEndpointRequest(), headers=headers or {}, runtime=RuntimeOptions(), ) @@ -1088,6 +1096,7 @@ async def get_agent_runtime_endpoint_async( await client.get_agent_runtime_endpoint_with_options_async( agent_id, endpoint_id, + GetAgentRuntimeEndpointRequest(), headers=headers or {}, runtime=RuntimeOptions(), ) diff --git a/agentrun/agent_runtime/model.py b/agentrun/agent_runtime/model.py index 965c1bf..a169ef3 100644 --- a/agentrun/agent_runtime/model.py +++ b/agentrun/agent_runtime/model.py @@ -252,6 +252,8 @@ class AgentRuntimeMutableProps(BaseModel): """会话空闲超时时间,单位:秒""" tags: Optional[List[str]] = None """标签列表""" + system_tags: Optional[List[str]] = None + """系统标签列表 (由平台内部使用, 例如 SuperAgent 用来标识下游 AgentRuntime)""" class AgentRuntimeImmutableProps(BaseModel): @@ -323,6 +325,8 @@ class AgentRuntimeListInput(PageableInput): """Agent Runtime 名称""" tags: Optional[str] = None """标签过滤,多个标签用逗号分隔""" + system_tags: Optional[str] = None + """系统标签过滤, 多个标签用逗号分隔""" search_mode: Optional[str] = None """搜索模式""" diff --git a/agentrun/super_agent/__client_async_template.py b/agentrun/super_agent/__client_async_template.py index 4beb65a..51b5723 100644 --- a/agentrun/super_agent/__client_async_template.py +++ b/agentrun/super_agent/__client_async_template.py @@ -4,7 +4,7 @@ 内部持有一个 :class:`AgentRuntimeClient` 实例, 通过 ``api/control.py`` 的 转换函数把 ``SuperAgent`` 与 ``AgentRuntime`` 互相映射。 -list 固定按 tag ``x-agentrun-super`` 过滤, 不接受用户自定义 tag。 +list 固定按 systemTag ``x-agentrun-super`` 过滤, 不接受用户自定义 tag。 """ import asyncio @@ -246,7 +246,7 @@ def _find_rt_by_name(self, name: str, config: Optional[Config]) -> Any: AgentRuntimeListInput( page_number=page_number, page_size=page_size, - tags=SUPER_AGENT_TAG, + system_tags=SUPER_AGENT_TAG, ), config=cfg, ) @@ -268,7 +268,7 @@ async def _find_rt_by_name_async( AgentRuntimeListInput( page_number=page_number, page_size=page_size, - tags=SUPER_AGENT_TAG, + system_tags=SUPER_AGENT_TAG, ), config=cfg, ) @@ -438,12 +438,12 @@ async def list_async( page_size: int = 20, config: Optional[Config] = None, ) -> List[SuperAgent]: - """异步列出超级 Agent (固定 tag 过滤, 过滤非 SUPER_AGENT).""" + """异步列出超级 Agent (固定 systemTag 过滤, 过滤非 SUPER_AGENT).""" cfg = Config.with_configs(self.config, config) rt_input = AgentRuntimeListInput( page_number=page_number, page_size=page_size, - tags=SUPER_AGENT_TAG, + system_tags=SUPER_AGENT_TAG, ) runtimes = await self._rt.list_async(rt_input, config=cfg) result: List[SuperAgent] = [] @@ -462,12 +462,12 @@ def list( page_size: int = 20, config: Optional[Config] = None, ) -> List[SuperAgent]: - """同步列出超级 Agent (固定 tag 过滤, 过滤非 SUPER_AGENT).""" + """同步列出超级 Agent (固定 systemTag 过滤, 过滤非 SUPER_AGENT).""" cfg = Config.with_configs(self.config, config) rt_input = AgentRuntimeListInput( page_number=page_number, page_size=page_size, - tags=SUPER_AGENT_TAG, + system_tags=SUPER_AGENT_TAG, ) runtimes = self._rt.list(rt_input, config=cfg) result: List[SuperAgent] = [] diff --git a/agentrun/super_agent/api/control.py b/agentrun/super_agent/api/control.py index fe3aecf..afaf665 100644 --- a/agentrun/super_agent/api/control.py +++ b/agentrun/super_agent/api/control.py @@ -20,19 +20,9 @@ if TYPE_CHECKING: from agentrun.super_agent.agent import SuperAgent -from alibabacloud_agentrun20250910.client import Client as _DaraClient -from alibabacloud_agentrun20250910.models import ( - CreateAgentRuntimeInput as _DaraCreateAgentRuntimeInput, -) -from alibabacloud_agentrun20250910.models import ( - ListAgentRuntimesRequest as _DaraListAgentRuntimesRequest, -) from alibabacloud_agentrun20250910.models import ( ProtocolConfiguration as _DaraProtocolConfiguration, ) -from alibabacloud_agentrun20250910.models import ( - UpdateAgentRuntimeInput as _DaraUpdateAgentRuntimeInput, -) from pydantic import Field from agentrun.agent_runtime.model import ( @@ -50,11 +40,12 @@ API_VERSION = "2025-09-10" SUPER_AGENT_PROTOCOL_TYPE = "SUPER_AGENT" # ``SUPER_AGENT_TAG`` 标识下游 AgentRuntime 是超级 Agent, 用于 list 过滤。 +# 写入 ``systemTags`` 字段 (由服务端原生支持), create/update/list 的 system_tags 参数统一使用。 SUPER_AGENT_TAG = "x-agentrun-super" # ``EXTERNAL_TAG`` 标识下游 AgentRuntime 由外部 (SuperAgent) 托管调用, 不由 AgentRun 直接托管。 # 保留常量以便外部消费者引用; 创建超级 Agent 时不再写入此 tag。 EXTERNAL_TAG = "x-agentrun-external" -# 创建下游 AgentRuntime 时固定写入的 tag 列表: ``[SUPER_AGENT_TAG]`` (仅一个)。 +# 创建下游 AgentRuntime 时固定写入的 systemTags 列表: ``[SUPER_AGENT_TAG]`` (仅一个)。 SUPER_AGENT_CREATE_TAGS = [SUPER_AGENT_TAG] SUPER_AGENT_RESOURCE_PATH = "__SUPER_AGENT__" SUPER_AGENT_INVOKE_PATH = "/invoke" @@ -153,16 +144,15 @@ def model_dump(self, **kwargs: Any) -> Dict[str, Any]: return super().model_dump(**kwargs) -# ─── Dara 模型/客户端猴补丁 ────────────────────────────────────── -# 当前版 Dara SDK 缺 ``ProtocolConfiguration.externalEndpoint`` 和 -# ``CreateAgentRuntimeInput/UpdateAgentRuntimeInput/ListAgentRuntimesRequest.tags`` -# 字段, 会在 Pydantic ↔ Dara ``from_map / to_map`` roundtrip 中静默丢失; 且 -# ``Client.list_agent_runtimes_with_options{,_async}`` 不会把 ``tags`` 写到 query。 +# ─── Dara 模型猴补丁 ────────────────────────────────────── +# 当前版 Dara SDK 缺 ``ProtocolConfiguration.externalEndpoint`` 字段, 会在 +# Pydantic ↔ Dara ``from_map / to_map`` roundtrip 中静默丢失。补丁延迟到 +# ``SuperAgentClient`` 实例化时 (见 ``ensure_super_agent_patches_applied``) 才 +# 触发, 避免仅 import 本模块的调用方被动承担全局副作用。补丁用哨兵属性保证 +# 幂等, 重复调用安全。TODO: 等 Dara SDK 原生支持后删除。 # -# 所有补丁都延迟到 ``SuperAgentClient`` 实例化时 (见 -# ``ensure_super_agent_patches_applied``) 才触发, 避免仅 import 本模块的调用方 -# 被动承担全局副作用。补丁本身用哨兵属性保证幂等, 重复调用安全。 -# TODO: 等 Dara SDK 原生支持后删除。 +# ``tags`` (原 hack 写入) 已由 SDK 原生 ``systemTags`` 字段替代, 不再需要任何 +# 补丁; create/update/list 统一走 ``system_tags`` 参数。 def _patch_dara_protocol_configuration() -> None: @@ -193,120 +183,14 @@ def _patched_from_map( _DaraProtocolConfiguration._super_agent_patched = True # type: ignore[attr-defined] -def _patch_dara_tags(cls: Any) -> None: - """给 Dara model 补齐 ``tags`` 字段的 from_map/to_map 读写.""" - if getattr(cls, "_super_agent_tags_patched", False): - return - _orig_to_map = cls.to_map - _orig_from_map = cls.from_map - - def _patched_to_map(self: Any) -> Dict[str, Any]: - result = _orig_to_map(self) - tags = getattr(self, "tags", None) - if tags is not None: - result["tags"] = tags - return result - - def _patched_from_map(self: Any, m: Optional[Dict[str, Any]] = None) -> Any: - _orig_from_map(self, m) - if m and m.get("tags") is not None: - self.tags = m.get("tags") - return self - - cls.to_map = _patched_to_map # type: ignore[assignment] - cls.from_map = _patched_from_map # type: ignore[assignment] - cls._super_agent_tags_patched = True # type: ignore[attr-defined] - - -def _tags_query_value(tags: Any) -> Optional[str]: - if tags is None: - return None - if isinstance(tags, str): - return tags - if isinstance(tags, (list, tuple)): - return ",".join(str(t) for t in tags) - return str(tags) - - -def _patch_dara_client_list_tags() -> None: - """包裹 ``Client.list_agent_runtimes_with_options{,_async}``: 若 request 带 - ``tags`` 就在底层 ``call_api`` 调用前把 ``tags`` (列表 → 逗号分隔) 追加到 - ``req.query``。 - - 每次 API 调用由 ``_get_client()`` 新建 ``Client`` 实例, 实例属性级别的 - ``self.call_api = _injecting`` 替换在并发下是安全的。 - """ - if getattr(_DaraClient, "_super_agent_list_tags_patched", False): - return - - _orig_sync = _DaraClient.list_agent_runtimes_with_options - _orig_async = _DaraClient.list_agent_runtimes_with_options_async - - def _patched_sync( - self: Any, request: Any, headers: Any, runtime: Any - ) -> Any: - tags_value = _tags_query_value(getattr(request, "tags", None)) - if tags_value is None: - return _orig_sync(self, request, headers, runtime) - orig_call_api = self.call_api - - def _injecting(params: Any, req: Any, rt: Any) -> Any: - if req.query is None: - req.query = {} - req.query["tags"] = tags_value - return orig_call_api(params, req, rt) - - self.call_api = _injecting - try: - return _orig_sync(self, request, headers, runtime) - finally: - try: - del self.call_api - except AttributeError: - pass - - async def _patched_async( - self: Any, request: Any, headers: Any, runtime: Any - ) -> Any: - tags_value = _tags_query_value(getattr(request, "tags", None)) - if tags_value is None: - return await _orig_async(self, request, headers, runtime) - orig_call_api_async = self.call_api_async - - async def _injecting(params: Any, req: Any, rt: Any) -> Any: - if req.query is None: - req.query = {} - req.query["tags"] = tags_value - return await orig_call_api_async(params, req, rt) - - self.call_api_async = _injecting - try: - return await _orig_async(self, request, headers, runtime) - finally: - try: - del self.call_api_async - except AttributeError: - pass - - _DaraClient.list_agent_runtimes_with_options = _patched_sync # type: ignore[assignment] - _DaraClient.list_agent_runtimes_with_options_async = _patched_async # type: ignore[assignment] - _DaraClient._super_agent_list_tags_patched = True # type: ignore[attr-defined] - - def ensure_super_agent_patches_applied() -> None: - """按需应用全部 Dara SDK 兼容补丁 (幂等)。 + """按需应用 Dara SDK 兼容补丁 (幂等)。 由 ``SuperAgentClient.__init__`` 调用。如果调用方直接使用 - ``to_create_input`` / ``to_update_input`` 并自己构造 ``CreateAgentRuntimeInput`` - / ``ListAgentRuntimesRequest``, 也应在 Pydantic → Dara 转换前调用一次本函数。 + ``to_create_input`` / ``to_update_input`` 并自己构造 Dara 输入, 也应在 + Pydantic → Dara 转换前调用一次本函数。 """ _patch_dara_protocol_configuration() - _patch_dara_tags(_DaraCreateAgentRuntimeInput) - _patch_dara_tags(_DaraUpdateAgentRuntimeInput) - # ``ListAgentRuntimesRequest`` 补齐 from_map/to_map 保留属性; 真正让服务端 - # 生效的 query 注入由 ``_patch_dara_client_list_tags`` 完成。 - _patch_dara_tags(_DaraListAgentRuntimesRequest) - _patch_dara_client_list_tags() # ─── AgentRuntime ↔ SuperAgent 转换 ──────────────────────── @@ -415,7 +299,7 @@ def to_create_input( agent_runtime_name=name, description=description, protocol_configuration=pc, - tags=list(SUPER_AGENT_CREATE_TAGS), + system_tags=list(SUPER_AGENT_CREATE_TAGS), # 超级 Agent 的数据面入口 (与 protocolConfiguration.externalEndpoint 同值)。 external_agent_endpoint_url=build_super_agent_endpoint(cfg), # 占位 artifact: SUPER_AGENT 不跑用户 container/code, 但服务端要求非空。 @@ -446,7 +330,7 @@ def to_update_input( agent_runtime_name=name, description=merged.get("description"), protocol_configuration=pc, - tags=list(SUPER_AGENT_CREATE_TAGS), + system_tags=list(SUPER_AGENT_CREATE_TAGS), # 超级 Agent 的数据面入口 (与 protocolConfiguration.externalEndpoint 同值)。 external_agent_endpoint_url=build_super_agent_endpoint(cfg), # 占位 artifact: SUPER_AGENT 不跑用户 container/code, 但服务端要求非空。 diff --git a/agentrun/super_agent/client.py b/agentrun/super_agent/client.py index 421f54d..2f0da6b 100644 --- a/agentrun/super_agent/client.py +++ b/agentrun/super_agent/client.py @@ -14,7 +14,7 @@ 内部持有一个 :class:`AgentRuntimeClient` 实例, 通过 ``api/control.py`` 的 转换函数把 ``SuperAgent`` 与 ``AgentRuntime`` 互相映射。 -list 固定按 tag ``x-agentrun-super`` 过滤, 不接受用户自定义 tag。 +list 固定按 systemTag ``x-agentrun-super`` 过滤, 不接受用户自定义 tag。 """ import asyncio @@ -256,7 +256,7 @@ def _find_rt_by_name(self, name: str, config: Optional[Config]) -> Any: AgentRuntimeListInput( page_number=page_number, page_size=page_size, - tags=SUPER_AGENT_TAG, + system_tags=SUPER_AGENT_TAG, ), config=cfg, ) @@ -278,7 +278,7 @@ async def _find_rt_by_name_async( AgentRuntimeListInput( page_number=page_number, page_size=page_size, - tags=SUPER_AGENT_TAG, + system_tags=SUPER_AGENT_TAG, ), config=cfg, ) @@ -448,12 +448,12 @@ async def list_async( page_size: int = 20, config: Optional[Config] = None, ) -> List[SuperAgent]: - """异步列出超级 Agent (固定 tag 过滤, 过滤非 SUPER_AGENT).""" + """异步列出超级 Agent (固定 systemTag 过滤, 过滤非 SUPER_AGENT).""" cfg = Config.with_configs(self.config, config) rt_input = AgentRuntimeListInput( page_number=page_number, page_size=page_size, - tags=SUPER_AGENT_TAG, + system_tags=SUPER_AGENT_TAG, ) runtimes = await self._rt.list_async(rt_input, config=cfg) result: List[SuperAgent] = [] @@ -472,12 +472,12 @@ def list( page_size: int = 20, config: Optional[Config] = None, ) -> List[SuperAgent]: - """同步列出超级 Agent (固定 tag 过滤, 过滤非 SUPER_AGENT).""" + """同步列出超级 Agent (固定 systemTag 过滤, 过滤非 SUPER_AGENT).""" cfg = Config.with_configs(self.config, config) rt_input = AgentRuntimeListInput( page_number=page_number, page_size=page_size, - tags=SUPER_AGENT_TAG, + system_tags=SUPER_AGENT_TAG, ) runtimes = self._rt.list(rt_input, config=cfg) result: List[SuperAgent] = [] diff --git a/pyproject.toml b/pyproject.toml index 1c1b78a..251e0d4 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -13,7 +13,7 @@ dependencies = [ "litellm>=1.79.3", "alibabacloud-devs20230714>=2.4.1", "pydash>=8.0.5", - "alibabacloud-agentrun20250910>=5.6.0", + "alibabacloud-agentrun20250910>=5.6.1", "alibabacloud_tea_openapi>=0.4.2", "alibabacloud_bailian20231229>=2.6.2", "agentrun-mem0ai>=0.0.10", diff --git a/tests/unittests/super_agent/test_client.py b/tests/unittests/super_agent/test_client.py index 604b4d6..5c1bc80 100644 --- a/tests/unittests/super_agent/test_client.py +++ b/tests/unittests/super_agent/test_client.py @@ -119,10 +119,7 @@ async def _create_async(dara_input, config=None): dara = captured_input["dara"] # Dara-level model uses snake_case attributes assert dara.agent_runtime_name == "alpha" - # 注: alibabacloud-agentrun20250910 的 Dara CreateAgentRuntimeInput 目前 - # 不包含 ``tags`` 字段, pydantic → Dara roundtrip 会丢弃. 校验 pydantic 侧 - # 的 rt_input 是否含 tags 在 test_control.py::test_to_create_input_tags_fixed - # 已覆盖. + assert dara.system_tags == [SUPER_AGENT_TAG] pc = dara.protocol_configuration # externalEndpoint preserved via the additive Dara monkey-patch assert pc.external_endpoint.endswith("/super-agents/__SUPER_AGENT__") @@ -337,7 +334,7 @@ async def _list(inp=None, config=None): await client.list_async() assert captured["inp"].page_number == 1 assert captured["inp"].page_size == 20 - assert captured["inp"].tags == SUPER_AGENT_TAG + assert captured["inp"].system_tags == SUPER_AGENT_TAG async def test_list_async_custom_pagination(): diff --git a/tests/unittests/super_agent/test_control.py b/tests/unittests/super_agent/test_control.py index 3831a2a..33d36a1 100644 --- a/tests/unittests/super_agent/test_control.py +++ b/tests/unittests/super_agent/test_control.py @@ -11,7 +11,6 @@ API_VERSION, build_super_agent_endpoint, ensure_super_agent_patches_applied, - EXTERNAL_TAG, from_agent_runtime, is_super_agent, parse_super_agent_config, @@ -24,7 +23,7 @@ from agentrun.super_agent.api.data import SuperAgentDataAPI from agentrun.utils.config import Config -# 本文件部分测试 (list request tags 补丁) 依赖 Dara SDK 已被打过补丁, +# 本文件部分测试依赖 Dara ProtocolConfiguration 已被打过补丁 (externalEndpoint), # 显式在模块加载时触发补丁 (幂等, 与 SuperAgentClient.__init__ 内触发点一致)。 ensure_super_agent_patches_applied() @@ -108,7 +107,7 @@ def test_to_create_input_minimal(): cfg = Config(account_id="123", region_id="cn-hangzhou") inp = to_create_input("alpha", cfg=cfg) assert inp.agent_runtime_name == "alpha" - assert inp.tags == [SUPER_AGENT_TAG] + assert inp.system_tags == [SUPER_AGENT_TAG] pc = inp.protocol_configuration assert pc.type == SUPER_AGENT_PROTOCOL_TYPE assert pc.external_endpoint.endswith("/super-agents/__SUPER_AGENT__") @@ -151,10 +150,10 @@ def test_to_create_input_full(): assert forwarded["modelName"] == "bar" -def test_to_create_input_tags_fixed(): +def test_to_create_input_system_tags_fixed(): cfg = Config(account_id="123", region_id="cn-hangzhou") inp = to_create_input("c", cfg=cfg) - assert inp.tags == [SUPER_AGENT_TAG] + assert inp.system_tags == [SUPER_AGENT_TAG] def test_to_create_input_metadata_only_agent_runtime_name(): @@ -383,65 +382,31 @@ def test_to_update_input_full_protocol_replace(): assert forwarded["tools"] == ["t"] -# ─── Dara ListAgentRuntimesRequest tags 补丁 ────────────────── -# 补丁已在模块顶部通过 ensure_super_agent_patches_applied() 显式触发。 +# ─── Dara ListAgentRuntimesRequest systemTags 原生字段 ────────────── +# ``systemTags`` 已由 Dara SDK 原生支持, 无需补丁。以下测试只校验 pydantic → +# Dara roundtrip 能把 ``system_tags`` 保留到请求 query。 -def test_list_request_from_map_preserves_tags(): +def test_list_request_from_map_preserves_system_tags(): from alibabacloud_agentrun20250910.models import ListAgentRuntimesRequest req = ListAgentRuntimesRequest().from_map({ - "tags": SUPER_AGENT_TAG, + "systemTags": SUPER_AGENT_TAG, "pageNumber": 1, "pageSize": 20, }) - assert getattr(req, "tags", None) == SUPER_AGENT_TAG + assert req.system_tags == SUPER_AGENT_TAG -def test_list_request_to_map_preserves_tags(): +def test_list_request_to_map_preserves_system_tags(): from alibabacloud_agentrun20250910.models import ListAgentRuntimesRequest req = ListAgentRuntimesRequest() - req.tags = SUPER_AGENT_TAG - assert req.to_map().get("tags") == SUPER_AGENT_TAG + req.system_tags = SUPER_AGENT_TAG + assert req.to_map().get("systemTags") == SUPER_AGENT_TAG -def _invoke_list_patch(tags_value): - """调用打过补丁的 ``list_agent_runtimes_with_options``, 捕获 call_api 的 query.""" - from alibabacloud_agentrun20250910.client import Client as _DaraClient - from alibabacloud_agentrun20250910.models import ListAgentRuntimesRequest - from darabonba.runtime import RuntimeOptions - - captured = {} - - def _fake_call_api(self, params, req, rt): - captured["query"] = dict(req.query) if req.query else {} - raise RuntimeError("_stop_after_query_capture_") - - client = _DaraClient.__new__(_DaraClient) - client._endpoint = "x" - # 绑定实例级 call_api (优先于类方法) - client.call_api = _fake_call_api.__get__(client, _DaraClient) - - req = ListAgentRuntimesRequest(page_number=1, page_size=20) - req.tags = tags_value - with pytest.raises(RuntimeError, match="_stop_after_query_capture_"): - client.list_agent_runtimes_with_options(req, {}, RuntimeOptions()) - return captured["query"] - - -def test_list_with_options_injects_tags_str(): - query = _invoke_list_patch(SUPER_AGENT_TAG) - assert query.get("tags") == SUPER_AGENT_TAG - assert query.get("pageNumber") == "1" - - -def test_list_with_options_injects_tags_list_comma_join(): - query = _invoke_list_patch([EXTERNAL_TAG, SUPER_AGENT_TAG]) - assert query.get("tags") == f"{EXTERNAL_TAG},{SUPER_AGENT_TAG}" - - -def test_list_with_options_no_tags_no_injection(): +def test_list_with_options_writes_system_tags_query(): from alibabacloud_agentrun20250910.client import Client as _DaraClient from alibabacloud_agentrun20250910.models import ListAgentRuntimesRequest from darabonba.runtime import RuntimeOptions @@ -456,32 +421,9 @@ def _fake_call_api(self, params, req, rt): client._endpoint = "x" client.call_api = _fake_call_api.__get__(client, _DaraClient) - req = ListAgentRuntimesRequest(page_number=1, page_size=20) + req = ListAgentRuntimesRequest( + page_number=1, page_size=20, system_tags=SUPER_AGENT_TAG + ) with pytest.raises(RuntimeError, match="_stop_"): client.list_agent_runtimes_with_options(req, {}, RuntimeOptions()) - assert "tags" not in captured["query"] - - -@pytest.mark.asyncio -async def test_list_with_options_async_injects_tags(): - from alibabacloud_agentrun20250910.client import Client as _DaraClient - from alibabacloud_agentrun20250910.models import ListAgentRuntimesRequest - from darabonba.runtime import RuntimeOptions - - captured = {} - - async def _fake_call_api_async(self, params, req, rt): - captured["query"] = dict(req.query) if req.query else {} - raise RuntimeError("_stop_") - - client = _DaraClient.__new__(_DaraClient) - client._endpoint = "x" - client.call_api_async = _fake_call_api_async.__get__(client, _DaraClient) - - req = ListAgentRuntimesRequest(page_number=1, page_size=20) - req.tags = SUPER_AGENT_TAG - with pytest.raises(RuntimeError, match="_stop_"): - await client.list_agent_runtimes_with_options_async( - req, {}, RuntimeOptions() - ) - assert captured["query"].get("tags") == SUPER_AGENT_TAG + assert captured["query"].get("systemTags") == SUPER_AGENT_TAG